dianfu commented on a change in pull request #13421:
URL: https://github.com/apache/flink/pull/13421#discussion_r491258305


##########
File path: flink-python/pyflink/fn_execution/beam/beam_operations_fast.pyx
##########
@@ -173,6 +173,28 @@ cdef class 
DataStreamStatelessFunctionOperation(BeamStatelessFunctionOperation):
         func = operation_utils.extract_data_stream_stateless_funcs(udfs)
         return func, []
 
+
+cdef class PandasAggregateFunctionOperation(BeamStatelessFunctionOperation):
+    def __init__(self, name, spec, counter_factory, sampler, consumers):
+        super(PandasAggregateFunctionOperation, self).__init__(name, spec, 
counter_factory,
+                                                                   sampler, 
consumers)
+
+    def generate_func(self, udfs):
+        pandas_functions, variable_dict, user_defined_funcs = reduce(
+            lambda x, y: (
+                ','.join([x[0], y[0]]),
+                dict(chain(x[1].items(), y[1].items())),
+                x[2] + y[2]),
+            [operation_utils.extract_user_defined_function(udf) for udf in 
udfs])
+        variable_dict['wrap_pandas_result'] = 
operation_utils.wrap_pandas_result
+        mapper = eval('lambda value: wrap_pandas_result([%s])' % 
pandas_functions, variable_dict)
+        if self._is_python_coder:

Review comment:
       Isn't this always true?

##########
File path: flink-python/pyflink/table/tests/test_pandas_udaf.py
##########
@@ -0,0 +1,48 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from pyflink.table.types import DataTypes
+
+from pyflink.table.udf import udaf
+from pyflink.testing import source_sink_utils
+from pyflink.testing.test_case_utils import PyFlinkBlinkBatchTableTestCase
+
+
+class BatchPandasUDAFITTests(PyFlinkBlinkBatchTableTestCase):
+    def test_group_aggregate_function(self):
+        t = self.t_env.from_elements(
+            [(1, 2, 3), (3, 2, 3), (2, 1, 3), (1, 5, 4), (1, 8, 6), (2, 3, 4)],
+            DataTypes.ROW(
+                [DataTypes.FIELD("a", DataTypes.TINYINT()),
+                 DataTypes.FIELD("b", DataTypes.SMALLINT()),
+                 DataTypes.FIELD("c", DataTypes.INT())]))
+
+        table_sink = source_sink_utils.TestAppendSink(
+            ['a', 'b'],
+            [DataTypes.TINYINT(), DataTypes.FLOAT()])
+        self.t_env.register_table_sink("Results", table_sink)
+        t.group_by("a") \
+            .select(t.a, mean_udaf(t.b)) \
+            .execute_insert("Results") \
+            .wait()
+        actual = source_sink_utils.results()
+        self.assert_equals(actual, ["1,5.0", "2,2.0", "3,2.0"])
+
+
+@udaf(result_type=DataTypes.FLOAT(), udaf_type="pandas")

Review comment:
       Do you think it makes sense to rename udaf_type to function_type? Then, 
udaf_type, udf_type can be unified to "function_type".

##########
File path: flink-python/pyflink/table/udf.py
##########
@@ -147,6 +234,41 @@ def eval(self, *args):
         return self.func(*args)
 
 
+class DelegatingPandasAggregateFunction(AggregateFunction):
+    """
+    Helper pandas aggregate function implementation for lambda expression and 
python function.
+    It's for internal use only.
+    """
+
+    def __init__(self, func):
+        self.func = func
+
+    def get_value(self, accumulator):
+        return accumulator[0]
+
+    def create_accumulator(self):
+        return []
+
+    def accumulate(self, accumulator, *args):
+        accumulator.append(self.func(*args))
+
+
+class WrapperPandasAggregateFunction(object):

Review comment:
       Rename it as `PandasAggregateFunctionWrapper` and add some description 
about this class?

##########
File path: flink-python/pyflink/table/udf.py
##########
@@ -313,6 +435,70 @@ def _create_judtf(self):
         return j_table_function
 
 
