diff --git a/arrow-arith/src/temporal.rs b/arrow-arith/src/temporal.rs index 12970f4498fa..869fba29cd15 100644 --- a/arrow-arith/src/temporal.rs +++ b/arrow-arith/src/temporal.rs @@ -23,6 +23,7 @@ use arrow_array::cast::AsArray; use cast::as_primitive_array; use chrono::{Datelike, TimeZone, Timelike, Utc}; +use arrow_array::ree_map; use arrow_array::temporal_conversions::{ MICROSECONDS, MICROSECONDS_IN_DAY, MILLISECONDS, MILLISECONDS_IN_DAY, NANOSECONDS, NANOSECONDS_IN_DAY, SECONDS_IN_DAY, date32_to_datetime, date64_to_datetime, @@ -194,6 +195,15 @@ pub fn date_part(array: &dyn Array, part: DatePart) -> Result match k.data_type() { + DataType::Int16 => ree_map!(array, Int16Type, |a| date_part(a, part)), + DataType::Int32 => ree_map!(array, Int32Type, |a| date_part(a, part)), + DataType::Int64 => ree_map!(array, Int64Type, |a| date_part(a, part)), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid run-end type: {:?}", + k.data_type() + ))), + }, t => return_compute_error_with!(format!("{part} does not support"), t), ) } @@ -2040,4 +2050,33 @@ mod tests { assert_eq!(2015, actual.value(1)); assert_eq!(2016, actual.value(2)); } + + #[test] + fn test_ree_timestamp_year() { + let vals: TimestampSecondArray = + vec![Some(1514764800), Some(1550636625), Some(1550636625)].into(); + let run_ends = Int32Array::from(vec![1, 2, 3]); + let ree = RunArray::try_new(&run_ends, &vals).unwrap(); + + let b = date_part(&ree, DatePart::Year).unwrap(); + let ree_result = b.as_run_opt::().unwrap(); + let values = ree_result.values().as_primitive::(); + assert_eq!(2018, values.value(0)); + assert_eq!(2019, values.value(1)); + assert_eq!(2019, values.value(2)); + } + + #[test] + fn test_ree_date64_month() { + let vals: PrimitiveArray = + vec![Some(1514764800000), Some(1550636625000)].into(); + let run_ends = Int64Array::from(vec![2, 4]); + let ree = RunArray::try_new(&run_ends, &vals).unwrap(); + + let b = date_part(&ree, DatePart::Month).unwrap(); + let ree_result = b.as_run_opt::().unwrap(); + let values = ree_result.values().as_primitive::(); + assert_eq!(1, values.value(0)); + assert_eq!(2, values.value(1)); + } } diff --git a/arrow-array/src/array/run_array.rs b/arrow-array/src/array/run_array.rs index f317af6a10f0..02bc730b32df 100644 --- a/arrow-array/src/array/run_array.rs +++ b/arrow-array/src/array/run_array.rs @@ -30,6 +30,25 @@ use crate::{ types::{Int16Type, Int32Type, Int64Type, RunEndIndexType}, }; +/// Recursively applies a function to the values of a RunEndEncoded array, preserving the run structure. +/// +/// # Example +/// +/// ```ignore +/// let result = ree_recurse!(array, Int32Type, my_function)?; +/// ``` +/// +/// This macro is useful for implementing functions that should work on the logical values +/// of a REE array while preserving the run-end encoding structure. +#[macro_export] +macro_rules! ree_map { + ($array:expr, $run_type:ty, $func:expr) => {{ + let ree = $array.as_run_opt::<$run_type>().unwrap(); + let inner_values = $func(ree.values().as_ref())?; + Ok(std::sync::Arc::new(ree.with_values(inner_values))) + }}; +} + /// An array of [run-end encoded values]. /// /// This encoding is variation on [run-length encoding (RLE)] and is good for representing @@ -200,6 +219,46 @@ impl RunArray { &self.values } + /// Returns a new [`RunArray`] with the same `run_ends` and the supplied `values`. + /// + /// # Panics + /// + /// Panics if `values.len()` does not equal `self.values().len()`. + /// + /// # Example + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::{RunArray, Int32Array, StringArray, ArrayRef,Array}; + /// # use arrow_array::types::Int32Type; + /// // A RunArray logically representing ["a", "a", "b", "c", "c"] + /// let run_ends = Int32Array::from(vec![2, 3, 5]); + /// let values: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); + /// let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + /// + /// // Swap in new values while keeping the same run pattern. + /// // The result logically represents ["x", "x", "y", "z", "z"]. + /// let new_values: ArrayRef = Arc::new(StringArray::from(vec!["x", "y", "z"])); + /// let new_run_array = run_array.with_values(new_values); + /// + /// assert_eq!(new_run_array.len(), 5); + /// assert_eq!(new_run_array.run_ends().values(), &[2, 3, 5]); + /// ``` + pub fn with_values(&self, values: ArrayRef) -> Self { + assert_eq!(values.len(), self.values().len()); + let (run_ends_field, values_field) = match &self.data_type { + DataType::RunEndEncoded(r, v) => (r, v), + _ => unreachable!("RunArray should have type RunEndEncoded"), + }; + let data_type = + DataType::RunEndEncoded(Arc::clone(run_ends_field), Arc::clone(values_field)); + Self { + data_type, + run_ends: self.run_ends.clone(), + values, + } + } + /// Similar to [`values`] but accounts for logical slicing, returning only the values /// that are part of the logical slice of this array. /// diff --git a/arrow-string/src/length.rs b/arrow-string/src/length.rs index ff98c5632b74..feefe1247e2c 100644 --- a/arrow-string/src/length.rs +++ b/arrow-string/src/length.rs @@ -17,25 +17,12 @@ //! Defines kernel for length of string arrays and binary arrays +use arrow_array::ree_map; use arrow_array::*; use arrow_array::{cast::AsArray, types::*}; use arrow_buffer::{ArrowNativeType, NullBuffer, OffsetBuffer}; use arrow_schema::{ArrowError, DataType}; use std::sync::Arc; -macro_rules! ree_length { - ($array:expr, $run_type:ty, $k:expr, $v:expr) => {{ - let ree = $array.as_run_opt::<$run_type>().unwrap(); - let inner_value_lengths = length(ree.values().as_ref())?; - let out_ree = unsafe { - RunArray::<$run_type>::new_unchecked( - DataType::RunEndEncoded(Arc::clone($k), Arc::clone($v)), - ree.run_ends().clone(), - inner_value_lengths, - ) - }; - Ok(Arc::new(out_ree) as ArrayRef) - }}; -} fn length_impl( offsets: &OffsetBuffer, @@ -130,10 +117,10 @@ pub fn length(array: &dyn Array) -> Result { list.nulls().cloned(), )?)) } - DataType::RunEndEncoded(k, v) => match k.data_type() { - DataType::Int16 => ree_length!(array, Int16Type, &k, &v), - DataType::Int32 => ree_length!(array, Int32Type, &k, &v), - DataType::Int64 => ree_length!(array, Int64Type, &k, &v), + DataType::RunEndEncoded(k, _) => match k.data_type() { + DataType::Int16 => ree_map!(array, Int16Type, length), + DataType::Int32 => ree_map!(array, Int32Type, length), + DataType::Int64 => ree_map!(array, Int64Type, length), _ => Err(ArrowError::InvalidArgumentError(format!( "Invalid run-end type: {:?}", k.data_type() @@ -149,7 +136,7 @@ pub fn length(array: &dyn Array) -> Result { /// /// * this only accepts StringArray/Utf8, LargeString/LargeUtf8, StringViewArray/Utf8View, /// BinaryArray, LargeBinaryArray, BinaryViewArray, and FixedSizeBinaryArray, -/// or DictionaryArray with above Arrays as values +/// or DictionaryArray/REE with above Arrays as values /// * bit_length of null is null. /// * bit_length is in number of bits pub fn bit_length(array: &dyn Array) -> Result { @@ -203,6 +190,15 @@ pub fn bit_length(array: &dyn Array) -> Result { array.nulls().cloned(), )?)) } + DataType::RunEndEncoded(k, _) => match k.data_type() { + DataType::Int16 => ree_map!(array, Int16Type, bit_length), + DataType::Int32 => ree_map!(array, Int32Type, bit_length), + DataType::Int64 => ree_map!(array, Int64Type, bit_length), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid run-end type: {:?}", + k.data_type() + ))), + }, other => Err(ArrowError::ComputeError(format!( "bit_length not supported for {other:?}" ))), @@ -903,4 +899,27 @@ mod tests { assert!(length(&ree_array).is_err()); } + + #[test] + fn bit_length_test_ree_utf8() { + use arrow_array::RunArray; + use arrow_array::types::Int32Type; + + let strings = StringArray::from(vec!["hello", "world", "test"]); + let run_ends = PrimitiveArray::::from(vec![1i32, 2, 3]); + let ree_array = RunArray::::try_new(&run_ends, &strings).unwrap(); + + let result = bit_length(&ree_array).unwrap(); + let result_values = result + .as_any() + .downcast_ref::>() + .unwrap() + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + let expected: Int32Array = vec![40, 40, 32].into(); + assert_eq!(&expected, result_values); + } }