Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 200 additions & 27 deletions datafusion/functions-nested/src/remove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@
//! [`ScalarUDFImpl`] definitions for array_remove, array_remove_n, array_remove_all functions.

use crate::utils;
use crate::utils::make_scalar_function;
use arrow::array::{
Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, OffsetSizeTrait,
cast::AsArray, make_array,
Scalar, cast::AsArray, make_array,
};
use arrow::buffer::{NullBuffer, OffsetBuffer};
use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::cast::as_int64_array;
use datafusion_common::utils::ListCoercion;
use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args};
use datafusion_common::{
Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
};
use datafusion_expr::{
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
Expand Down Expand Up @@ -113,7 +114,21 @@ impl ScalarUDFImpl for ArrayRemove {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_remove_inner)(&args.args)
let [list_arg, element_arg] = take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
let arr_n = vec![1; num_rows];
match element_arg {
ColumnarValue::Array(element_array) => {
let result = array_remove_internal(&list_array, element_array, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
ColumnarValue::Scalar(scalar_element) => {
let result =
remove_with_scalar_needle(&list_array, scalar_element, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -214,7 +229,23 @@ impl ScalarUDFImpl for ArrayRemoveN {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_remove_n_inner)(&args.args)
let [list_arg, element_arg, max_arg] =
take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
let max_array = max_arg.to_array(num_rows)?;
let arr_n = as_int64_array(&max_array)?.values().to_vec();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're ignoring nulls in max_arg here, though I think this may be an existing issue

match element_arg {
ColumnarValue::Array(element_array) => {
let result = array_remove_internal(&list_array, element_array, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
ColumnarValue::Scalar(scalar_element) => {
let result =
remove_with_scalar_needle(&list_array, scalar_element, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -304,7 +335,21 @@ impl ScalarUDFImpl for ArrayRemoveAll {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_remove_all_inner)(&args.args)
let [list_arg, element_arg] = take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
let arr_n = vec![i64::MAX; num_rows];
match element_arg {
ColumnarValue::Array(element_array) => {
let result = array_remove_internal(&list_array, element_array, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
ColumnarValue::Scalar(scalar_element) => {
let result =
remove_with_scalar_needle(&list_array, scalar_element, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand All @@ -316,27 +361,6 @@ impl ScalarUDFImpl for ArrayRemoveAll {
}
}

fn array_remove_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, element] = take_function_args("array_remove", args)?;

let arr_n = vec![1; array.len()];
array_remove_internal(array, element, &arr_n)
}

fn array_remove_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, element, max] = take_function_args("array_remove_n", args)?;

let arr_n = as_int64_array(max)?.values().to_vec();
array_remove_internal(array, element, &arr_n)
}

fn array_remove_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, element] = take_function_args("array_remove_all", args)?;

let arr_n = vec![i64::MAX; array.len()];
array_remove_internal(array, element, &arr_n)
}

fn array_remove_internal(
array: &ArrayRef,
element_array: &ArrayRef,
Expand All @@ -357,6 +381,47 @@ fn array_remove_internal(
}
}

/// Dispatches scalar-needle array removal by list offset type.
///
/// `needle` must be a length-1 array containing the scalar element to remove.
fn array_remove_dispatch_scalar(
array: &ArrayRef,
needle: &ArrayRef,
arr_n: &[i64],
) -> Result<ArrayRef> {
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_remove_with_scalar::<i32>(list_array, needle, arr_n)
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_remove_with_scalar::<i64>(list_array, needle, arr_n)
}
array_type => exec_err!(
"array_remove/array_remove_n/array_remove_all does not support type '{array_type}'."
),
}
}

/// Removes elements matching a scalar needle from a list array.
///
/// Uses a bulk `distinct` comparison for non-null, non-nested scalar elements,
/// falling back to the per-row `general_remove` path for null or nested types.
fn remove_with_scalar_needle(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this function a little confusing, since the code paths can then be like:

invoke with args -> remove_with_scalar_needle -> array_remove_internal (fallback)

Perhaps we can hoist this null/nested check earlier to remove need for this function? Or the benefits of centralizing the check outweighs it?

list_array: &ArrayRef,
scalar_element: &ScalarValue,
arr_n: &[i64],
) -> Result<ArrayRef> {
if !scalar_element.is_null() && !scalar_element.data_type().is_nested() {
let needle = scalar_element.to_array_of_size(1)?;
array_remove_dispatch_scalar(list_array, &needle, arr_n)
} else {
let needle_array = scalar_element.to_array_of_size(list_array.len())?;
array_remove_internal(list_array, &needle_array, arr_n)
}
}

/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences
/// of `element_array[i]`.
///
Expand Down Expand Up @@ -468,6 +533,114 @@ fn general_remove<OffsetSize: OffsetSizeTrait>(
)?))
}

/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences
/// of `needle[0]` (scalar element broadcasted).
///
/// This is a specialized version of `general_remove` for scalar elements that
/// uses bulk comparison for better performance.
fn general_remove_with_scalar<OffsetSize: OffsetSizeTrait>(
list_array: &GenericListArray<OffsetSize>,
needle: &ArrayRef,
arr_n: &[i64],
Comment on lines +543 to +544
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar question to that of array_replace PR, in if we should just handle when needle & max are scalars only

) -> Result<ArrayRef> {
let list_field = match list_array.data_type() {
DataType::List(field) | DataType::LargeList(field) => field,
_ => {
return exec_err!(
"Expected List or LargeList data type, got {:?}",
list_array.data_type()
);
}
};

let list_offsets = list_array.offsets();
let first_offset = list_offsets[0].to_usize().unwrap();
let last_offset = list_offsets[list_offsets.len() - 1].to_usize().unwrap();
let values_range_len = last_offset - first_offset;
let values_slice = list_array.values().slice(first_offset, values_range_len);
let original_data = values_slice.to_data();
let mut offsets = Vec::<OffsetSize>::with_capacity(list_array.len() + 1);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offsets.push(OffsetSize::zero());

let mut mutable = MutableArrayData::with_capacities(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if an approach using take kernel could provide even more performance gains?

vec![&original_data],
false,
Capacities::Array(original_data.len()),
);
let nulls = list_array.nulls().cloned();
let keep_mask =
arrow_ord::cmp::distinct(&values_slice, &Scalar::new(Arc::clone(needle)))?;
let remove_bits = match keep_mask.nulls() {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

distinct is similar to neq, only differing in null handling. In particular, two
operands are considered DISTINCT if they have a different value or if one of them is NULL
and the other isn’t. The result of distinct is never NULL.

https://docs.rs/arrow/latest/arrow/compute/kernels/cmp/fn.distinct.html

I don't think we should have handling around the null buffer of keep_mask according to the documentation

Some(validity) => !(&(keep_mask.values() & validity.inner())),
None => !keep_mask.values(),
};

for (row_index, offset_window) in list_offsets.windows(2).enumerate() {
if nulls.as_ref().is_some_and(|nulls| nulls.is_null(row_index)) {
offsets.push(offsets[row_index]);
continue;
}

let start = offset_window[0].to_usize().unwrap() - first_offset;
let end = offset_window[1].to_usize().unwrap() - first_offset;

let n = arr_n[row_index];

if n <= 0 {
mutable.extend(0, start, end);
offsets.push(offsets[row_index] + OffsetSize::usize_as(end - start));
continue;
}

let row_len = end - start;
let row_remove_bits = remove_bits.slice(start, row_len);
let num_to_remove = row_remove_bits.count_set_bits();

if num_to_remove == 0 {
mutable.extend(0, start, end);
offsets.push(offsets[row_index] + OffsetSize::usize_as(row_len));
continue;
}

let max_removals = n.min(num_to_remove as i64) as usize;

// Iterate only over the removal positions via set_indices. This is
// efficient when the number of removals is small relative to the row
// length (common case), since it skips over retained elements.
let mut removed = 0usize;
let mut copied = 0usize;
let mut prev_end = start; // end of last copied range (absolute index into values_slice)
for remove_pos in row_remove_bits.set_indices() {
let abs_pos = start + remove_pos;
// Copy the range before this removal position
if abs_pos > prev_end {
mutable.extend(0, prev_end, abs_pos);
copied += abs_pos - prev_end;
}
prev_end = abs_pos + 1;
removed += 1;
if removed == max_removals {
break;
}
}
// Copy the remaining tail after the last removal
if prev_end < end {
mutable.extend(0, prev_end, end);
copied += end - prev_end;
}

offsets.push(offsets[row_index] + OffsetSize::usize_as(copied));
}

let new_values = make_array(mutable.freeze());
Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
Arc::clone(list_field),
OffsetBuffer::new(offsets.into()),
new_values,
nulls,
)?))
}

#[cfg(test)]
mod tests {
use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN};
Expand Down
72 changes: 72 additions & 0 deletions datafusion/sqllogictest/test_files/array/array_remove.slt
Original file line number Diff line number Diff line change
Expand Up @@ -537,4 +537,76 @@ select array_remove_all(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12]
[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]]


