yunfengzhou-hub commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r760737278



##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,273 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+
+    List<Row> trainArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of("f", "2.0 3.0", 1, 0, 1.47),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.5),
+                            Row.of("m", "200.1 300.1", 1, 0, 1.5),
+                            Row.of("m", "200.2 300.2", 1, 0, 2.59),
+                            Row.of("m", "200.3 300.3", 1, 0, 2.55),
+                            Row.of("m", "200.4 300.4", 1, 0, 2.53),
+                            Row.of("m", "200.4 300.4", 1, 0, 2.52),
+                            Row.of("m", "200.6 300.6", 1, 0, 2.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.56),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.51),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.52),
+                            Row.of("f", "2.3 3.2", 1, 0, 1.53),
+                            Row.of("f", "2.3 3.2", 1, 0, 1.54),
+                            Row.of("c", "2.8 3.2", 3, 0, 1.6),
+                            Row.of("d", "300. 3.2", 5, 0, 1.5),
+                            Row.of("f", "2.2 3.2", 1, 0, 1.5),
+                            Row.of("e", "2.4 3.2", 2, 0, 1.3),
+                            Row.of("e", "2.5 3.2", 2, 0, 1.4),
+                            Row.of("e", "2.5 3.2", 2, 0, 1.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.6)));
+
+    List<Row> testArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of("e", "4.0 4.1", 2, 0, 1.5), Row.of("m", 
"300 42", 1, 0, 2.59)));
+
+    private Table testData;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        DataStream<Row> dataStream =
+                env.fromCollection(
+                        trainArray,
+                        new RowTypeInfo(
+                                new TypeInformation[] {
+                                    Types.STRING, Types.STRING, Types.INT, 
Types.INT, Types.DOUBLE
+                                },
+                                new String[] {"label", "vec", "f0", "f1", 
"f2"}));
+        trainData = tEnv.fromDataStream(dataStream);

Review comment:
       A more concise way to create the table might be to use a schema. You can 
refer to that in `KmeansTest` and `NaiveBayesTest` in Naive bayes's PR.

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,273 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+
+    List<Row> trainArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of("f", "2.0 3.0", 1, 0, 1.47),

Review comment:
       Please correct me if I'm wrong. I remember it has reached an agreement 
that our stages should receive Vectors as it is, instead of Vectors in format 
of string. If this is the case, I think we should also change this trainArray 
data to follow this convention.
   
   Besides, we can also remove the changes to `DenseVector.toString()` and 
`VectorUtils.parse()` methods as we no longer need them.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLabelCol.java
##########
@@ -0,0 +1,29 @@
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Param of the name of the label column in the input table.
+ *
+ * @param <T>

