Copilot commented on code in PR #1545:
URL: 
https://github.com/apache/datafusion-python/pull/1545#discussion_r3274515523


##########
python/tests/test_pickle_expr.py:
##########
@@ -147,6 +147,131 @@ def test_multi_arg_udf_round_trip(self):
         assert "add_scaled" in decoded.canonical_name()
 
 
+class TestAggregateUDFCodec:
+    """Python aggregate UDFs travel inline like scalar UDFs."""
+
+    def _build_aggregate_udf(self):
+        from datafusion import udaf
+        from datafusion.user_defined import Accumulator
+
+        class CountAcc(Accumulator):
+            def __init__(self):
+                self._count = 0
+
+            def state(self):
+                return [pa.scalar(self._count, type=pa.int64())]
+
+            def update(self, values):
+                self._count += len(values)
+
+            def merge(self, states):
+                for s in states:
+                    self._count += s[0].as_py()

Review Comment:
   `CountAcc.merge` is iterating over the list of state *fields* (`states`) 
rather than the elements within the first state field (`states[0]`). For a 
single-state UDAF, `states` will have length 1 and `states[0]` contains one 
value per partition; the current implementation only adds the first partition's 
state (via `s[0]`). Update `merge` to aggregate all values from `states[0]` 
(e.g., sum over the array) so it merges correctly when there are multiple 
partitions.
   



##########
python/tests/test_pickle_expr.py:
##########
@@ -147,6 +147,131 @@ def test_multi_arg_udf_round_trip(self):
         assert "add_scaled" in decoded.canonical_name()
 
 
+class TestAggregateUDFCodec:
+    """Python aggregate UDFs travel inline like scalar UDFs."""
+
+    def _build_aggregate_udf(self):
+        from datafusion import udaf
+        from datafusion.user_defined import Accumulator
+
+        class CountAcc(Accumulator):
+            def __init__(self):
+                self._count = 0
+
+            def state(self):
+                return [pa.scalar(self._count, type=pa.int64())]
+
+            def update(self, values):
+                self._count += len(values)
+
+            def merge(self, states):
+                for s in states:
+                    self._count += s[0].as_py()
+
+            def evaluate(self):
+                return pa.scalar(self._count, type=pa.int64())
+
+        return udaf(
+            CountAcc,
+            [pa.int64()],
+            pa.int64(),
+            [pa.int64()],
+            "immutable",
+            name="count_all",
+        )
+
+    def test_agg_udf_self_contained_blob(self):
+        u = self._build_aggregate_udf()
+        e = u(col("a"))
+        blob = pickle.dumps(e)
+        assert len(blob) > 200
+
+    def test_agg_udf_decodes_into_fresh_ctx(self):
+        u = self._build_aggregate_udf()
+        e = u(col("a"))
+        blob = e.to_bytes()
+        fresh = SessionContext()
+        decoded = Expr.from_bytes(blob, ctx=fresh)
+        assert "count_all" in decoded.canonical_name()
+
+    def test_agg_udf_decodes_via_pickle_with_no_worker_ctx(self):
+        u = self._build_aggregate_udf()
+        e = u(col("a"))
+        blob = pickle.dumps(e)
+        decoded = pickle.loads(blob)  # noqa: S301
+        assert "count_all" in decoded.canonical_name()
+
+    def test_agg_udf_evaluates_after_roundtrip(self):
+        """End-to-end: the decoded aggregate UDF runs and merges across
+        partitions, exercising the round-tripped state-field schema."""
+        u = self._build_aggregate_udf()
+        e = u(col("a"))
+        decoded = pickle.loads(pickle.dumps(e))  # noqa: S301
+
+        ctx = SessionContext()
+        df = ctx.from_pydict({"a": [1, 2, 3, 4, 5]})

Review Comment:
   This test claims to "merge across partitions", but the input is built with 
`ctx.from_pydict`, which typically produces a single partition/record batch. As 
a result, `merge` may never be exercised and the test won't actually validate 
the round-tripped state schema across partitions. Consider constructing a 
DataFrame with multiple partitions (e.g., `ctx.create_dataframe([[batch1], 
[batch2]], ...)`) so `merge` is guaranteed to run.
   



##########
crates/core/src/codec.rs:
##########
@@ -126,6 +131,16 @@ use crate::udf::PythonFunctionScalarUDF;
 /// volatility).
 pub(crate) const PY_SCALAR_UDF_FAMILY: &[u8] = b"DFPYUDF";
 
+/// Family prefix for an inlined Python aggregate UDF
+/// (cloudpickled tuple of name, accumulator factory, input schema,
+/// return type, state types schema, volatility).
+pub(crate) const PY_AGG_UDF_FAMILY: &[u8] = b"DFPYUDA";
+
+/// Family prefix for an inlined Python window UDF
+/// (cloudpickled tuple of name, evaluator factory, input schema,
+/// return type, volatility).

Review Comment:
   The doc comments for `PY_AGG_UDF_FAMILY` / `PY_WINDOW_UDF_FAMILY` describe 
the payload as including a "return type", but the encoder actually writes 
`return_schema_bytes` (a single-field IPC schema) for both aggregate and window 
UDFs. Please align these comments with the actual tuple shape (schema bytes) to 
avoid confusion for anyone implementing/inspecting the wire format.
   



##########
crates/core/src/codec.rs:
##########
@@ -642,6 +710,186 @@ fn cloudpickle<'py>(py: Python<'py>) -> 
PyResult<Bound<'py, PyAny>> {
         .map(|cached| cached.bind(py).clone())
 }
 