# array_remove scalar arguments over multiple input rows
query ???
select
array_remove(column1, 2),
array_remove_n(column1, 2, 2),
array_remove_all(column1, 2)
from (
values
(make_array(1, 2, 2, 3, 2, 1, 4)),
(make_array(42, 2, 55, 63, 2))
) as t(column1);
----
[1, 2, 3, 2, 1, 4] [1, 3, 2, 1, 4] [1, 3, 1, 4]
[42, 55, 63, 2] [42, 55, 63] [42, 55, 63]

# array_remove with elements containing NULLs — scalar path preserves NULLs
query ???
select
array_remove(column1, 2),
array_remove_n(column1, 2, 2),
array_remove_all(column1, 2)
from (
values
(make_array(1, 2, NULL, 3, 2, NULL, 4)),
(make_array(42, 2, NULL, 63, 2))
) as t(column1);
----
[1, NULL, 3, 2, NULL, 4] [1, NULL, 3, NULL, 4] [1, NULL, 3, NULL, 4]
[42, NULL, 63, 2] [42, NULL, 63] [42, NULL, 63]

# array_remove_n with n exceeding match count
query ?
select array_remove_n(make_array(1, 2, 2, 3), 2, 100);
----
[1, 3]

# array_remove_n with n=0 and n=-1 (no removal)
query ??
select
array_remove_n(make_array(1, 2, 2, 3), 2, 0),
array_remove_n(make_array(1, 2, 2, 3), 2, -1);
----
[1, 2, 2, 3] [1, 2, 2, 3]

# array_remove on empty arrays
query ??
select
array_remove(arrow_cast(make_array(), 'List(Int64)'), 1),
array_remove_all(arrow_cast(make_array(), 'List(Int64)'), 1);
----
[] []

# array_remove needle not found — array unchanged
query ?
select array_remove_all(make_array(1, 2, 3, 4, 5), 99);
----
[1, 2, 3, 4, 5]

# array_remove all elements match
query ?
select array_remove_all(make_array(7, 7, 7, 7), 7);
----
[]

# LargeList scalar path edge cases
query ??
select
array_remove_all(arrow_cast(make_array(1, 1, 1), 'LargeList(Int64)'), 1),
array_remove_n(arrow_cast(make_array(1, 1, 1), 'LargeList(Int64)'), 1, 2);
----
[] [1]

include ./cleanup.slt.part
Loading