Review comment:
       This `<T>` can be removed.
   
   The javadoc description can be renamed to `/** Interface for the shared 
labelCol param. */`. Same for other params.

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,273 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+
+    List<Row> trainArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of("f", "2.0 3.0", 1, 0, 1.47),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.5),
+                            Row.of("m", "200.1 300.1", 1, 0, 1.5),
+                            Row.of("m", "200.2 300.2", 1, 0, 2.59),
+                            Row.of("m", "200.3 300.3", 1, 0, 2.55),
+                            Row.of("m", "200.4 300.4", 1, 0, 2.53),
+                            Row.of("m", "200.4 300.4", 1, 0, 2.52),
+                            Row.of("m", "200.6 300.6", 1, 0, 2.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.56),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.51),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.52),
+                            Row.of("f", "2.3 3.2", 1, 0, 1.53),
+                            Row.of("f", "2.3 3.2", 1, 0, 1.54),
+                            Row.of("c", "2.8 3.2", 3, 0, 1.6),
+                            Row.of("d", "300. 3.2", 5, 0, 1.5),
+                            Row.of("f", "2.2 3.2", 1, 0, 1.5),
+                            Row.of("e", "2.4 3.2", 2, 0, 1.3),
+                            Row.of("e", "2.5 3.2", 2, 0, 1.4),
+                            Row.of("e", "2.5 3.2", 2, 0, 1.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.6)));
+
+    List<Row> testArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of("e", "4.0 4.1", 2, 0, 1.5), Row.of("m", 
"300 42", 1, 0, 2.59)));
+
+    private Table testData;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        DataStream<Row> dataStream =
+                env.fromCollection(
+                        trainArray,
+                        new RowTypeInfo(
+                                new TypeInformation[] {
+                                    Types.STRING, Types.STRING, Types.INT, 
Types.INT, Types.DOUBLE
+                                },
+                                new String[] {"label", "vec", "f0", "f1", 
"f2"}));
+        trainData = tEnv.fromDataStream(dataStream);
+
+        DataStream<Row> dataStreamStr =
+                env.fromCollection(
+                        testArray,
+                        new RowTypeInfo(
+                                new TypeInformation[] {
+                                    Types.STRING, Types.STRING, Types.INT, 
Types.INT, Types.DOUBLE
+                                },
+                                new String[] {"label", "vec", "f0", "f1", 
"f2"}));
+
+        testData = tEnv.fromDataStream(dataStreamStr);
+    }
+
+    /** test knn Estimator. */
+    @Test
+    public void testKnnEstimator() throws Exception {

Review comment:
       Naming it as `testKnnEstimator`  might be a little bit broad, as most 
UTs in this file meets this name. We can rename it to `testFitAndPredict`.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,152 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.VectorUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * KNN is to classify unlabeled observations by assigning them to the class of 
the most similar
+ * labeled examples.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+    protected Map<Param<?>, Object> params = new HashMap<>();
+
+    /** constructor. */
+    public Knn() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    /**
+     * constructor.
+     *
+     * @param params parameters for algorithm.
+     */
+    public Knn(Map<Param<?>, Object> params) {
+        this.params = params;
+    }
+
+    /**
+     * @param inputs a list of tables
+     * @return knn classification model.
+     */
+    @Override
+    public KnnModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        ResolvedSchema schema = inputs[0].getResolvedSchema();
+        String[] colNames = schema.getColumnNames().toArray(new String[0]);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+
+        String labelCol = getLabelCol();
+        String vecCol = getFeaturesCol();
+
+        DataStream<Row> trainData =
+                input.map(
+                        (MapFunction<Row, Row>)
+                                value -> {
+                                    Object label = value.getField(labelCol);
+                                    DenseVector vec =
+                                            
VectorUtils.parse(value.getField(vecCol).toString());
+                                    return Row.of(label, vec);
+                                });
+
+        DataType idType = 
schema.getColumnDataTypes().get(findColIndex(colNames, labelCol));
+        DataStream<Row> model = buildModel(trainData, getParamMap(), idType);
+        KnnModel knnModel = new KnnModel(params);
+        knnModel.setModelData(tEnv.fromDataStream(model));
+        return knnModel;
+    }
+
+    /**
+     * build knn model.
+     *
+     * @param dataStream input data.
+     * @param params input parameters.
+     * @return stream format model.
+     */
+    private static DataStream<Row> buildModel(
+            DataStream<Row> dataStream, final Map<Param<?>, Object> params, 
final DataType idType) {
+        FastDistance fastDistance = new FastDistance();
+
+        return dataStream.transform(
+                "build index",
+                new RowTypeInfo(
+                        new TypeInformation[] {
+                            Types.STRING,
+                            
TypeInformation.of(idType.getLogicalType().getDefaultConversion())
+                        },
+                        new String[] {"DATA", "KNN_LABEL_TYPE"}),

Review comment:
       In `KnnModel.load()`, schema is used when converting the DataStream to 
Table; Here in `buildModel`, `RowTypeInfo` is used when creating the datastream 
and no schema is used in `tEnv.fromDataStream()`. 
   
   Since the two usages mentioned above have the same functionality, can we 
create a unified implementation for this function and place that inside 
ModelData class as a static method? You can refer to NaiveBayes's PR for 
similar methods.

##########
File path: flink-ml-lib/pom.xml
##########
@@ -65,6 +65,12 @@ under the License.
       <scope>test</scope>
     </dependency>
 
+    <dependency>
+      <groupId>com.github.fommil.netlib</groupId>

Review comment:
       In [this jira](https://issues.apache.org/jira/browse/SPARK-35295) Spark 
replaced `com.github.fommil.netlib` with `dev.ludovic.netlib` for licensing 
reasons. As both Spark and Flink are under Apache's license I believe the 
licensing reason also applies for us. Thus we should also use that package for 
BLAS operations.

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,273 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */

Review comment:
       Maybe we can rewrite the comments for such classes to follow existing 
conventions, like that in `KmeansTest`.

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,273 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+
+    List<Row> trainArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of("f", "2.0 3.0", 1, 0, 1.47),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.5),
+                            Row.of("m", "200.1 300.1", 1, 0, 1.5),
+                            Row.of("m", "200.2 300.2", 1, 0, 2.59),
+                            Row.of("m", "200.3 300.3", 1, 0, 2.55),
+                            Row.of("m", "200.4 300.4", 1, 0, 2.53),
+                            Row.of("m", "200.4 300.4", 1, 0, 2.52),
+                            Row.of("m", "200.6 300.6", 1, 0, 2.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.56),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.51),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.52),
+                            Row.of("f", "2.3 3.2", 1, 0, 1.53),
+                            Row.of("f", "2.3 3.2", 1, 0, 1.54),
+                            Row.of("c", "2.8 3.2", 3, 0, 1.6),
+                            Row.of("d", "300. 3.2", 5, 0, 1.5),
+                            Row.of("f", "2.2 3.2", 1, 0, 1.5),
+                            Row.of("e", "2.4 3.2", 2, 0, 1.3),
+                            Row.of("e", "2.5 3.2", 2, 0, 1.4),
+                            Row.of("e", "2.5 3.2", 2, 0, 1.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.6)));
+
+    List<Row> testArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of("e", "4.0 4.1", 2, 0, 1.5), Row.of("m", 
"300 42", 1, 0, 2.59)));
+
+    private Table testData;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        DataStream<Row> dataStream =
+                env.fromCollection(
+                        trainArray,
+                        new RowTypeInfo(
+                                new TypeInformation[] {
+                                    Types.STRING, Types.STRING, Types.INT, 
Types.INT, Types.DOUBLE
+                                },
+                                new String[] {"label", "vec", "f0", "f1", 
"f2"}));
+        trainData = tEnv.fromDataStream(dataStream);
+
+        DataStream<Row> dataStreamStr =
+                env.fromCollection(
+                        testArray,
+                        new RowTypeInfo(
+                                new TypeInformation[] {
+                                    Types.STRING, Types.STRING, Types.INT, 
Types.INT, Types.DOUBLE
+                                },
+                                new String[] {"label", "vec", "f0", "f1", 
"f2"}));
+
+        testData = tEnv.fromDataStream(dataStreamStr);
+    }
+
+    /** test knn Estimator. */
+    @Test
+    public void testKnnEstimator() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")