+class UserDefinedAggregateFunctionWrapper(UserDefinedFunctionWrapper):

Review comment:
       Currently we have UserDefinedAggregateFunctionWrapper, 
UserDefinedTableFunctionWrapper, UserDefinedScalarFunctionWrapper, it seems to 
me that most of the implementation is duplicate. Could we refactor them a bit?

##########
File path: flink-python/pyflink/table/udf.py
##########
@@ -121,6 +122,92 @@ def eval(self, *args):
         pass
 
 
+class AggregateFunction(UserDefinedFunction):
+    """
+

Review comment:
       remove the empty line

##########
File path: flink-python/pyflink/table/udf.py
##########
@@ -121,6 +122,92 @@ def eval(self, *args):
         pass
 
 
+class AggregateFunction(UserDefinedFunction):
+    """
+
+    Base interface for user-defined aggregate function. A user-defined 
aggregate function maps
+    scalar values of multiple rows to a new scalar value.
+
+    .. versionadded:: 1.12.0
+

Review comment:
       remove the empty line

##########
File path: flink-python/pyflink/table/udf.py
##########
@@ -421,3 +612,45 @@ def udtf(f=None, input_types=None, result_types=None, 
deterministic=None, name=N
                                  deterministic=deterministic, name=name)
     else:
         return _create_udtf(f, input_types, result_types, deterministic, name)
