Skip to content
258 changes: 213 additions & 45 deletions datafusion/functions-nested/src/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,20 @@

use arrow::array::{
Array, ArrayRef, AsArray, Capacities, GenericListArray, MutableArrayData,
NullBufferBuilder, OffsetSizeTrait, new_null_array,
NullBufferBuilder, OffsetBufferBuilder, OffsetSizeTrait, Scalar, new_null_array,
};
use arrow::datatypes::{DataType, Field};

use arrow::buffer::OffsetBuffer;
use arrow::datatypes::{DataType, Field};
use datafusion_common::cast::as_int64_array;
use datafusion_common::utils::ListCoercion;
use datafusion_common::{Result, exec_err, utils::take_function_args};
use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args};
use datafusion_expr::{
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use datafusion_macros::user_doc;

use crate::utils::compare_element_to_list;
use crate::utils::make_scalar_function;

use std::sync::Arc;

Expand Down Expand Up @@ -125,7 +123,27 @@ impl ScalarUDFImpl for ArrayReplace {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_replace_inner)(&args.args)
let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
match (from_arg, to_arg) {
(ColumnarValue::Scalar(scalar_from), ColumnarValue::Scalar(scalar_to)) => {
let result = array_replace_with_scalar_args(
&list_array,
scalar_from,
scalar_to,
1i64,
)?;
Ok(ColumnarValue::Array(result))
}
(from_arg, to_arg) => {
let from_array = from_arg.to_array(num_rows)?;
let to_array = to_arg.to_array(num_rows)?;
let result =
array_replace_internal(&list_array, &from_array, &to_array, &[1])?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -200,7 +218,38 @@ impl ScalarUDFImpl for ArrayReplaceN {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_replace_n_inner)(&args.args)
let [list_arg, from_arg, to_arg, max_arg] =
take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
match (from_arg, to_arg, max_arg) {
(
ColumnarValue::Scalar(scalar_from),
ColumnarValue::Scalar(scalar_to),
ColumnarValue::Scalar(scalar_max),
) => {
let ScalarValue::Int64(Some(n)) = scalar_max else {
// null max means no replacements
return Ok(ColumnarValue::Array(list_array));
};
let result = array_replace_with_scalar_args(
&list_array,
scalar_from,
scalar_to,
*n,
)?;
Ok(ColumnarValue::Array(result))
}
(from_arg, to_arg, max_arg) => {
let from_array = from_arg.to_array(num_rows)?;
let to_array = to_arg.to_array(num_rows)?;
let max_array = max_arg.to_array(num_rows)?;
let arr_n = as_int64_array(&max_array)?.values().to_vec();
let result =
array_replace_internal(&list_array, &from_array, &to_array, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -273,7 +322,31 @@ impl ScalarUDFImpl for ArrayReplaceAll {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_replace_all_inner)(&args.args)
let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
match (from_arg, to_arg) {
(ColumnarValue::Scalar(scalar_from), ColumnarValue::Scalar(scalar_to)) => {
let result = array_replace_with_scalar_args(
&list_array,
scalar_from,
scalar_to,
i64::MAX,
)?;
Ok(ColumnarValue::Array(result))
}
(from_arg, to_arg) => {
let from_array = from_arg.to_array(num_rows)?;
let to_array = to_arg.to_array(num_rows)?;
let result = array_replace_internal(
&list_array,
&from_array,
&to_array,
&[i64::MAX],
)?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -343,7 +416,11 @@ fn general_replace<O: OffsetSizeTrait>(

let original_idx = O::usize_as(0);
let replace_idx = O::usize_as(1);
let n = arr_n[row_index];
let n = if arr_n.len() == 1 {
Comment thread
Jefffrey marked this conversation as resolved.
arr_n[0]
} else {
arr_n[row_index]
};
let mut counter = 0;

// All elements are false, no need to replace, just copy original data
Expand Down Expand Up @@ -412,63 +489,154 @@ fn general_replace<O: OffsetSizeTrait>(
)?))
}

fn array_replace_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, from, to] = take_function_args("array_replace", args)?;
/// Replaces up to `max_replacements` occurrences of `needle` with the single
/// element in `to_array` for each row in `list_array`.
///
/// This is a specialized fast path for the all-scalar case that uses a single
/// bulk `not_distinct` comparison over only the visible values range, then
/// iterates match positions via `set_indices` instead of scanning every bit.
fn general_replace_with_scalar<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
needle: &Scalar<ArrayRef>,
scalar_to: &ScalarValue,
max_replacements: i64,
) -> Result<ArrayRef> {
// No replacement needed - return unchanged.
if max_replacements <= 0 {
return Ok(Arc::new(list_array.clone()));
}

// replace at most one occurrence for each element
let arr_n = vec![1; array.len()];
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_replace::<i32>(list_array, from, to, &arr_n)
let first_offset = list_array.offsets()[0].to_usize().unwrap();
let last_offset = list_array.offsets()[list_array.len()].to_usize().unwrap();
let visible_values = list_array
.values()
.slice(first_offset, last_offset - first_offset);

let to_array = scalar_to.to_array_of_size(1)?;
let original_data = visible_values.to_data();
let to_data = to_array.to_data();
let capacity = Capacities::Array(original_data.len());

let mut mutable = MutableArrayData::with_capacities(
vec![&original_data, &to_data],
false,
capacity,
);

let mut offsets = OffsetBufferBuilder::<O>::new(list_array.len());

// Single bulk comparison over the visible values only.
let match_bitmap = arrow_ord::cmp::not_distinct(&visible_values, needle)?;
let match_bits = match_bitmap.values();

for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
// Offsets relative to visible_values (subtract first_offset).
let start = offset_window[0].to_usize().unwrap() - first_offset;
let end = offset_window[1].to_usize().unwrap() - first_offset;
let row_len = end - start;

if list_array.is_null(row_index) {
offsets.push_length(0);
continue;
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, from, to, &arr_n)

// Slice the match bits to this row and iterate only over true positions.
let row_bits = match_bits.slice(start, row_len);
let mut match_positions = row_bits
.set_indices()
.take(max_replacements as usize)
.peekable();
if match_positions.peek().is_none() {
mutable.extend(0, start, end);
offsets.push_length(row_len);
continue;
}
DataType::Null => Ok(new_null_array(array.data_type(), 1)),
array_type => exec_err!("array_replace does not support type '{array_type}'."),

// Iterate only over the positions that match using set_indices,
// which is more efficient than scanning every bit because the number
// of matches is typically much smaller than the total array size.
let mut prev_end = 0usize;
for match_pos in match_positions {
// Retain elements before this match.
if match_pos > prev_end {
mutable.extend(0, start + prev_end, start + match_pos);
}
// Emit the replacement element.
mutable.extend(1, 0, 1);
prev_end = match_pos + 1;
}

// Copy remaining elements after the last replacement.
if prev_end < row_len {
mutable.extend(0, start + prev_end, end);
}

offsets.push_length(row_len);
}

let data = mutable.freeze();

Ok(Arc::new(GenericListArray::<O>::try_new(
Arc::new(Field::new_list_field(list_array.value_type(), true)),
offsets.finish(),
arrow::array::make_array(data),
list_array.nulls().cloned(),
)?))
}

fn array_replace_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, from, to, max] = take_function_args("array_replace_n", args)?;
/// Fast path for `array_replace` when all arguments are scalars.
///
/// Uses a single bulk `not_distinct` comparison instead of per-row comparisons.
fn array_replace_with_scalar_args(
list_array: &ArrayRef,
scalar_from: &ScalarValue,
scalar_to: &ScalarValue,
max_replacements: i64,
) -> Result<ArrayRef> {
// `not_distinct` doesn't support nested types, fall back to the generic array path.
if scalar_from.data_type().is_nested() {
let num_rows = list_array.len();
let from_array = scalar_from.to_array_of_size(num_rows)?;
let to_array = scalar_to.to_array_of_size(num_rows)?;
return array_replace_internal(
list_array,
&from_array,
&to_array,
&vec![max_replacements; num_rows],
);
}

// replace the specified number of occurrences
let arr_n = as_int64_array(max)?.values().to_vec();
match array.data_type() {
let needle = Scalar::new(scalar_from.to_array_of_size(1)?);
match list_array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_replace::<i32>(list_array, from, to, &arr_n)
let list = list_array.as_list::<i32>();
general_replace_with_scalar::<i32>(list, &needle, scalar_to, max_replacements)
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, from, to, &arr_n)
}
DataType::Null => Ok(new_null_array(array.data_type(), 1)),
array_type => {
exec_err!("array_replace_n does not support type '{array_type}'.")
let list = list_array.as_list::<i64>();
general_replace_with_scalar::<i64>(list, &needle, scalar_to, max_replacements)
}
DataType::Null => Ok(new_null_array(list_array.data_type(), 1)),
array_type => exec_err!("array_replace does not support type '{array_type}'."),
}
}

fn array_replace_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, from, to] = take_function_args("array_replace_all", args)?;

// replace all occurrences (up to "i64::MAX")
let arr_n = vec![i64::MAX; array.len()];
fn array_replace_internal(
array: &ArrayRef,
from: &ArrayRef,
to: &ArrayRef,
arr_n: &[i64],
) -> Result<ArrayRef> {
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_replace::<i32>(list_array, from, to, &arr_n)
general_replace::<i32>(list_array, from, to, arr_n)
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, from, to, &arr_n)
general_replace::<i64>(list_array, from, to, arr_n)
}
DataType::Null => Ok(new_null_array(array.data_type(), 1)),
array_type => {
exec_err!("array_replace_all does not support type '{array_type}'.")
}
array_type => exec_err!("array_replace does not support type '{array_type}'."),
}
}
Loading
Loading