Review comment:
       We should try to test setting the value to something different from 
default values, or we may not know whether the setting succeeded or not.

##########
File path: flink-ml-core/pom.xml
##########
@@ -59,11 +59,11 @@ under the License.
       <version>${flink.version}</version>
       <scope>test</scope>
     </dependency>
-
     <dependency>
-      <groupId>org.apache.flink</groupId>
-      <artifactId>flink-shaded-jackson</artifactId>
-      <scope>provided</scope>
+      <groupId>com.google.code.gson</groupId>

Review comment:
       This dependency is used nowhere in current PR. We can just remove this.

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,273 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+
+    List<Row> trainArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of("f", "2.0 3.0", 1, 0, 1.47),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.5),
+                            Row.of("m", "200.1 300.1", 1, 0, 1.5),
+                            Row.of("m", "200.2 300.2", 1, 0, 2.59),
+                            Row.of("m", "200.3 300.3", 1, 0, 2.55),
+                            Row.of("m", "200.4 300.4", 1, 0, 2.53),
+                            Row.of("m", "200.4 300.4", 1, 0, 2.52),
+                            Row.of("m", "200.6 300.6", 1, 0, 2.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.56),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.51),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.52),
+                            Row.of("f", "2.3 3.2", 1, 0, 1.53),
+                            Row.of("f", "2.3 3.2", 1, 0, 1.54),
+                            Row.of("c", "2.8 3.2", 3, 0, 1.6),
+                            Row.of("d", "300. 3.2", 5, 0, 1.5),
+                            Row.of("f", "2.2 3.2", 1, 0, 1.5),
+                            Row.of("e", "2.4 3.2", 2, 0, 1.3),
+                            Row.of("e", "2.5 3.2", 2, 0, 1.4),
+                            Row.of("e", "2.5 3.2", 2, 0, 1.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.6)));
+
+    List<Row> testArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of("e", "4.0 4.1", 2, 0, 1.5), Row.of("m", 
"300 42", 1, 0, 2.59)));
+
+    private Table testData;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        DataStream<Row> dataStream =
+                env.fromCollection(
+                        trainArray,
+                        new RowTypeInfo(
+                                new TypeInformation[] {
+                                    Types.STRING, Types.STRING, Types.INT, 
Types.INT, Types.DOUBLE
+                                },
+                                new String[] {"label", "vec", "f0", "f1", 
"f2"}));
+        trainData = tEnv.fromDataStream(dataStream);
+
+        DataStream<Row> dataStreamStr =
+                env.fromCollection(
+                        testArray,
+                        new RowTypeInfo(
+                                new TypeInformation[] {
+                                    Types.STRING, Types.STRING, Types.INT, 
Types.INT, Types.DOUBLE
+                                },
+                                new String[] {"label", "vec", "f0", "f1", 
"f2"}));
+
+        testData = tEnv.fromDataStream(dataStreamStr);
+    }
+
+    /** test knn Estimator. */
+    @Test
+    public void testKnnEstimator() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("vec")
+                        .setK(4)
+                        .setPredictionCol("pred");
+
+        KnnModel knnModel = knn.fit(trainData);
+        Table result = knnModel.transform(testData)[0];
+
+        DataStream<Row> output = tEnv.toDataStream(result);
+
+        List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+        for (Row value : rows) {
+            String label = (String) value.getField(0);
+            String pred = (String) value.getField(5);
+            assert (label.equals(pred));
+        }
+    }
+
+    /** test knn Estimator. */
+    @Test
+    public void testKnnEstimatorWithFeatures() throws Exception {
+        Map<Param<?>, Object> params = new HashMap<>();
+        params.put(HasLabelCol.LABEL_COL, "label");
+        params.put(HasFeaturesCol.FEATURES_COL, "vec");
+        params.put(HasK.K, 4);
+        params.put(HasPredictionCol.PREDICTION_COL, "pred");
+        Knn knn = new Knn(params);
+
+        KnnModel knnModel = knn.fit(trainData);
+        Table result = knnModel.transform(testData)[0];
+
+        DataStream<Row> output = tEnv.toDataStream(result);
+
+        List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+        for (Row value : rows) {
+            String label = (String) value.getField(0);
+            String pred = (String) value.getField(5);
+            assert (label.equals(pred));
+        }
+    }
+
+    /** test knn as a pipeline stage. */
+    @Test
+    public void testKnnPipeline() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("vec")
+                        .setK(4)
+                        .setPredictionCol("pred");
+
+        List<Stage<?>> stages = new ArrayList<>();
+        stages.add(knn);
+
+        Pipeline pipe = new Pipeline(stages);
+
+        Table result = pipe.fit(trainData).transform(testData)[0];
+
+        DataStream<Row> output = tEnv.toDataStream(result);
+
+        List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+        for (Row value : rows) {
+            String label = (String) value.getField(0);
+            String pred = (String) value.getField(5);
+            assert (label.equals(pred));
+        }
+    }
+
+    /** test knn model save. */
+    @Test
+    public void testKnnModelSave() throws Exception {
+        String knnPath = Files.createTempDirectory("").toString();
+        String modelPath = Files.createTempDirectory("").toString();
+        Knn knn =
+                new 
Knn().setLabelCol("f0").setFeaturesCol("vec").setK(4).setPredictionCol("pred");
+        knn.save(knnPath);
+        Knn cloneKnn = Knn.load(knnPath);
+        KnnModel knnModel = cloneKnn.fit(trainData);
+        knnModel.save(modelPath);
+        env.execute();
+    }
+
+    /** test knn model load and transform. */
+    @Test
+    public void testKnnModelLoad() throws Exception {

Review comment:
       `testKnnModelLoad` can cover all that tested in `testKnnModelSave`, so I 
believe we can remove `testKnnModelSave`.
   
   We have also established practice to test saving/loading a stage in 
NaiveBayes's PR, using methods provided in `StageTestUtils`. We can also use 
that tool in this PR.
   
   Besides Model, we should also add tests for the corresponding Estimator.
   
   The class name has provided enough information to show that the tests in 
this class is for Knn. Thus maybe we can just rename `testKnnModelLoad` to 
something like `testSaveLoad`.

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,273 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+
+    List<Row> trainArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of("f", "2.0 3.0", 1, 0, 1.47),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.5),
+                            Row.of("m", "200.1 300.1", 1, 0, 1.5),
+                            Row.of("m", "200.2 300.2", 1, 0, 2.59),
+                            Row.of("m", "200.3 300.3", 1, 0, 2.55),
+                            Row.of("m", "200.4 300.4", 1, 0, 2.53),
+                            Row.of("m", "200.4 300.4", 1, 0, 2.52),
+                            Row.of("m", "200.6 300.6", 1, 0, 2.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.56),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.51),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.52),
+                            Row.of("f", "2.3 3.2", 1, 0, 1.53),
+                            Row.of("f", "2.3 3.2", 1, 0, 1.54),
+                            Row.of("c", "2.8 3.2", 3, 0, 1.6),
+                            Row.of("d", "300. 3.2", 5, 0, 1.5),
+                            Row.of("f", "2.2 3.2", 1, 0, 1.5),
+                            Row.of("e", "2.4 3.2", 2, 0, 1.3),
+                            Row.of("e", "2.5 3.2", 2, 0, 1.4),
+                            Row.of("e", "2.5 3.2", 2, 0, 1.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.6)));
+
+    List<Row> testArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of("e", "4.0 4.1", 2, 0, 1.5), Row.of("m", 
"300 42", 1, 0, 2.59)));
+
+    private Table testData;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        DataStream<Row> dataStream =
+                env.fromCollection(
+                        trainArray,
+                        new RowTypeInfo(
+                                new TypeInformation[] {
+                                    Types.STRING, Types.STRING, Types.INT, 
Types.INT, Types.DOUBLE
+                                },
+                                new String[] {"label", "vec", "f0", "f1", 
"f2"}));
+        trainData = tEnv.fromDataStream(dataStream);
+
+        DataStream<Row> dataStreamStr =
+                env.fromCollection(
+                        testArray,
+                        new RowTypeInfo(
+                                new TypeInformation[] {
+                                    Types.STRING, Types.STRING, Types.INT, 
Types.INT, Types.DOUBLE
+                                },
+                                new String[] {"label", "vec", "f0", "f1", 
"f2"}));
+
+        testData = tEnv.fromDataStream(dataStreamStr);
+    }
+
+    /** test knn Estimator. */
+    @Test
+    public void testKnnEstimator() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("vec")
+                        .setK(4)
+                        .setPredictionCol("pred");
+
+        KnnModel knnModel = knn.fit(trainData);
+        Table result = knnModel.transform(testData)[0];
+
+        DataStream<Row> output = tEnv.toDataStream(result);
+
+        List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+        for (Row value : rows) {
+            String label = (String) value.getField(0);
+            String pred = (String) value.getField(5);
+            assert (label.equals(pred));

Review comment:
       In NaiveBayes's PR we have established a practice to collect table 
outputs, extract supervised or unsupervised result and compare it against 
expected data. I suggest that we can also adopt that practice in this PR.
   
   One possible risk of this current implementation is that if the output table 
does not produce any result, which means `rows.size() == 0`, the test can still 
pass without the error detected.

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,273 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+
+    List<Row> trainArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of("f", "2.0 3.0", 1, 0, 1.47),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.5),
+                            Row.of("m", "200.1 300.1", 1, 0, 1.5),
+                            Row.of("m", "200.2 300.2", 1, 0, 2.59),
+                            Row.of("m", "200.3 300.3", 1, 0, 2.55),
+                            Row.of("m", "200.4 300.4", 1, 0, 2.53),
+                            Row.of("m", "200.4 300.4", 1, 0, 2.52),
+                            Row.of("m", "200.6 300.6", 1, 0, 2.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.56),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.51),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.52),
+                            Row.of("f", "2.3 3.2", 1, 0, 1.53),
+                            Row.of("f", "2.3 3.2", 1, 0, 1.54),
+                            Row.of("c", "2.8 3.2", 3, 0, 1.6),
+                            Row.of("d", "300. 3.2", 5, 0, 1.5),
+                            Row.of("f", "2.2 3.2", 1, 0, 1.5),
+                            Row.of("e", "2.4 3.2", 2, 0, 1.3),
+                            Row.of("e", "2.5 3.2", 2, 0, 1.4),
+                            Row.of("e", "2.5 3.2", 2, 0, 1.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.6)));
+
+    List<Row> testArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of("e", "4.0 4.1", 2, 0, 1.5), Row.of("m", 
"300 42", 1, 0, 2.59)));
+
+    private Table testData;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        DataStream<Row> dataStream =
+                env.fromCollection(
+                        trainArray,
+                        new RowTypeInfo(
+                                new TypeInformation[] {
+                                    Types.STRING, Types.STRING, Types.INT, 
Types.INT, Types.DOUBLE
+                                },
+                                new String[] {"label", "vec", "f0", "f1", 
"f2"}));
+        trainData = tEnv.fromDataStream(dataStream);
+
+        DataStream<Row> dataStreamStr =
+                env.fromCollection(
+                        testArray,
+                        new RowTypeInfo(
+                                new TypeInformation[] {
+                                    Types.STRING, Types.STRING, Types.INT, 
Types.INT, Types.DOUBLE
+                                },
+                                new String[] {"label", "vec", "f0", "f1", 
"f2"}));
+
+        testData = tEnv.fromDataStream(dataStreamStr);
+    }
+
+    /** test knn Estimator. */
+    @Test
+    public void testKnnEstimator() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("vec")
+                        .setK(4)
+                        .setPredictionCol("pred");
+
+        KnnModel knnModel = knn.fit(trainData);
+        Table result = knnModel.transform(testData)[0];
+
+        DataStream<Row> output = tEnv.toDataStream(result);
+
+        List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+        for (Row value : rows) {
+            String label = (String) value.getField(0);
+            String pred = (String) value.getField(5);
+            assert (label.equals(pred));
+        }
+    }
+
+    /** test knn Estimator. */
+    @Test
+    public void testKnnEstimatorWithFeatures() throws Exception {
+        Map<Param<?>, Object> params = new HashMap<>();
+        params.put(HasLabelCol.LABEL_COL, "label");
+        params.put(HasFeaturesCol.FEATURES_COL, "vec");
+        params.put(HasK.K, 4);
+        params.put(HasPredictionCol.PREDICTION_COL, "pred");
+        Knn knn = new Knn(params);
+
+        KnnModel knnModel = knn.fit(trainData);
+        Table result = knnModel.transform(testData)[0];
+
+        DataStream<Row> output = tEnv.toDataStream(result);
+
+        List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+        for (Row value : rows) {
+            String label = (String) value.getField(0);
+            String pred = (String) value.getField(5);
+            assert (label.equals(pred));
+        }
+    }
+
+    /** test knn as a pipeline stage. */
+    @Test
+    public void testKnnPipeline() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("vec")
+                        .setK(4)
+                        .setPredictionCol("pred");
+
+        List<Stage<?>> stages = new ArrayList<>();
+        stages.add(knn);
+
+        Pipeline pipe = new Pipeline(stages);
+
+        Table result = pipe.fit(trainData).transform(testData)[0];
+
+        DataStream<Row> output = tEnv.toDataStream(result);
+
+        List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+        for (Row value : rows) {
+            String label = (String) value.getField(0);
+            String pred = (String) value.getField(5);
+            assert (label.equals(pred));
+        }
+    }
+
+    /** test knn model save. */
+    @Test
+    public void testKnnModelSave() throws Exception {
+        String knnPath = Files.createTempDirectory("").toString();
+        String modelPath = Files.createTempDirectory("").toString();
+        Knn knn =
+                new 
Knn().setLabelCol("f0").setFeaturesCol("vec").setK(4).setPredictionCol("pred");
+        knn.save(knnPath);
+        Knn cloneKnn = Knn.load(knnPath);
+        KnnModel knnModel = cloneKnn.fit(trainData);
+        knnModel.save(modelPath);
+        env.execute();
+    }
+
+    /** test knn model load and transform. */
+    @Test
+    public void testKnnModelLoad() throws Exception {
+        String path = Files.createTempDirectory("").toString();
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("vec")
+                        .setK(4)
+                        .setPredictionCol("pred");
+        KnnModel knnModel = knn.fit(trainData);
+        knnModel.save(path);
+        env.execute();
+
+        KnnModel newModel = KnnModel.load(env, path);
+        Table result = newModel.transform(testData)[0];
+
+        DataStream<Row> output = tEnv.toDataStream(result);
+
+        List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+        for (Row value : rows) {
+            String label = (String) value.getField(0);
+            String pred = (String) value.getField(5);
+            assert (label.equals(pred));
+        }
+    }
+
+    /** Test Param */
+    @Test
+    public void testParam() {

Review comment:
       We can also add tests about whether the params can be successfully 
passed to the corresponding model.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnParams.java
##########
@@ -0,0 +1,10 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+
+/** knn parameters. */
+public interface KnnParams<T>
+        extends HasFeaturesCol<T>, HasLabelCol<T>, HasPredictionCol<T>, 
HasK<T> {}

Review comment:
       `HasLabelCol` is only used in `Knn`, not in `KnnModel`. If this is the 
case, we should still separate `KnnParams` and `KnnModelParams`.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,489 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.VectorUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.ImmutableMap;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.TreeMap;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table[] modelData;
+
+    /** constructor. */
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    /**
+     * constructor.
+     *
+     * @param params parameters for algorithm.
+     */
+    public KnnModel(Map<Param<?>, Object> params) {
+        this.params = params;
+    }
+
+    /**
+     * Set model data for knn prediction.
+     *
+     * @param modelData knn model.
+     * @return knn model.
+     */
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelData = modelData;
+        return this;
+    }
+
+    /**
+     * get model data.
+     *
+     * @return list of tables.
+     */
+    @Override
+    public Table[] getModelData() {
+        return modelData;
+    }
+
+    /**
+     * @param inputs a list of tables.
+     * @return result.
+     */
+    @Override
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<Row> model = tEnv.toDataStream(modelData[0]);
+        final String BROADCAST_STR = "broadcastModelKey";
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>(1);
+        broadcastMap.put(BROADCAST_STR, model);
+        ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+        DataType idType =
+                
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+        String[] reservedCols =
+                inputs[0].getResolvedSchema().getColumnNames().toArray(new 
String[0]);
+        DataType[] reservedTypes =
+                inputs[0].getResolvedSchema().getColumnDataTypes().toArray(new 
DataType[0]);
+        String[] resultCols = new String[] {(String) 
params.get(KnnParams.PREDICTION_COL)};
+        DataType[] resultTypes = new DataType[] {idType};
+        ResolvedSchema outputSchema =
+                ResolvedSchema.physical(
+                        ArrayUtils.addAll(reservedCols, resultCols),
+                        ArrayUtils.addAll(reservedTypes, resultTypes));

