yunfengzhou-hub commented on a change in pull request #24: URL: https://github.com/apache/flink-ml/pull/24#discussion_r760737278
########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java ########## @@ -0,0 +1,273 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.Stage; +import org.apache.flink.ml.builder.Pipeline; +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasK; +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.param.Param; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +/** knn algorithm test. */ +public class KnnTest { + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainData; + + List<Row> trainArray = + new ArrayList<>( + Arrays.asList( + Row.of("f", "2.0 3.0", 1, 0, 1.47), + Row.of("f", "2.1 3.1", 1, 0, 1.5), + Row.of("m", "200.1 300.1", 1, 0, 1.5), + Row.of("m", "200.2 300.2", 1, 0, 2.59), + Row.of("m", "200.3 300.3", 1, 0, 2.55), + Row.of("m", "200.4 300.4", 1, 0, 2.53), + Row.of("m", "200.4 300.4", 1, 0, 2.52), + Row.of("m", "200.6 300.6", 1, 0, 2.5), + Row.of("f", "2.1 3.1", 1, 0, 1.5), + Row.of("f", "2.1 3.1", 1, 0, 1.56), + Row.of("f", "2.1 3.1", 1, 0, 1.51), + Row.of("f", "2.1 3.1", 1, 0, 1.52), + Row.of("f", "2.3 3.2", 1, 0, 1.53), + Row.of("f", "2.3 3.2", 1, 0, 1.54), + Row.of("c", "2.8 3.2", 3, 0, 1.6), + Row.of("d", "300. 3.2", 5, 0, 1.5), + Row.of("f", "2.2 3.2", 1, 0, 1.5), + Row.of("e", "2.4 3.2", 2, 0, 1.3), + Row.of("e", "2.5 3.2", 2, 0, 1.4), + Row.of("e", "2.5 3.2", 2, 0, 1.5), + Row.of("f", "2.1 3.1", 1, 0, 1.6))); + + List<Row> testArray = + new ArrayList<>( + Arrays.asList( + Row.of("e", "4.0 4.1", 2, 0, 1.5), Row.of("m", "300 42", 1, 0, 2.59))); + + private Table testData; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + DataStream<Row> dataStream = + env.fromCollection( + trainArray, + new RowTypeInfo( + new TypeInformation[] { + Types.STRING, Types.STRING, Types.INT, Types.INT, Types.DOUBLE + }, + new String[] {"label", "vec", "f0", "f1", "f2"})); + trainData = tEnv.fromDataStream(dataStream); Review comment: A more concise way to create the table might be to use a schema. You can refer to that in `KmeansTest` and `NaiveBayesTest` in Naive bayes's PR. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java ########## @@ -0,0 +1,273 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.Stage; +import org.apache.flink.ml.builder.Pipeline; +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasK; +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.param.Param; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +/** knn algorithm test. */ +public class KnnTest { + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainData; + + List<Row> trainArray = + new ArrayList<>( + Arrays.asList( + Row.of("f", "2.0 3.0", 1, 0, 1.47), Review comment: Please correct me if I'm wrong. I remember it has reached an agreement that our stages should receive Vectors as it is, instead of Vectors in format of string. If this is the case, I think we should also change this trainArray data to follow this convention. Besides, we can also remove the changes to `DenseVector.toString()` and `VectorUtils.parse()` methods as we no longer need them. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLabelCol.java ########## @@ -0,0 +1,29 @@ +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** + * Param of the name of the label column in the input table. + * + * @param <T> Review comment: This `<T>` can be removed. The javadoc description can be renamed to `/** Interface for the shared labelCol param. */`. Same for other params. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java ########## @@ -0,0 +1,273 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.Stage; +import org.apache.flink.ml.builder.Pipeline; +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasK; +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.param.Param; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +/** knn algorithm test. */ +public class KnnTest { + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainData; + + List<Row> trainArray = + new ArrayList<>( + Arrays.asList( + Row.of("f", "2.0 3.0", 1, 0, 1.47), + Row.of("f", "2.1 3.1", 1, 0, 1.5), + Row.of("m", "200.1 300.1", 1, 0, 1.5), + Row.of("m", "200.2 300.2", 1, 0, 2.59), + Row.of("m", "200.3 300.3", 1, 0, 2.55), + Row.of("m", "200.4 300.4", 1, 0, 2.53), + Row.of("m", "200.4 300.4", 1, 0, 2.52), + Row.of("m", "200.6 300.6", 1, 0, 2.5), + Row.of("f", "2.1 3.1", 1, 0, 1.5), + Row.of("f", "2.1 3.1", 1, 0, 1.56), + Row.of("f", "2.1 3.1", 1, 0, 1.51), + Row.of("f", "2.1 3.1", 1, 0, 1.52), + Row.of("f", "2.3 3.2", 1, 0, 1.53), + Row.of("f", "2.3 3.2", 1, 0, 1.54), + Row.of("c", "2.8 3.2", 3, 0, 1.6), + Row.of("d", "300. 3.2", 5, 0, 1.5), + Row.of("f", "2.2 3.2", 1, 0, 1.5), + Row.of("e", "2.4 3.2", 2, 0, 1.3), + Row.of("e", "2.5 3.2", 2, 0, 1.4), + Row.of("e", "2.5 3.2", 2, 0, 1.5), + Row.of("f", "2.1 3.1", 1, 0, 1.6))); + + List<Row> testArray = + new ArrayList<>( + Arrays.asList( + Row.of("e", "4.0 4.1", 2, 0, 1.5), Row.of("m", "300 42", 1, 0, 2.59))); + + private Table testData; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + DataStream<Row> dataStream = + env.fromCollection( + trainArray, + new RowTypeInfo( + new TypeInformation[] { + Types.STRING, Types.STRING, Types.INT, Types.INT, Types.DOUBLE + }, + new String[] {"label", "vec", "f0", "f1", "f2"})); + trainData = tEnv.fromDataStream(dataStream); + + DataStream<Row> dataStreamStr = + env.fromCollection( + testArray, + new RowTypeInfo( + new TypeInformation[] { + Types.STRING, Types.STRING, Types.INT, Types.INT, Types.DOUBLE + }, + new String[] {"label", "vec", "f0", "f1", "f2"})); + + testData = tEnv.fromDataStream(dataStreamStr); + } + + /** test knn Estimator. */ + @Test + public void testKnnEstimator() throws Exception { Review comment: Naming it as `testKnnEstimator` might be a little bit broad, as most UTs in this file meets this name. We can rename it to `testFitAndPredict`. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java ########## @@ -0,0 +1,152 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.VectorUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.table.types.DataType; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * KNN is to classify unlabeled observations by assigning them to the class of the most similar + * labeled examples. + */ +public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> { + + protected Map<Param<?>, Object> params = new HashMap<>(); + + /** constructor. */ + public Knn() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + /** + * constructor. + * + * @param params parameters for algorithm. + */ + public Knn(Map<Param<?>, Object> params) { + this.params = params; + } + + /** + * @param inputs a list of tables + * @return knn classification model. + */ + @Override + public KnnModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + ResolvedSchema schema = inputs[0].getResolvedSchema(); + String[] colNames = schema.getColumnNames().toArray(new String[0]); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> input = tEnv.toDataStream(inputs[0]); + + String labelCol = getLabelCol(); + String vecCol = getFeaturesCol(); + + DataStream<Row> trainData = + input.map( + (MapFunction<Row, Row>) + value -> { + Object label = value.getField(labelCol); + DenseVector vec = + VectorUtils.parse(value.getField(vecCol).toString()); + return Row.of(label, vec); + }); + + DataType idType = schema.getColumnDataTypes().get(findColIndex(colNames, labelCol)); + DataStream<Row> model = buildModel(trainData, getParamMap(), idType); + KnnModel knnModel = new KnnModel(params); + knnModel.setModelData(tEnv.fromDataStream(model)); + return knnModel; + } + + /** + * build knn model. + * + * @param dataStream input data. + * @param params input parameters. + * @return stream format model. + */ + private static DataStream<Row> buildModel( + DataStream<Row> dataStream, final Map<Param<?>, Object> params, final DataType idType) { + FastDistance fastDistance = new FastDistance(); + + return dataStream.transform( + "build index", + new RowTypeInfo( + new TypeInformation[] { + Types.STRING, + TypeInformation.of(idType.getLogicalType().getDefaultConversion()) + }, + new String[] {"DATA", "KNN_LABEL_TYPE"}), Review comment: In `KnnModel.load()`, schema is used when converting the DataStream to Table; Here in `buildModel`, `RowTypeInfo` is used when creating the datastream and no schema is used in `tEnv.fromDataStream()`. Since the two usages mentioned above have the same functionality, can we create a unified implementation for this function and place that inside ModelData class as a static method? You can refer to NaiveBayes's PR for similar methods. ########## File path: flink-ml-lib/pom.xml ########## @@ -65,6 +65,12 @@ under the License. <scope>test</scope> </dependency> + <dependency> + <groupId>com.github.fommil.netlib</groupId> Review comment: In [this jira](https://issues.apache.org/jira/browse/SPARK-35295) Spark replaced `com.github.fommil.netlib` with `dev.ludovic.netlib` for licensing reasons. As both Spark and Flink are under Apache's license I believe the licensing reason also applies for us. Thus we should also use that package for BLAS operations. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java ########## @@ -0,0 +1,273 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.Stage; +import org.apache.flink.ml.builder.Pipeline; +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasK; +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.param.Param; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +/** knn algorithm test. */ Review comment: Maybe we can rewrite the comments for such classes to follow existing conventions, like that in `KmeansTest`. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java ########## @@ -0,0 +1,273 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.Stage; +import org.apache.flink.ml.builder.Pipeline; +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasK; +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.param.Param; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +/** knn algorithm test. */ +public class KnnTest { + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainData; + + List<Row> trainArray = + new ArrayList<>( + Arrays.asList( + Row.of("f", "2.0 3.0", 1, 0, 1.47), + Row.of("f", "2.1 3.1", 1, 0, 1.5), + Row.of("m", "200.1 300.1", 1, 0, 1.5), + Row.of("m", "200.2 300.2", 1, 0, 2.59), + Row.of("m", "200.3 300.3", 1, 0, 2.55), + Row.of("m", "200.4 300.4", 1, 0, 2.53), + Row.of("m", "200.4 300.4", 1, 0, 2.52), + Row.of("m", "200.6 300.6", 1, 0, 2.5), + Row.of("f", "2.1 3.1", 1, 0, 1.5), + Row.of("f", "2.1 3.1", 1, 0, 1.56), + Row.of("f", "2.1 3.1", 1, 0, 1.51), + Row.of("f", "2.1 3.1", 1, 0, 1.52), + Row.of("f", "2.3 3.2", 1, 0, 1.53), + Row.of("f", "2.3 3.2", 1, 0, 1.54), + Row.of("c", "2.8 3.2", 3, 0, 1.6), + Row.of("d", "300. 3.2", 5, 0, 1.5), + Row.of("f", "2.2 3.2", 1, 0, 1.5), + Row.of("e", "2.4 3.2", 2, 0, 1.3), + Row.of("e", "2.5 3.2", 2, 0, 1.4), + Row.of("e", "2.5 3.2", 2, 0, 1.5), + Row.of("f", "2.1 3.1", 1, 0, 1.6))); + + List<Row> testArray = + new ArrayList<>( + Arrays.asList( + Row.of("e", "4.0 4.1", 2, 0, 1.5), Row.of("m", "300 42", 1, 0, 2.59))); + + private Table testData; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + DataStream<Row> dataStream = + env.fromCollection( + trainArray, + new RowTypeInfo( + new TypeInformation[] { + Types.STRING, Types.STRING, Types.INT, Types.INT, Types.DOUBLE + }, + new String[] {"label", "vec", "f0", "f1", "f2"})); + trainData = tEnv.fromDataStream(dataStream); + + DataStream<Row> dataStreamStr = + env.fromCollection( + testArray, + new RowTypeInfo( + new TypeInformation[] { + Types.STRING, Types.STRING, Types.INT, Types.INT, Types.DOUBLE + }, + new String[] {"label", "vec", "f0", "f1", "f2"})); + + testData = tEnv.fromDataStream(dataStreamStr); + } + + /** test knn Estimator. */ + @Test + public void testKnnEstimator() throws Exception { + Knn knn = + new Knn() + .setLabelCol("label") Review comment: We should try to test setting the value to something different from default values, or we may not know whether the setting succeeded or not. ########## File path: flink-ml-core/pom.xml ########## @@ -59,11 +59,11 @@ under the License. <version>${flink.version}</version> <scope>test</scope> </dependency> - <dependency> - <groupId>org.apache.flink</groupId> - <artifactId>flink-shaded-jackson</artifactId> - <scope>provided</scope> + <groupId>com.google.code.gson</groupId> Review comment: This dependency is used nowhere in current PR. We can just remove this. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java ########## @@ -0,0 +1,273 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.Stage; +import org.apache.flink.ml.builder.Pipeline; +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasK; +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.param.Param; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +/** knn algorithm test. */ +public class KnnTest { + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainData; + + List<Row> trainArray = + new ArrayList<>( + Arrays.asList( + Row.of("f", "2.0 3.0", 1, 0, 1.47), + Row.of("f", "2.1 3.1", 1, 0, 1.5), + Row.of("m", "200.1 300.1", 1, 0, 1.5), + Row.of("m", "200.2 300.2", 1, 0, 2.59), + Row.of("m", "200.3 300.3", 1, 0, 2.55), + Row.of("m", "200.4 300.4", 1, 0, 2.53), + Row.of("m", "200.4 300.4", 1, 0, 2.52), + Row.of("m", "200.6 300.6", 1, 0, 2.5), + Row.of("f", "2.1 3.1", 1, 0, 1.5), + Row.of("f", "2.1 3.1", 1, 0, 1.56), + Row.of("f", "2.1 3.1", 1, 0, 1.51), + Row.of("f", "2.1 3.1", 1, 0, 1.52), + Row.of("f", "2.3 3.2", 1, 0, 1.53), + Row.of("f", "2.3 3.2", 1, 0, 1.54), + Row.of("c", "2.8 3.2", 3, 0, 1.6), + Row.of("d", "300. 3.2", 5, 0, 1.5), + Row.of("f", "2.2 3.2", 1, 0, 1.5), + Row.of("e", "2.4 3.2", 2, 0, 1.3), + Row.of("e", "2.5 3.2", 2, 0, 1.4), + Row.of("e", "2.5 3.2", 2, 0, 1.5), + Row.of("f", "2.1 3.1", 1, 0, 1.6))); + + List<Row> testArray = + new ArrayList<>( + Arrays.asList( + Row.of("e", "4.0 4.1", 2, 0, 1.5), Row.of("m", "300 42", 1, 0, 2.59))); + + private Table testData; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + DataStream<Row> dataStream = + env.fromCollection( + trainArray, + new RowTypeInfo( + new TypeInformation[] { + Types.STRING, Types.STRING, Types.INT, Types.INT, Types.DOUBLE + }, + new String[] {"label", "vec", "f0", "f1", "f2"})); + trainData = tEnv.fromDataStream(dataStream); + + DataStream<Row> dataStreamStr = + env.fromCollection( + testArray, + new RowTypeInfo( + new TypeInformation[] { + Types.STRING, Types.STRING, Types.INT, Types.INT, Types.DOUBLE + }, + new String[] {"label", "vec", "f0", "f1", "f2"})); + + testData = tEnv.fromDataStream(dataStreamStr); + } + + /** test knn Estimator. */ + @Test + public void testKnnEstimator() throws Exception { + Knn knn = + new Knn() + .setLabelCol("label") + .setFeaturesCol("vec") + .setK(4) + .setPredictionCol("pred"); + + KnnModel knnModel = knn.fit(trainData); + Table result = knnModel.transform(testData)[0]; + + DataStream<Row> output = tEnv.toDataStream(result); + + List<Row> rows = IteratorUtils.toList(output.executeAndCollect()); + for (Row value : rows) { + String label = (String) value.getField(0); + String pred = (String) value.getField(5); + assert (label.equals(pred)); + } + } + + /** test knn Estimator. */ + @Test + public void testKnnEstimatorWithFeatures() throws Exception { + Map<Param<?>, Object> params = new HashMap<>(); + params.put(HasLabelCol.LABEL_COL, "label"); + params.put(HasFeaturesCol.FEATURES_COL, "vec"); + params.put(HasK.K, 4); + params.put(HasPredictionCol.PREDICTION_COL, "pred"); + Knn knn = new Knn(params); + + KnnModel knnModel = knn.fit(trainData); + Table result = knnModel.transform(testData)[0]; + + DataStream<Row> output = tEnv.toDataStream(result); + + List<Row> rows = IteratorUtils.toList(output.executeAndCollect()); + for (Row value : rows) { + String label = (String) value.getField(0); + String pred = (String) value.getField(5); + assert (label.equals(pred)); + } + } + + /** test knn as a pipeline stage. */ + @Test + public void testKnnPipeline() throws Exception { + Knn knn = + new Knn() + .setLabelCol("label") + .setFeaturesCol("vec") + .setK(4) + .setPredictionCol("pred"); + + List<Stage<?>> stages = new ArrayList<>(); + stages.add(knn); + + Pipeline pipe = new Pipeline(stages); + + Table result = pipe.fit(trainData).transform(testData)[0]; + + DataStream<Row> output = tEnv.toDataStream(result); + + List<Row> rows = IteratorUtils.toList(output.executeAndCollect()); + for (Row value : rows) { + String label = (String) value.getField(0); + String pred = (String) value.getField(5); + assert (label.equals(pred)); + } + } + + /** test knn model save. */ + @Test + public void testKnnModelSave() throws Exception { + String knnPath = Files.createTempDirectory("").toString(); + String modelPath = Files.createTempDirectory("").toString(); + Knn knn = + new Knn().setLabelCol("f0").setFeaturesCol("vec").setK(4).setPredictionCol("pred"); + knn.save(knnPath); + Knn cloneKnn = Knn.load(knnPath); + KnnModel knnModel = cloneKnn.fit(trainData); + knnModel.save(modelPath); + env.execute(); + } + + /** test knn model load and transform. */ + @Test + public void testKnnModelLoad() throws Exception { Review comment: `testKnnModelLoad` can cover all that tested in `testKnnModelSave`, so I believe we can remove `testKnnModelSave`. We have also established practice to test saving/loading a stage in NaiveBayes's PR, using methods provided in `StageTestUtils`. We can also use that tool in this PR. Besides Model, we should also add tests for the corresponding Estimator. The class name has provided enough information to show that the tests in this class is for Knn. Thus maybe we can just rename `testKnnModelLoad` to something like `testSaveLoad`. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java ########## @@ -0,0 +1,273 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.Stage; +import org.apache.flink.ml.builder.Pipeline; +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasK; +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.param.Param; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +/** knn algorithm test. */ +public class KnnTest { + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainData; + + List<Row> trainArray = + new ArrayList<>( + Arrays.asList( + Row.of("f", "2.0 3.0", 1, 0, 1.47), + Row.of("f", "2.1 3.1", 1, 0, 1.5), + Row.of("m", "200.1 300.1", 1, 0, 1.5), + Row.of("m", "200.2 300.2", 1, 0, 2.59), + Row.of("m", "200.3 300.3", 1, 0, 2.55), + Row.of("m", "200.4 300.4", 1, 0, 2.53), + Row.of("m", "200.4 300.4", 1, 0, 2.52), + Row.of("m", "200.6 300.6", 1, 0, 2.5), + Row.of("f", "2.1 3.1", 1, 0, 1.5), + Row.of("f", "2.1 3.1", 1, 0, 1.56), + Row.of("f", "2.1 3.1", 1, 0, 1.51), + Row.of("f", "2.1 3.1", 1, 0, 1.52), + Row.of("f", "2.3 3.2", 1, 0, 1.53), + Row.of("f", "2.3 3.2", 1, 0, 1.54), + Row.of("c", "2.8 3.2", 3, 0, 1.6), + Row.of("d", "300. 3.2", 5, 0, 1.5), + Row.of("f", "2.2 3.2", 1, 0, 1.5), + Row.of("e", "2.4 3.2", 2, 0, 1.3), + Row.of("e", "2.5 3.2", 2, 0, 1.4), + Row.of("e", "2.5 3.2", 2, 0, 1.5), + Row.of("f", "2.1 3.1", 1, 0, 1.6))); + + List<Row> testArray = + new ArrayList<>( + Arrays.asList( + Row.of("e", "4.0 4.1", 2, 0, 1.5), Row.of("m", "300 42", 1, 0, 2.59))); + + private Table testData; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + DataStream<Row> dataStream = + env.fromCollection( + trainArray, + new RowTypeInfo( + new TypeInformation[] { + Types.STRING, Types.STRING, Types.INT, Types.INT, Types.DOUBLE + }, + new String[] {"label", "vec", "f0", "f1", "f2"})); + trainData = tEnv.fromDataStream(dataStream); + + DataStream<Row> dataStreamStr = + env.fromCollection( + testArray, + new RowTypeInfo( + new TypeInformation[] { + Types.STRING, Types.STRING, Types.INT, Types.INT, Types.DOUBLE + }, + new String[] {"label", "vec", "f0", "f1", "f2"})); + + testData = tEnv.fromDataStream(dataStreamStr); + } + + /** test knn Estimator. */ + @Test + public void testKnnEstimator() throws Exception { + Knn knn = + new Knn() + .setLabelCol("label") + .setFeaturesCol("vec") + .setK(4) + .setPredictionCol("pred"); + + KnnModel knnModel = knn.fit(trainData); + Table result = knnModel.transform(testData)[0]; + + DataStream<Row> output = tEnv.toDataStream(result); + + List<Row> rows = IteratorUtils.toList(output.executeAndCollect()); + for (Row value : rows) { + String label = (String) value.getField(0); + String pred = (String) value.getField(5); + assert (label.equals(pred)); Review comment: In NaiveBayes's PR we have established a practice to collect table outputs, extract supervised or unsupervised result and compare it against expected data. I suggest that we can also adopt that practice in this PR. One possible risk of this current implementation is that if the output table does not produce any result, which means `rows.size() == 0`, the test can still pass without the error detected. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java ########## @@ -0,0 +1,273 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.Stage; +import org.apache.flink.ml.builder.Pipeline; +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasK; +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.param.Param; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +/** knn algorithm test. */ +public class KnnTest { + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainData; + + List<Row> trainArray = + new ArrayList<>( + Arrays.asList( + Row.of("f", "2.0 3.0", 1, 0, 1.47), + Row.of("f", "2.1 3.1", 1, 0, 1.5), + Row.of("m", "200.1 300.1", 1, 0, 1.5), + Row.of("m", "200.2 300.2", 1, 0, 2.59), + Row.of("m", "200.3 300.3", 1, 0, 2.55), + Row.of("m", "200.4 300.4", 1, 0, 2.53), + Row.of("m", "200.4 300.4", 1, 0, 2.52), + Row.of("m", "200.6 300.6", 1, 0, 2.5), + Row.of("f", "2.1 3.1", 1, 0, 1.5), + Row.of("f", "2.1 3.1", 1, 0, 1.56), + Row.of("f", "2.1 3.1", 1, 0, 1.51), + Row.of("f", "2.1 3.1", 1, 0, 1.52), + Row.of("f", "2.3 3.2", 1, 0, 1.53), + Row.of("f", "2.3 3.2", 1, 0, 1.54), + Row.of("c", "2.8 3.2", 3, 0, 1.6), + Row.of("d", "300. 3.2", 5, 0, 1.5), + Row.of("f", "2.2 3.2", 1, 0, 1.5), + Row.of("e", "2.4 3.2", 2, 0, 1.3), + Row.of("e", "2.5 3.2", 2, 0, 1.4), + Row.of("e", "2.5 3.2", 2, 0, 1.5), + Row.of("f", "2.1 3.1", 1, 0, 1.6))); + + List<Row> testArray = + new ArrayList<>( + Arrays.asList( + Row.of("e", "4.0 4.1", 2, 0, 1.5), Row.of("m", "300 42", 1, 0, 2.59))); + + private Table testData; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + DataStream<Row> dataStream = + env.fromCollection( + trainArray, + new RowTypeInfo( + new TypeInformation[] { + Types.STRING, Types.STRING, Types.INT, Types.INT, Types.DOUBLE + }, + new String[] {"label", "vec", "f0", "f1", "f2"})); + trainData = tEnv.fromDataStream(dataStream); + + DataStream<Row> dataStreamStr = + env.fromCollection( + testArray, + new RowTypeInfo( + new TypeInformation[] { + Types.STRING, Types.STRING, Types.INT, Types.INT, Types.DOUBLE + }, + new String[] {"label", "vec", "f0", "f1", "f2"})); + + testData = tEnv.fromDataStream(dataStreamStr); + } + + /** test knn Estimator. */ + @Test + public void testKnnEstimator() throws Exception { + Knn knn = + new Knn() + .setLabelCol("label") + .setFeaturesCol("vec") + .setK(4) + .setPredictionCol("pred"); + + KnnModel knnModel = knn.fit(trainData); + Table result = knnModel.transform(testData)[0]; + + DataStream<Row> output = tEnv.toDataStream(result); + + List<Row> rows = IteratorUtils.toList(output.executeAndCollect()); + for (Row value : rows) { + String label = (String) value.getField(0); + String pred = (String) value.getField(5); + assert (label.equals(pred)); + } + } + + /** test knn Estimator. */ + @Test + public void testKnnEstimatorWithFeatures() throws Exception { + Map<Param<?>, Object> params = new HashMap<>(); + params.put(HasLabelCol.LABEL_COL, "label"); + params.put(HasFeaturesCol.FEATURES_COL, "vec"); + params.put(HasK.K, 4); + params.put(HasPredictionCol.PREDICTION_COL, "pred"); + Knn knn = new Knn(params); + + KnnModel knnModel = knn.fit(trainData); + Table result = knnModel.transform(testData)[0]; + + DataStream<Row> output = tEnv.toDataStream(result); + + List<Row> rows = IteratorUtils.toList(output.executeAndCollect()); + for (Row value : rows) { + String label = (String) value.getField(0); + String pred = (String) value.getField(5); + assert (label.equals(pred)); + } + } + + /** test knn as a pipeline stage. */ + @Test + public void testKnnPipeline() throws Exception { + Knn knn = + new Knn() + .setLabelCol("label") + .setFeaturesCol("vec") + .setK(4) + .setPredictionCol("pred"); + + List<Stage<?>> stages = new ArrayList<>(); + stages.add(knn); + + Pipeline pipe = new Pipeline(stages); + + Table result = pipe.fit(trainData).transform(testData)[0]; + + DataStream<Row> output = tEnv.toDataStream(result); + + List<Row> rows = IteratorUtils.toList(output.executeAndCollect()); + for (Row value : rows) { + String label = (String) value.getField(0); + String pred = (String) value.getField(5); + assert (label.equals(pred)); + } + } + + /** test knn model save. */ + @Test + public void testKnnModelSave() throws Exception { + String knnPath = Files.createTempDirectory("").toString(); + String modelPath = Files.createTempDirectory("").toString(); + Knn knn = + new Knn().setLabelCol("f0").setFeaturesCol("vec").setK(4).setPredictionCol("pred"); + knn.save(knnPath); + Knn cloneKnn = Knn.load(knnPath); + KnnModel knnModel = cloneKnn.fit(trainData); + knnModel.save(modelPath); + env.execute(); + } + + /** test knn model load and transform. */ + @Test + public void testKnnModelLoad() throws Exception { + String path = Files.createTempDirectory("").toString(); + Knn knn = + new Knn() + .setLabelCol("label") + .setFeaturesCol("vec") + .setK(4) + .setPredictionCol("pred"); + KnnModel knnModel = knn.fit(trainData); + knnModel.save(path); + env.execute(); + + KnnModel newModel = KnnModel.load(env, path); + Table result = newModel.transform(testData)[0]; + + DataStream<Row> output = tEnv.toDataStream(result); + + List<Row> rows = IteratorUtils.toList(output.executeAndCollect()); + for (Row value : rows) { + String label = (String) value.getField(0); + String pred = (String) value.getField(5); + assert (label.equals(pred)); + } + } + + /** Test Param */ + @Test + public void testParam() { Review comment: We can also add tests about whether the params can be successfully passed to the corresponding model. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnParams.java ########## @@ -0,0 +1,10 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasK; +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasPredictionCol; + +/** knn parameters. */ +public interface KnnParams<T> + extends HasFeaturesCol<T>, HasLabelCol<T>, HasPredictionCol<T>, HasK<T> {} Review comment: `HasLabelCol` is only used in `Knn`, not in `KnnModel`. If this is the case, we should still separate `KnnParams` and `KnnModelParams`. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java ########## @@ -0,0 +1,489 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.connector.source.Source; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.connector.file.sink.FileSink; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.VectorUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner; +import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.utils.LogicalTypeParser; +import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter; +import org.apache.flink.types.Row; + +import org.apache.flink.shaded.curator4.com.google.common.collect.ImmutableMap; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.TreeMap; +import java.util.function.Function; + +/** Knn classification model fitted by estimator. */ +public class KnnModel implements Model<KnnModel>, KnnParams<KnnModel> { + protected Map<Param<?>, Object> params = new HashMap<>(); + private Table[] modelData; + + /** constructor. */ + public KnnModel() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + /** + * constructor. + * + * @param params parameters for algorithm. + */ + public KnnModel(Map<Param<?>, Object> params) { + this.params = params; + } + + /** + * Set model data for knn prediction. + * + * @param modelData knn model. + * @return knn model. + */ + @Override + public KnnModel setModelData(Table... modelData) { + this.modelData = modelData; + return this; + } + + /** + * get model data. + * + * @return list of tables. + */ + @Override + public Table[] getModelData() { + return modelData; + } + + /** + * @param inputs a list of tables. + * @return result. + */ + @Override + public Table[] transform(Table... inputs) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> input = tEnv.toDataStream(inputs[0]); + DataStream<Row> model = tEnv.toDataStream(modelData[0]); + final String BROADCAST_STR = "broadcastModelKey"; + Map<String, DataStream<?>> broadcastMap = new HashMap<>(1); + broadcastMap.put(BROADCAST_STR, model); + ResolvedSchema modelSchema = modelData[0].getResolvedSchema(); + DataType idType = + modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1); + String[] reservedCols = + inputs[0].getResolvedSchema().getColumnNames().toArray(new String[0]); + DataType[] reservedTypes = + inputs[0].getResolvedSchema().getColumnDataTypes().toArray(new DataType[0]); + String[] resultCols = new String[] {(String) params.get(KnnParams.PREDICTION_COL)}; + DataType[] resultTypes = new DataType[] {idType}; + ResolvedSchema outputSchema = + ResolvedSchema.physical( + ArrayUtils.addAll(reservedCols, resultCols), + ArrayUtils.addAll(reservedTypes, resultTypes)); Review comment: > Different algorithm maybe have different output schema. For example : knn result may different with lr, for lr has detail info which knn not have. > > So, I think the output schema need be write by each algorithm developer. The process to extract input table schema, add a predict column and convert to RowTypeInfo is identical to that in Kmeans and NaiveBayes and can be supported by methods in `TableUtils`. If the method in that class is used we can greatly reduce the number of lines in `transform()` and make code look clearer. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java ########## @@ -0,0 +1,489 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.connector.source.Source; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.connector.file.sink.FileSink; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.VectorUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner; +import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.utils.LogicalTypeParser; +import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter; +import org.apache.flink.types.Row; + +import org.apache.flink.shaded.curator4.com.google.common.collect.ImmutableMap; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.TreeMap; +import java.util.function.Function; + +/** Knn classification model fitted by estimator. */ +public class KnnModel implements Model<KnnModel>, KnnParams<KnnModel> { + protected Map<Param<?>, Object> params = new HashMap<>(); + private Table[] modelData; + + /** constructor. */ + public KnnModel() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + /** + * constructor. + * + * @param params parameters for algorithm. + */ + public KnnModel(Map<Param<?>, Object> params) { + this.params = params; + } + + /** + * Set model data for knn prediction. + * + * @param modelData knn model. + * @return knn model. + */ + @Override + public KnnModel setModelData(Table... modelData) { + this.modelData = modelData; + return this; + } + + /** + * get model data. + * + * @return list of tables. + */ + @Override + public Table[] getModelData() { + return modelData; + } + + /** + * @param inputs a list of tables. + * @return result. + */ + @Override + public Table[] transform(Table... inputs) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> input = tEnv.toDataStream(inputs[0]); + DataStream<Row> model = tEnv.toDataStream(modelData[0]); + final String BROADCAST_STR = "broadcastModelKey"; + Map<String, DataStream<?>> broadcastMap = new HashMap<>(1); + broadcastMap.put(BROADCAST_STR, model); + ResolvedSchema modelSchema = modelData[0].getResolvedSchema(); + DataType idType = + modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1); + String[] reservedCols = + inputs[0].getResolvedSchema().getColumnNames().toArray(new String[0]); + DataType[] reservedTypes = + inputs[0].getResolvedSchema().getColumnDataTypes().toArray(new DataType[0]); + String[] resultCols = new String[] {(String) params.get(KnnParams.PREDICTION_COL)}; + DataType[] resultTypes = new DataType[] {idType}; + ResolvedSchema outputSchema = + ResolvedSchema.physical( + ArrayUtils.addAll(reservedCols, resultCols), + ArrayUtils.addAll(reservedTypes, resultTypes)); + + DataType[] dataTypes = outputSchema.getColumnDataTypes().toArray(new DataType[0]); + TypeInformation<?>[] typeInformations = new TypeInformation[dataTypes.length]; + + for (int i = 0; i < dataTypes.length; ++i) { + typeInformations[i] = TypeInformation.of(dataTypes[i].getLogicalType().getClass()); + } + + Function<List<DataStream<?>>, DataStream<Row>> function = + dataStreams -> { + DataStream stream = dataStreams.get(0); + return stream.transform( + "mapFunc", + new RowTypeInfo( + typeInformations, + outputSchema.getColumnNames().toArray(new String[0])), + new PredictOperator( Review comment: We can combine `PredictOperator` and `KnnRichFunction` and avoid creating two classes. Discussions under [this comment](https://github.com/apache/flink-ml/pull/32#discussion_r755644135) show how to achieve that. When combined, we can rename the newly-born class to `PredictLabelOperator` or PredictLabelFunction`, since it is not common practice to contain algorithm name(knn) in this class name. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org