yunfengzhou-hub commented on code in PR #156: URL: https://github.com/apache/flink-ml/pull/156#discussion_r984140802
########## docs/content/docs/operators/feature/vectorassembler.md: ########## @@ -27,8 +27,10 @@ under the License. Review Comment: I found that `VectorSizeHint` is also used in Normalizer and DCT algorithm as in link. Could you please add this parameter to those algorithms as well? https://github.com/apache/spark/blob/0494dc90af48ce7da0625485a4dc6917a244d580/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala#L90 https://github.com/apache/spark/blob/0494dc90af48ce7da0625485a4dc6917a244d580/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala#L95 ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java: ########## @@ -21,11 +21,31 @@ import org.apache.flink.ml.common.param.HasHandleInvalid; import org.apache.flink.ml.common.param.HasInputCols; import org.apache.flink.ml.common.param.HasOutputCol; +import org.apache.flink.ml.param.IntArrayParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; + +import java.util.Arrays; /** * Params of {@link VectorAssembler}. * * @param <T> The class type of this instance. */ public interface VectorAssemblerParams<T> - extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {} + extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> { + Param<Integer[]> SIZES = + new IntArrayParam( + "sizeArray", + "Sizes of the assembling elements.", + null, + ParamValidators.notNull()); + + default int[] getSizes() { + return Arrays.stream(get(SIZES)).mapToInt(Integer::intValue).toArray(); + } + + default T setSizes(Integer... value) { Review Comment: nit: `T setSizes(int... value)` ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java: ########## @@ -74,38 +76,65 @@ public Table[] transform(Table... inputs) { DataStream<Row> output = tEnv.toDataStream(inputs[0]) .flatMap( - new AssemblerFunc(getInputCols(), getHandleInvalid()), + new AssemblerFunction( + getInputCols(), getHandleInvalid(), getSizes()), outputTypeInfo); Table outputTable = tEnv.fromDataStream(output); return new Table[] {outputTable}; } - private static class AssemblerFunc implements FlatMapFunction<Row, Row> { + private static class AssemblerFunction implements FlatMapFunction<Row, Row> { private final String[] inputCols; private final String handleInvalid; + private final int[] sizeArray; - public AssemblerFunc(String[] inputCols, String handleInvalid) { + public AssemblerFunction(String[] inputCols, String handleInvalid, int[] sizeArray) { this.inputCols = inputCols; this.handleInvalid = handleInvalid; + this.sizeArray = sizeArray; } @Override public void flatMap(Row value, Collector<Row> out) { int nnz = 0; int vectorSize = 0; try { - for (String inputCol : inputCols) { + for (int i = 0; i < inputCols.length; ++i) { Review Comment: If `sizeArray.length > inputCols.length`, the code seems to still work, but an exception is expected. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java: ########## @@ -74,38 +76,65 @@ public Table[] transform(Table... inputs) { DataStream<Row> output = tEnv.toDataStream(inputs[0]) .flatMap( - new AssemblerFunc(getInputCols(), getHandleInvalid()), + new AssemblerFunction( + getInputCols(), getHandleInvalid(), getSizes()), outputTypeInfo); Table outputTable = tEnv.fromDataStream(output); return new Table[] {outputTable}; } - private static class AssemblerFunc implements FlatMapFunction<Row, Row> { + private static class AssemblerFunction implements FlatMapFunction<Row, Row> { private final String[] inputCols; private final String handleInvalid; + private final int[] sizeArray; - public AssemblerFunc(String[] inputCols, String handleInvalid) { + public AssemblerFunction(String[] inputCols, String handleInvalid, int[] sizeArray) { this.inputCols = inputCols; this.handleInvalid = handleInvalid; + this.sizeArray = sizeArray; } @Override public void flatMap(Row value, Collector<Row> out) { int nnz = 0; int vectorSize = 0; try { - for (String inputCol : inputCols) { + for (int i = 0; i < inputCols.length; ++i) { + String inputCol = inputCols[i]; Object object = value.getField(inputCol); Preconditions.checkNotNull(object, "Input column value should not be null."); if (object instanceof Number) { + Preconditions.checkArgument( + sizeArray[i] == 1, + "Inconsistent vector size, setSize is " + + sizeArray[i] Review Comment: String concatenation could cost more cpu time than assembling vectors, and we should try to avoid this unless necessary. For example, the following code might work much better. ```java if (sizeArray[i] != 1) { throw new IllegalArgumentException("..." + "..."); } ``` ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java: ########## @@ -21,11 +21,31 @@ import org.apache.flink.ml.common.param.HasHandleInvalid; import org.apache.flink.ml.common.param.HasInputCols; import org.apache.flink.ml.common.param.HasOutputCol; +import org.apache.flink.ml.param.IntArrayParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; + +import java.util.Arrays; /** * Params of {@link VectorAssembler}. * * @param <T> The class type of this instance. */ public interface VectorAssemblerParams<T> - extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {} + extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> { + Param<Integer[]> SIZES = Review Comment: In Spark's equivalent algorithm, `getSize` has a clear context that it is about `VectorSizeHint`, or the size of the vector. But in Flink ML's VectorAssembler, there might be ambiguity for lacking in this context. Thus I would still prefer to add some delimiters to this parameter, like `setInputSizes`. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java: ########## @@ -74,38 +76,65 @@ public Table[] transform(Table... inputs) { DataStream<Row> output = tEnv.toDataStream(inputs[0]) .flatMap( - new AssemblerFunc(getInputCols(), getHandleInvalid()), + new AssemblerFunction( + getInputCols(), getHandleInvalid(), getSizes()), outputTypeInfo); Table outputTable = tEnv.fromDataStream(output); return new Table[] {outputTable}; } - private static class AssemblerFunc implements FlatMapFunction<Row, Row> { + private static class AssemblerFunction implements FlatMapFunction<Row, Row> { private final String[] inputCols; private final String handleInvalid; + private final int[] sizeArray; - public AssemblerFunc(String[] inputCols, String handleInvalid) { + public AssemblerFunction(String[] inputCols, String handleInvalid, int[] sizeArray) { this.inputCols = inputCols; this.handleInvalid = handleInvalid; + this.sizeArray = sizeArray; } @Override public void flatMap(Row value, Collector<Row> out) { int nnz = 0; int vectorSize = 0; try { - for (String inputCol : inputCols) { + for (int i = 0; i < inputCols.length; ++i) { + String inputCol = inputCols[i]; Object object = value.getField(inputCol); Preconditions.checkNotNull(object, "Input column value should not be null."); if (object instanceof Number) { + Preconditions.checkArgument( + sizeArray[i] == 1, + "Inconsistent vector size, setSize is " + + sizeArray[i] + + ", but current size is " + + 1 + + "."); nnz += 1; - vectorSize += 1; + vectorSize += sizeArray[i]; } else if (object instanceof SparseVector) { + int localSize = ((SparseVector) object).size(); + Preconditions.checkArgument( + sizeArray[i] == localSize, + "Inconsistent vector size, setSize is " + + sizeArray[i] + + ", but current vector size is " + + localSize + + "."); nnz += ((SparseVector) object).indices.length; - vectorSize += ((SparseVector) object).size(); + vectorSize += sizeArray[i]; } else if (object instanceof DenseVector) { + int localSize = ((DenseVector) object).size(); + Preconditions.checkArgument( + sizeArray[i] == localSize, + "Inconsistent vector size, setSize is " + + sizeArray[i] + + ", but current vector size is " + + localSize Review Comment: The three `checkArgument`s are almost the same and we may merge them into one common function. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java: ########## @@ -74,38 +76,65 @@ public Table[] transform(Table... inputs) { DataStream<Row> output = tEnv.toDataStream(inputs[0]) .flatMap( - new AssemblerFunc(getInputCols(), getHandleInvalid()), + new AssemblerFunction( + getInputCols(), getHandleInvalid(), getSizes()), outputTypeInfo); Table outputTable = tEnv.fromDataStream(output); return new Table[] {outputTable}; } - private static class AssemblerFunc implements FlatMapFunction<Row, Row> { + private static class AssemblerFunction implements FlatMapFunction<Row, Row> { private final String[] inputCols; private final String handleInvalid; + private final int[] sizeArray; - public AssemblerFunc(String[] inputCols, String handleInvalid) { + public AssemblerFunction(String[] inputCols, String handleInvalid, int[] sizeArray) { this.inputCols = inputCols; this.handleInvalid = handleInvalid; + this.sizeArray = sizeArray; } @Override public void flatMap(Row value, Collector<Row> out) { int nnz = 0; int vectorSize = 0; try { - for (String inputCol : inputCols) { + for (int i = 0; i < inputCols.length; ++i) { + String inputCol = inputCols[i]; Object object = value.getField(inputCol); Preconditions.checkNotNull(object, "Input column value should not be null."); if (object instanceof Number) { Review Comment: During offline discussion, I remembered that the purpose of adding the parameter `size` is to output a vector of size `sum(size)` when some input values might be null. If this is the case, this check and corresponding behavior in this algorithm needs be changed. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java: ########## @@ -21,11 +21,31 @@ import org.apache.flink.ml.common.param.HasHandleInvalid; import org.apache.flink.ml.common.param.HasInputCols; import org.apache.flink.ml.common.param.HasOutputCol; +import org.apache.flink.ml.param.IntArrayParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; + +import java.util.Arrays; /** * Params of {@link VectorAssembler}. * * @param <T> The class type of this instance. */ public interface VectorAssemblerParams<T> - extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {} + extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> { + Param<Integer[]> SIZES = + new IntArrayParam( + "sizeArray", + "Sizes of the assembling elements.", + null, + ParamValidators.notNull()); + + default int[] getSizes() { + return Arrays.stream(get(SIZES)).mapToInt(Integer::intValue).toArray(); Review Comment: nit: `ArrayUtils.toPrimitive` ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java: ########## @@ -47,7 +47,9 @@ /** * A Transformer which combines a given list of input columns into a vector column. Types of input - * columns must be either vector or numerical value. + * columns must be either vector or numerical types. The elements assembled in the same column must + * have the same size. If the size of the element is not equal to sizes[columnIdx], it will throw an + * IllegalArgumentException. Review Comment: Do we need to throw exception when input element is null? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java: ########## @@ -21,11 +21,31 @@ import org.apache.flink.ml.common.param.HasHandleInvalid; import org.apache.flink.ml.common.param.HasInputCols; import org.apache.flink.ml.common.param.HasOutputCol; +import org.apache.flink.ml.param.IntArrayParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; + +import java.util.Arrays; /** * Params of {@link VectorAssembler}. * * @param <T> The class type of this instance. */ public interface VectorAssemblerParams<T> - extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {} + extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> { + Param<Integer[]> SIZES = + new IntArrayParam( + "sizeArray", Review Comment: Let's keep this name consistent with the name of the parameter variable. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java: ########## @@ -21,11 +21,31 @@ import org.apache.flink.ml.common.param.HasHandleInvalid; import org.apache.flink.ml.common.param.HasInputCols; import org.apache.flink.ml.common.param.HasOutputCol; +import org.apache.flink.ml.param.IntArrayParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; + +import java.util.Arrays; /** * Params of {@link VectorAssembler}. * * @param <T> The class type of this instance. */ public interface VectorAssemblerParams<T> - extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {} + extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> { + Param<Integer[]> SIZES = + new IntArrayParam( + "sizeArray", + "Sizes of the assembling elements.", Review Comment: nit: "Sizes of the input elements to be assembled." ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java: ########## @@ -74,38 +76,65 @@ public Table[] transform(Table... inputs) { DataStream<Row> output = tEnv.toDataStream(inputs[0]) .flatMap( - new AssemblerFunc(getInputCols(), getHandleInvalid()), + new AssemblerFunction( + getInputCols(), getHandleInvalid(), getSizes()), outputTypeInfo); Table outputTable = tEnv.fromDataStream(output); return new Table[] {outputTable}; } - private static class AssemblerFunc implements FlatMapFunction<Row, Row> { + private static class AssemblerFunction implements FlatMapFunction<Row, Row> { private final String[] inputCols; private final String handleInvalid; + private final int[] sizeArray; - public AssemblerFunc(String[] inputCols, String handleInvalid) { + public AssemblerFunction(String[] inputCols, String handleInvalid, int[] sizeArray) { Review Comment: Let's either use `sizeArray` or `sizes` across all references to this parameter. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org