+
+
+def udaf(f=None, input_types=None, result_type=None, deterministic=None, 
name=None,
+         udaf_type="pandas"):
+    """
+    Helper method for creating a user-defined aggregate function.
+
+    Example:
+        ::
+
+            >>> # The input_types is optional.
+            >>> @udaf(result_type=DataTypes.FLOAT(), udaf_type="pandas")
+            ... def mean_udaf(v):
+            ...     return v.mean()
+
+    :param f: user-defined aggregate function.
+    :type f: function or UserDefinedFunction or type
+    :param input_types: optional, the input data types.
+    :type input_types: list[DataType] or DataType
+    :param result_type: the result data type.
+    :type result_type: DataType
+    :param deterministic: the determinism of the function's results. True if 
and only if a call to
+                          this function is guaranteed to always return the 
same result given the
+                          same parameters. (default True)
+    :type deterministic: bool
+    :param name: the function name.
+    :type name: str
+    :param udaf_type: the type of the python function, available value: 
general, pandas,
+                     (default: pandas)
+    :type udaf_type: str
+    :return: UserDefinedAggregateFunctionWrapper or function.
+    :rtype: UserDefinedAggregateFunctionWrapper or function
+

Review comment:
       only one empty line is enough

##########
File path: flink-python/pyflink/table/udf.py
##########
@@ -313,6 +435,70 @@ def _create_judtf(self):
         return j_table_function
 
 
+class UserDefinedAggregateFunctionWrapper(UserDefinedFunctionWrapper):
+    """
+    Wrapper for Python user-defined aggregate function.
+    """
+    def __init__(self, func, input_types, result_type, udaf_type, 
deterministic, name):
+        super(UserDefinedAggregateFunctionWrapper, self).__init__(
+            func, input_types, deterministic, name)
+
+        if not isinstance(result_type, DataType):
+            raise TypeError(
+                "Invalid returnType: returnType should be DataType but is 
{}".format(result_type))
+        self._result_type = result_type
+        self._udaf_type = udaf_type
+        self._judaf_placeholder = None
+
+    def java_user_defined_function(self):
+        if self._judaf_placeholder is None:
+            self._judaf_placeholder = self._create_judaf()
+        return self._judaf_placeholder
+
+    def _create_judaf(self):
+        gateway = get_gateway()
+
+        def get_python_function_kind(udf_type):
+            JPythonFunctionKind = 
gateway.jvm.org.apache.flink.table.functions.python. \
+                PythonFunctionKind
+            if udf_type == "general":
+                return JPythonFunctionKind.GENERAL
+            elif udf_type == "pandas":
+                return JPythonFunctionKind.PANDAS
+            else:
+                raise TypeError("Unsupported udf_type: %s." % udf_type)
+
+        func = self._func
+        if not isinstance(self._func, UserDefinedFunction):
+            func = DelegatingPandasAggregateFunction(self._func)
+
+        if self._udaf_type == "pandas":
+            func = WrapperPandasAggregateFunction(func)

Review comment:
       We can also create the wrapper class during execution.

##########
File path: 
flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonAggregateFunction.java
##########
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions.python;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.table.catalog.DataTypeFactory;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.TypeInference;
+import org.apache.flink.table.types.inference.TypeStrategies;
+import org.apache.flink.table.types.utils.TypeConversions;
+
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+/**
+ * The wrapper of user defined python aggregate function.
+ */
+@Internal
+public class PythonAggregateFunction extends AggregateFunction implements 
PythonFunction {
+
+       private static final long serialVersionUID = 1L;
+
+       private final String name;
+       private final byte[] serializedAggregateFunction;
+       private final TypeInformation[] inputTypes;
+       private final TypeInformation resultType;
+       private final PythonFunctionKind pythonFunctionKind;
+       private final boolean deterministic;
+       private final PythonEnv pythonEnv;
+
+       public PythonAggregateFunction(
+               String name,
+               byte[] serializedAggregateFunction,
+               TypeInformation[] inputTypes,
+               TypeInformation resultType,
+               PythonFunctionKind pythonFunctionKind,
+               boolean deterministic,
+               PythonEnv pythonEnv) {
+               this.name = name;
+               this.serializedAggregateFunction = serializedAggregateFunction;
+               this.inputTypes = inputTypes;
+               this.resultType = resultType;
+               this.pythonFunctionKind = pythonFunctionKind;
+               this.deterministic = deterministic;
+               this.pythonEnv = pythonEnv;
+       }
+
+       public void accumulate(Object accumulator, Object... args) {
+               throw new UnsupportedOperationException(
+                       "This method is a placeholder and should not be 
called.");
+       }
+
+       @Override
+       public Object getValue(Object accumulator) {
+               return null;
+       }
+
+       @Override
+       public Object createAccumulator() {
+               return null;
+       }
+
+       @Override
+       public byte[] getSerializedPythonFunction() {
+               return serializedAggregateFunction;
+       }
+
+       @Override
+       public PythonEnv getPythonEnv() {
+               return pythonEnv;
+       }
+
+       @Override
+       public PythonFunctionKind getPythonFunctionKind() {
+               return pythonFunctionKind;
+       }
+
+       @Override
+       public boolean isDeterministic() {
+               return deterministic;
+       }
+
+       @Override
+       public TypeInformation getResultType() {
+               return resultType;
+       }
+
+       @Override
+       public TypeInformation getAccumulatorType() {
+               return resultType;

Review comment:
       Why returns the resultType?

##########
File path: 
flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonGroupAggregate.scala
##########
@@ -168,6 +177,78 @@ class BatchExecPythonGroupAggregate(
 
   override protected def translateToPlanInternal(
       planner: BatchPlanner): Transformation[RowData] = {
-    throw new TableException("The implementation will be in FLINK-19186.")
+    val input = getInputNodes.get(0).translateToPlan(planner)
+      .asInstanceOf[Transformation[RowData]]
+    val outputType = FlinkTypeFactory.toLogicalRowType(getRowType)
+    val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType)
+
+    val ret = createPythonOneInputTransformation(
+      input,
+      inputType,
+      outputType,
+      grouping,
+      getConfig(planner.getExecEnv, planner.getTableConfig))
+    if 
(isPythonWorkerUsingManagedMemory(planner.getTableConfig.getConfiguration)) {
+      ExecNode.setManagedMemoryWeight(
+        ret, 
getPythonWorkerMemory(planner.getTableConfig.getConfiguration).getBytes)
+    }
+    ret
+  }
+
+  private[this] def createPythonOneInputTransformation(
+      inputTransform: Transformation[RowData],
+      inputRowType: RowType,
+      outputRowType: RowType,
+      groupingSet: Array[Int],

Review comment:
       groupingSet is never used?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to