From 146880308eb17da86cce9cc4fd0ab5b0745ad29f Mon Sep 17 00:00:00 2001 From: lyne7-sc <734432041@qq.com> Date: Wed, 20 May 2026 11:16:28 +0800 Subject: [PATCH 1/4] Refactor array replace invocation --- datafusion/functions-nested/src/replace.rs | 355 ++++++++++++++++++--- 1 file changed, 317 insertions(+), 38 deletions(-) diff --git a/datafusion/functions-nested/src/replace.rs b/datafusion/functions-nested/src/replace.rs index a9a53a3cb989f..eb9d60547df55 100644 --- a/datafusion/functions-nested/src/replace.rs +++ b/datafusion/functions-nested/src/replace.rs @@ -19,14 +19,14 @@ use arrow::array::{ Array, ArrayRef, AsArray, Capacities, GenericListArray, MutableArrayData, - NullBufferBuilder, OffsetSizeTrait, new_null_array, + NullBufferBuilder, OffsetSizeTrait, Scalar, new_null_array, }; use arrow::datatypes::{DataType, Field}; use arrow::buffer::OffsetBuffer; use datafusion_common::cast::as_int64_array; use datafusion_common::utils::ListCoercion; -use datafusion_common::{Result, exec_err, utils::take_function_args}; +use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -34,7 +34,6 @@ use datafusion_expr::{ use datafusion_macros::user_doc; use crate::utils::compare_element_to_list; -use crate::utils::make_scalar_function; use std::sync::Arc; @@ -125,7 +124,27 @@ impl ScalarUDFImpl for ArrayReplace { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(array_replace_inner)(&args.args) + let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + let to_array = to_arg.to_array(num_rows)?; + let arr_n = vec![1; num_rows]; + match from_arg { + ColumnarValue::Array(from_array) => { + let result = + array_replace_internal(&list_array, from_array, &to_array, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(scalar_from) => { + let result = replace_with_scalar_needle( + &list_array, + scalar_from, + &to_array, + &arr_n, + )?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -200,7 +219,29 @@ impl ScalarUDFImpl for ArrayReplaceN { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(array_replace_n_inner)(&args.args) + let [list_arg, from_arg, to_arg, max_arg] = + take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + let to_array = to_arg.to_array(num_rows)?; + let max_array = max_arg.to_array(num_rows)?; + let arr_n = as_int64_array(&max_array)?.values().to_vec(); + match from_arg { + ColumnarValue::Array(from_array) => { + let result = + array_replace_internal(&list_array, from_array, &to_array, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(scalar_from) => { + let result = replace_with_scalar_needle( + &list_array, + scalar_from, + &to_array, + &arr_n, + )?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -273,7 +314,27 @@ impl ScalarUDFImpl for ArrayReplaceAll { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(array_replace_all_inner)(&args.args) + let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + let to_array = to_arg.to_array(num_rows)?; + let arr_n = vec![i64::MAX; num_rows]; + match from_arg { + ColumnarValue::Array(from_array) => { + let result = + array_replace_internal(&list_array, from_array, &to_array, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(scalar_from) => { + let result = replace_with_scalar_needle( + &list_array, + scalar_from, + &to_array, + &arr_n, + )?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -412,63 +473,281 @@ fn general_replace( )?)) } -fn array_replace_inner(args: &[ArrayRef]) -> Result { - let [array, from, to] = take_function_args("array_replace", args)?; +/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurrences +/// of `needle[0]` with `to_array[i]`. +/// +/// This is a specialized version of `general_replace` for scalar needles that +/// uses a single bulk comparison for better performance. +fn general_replace_with_scalar( + list_array: &GenericListArray, + needle: &ArrayRef, + to_array: &ArrayRef, + arr_n: &[i64], +) -> Result { + let mut offsets: Vec = vec![O::usize_as(0)]; + let values = list_array.values(); + let original_data = values.to_data(); + let to_data = to_array.to_data(); + let capacity = Capacities::Array(original_data.len()); + + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &to_data], + false, + capacity, + ); - // replace at most one occurrence for each element - let arr_n = vec![1; array.len()]; - match array.data_type() { - DataType::List(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + let mut valid = NullBufferBuilder::new(list_array.len()); + let match_array = arrow_ord::cmp::not_distinct( + list_array.values(), + &Scalar::new(Arc::clone(needle)), + )?; + + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + if list_array.is_null(row_index) { + offsets.push(offsets[row_index]); + valid.append_null(); + continue; } - DataType::LargeList(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + + let start = offset_window[0]; + let end = offset_window[1]; + let start_usize = start.to_usize().unwrap(); + let end_usize = end.to_usize().unwrap(); + let row_len = end_usize - start_usize; + + let original_idx = O::usize_as(0); + let replace_idx = O::usize_as(1); + let n = arr_n[row_index]; + let mut counter = 0; + + let eq_array = match_array.slice(start_usize, row_len); + if n <= 0 || eq_array.true_count() == 0 { + mutable.extend(original_idx.to_usize().unwrap(), start_usize, end_usize); + offsets.push(offsets[row_index] + (end - start)); + valid.append_non_null(); + continue; } - DataType::Null => Ok(new_null_array(array.data_type(), 1)), - array_type => exec_err!("array_replace does not support type '{array_type}'."), + + let mut pending_retain: Option = None; + for (i, is_match) in eq_array.iter().enumerate() { + let i = O::usize_as(i); + if is_match == Some(true) && counter < n { + if let Some(rs) = pending_retain.take() { + mutable.extend( + original_idx.to_usize().unwrap(), + start_usize + rs.to_usize().unwrap(), + start_usize + i.to_usize().unwrap(), + ); + } + mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1); + counter += 1; + if counter == n { + mutable.extend( + original_idx.to_usize().unwrap(), + start_usize + i.to_usize().unwrap() + 1, + end_usize, + ); + break; + } + } else if pending_retain.is_none() { + pending_retain = Some(i); + } + } + + if counter < n + && let Some(rs) = pending_retain + { + mutable.extend( + original_idx.to_usize().unwrap(), + start_usize + rs.to_usize().unwrap(), + end_usize, + ); + } + + offsets.push(offsets[row_index] + (end - start)); + valid.append_non_null(); } -} -fn array_replace_n_inner(args: &[ArrayRef]) -> Result { - let [array, from, to, max] = take_function_args("array_replace_n", args)?; + let data = mutable.freeze(); - // replace the specified number of occurrences - let arr_n = as_int64_array(max)?.values().to_vec(); + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new_list_field(list_array.value_type(), true)), + OffsetBuffer::::new(offsets.into()), + arrow::array::make_array(data), + valid.finish(), + )?)) +} + +/// Dispatches scalar-needle array replacement by list offset type. +/// +/// `needle` must be a length-1 array containing the scalar element to replace. +fn array_replace_dispatch_scalar( + array: &ArrayRef, + needle: &ArrayRef, + to: &ArrayRef, + arr_n: &[i64], +) -> Result { match array.data_type() { DataType::List(_) => { let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + general_replace_with_scalar::(list_array, needle, to, arr_n) } DataType::LargeList(_) => { let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + general_replace_with_scalar::(list_array, needle, to, arr_n) } DataType::Null => Ok(new_null_array(array.data_type(), 1)), - array_type => { - exec_err!("array_replace_n does not support type '{array_type}'.") - } + array_type => exec_err!("array_replace does not support type '{array_type}'."), } } -fn array_replace_all_inner(args: &[ArrayRef]) -> Result { - let [array, from, to] = take_function_args("array_replace_all", args)?; +fn replace_with_scalar_needle( + list_array: &ArrayRef, + scalar_from: &ScalarValue, + to_array: &ArrayRef, + arr_n: &[i64], +) -> Result { + if scalar_from.is_null() || scalar_from.data_type().is_nested() { + let from_array = scalar_from.to_array_of_size(list_array.len())?; + return array_replace_internal(list_array, &from_array, to_array, arr_n); + } - // replace all occurrences (up to "i64::MAX") - let arr_n = vec![i64::MAX; array.len()]; + let needle = scalar_from.to_array_of_size(1)?; + array_replace_dispatch_scalar(list_array, &needle, to_array, arr_n) +} + +fn array_replace_internal( + array: &ArrayRef, + from: &ArrayRef, + to: &ArrayRef, + arr_n: &[i64], +) -> Result { match array.data_type() { DataType::List(_) => { let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + general_replace::(list_array, from, to, arr_n) } DataType::LargeList(_) => { let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + general_replace::(list_array, from, to, arr_n) } DataType::Null => Ok(new_null_array(array.data_type(), 1)), - array_type => { - exec_err!("array_replace_all does not support type '{array_type}'.") - } + array_type => exec_err!("array_replace does not support type '{array_type}'."), + } +} + +#[cfg(test)] +mod tests { + use super::{ArrayReplace, ArrayReplaceAll, ArrayReplaceN}; + use arrow::array::{ArrayRef, AsArray, ListArray}; + use arrow::datatypes::{DataType, Field, Int32Type}; + use datafusion_common::ScalarValue; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + use std::sync::Arc; + + fn int_list(values: Vec>) -> ArrayRef { + Arc::new(ListArray::from_iter_primitive::( + values + .into_iter() + .map(|row| Some(row.into_iter().map(Some))), + )) + } + + fn invoke_replace( + udf: &dyn ScalarUDFImpl, + args: Vec, + return_type: DataType, + ) -> ColumnarValue { + let arg_fields = args + .iter() + .enumerate() + .map(|(i, arg)| { + Arc::new(Field::new(format!("arg{i}"), arg.data_type(), true)) + }) + .collect::>(); + let number_rows = args + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + ColumnarValue::Scalar(_) => None, + }) + .unwrap_or(1); + let return_field = Arc::new(Field::new("result", return_type, true)); + + udf.invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows, + return_field, + config_options: Arc::new(Default::default()), + }) + .unwrap() + } + + #[test] + fn array_replace_uses_scalar_arguments() { + let input = int_list(vec![vec![1, 2, 2, 3], vec![2, 4, 2]]); + let expected = int_list(vec![vec![1, 9, 2, 3], vec![9, 4, 2]]); + let return_type = input.data_type().clone(); + + let result = invoke_replace( + &ArrayReplace::new(), + vec![ + ColumnarValue::Array(input), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(9))), + ], + return_type, + ); + + let ColumnarValue::Array(result) = result else { + panic!("expected array result"); + }; + assert_eq!(result.as_list::(), expected.as_list::()); + } + + #[test] + fn array_replace_n_uses_scalar_arguments() { + let input = int_list(vec![vec![1, 2, 2, 3], vec![2, 4, 2]]); + let expected = int_list(vec![vec![1, 9, 9, 3], vec![9, 4, 9]]); + let return_type = input.data_type().clone(); + + let result = invoke_replace( + &ArrayReplaceN::new(), + vec![ + ColumnarValue::Array(input), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(9))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + return_type, + ); + + let ColumnarValue::Array(result) = result else { + panic!("expected array result"); + }; + assert_eq!(result.as_list::(), expected.as_list::()); + } + + #[test] + fn array_replace_all_uses_scalar_arguments() { + let input = int_list(vec![vec![1, 2, 2, 3], vec![2, 4, 2]]); + let expected = int_list(vec![vec![1, 9, 9, 3], vec![9, 4, 9]]); + let return_type = input.data_type().clone(); + + let result = invoke_replace( + &ArrayReplaceAll::new(), + vec![ + ColumnarValue::Array(input), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(9))), + ], + return_type, + ); + + let ColumnarValue::Array(result) = result else { + panic!("expected array result"); + }; + assert_eq!(result.as_list::(), expected.as_list::()); } } From 3e79a3887f8dcf6675da6873e18d1711a209d346 Mon Sep 17 00:00:00 2001 From: lyne7-sc <734432041@qq.com> Date: Wed, 20 May 2026 14:45:47 +0800 Subject: [PATCH 2/4] Refine array replace scalar path tests --- datafusion/functions-nested/src/replace.rs | 118 +----------------- .../test_files/array/array_replace.slt | 60 ++++++++- 2 files changed, 55 insertions(+), 123 deletions(-) diff --git a/datafusion/functions-nested/src/replace.rs b/datafusion/functions-nested/src/replace.rs index eb9d60547df55..ba73fce5e74ca 100644 --- a/datafusion/functions-nested/src/replace.rs +++ b/datafusion/functions-nested/src/replace.rs @@ -607,7 +607,7 @@ fn replace_with_scalar_needle( to_array: &ArrayRef, arr_n: &[i64], ) -> Result { - if scalar_from.is_null() || scalar_from.data_type().is_nested() { + if scalar_from.data_type().is_nested() { let from_array = scalar_from.to_array_of_size(list_array.len())?; return array_replace_internal(list_array, &from_array, to_array, arr_n); } @@ -635,119 +635,3 @@ fn array_replace_internal( array_type => exec_err!("array_replace does not support type '{array_type}'."), } } - -#[cfg(test)] -mod tests { - use super::{ArrayReplace, ArrayReplaceAll, ArrayReplaceN}; - use arrow::array::{ArrayRef, AsArray, ListArray}; - use arrow::datatypes::{DataType, Field, Int32Type}; - use datafusion_common::ScalarValue; - use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; - use std::sync::Arc; - - fn int_list(values: Vec>) -> ArrayRef { - Arc::new(ListArray::from_iter_primitive::( - values - .into_iter() - .map(|row| Some(row.into_iter().map(Some))), - )) - } - - fn invoke_replace( - udf: &dyn ScalarUDFImpl, - args: Vec, - return_type: DataType, - ) -> ColumnarValue { - let arg_fields = args - .iter() - .enumerate() - .map(|(i, arg)| { - Arc::new(Field::new(format!("arg{i}"), arg.data_type(), true)) - }) - .collect::>(); - let number_rows = args - .iter() - .find_map(|arg| match arg { - ColumnarValue::Array(array) => Some(array.len()), - ColumnarValue::Scalar(_) => None, - }) - .unwrap_or(1); - let return_field = Arc::new(Field::new("result", return_type, true)); - - udf.invoke_with_args(ScalarFunctionArgs { - args, - arg_fields, - number_rows, - return_field, - config_options: Arc::new(Default::default()), - }) - .unwrap() - } - - #[test] - fn array_replace_uses_scalar_arguments() { - let input = int_list(vec![vec![1, 2, 2, 3], vec![2, 4, 2]]); - let expected = int_list(vec![vec![1, 9, 2, 3], vec![9, 4, 2]]); - let return_type = input.data_type().clone(); - - let result = invoke_replace( - &ArrayReplace::new(), - vec![ - ColumnarValue::Array(input), - ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), - ColumnarValue::Scalar(ScalarValue::Int32(Some(9))), - ], - return_type, - ); - - let ColumnarValue::Array(result) = result else { - panic!("expected array result"); - }; - assert_eq!(result.as_list::(), expected.as_list::()); - } - - #[test] - fn array_replace_n_uses_scalar_arguments() { - let input = int_list(vec![vec![1, 2, 2, 3], vec![2, 4, 2]]); - let expected = int_list(vec![vec![1, 9, 9, 3], vec![9, 4, 9]]); - let return_type = input.data_type().clone(); - - let result = invoke_replace( - &ArrayReplaceN::new(), - vec![ - ColumnarValue::Array(input), - ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), - ColumnarValue::Scalar(ScalarValue::Int32(Some(9))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), - ], - return_type, - ); - - let ColumnarValue::Array(result) = result else { - panic!("expected array result"); - }; - assert_eq!(result.as_list::(), expected.as_list::()); - } - - #[test] - fn array_replace_all_uses_scalar_arguments() { - let input = int_list(vec![vec![1, 2, 2, 3], vec![2, 4, 2]]); - let expected = int_list(vec![vec![1, 9, 9, 3], vec![9, 4, 9]]); - let return_type = input.data_type().clone(); - - let result = invoke_replace( - &ArrayReplaceAll::new(), - vec![ - ColumnarValue::Array(input), - ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), - ColumnarValue::Scalar(ScalarValue::Int32(Some(9))), - ], - return_type, - ); - - let ColumnarValue::Array(result) = result else { - panic!("expected array result"); - }; - assert_eq!(result.as_list::(), expected.as_list::()); - } -} diff --git a/datafusion/sqllogictest/test_files/array/array_replace.slt b/datafusion/sqllogictest/test_files/array/array_replace.slt index 390ed4b946520..9c5ab206a26b0 100644 --- a/datafusion/sqllogictest/test_files/array/array_replace.slt +++ b/datafusion/sqllogictest/test_files/array/array_replace.slt @@ -212,6 +212,33 @@ from large_nested_arrays_with_repeating_elements; [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] +# array_replace scalar arguments over multiple input rows +query ??? +select + array_replace(column1, 2, 9), + array_replace_n(column1, 2, 9, 2), + array_replace_all(column1, 2, 9) +from ( + values + (make_array(1, 2, 2, 3)), + (make_array(2, 4, 2)) +) as t(column1); +---- +[1, 9, 2, 3] [1, 9, 9, 3] [1, 9, 9, 3] +[9, 4, 2] [9, 4, 9] [9, 4, 9] + +# array_replace_n scalar max exceeding matches over multiple input rows +query ? +select array_replace_n(column1, 2, 9, 10) +from ( + values + (make_array(1, 2, 2, 3)), + (make_array(2, 4, 2)) +) as t(column1); +---- +[1, 9, 9, 3] +[9, 4, 9] + ## array_replace_n (aliases: `list_replace_n`) # array_replace_n scalar function #1 @@ -226,22 +253,35 @@ select ---- [1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4] [1, 4, 4] [0, 4, 0, 5] -query ???? +query ?????? select array_replace_n(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3, 2), array_replace_n(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0, 2), array_replace_n(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0, 3), - array_replace_n(arrow_cast(make_array(1, 4, 4), 'LargeList(Int64)'), 4, 0, 0); + array_replace_n(arrow_cast(make_array(1, 4, 4), 'LargeList(Int64)'), 4, 0, 0), + array_replace_n(arrow_cast(make_array(1, 4, 4), 'LargeList(Int64)'), 4, 0, -1), + array_replace_n(arrow_cast(make_array(1, 4, 1, 5), 'LargeList(Int64)'), 1, 0, 10); ---- -[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4] +[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4] [1, 4, 4] [0, 4, 0, 5] -query ??? +query ?????? select array_replace_n(arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int64)'), 2, 3, 2), array_replace_n(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'FixedSizeList(7, Int64)'), 4, 0, 2), - array_replace_n(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), 4, 0, 3); + array_replace_n(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), 4, 0, 3), + array_replace_n(arrow_cast(make_array(1, 4, 4), 'FixedSizeList(3, Int64)'), 4, 0, 0), + array_replace_n(arrow_cast(make_array(1, 4, 4), 'FixedSizeList(3, Int64)'), 4, 0, -1), + array_replace_n(arrow_cast(make_array(1, 4, 1, 5), 'FixedSizeList(4, Int64)'), 1, 0, 10); ---- -[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] +[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4] [1, 4, 4] [0, 4, 0, 5] + +# array_replace_n scalar max exceeding matches for empty arrays +query ?? +select + array_replace_n(arrow_cast(make_array(), 'List(Int64)'), 2, 9, 10), + array_replace_n(arrow_cast(make_array(), 'LargeList(Int64)'), 2, 9, 10); +---- +[] [] # array_replace_n scalar function #2 (element is list) query ?? @@ -657,6 +697,14 @@ select column1, column2, column3, column4, array_replace_n(column1, column2, col NULL 3 2 1 NULL [3, 1, 3] 3 NULL 1 [NULL, 1, 3] +query ??? +select + array_replace(make_array(3, NULL, NULL), NULL, 5), + array_replace_n(make_array(3, NULL, NULL), NULL, 5, 10), + array_replace_all(make_array(3, NULL, NULL), NULL, 5); +---- +[3, 5, NULL] [3, 5, 5] [3, 5, 5] + statement ok From 11dfa84b2fd63e0b686a5003481c2ac6397a9ab4 Mon Sep 17 00:00:00 2001 From: linfeng <33561138+lyne7-sc@users.noreply.github.com> Date: Thu, 21 May 2026 15:53:37 +0800 Subject: [PATCH 3/4] better clarity --- datafusion/functions-nested/src/replace.rs | 100 ++++++++------------- 1 file changed, 35 insertions(+), 65 deletions(-) diff --git a/datafusion/functions-nested/src/replace.rs b/datafusion/functions-nested/src/replace.rs index ba73fce5e74ca..819edff1f2f20 100644 --- a/datafusion/functions-nested/src/replace.rs +++ b/datafusion/functions-nested/src/replace.rs @@ -19,11 +19,10 @@ use arrow::array::{ Array, ArrayRef, AsArray, Capacities, GenericListArray, MutableArrayData, - NullBufferBuilder, OffsetSizeTrait, Scalar, new_null_array, + NullBufferBuilder, OffsetBufferBuilder, OffsetSizeTrait, Scalar, new_null_array, }; -use arrow::datatypes::{DataType, Field}; - use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field}; use datafusion_common::cast::as_int64_array; use datafusion_common::utils::ListCoercion; use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args}; @@ -474,17 +473,16 @@ fn general_replace( } /// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurrences -/// of `needle[0]` with `to_array[i]`. +/// of `needle` with `to_array[i]`. /// /// This is a specialized version of `general_replace` for scalar needles that /// uses a single bulk comparison for better performance. fn general_replace_with_scalar( list_array: &GenericListArray, - needle: &ArrayRef, + needle: &Scalar, to_array: &ArrayRef, arr_n: &[i64], ) -> Result { - let mut offsets: Vec = vec![O::usize_as(0)]; let values = list_array.values(); let original_data = values.to_data(); let to_data = to_array.to_data(); @@ -496,57 +494,45 @@ fn general_replace_with_scalar( capacity, ); - let mut valid = NullBufferBuilder::new(list_array.len()); - let match_array = arrow_ord::cmp::not_distinct( - list_array.values(), - &Scalar::new(Arc::clone(needle)), - )?; + let mut offsets = OffsetBufferBuilder::::new(list_array.len()); + let match_array = arrow_ord::cmp::not_distinct(list_array.values(), needle)?; for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + let start = offset_window[0].to_usize().unwrap(); + let end = offset_window[1].to_usize().unwrap(); + let row_len = end - start; + if list_array.is_null(row_index) { - offsets.push(offsets[row_index]); - valid.append_null(); + offsets.push_length(0); continue; } - let start = offset_window[0]; - let end = offset_window[1]; - let start_usize = start.to_usize().unwrap(); - let end_usize = end.to_usize().unwrap(); - let row_len = end_usize - start_usize; - let original_idx = O::usize_as(0); let replace_idx = O::usize_as(1); let n = arr_n[row_index]; let mut counter = 0; - let eq_array = match_array.slice(start_usize, row_len); + let eq_array = match_array.slice(start, row_len); if n <= 0 || eq_array.true_count() == 0 { - mutable.extend(original_idx.to_usize().unwrap(), start_usize, end_usize); - offsets.push(offsets[row_index] + (end - start)); - valid.append_non_null(); + mutable.extend(original_idx.to_usize().unwrap(), start, end); + offsets.push_length(row_len); continue; } - let mut pending_retain: Option = None; + let mut pending_retain: Option = None; for (i, is_match) in eq_array.iter().enumerate() { - let i = O::usize_as(i); if is_match == Some(true) && counter < n { if let Some(rs) = pending_retain.take() { mutable.extend( original_idx.to_usize().unwrap(), - start_usize + rs.to_usize().unwrap(), - start_usize + i.to_usize().unwrap(), + start + rs, + start + i, ); } mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1); counter += 1; if counter == n { - mutable.extend( - original_idx.to_usize().unwrap(), - start_usize + i.to_usize().unwrap() + 1, - end_usize, - ); + mutable.extend(original_idx.to_usize().unwrap(), start + i + 1, end); break; } } else if pending_retain.is_none() { @@ -557,50 +543,23 @@ fn general_replace_with_scalar( if counter < n && let Some(rs) = pending_retain { - mutable.extend( - original_idx.to_usize().unwrap(), - start_usize + rs.to_usize().unwrap(), - end_usize, - ); + mutable.extend(original_idx.to_usize().unwrap(), start + rs, end); } - offsets.push(offsets[row_index] + (end - start)); - valid.append_non_null(); + offsets.push_length(row_len); } let data = mutable.freeze(); Ok(Arc::new(GenericListArray::::try_new( Arc::new(Field::new_list_field(list_array.value_type(), true)), - OffsetBuffer::::new(offsets.into()), + offsets.finish(), arrow::array::make_array(data), - valid.finish(), + list_array.nulls().cloned(), )?)) } /// Dispatches scalar-needle array replacement by list offset type. -/// -/// `needle` must be a length-1 array containing the scalar element to replace. -fn array_replace_dispatch_scalar( - array: &ArrayRef, - needle: &ArrayRef, - to: &ArrayRef, - arr_n: &[i64], -) -> Result { - match array.data_type() { - DataType::List(_) => { - let list_array = array.as_list::(); - general_replace_with_scalar::(list_array, needle, to, arr_n) - } - DataType::LargeList(_) => { - let list_array = array.as_list::(); - general_replace_with_scalar::(list_array, needle, to, arr_n) - } - DataType::Null => Ok(new_null_array(array.data_type(), 1)), - array_type => exec_err!("array_replace does not support type '{array_type}'."), - } -} - fn replace_with_scalar_needle( list_array: &ArrayRef, scalar_from: &ScalarValue, @@ -612,8 +571,19 @@ fn replace_with_scalar_needle( return array_replace_internal(list_array, &from_array, to_array, arr_n); } - let needle = scalar_from.to_array_of_size(1)?; - array_replace_dispatch_scalar(list_array, &needle, to_array, arr_n) + let needle = Scalar::new(scalar_from.to_array_of_size(1)?); + match list_array.data_type() { + DataType::List(_) => { + let list = list_array.as_list::(); + general_replace_with_scalar::(list, &needle, to_array, arr_n) + } + DataType::LargeList(_) => { + let list = list_array.as_list::(); + general_replace_with_scalar::(list, &needle, to_array, arr_n) + } + DataType::Null => Ok(new_null_array(list_array.data_type(), 1)), + array_type => exec_err!("array_replace does not support type '{array_type}'."), + } } fn array_replace_internal( From 3bae38b17422d982f88b962c22d32315bc16799a Mon Sep 17 00:00:00 2001 From: linfeng <33561138+lyne7-sc@users.noreply.github.com> Date: Thu, 21 May 2026 17:46:09 +0800 Subject: [PATCH 4/4] handle to and max arg scalar --- datafusion/functions-nested/src/replace.rs | 78 +++++++++++++--------- 1 file changed, 45 insertions(+), 33 deletions(-) diff --git a/datafusion/functions-nested/src/replace.rs b/datafusion/functions-nested/src/replace.rs index 819edff1f2f20..630a24ac5440b 100644 --- a/datafusion/functions-nested/src/replace.rs +++ b/datafusion/functions-nested/src/replace.rs @@ -126,21 +126,17 @@ impl ScalarUDFImpl for ArrayReplace { let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?; let num_rows = args.number_rows; let list_array = list_arg.to_array(num_rows)?; - let to_array = to_arg.to_array(num_rows)?; - let arr_n = vec![1; num_rows]; + let arr_n = [1i64]; match from_arg { ColumnarValue::Array(from_array) => { + let to_array = to_arg.to_array(num_rows)?; let result = array_replace_internal(&list_array, from_array, &to_array, &arr_n)?; Ok(ColumnarValue::Array(result)) } ColumnarValue::Scalar(scalar_from) => { - let result = replace_with_scalar_needle( - &list_array, - scalar_from, - &to_array, - &arr_n, - )?; + let result = + replace_with_scalar_needle(&list_array, scalar_from, to_arg, &arr_n)?; Ok(ColumnarValue::Array(result)) } } @@ -222,22 +218,23 @@ impl ScalarUDFImpl for ArrayReplaceN { take_function_args(self.name(), &args.args)?; let num_rows = args.number_rows; let list_array = list_arg.to_array(num_rows)?; - let to_array = to_arg.to_array(num_rows)?; - let max_array = max_arg.to_array(num_rows)?; - let arr_n = as_int64_array(&max_array)?.values().to_vec(); + let arr_n = match max_arg { + ColumnarValue::Scalar(s) => { + let a = s.to_array_of_size(1)?; + as_int64_array(&a)?.values().to_vec() + } + ColumnarValue::Array(a) => as_int64_array(&a)?.values().to_vec(), + }; match from_arg { ColumnarValue::Array(from_array) => { + let to_array = to_arg.to_array(num_rows)?; let result = array_replace_internal(&list_array, from_array, &to_array, &arr_n)?; Ok(ColumnarValue::Array(result)) } ColumnarValue::Scalar(scalar_from) => { - let result = replace_with_scalar_needle( - &list_array, - scalar_from, - &to_array, - &arr_n, - )?; + let result = + replace_with_scalar_needle(&list_array, scalar_from, to_arg, &arr_n)?; Ok(ColumnarValue::Array(result)) } } @@ -316,21 +313,17 @@ impl ScalarUDFImpl for ArrayReplaceAll { let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?; let num_rows = args.number_rows; let list_array = list_arg.to_array(num_rows)?; - let to_array = to_arg.to_array(num_rows)?; - let arr_n = vec![i64::MAX; num_rows]; + let arr_n = [i64::MAX]; match from_arg { ColumnarValue::Array(from_array) => { + let to_array = to_arg.to_array(num_rows)?; let result = array_replace_internal(&list_array, from_array, &to_array, &arr_n)?; Ok(ColumnarValue::Array(result)) } ColumnarValue::Scalar(scalar_from) => { - let result = replace_with_scalar_needle( - &list_array, - scalar_from, - &to_array, - &arr_n, - )?; + let result = + replace_with_scalar_needle(&list_array, scalar_from, to_arg, &arr_n)?; Ok(ColumnarValue::Array(result)) } } @@ -403,7 +396,11 @@ fn general_replace( let original_idx = O::usize_as(0); let replace_idx = O::usize_as(1); - let n = arr_n[row_index]; + let n = if arr_n.len() == 1 { + arr_n[0] + } else { + arr_n[row_index] + }; let mut counter = 0; // All elements are false, no need to replace, just copy original data @@ -509,7 +506,12 @@ fn general_replace_with_scalar( let original_idx = O::usize_as(0); let replace_idx = O::usize_as(1); - let n = arr_n[row_index]; + let to_index = if to_array.len() == 1 { 0 } else { row_index }; + let n = if arr_n.len() == 1 { + arr_n[0] + } else { + arr_n[row_index] + }; let mut counter = 0; let eq_array = match_array.slice(start, row_len); @@ -529,7 +531,7 @@ fn general_replace_with_scalar( start + i, ); } - mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1); + mutable.extend(replace_idx.to_usize().unwrap(), to_index, to_index + 1); counter += 1; if counter == n { mutable.extend(original_idx.to_usize().unwrap(), start + i + 1, end); @@ -560,26 +562,36 @@ fn general_replace_with_scalar( } /// Dispatches scalar-needle array replacement by list offset type. +/// +/// When `to_arg` is a scalar, only a length-1 array is materialized for the +/// replacement value, avoiding unnecessary allocation of `num_rows` copies. fn replace_with_scalar_needle( list_array: &ArrayRef, scalar_from: &ScalarValue, - to_array: &ArrayRef, + to_arg: &ColumnarValue, arr_n: &[i64], ) -> Result { + let num_rows = list_array.len(); + if scalar_from.data_type().is_nested() { - let from_array = scalar_from.to_array_of_size(list_array.len())?; - return array_replace_internal(list_array, &from_array, to_array, arr_n); + let from_array = scalar_from.to_array_of_size(num_rows)?; + let to_array = to_arg.to_array(num_rows)?; + return array_replace_internal(list_array, &from_array, &to_array, arr_n); } let needle = Scalar::new(scalar_from.to_array_of_size(1)?); + let to_array = match to_arg { + ColumnarValue::Scalar(s) => s.to_array_of_size(1)?, + ColumnarValue::Array(a) => Arc::clone(a), + }; match list_array.data_type() { DataType::List(_) => { let list = list_array.as_list::(); - general_replace_with_scalar::(list, &needle, to_array, arr_n) + general_replace_with_scalar::(list, &needle, &to_array, arr_n) } DataType::LargeList(_) => { let list = list_array.as_list::(); - general_replace_with_scalar::(list, &needle, to_array, arr_n) + general_replace_with_scalar::(list, &needle, &to_array, arr_n) } DataType::Null => Ok(new_null_array(list_array.data_type(), 1)), array_type => exec_err!("array_replace does not support type '{array_type}'."),