-
Notifications
You must be signed in to change notification settings - Fork 2.1k
perf: optimize array_remove for scalar needle #22390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
9b3f32f
0c787f4
9697c3a
a6e6ad1
b6023a0
2e7cd40
cdb6021
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,16 +18,17 @@ | |
| //! [`ScalarUDFImpl`] definitions for array_remove, array_remove_n, array_remove_all functions. | ||
|
|
||
| use crate::utils; | ||
| use crate::utils::make_scalar_function; | ||
| use arrow::array::{ | ||
| Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, OffsetSizeTrait, | ||
| cast::AsArray, make_array, | ||
| Scalar, cast::AsArray, make_array, | ||
| }; | ||
| use arrow::buffer::{NullBuffer, OffsetBuffer}; | ||
| use arrow::datatypes::{DataType, FieldRef}; | ||
| use datafusion_common::cast::as_int64_array; | ||
| use datafusion_common::utils::ListCoercion; | ||
| use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args}; | ||
| use datafusion_common::{ | ||
| Result, ScalarValue, exec_err, internal_err, utils::take_function_args, | ||
| }; | ||
| use datafusion_expr::{ | ||
| ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, | ||
| ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, | ||
|
|
@@ -113,7 +114,21 @@ impl ScalarUDFImpl for ArrayRemove { | |
| } | ||
|
|
||
| fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
| make_scalar_function(array_remove_inner)(&args.args) | ||
| let [list_arg, element_arg] = take_function_args(self.name(), &args.args)?; | ||
| let num_rows = args.number_rows; | ||
| let list_array = list_arg.to_array(num_rows)?; | ||
| let arr_n = vec![1; num_rows]; | ||
| match element_arg { | ||
| ColumnarValue::Array(element_array) => { | ||
| let result = array_remove_internal(&list_array, element_array, &arr_n)?; | ||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
| ColumnarValue::Scalar(scalar_element) => { | ||
| let result = | ||
| remove_with_scalar_needle(&list_array, scalar_element, &arr_n)?; | ||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| fn aliases(&self) -> &[String] { | ||
|
|
@@ -214,7 +229,23 @@ impl ScalarUDFImpl for ArrayRemoveN { | |
| } | ||
|
|
||
| fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
| make_scalar_function(array_remove_n_inner)(&args.args) | ||
| let [list_arg, element_arg, max_arg] = | ||
| take_function_args(self.name(), &args.args)?; | ||
| let num_rows = args.number_rows; | ||
| let list_array = list_arg.to_array(num_rows)?; | ||
| let max_array = max_arg.to_array(num_rows)?; | ||
| let arr_n = as_int64_array(&max_array)?.values().to_vec(); | ||
| match element_arg { | ||
| ColumnarValue::Array(element_array) => { | ||
| let result = array_remove_internal(&list_array, element_array, &arr_n)?; | ||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
| ColumnarValue::Scalar(scalar_element) => { | ||
| let result = | ||
| remove_with_scalar_needle(&list_array, scalar_element, &arr_n)?; | ||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| fn aliases(&self) -> &[String] { | ||
|
|
@@ -304,7 +335,21 @@ impl ScalarUDFImpl for ArrayRemoveAll { | |
| } | ||
|
|
||
| fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
| make_scalar_function(array_remove_all_inner)(&args.args) | ||
| let [list_arg, element_arg] = take_function_args(self.name(), &args.args)?; | ||
| let num_rows = args.number_rows; | ||
| let list_array = list_arg.to_array(num_rows)?; | ||
| let arr_n = vec![i64::MAX; num_rows]; | ||
| match element_arg { | ||
| ColumnarValue::Array(element_array) => { | ||
| let result = array_remove_internal(&list_array, element_array, &arr_n)?; | ||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
| ColumnarValue::Scalar(scalar_element) => { | ||
| let result = | ||
| remove_with_scalar_needle(&list_array, scalar_element, &arr_n)?; | ||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| fn aliases(&self) -> &[String] { | ||
|
|
@@ -316,27 +361,6 @@ impl ScalarUDFImpl for ArrayRemoveAll { | |
| } | ||
| } | ||
|
|
||
| fn array_remove_inner(args: &[ArrayRef]) -> Result<ArrayRef> { | ||
| let [array, element] = take_function_args("array_remove", args)?; | ||
|
|
||
| let arr_n = vec![1; array.len()]; | ||
| array_remove_internal(array, element, &arr_n) | ||
| } | ||
|
|
||
| fn array_remove_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> { | ||
| let [array, element, max] = take_function_args("array_remove_n", args)?; | ||
|
|
||
| let arr_n = as_int64_array(max)?.values().to_vec(); | ||
| array_remove_internal(array, element, &arr_n) | ||
| } | ||
|
|
||
| fn array_remove_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> { | ||
| let [array, element] = take_function_args("array_remove_all", args)?; | ||
|
|
||
| let arr_n = vec![i64::MAX; array.len()]; | ||
| array_remove_internal(array, element, &arr_n) | ||
| } | ||
|
|
||
| fn array_remove_internal( | ||
| array: &ArrayRef, | ||
| element_array: &ArrayRef, | ||
|
|
@@ -357,6 +381,47 @@ fn array_remove_internal( | |
| } | ||
| } | ||
|
|
||
| /// Dispatches scalar-needle array removal by list offset type. | ||
| /// | ||
| /// `needle` must be a length-1 array containing the scalar element to remove. | ||
| fn array_remove_dispatch_scalar( | ||
| array: &ArrayRef, | ||
| needle: &ArrayRef, | ||
| arr_n: &[i64], | ||
| ) -> Result<ArrayRef> { | ||
| match array.data_type() { | ||
| DataType::List(_) => { | ||
| let list_array = array.as_list::<i32>(); | ||
| general_remove_with_scalar::<i32>(list_array, needle, arr_n) | ||
| } | ||
| DataType::LargeList(_) => { | ||
| let list_array = array.as_list::<i64>(); | ||
| general_remove_with_scalar::<i64>(list_array, needle, arr_n) | ||
| } | ||
| array_type => exec_err!( | ||
| "array_remove/array_remove_n/array_remove_all does not support type '{array_type}'." | ||
| ), | ||
| } | ||
| } | ||
|
|
||
| /// Removes elements matching a scalar needle from a list array. | ||
| /// | ||
| /// Uses a bulk `distinct` comparison for non-null, non-nested scalar elements, | ||
| /// falling back to the per-row `general_remove` path for null or nested types. | ||
| fn remove_with_scalar_needle( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find this function a little confusing, since the code paths can then be like: Perhaps we can hoist this null/nested check earlier to remove need for this function? Or the benefits of centralizing the check outweighs it? |
||
| list_array: &ArrayRef, | ||
| scalar_element: &ScalarValue, | ||
| arr_n: &[i64], | ||
| ) -> Result<ArrayRef> { | ||
| if !scalar_element.is_null() && !scalar_element.data_type().is_nested() { | ||
| let needle = scalar_element.to_array_of_size(1)?; | ||
| array_remove_dispatch_scalar(list_array, &needle, arr_n) | ||
| } else { | ||
| let needle_array = scalar_element.to_array_of_size(list_array.len())?; | ||
| array_remove_internal(list_array, &needle_array, arr_n) | ||
| } | ||
| } | ||
|
|
||
| /// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences | ||
| /// of `element_array[i]`. | ||
| /// | ||
|
|
@@ -468,6 +533,113 @@ fn general_remove<OffsetSize: OffsetSizeTrait>( | |
| )?)) | ||
| } | ||
|
|
||
| /// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences | ||
| /// of `needle[0]` (scalar element broadcasted). | ||
| /// | ||
| /// This is a specialized version of `general_remove` for scalar elements that | ||
| /// uses bulk comparison for better performance. | ||
| fn general_remove_with_scalar<OffsetSize: OffsetSizeTrait>( | ||
| list_array: &GenericListArray<OffsetSize>, | ||
| needle: &ArrayRef, | ||
| arr_n: &[i64], | ||
|
Comment on lines
+543
to
+544
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar question to that of array_replace PR, in if we should just handle when needle & max are scalars only |
||
| ) -> Result<ArrayRef> { | ||
| let list_field = match list_array.data_type() { | ||
| DataType::List(field) | DataType::LargeList(field) => field, | ||
| _ => { | ||
| return exec_err!( | ||
| "Expected List or LargeList data type, got {:?}", | ||
| list_array.data_type() | ||
| ); | ||
| } | ||
| }; | ||
|
|
||
| let list_offsets = list_array.offsets(); | ||
| let first_offset = list_offsets[0].to_usize().unwrap(); | ||
| let last_offset = list_offsets[list_offsets.len() - 1].to_usize().unwrap(); | ||
| let values_range_len = last_offset - first_offset; | ||
| let values_slice = list_array.values().slice(first_offset, values_range_len); | ||
| let original_data = values_slice.to_data(); | ||
| let mut offsets = Vec::<OffsetSize>::with_capacity(list_array.len() + 1); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
| offsets.push(OffsetSize::zero()); | ||
|
|
||
| let mut mutable = MutableArrayData::with_capacities( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if an approach using |
||
| vec![&original_data], | ||
| false, | ||
| Capacities::Array(original_data.len()), | ||
| ); | ||
| let nulls = list_array.nulls().cloned(); | ||
| let keep_mask = | ||
| arrow_ord::cmp::distinct(list_array.values(), &Scalar::new(Arc::clone(needle)))?; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will call the distinct kernel on all the elements in the value buffer, not just the ones that are visible in a sliced array. |
||
| let remove_bits = match keep_mask.nulls() { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
https://docs.rs/arrow/latest/arrow/compute/kernels/cmp/fn.distinct.html I don't think we should have handling around the null buffer of |
||
| Some(validity) => !(&(keep_mask.values() & validity.inner())), | ||
| None => !keep_mask.values(), | ||
| }; | ||
|
|
||
| for (row_index, offset_window) in list_offsets.windows(2).enumerate() { | ||
| if nulls.as_ref().is_some_and(|nulls| nulls.is_null(row_index)) { | ||
| offsets.push(offsets[row_index]); | ||
| continue; | ||
| } | ||
|
|
||
| let start = offset_window[0].to_usize().unwrap() - first_offset; | ||
| let end = offset_window[1].to_usize().unwrap() - first_offset; | ||
|
|
||
| let n = arr_n[row_index]; | ||
|
|
||
| if n <= 0 { | ||
| mutable.extend(0, start, end); | ||
| offsets.push(offsets[row_index] + OffsetSize::usize_as(end - start)); | ||
| continue; | ||
| } | ||
|
|
||
| let row_len = end - start; | ||
| let row_remove_bits = remove_bits.slice(first_offset + start, row_len); | ||
| let num_to_remove = row_remove_bits.count_set_bits(); | ||
|
|
||
| if num_to_remove == 0 { | ||
| mutable.extend(0, start, end); | ||
| offsets.push(offsets[row_index] + OffsetSize::usize_as(row_len)); | ||
| continue; | ||
| } | ||
|
|
||
| let max_removals = n.min(num_to_remove as i64) as usize; | ||
|
|
||
| // Iterate only over the positions that need removal using set_indices, | ||
| // which is more efficient than scanning every bit. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be worth elaborating that the win here is mostly because we expect the # of values-to-remove is a lot smaller than the total array size, which it usually (but not always) will be. |
||
| let mut removed = 0usize; | ||
| let mut copied = 0usize; | ||
| let mut prev_end = start; // end of last copied range (absolute index into values_slice) | ||
| for remove_pos in row_remove_bits.set_indices() { | ||
| let abs_pos = start + remove_pos; | ||
| // Copy the range before this removal position | ||
| if abs_pos > prev_end { | ||
| mutable.extend(0, prev_end, abs_pos); | ||
| copied += abs_pos - prev_end; | ||
| } | ||
| prev_end = abs_pos + 1; | ||
| removed += 1; | ||
| if removed == max_removals { | ||
| break; | ||
| } | ||
| } | ||
| // Copy the remaining tail after the last removal | ||
| if prev_end < end { | ||
| mutable.extend(0, prev_end, end); | ||
| copied += end - prev_end; | ||
| } | ||
|
|
||
| offsets.push(offsets[row_index] + OffsetSize::usize_as(copied)); | ||
| } | ||
|
|
||
| let new_values = make_array(mutable.freeze()); | ||
| Ok(Arc::new(GenericListArray::<OffsetSize>::try_new( | ||
| Arc::clone(list_field), | ||
| OffsetBuffer::new(offsets.into()), | ||
| new_values, | ||
| nulls, | ||
| )?)) | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN}; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're ignoring nulls in
max_arghere, though I think this may be an existing issue