Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 182 additions & 27 deletions datafusion/functions-nested/src/remove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@
//! [`ScalarUDFImpl`] definitions for array_remove, array_remove_n, array_remove_all functions.

use crate::utils;
use crate::utils::make_scalar_function;
use arrow::array::{
Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, OffsetSizeTrait,
cast::AsArray, make_array,
Scalar, cast::AsArray, make_array,
};
use arrow::buffer::{NullBuffer, OffsetBuffer};
use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::cast::as_int64_array;
use datafusion_common::utils::ListCoercion;
use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args};
use datafusion_common::{
Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
};
use datafusion_expr::{
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
Expand Down Expand Up @@ -113,7 +114,21 @@ impl ScalarUDFImpl for ArrayRemove {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_remove_inner)(&args.args)
let [list_arg, element_arg] = take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
let arr_n = vec![1; num_rows];
match element_arg {
ColumnarValue::Array(element_array) => {
let result = array_remove_internal(&list_array, element_array, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
ColumnarValue::Scalar(scalar_element) => {
let result =
remove_with_scalar_needle(&list_array, scalar_element, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -214,7 +229,23 @@ impl ScalarUDFImpl for ArrayRemoveN {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_remove_n_inner)(&args.args)
let [list_arg, element_arg, max_arg] =
take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
let max_array = max_arg.to_array(num_rows)?;
let arr_n = as_int64_array(&max_array)?.values().to_vec();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're ignoring nulls in max_arg here, though I think this may be an existing issue

match element_arg {
ColumnarValue::Array(element_array) => {
let result = array_remove_internal(&list_array, element_array, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
ColumnarValue::Scalar(scalar_element) => {
let result =
remove_with_scalar_needle(&list_array, scalar_element, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -304,7 +335,21 @@ impl ScalarUDFImpl for ArrayRemoveAll {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_remove_all_inner)(&args.args)
let [list_arg, element_arg] = take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
let arr_n = vec![i64::MAX; num_rows];
match element_arg {
ColumnarValue::Array(element_array) => {
let result = array_remove_internal(&list_array, element_array, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
ColumnarValue::Scalar(scalar_element) => {
let result =
remove_with_scalar_needle(&list_array, scalar_element, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand All @@ -316,27 +361,6 @@ impl ScalarUDFImpl for ArrayRemoveAll {
}
}

fn array_remove_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, element] = take_function_args("array_remove", args)?;

let arr_n = vec![1; array.len()];
array_remove_internal(array, element, &arr_n)
}

fn array_remove_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, element, max] = take_function_args("array_remove_n", args)?;

let arr_n = as_int64_array(max)?.values().to_vec();
array_remove_internal(array, element, &arr_n)
}

fn array_remove_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, element] = take_function_args("array_remove_all", args)?;

let arr_n = vec![i64::MAX; array.len()];
array_remove_internal(array, element, &arr_n)
}

fn array_remove_internal(
array: &ArrayRef,
element_array: &ArrayRef,
Expand All @@ -357,6 +381,45 @@ fn array_remove_internal(
}
}

/// Dispatches scalar-needle array removal by list offset type.
///
/// `needle` must be a length-1 array containing the scalar element to remove.
fn array_remove_dispatch_scalar(
array: &ArrayRef,
needle: &ArrayRef,
arr_n: &[i64],
) -> Result<ArrayRef> {
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_remove_with_scalar::<i32>(list_array, needle, arr_n)
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_remove_with_scalar::<i64>(list_array, needle, arr_n)
}
array_type => exec_err!("array_remove does not support type '{array_type}'."),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is called by more than just array_remove; can we improve the error message?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, updated.

}
}

/// Removes elements matching a scalar needle from a list array.
///
/// Uses a bulk `distinct` comparison for non-null, non-nested scalar elements,
/// falling back to the per-row `general_remove` path for null or nested types.
fn remove_with_scalar_needle(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this function a little confusing, since the code paths can then be like:

invoke with args -> remove_with_scalar_needle -> array_remove_internal (fallback)

Perhaps we can hoist this null/nested check earlier to remove need for this function? Or the benefits of centralizing the check outweighs it?

list_array: &ArrayRef,
scalar_element: &ScalarValue,
arr_n: &[i64],
) -> Result<ArrayRef> {
if !scalar_element.is_null() && !scalar_element.data_type().is_nested() {
let needle = scalar_element.to_array_of_size(1)?;
array_remove_dispatch_scalar(list_array, &needle, arr_n)
} else {
let needle_array = scalar_element.to_array_of_size(list_array.len())?;
array_remove_internal(list_array, &needle_array, arr_n)
}
}

/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences
/// of `element_array[i]`.
///
Expand Down Expand Up @@ -468,6 +531,98 @@ fn general_remove<OffsetSize: OffsetSizeTrait>(
)?))
}

/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences
/// of `needle[0]` (scalar element broadcasted).
///
/// This is a specialized version of `general_remove` for scalar elements that
/// uses bulk comparison for better performance.
fn general_remove_with_scalar<OffsetSize: OffsetSizeTrait>(
list_array: &GenericListArray<OffsetSize>,
needle: &ArrayRef,
arr_n: &[i64],
Comment on lines +543 to +544
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar question to that of array_replace PR, in if we should just handle when needle & max are scalars only

) -> Result<ArrayRef> {
let list_field = match list_array.data_type() {
DataType::List(field) | DataType::LargeList(field) => field,
_ => {
return exec_err!(
"Expected List or LargeList data type, got {:?}",
list_array.data_type()
);
}
};
let original_data = list_array.values().to_data();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be inefficient for sliced arrays.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now slice the values to the range actually referenced by the offsets.

That said, I wanted to understand your concern better: when a GenericListArray is sliced, values() returns the full underlying array, and to_data() on it wraps the existing buffer references into ArrayData without copying. So the main downside I could identify is that Capacities::Array(original_data.len()) over-estimates the pre-allocation for sliced inputs. Were you thinking of a different inefficiency, or is the over-allocation what you had in mind?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The over-allocation was one part, but the bigger concern is calling the distinct kernel on the entire values buffer (see other comment).

let mut offsets = Vec::<OffsetSize>::with_capacity(list_array.len() + 1);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offsets.push(OffsetSize::zero());

let mut mutable = MutableArrayData::with_capacities(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if an approach using take kernel could provide even more performance gains?

vec![&original_data],
false,
Capacities::Array(original_data.len()),
);
let nulls = list_array.nulls().cloned();
let keep_mask =
arrow_ord::cmp::distinct(list_array.values(), &Scalar::new(Arc::clone(needle)))?;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will call the distinct kernel on all the elements in the value buffer, not just the ones that are visible in a sliced array.


for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
if nulls.as_ref().is_some_and(|nulls| nulls.is_null(row_index)) {
offsets.push(offsets[row_index]);
continue;
}

let start = offset_window[0].to_usize().unwrap();
let end = offset_window[1].to_usize().unwrap();

let n = arr_n[row_index];

if n <= 0 {
mutable.extend(0, start, end);
offsets.push(offsets[row_index] + OffsetSize::usize_as(end - start));
continue;
}

let eq_array = keep_mask.slice(start, end - start);
let num_to_remove = eq_array.false_count();

if num_to_remove == 0 {
mutable.extend(0, start, end);
offsets.push(offsets[row_index] + OffsetSize::usize_as(end - start));
continue;
}

let max_removals = n.min(num_to_remove as i64);
let mut removed = 0i64;
let mut copied = 0usize;
let mut pending_batch_to_retain: Option<usize> = None;
for (i, keep) in eq_array.iter().enumerate() {
if keep == Some(false) && removed < max_removals {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we break from the loop once we hit max_removals?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. now break early once max_removals is reached.

if let Some(bs) = pending_batch_to_retain {
mutable.extend(0, start + bs, start + i);
copied += i - bs;
pending_batch_to_retain = None;
}
removed += 1;
} else if pending_batch_to_retain.is_none() {
pending_batch_to_retain = Some(i);
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would be possible to iterate only over the "false" bits, e.g., by negating the buffer and looking at BooleanBuffer::set_indices.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestion. Benchmarks show a ~20–40% improvement with this optimization.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing!


if let Some(bs) = pending_batch_to_retain {
mutable.extend(0, start + bs, end);
copied += end - start - bs;
}

offsets.push(offsets[row_index] + OffsetSize::usize_as(copied));
}

let new_values = make_array(mutable.freeze());
Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
Arc::clone(list_field),
OffsetBuffer::new(offsets.into()),
new_values,
nulls,
)?))
}

#[cfg(test)]
mod tests {
use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN};
Expand Down
72 changes: 72 additions & 0 deletions datafusion/sqllogictest/test_files/array/array_remove.slt
Original file line number Diff line number Diff line change
Expand Up @@ -537,4 +537,76 @@ select array_remove_all(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12]
[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]]


# array_remove scalar arguments over multiple input rows
query ???
select
array_remove(column1, 2),
array_remove_n(column1, 2, 2),
array_remove_all(column1, 2)
from (
values
(make_array(1, 2, 2, 3, 2, 1, 4)),
(make_array(42, 2, 55, 63, 2))
) as t(column1);
----
[1, 2, 3, 2, 1, 4] [1, 3, 2, 1, 4] [1, 3, 1, 4]
[42, 55, 63, 2] [42, 55, 63] [42, 55, 63]

# array_remove with elements containing NULLs — scalar path preserves NULLs
query ???
select
array_remove(column1, 2),
array_remove_n(column1, 2, 2),
array_remove_all(column1, 2)
from (
values
(make_array(1, 2, NULL, 3, 2, NULL, 4)),
(make_array(42, 2, NULL, 63, 2))
) as t(column1);
----
[1, NULL, 3, 2, NULL, 4] [1, NULL, 3, NULL, 4] [1, NULL, 3, NULL, 4]
[42, NULL, 63, 2] [42, NULL, 63] [42, NULL, 63]

# array_remove_n with n exceeding match count
query ?
select array_remove_n(make_array(1, 2, 2, 3), 2, 100);
----
[1, 3]

# array_remove_n with n=0 and n=-1 (no removal)
query ??
select
array_remove_n(make_array(1, 2, 2, 3), 2, 0),
array_remove_n(make_array(1, 2, 2, 3), 2, -1);
----
[1, 2, 2, 3] [1, 2, 2, 3]

# array_remove on empty arrays
query ??
select
array_remove(arrow_cast(make_array(), 'List(Int64)'), 1),
array_remove_all(arrow_cast(make_array(), 'List(Int64)'), 1);
----
[] []

# array_remove needle not found — array unchanged
query ?
select array_remove_all(make_array(1, 2, 3, 4, 5), 99);
----
[1, 2, 3, 4, 5]

# array_remove all elements match
query ?
select array_remove_all(make_array(7, 7, 7, 7), 7);
----
[]

# LargeList scalar path edge cases
query ??
select
array_remove_all(arrow_cast(make_array(1, 1, 1), 'LargeList(Int64)'), 1),
array_remove_n(arrow_cast(make_array(1, 1, 1), 'LargeList(Int64)'), 1, 2);
----
[] [1]

include ./cleanup.slt.part
Loading