-
Notifications
You must be signed in to change notification settings - Fork 2.1k
perf: optimize array_remove for scalar needle #22390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
9b3f32f
0c787f4
9697c3a
a6e6ad1
b6023a0
2e7cd40
cdb6021
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,16 +18,17 @@ | |
| //! [`ScalarUDFImpl`] definitions for array_remove, array_remove_n, array_remove_all functions. | ||
|
|
||
| use crate::utils; | ||
| use crate::utils::make_scalar_function; | ||
| use arrow::array::{ | ||
| Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, OffsetSizeTrait, | ||
| cast::AsArray, make_array, | ||
| Scalar, cast::AsArray, make_array, | ||
| }; | ||
| use arrow::buffer::{NullBuffer, OffsetBuffer}; | ||
| use arrow::datatypes::{DataType, FieldRef}; | ||
| use datafusion_common::cast::as_int64_array; | ||
| use datafusion_common::utils::ListCoercion; | ||
| use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args}; | ||
| use datafusion_common::{ | ||
| Result, ScalarValue, exec_err, internal_err, utils::take_function_args, | ||
| }; | ||
| use datafusion_expr::{ | ||
| ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, | ||
| ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, | ||
|
|
@@ -113,7 +114,21 @@ impl ScalarUDFImpl for ArrayRemove { | |
| } | ||
|
|
||
| fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
| make_scalar_function(array_remove_inner)(&args.args) | ||
| let [list_arg, element_arg] = take_function_args(self.name(), &args.args)?; | ||
| let num_rows = args.number_rows; | ||
| let list_array = list_arg.to_array(num_rows)?; | ||
| let arr_n = vec![1; num_rows]; | ||
| match element_arg { | ||
| ColumnarValue::Array(element_array) => { | ||
| let result = array_remove_internal(&list_array, element_array, &arr_n)?; | ||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
| ColumnarValue::Scalar(scalar_element) => { | ||
| let result = | ||
| remove_with_scalar_needle(&list_array, scalar_element, &arr_n)?; | ||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| fn aliases(&self) -> &[String] { | ||
|
|
@@ -214,7 +229,23 @@ impl ScalarUDFImpl for ArrayRemoveN { | |
| } | ||
|
|
||
| fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
| make_scalar_function(array_remove_n_inner)(&args.args) | ||
| let [list_arg, element_arg, max_arg] = | ||
| take_function_args(self.name(), &args.args)?; | ||
| let num_rows = args.number_rows; | ||
| let list_array = list_arg.to_array(num_rows)?; | ||
| let max_array = max_arg.to_array(num_rows)?; | ||
| let arr_n = as_int64_array(&max_array)?.values().to_vec(); | ||
| match element_arg { | ||
| ColumnarValue::Array(element_array) => { | ||
| let result = array_remove_internal(&list_array, element_array, &arr_n)?; | ||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
| ColumnarValue::Scalar(scalar_element) => { | ||
| let result = | ||
| remove_with_scalar_needle(&list_array, scalar_element, &arr_n)?; | ||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| fn aliases(&self) -> &[String] { | ||
|
|
@@ -304,7 +335,21 @@ impl ScalarUDFImpl for ArrayRemoveAll { | |
| } | ||
|
|
||
| fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
| make_scalar_function(array_remove_all_inner)(&args.args) | ||
| let [list_arg, element_arg] = take_function_args(self.name(), &args.args)?; | ||
| let num_rows = args.number_rows; | ||
| let list_array = list_arg.to_array(num_rows)?; | ||
| let arr_n = vec![i64::MAX; num_rows]; | ||
| match element_arg { | ||
| ColumnarValue::Array(element_array) => { | ||
| let result = array_remove_internal(&list_array, element_array, &arr_n)?; | ||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
| ColumnarValue::Scalar(scalar_element) => { | ||
| let result = | ||
| remove_with_scalar_needle(&list_array, scalar_element, &arr_n)?; | ||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| fn aliases(&self) -> &[String] { | ||
|
|
@@ -316,27 +361,6 @@ impl ScalarUDFImpl for ArrayRemoveAll { | |
| } | ||
| } | ||
|
|
||
| fn array_remove_inner(args: &[ArrayRef]) -> Result<ArrayRef> { | ||
| let [array, element] = take_function_args("array_remove", args)?; | ||
|
|
||
| let arr_n = vec![1; array.len()]; | ||
| array_remove_internal(array, element, &arr_n) | ||
| } | ||
|
|
||
| fn array_remove_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> { | ||
| let [array, element, max] = take_function_args("array_remove_n", args)?; | ||
|
|
||
| let arr_n = as_int64_array(max)?.values().to_vec(); | ||
| array_remove_internal(array, element, &arr_n) | ||
| } | ||
|
|
||
| fn array_remove_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> { | ||
| let [array, element] = take_function_args("array_remove_all", args)?; | ||
|
|
||
| let arr_n = vec![i64::MAX; array.len()]; | ||
| array_remove_internal(array, element, &arr_n) | ||
| } | ||
|
|
||
| fn array_remove_internal( | ||
| array: &ArrayRef, | ||
| element_array: &ArrayRef, | ||
|
|
@@ -357,6 +381,45 @@ fn array_remove_internal( | |
| } | ||
| } | ||
|
|
||
| /// Dispatches scalar-needle array removal by list offset type. | ||
| /// | ||
| /// `needle` must be a length-1 array containing the scalar element to remove. | ||
| fn array_remove_dispatch_scalar( | ||
| array: &ArrayRef, | ||
| needle: &ArrayRef, | ||
| arr_n: &[i64], | ||
| ) -> Result<ArrayRef> { | ||
| match array.data_type() { | ||
| DataType::List(_) => { | ||
| let list_array = array.as_list::<i32>(); | ||
| general_remove_with_scalar::<i32>(list_array, needle, arr_n) | ||
| } | ||
| DataType::LargeList(_) => { | ||
| let list_array = array.as_list::<i64>(); | ||
| general_remove_with_scalar::<i64>(list_array, needle, arr_n) | ||
| } | ||
| array_type => exec_err!("array_remove does not support type '{array_type}'."), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is called by more than just
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, updated. |
||
| } | ||
| } | ||
|
|
||
| /// Removes elements matching a scalar needle from a list array. | ||
| /// | ||
| /// Uses a bulk `distinct` comparison for non-null, non-nested scalar elements, | ||
| /// falling back to the per-row `general_remove` path for null or nested types. | ||
| fn remove_with_scalar_needle( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find this function a little confusing, since the code paths can then be like: Perhaps we can hoist this null/nested check earlier to remove need for this function? Or the benefits of centralizing the check outweighs it? |
||
| list_array: &ArrayRef, | ||
| scalar_element: &ScalarValue, | ||
| arr_n: &[i64], | ||
| ) -> Result<ArrayRef> { | ||
| if !scalar_element.is_null() && !scalar_element.data_type().is_nested() { | ||
| let needle = scalar_element.to_array_of_size(1)?; | ||
| array_remove_dispatch_scalar(list_array, &needle, arr_n) | ||
| } else { | ||
| let needle_array = scalar_element.to_array_of_size(list_array.len())?; | ||
| array_remove_internal(list_array, &needle_array, arr_n) | ||
| } | ||
| } | ||
|
|
||
| /// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences | ||
| /// of `element_array[i]`. | ||
| /// | ||
|
|
@@ -468,6 +531,98 @@ fn general_remove<OffsetSize: OffsetSizeTrait>( | |
| )?)) | ||
| } | ||
|
|
||
| /// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences | ||
| /// of `needle[0]` (scalar element broadcasted). | ||
| /// | ||
| /// This is a specialized version of `general_remove` for scalar elements that | ||
| /// uses bulk comparison for better performance. | ||
| fn general_remove_with_scalar<OffsetSize: OffsetSizeTrait>( | ||
| list_array: &GenericListArray<OffsetSize>, | ||
| needle: &ArrayRef, | ||
| arr_n: &[i64], | ||
|
Comment on lines
+543
to
+544
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar question to that of array_replace PR, in if we should just handle when needle & max are scalars only |
||
| ) -> Result<ArrayRef> { | ||
| let list_field = match list_array.data_type() { | ||
| DataType::List(field) | DataType::LargeList(field) => field, | ||
| _ => { | ||
| return exec_err!( | ||
| "Expected List or LargeList data type, got {:?}", | ||
| list_array.data_type() | ||
| ); | ||
| } | ||
| }; | ||
| let original_data = list_array.values().to_data(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will be inefficient for sliced arrays.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I now slice the values to the range actually referenced by the offsets. That said, I wanted to understand your concern better: when a
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The over-allocation was one part, but the bigger concern is calling the |
||
| let mut offsets = Vec::<OffsetSize>::with_capacity(list_array.len() + 1); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
| offsets.push(OffsetSize::zero()); | ||
|
|
||
| let mut mutable = MutableArrayData::with_capacities( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if an approach using |
||
| vec![&original_data], | ||
| false, | ||
| Capacities::Array(original_data.len()), | ||
| ); | ||
| let nulls = list_array.nulls().cloned(); | ||
| let keep_mask = | ||
| arrow_ord::cmp::distinct(list_array.values(), &Scalar::new(Arc::clone(needle)))?; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will call the distinct kernel on all the elements in the value buffer, not just the ones that are visible in a sliced array. |
||
|
|
||
| for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { | ||
| if nulls.as_ref().is_some_and(|nulls| nulls.is_null(row_index)) { | ||
| offsets.push(offsets[row_index]); | ||
| continue; | ||
| } | ||
|
|
||
| let start = offset_window[0].to_usize().unwrap(); | ||
| let end = offset_window[1].to_usize().unwrap(); | ||
|
|
||
| let n = arr_n[row_index]; | ||
|
|
||
| if n <= 0 { | ||
| mutable.extend(0, start, end); | ||
| offsets.push(offsets[row_index] + OffsetSize::usize_as(end - start)); | ||
| continue; | ||
| } | ||
|
|
||
| let eq_array = keep_mask.slice(start, end - start); | ||
| let num_to_remove = eq_array.false_count(); | ||
|
|
||
| if num_to_remove == 0 { | ||
| mutable.extend(0, start, end); | ||
| offsets.push(offsets[row_index] + OffsetSize::usize_as(end - start)); | ||
| continue; | ||
| } | ||
|
|
||
| let max_removals = n.min(num_to_remove as i64); | ||
| let mut removed = 0i64; | ||
| let mut copied = 0usize; | ||
| let mut pending_batch_to_retain: Option<usize> = None; | ||
| for (i, keep) in eq_array.iter().enumerate() { | ||
| if keep == Some(false) && removed < max_removals { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we break from the loop once we hit
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. now break early once |
||
| if let Some(bs) = pending_batch_to_retain { | ||
| mutable.extend(0, start + bs, start + i); | ||
| copied += i - bs; | ||
| pending_batch_to_retain = None; | ||
| } | ||
| removed += 1; | ||
| } else if pending_batch_to_retain.is_none() { | ||
| pending_batch_to_retain = Some(i); | ||
| } | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if it would be possible to iterate only over the "false" bits, e.g., by negating the buffer and looking at
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great suggestion. Benchmarks show a ~20–40% improvement with this optimization.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Amazing! |
||
|
|
||
| if let Some(bs) = pending_batch_to_retain { | ||
| mutable.extend(0, start + bs, end); | ||
| copied += end - start - bs; | ||
| } | ||
|
|
||
| offsets.push(offsets[row_index] + OffsetSize::usize_as(copied)); | ||
| } | ||
|
|
||
| let new_values = make_array(mutable.freeze()); | ||
| Ok(Arc::new(GenericListArray::<OffsetSize>::try_new( | ||
| Arc::clone(list_field), | ||
| OffsetBuffer::new(offsets.into()), | ||
| new_values, | ||
| nulls, | ||
| )?)) | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN}; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're ignoring nulls in
max_arghere, though I think this may be an existing issue