diff --git a/datafusion/functions-nested/src/replace.rs b/datafusion/functions-nested/src/replace.rs index a9a53a3cb989f..630a24ac5440b 100644 --- a/datafusion/functions-nested/src/replace.rs +++ b/datafusion/functions-nested/src/replace.rs @@ -19,14 +19,13 @@ use arrow::array::{ Array, ArrayRef, AsArray, Capacities, GenericListArray, MutableArrayData, - NullBufferBuilder, OffsetSizeTrait, new_null_array, + NullBufferBuilder, OffsetBufferBuilder, OffsetSizeTrait, Scalar, new_null_array, }; -use arrow::datatypes::{DataType, Field}; - use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field}; use datafusion_common::cast::as_int64_array; use datafusion_common::utils::ListCoercion; -use datafusion_common::{Result, exec_err, utils::take_function_args}; +use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -34,7 +33,6 @@ use datafusion_expr::{ use datafusion_macros::user_doc; use crate::utils::compare_element_to_list; -use crate::utils::make_scalar_function; use std::sync::Arc; @@ -125,7 +123,23 @@ impl ScalarUDFImpl for ArrayReplace { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(array_replace_inner)(&args.args) + let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + let arr_n = [1i64]; + match from_arg { + ColumnarValue::Array(from_array) => { + let to_array = to_arg.to_array(num_rows)?; + let result = + array_replace_internal(&list_array, from_array, &to_array, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(scalar_from) => { + let result = + replace_with_scalar_needle(&list_array, scalar_from, to_arg, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -200,7 +214,30 @@ impl ScalarUDFImpl for ArrayReplaceN { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(array_replace_n_inner)(&args.args) + let [list_arg, from_arg, to_arg, max_arg] = + take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + let arr_n = match max_arg { + ColumnarValue::Scalar(s) => { + let a = s.to_array_of_size(1)?; + as_int64_array(&a)?.values().to_vec() + } + ColumnarValue::Array(a) => as_int64_array(&a)?.values().to_vec(), + }; + match from_arg { + ColumnarValue::Array(from_array) => { + let to_array = to_arg.to_array(num_rows)?; + let result = + array_replace_internal(&list_array, from_array, &to_array, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(scalar_from) => { + let result = + replace_with_scalar_needle(&list_array, scalar_from, to_arg, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -273,7 +310,23 @@ impl ScalarUDFImpl for ArrayReplaceAll { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(array_replace_all_inner)(&args.args) + let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + let arr_n = [i64::MAX]; + match from_arg { + ColumnarValue::Array(from_array) => { + let to_array = to_arg.to_array(num_rows)?; + let result = + array_replace_internal(&list_array, from_array, &to_array, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(scalar_from) => { + let result = + replace_with_scalar_needle(&list_array, scalar_from, to_arg, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -343,7 +396,11 @@ fn general_replace( let original_idx = O::usize_as(0); let replace_idx = O::usize_as(1); - let n = arr_n[row_index]; + let n = if arr_n.len() == 1 { + arr_n[0] + } else { + arr_n[row_index] + }; let mut counter = 0; // All elements are false, no need to replace, just copy original data @@ -412,63 +469,151 @@ fn general_replace( )?)) } -fn array_replace_inner(args: &[ArrayRef]) -> Result { - let [array, from, to] = take_function_args("array_replace", args)?; +/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurrences +/// of `needle` with `to_array[i]`. +/// +/// This is a specialized version of `general_replace` for scalar needles that +/// uses a single bulk comparison for better performance. +fn general_replace_with_scalar( + list_array: &GenericListArray, + needle: &Scalar, + to_array: &ArrayRef, + arr_n: &[i64], +) -> Result { + let values = list_array.values(); + let original_data = values.to_data(); + let to_data = to_array.to_data(); + let capacity = Capacities::Array(original_data.len()); - // replace at most one occurrence for each element - let arr_n = vec![1; array.len()]; - match array.data_type() { - DataType::List(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &to_data], + false, + capacity, + ); + + let mut offsets = OffsetBufferBuilder::::new(list_array.len()); + let match_array = arrow_ord::cmp::not_distinct(list_array.values(), needle)?; + + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + let start = offset_window[0].to_usize().unwrap(); + let end = offset_window[1].to_usize().unwrap(); + let row_len = end - start; + + if list_array.is_null(row_index) { + offsets.push_length(0); + continue; } - DataType::LargeList(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + + let original_idx = O::usize_as(0); + let replace_idx = O::usize_as(1); + let to_index = if to_array.len() == 1 { 0 } else { row_index }; + let n = if arr_n.len() == 1 { + arr_n[0] + } else { + arr_n[row_index] + }; + let mut counter = 0; + + let eq_array = match_array.slice(start, row_len); + if n <= 0 || eq_array.true_count() == 0 { + mutable.extend(original_idx.to_usize().unwrap(), start, end); + offsets.push_length(row_len); + continue; } - DataType::Null => Ok(new_null_array(array.data_type(), 1)), - array_type => exec_err!("array_replace does not support type '{array_type}'."), + + let mut pending_retain: Option = None; + for (i, is_match) in eq_array.iter().enumerate() { + if is_match == Some(true) && counter < n { + if let Some(rs) = pending_retain.take() { + mutable.extend( + original_idx.to_usize().unwrap(), + start + rs, + start + i, + ); + } + mutable.extend(replace_idx.to_usize().unwrap(), to_index, to_index + 1); + counter += 1; + if counter == n { + mutable.extend(original_idx.to_usize().unwrap(), start + i + 1, end); + break; + } + } else if pending_retain.is_none() { + pending_retain = Some(i); + } + } + + if counter < n + && let Some(rs) = pending_retain + { + mutable.extend(original_idx.to_usize().unwrap(), start + rs, end); + } + + offsets.push_length(row_len); } + + let data = mutable.freeze(); + + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new_list_field(list_array.value_type(), true)), + offsets.finish(), + arrow::array::make_array(data), + list_array.nulls().cloned(), + )?)) } -fn array_replace_n_inner(args: &[ArrayRef]) -> Result { - let [array, from, to, max] = take_function_args("array_replace_n", args)?; +/// Dispatches scalar-needle array replacement by list offset type. +/// +/// When `to_arg` is a scalar, only a length-1 array is materialized for the +/// replacement value, avoiding unnecessary allocation of `num_rows` copies. +fn replace_with_scalar_needle( + list_array: &ArrayRef, + scalar_from: &ScalarValue, + to_arg: &ColumnarValue, + arr_n: &[i64], +) -> Result { + let num_rows = list_array.len(); - // replace the specified number of occurrences - let arr_n = as_int64_array(max)?.values().to_vec(); - match array.data_type() { + if scalar_from.data_type().is_nested() { + let from_array = scalar_from.to_array_of_size(num_rows)?; + let to_array = to_arg.to_array(num_rows)?; + return array_replace_internal(list_array, &from_array, &to_array, arr_n); + } + + let needle = Scalar::new(scalar_from.to_array_of_size(1)?); + let to_array = match to_arg { + ColumnarValue::Scalar(s) => s.to_array_of_size(1)?, + ColumnarValue::Array(a) => Arc::clone(a), + }; + match list_array.data_type() { DataType::List(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + let list = list_array.as_list::(); + general_replace_with_scalar::(list, &needle, &to_array, arr_n) } DataType::LargeList(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) - } - DataType::Null => Ok(new_null_array(array.data_type(), 1)), - array_type => { - exec_err!("array_replace_n does not support type '{array_type}'.") + let list = list_array.as_list::(); + general_replace_with_scalar::(list, &needle, &to_array, arr_n) } + DataType::Null => Ok(new_null_array(list_array.data_type(), 1)), + array_type => exec_err!("array_replace does not support type '{array_type}'."), } } -fn array_replace_all_inner(args: &[ArrayRef]) -> Result { - let [array, from, to] = take_function_args("array_replace_all", args)?; - - // replace all occurrences (up to "i64::MAX") - let arr_n = vec![i64::MAX; array.len()]; +fn array_replace_internal( + array: &ArrayRef, + from: &ArrayRef, + to: &ArrayRef, + arr_n: &[i64], +) -> Result { match array.data_type() { DataType::List(_) => { let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + general_replace::(list_array, from, to, arr_n) } DataType::LargeList(_) => { let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + general_replace::(list_array, from, to, arr_n) } DataType::Null => Ok(new_null_array(array.data_type(), 1)), - array_type => { - exec_err!("array_replace_all does not support type '{array_type}'.") - } + array_type => exec_err!("array_replace does not support type '{array_type}'."), } } diff --git a/datafusion/sqllogictest/test_files/array/array_replace.slt b/datafusion/sqllogictest/test_files/array/array_replace.slt index 390ed4b946520..9c5ab206a26b0 100644 --- a/datafusion/sqllogictest/test_files/array/array_replace.slt +++ b/datafusion/sqllogictest/test_files/array/array_replace.slt @@ -212,6 +212,33 @@ from large_nested_arrays_with_repeating_elements; [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] +# array_replace scalar arguments over multiple input rows +query ??? +select + array_replace(column1, 2, 9), + array_replace_n(column1, 2, 9, 2), + array_replace_all(column1, 2, 9) +from ( + values + (make_array(1, 2, 2, 3)), + (make_array(2, 4, 2)) +) as t(column1); +---- +[1, 9, 2, 3] [1, 9, 9, 3] [1, 9, 9, 3] +[9, 4, 2] [9, 4, 9] [9, 4, 9] + +# array_replace_n scalar max exceeding matches over multiple input rows +query ? +select array_replace_n(column1, 2, 9, 10) +from ( + values + (make_array(1, 2, 2, 3)), + (make_array(2, 4, 2)) +) as t(column1); +---- +[1, 9, 9, 3] +[9, 4, 9] + ## array_replace_n (aliases: `list_replace_n`) # array_replace_n scalar function #1 @@ -226,22 +253,35 @@ select ---- [1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4] [1, 4, 4] [0, 4, 0, 5] -query ???? +query ?????? select array_replace_n(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3, 2), array_replace_n(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0, 2), array_replace_n(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0, 3), - array_replace_n(arrow_cast(make_array(1, 4, 4), 'LargeList(Int64)'), 4, 0, 0); + array_replace_n(arrow_cast(make_array(1, 4, 4), 'LargeList(Int64)'), 4, 0, 0), + array_replace_n(arrow_cast(make_array(1, 4, 4), 'LargeList(Int64)'), 4, 0, -1), + array_replace_n(arrow_cast(make_array(1, 4, 1, 5), 'LargeList(Int64)'), 1, 0, 10); ---- -[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4] +[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4] [1, 4, 4] [0, 4, 0, 5] -query ??? +query ?????? select array_replace_n(arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int64)'), 2, 3, 2), array_replace_n(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'FixedSizeList(7, Int64)'), 4, 0, 2), - array_replace_n(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), 4, 0, 3); + array_replace_n(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), 4, 0, 3), + array_replace_n(arrow_cast(make_array(1, 4, 4), 'FixedSizeList(3, Int64)'), 4, 0, 0), + array_replace_n(arrow_cast(make_array(1, 4, 4), 'FixedSizeList(3, Int64)'), 4, 0, -1), + array_replace_n(arrow_cast(make_array(1, 4, 1, 5), 'FixedSizeList(4, Int64)'), 1, 0, 10); ---- -[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] +[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4] [1, 4, 4] [0, 4, 0, 5] + +# array_replace_n scalar max exceeding matches for empty arrays +query ?? +select + array_replace_n(arrow_cast(make_array(), 'List(Int64)'), 2, 9, 10), + array_replace_n(arrow_cast(make_array(), 'LargeList(Int64)'), 2, 9, 10); +---- +[] [] # array_replace_n scalar function #2 (element is list) query ?? @@ -657,6 +697,14 @@ select column1, column2, column3, column4, array_replace_n(column1, column2, col NULL 3 2 1 NULL [3, 1, 3] 3 NULL 1 [NULL, 1, 3] +query ??? +select + array_replace(make_array(3, NULL, NULL), NULL, 5), + array_replace_n(make_array(3, NULL, NULL), NULL, 5, 10), + array_replace_all(make_array(3, NULL, NULL), NULL, 5); +---- +[3, 5, NULL] [3, 5, 5] [3, 5, 5] + statement ok