lindong28 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r761945296



##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMultiClass.java
##########
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Interface for the shared multi-class param.
+ *
+ * <p>Supported options:
+ * <li>auto: selects the classification type based on the number of classes: 
If numClasses is one or

Review comment:
       nits: since we don't have any API or variable named `numClasses`, it may 
be a slightly better to say `number of classes` and explain how it is derived.
   
   How about use comments like this:
   ```
   auto: selects the classification type based on the number of classes: If the 
number of unique label values from the input data is one or two, set to 
"binomial". Otherwise, set to "multinomial".
   
   ```

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModelData.java
##########
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Serializable;
+
+/** Model data of {@link LogisticRegressionModel}. */
+public class LogisticRegressionModelData implements Serializable {

Review comment:
       Could this `implements Serializable` be removed?
   
   Passing model data as `DataStream<DenseVector>` seems to be more efficient 
and straightforward than passing it as 
`DataStream<LogisticRegressionModelData>`. The meaning of streams can be 
specified as variable name.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticGradient.java
##########
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+
+import java.io.Serializable;
+import java.util.List;
+
+/**
+ * Utility class to compute gradient and loss for logistic loss function.
+ *
+ * <p>See http://mlwiki.org/index.php/Logistic_Regression.
+ */
+public class LogisticGradient implements Serializable {
+
+    /** L2 regularization term. */
+    private final double l2;
+
+    public LogisticGradient(double l2) {
+        this.l2 = l2;
+    }
+
+    /**
+     * Computes weight sum and loss sum on a set of samples.
+     *
+     * @param batchData A sample set of train data.
+     * @param coefficient The model parameters.
+     * @return Weight sum and loss sum of the input data.
+     */
+    public final Tuple2<Double, Double> computeLoss(
+            List<LabeledPointWithWeight> batchData, DenseVector coefficient) {

Review comment:
       nits: how about renaming the first parameter as `dataPoints` so that it 
is consistent with the `dataPoint` used in this method and other methods of 
this class.
   
   Same for `computeGradient()`.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModelData.java
##########
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Serializable;
+
+/** Model data of {@link LogisticRegressionModel}. */
+public class LogisticRegressionModelData implements Serializable {
+
+    public final DenseVector coefficient;
+
+    public LogisticRegressionModelData(DenseVector coefficient) {
+        this.coefficient = coefficient;
+    }
+
+    /**
+     * Converts the table model to a data stream.
+     *
+     * @param modelData The table model data.
+     * @return The data stream model data.
+     */
+    public static DataStream<LogisticRegressionModelData> 
getModelDataStream(Table modelData) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
modelData).getTableEnvironment();
+        return tEnv.toDataStream(modelData).map(x -> 
(LogisticRegressionModelData) x.getField(0));
+    }
+
+    /**
+     * Gets the data encoder for {@link LogisticRegressionModelData}.
+     *
+     * @return The data encoder for {@link LogisticRegressionModelData}.
+     */
+    public static ModelDataEncoder getModelDataEncoder() {
+        return new ModelDataEncoder();
+    }
+
+    /**
+     * Gets the data decoder for {@link LogisticRegressionModelData}.
+     *
+     * @return The data decoder for {@link LogisticRegressionModelData}.
+     */
+    public static ModelDataDecoder getModelDataDecoder() {
+        return new ModelDataDecoder();
+    }
+
+    /** Data encoder for {@link LogisticRegressionModel}. */
+    private static class ModelDataEncoder implements 
Encoder<LogisticRegressionModelData> {
+
+        @Override
+        public void encode(LogisticRegressionModelData modelData, OutputStream 
stream) {

Review comment:
       Could you try to re-use `DenseVectorSerializer` by doing something like 
this:
   
   ```
   DenseVectorSerializer serializer = new DenseVectorSerializer();
   serializer.serialize(modelData.coefficient, new 
DataOutputViewStreamWrapper(stream));
   ```
   
   We can create `DenseVectorSerializer`  only once in this class. There is 
probably a way to do something similar in `ModelDataDecoder`.
   
   If it works, could you help update `KMeansModelData` as well?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##########
@@ -0,0 +1,460 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * An Estimator which implements the logistic regression algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegression
+        implements Estimator<LogisticRegression, LogisticRegressionModel>,
+                LogisticRegressionParams<LogisticRegression> {
+
+    private Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public LogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static LogisticRegression load(StreamExecutionEnvironment env, 
String path)
+            throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    @SuppressWarnings("rawTypes")
+    public LogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        String classificationType = getMultiClass();
+        Preconditions.checkArgument(
+                "auto".equals(classificationType) || 
"binomial".equals(classificationType),
+                "Multinomial classification is not supported yet. Supported 
options: [auto, binomial].");
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<LabeledPointWithWeight> trainData =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                dataPoint -> {
+                                    Double weight =
+                                            getWeightCol() == null
+                                                    ? new Double(1.0)
+                                                    : (Double) 
dataPoint.getField(getWeightCol());
+                                    Double label = (Double) 
dataPoint.getField(getLabelCol());
+                                    assert label != null;

Review comment:
       With the current Flink code style, I believe we typically don't check 
for null values. It is simpler to let the code below throw 
`NullPointerException`.
   
   BTW, we typically don't use `assert` in production code and `assert` are not 
turned on at runtime by default.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java
##########
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** This class implements {@link Model} for {@link LogisticRegression}. */
+public class LogisticRegressionModel
+        implements Model<LogisticRegressionModel>,
+                LogisticRegressionModelParams<LogisticRegressionModel> {
+
+    private Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    private Table modelData;
+
+    public LogisticRegressionModel() {
+        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                LogisticRegressionModelData.getModelDataStream(modelData),
+                path,
+                LogisticRegressionModelData.getModelDataEncoder());
+    }
+
+    public static LogisticRegressionModel load(StreamExecutionEnvironment env, 
String path)
+            throws IOException {
+        LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelData =
+                ReadWriteUtils.loadModelData(
+                        env, path, 
LogisticRegressionModelData.getModelDataDecoder());
+        return model.setModelData(modelData);
+    }
+
+    @Override
+    public LogisticRegressionModel setModelData(Table... inputs) {
+        modelData = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelData};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> inputStream = tEnv.toDataStream(inputs[0]);
+        final String broadcastModelKey = "broadcastModelKey";
+        DataStream<LogisticRegressionModelData> modelData =
+                LogisticRegressionModelData.getModelDataStream(this.modelData);
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(),
+                                BasicTypeInfo.DOUBLE_TYPE_INFO,
+                                TypeInformation.of(DenseVector.class)),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(),
+                                getPredictionCol(),
+                                getRawPredictionCol()));
+        DataStream<Row> predictionResult =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(inputStream),
+                        Collections.singletonMap(broadcastModelKey, modelData),
+                        inputList -> {
+                            DataStream inputData = inputList.get(0);
+                            return inputData.transform(
+                                    "doPrediction",
+                                    outputTypeInfo,
+                                    new PredictOp(broadcastModelKey, 
getFeaturesCol()));
+                        });
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    private static class PredictOp extends AbstractUdfStreamOperator<Row, 
AbstractRichFunction>

Review comment:
       nits: could we rename this as `PredictOperator` for consistency with 
`MapPartitionOperator`. It would be nice to use consistent names for private 
classes that extend `StreamOperator`.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##########
@@ -0,0 +1,460 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * An Estimator which implements the logistic regression algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegression
+        implements Estimator<LogisticRegression, LogisticRegressionModel>,
+                LogisticRegressionParams<LogisticRegression> {
+
+    private Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public LogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static LogisticRegression load(StreamExecutionEnvironment env, 
String path)
+            throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    @SuppressWarnings("rawTypes")
+    public LogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        String classificationType = getMultiClass();
+        Preconditions.checkArgument(
+                "auto".equals(classificationType) || 
"binomial".equals(classificationType),
+                "Multinomial classification is not supported yet. Supported 
options: [auto, binomial].");
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<LabeledPointWithWeight> trainData =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                dataPoint -> {
+                                    Double weight =
+                                            getWeightCol() == null
+                                                    ? new Double(1.0)
+                                                    : (Double) 
dataPoint.getField(getWeightCol());
+                                    Double label = (Double) 
dataPoint.getField(getLabelCol());
+                                    assert label != null;
+                                    assert weight != null;
+                                    boolean isBinomial =
+                                            Double.compare(0., label) == 0
+                                                    || Double.compare(1., 
label) == 0;
+                                    if (!isBinomial) {
+                                        throw new RuntimeException(
+                                                "Multinomial classification is 
not supported yet. Supported options: [auto, binomial].");
+                                    }
+                                    DenseVector features =
+                                            (DenseVector) 
dataPoint.getField(getFeaturesCol());
+                                    return new 
LabeledPointWithWeight(features, label, weight);
+                                });
+        DataStream<double[]> initModelData =
+                trainData.transform(
+                        "genInitModelData",
+                        
PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                        new GenInitModelData());
+
+        DataStream<LogisticRegressionModelData> modelData = train(trainData, 
initModelData);
+        LogisticRegressionModel model =
+                new 
LogisticRegressionModel().setModelData(tEnv.fromDataStream(modelData));
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /**
+     * Generates initialized model data. Note that the parallelism of model 
data is same as the
+     * input train data, not one.
+     */
+    private static class GenInitModelData extends 
AbstractStreamOperator<double[]>
+            implements OneInputStreamOperator<LabeledPointWithWeight, 
double[]>, BoundedOneInput {
+
+        private int dim = 0;
+
+        private ListState<Integer> dimState;
+
+        @Override
+        public void endInput() {
+            output.collect(new StreamRecord<>(new double[dim]));
+        }
+
+        @Override
+        public void processElement(StreamRecord<LabeledPointWithWeight> 
streamRecord) {
+            if (dim == 0) {
+                dim = streamRecord.getValue().features.size();
+            } else {
+                if (dim != streamRecord.getValue().features.size()) {
+                    throw new RuntimeException(
+                            "The training data should all have same 
dimensions.");
+                }
+            }
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            dimState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "dimState", 
BasicTypeInfo.INT_TYPE_INFO));
+            dim = OperatorStateUtils.getUniqueElement(dimState, 
"dimState").orElse(0);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            dimState.clear();
+            dimState.add(dim);
+        }
+    }
+
+    /**
+     * Does machine learning training on the input data with the initialized 
model data.
+     *
+     * @param trainData The training data.
+     * @param initModelData The initialized model.
+     * @return The trained model data.
+     */
+    private DataStream<LogisticRegressionModelData> train(
+            DataStream<LabeledPointWithWeight> trainData, DataStream<double[]> 
initModelData) {
+        LogisticGradient logisticGradient = new LogisticGradient(getReg());
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(initModelData),
+                        ReplayableDataStreamList.notReplay(trainData),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(
+                                logisticGradient,
+                                getGlobalBatchSize(),
+                                getLearningRate(),
+                                getMaxIter(),
+                                getTol()));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+
+        private final LogisticGradient logisticGradient;
+
+        private final int globalBatchSize;
+
+        private final double learningRate;
+
+        private final int maxIter;
+
+        private final double tol;
+
+        public TrainIterationBody(
+                LogisticGradient logisticGradient,
+                int globalBatchSize,
+                double learningRate,
+                int maxIter,
+                double tol) {
+            this.logisticGradient = logisticGradient;
+            this.globalBatchSize = globalBatchSize;
+            this.learningRate = learningRate;
+            this.maxIter = maxIter;
+            this.tol = tol;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized 
model data.
+            // In the following iterations, it contains: the computed 
gradient, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<LogisticRegressionModelData> modelDataOutputTag =
+                    new OutputTag<LogisticRegressionModelData>("MODEL_OUTPUT") 
{};
+            SingleOutputStreamOperator<double[]> gradientAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    
PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(
+                                            logisticGradient,
+                                            globalBatchSize,
+                                            learningRate,
+                                            modelDataOutputTag));
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(gradientAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        
DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedGradientAndWeightAndLoss -> {
+                                        double[] value = (double[]) 
reducedGradientAndWeightAndLoss;
+                                        return value[value.length - 1] / 
value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(maxIter, 
tol));
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    
DataStreamList.of(gradientAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration 
and updates the model
+     * using gradients iteratively. The first input is the training data, and 
the second input is
+     * the initialized model data or feedback of gradient, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends 
AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, 
double[], double[]>,
+                    IterationListener<double[]> {
+
+        private final int globalBatchSize;
+
+        private int localBatchSize;
+
+        private final double learningRate;
+
+        private final LogisticGradient logisticGradient;
+
+        private DenseVector gradient;
+
+        private DenseVector coefficient;
+
+        private int coefficientDim;
+
+        private ListState<DenseVector> coefficientState;
+
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        private Random random = new Random(2021);
+
+        private List<LabeledPointWithWeight> miniBatchData;
+
+        /** The buffer for feedback record: {coefficient, weightSum, loss}. */

Review comment:
       It looks like the first part of `feedbackBuffer` is `gradient`  instead 
of `coefficient`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java
##########
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** This class implements {@link Model} for {@link LogisticRegression}. */
+public class LogisticRegressionModel
+        implements Model<LogisticRegressionModel>,
+                LogisticRegressionModelParams<LogisticRegressionModel> {
+
+    private Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    private Table modelData;
+
+    public LogisticRegressionModel() {
+        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                LogisticRegressionModelData.getModelDataStream(modelData),
+                path,
+                LogisticRegressionModelData.getModelDataEncoder());
+    }
+
+    public static LogisticRegressionModel load(StreamExecutionEnvironment env, 
String path)
+            throws IOException {
+        LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelData =
+                ReadWriteUtils.loadModelData(
+                        env, path, 
LogisticRegressionModelData.getModelDataDecoder());
+        return model.setModelData(modelData);
+    }
+
+    @Override
+    public LogisticRegressionModel setModelData(Table... inputs) {
+        modelData = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelData};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> inputStream = tEnv.toDataStream(inputs[0]);
+        final String broadcastModelKey = "broadcastModelKey";
+        DataStream<LogisticRegressionModelData> modelData =
+                LogisticRegressionModelData.getModelDataStream(this.modelData);
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(),
+                                BasicTypeInfo.DOUBLE_TYPE_INFO,
+                                TypeInformation.of(DenseVector.class)),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(),
+                                getPredictionCol(),
+                                getRawPredictionCol()));
+        DataStream<Row> predictionResult =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(inputStream),
+                        Collections.singletonMap(broadcastModelKey, modelData),
+                        inputList -> {
+                            DataStream inputData = inputList.get(0);

Review comment:
       It would be nice to let `BroadcastUtils.withBroadcastStream` take 
`DataStreamList` instead of `List<DataStream<?>>`. This would our code style 
more consistent (i.e. use `DataStreamList` where it works). And then 
`DataStream inputData` could have the element type here.

##########
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java
##########
@@ -16,29 +16,53 @@
  * limitations under the License.
  */
 
-package org.apache.flink.test.iteration.operators;
+package org.apache.flink.ml.common.iteration;
 
 import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.util.Collector;
 
-/** An termination criteria function that asks to stop after the specialized 
round. */
-public class RoundBasedTerminationCriteria
-        implements FlatMapFunction<EpochRecord, Integer>, 
IterationListener<Integer> {
+/**
+ * A FlatMapFunction that emits values iff the iteration's epochWatermark does 
not exceed a certain
+ * threshold and the loss exceeds a certain tolerance.
+ *
+ * <p>When the output of this FlatMapFunction is used as the termination 
criteria of an iteration
+ * body, the iteration terminates if epochWatermark is greater than or equal 
to `maxIter` or loss

Review comment:
       Hmm... there might be some minor issues with this comment:
   
   - The operator stops emitting values if `epochWatermark + 1 >= maxIter`. But 
the comment says iteration terminates if `epochWatermark >= maxIter`.
   - The comment mentions loss without explaining that `loss == input of this 
operator`. 
   - `epochWatermark` is kind of implementation detail and it is probably not 
the easiest way to explain the functionality of this operator using 
`epochWatermark`.
   
   How about we use comments like this:
   
   ```
    * <p>When the output of this FlatMapFunction is used as the termination 
criteria of an iteration
    * body, the iteration will be executed for at most the given `maxIter` 
iterations. And the iteration will terminate
    * once any input value is smaller than or equal to the given `tol`.
   ```

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModelData.java
##########
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Serializable;
+
+/** Model data of {@link LogisticRegressionModel}. */
+public class LogisticRegressionModelData implements Serializable {
+
+    public final DenseVector coefficient;
+
+    public LogisticRegressionModelData(DenseVector coefficient) {
+        this.coefficient = coefficient;
+    }
+
+    /**
+     * Converts the table model to a data stream.
+     *
+     * @param modelData The table model data.
+     * @return The data stream model data.
+     */
+    public static DataStream<LogisticRegressionModelData> 
getModelDataStream(Table modelData) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
modelData).getTableEnvironment();
+        return tEnv.toDataStream(modelData).map(x -> 
(LogisticRegressionModelData) x.getField(0));
+    }
+
+    /**
+     * Gets the data encoder for {@link LogisticRegressionModelData}.
+     *
+     * @return The data encoder for {@link LogisticRegressionModelData}.
+     */
+    public static ModelDataEncoder getModelDataEncoder() {

Review comment:
       This method seems a bit over-designed. Would it be simpler for the 
caller to just do `new LogisticRegressionModelData.ModelDataEncoder()`
   
   I am OK to do this if there are precedence of doing this in Flink (i.e. we 
use a static getXXX() method whose body contains only one line code).




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to