Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 190 additions & 45 deletions datafusion/functions-nested/src/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,20 @@

use arrow::array::{
Array, ArrayRef, AsArray, Capacities, GenericListArray, MutableArrayData,
NullBufferBuilder, OffsetSizeTrait, new_null_array,
NullBufferBuilder, OffsetBufferBuilder, OffsetSizeTrait, Scalar, new_null_array,
};
use arrow::datatypes::{DataType, Field};

use arrow::buffer::OffsetBuffer;
use arrow::datatypes::{DataType, Field};
use datafusion_common::cast::as_int64_array;
use datafusion_common::utils::ListCoercion;
use datafusion_common::{Result, exec_err, utils::take_function_args};
use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args};
use datafusion_expr::{
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use datafusion_macros::user_doc;

use crate::utils::compare_element_to_list;
use crate::utils::make_scalar_function;

use std::sync::Arc;

Expand Down Expand Up @@ -125,7 +123,23 @@ impl ScalarUDFImpl for ArrayReplace {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_replace_inner)(&args.args)
let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
let arr_n = [1i64];
match from_arg {
ColumnarValue::Array(from_array) => {
let to_array = to_arg.to_array(num_rows)?;
let result =
array_replace_internal(&list_array, from_array, &to_array, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
ColumnarValue::Scalar(scalar_from) => {
let result =
replace_with_scalar_needle(&list_array, scalar_from, to_arg, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -200,7 +214,30 @@ impl ScalarUDFImpl for ArrayReplaceN {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_replace_n_inner)(&args.args)
let [list_arg, from_arg, to_arg, max_arg] =
take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
let arr_n = match max_arg {
ColumnarValue::Scalar(s) => {
let a = s.to_array_of_size(1)?;
as_int64_array(&a)?.values().to_vec()
}
ColumnarValue::Array(a) => as_int64_array(&a)?.values().to_vec(),
};
match from_arg {
ColumnarValue::Array(from_array) => {
let to_array = to_arg.to_array(num_rows)?;
let result =
array_replace_internal(&list_array, from_array, &to_array, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
ColumnarValue::Scalar(scalar_from) => {
let result =
replace_with_scalar_needle(&list_array, scalar_from, to_arg, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -273,7 +310,23 @@ impl ScalarUDFImpl for ArrayReplaceAll {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_replace_all_inner)(&args.args)
let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
let arr_n = [i64::MAX];
match from_arg {
ColumnarValue::Array(from_array) => {
let to_array = to_arg.to_array(num_rows)?;
let result =
array_replace_internal(&list_array, from_array, &to_array, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
ColumnarValue::Scalar(scalar_from) => {
let result =
replace_with_scalar_needle(&list_array, scalar_from, to_arg, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -343,7 +396,11 @@ fn general_replace<O: OffsetSizeTrait>(

let original_idx = O::usize_as(0);
let replace_idx = O::usize_as(1);
let n = arr_n[row_index];
let n = if arr_n.len() == 1 {
arr_n[0]
} else {
arr_n[row_index]
};
let mut counter = 0;

// All elements are false, no need to replace, just copy original data
Expand Down Expand Up @@ -412,63 +469,151 @@ fn general_replace<O: OffsetSizeTrait>(
)?))
}

fn array_replace_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, from, to] = take_function_args("array_replace", args)?;
/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurrences
/// of `needle` with `to_array[i]`.
///
/// This is a specialized version of `general_replace` for scalar needles that
/// uses a single bulk comparison for better performance.
fn general_replace_with_scalar<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
needle: &Scalar<ArrayRef>,
to_array: &ArrayRef,
arr_n: &[i64],
) -> Result<ArrayRef> {
let values = list_array.values();
let original_data = values.to_data();
let to_data = to_array.to_data();
let capacity = Capacities::Array(original_data.len());

// replace at most one occurrence for each element
let arr_n = vec![1; array.len()];
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_replace::<i32>(list_array, from, to, &arr_n)
let mut mutable = MutableArrayData::with_capacities(
vec![&original_data, &to_data],
false,
capacity,
);

let mut offsets = OffsetBufferBuilder::<O>::new(list_array.len());
let match_array = arrow_ord::cmp::not_distinct(list_array.values(), needle)?;

for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
let start = offset_window[0].to_usize().unwrap();
let end = offset_window[1].to_usize().unwrap();
let row_len = end - start;

if list_array.is_null(row_index) {
offsets.push_length(0);
continue;
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, from, to, &arr_n)

let original_idx = O::usize_as(0);
let replace_idx = O::usize_as(1);
let to_index = if to_array.len() == 1 { 0 } else { row_index };
let n = if arr_n.len() == 1 {
arr_n[0]
} else {
arr_n[row_index]
};
let mut counter = 0;

let eq_array = match_array.slice(start, row_len);
if n <= 0 || eq_array.true_count() == 0 {
mutable.extend(original_idx.to_usize().unwrap(), start, end);
offsets.push_length(row_len);
continue;
}
DataType::Null => Ok(new_null_array(array.data_type(), 1)),
array_type => exec_err!("array_replace does not support type '{array_type}'."),

let mut pending_retain: Option<usize> = None;
for (i, is_match) in eq_array.iter().enumerate() {
if is_match == Some(true) && counter < n {
if let Some(rs) = pending_retain.take() {
mutable.extend(
original_idx.to_usize().unwrap(),
start + rs,
start + i,
);
}
mutable.extend(replace_idx.to_usize().unwrap(), to_index, to_index + 1);
counter += 1;
if counter == n {
mutable.extend(original_idx.to_usize().unwrap(), start + i + 1, end);
break;
}
} else if pending_retain.is_none() {
pending_retain = Some(i);
}
}

if counter < n
&& let Some(rs) = pending_retain
{
mutable.extend(original_idx.to_usize().unwrap(), start + rs, end);
}

offsets.push_length(row_len);
}

let data = mutable.freeze();

Ok(Arc::new(GenericListArray::<O>::try_new(
Arc::new(Field::new_list_field(list_array.value_type(), true)),
offsets.finish(),
arrow::array::make_array(data),
list_array.nulls().cloned(),
)?))
}

fn array_replace_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, from, to, max] = take_function_args("array_replace_n", args)?;
/// Dispatches scalar-needle array replacement by list offset type.
///
/// When `to_arg` is a scalar, only a length-1 array is materialized for the
/// replacement value, avoiding unnecessary allocation of `num_rows` copies.
fn replace_with_scalar_needle(
list_array: &ArrayRef,
scalar_from: &ScalarValue,
to_arg: &ColumnarValue,
arr_n: &[i64],
) -> Result<ArrayRef> {
let num_rows = list_array.len();

// replace the specified number of occurrences
let arr_n = as_int64_array(max)?.values().to_vec();
match array.data_type() {
if scalar_from.data_type().is_nested() {
let from_array = scalar_from.to_array_of_size(num_rows)?;
let to_array = to_arg.to_array(num_rows)?;
return array_replace_internal(list_array, &from_array, &to_array, arr_n);
}

let needle = Scalar::new(scalar_from.to_array_of_size(1)?);
let to_array = match to_arg {
ColumnarValue::Scalar(s) => s.to_array_of_size(1)?,
ColumnarValue::Array(a) => Arc::clone(a),
};
match list_array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_replace::<i32>(list_array, from, to, &arr_n)
let list = list_array.as_list::<i32>();
general_replace_with_scalar::<i32>(list, &needle, &to_array, arr_n)
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, from, to, &arr_n)
}
DataType::Null => Ok(new_null_array(array.data_type(), 1)),
array_type => {
exec_err!("array_replace_n does not support type '{array_type}'.")
let list = list_array.as_list::<i64>();
general_replace_with_scalar::<i64>(list, &needle, &to_array, arr_n)
}
DataType::Null => Ok(new_null_array(list_array.data_type(), 1)),
array_type => exec_err!("array_replace does not support type '{array_type}'."),
}
}

fn array_replace_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, from, to] = take_function_args("array_replace_all", args)?;

// replace all occurrences (up to "i64::MAX")
let arr_n = vec![i64::MAX; array.len()];
fn array_replace_internal(
array: &ArrayRef,
from: &ArrayRef,
to: &ArrayRef,
arr_n: &[i64],
) -> Result<ArrayRef> {
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_replace::<i32>(list_array, from, to, &arr_n)
general_replace::<i32>(list_array, from, to, arr_n)
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, from, to, &arr_n)
general_replace::<i64>(list_array, from, to, arr_n)
}
DataType::Null => Ok(new_null_array(array.data_type(), 1)),
array_type => {
exec_err!("array_replace_all does not support type '{array_type}'.")
}
array_type => exec_err!("array_replace does not support type '{array_type}'."),
}
}
60 changes: 54 additions & 6 deletions datafusion/sqllogictest/test_files/array/array_replace.slt
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,33 @@ from large_nested_arrays_with_repeating_elements;
[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]]
[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]]

# array_replace scalar arguments over multiple input rows
query ???
select
array_replace(column1, 2, 9),
array_replace_n(column1, 2, 9, 2),
array_replace_all(column1, 2, 9)
from (
values
(make_array(1, 2, 2, 3)),
(make_array(2, 4, 2))
) as t(column1);
----
[1, 9, 2, 3] [1, 9, 9, 3] [1, 9, 9, 3]
[9, 4, 2] [9, 4, 9] [9, 4, 9]

# array_replace_n scalar max exceeding matches over multiple input rows
query ?
select array_replace_n(column1, 2, 9, 10)
from (
values
(make_array(1, 2, 2, 3)),
(make_array(2, 4, 2))
) as t(column1);
----
[1, 9, 9, 3]
[9, 4, 9]

## array_replace_n (aliases: `list_replace_n`)

# array_replace_n scalar function #1
Expand All @@ -226,22 +253,35 @@ select
----
[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4] [1, 4, 4] [0, 4, 0, 5]

query ????
query ??????
select
array_replace_n(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3, 2),
array_replace_n(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0, 2),
array_replace_n(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0, 3),
array_replace_n(arrow_cast(make_array(1, 4, 4), 'LargeList(Int64)'), 4, 0, 0);
array_replace_n(arrow_cast(make_array(1, 4, 4), 'LargeList(Int64)'), 4, 0, 0),
array_replace_n(arrow_cast(make_array(1, 4, 4), 'LargeList(Int64)'), 4, 0, -1),
array_replace_n(arrow_cast(make_array(1, 4, 1, 5), 'LargeList(Int64)'), 1, 0, 10);
----
[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4]
[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4] [1, 4, 4] [0, 4, 0, 5]

query ???
query ??????
select
array_replace_n(arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int64)'), 2, 3, 2),
array_replace_n(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'FixedSizeList(7, Int64)'), 4, 0, 2),
array_replace_n(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), 4, 0, 3);
array_replace_n(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), 4, 0, 3),
array_replace_n(arrow_cast(make_array(1, 4, 4), 'FixedSizeList(3, Int64)'), 4, 0, 0),
array_replace_n(arrow_cast(make_array(1, 4, 4), 'FixedSizeList(3, Int64)'), 4, 0, -1),
array_replace_n(arrow_cast(make_array(1, 4, 1, 5), 'FixedSizeList(4, Int64)'), 1, 0, 10);
----
[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3]
[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4] [1, 4, 4] [0, 4, 0, 5]

# array_replace_n scalar max exceeding matches for empty arrays
query ??
select
array_replace_n(arrow_cast(make_array(), 'List(Int64)'), 2, 9, 10),
array_replace_n(arrow_cast(make_array(), 'LargeList(Int64)'), 2, 9, 10);
----
[] []

# array_replace_n scalar function #2 (element is list)
query ??
Expand Down Expand Up @@ -657,6 +697,14 @@ select column1, column2, column3, column4, array_replace_n(column1, column2, col
NULL 3 2 1 NULL
[3, 1, 3] 3 NULL 1 [NULL, 1, 3]

query ???
select
array_replace(make_array(3, NULL, NULL), NULL, 5),
array_replace_n(make_array(3, NULL, NULL), NULL, 5, 10),
array_replace_all(make_array(3, NULL, NULL), NULL, 5);
----
[3, 5, NULL] [3, 5, 5] [3, 5, 5]



statement ok
Expand Down
Loading