+// 
=============================================================================
+// Shared Python window UDF encode / decode helpers
+//
+// Cloudpickle tuple shape: `(name, evaluator_factory, input_schema_bytes,
+// return_schema_bytes, volatility_str)`. The evaluator factory is the
+// Python callable that produces a new evaluator instance per partition.
+// 
=============================================================================
+
+pub(crate) fn try_encode_python_window_udf(node: &WindowUDF, buf: &mut 
Vec<u8>) -> Result<bool> {
+    let Some(py_udf) = node.inner().downcast_ref::<PythonFunctionWindowUDF>() 
else {
+        return Ok(false);
+    };
+
+    Python::attach(|py| -> Result<bool> {
+        let bytes = encode_python_window_udf(py, 
py_udf).map_err(to_datafusion_err)?;
+        append_framed_payload(py, buf, PY_WINDOW_UDF_FAMILY, &bytes)?;
+        Ok(true)
+    })
+}
+
+pub(crate) fn try_decode_python_window_udf(buf: &[u8]) -> 
Result<Option<Arc<WindowUDF>>> {
+    Python::attach(|py| -> Result<Option<Arc<WindowUDF>>> {
+        let Some(payload) = read_framed_payload(py, buf, PY_WINDOW_UDF_FAMILY, 
"window UDF")?
+        else {
+            return Ok(None);
+        };
+        let udf = decode_python_window_udf(py, 
payload).map_err(to_datafusion_err)?;
+        Ok(Some(Arc::new(WindowUDF::new_from_impl(udf))))
+    })
+}
+
+fn encode_python_window_udf(py: Python<'_>, udf: &PythonFunctionWindowUDF) -> 
PyResult<Vec<u8>> {
+    let signature = WindowUDFImpl::signature(udf);
+    let input_dtypes = signature_input_dtypes(signature, 
"PythonFunctionWindowUDF")?;
+    let input_schema_bytes = build_input_schema_bytes(&input_dtypes)?;
+    let return_field = Field::new("result", udf.return_type().clone(), true);
+    let return_schema_bytes = build_single_field_schema_bytes(&return_field)?;
+    let volatility = volatility_wire_str(signature.volatility);
+
+    let payload = PyTuple::new(
+        py,
+        [
+            WindowUDFImpl::name(udf).into_pyobject(py)?.into_any(),
+            udf.evaluator().bind(py).clone().into_any(),
+            PyBytes::new(py, &input_schema_bytes).into_any(),
+            PyBytes::new(py, &return_schema_bytes).into_any(),
+            volatility.into_pyobject(py)?.into_any(),
+        ],
+    )?;
+
+    cloudpickle(py)?
+        .call_method1("dumps", (payload,))?
+        .extract::<Vec<u8>>()
+}
+
+fn decode_python_window_udf(py: Python<'_>, payload: &[u8]) -> 
PyResult<PythonFunctionWindowUDF> {
+    let tuple = cloudpickle(py)?
+        .call_method1("loads", (PyBytes::new(py, payload),))?
+        .cast_into::<PyTuple>()?;
+
+    let name: String = tuple.get_item(0)?.extract()?;
+    let evaluator: Py<PyAny> = tuple.get_item(1)?.unbind();
+    let input_schema_bytes: Vec<u8> = tuple.get_item(2)?.extract()?;
+    let return_schema_bytes: Vec<u8> = tuple.get_item(3)?.extract()?;
+    let volatility_str: String = tuple.get_item(4)?.extract()?;
+
+    let input_types = read_input_dtypes(&input_schema_bytes)?;
+    let return_type = read_single_return_field(&return_schema_bytes, 
"PythonFunctionWindowUDF")?
+        .data_type()
+        .clone();
+    let volatility = parse_volatility_str(&volatility_str)?;
+
+    Ok(PythonFunctionWindowUDF::new(
+        name,
+        evaluator,
+        input_types,
+        return_type,
+        volatility,
+    ))
+}
+
+// 
=============================================================================
+// Shared Python aggregate UDF encode / decode helpers
+//
+// Cloudpickle tuple shape: `(name, accumulator_factory, input_schema_bytes,
+// return_type_bytes, state_schema_bytes, volatility_str)`. The accumulator

Review Comment:
   The aggregate UDF wire-format comment says `return_type_bytes`, but the 
payload is encoded/decoded as `return_schema_bytes` (single-field IPC schema), 
consistent with the scalar/window paths. Please correct the comment so it 
matches the actual on-wire layout.
   



##########
crates/core/src/lib.rs:
##########
@@ -59,11 +59,11 @@ mod array;
 #[cfg(feature = "substrait")]
 pub mod substrait;
 #[allow(clippy::borrow_deref_ref)]
-mod udaf;
+pub mod udaf;
 #[allow(clippy::borrow_deref_ref)]
 mod udf;
 pub mod udtf;
-mod udwf;
+pub mod udwf;
 

Review Comment:
   `udaf` / `udwf` are switched from private `mod` to `pub mod`, which expands 
the crate's public Rust API surface significantly (and is inconsistent with 
`udf` remaining private). If the intent is only to keep specific helpers/types 
available for downstream callers, consider keeping the modules private and `pub 
use`-reexporting only the needed items (e.g. `to_rust_accumulator`, 
`to_rust_partition_evaluator`, and/or the compatibility type alias) to avoid 
committing to the full module as public API.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to