zhipeng93 commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r753647934



##########
File path: 
flink-ml-api/src/main/java/org/apache/flink/ml/linalg/DenseVector.java
##########
@@ -42,14 +45,26 @@ public double get(int i) {
         return values[i];
     }
 
+    @Override
+    public void set(int i, double val) {
+        values[i] = val;
+    }
+
     @Override
     public double[] toArray() {
         return values;
     }
 
     @Override
     public String toString() {
-        return Arrays.toString(values);
+        StringBuilder sbd = new StringBuilder();

Review comment:
       why is this method needed? Given that we already have 
`DenseVectorSerializer`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/MapPartitionFunctionWrapper.java
##########
@@ -0,0 +1,68 @@
+package org.apache.flink.ml.common;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.functions.util.FunctionUtils;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+
+/**
+ * MapPartitionFunction wrapper.
+ *
+ * @param <IN> Input element type.
+ * @param <OUT> Output element type.
+ */
+public class MapPartitionFunctionWrapper<IN, OUT> extends 
AbstractStreamOperator<OUT>

Review comment:
       Can we remove this class and use the existing one 
`org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnParams.java
##########
@@ -0,0 +1,32 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.common.param.HasFeatureColsDefaultAsNull;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.common.param.HasVectorColDefaultAsNull;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** knn fit parameters. */
+public interface KnnParams<T>
+        extends WithParams<T>,
+                HasVectorColDefaultAsNull<T>,
+                HasLabelCol<T>,
+                HasFeatureColsDefaultAsNull<T>,
+                HasPredictionCol<T> {
+    /**
+     * @cn-name topK
+     * @cn topK
+     */
+    Param<Integer> K = new IntParam("k", "k", 10, ParamValidators.gt(0));

Review comment:
       Can we make `K` a shared param?

##########
File path: 
flink-ml-api/src/main/java/org/apache/flink/ml/linalg/DenseVector.java
##########
@@ -60,6 +75,74 @@ public boolean equals(Object obj) {
         return Arrays.equals(values, ((DenseVector) obj).values);
     }
 
+    /**
+     * Parse the dense vector from a formatted string.
+     *
+     * <p>The format of a dense vector is space separated values such as "1 2 
3 4".
+     *
+     * @param str A string of space separated values.
+     * @return The parsed vector.
+     */
+    public static DenseVector fromString(String str) {

Review comment:
       Could this method be a utility function? It may not be part of the math 
library.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,255 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.core.Estimator;
+import org.apache.flink.ml.common.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * KNN is to classify unlabeled observations by assigning them to the class of 
the most similar
+ * labeled examples.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+    private static final long serialVersionUID = 5292477422193301398L;
+    private static final int ROW_SIZE = 2;
+    private static final int FASTDISTANCE_TYPE_INDEX = 0;
+    private static final int DATA_INDEX = 1;
+
+    protected Map<Param<?>, Object> params = new HashMap<>();
+
+    /** constructor. */
+    public Knn() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    /**
+     * constructor.
+     *
+     * @param params parameters for algorithm.
+     */
+    public Knn(Map<Param<?>, Object> params) {
+        this.params = params;
+    }
+
+    /**
+     * @param inputs a list of tables
+     * @return knn classification model.
+     */
+    @Override
+    public KnnModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        ResolvedSchema schema = inputs[0].getResolvedSchema();
+        String[] colNames = schema.getColumnNames().toArray(new String[0]);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        String[] targetCols = getFeatureCols();
+        final int[] featureIndices;
+        if (targetCols == null) {
+            featureIndices = new int[colNames.length];
+            for (int i = 0; i < colNames.length; i++) {
+                featureIndices[i] = i;
+            }
+        } else {
+            featureIndices = new int[targetCols.length];
+            for (int i = 0; i < featureIndices.length; i++) {
+                featureIndices[i] = findColIndex(colNames, targetCols[i]);
+            }
+        }
+        String labelCol = getLabelCol();
+        final int labelIdx = findColIndex(colNames, labelCol);
+        final int vecIdx =
+                getVectorCol() != null
+                        ? findColIndex(
+                                inputs[0]
+                                        .getResolvedSchema()
+                                        .getColumnNames()
+                                        .toArray(new String[0]),
+                                getVectorCol())
+                        : -1;
+
+        DataStream<Row> trainData =
+                input.map(
+                        (MapFunction<Row, Row>)
+                                value -> {
+                                    Object label = value.getField(labelIdx);

Review comment:
       Can we use `value.getField(getLabelCol())` here? 

##########
File path: flink-ml-api/src/main/java/org/apache/flink/ml/param/Param.java
##########
@@ -64,7 +64,7 @@ public Param(
      * @return A json-formatted string.
      */
     public String jsonEncode(T value) throws IOException {
-        return ReadWriteUtils.OBJECT_MAPPER.writeValueAsString(value);
+        return ReadWriteUtils.OBJECT_MAPPER.toJson(value);

Review comment:
       why use `toJson` instead of writeValueAsString? 

##########
File path: 
flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
##########
@@ -43,8 +46,12 @@
 
 /** Utility methods for reading and writing stages. */
 public class ReadWriteUtils {
-    public static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
-
+    public static Gson OBJECT_MAPPER =

Review comment:
       why use Gson instead here? 

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceMatrixData.java
##########
@@ -0,0 +1,93 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Save the data for calculating distance fast. The FastDistanceMatrixData 
saves several dense
+ * vectors in a single matrix. The vectors are organized in columns, which 
means each column is a
+ * single vector. For example, vec1: 0,1,2, vec2: 3,4,5, vec3: 6,7,8, then the 
data in matrix is
+ * organized as: vec1,vec2,vec3. And the data array in <code>vectors</code> is 
{0,1,2,3,4,5,6,7,8}.
+ */
+public class FastDistanceMatrixData implements Serializable {
+    private static final long serialVersionUID = 3093977891649431843L;
+
+    /**
+     * Stores several dense vectors in columns. For example, if the vectorSize 
is n, and matrix
+     * saves m vectors, then the number of rows of <code>vectors</code> is n 
and the number of cols
+     * of <code>vectors</code> is m.
+     */
+    public final DenseMatrix vectors;
+    /**
+     * Save the extra info besides the vector. Each vector is related to one 
row. Thus, for
+     * FastDistanceVectorData, the length of <code>rows</code> is one. And for
+     * FastDistanceMatrixData, the length of <code>rows</code> is equal to the 
number of cols of
+     * <code>matrix</code>. Besides, the order of the rows are the same with 
the vectors.
+     */
+    public final Row[] rows;
+
+    /**
+     * Stores some extra info extracted from the vector. It's also organized 
in columns. For
+     * example, if we want to save the L1 norm and L2 norm of the vector, then 
the two values are
+     * viewed as a two-dimension label vector. We organize the norm vectors 
together to get the
+     * <code>label</code>. If the number of cols of <code>vectors</code> is m, 
then in this case the
+     * dimension of <code>label</code> is 2 * m.
+     */
+    public DenseMatrix label;
+
+    public Row[] getRows() {
+        return rows;
+    }
+
+    /**
+     * Constructor, initialize the vector data and extra info.
+     *
+     * @param vectors DenseMatrix which saves vectors in columns.
+     * @param rows extra info besides the vector.
+     */
+    public FastDistanceMatrixData(DenseMatrix vectors, Row[] rows) {
+        this.rows = rows;
+        Preconditions.checkNotNull(vectors, "DenseMatrix should not be null!");
+        if (null != rows) {
+            Preconditions.checkArgument(
+                    vectors.numCols() == rows.length,
+                    "The column number of DenseMatrix must be equal to the 
rows array length!");
+        }
+        this.vectors = vectors;
+    }
+
+    /**
+     * serialization of FastDistanceMatrixData.
+     *
+     * @return json string.
+     */
+    @Override
+    public String toString() {
+        Map<String, Object> params = new HashMap<>(3);
+        params.put("vectors", ReadWriteUtils.OBJECT_MAPPER.toJson(vectors));
+        params.put("label", ReadWriteUtils.OBJECT_MAPPER.toJson(label));
+        params.put("rows", ReadWriteUtils.OBJECT_MAPPER.toJson(rows));
+        return ReadWriteUtils.OBJECT_MAPPER.toJson(params);
+    }
+
+    /**
+     * deserialization of FastDistanceMatrixData.
+     *
+     * @param modelStr string of model serialization.
+     * @return FastDistanceMatrixData
+     */
+    public static FastDistanceMatrixData fromString(String modelStr) {

Review comment:
       We probably need a separate serializer here. Similar as 
DenseVectorSerializer.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeatureColsDefaultAsNull.java
##########
@@ -0,0 +1,28 @@
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Params of the names of the feature columns used for training in the input 
table. */
+public interface HasFeatureColsDefaultAsNull<T> extends WithParams<T> {
+    /**
+     * @cn-name 特征列名数组

Review comment:
       Can we make the Java doc all in English?
   
   Also what is the different between `HasFeatureColsDefaultAsNull` and 
`HasFeatureCols`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,255 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.core.Estimator;
+import org.apache.flink.ml.common.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * KNN is to classify unlabeled observations by assigning them to the class of 
the most similar
+ * labeled examples.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+    private static final long serialVersionUID = 5292477422193301398L;
+    private static final int ROW_SIZE = 2;
+    private static final int FASTDISTANCE_TYPE_INDEX = 0;
+    private static final int DATA_INDEX = 1;
+
+    protected Map<Param<?>, Object> params = new HashMap<>();
+
+    /** constructor. */
+    public Knn() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    /**
+     * constructor.
+     *
+     * @param params parameters for algorithm.
+     */
+    public Knn(Map<Param<?>, Object> params) {
+        this.params = params;

Review comment:
       We should also call `ParamUtils.initializeMapWithDefaultValues(params, 
this)` here.

##########
File path: flink-ml-api/src/main/java/org/apache/flink/ml/param/Param.java
##########
@@ -75,7 +75,7 @@ public String jsonEncode(T value) throws IOException {
      */
     @SuppressWarnings("unchecked")
     public T jsonDecode(String json) throws IOException {
-        return ReadWriteUtils.OBJECT_MAPPER.readValue(json, clazz);
+        return ReadWriteUtils.OBJECT_MAPPER.fromJson(json, clazz);

Review comment:
       why use `fromJson` here?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to