diff --git a/datafusion/functions-nested/src/remove.rs b/datafusion/functions-nested/src/remove.rs index d0f838ddad12a..388c60c87d6e7 100644 --- a/datafusion/functions-nested/src/remove.rs +++ b/datafusion/functions-nested/src/remove.rs @@ -18,16 +18,17 @@ //! [`ScalarUDFImpl`] definitions for array_remove, array_remove_n, array_remove_all functions. use crate::utils; -use crate::utils::make_scalar_function; use arrow::array::{ Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, OffsetSizeTrait, - cast::AsArray, make_array, + Scalar, cast::AsArray, make_array, }; use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::cast::as_int64_array; use datafusion_common::utils::ListCoercion; -use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args}; +use datafusion_common::{ + Result, ScalarValue, exec_err, internal_err, utils::take_function_args, +}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -113,7 +114,21 @@ impl ScalarUDFImpl for ArrayRemove { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(array_remove_inner)(&args.args) + let [list_arg, element_arg] = take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + let arr_n = vec![1; num_rows]; + match element_arg { + ColumnarValue::Array(element_array) => { + let result = array_remove_internal(&list_array, element_array, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(scalar_element) => { + let result = + remove_with_scalar_needle(&list_array, scalar_element, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -214,7 +229,23 @@ impl ScalarUDFImpl for ArrayRemoveN { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(array_remove_n_inner)(&args.args) + let [list_arg, element_arg, max_arg] = + take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + let max_array = max_arg.to_array(num_rows)?; + let arr_n = as_int64_array(&max_array)?.values().to_vec(); + match element_arg { + ColumnarValue::Array(element_array) => { + let result = array_remove_internal(&list_array, element_array, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(scalar_element) => { + let result = + remove_with_scalar_needle(&list_array, scalar_element, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -304,7 +335,21 @@ impl ScalarUDFImpl for ArrayRemoveAll { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(array_remove_all_inner)(&args.args) + let [list_arg, element_arg] = take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + let arr_n = vec![i64::MAX; num_rows]; + match element_arg { + ColumnarValue::Array(element_array) => { + let result = array_remove_internal(&list_array, element_array, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(scalar_element) => { + let result = + remove_with_scalar_needle(&list_array, scalar_element, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -316,27 +361,6 @@ impl ScalarUDFImpl for ArrayRemoveAll { } } -fn array_remove_inner(args: &[ArrayRef]) -> Result { - let [array, element] = take_function_args("array_remove", args)?; - - let arr_n = vec![1; array.len()]; - array_remove_internal(array, element, &arr_n) -} - -fn array_remove_n_inner(args: &[ArrayRef]) -> Result { - let [array, element, max] = take_function_args("array_remove_n", args)?; - - let arr_n = as_int64_array(max)?.values().to_vec(); - array_remove_internal(array, element, &arr_n) -} - -fn array_remove_all_inner(args: &[ArrayRef]) -> Result { - let [array, element] = take_function_args("array_remove_all", args)?; - - let arr_n = vec![i64::MAX; array.len()]; - array_remove_internal(array, element, &arr_n) -} - fn array_remove_internal( array: &ArrayRef, element_array: &ArrayRef, @@ -357,6 +381,47 @@ fn array_remove_internal( } } +/// Dispatches scalar-needle array removal by list offset type. +/// +/// `needle` must be a length-1 array containing the scalar element to remove. +fn array_remove_dispatch_scalar( + array: &ArrayRef, + needle: &ArrayRef, + arr_n: &[i64], +) -> Result { + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_remove_with_scalar::(list_array, needle, arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_remove_with_scalar::(list_array, needle, arr_n) + } + array_type => exec_err!( + "array_remove/array_remove_n/array_remove_all does not support type '{array_type}'." + ), + } +} + +/// Removes elements matching a scalar needle from a list array. +/// +/// Uses a bulk `distinct` comparison for non-null, non-nested scalar elements, +/// falling back to the per-row `general_remove` path for null or nested types. +fn remove_with_scalar_needle( + list_array: &ArrayRef, + scalar_element: &ScalarValue, + arr_n: &[i64], +) -> Result { + if !scalar_element.is_null() && !scalar_element.data_type().is_nested() { + let needle = scalar_element.to_array_of_size(1)?; + array_remove_dispatch_scalar(list_array, &needle, arr_n) + } else { + let needle_array = scalar_element.to_array_of_size(list_array.len())?; + array_remove_internal(list_array, &needle_array, arr_n) + } +} + /// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences /// of `element_array[i]`. /// @@ -468,6 +533,114 @@ fn general_remove( )?)) } +/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences +/// of `needle[0]` (scalar element broadcasted). +/// +/// This is a specialized version of `general_remove` for scalar elements that +/// uses bulk comparison for better performance. +fn general_remove_with_scalar( + list_array: &GenericListArray, + needle: &ArrayRef, + arr_n: &[i64], +) -> Result { + let list_field = match list_array.data_type() { + DataType::List(field) | DataType::LargeList(field) => field, + _ => { + return exec_err!( + "Expected List or LargeList data type, got {:?}", + list_array.data_type() + ); + } + }; + + let list_offsets = list_array.offsets(); + let first_offset = list_offsets[0].to_usize().unwrap(); + let last_offset = list_offsets[list_offsets.len() - 1].to_usize().unwrap(); + let values_range_len = last_offset - first_offset; + let values_slice = list_array.values().slice(first_offset, values_range_len); + let original_data = values_slice.to_data(); + let mut offsets = Vec::::with_capacity(list_array.len() + 1); + offsets.push(OffsetSize::zero()); + + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data], + false, + Capacities::Array(original_data.len()), + ); + let nulls = list_array.nulls().cloned(); + let keep_mask = + arrow_ord::cmp::distinct(&values_slice, &Scalar::new(Arc::clone(needle)))?; + let remove_bits = match keep_mask.nulls() { + Some(validity) => !(&(keep_mask.values() & validity.inner())), + None => !keep_mask.values(), + }; + + for (row_index, offset_window) in list_offsets.windows(2).enumerate() { + if nulls.as_ref().is_some_and(|nulls| nulls.is_null(row_index)) { + offsets.push(offsets[row_index]); + continue; + } + + let start = offset_window[0].to_usize().unwrap() - first_offset; + let end = offset_window[1].to_usize().unwrap() - first_offset; + + let n = arr_n[row_index]; + + if n <= 0 { + mutable.extend(0, start, end); + offsets.push(offsets[row_index] + OffsetSize::usize_as(end - start)); + continue; + } + + let row_len = end - start; + let row_remove_bits = remove_bits.slice(start, row_len); + let num_to_remove = row_remove_bits.count_set_bits(); + + if num_to_remove == 0 { + mutable.extend(0, start, end); + offsets.push(offsets[row_index] + OffsetSize::usize_as(row_len)); + continue; + } + + let max_removals = n.min(num_to_remove as i64) as usize; + + // Iterate only over the removal positions via set_indices. This is + // efficient when the number of removals is small relative to the row + // length (common case), since it skips over retained elements. + let mut removed = 0usize; + let mut copied = 0usize; + let mut prev_end = start; // end of last copied range (absolute index into values_slice) + for remove_pos in row_remove_bits.set_indices() { + let abs_pos = start + remove_pos; + // Copy the range before this removal position + if abs_pos > prev_end { + mutable.extend(0, prev_end, abs_pos); + copied += abs_pos - prev_end; + } + prev_end = abs_pos + 1; + removed += 1; + if removed == max_removals { + break; + } + } + // Copy the remaining tail after the last removal + if prev_end < end { + mutable.extend(0, prev_end, end); + copied += end - prev_end; + } + + offsets.push(offsets[row_index] + OffsetSize::usize_as(copied)); + } + + let new_values = make_array(mutable.freeze()); + Ok(Arc::new(GenericListArray::::try_new( + Arc::clone(list_field), + OffsetBuffer::new(offsets.into()), + new_values, + nulls, + )?)) +} + #[cfg(test)] mod tests { use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN}; diff --git a/datafusion/sqllogictest/test_files/array/array_remove.slt b/datafusion/sqllogictest/test_files/array/array_remove.slt index c3ce7073eca83..456ebb6482341 100644 --- a/datafusion/sqllogictest/test_files/array/array_remove.slt +++ b/datafusion/sqllogictest/test_files/array/array_remove.slt @@ -537,4 +537,76 @@ select array_remove_all(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] +# array_remove scalar arguments over multiple input rows +query ??? +select + array_remove(column1, 2), + array_remove_n(column1, 2, 2), + array_remove_all(column1, 2) +from ( + values + (make_array(1, 2, 2, 3, 2, 1, 4)), + (make_array(42, 2, 55, 63, 2)) +) as t(column1); +---- +[1, 2, 3, 2, 1, 4] [1, 3, 2, 1, 4] [1, 3, 1, 4] +[42, 55, 63, 2] [42, 55, 63] [42, 55, 63] + +# array_remove with elements containing NULLs — scalar path preserves NULLs +query ??? +select + array_remove(column1, 2), + array_remove_n(column1, 2, 2), + array_remove_all(column1, 2) +from ( + values + (make_array(1, 2, NULL, 3, 2, NULL, 4)), + (make_array(42, 2, NULL, 63, 2)) +) as t(column1); +---- +[1, NULL, 3, 2, NULL, 4] [1, NULL, 3, NULL, 4] [1, NULL, 3, NULL, 4] +[42, NULL, 63, 2] [42, NULL, 63] [42, NULL, 63] + +# array_remove_n with n exceeding match count +query ? +select array_remove_n(make_array(1, 2, 2, 3), 2, 100); +---- +[1, 3] + +# array_remove_n with n=0 and n=-1 (no removal) +query ?? +select + array_remove_n(make_array(1, 2, 2, 3), 2, 0), + array_remove_n(make_array(1, 2, 2, 3), 2, -1); +---- +[1, 2, 2, 3] [1, 2, 2, 3] + +# array_remove on empty arrays +query ?? +select + array_remove(arrow_cast(make_array(), 'List(Int64)'), 1), + array_remove_all(arrow_cast(make_array(), 'List(Int64)'), 1); +---- +[] [] + +# array_remove needle not found — array unchanged +query ? +select array_remove_all(make_array(1, 2, 3, 4, 5), 99); +---- +[1, 2, 3, 4, 5] + +# array_remove all elements match +query ? +select array_remove_all(make_array(7, 7, 7, 7), 7); +---- +[] + +# LargeList scalar path edge cases +query ?? +select + array_remove_all(arrow_cast(make_array(1, 1, 1), 'LargeList(Int64)'), 1), + array_remove_n(arrow_cast(make_array(1, 1, 1), 'LargeList(Int64)'), 1, 2); +---- +[] [1] + include ./cleanup.slt.part