Review comment:
       > Different algorithm maybe have different output schema. For example : 
knn result may different with lr, for lr has detail info which knn not have.
   >
   > So, I think the output schema need be write by each algorithm developer.
   
   The process to extract input table schema, add a predict column and convert 
to RowTypeInfo is identical to that in Kmeans and NaiveBayes and can be 
supported by methods in `TableUtils`. If the method in that class is used we 
can greatly reduce the number of lines in `transform()` and make code look 
clearer.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,489 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.VectorUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.ImmutableMap;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.TreeMap;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table[] modelData;
+
+    /** constructor. */
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    /**
+     * constructor.
+     *
+     * @param params parameters for algorithm.
+     */
+    public KnnModel(Map<Param<?>, Object> params) {
+        this.params = params;
+    }
+
+    /**
+     * Set model data for knn prediction.
+     *
+     * @param modelData knn model.
+     * @return knn model.
+     */
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelData = modelData;
+        return this;
+    }
+
+    /**
+     * get model data.
+     *
+     * @return list of tables.
+     */
+    @Override
+    public Table[] getModelData() {
+        return modelData;
+    }
+
+    /**
+     * @param inputs a list of tables.
+     * @return result.
+     */
+    @Override
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<Row> model = tEnv.toDataStream(modelData[0]);
+        final String BROADCAST_STR = "broadcastModelKey";
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>(1);
+        broadcastMap.put(BROADCAST_STR, model);
+        ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+        DataType idType =
+                
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+        String[] reservedCols =
+                inputs[0].getResolvedSchema().getColumnNames().toArray(new 
String[0]);
+        DataType[] reservedTypes =
+                inputs[0].getResolvedSchema().getColumnDataTypes().toArray(new 
DataType[0]);
+        String[] resultCols = new String[] {(String) 
params.get(KnnParams.PREDICTION_COL)};
+        DataType[] resultTypes = new DataType[] {idType};
+        ResolvedSchema outputSchema =
+                ResolvedSchema.physical(
+                        ArrayUtils.addAll(reservedCols, resultCols),
+                        ArrayUtils.addAll(reservedTypes, resultTypes));
+
+        DataType[] dataTypes = outputSchema.getColumnDataTypes().toArray(new 
DataType[0]);
+        TypeInformation<?>[] typeInformations = new 
TypeInformation[dataTypes.length];
+
+        for (int i = 0; i < dataTypes.length; ++i) {
+            typeInformations[i] = 
TypeInformation.of(dataTypes[i].getLogicalType().getClass());
+        }
+
+        Function<List<DataStream<?>>, DataStream<Row>> function =
+                dataStreams -> {
+                    DataStream stream = dataStreams.get(0);
+                    return stream.transform(
+                            "mapFunc",
+                            new RowTypeInfo(
+                                    typeInformations,
+                                    outputSchema.getColumnNames().toArray(new 
String[0])),
+                            new PredictOperator(

Review comment:
       We can combine `PredictOperator` and `KnnRichFunction` and avoid 
creating two classes. Discussions under [this 
comment](https://github.com/apache/flink-ml/pull/32#discussion_r755644135) show 
how to achieve that.
   
   When combined, we can rename the newly-born class to `PredictLabelOperator` 
or PredictLabelFunction`, since it is not common practice to contain algorithm 
name(knn) in this class name.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to