zhipeng93 commented on a change in pull request #24: URL: https://github.com/apache/flink-ml/pull/24#discussion_r753647934
########## File path: flink-ml-api/src/main/java/org/apache/flink/ml/linalg/DenseVector.java ########## @@ -42,14 +45,26 @@ public double get(int i) { return values[i]; } + @Override + public void set(int i, double val) { + values[i] = val; + } + @Override public double[] toArray() { return values; } @Override public String toString() { - return Arrays.toString(values); + StringBuilder sbd = new StringBuilder(); Review comment: why is this method needed? Given that we already have `DenseVectorSerializer`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/MapPartitionFunctionWrapper.java ########## @@ -0,0 +1,68 @@ +package org.apache.flink.ml.common; + +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.functions.util.FunctionUtils; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Collector; + +/** + * MapPartitionFunction wrapper. + * + * @param <IN> Input element type. + * @param <OUT> Output element type. + */ +public class MapPartitionFunctionWrapper<IN, OUT> extends AbstractStreamOperator<OUT> Review comment: Can we remove this class and use the existing one `org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnParams.java ########## @@ -0,0 +1,32 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.ml.common.param.HasFeatureColsDefaultAsNull; +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.common.param.HasVectorColDefaultAsNull; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** knn fit parameters. */ +public interface KnnParams<T> + extends WithParams<T>, + HasVectorColDefaultAsNull<T>, + HasLabelCol<T>, + HasFeatureColsDefaultAsNull<T>, + HasPredictionCol<T> { + /** + * @cn-name topK + * @cn topK + */ + Param<Integer> K = new IntParam("k", "k", 10, ParamValidators.gt(0)); Review comment: Can we make `K` a shared param? ########## File path: flink-ml-api/src/main/java/org/apache/flink/ml/linalg/DenseVector.java ########## @@ -60,6 +75,74 @@ public boolean equals(Object obj) { return Arrays.equals(values, ((DenseVector) obj).values); } + /** + * Parse the dense vector from a formatted string. + * + * <p>The format of a dense vector is space separated values such as "1 2 3 4". + * + * @param str A string of space separated values. + * @return The parsed vector. + */ + public static DenseVector fromString(String str) { Review comment: Could this method be a utility function? It may not be part of the math library. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java ########## @@ -0,0 +1,255 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.core.Estimator; +import org.apache.flink.ml.common.MapPartitionFunctionWrapper; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.table.types.DataType; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * KNN is to classify unlabeled observations by assigning them to the class of the most similar + * labeled examples. + */ +public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> { + + private static final long serialVersionUID = 5292477422193301398L; + private static final int ROW_SIZE = 2; + private static final int FASTDISTANCE_TYPE_INDEX = 0; + private static final int DATA_INDEX = 1; + + protected Map<Param<?>, Object> params = new HashMap<>(); + + /** constructor. */ + public Knn() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + /** + * constructor. + * + * @param params parameters for algorithm. + */ + public Knn(Map<Param<?>, Object> params) { + this.params = params; + } + + /** + * @param inputs a list of tables + * @return knn classification model. + */ + @Override + public KnnModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + ResolvedSchema schema = inputs[0].getResolvedSchema(); + String[] colNames = schema.getColumnNames().toArray(new String[0]); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> input = tEnv.toDataStream(inputs[0]); + String[] targetCols = getFeatureCols(); + final int[] featureIndices; + if (targetCols == null) { + featureIndices = new int[colNames.length]; + for (int i = 0; i < colNames.length; i++) { + featureIndices[i] = i; + } + } else { + featureIndices = new int[targetCols.length]; + for (int i = 0; i < featureIndices.length; i++) { + featureIndices[i] = findColIndex(colNames, targetCols[i]); + } + } + String labelCol = getLabelCol(); + final int labelIdx = findColIndex(colNames, labelCol); + final int vecIdx = + getVectorCol() != null + ? findColIndex( + inputs[0] + .getResolvedSchema() + .getColumnNames() + .toArray(new String[0]), + getVectorCol()) + : -1; + + DataStream<Row> trainData = + input.map( + (MapFunction<Row, Row>) + value -> { + Object label = value.getField(labelIdx); Review comment: Can we use `value.getField(getLabelCol())` here? ########## File path: flink-ml-api/src/main/java/org/apache/flink/ml/param/Param.java ########## @@ -64,7 +64,7 @@ public Param( * @return A json-formatted string. */ public String jsonEncode(T value) throws IOException { - return ReadWriteUtils.OBJECT_MAPPER.writeValueAsString(value); + return ReadWriteUtils.OBJECT_MAPPER.toJson(value); Review comment: why use `toJson` instead of writeValueAsString? ########## File path: flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java ########## @@ -43,8 +46,12 @@ /** Utility methods for reading and writing stages. */ public class ReadWriteUtils { - public static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - + public static Gson OBJECT_MAPPER = Review comment: why use Gson instead here? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceMatrixData.java ########## @@ -0,0 +1,93 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +/** + * Save the data for calculating distance fast. The FastDistanceMatrixData saves several dense + * vectors in a single matrix. The vectors are organized in columns, which means each column is a + * single vector. For example, vec1: 0,1,2, vec2: 3,4,5, vec3: 6,7,8, then the data in matrix is + * organized as: vec1,vec2,vec3. And the data array in <code>vectors</code> is {0,1,2,3,4,5,6,7,8}. + */ +public class FastDistanceMatrixData implements Serializable { + private static final long serialVersionUID = 3093977891649431843L; + + /** + * Stores several dense vectors in columns. For example, if the vectorSize is n, and matrix + * saves m vectors, then the number of rows of <code>vectors</code> is n and the number of cols + * of <code>vectors</code> is m. + */ + public final DenseMatrix vectors; + /** + * Save the extra info besides the vector. Each vector is related to one row. Thus, for + * FastDistanceVectorData, the length of <code>rows</code> is one. And for + * FastDistanceMatrixData, the length of <code>rows</code> is equal to the number of cols of + * <code>matrix</code>. Besides, the order of the rows are the same with the vectors. + */ + public final Row[] rows; + + /** + * Stores some extra info extracted from the vector. It's also organized in columns. For + * example, if we want to save the L1 norm and L2 norm of the vector, then the two values are + * viewed as a two-dimension label vector. We organize the norm vectors together to get the + * <code>label</code>. If the number of cols of <code>vectors</code> is m, then in this case the + * dimension of <code>label</code> is 2 * m. + */ + public DenseMatrix label; + + public Row[] getRows() { + return rows; + } + + /** + * Constructor, initialize the vector data and extra info. + * + * @param vectors DenseMatrix which saves vectors in columns. + * @param rows extra info besides the vector. + */ + public FastDistanceMatrixData(DenseMatrix vectors, Row[] rows) { + this.rows = rows; + Preconditions.checkNotNull(vectors, "DenseMatrix should not be null!"); + if (null != rows) { + Preconditions.checkArgument( + vectors.numCols() == rows.length, + "The column number of DenseMatrix must be equal to the rows array length!"); + } + this.vectors = vectors; + } + + /** + * serialization of FastDistanceMatrixData. + * + * @return json string. + */ + @Override + public String toString() { + Map<String, Object> params = new HashMap<>(3); + params.put("vectors", ReadWriteUtils.OBJECT_MAPPER.toJson(vectors)); + params.put("label", ReadWriteUtils.OBJECT_MAPPER.toJson(label)); + params.put("rows", ReadWriteUtils.OBJECT_MAPPER.toJson(rows)); + return ReadWriteUtils.OBJECT_MAPPER.toJson(params); + } + + /** + * deserialization of FastDistanceMatrixData. + * + * @param modelStr string of model serialization. + * @return FastDistanceMatrixData + */ + public static FastDistanceMatrixData fromString(String modelStr) { Review comment: We probably need a separate serializer here. Similar as DenseVectorSerializer. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeatureColsDefaultAsNull.java ########## @@ -0,0 +1,28 @@ +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Params of the names of the feature columns used for training in the input table. */ +public interface HasFeatureColsDefaultAsNull<T> extends WithParams<T> { + /** + * @cn-name 特征列名数组 Review comment: Can we make the Java doc all in English? Also what is the different between `HasFeatureColsDefaultAsNull` and `HasFeatureCols`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java ########## @@ -0,0 +1,255 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.core.Estimator; +import org.apache.flink.ml.common.MapPartitionFunctionWrapper; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.table.types.DataType; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * KNN is to classify unlabeled observations by assigning them to the class of the most similar + * labeled examples. + */ +public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> { + + private static final long serialVersionUID = 5292477422193301398L; + private static final int ROW_SIZE = 2; + private static final int FASTDISTANCE_TYPE_INDEX = 0; + private static final int DATA_INDEX = 1; + + protected Map<Param<?>, Object> params = new HashMap<>(); + + /** constructor. */ + public Knn() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + /** + * constructor. + * + * @param params parameters for algorithm. + */ + public Knn(Map<Param<?>, Object> params) { + this.params = params; Review comment: We should also call `ParamUtils.initializeMapWithDefaultValues(params, this)` here. ########## File path: flink-ml-api/src/main/java/org/apache/flink/ml/param/Param.java ########## @@ -75,7 +75,7 @@ public String jsonEncode(T value) throws IOException { */ @SuppressWarnings("unchecked") public T jsonDecode(String json) throws IOException { - return ReadWriteUtils.OBJECT_MAPPER.readValue(json, clazz); + return ReadWriteUtils.OBJECT_MAPPER.fromJson(json, clazz); Review comment: why use `fromJson` here? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org