jayzhan211 commented on code in PR #13581:
URL: https://github.com/apache/datafusion/pull/13581#discussion_r1865684095


##########
datafusion/functions-aggregate/src/correlation.rs:
##########
@@ -263,3 +283,307 @@ impl Accumulator for CorrelationAccumulator {
         Ok(())
     }
 }
+
+#[derive(Default)]
+pub struct CorrelationGroupsAccumulator {
+    // Number of elements for each group
+    // This is also used to track nulls: if a group has 0 valid values 
accumulated,
+    // final aggregation result will be null.
+    count: Vec<u64>,
+    // Sum of x values for each group
+    sum_x: Vec<f64>,
+    // Sum of y
+    sum_y: Vec<f64>,
+    // Sum of x*y
+    sum_xy: Vec<f64>,
+    // Sum of x^2
+    sum_xx: Vec<f64>,
+    // Sum of y^2
+    sum_yy: Vec<f64>,
+}
+
+impl CorrelationGroupsAccumulator {
+    pub fn new() -> Self {
+        Default::default()
+    }
+}
+
+/// Specialized version of `accumulate_multiple` for correlation's merge_batch
+///
+/// Note: Arrays in `state_arrays` should not have null values, because they 
are all
+/// intermediate states created within the accumulator, instead of inputs from
+/// outside.
+fn accumulate_correlation_states(
+    group_indices: &[usize],
+    state_arrays: (
+        &UInt64Array,  // count
+        &Float64Array, // sum_x
+        &Float64Array, // sum_y
+        &Float64Array, // sum_xy
+        &Float64Array, // sum_xx
+        &Float64Array, // sum_yy
+    ),
+    mut value_fn: impl FnMut(usize, u64, &[f64]),
+) {
+    let (counts, sum_x, sum_y, sum_xy, sum_xx, sum_yy) = state_arrays;
+
+    assert_eq!(counts.null_count(), 0);
+    assert_eq!(sum_x.null_count(), 0);
+    assert_eq!(sum_y.null_count(), 0);
+    assert_eq!(sum_xy.null_count(), 0);
+    assert_eq!(sum_xx.null_count(), 0);
+    assert_eq!(sum_yy.null_count(), 0);
+
+    let counts_values = counts.values().as_ref();
+    let sum_x_values = sum_x.values().as_ref();
+    let sum_y_values = sum_y.values().as_ref();
+    let sum_xy_values = sum_xy.values().as_ref();
+    let sum_xx_values = sum_xx.values().as_ref();
+    let sum_yy_values = sum_yy.values().as_ref();
+
+    let mut row = [0.0; 5];
+    for (idx, &group_idx) in group_indices.iter().enumerate() {
+        row[0] = sum_x_values[idx];
+        row[1] = sum_y_values[idx];
+        row[2] = sum_xy_values[idx];
+        row[3] = sum_xx_values[idx];
+        row[4] = sum_yy_values[idx];
+        value_fn(group_idx, counts_values[idx], &row);
+    }
+}
+
+/// GroupsAccumulator implementation for `corr(x, y)` that computes the 
Pearson correlation coefficient
+/// between two numeric columns.
+///
+/// Online algorithm for correlation:
+///
+/// r = (n * sum_xy - sum_x * sum_y) / sqrt((n * sum_xx - sum_x^2) * (n * 
sum_yy - sum_y^2))
+/// where:
+/// n = number of observations
+/// sum_x = sum of x values
+/// sum_y = sum of y values  
+/// sum_xy = sum of (x * y)
+/// sum_xx = sum of x^2 values
+/// sum_yy = sum of y^2 values
+///
+/// Reference: 
<https://en.wikipedia.org/wiki/Pearson_correlation_coefficient#For_a_sample>
+impl GroupsAccumulator for CorrelationGroupsAccumulator {
+    fn update_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        self.count.resize(total_num_groups, 0);
+        self.sum_x.resize(total_num_groups, 0.0);
+        self.sum_y.resize(total_num_groups, 0.0);
+        self.sum_xy.resize(total_num_groups, 0.0);
+        self.sum_xx.resize(total_num_groups, 0.0);
+        self.sum_yy.resize(total_num_groups, 0.0);
+
+        let array_x = &cast(&values[0], &DataType::Float64)?;

Review Comment:
   I think casting should be handled in logical optimizer. Fixing the signature 
of `Correlation` might helps



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to