Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 201 additions & 38 deletions datafusion/functions-nested/src/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,21 @@

use arrow::array::{
Array, ArrayRef, AsArray, Capacities, GenericListArray, MutableArrayData,
NullBufferBuilder, OffsetSizeTrait, new_null_array,
NullBufferBuilder, OffsetSizeTrait, Scalar, new_null_array,
};
use arrow::datatypes::{DataType, Field};

use arrow::buffer::OffsetBuffer;
use datafusion_common::cast::as_int64_array;
use datafusion_common::utils::ListCoercion;
use datafusion_common::{Result, exec_err, utils::take_function_args};
use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args};
use datafusion_expr::{
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use datafusion_macros::user_doc;

use crate::utils::compare_element_to_list;
use crate::utils::make_scalar_function;

use std::sync::Arc;

Expand Down Expand Up @@ -125,7 +124,27 @@ impl ScalarUDFImpl for ArrayReplace {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_replace_inner)(&args.args)
let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
let to_array = to_arg.to_array(num_rows)?;
let arr_n = vec![1; num_rows];
match from_arg {
ColumnarValue::Array(from_array) => {
let result =
array_replace_internal(&list_array, from_array, &to_array, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
ColumnarValue::Scalar(scalar_from) => {
let result = replace_with_scalar_needle(
&list_array,
scalar_from,
&to_array,
&arr_n,
)?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -200,7 +219,29 @@ impl ScalarUDFImpl for ArrayReplaceN {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_replace_n_inner)(&args.args)
let [list_arg, from_arg, to_arg, max_arg] =
take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
let to_array = to_arg.to_array(num_rows)?;
let max_array = max_arg.to_array(num_rows)?;
let arr_n = as_int64_array(&max_array)?.values().to_vec();
match from_arg {
ColumnarValue::Array(from_array) => {
let result =
array_replace_internal(&list_array, from_array, &to_array, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
ColumnarValue::Scalar(scalar_from) => {
let result = replace_with_scalar_needle(
&list_array,
scalar_from,
&to_array,
&arr_n,
)?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -273,7 +314,27 @@ impl ScalarUDFImpl for ArrayReplaceAll {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_replace_all_inner)(&args.args)
let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
let to_array = to_arg.to_array(num_rows)?;
let arr_n = vec![i64::MAX; num_rows];
match from_arg {
ColumnarValue::Array(from_array) => {
let result =
array_replace_internal(&list_array, from_array, &to_array, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
ColumnarValue::Scalar(scalar_from) => {
let result = replace_with_scalar_needle(
&list_array,
scalar_from,
&to_array,
&arr_n,
)?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -412,63 +473,165 @@ fn general_replace<O: OffsetSizeTrait>(
)?))
}

fn array_replace_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, from, to] = take_function_args("array_replace", args)?;
/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurrences
/// of `needle[0]` with `to_array[i]`.
///
/// This is a specialized version of `general_replace` for scalar needles that
/// uses a single bulk comparison for better performance.
fn general_replace_with_scalar<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
needle: &ArrayRef,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think needle should be a Scalar in the arguments here, to make it clear this is the scalar (without needing to read the docstring)

  • This can be passed in all the way from replace_with_scalar_needle for example, since its still a ScalarValue at that point

to_array: &ArrayRef,
arr_n: &[i64],
) -> Result<ArrayRef> {
let mut offsets: Vec<O> = vec![O::usize_as(0)];
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using OffsetBufferBuilder provides a nicer API for doing these operations (can just push length)

let values = list_array.values();
let original_data = values.to_data();
let to_data = to_array.to_data();
let capacity = Capacities::Array(original_data.len());

// replace at most one occurrence for each element
let arr_n = vec![1; array.len()];
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_replace::<i32>(list_array, from, to, &arr_n)
let mut mutable = MutableArrayData::with_capacities(
vec![&original_data, &to_data],
false,
capacity,
);

let mut valid = NullBufferBuilder::new(list_array.len());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need a builder for nulls, we can copy the input array null buffer as is

let match_array = arrow_ord::cmp::not_distinct(
list_array.values(),
&Scalar::new(Arc::clone(needle)),
)?;

for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
if list_array.is_null(row_index) {
offsets.push(offsets[row_index]);
valid.append_null();
continue;
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, from, to, &arr_n)

let start = offset_window[0];
let end = offset_window[1];
let start_usize = start.to_usize().unwrap();
let end_usize = end.to_usize().unwrap();
let row_len = end_usize - start_usize;

let original_idx = O::usize_as(0);
let replace_idx = O::usize_as(1);
let n = arr_n[row_index];
let mut counter = 0;

let eq_array = match_array.slice(start_usize, row_len);
if n <= 0 || eq_array.true_count() == 0 {
mutable.extend(original_idx.to_usize().unwrap(), start_usize, end_usize);
offsets.push(offsets[row_index] + (end - start));
valid.append_non_null();
continue;
}
DataType::Null => Ok(new_null_array(array.data_type(), 1)),
array_type => exec_err!("array_replace does not support type '{array_type}'."),

let mut pending_retain: Option<O> = None;
for (i, is_match) in eq_array.iter().enumerate() {
let i = O::usize_as(i);
if is_match == Some(true) && counter < n {
if let Some(rs) = pending_retain.take() {
mutable.extend(
original_idx.to_usize().unwrap(),
start_usize + rs.to_usize().unwrap(),
start_usize + i.to_usize().unwrap(),
);
}
mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1);
counter += 1;
if counter == n {
mutable.extend(
original_idx.to_usize().unwrap(),
start_usize + i.to_usize().unwrap() + 1,
end_usize,
);
break;
}
} else if pending_retain.is_none() {
pending_retain = Some(i);
}
}

if counter < n
&& let Some(rs) = pending_retain
{
mutable.extend(
original_idx.to_usize().unwrap(),
start_usize + rs.to_usize().unwrap(),
end_usize,
);
}

offsets.push(offsets[row_index] + (end - start));
valid.append_non_null();
}
}

fn array_replace_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, from, to, max] = take_function_args("array_replace_n", args)?;
let data = mutable.freeze();

// replace the specified number of occurrences
let arr_n = as_int64_array(max)?.values().to_vec();
Ok(Arc::new(GenericListArray::<O>::try_new(
Arc::new(Field::new_list_field(list_array.value_type(), true)),
OffsetBuffer::<O>::new(offsets.into()),
arrow::array::make_array(data),
valid.finish(),
)?))
}

/// Dispatches scalar-needle array replacement by list offset type.
///
/// `needle` must be a length-1 array containing the scalar element to replace.
fn array_replace_dispatch_scalar(
array: &ArrayRef,
needle: &ArrayRef,
to: &ArrayRef,
arr_n: &[i64],
) -> Result<ArrayRef> {
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_replace::<i32>(list_array, from, to, &arr_n)
general_replace_with_scalar::<i32>(list_array, needle, to, arr_n)
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, from, to, &arr_n)
general_replace_with_scalar::<i64>(list_array, needle, to, arr_n)
}
DataType::Null => Ok(new_null_array(array.data_type(), 1)),
array_type => {
exec_err!("array_replace_n does not support type '{array_type}'.")
}
array_type => exec_err!("array_replace does not support type '{array_type}'."),
}
}

fn array_replace_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, from, to] = take_function_args("array_replace_all", args)?;
fn replace_with_scalar_needle(
list_array: &ArrayRef,
scalar_from: &ScalarValue,
to_array: &ArrayRef,
arr_n: &[i64],
) -> Result<ArrayRef> {
if scalar_from.data_type().is_nested() {
let from_array = scalar_from.to_array_of_size(list_array.len())?;
return array_replace_internal(list_array, &from_array, to_array, arr_n);
}

// replace all occurrences (up to "i64::MAX")
let arr_n = vec![i64::MAX; array.len()];
let needle = scalar_from.to_array_of_size(1)?;
array_replace_dispatch_scalar(list_array, &needle, to_array, arr_n)
}

fn array_replace_internal(
array: &ArrayRef,
from: &ArrayRef,
to: &ArrayRef,
arr_n: &[i64],
) -> Result<ArrayRef> {
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_replace::<i32>(list_array, from, to, &arr_n)
general_replace::<i32>(list_array, from, to, arr_n)
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, from, to, &arr_n)
general_replace::<i64>(list_array, from, to, arr_n)
}
DataType::Null => Ok(new_null_array(array.data_type(), 1)),
array_type => {
exec_err!("array_replace_all does not support type '{array_type}'.")
}
array_type => exec_err!("array_replace does not support type '{array_type}'."),
}
}
Loading
Loading