yunfengzhou-hub commented on code in PR #156:
URL: https://github.com/apache/flink-ml/pull/156#discussion_r984140802


##########
docs/content/docs/operators/feature/vectorassembler.md:
##########
@@ -27,8 +27,10 @@ under the License.
 

Review Comment:
   I found that `VectorSizeHint` is also used in Normalizer and DCT algorithm 
as in link. Could you please add this parameter to those algorithms as well?
   
   
https://github.com/apache/spark/blob/0494dc90af48ce7da0625485a4dc6917a244d580/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala#L90
   
https://github.com/apache/spark/blob/0494dc90af48ce7da0625485a4dc6917a244d580/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala#L95



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java:
##########
@@ -21,11 +21,31 @@
 import org.apache.flink.ml.common.param.HasHandleInvalid;
 import org.apache.flink.ml.common.param.HasInputCols;
 import org.apache.flink.ml.common.param.HasOutputCol;
+import org.apache.flink.ml.param.IntArrayParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+import java.util.Arrays;
 
 /**
  * Params of {@link VectorAssembler}.
  *
  * @param <T> The class type of this instance.
  */
 public interface VectorAssemblerParams<T>
-        extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {}
+        extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {
+    Param<Integer[]> SIZES =
+            new IntArrayParam(
+                    "sizeArray",
+                    "Sizes of the assembling elements.",
+                    null,
+                    ParamValidators.notNull());
+
+    default int[] getSizes() {
+        return Arrays.stream(get(SIZES)).mapToInt(Integer::intValue).toArray();
+    }
+
+    default T setSizes(Integer... value) {

Review Comment:
   nit: `T setSizes(int... value)`



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java:
##########
@@ -74,38 +76,65 @@ public Table[] transform(Table... inputs) {
         DataStream<Row> output =
                 tEnv.toDataStream(inputs[0])
                         .flatMap(
-                                new AssemblerFunc(getInputCols(), 
getHandleInvalid()),
+                                new AssemblerFunction(
+                                        getInputCols(), getHandleInvalid(), 
getSizes()),
                                 outputTypeInfo);
         Table outputTable = tEnv.fromDataStream(output);
         return new Table[] {outputTable};
     }
 
-    private static class AssemblerFunc implements FlatMapFunction<Row, Row> {
+    private static class AssemblerFunction implements FlatMapFunction<Row, 
Row> {
         private final String[] inputCols;
         private final String handleInvalid;
+        private final int[] sizeArray;
 
-        public AssemblerFunc(String[] inputCols, String handleInvalid) {
+        public AssemblerFunction(String[] inputCols, String handleInvalid, 
int[] sizeArray) {
             this.inputCols = inputCols;
             this.handleInvalid = handleInvalid;
+            this.sizeArray = sizeArray;
         }
 
         @Override
         public void flatMap(Row value, Collector<Row> out) {
             int nnz = 0;
             int vectorSize = 0;
             try {
-                for (String inputCol : inputCols) {
+                for (int i = 0; i < inputCols.length; ++i) {

Review Comment:
   If `sizeArray.length > inputCols.length`, the code seems to still work, but 
an exception is expected.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java:
##########
@@ -74,38 +76,65 @@ public Table[] transform(Table... inputs) {
         DataStream<Row> output =
                 tEnv.toDataStream(inputs[0])
                         .flatMap(
-                                new AssemblerFunc(getInputCols(), 
getHandleInvalid()),
+                                new AssemblerFunction(
+                                        getInputCols(), getHandleInvalid(), 
getSizes()),
                                 outputTypeInfo);
         Table outputTable = tEnv.fromDataStream(output);
         return new Table[] {outputTable};
     }
 
-    private static class AssemblerFunc implements FlatMapFunction<Row, Row> {
+    private static class AssemblerFunction implements FlatMapFunction<Row, 
Row> {
         private final String[] inputCols;
         private final String handleInvalid;
+        private final int[] sizeArray;
 
-        public AssemblerFunc(String[] inputCols, String handleInvalid) {
+        public AssemblerFunction(String[] inputCols, String handleInvalid, 
int[] sizeArray) {
             this.inputCols = inputCols;
             this.handleInvalid = handleInvalid;
+            this.sizeArray = sizeArray;
         }
 
         @Override
         public void flatMap(Row value, Collector<Row> out) {
             int nnz = 0;
             int vectorSize = 0;
             try {
-                for (String inputCol : inputCols) {
+                for (int i = 0; i < inputCols.length; ++i) {
+                    String inputCol = inputCols[i];
                     Object object = value.getField(inputCol);
                     Preconditions.checkNotNull(object, "Input column value 
should not be null.");
                     if (object instanceof Number) {
+                        Preconditions.checkArgument(
+                                sizeArray[i] == 1,
+                                "Inconsistent vector size, setSize is "
+                                        + sizeArray[i]

Review Comment:
   String concatenation could cost more cpu time than assembling vectors, and 
we should try to avoid this unless necessary. For example, the following code 
might work much better.
   ```java
   if (sizeArray[i] != 1) {
       throw new IllegalArgumentException("..." + "...");
   }
   ```



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java:
##########
@@ -21,11 +21,31 @@
 import org.apache.flink.ml.common.param.HasHandleInvalid;
 import org.apache.flink.ml.common.param.HasInputCols;
 import org.apache.flink.ml.common.param.HasOutputCol;
+import org.apache.flink.ml.param.IntArrayParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+import java.util.Arrays;
 
 /**
  * Params of {@link VectorAssembler}.
  *
  * @param <T> The class type of this instance.
  */
 public interface VectorAssemblerParams<T>
-        extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {}
+        extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {
+    Param<Integer[]> SIZES =

Review Comment:
   In Spark's equivalent algorithm, `getSize` has a clear context that it is 
about `VectorSizeHint`, or the size of the vector. But in Flink ML's 
VectorAssembler, there might be ambiguity for lacking in this context. Thus I 
would still prefer to add some delimiters to this parameter, like 
`setInputSizes`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java:
##########
@@ -74,38 +76,65 @@ public Table[] transform(Table... inputs) {
         DataStream<Row> output =
                 tEnv.toDataStream(inputs[0])
                         .flatMap(
-                                new AssemblerFunc(getInputCols(), 
getHandleInvalid()),
+                                new AssemblerFunction(
+                                        getInputCols(), getHandleInvalid(), 
getSizes()),
                                 outputTypeInfo);
         Table outputTable = tEnv.fromDataStream(output);
         return new Table[] {outputTable};
     }
 
-    private static class AssemblerFunc implements FlatMapFunction<Row, Row> {
+    private static class AssemblerFunction implements FlatMapFunction<Row, 
Row> {
         private final String[] inputCols;
         private final String handleInvalid;
+        private final int[] sizeArray;
 
-        public AssemblerFunc(String[] inputCols, String handleInvalid) {
+        public AssemblerFunction(String[] inputCols, String handleInvalid, 
int[] sizeArray) {
             this.inputCols = inputCols;
             this.handleInvalid = handleInvalid;
+            this.sizeArray = sizeArray;
         }
 
         @Override
         public void flatMap(Row value, Collector<Row> out) {
             int nnz = 0;
             int vectorSize = 0;
             try {
-                for (String inputCol : inputCols) {
+                for (int i = 0; i < inputCols.length; ++i) {
+                    String inputCol = inputCols[i];
                     Object object = value.getField(inputCol);
                     Preconditions.checkNotNull(object, "Input column value 
should not be null.");
                     if (object instanceof Number) {
+                        Preconditions.checkArgument(
+                                sizeArray[i] == 1,
+                                "Inconsistent vector size, setSize is "
+                                        + sizeArray[i]
+                                        + ", but current size is "
+                                        + 1
+                                        + ".");
                         nnz += 1;
-                        vectorSize += 1;
+                        vectorSize += sizeArray[i];
                     } else if (object instanceof SparseVector) {
+                        int localSize = ((SparseVector) object).size();
+                        Preconditions.checkArgument(
+                                sizeArray[i] == localSize,
+                                "Inconsistent vector size, setSize is "
+                                        + sizeArray[i]
+                                        + ", but current vector size is "
+                                        + localSize
+                                        + ".");
                         nnz += ((SparseVector) object).indices.length;
-                        vectorSize += ((SparseVector) object).size();
+                        vectorSize += sizeArray[i];
                     } else if (object instanceof DenseVector) {
+                        int localSize = ((DenseVector) object).size();
+                        Preconditions.checkArgument(
+                                sizeArray[i] == localSize,
+                                "Inconsistent vector size, setSize is "
+                                        + sizeArray[i]
+                                        + ", but current vector size is "
+                                        + localSize

Review Comment:
   The three `checkArgument`s are almost the same and we may merge them into 
one common function.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java:
##########
@@ -74,38 +76,65 @@ public Table[] transform(Table... inputs) {
         DataStream<Row> output =
                 tEnv.toDataStream(inputs[0])
                         .flatMap(
-                                new AssemblerFunc(getInputCols(), 
getHandleInvalid()),
+                                new AssemblerFunction(
+                                        getInputCols(), getHandleInvalid(), 
getSizes()),
                                 outputTypeInfo);
         Table outputTable = tEnv.fromDataStream(output);
         return new Table[] {outputTable};
     }
 
-    private static class AssemblerFunc implements FlatMapFunction<Row, Row> {
+    private static class AssemblerFunction implements FlatMapFunction<Row, 
Row> {
         private final String[] inputCols;
         private final String handleInvalid;
+        private final int[] sizeArray;
 
-        public AssemblerFunc(String[] inputCols, String handleInvalid) {
+        public AssemblerFunction(String[] inputCols, String handleInvalid, 
int[] sizeArray) {
             this.inputCols = inputCols;
             this.handleInvalid = handleInvalid;
+            this.sizeArray = sizeArray;
         }
 
         @Override
         public void flatMap(Row value, Collector<Row> out) {
             int nnz = 0;
             int vectorSize = 0;
             try {
-                for (String inputCol : inputCols) {
+                for (int i = 0; i < inputCols.length; ++i) {
+                    String inputCol = inputCols[i];
                     Object object = value.getField(inputCol);
                     Preconditions.checkNotNull(object, "Input column value 
should not be null.");
                     if (object instanceof Number) {

Review Comment:
   During offline discussion, I remembered that the purpose of adding the 
parameter `size` is to output a vector of size `sum(size)` when some input 
values might be null. If this is the case, this check and corresponding 
behavior in this algorithm needs be changed.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java:
##########
@@ -21,11 +21,31 @@
 import org.apache.flink.ml.common.param.HasHandleInvalid;
 import org.apache.flink.ml.common.param.HasInputCols;
 import org.apache.flink.ml.common.param.HasOutputCol;
+import org.apache.flink.ml.param.IntArrayParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+import java.util.Arrays;
 
 /**
  * Params of {@link VectorAssembler}.
  *
  * @param <T> The class type of this instance.
  */
 public interface VectorAssemblerParams<T>
-        extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {}
+        extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {
+    Param<Integer[]> SIZES =
+            new IntArrayParam(
+                    "sizeArray",
+                    "Sizes of the assembling elements.",
+                    null,
+                    ParamValidators.notNull());
+
+    default int[] getSizes() {
+        return Arrays.stream(get(SIZES)).mapToInt(Integer::intValue).toArray();

Review Comment:
   nit: `ArrayUtils.toPrimitive`



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java:
##########
@@ -47,7 +47,9 @@
 
 /**
  * A Transformer which combines a given list of input columns into a vector 
column. Types of input
- * columns must be either vector or numerical value.
+ * columns must be either vector or numerical types. The elements assembled in 
the same column must
+ * have the same size. If the size of the element is not equal to 
sizes[columnIdx], it will throw an
+ * IllegalArgumentException.

Review Comment:
   Do we need to throw exception when input element is null?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java:
##########
@@ -21,11 +21,31 @@
 import org.apache.flink.ml.common.param.HasHandleInvalid;
 import org.apache.flink.ml.common.param.HasInputCols;
 import org.apache.flink.ml.common.param.HasOutputCol;
+import org.apache.flink.ml.param.IntArrayParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+import java.util.Arrays;
 
 /**
  * Params of {@link VectorAssembler}.
  *
  * @param <T> The class type of this instance.
  */
 public interface VectorAssemblerParams<T>
-        extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {}
+        extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {
+    Param<Integer[]> SIZES =
+            new IntArrayParam(
+                    "sizeArray",

Review Comment:
   Let's keep this name consistent with the name of the parameter variable.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java:
##########
@@ -21,11 +21,31 @@
 import org.apache.flink.ml.common.param.HasHandleInvalid;
 import org.apache.flink.ml.common.param.HasInputCols;
 import org.apache.flink.ml.common.param.HasOutputCol;
+import org.apache.flink.ml.param.IntArrayParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+import java.util.Arrays;
 
 /**
  * Params of {@link VectorAssembler}.
  *
  * @param <T> The class type of this instance.
  */
 public interface VectorAssemblerParams<T>
-        extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {}
+        extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {
+    Param<Integer[]> SIZES =
+            new IntArrayParam(
+                    "sizeArray",
+                    "Sizes of the assembling elements.",

Review Comment:
   nit: "Sizes of the input elements to be assembled."



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java:
##########
@@ -74,38 +76,65 @@ public Table[] transform(Table... inputs) {
         DataStream<Row> output =
                 tEnv.toDataStream(inputs[0])
                         .flatMap(
-                                new AssemblerFunc(getInputCols(), 
getHandleInvalid()),
+                                new AssemblerFunction(
+                                        getInputCols(), getHandleInvalid(), 
getSizes()),
                                 outputTypeInfo);
         Table outputTable = tEnv.fromDataStream(output);
         return new Table[] {outputTable};
     }
 
-    private static class AssemblerFunc implements FlatMapFunction<Row, Row> {
+    private static class AssemblerFunction implements FlatMapFunction<Row, 
Row> {
         private final String[] inputCols;
         private final String handleInvalid;
+        private final int[] sizeArray;
 
-        public AssemblerFunc(String[] inputCols, String handleInvalid) {
+        public AssemblerFunction(String[] inputCols, String handleInvalid, 
int[] sizeArray) {

Review Comment:
   Let's either use `sizeArray` or `sizes` across all references to this 
parameter.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to