-
Notifications
You must be signed in to change notification settings - Fork 2.1k
perf: optimize array_replace for scalar needle
#22387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
1468803
3e79a38
47b811b
11dfa84
3bae38b
e50617c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,22 +19,21 @@ | |
|
|
||
| use arrow::array::{ | ||
| Array, ArrayRef, AsArray, Capacities, GenericListArray, MutableArrayData, | ||
| NullBufferBuilder, OffsetSizeTrait, new_null_array, | ||
| NullBufferBuilder, OffsetSizeTrait, Scalar, new_null_array, | ||
| }; | ||
| use arrow::datatypes::{DataType, Field}; | ||
|
|
||
| use arrow::buffer::OffsetBuffer; | ||
| use datafusion_common::cast::as_int64_array; | ||
| use datafusion_common::utils::ListCoercion; | ||
| use datafusion_common::{Result, exec_err, utils::take_function_args}; | ||
| use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args}; | ||
| use datafusion_expr::{ | ||
| ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, | ||
| ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, | ||
| }; | ||
| use datafusion_macros::user_doc; | ||
|
|
||
| use crate::utils::compare_element_to_list; | ||
| use crate::utils::make_scalar_function; | ||
|
|
||
| use std::sync::Arc; | ||
|
|
||
|
|
@@ -125,7 +124,27 @@ impl ScalarUDFImpl for ArrayReplace { | |
| } | ||
|
|
||
| fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
| make_scalar_function(array_replace_inner)(&args.args) | ||
| let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?; | ||
| let num_rows = args.number_rows; | ||
| let list_array = list_arg.to_array(num_rows)?; | ||
| let to_array = to_arg.to_array(num_rows)?; | ||
| let arr_n = vec![1; num_rows]; | ||
| match from_arg { | ||
| ColumnarValue::Array(from_array) => { | ||
| let result = | ||
| array_replace_internal(&list_array, from_array, &to_array, &arr_n)?; | ||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
| ColumnarValue::Scalar(scalar_from) => { | ||
| let result = replace_with_scalar_needle( | ||
| &list_array, | ||
| scalar_from, | ||
| &to_array, | ||
| &arr_n, | ||
| )?; | ||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| fn aliases(&self) -> &[String] { | ||
|
|
@@ -200,7 +219,29 @@ impl ScalarUDFImpl for ArrayReplaceN { | |
| } | ||
|
|
||
| fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
| make_scalar_function(array_replace_n_inner)(&args.args) | ||
| let [list_arg, from_arg, to_arg, max_arg] = | ||
| take_function_args(self.name(), &args.args)?; | ||
| let num_rows = args.number_rows; | ||
| let list_array = list_arg.to_array(num_rows)?; | ||
| let to_array = to_arg.to_array(num_rows)?; | ||
| let max_array = max_arg.to_array(num_rows)?; | ||
| let arr_n = as_int64_array(&max_array)?.values().to_vec(); | ||
| match from_arg { | ||
| ColumnarValue::Array(from_array) => { | ||
| let result = | ||
| array_replace_internal(&list_array, from_array, &to_array, &arr_n)?; | ||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
| ColumnarValue::Scalar(scalar_from) => { | ||
| let result = replace_with_scalar_needle( | ||
| &list_array, | ||
| scalar_from, | ||
| &to_array, | ||
| &arr_n, | ||
| )?; | ||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| fn aliases(&self) -> &[String] { | ||
|
|
@@ -273,7 +314,27 @@ impl ScalarUDFImpl for ArrayReplaceAll { | |
| } | ||
|
|
||
| fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
| make_scalar_function(array_replace_all_inner)(&args.args) | ||
| let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?; | ||
| let num_rows = args.number_rows; | ||
| let list_array = list_arg.to_array(num_rows)?; | ||
| let to_array = to_arg.to_array(num_rows)?; | ||
| let arr_n = vec![i64::MAX; num_rows]; | ||
| match from_arg { | ||
| ColumnarValue::Array(from_array) => { | ||
| let result = | ||
| array_replace_internal(&list_array, from_array, &to_array, &arr_n)?; | ||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
| ColumnarValue::Scalar(scalar_from) => { | ||
| let result = replace_with_scalar_needle( | ||
| &list_array, | ||
| scalar_from, | ||
| &to_array, | ||
| &arr_n, | ||
| )?; | ||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| fn aliases(&self) -> &[String] { | ||
|
|
@@ -412,63 +473,165 @@ fn general_replace<O: OffsetSizeTrait>( | |
| )?)) | ||
| } | ||
|
|
||
| fn array_replace_inner(args: &[ArrayRef]) -> Result<ArrayRef> { | ||
| let [array, from, to] = take_function_args("array_replace", args)?; | ||
| /// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurrences | ||
| /// of `needle[0]` with `to_array[i]`. | ||
| /// | ||
| /// This is a specialized version of `general_replace` for scalar needles that | ||
| /// uses a single bulk comparison for better performance. | ||
| fn general_replace_with_scalar<O: OffsetSizeTrait>( | ||
| list_array: &GenericListArray<O>, | ||
| needle: &ArrayRef, | ||
| to_array: &ArrayRef, | ||
| arr_n: &[i64], | ||
| ) -> Result<ArrayRef> { | ||
| let mut offsets: Vec<O> = vec![O::usize_as(0)]; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using |
||
| let values = list_array.values(); | ||
| let original_data = values.to_data(); | ||
| let to_data = to_array.to_data(); | ||
| let capacity = Capacities::Array(original_data.len()); | ||
|
|
||
| // replace at most one occurrence for each element | ||
| let arr_n = vec![1; array.len()]; | ||
| match array.data_type() { | ||
| DataType::List(_) => { | ||
| let list_array = array.as_list::<i32>(); | ||
| general_replace::<i32>(list_array, from, to, &arr_n) | ||
| let mut mutable = MutableArrayData::with_capacities( | ||
| vec![&original_data, &to_data], | ||
| false, | ||
| capacity, | ||
| ); | ||
|
|
||
| let mut valid = NullBufferBuilder::new(list_array.len()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need a builder for nulls, we can copy the input array null buffer as is |
||
| let match_array = arrow_ord::cmp::not_distinct( | ||
| list_array.values(), | ||
| &Scalar::new(Arc::clone(needle)), | ||
| )?; | ||
|
|
||
| for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { | ||
| if list_array.is_null(row_index) { | ||
| offsets.push(offsets[row_index]); | ||
| valid.append_null(); | ||
| continue; | ||
| } | ||
| DataType::LargeList(_) => { | ||
| let list_array = array.as_list::<i64>(); | ||
| general_replace::<i64>(list_array, from, to, &arr_n) | ||
|
|
||
| let start = offset_window[0]; | ||
| let end = offset_window[1]; | ||
| let start_usize = start.to_usize().unwrap(); | ||
| let end_usize = end.to_usize().unwrap(); | ||
| let row_len = end_usize - start_usize; | ||
|
|
||
| let original_idx = O::usize_as(0); | ||
| let replace_idx = O::usize_as(1); | ||
| let n = arr_n[row_index]; | ||
| let mut counter = 0; | ||
|
|
||
| let eq_array = match_array.slice(start_usize, row_len); | ||
| if n <= 0 || eq_array.true_count() == 0 { | ||
| mutable.extend(original_idx.to_usize().unwrap(), start_usize, end_usize); | ||
| offsets.push(offsets[row_index] + (end - start)); | ||
| valid.append_non_null(); | ||
| continue; | ||
| } | ||
| DataType::Null => Ok(new_null_array(array.data_type(), 1)), | ||
| array_type => exec_err!("array_replace does not support type '{array_type}'."), | ||
|
|
||
| let mut pending_retain: Option<O> = None; | ||
| for (i, is_match) in eq_array.iter().enumerate() { | ||
| let i = O::usize_as(i); | ||
| if is_match == Some(true) && counter < n { | ||
| if let Some(rs) = pending_retain.take() { | ||
| mutable.extend( | ||
| original_idx.to_usize().unwrap(), | ||
| start_usize + rs.to_usize().unwrap(), | ||
| start_usize + i.to_usize().unwrap(), | ||
| ); | ||
| } | ||
| mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1); | ||
| counter += 1; | ||
| if counter == n { | ||
| mutable.extend( | ||
| original_idx.to_usize().unwrap(), | ||
| start_usize + i.to_usize().unwrap() + 1, | ||
| end_usize, | ||
| ); | ||
| break; | ||
| } | ||
| } else if pending_retain.is_none() { | ||
| pending_retain = Some(i); | ||
| } | ||
| } | ||
|
|
||
| if counter < n | ||
| && let Some(rs) = pending_retain | ||
| { | ||
| mutable.extend( | ||
| original_idx.to_usize().unwrap(), | ||
| start_usize + rs.to_usize().unwrap(), | ||
| end_usize, | ||
| ); | ||
| } | ||
|
|
||
| offsets.push(offsets[row_index] + (end - start)); | ||
| valid.append_non_null(); | ||
| } | ||
| } | ||
|
|
||
| fn array_replace_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> { | ||
| let [array, from, to, max] = take_function_args("array_replace_n", args)?; | ||
| let data = mutable.freeze(); | ||
|
|
||
| // replace the specified number of occurrences | ||
| let arr_n = as_int64_array(max)?.values().to_vec(); | ||
| Ok(Arc::new(GenericListArray::<O>::try_new( | ||
| Arc::new(Field::new_list_field(list_array.value_type(), true)), | ||
| OffsetBuffer::<O>::new(offsets.into()), | ||
| arrow::array::make_array(data), | ||
| valid.finish(), | ||
| )?)) | ||
| } | ||
|
|
||
| /// Dispatches scalar-needle array replacement by list offset type. | ||
| /// | ||
| /// `needle` must be a length-1 array containing the scalar element to replace. | ||
| fn array_replace_dispatch_scalar( | ||
| array: &ArrayRef, | ||
| needle: &ArrayRef, | ||
| to: &ArrayRef, | ||
| arr_n: &[i64], | ||
| ) -> Result<ArrayRef> { | ||
| match array.data_type() { | ||
| DataType::List(_) => { | ||
| let list_array = array.as_list::<i32>(); | ||
| general_replace::<i32>(list_array, from, to, &arr_n) | ||
| general_replace_with_scalar::<i32>(list_array, needle, to, arr_n) | ||
| } | ||
| DataType::LargeList(_) => { | ||
| let list_array = array.as_list::<i64>(); | ||
| general_replace::<i64>(list_array, from, to, &arr_n) | ||
| general_replace_with_scalar::<i64>(list_array, needle, to, arr_n) | ||
| } | ||
| DataType::Null => Ok(new_null_array(array.data_type(), 1)), | ||
| array_type => { | ||
| exec_err!("array_replace_n does not support type '{array_type}'.") | ||
| } | ||
| array_type => exec_err!("array_replace does not support type '{array_type}'."), | ||
| } | ||
| } | ||
|
|
||
| fn array_replace_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> { | ||
| let [array, from, to] = take_function_args("array_replace_all", args)?; | ||
| fn replace_with_scalar_needle( | ||
| list_array: &ArrayRef, | ||
| scalar_from: &ScalarValue, | ||
| to_array: &ArrayRef, | ||
| arr_n: &[i64], | ||
| ) -> Result<ArrayRef> { | ||
| if scalar_from.data_type().is_nested() { | ||
| let from_array = scalar_from.to_array_of_size(list_array.len())?; | ||
| return array_replace_internal(list_array, &from_array, to_array, arr_n); | ||
| } | ||
|
|
||
| // replace all occurrences (up to "i64::MAX") | ||
| let arr_n = vec![i64::MAX; array.len()]; | ||
| let needle = scalar_from.to_array_of_size(1)?; | ||
| array_replace_dispatch_scalar(list_array, &needle, to_array, arr_n) | ||
| } | ||
|
|
||
| fn array_replace_internal( | ||
| array: &ArrayRef, | ||
| from: &ArrayRef, | ||
| to: &ArrayRef, | ||
| arr_n: &[i64], | ||
| ) -> Result<ArrayRef> { | ||
| match array.data_type() { | ||
| DataType::List(_) => { | ||
| let list_array = array.as_list::<i32>(); | ||
| general_replace::<i32>(list_array, from, to, &arr_n) | ||
| general_replace::<i32>(list_array, from, to, arr_n) | ||
| } | ||
| DataType::LargeList(_) => { | ||
| let list_array = array.as_list::<i64>(); | ||
| general_replace::<i64>(list_array, from, to, &arr_n) | ||
| general_replace::<i64>(list_array, from, to, arr_n) | ||
| } | ||
| DataType::Null => Ok(new_null_array(array.data_type(), 1)), | ||
| array_type => { | ||
| exec_err!("array_replace_all does not support type '{array_type}'.") | ||
| } | ||
| array_type => exec_err!("array_replace does not support type '{array_type}'."), | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
needleshould be aScalarin the arguments here, to make it clear this is the scalar (without needing to read the docstring)replace_with_scalar_needlefor example, since its still aScalarValueat that point