walterddr commented on a change in pull request #11344: 
[FLINK-16250][python][ml] Add interfaces for PipelineStage and Pipeline
URL: https://github.com/apache/flink/pull/11344#discussion_r391169569
 
 

 ##########
 File path: flink-python/pyflink/ml/api/base.py
 ##########
 @@ -0,0 +1,275 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import re
+
+from abc import ABCMeta, abstractmethod
+
+from pyflink.table.table_environment import TableEnvironment
+from pyflink.table.table import Table
+from pyflink.ml.api.param import WithParams, Params
+from py4j.java_gateway import get_field
+
+
+class PipelineStage(WithParams):
+    """
+    Base class for a stage in a pipeline. The interface is only a concept, and 
does not have any
+    actual functionality. Its subclasses must be either Estimator or 
Transformer. No other classes
+    should inherit this interface directly.
+
+    Each pipeline stage is with parameters, and requires a public empty 
constructor for
+    restoration in Pipeline.
+    """
+
+    def __init__(self, params=None):
+        if params is None:
+            self._params = Params()
+        else:
+            self._params = params
+
+    def get_params(self) -> Params:
+        return self._params
+
+    def _convert_params_to_java(self, j_pipeline_stage):
+        for param in self._params._param_map:
+            java_param = self._make_java_param(j_pipeline_stage, param)
+            java_value = self._make_java_value(self._params._param_map[param])
+            j_pipeline_stage.set(java_param, java_value)
+
+    @staticmethod
+    def _make_java_param(j_pipeline_stage, param):
+        # camel case to snake case
+        name = re.sub(r'(?<!^)(?=[A-Z])', '_', param.name).upper()
+        return get_field(j_pipeline_stage, name)
+
+    @staticmethod
+    def _make_java_value(obj):
+        """ Convert Python object into Java """
+        if isinstance(obj, list):
+            obj = [PipelineStage._make_java_value(x) for x in obj]
+        return obj
+
+    def to_json(self) -> str:
+        return self.get_params().to_json()
+
+    def load_json(self, json: str) -> None:
+        self.get_params().load_json(json)
+
+
+class Transformer(PipelineStage):
+    """
+    A transformer is a PipelineStage that transforms an input Table to a 
result Table.
+    """
+
+    __metaclass__ = ABCMeta
+
+    @abstractmethod
+    def transform(self, table_env: TableEnvironment, table: Table) -> Table:
+        """
+        Applies the transformer on the input table, and returns the result 
table.
+
+        :param table_env: the table environment to which the input table is 
bound.
+        :param table: the table to be transformed
+        :returns: the transformed table
+        """
+        raise NotImplementedError()
+
+
+class JavaTransformer(Transformer):
+    """
+    Base class for Transformer that wrap Java implementations. Subclasses 
should
+    ensure they have the transformer Java object available as j_obj.
+    """
+
+    def __init__(self, j_obj):
+        super().__init__()
+        self._j_obj = j_obj
+
+    def transform(self, table_env: TableEnvironment, table: Table) -> Table:
+        """
+        Applies the transformer on the input table, and returns the result 
table.
+
+        :param table_env: the table environment to which the input table is 
bound.
+        :param table: the table to be transformed
+        :returns: the transformed table
+        """
+        self._convert_params_to_java(self._j_obj)
+        return Table(self._j_obj.transform(table_env._j_tenv, table._j_table))
+
+
+class Model(Transformer):
+    """
+    Abstract class for models that are fitted by estimators.
+
+    A model is an ordinary Transformer except how it is created. While 
ordinary transformers
+    are defined by specifying the parameters directly, a model is usually 
generated by an Estimator
+    when Estimator.fit(table_env, table) is invoked.
+    """
+
+    __metaclass__ = ABCMeta
+
+
+class JavaModel(JavaTransformer, Model):
+    """
+    Base class for JavaTransformer that wrap Java implementations.
+    Subclasses should ensure they have the model Java object available as 
j_obj.
+    """
+
+
+class Estimator(PipelineStage):
+    """
+    Estimators are PipelineStages responsible for training and generating 
machine learning models.
+
+    The implementations are expected to take an input table as training 
samples and generate a
+    Model which fits these samples.
+    """
+
+    __metaclass__ = ABCMeta
+
+    def fit(self, table_env: TableEnvironment, table: Table) -> Model:
+        """
+        Train and produce a Model which fits the records in the given Table.
+
+        :param table_env: the table environment to which the input table is 
bound.
+        :param table: the table with records to train the Model.
+        :returns: a model trained to fit on the given Table.
+        """
+        raise NotImplementedError()
+
+
+class JavaEstimator(Estimator):
+    """
+    Base class for Estimator that wrap Java implementations.
+    Subclasses should ensure they have the estimator Java object available as 
j_obj.
+    """
+
+    def __init__(self, j_obj):
+        super().__init__()
+        self._j_obj = j_obj
+
+    def fit(self, table_env: TableEnvironment, table: Table) -> JavaModel:
+        """
+        Train and produce a Model which fits the records in the given Table.
+
+        :param table_env: the table environment to which the input table is 
bound.
+        :param table: the table with records to train the Model.
+        :returns: a model trained to fit on the given Table.
+        """
+        self._convert_params_to_java(self._j_obj)
+        return JavaModel(self._j_obj.fit(table_env._j_tenv, table._j_table))
+
+
+class Pipeline(Estimator, Model):
 
 Review comment:
   I would be better if we keep it consistent with the Java API definition. 
adding Transformer here shouldn't change anything, as you said its already 
included in the Estimator 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to