rluvaton commented on code in PR #20498:
URL: https://github.com/apache/datafusion/pull/20498#discussion_r2907332726
##########
datafusion/physical-expr-common/src/utils.rs:
##########
@@ -50,30 +62,332 @@ impl ExprPropertiesNode {
}
}
+/// If the mask selects more than this fraction of rows, use
+/// `set_slices()` to copy contiguous ranges. Otherwise iterate
+/// over individual positions using `set_indices()`.
+const SCATTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
+
/// Scatter `truthy` array by boolean mask. When the mask evaluates `true`,
next values of `truthy`
/// are taken, when the mask evaluates `false` values null values are filled.
///
/// # Arguments
/// * `mask` - Boolean values used to determine where to put the `truthy`
values
/// * `truthy` - All values of this array are to scatter according to `mask`
into final result.
pub fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result<ArrayRef> {
- let truthy = truthy.to_data();
+ let mask = match mask.null_count() {
+ 0 => Cow::Borrowed(mask),
+ n if n == mask.len() => {
+ return Ok(new_null_array(truthy.data_type(), mask.len()));
+ }
+ _ => Cow::Owned(prep_null_mask_filter(mask)),
+ };
+
+ let output_len = mask.len();
+ let count = mask.true_count();
+
+ // Fast path: no true values mean all-null object
+ if count == 0 {
+ return Ok(new_null_array(truthy.data_type(), output_len));
+ }
+
+ // Fast path: all true means output = truthy
+ if count == output_len {
+ return Ok(truthy.slice(0, truthy.len()));
+ }
+
+ let selectivity = count as f64 / output_len as f64;
+ let mask_buffer = mask.values();
+
+ scatter_array(truthy, mask_buffer, output_len, selectivity)
+}
+
+fn scatter_array(
+ truthy: &dyn Array,
+ mask: &BooleanBuffer,
+ output_len: usize,
+ selectivity: f64,
+) -> Result<ArrayRef> {
+ downcast_primitive_array! {
+ truthy => Ok(Arc::new(scatter_primitive(truthy, mask, output_len,
selectivity))),
+ DataType::Boolean => {
+ Ok(Arc::new(scatter_boolean(truthy.as_boolean(), mask, output_len,
selectivity)))
+ }
+ DataType::Utf8 => {
+ Ok(Arc::new(scatter_bytes(truthy.as_string::<i32>(), mask,
output_len, selectivity)))
+ }
+ DataType::LargeUtf8 => {
+ Ok(Arc::new(scatter_bytes(truthy.as_string::<i64>(), mask,
output_len, selectivity)))
+ }
+ DataType::Utf8View => {
+ Ok(Arc::new(scatter_byte_view(truthy.as_string_view(), mask,
output_len, selectivity)))
+ }
+ DataType::Binary => {
+ Ok(Arc::new(scatter_bytes(truthy.as_binary::<i32>(), mask,
output_len, selectivity)))
+ }
+ DataType::LargeBinary => {
+ Ok(Arc::new(scatter_bytes(truthy.as_binary::<i64>(), mask,
output_len, selectivity)))
+ }
+ DataType::BinaryView => {
+ Ok(Arc::new(scatter_byte_view(truthy.as_binary_view(), mask,
output_len, selectivity)))
+ }
+ DataType::FixedSizeBinary(_) => {
+ Ok(Arc::new(scatter_fixed_size_binary(
+ truthy.as_fixed_size_binary(), mask, output_len, selectivity,
+ )))
+ }
+ DataType::Dictionary(_, _) => {
+ downcast_dictionary_array! {
+ truthy => Ok(Arc::new(scatter_dict(truthy, mask, output_len,
selectivity))),
+ _t => scatter_fallback(truthy, mask, output_len)
+ }
+ }
+ _ => scatter_fallback(truthy, mask, output_len)
+ }
+}
+
+#[inline(never)]
+fn scatter_native<T: ArrowNativeType>(
+ src: &[T],
+ mask: &BooleanBuffer,
+ output_len: usize,
+ selectivity: f64,
+) -> Buffer {
+ let mut output = vec![T::default(); output_len];
+ let mut src_offset = 0;
+
+ if selectivity > SCATTER_SLICES_SELECTIVITY_THRESHOLD {
+ for (start, end) in mask.set_slices() {
+ let len = end - start;
+ output[start..end].copy_from_slice(&src[src_offset..src_offset +
len]);
+ src_offset += len;
+ }
+ } else {
+ for dst_idx in mask.set_indices() {
+ output[dst_idx] = src[src_offset];
+ src_offset += 1;
+ }
+ }
+
+ output.into()
+}
+
+fn scatter_bits(
+ src: &BooleanBuffer,
+ mask: &BooleanBuffer,
+ output_len: usize,
+ selectivity: f64,
+) -> Buffer {
+ let mut builder = BooleanBufferBuilder::new(output_len);
+ builder.advance(output_len);
+ let mut src_offset = 0;
+
+ if selectivity > SCATTER_SLICES_SELECTIVITY_THRESHOLD {
+ for (start, end) in mask.set_slices() {
+ for i in start..end {
+ if src.value(src_offset) {
+ builder.set_bit(i, true);
+ }
+ src_offset += 1;
+ }
+ }
+ } else {
+ for dst_idx in mask.set_indices() {
+ if src.value(src_offset) {
+ builder.set_bit(dst_idx, true);
+ }
+ src_offset += 1;
+ }
+ }
+
+ builder.finish().into_inner()
+}
+
+fn scatter_null_mask(
+ src_nulls: Option<&NullBuffer>,
+ mask: &BooleanBuffer,
+ output_len: usize,
+ selectivity: f64,
+) -> Option<(usize, Buffer)> {
+ let false_count = output_len - mask.count_set_bits();
+ let src_null_count = src_nulls.map(|n| n.null_count()).unwrap_or(0);
+
+ if src_null_count == 0 {
+ if false_count == 0 {
+ None
+ } else {
+ Some((false_count, mask.inner().clone()))
+ }
+ } else {
+ let src_nulls = src_nulls.unwrap();
+ let scattered = scatter_bits(src_nulls.inner(), mask, output_len,
selectivity);
+ let valid_count = scattered.count_set_bits_offset(0, output_len);
+ let null_count = output_len - valid_count;
+ if null_count == 0 {
+ None
+ } else {
+ Some((null_count, scattered))
+ }
Review Comment:
you can then do:
```rust
Some(NullBuffer::new(scattered)).filter(|n| n.null_count() > 0)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]