Dandandan commented on code in PR #13581:
URL: https://github.com/apache/datafusion/pull/13581#discussion_r1878563046


##########
datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs:
##########
@@ -371,6 +371,69 @@ pub fn accumulate<T, F>(
     }
 }
 
+/// Accumulates with multiple accumulate(value) columns. (e.g. `corr(c1, c2)`)
+///
+/// This method assumes that for any input record index, if any of the value 
column
+/// is null, or it's filtered out by `opt_filter`, then the record would be 
ignored.
+/// (won't be accumulated by `value_fn`)
+pub fn accumulate_multiple<T, F>(
+    group_indices: &[usize],
+    value_columns: &[&PrimitiveArray<T>],
+    opt_filter: Option<&BooleanArray>,
+    mut value_fn: F,
+) where
+    T: ArrowPrimitiveType + Send,
+    F: FnMut(usize, &[T::Native]) + Send,
+{
+    // Calculate `valid_indices` to accumulate, non-valid indices are ignored.
+    // `valid_indices` is a bit mask corresponding to the `group_indices`. An 
index
+    // is considered valid if:
+    // 1. All columns are non-null at this index.
+    // 2. Not filtered out by `opt_filter`
+
+    // Take AND from all null buffers of `value_columns`.
+    let combined_nulls = value_columns
+        .iter()
+        .map(|arr| arr.nulls())
+        .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls));
+
+    // Take AND from previous combined nulls and `opt_filter`.
+    let valid_indices = match (combined_nulls, opt_filter) {
+        (None, None) => None,
+        (None, Some(filter)) => Some(filter.clone()),
+        (Some(nulls), None) => Some(BooleanArray::new(nulls.inner().clone(), 
None)),
+        (Some(nulls), Some(filter)) => {
+            let combined = nulls.inner() & filter.values();
+            Some(BooleanArray::new(combined, None))
+        }
+    };
+
+    for col in value_columns.iter() {
+        assert_eq!(col.len(), group_indices.len());
+    }
+
+    match valid_indices {
+        None => {
+            for (idx, &group_idx) in group_indices.iter().enumerate() {
+                // Get `idx`-th row from all value(accumulate) columns
+                let row_values: Vec<_> =
+                    value_columns.iter().map(|col| col.value(idx)).collect();

Review Comment:
   Nice!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to