zhipeng93 commented on code in PR #132:
URL: https://github.com/apache/flink-ml/pull/132#discussion_r934122580


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java:
##########
@@ -0,0 +1,654 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.stats.chisqtest;
+
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeHint;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.math3.distribution.ChiSquaredDistribution;
+
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.math.RoundingMode;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+/**
+ * An AlgoOperator which implements the Chi-square test algorithm.
+ *
+ * <p>Chi-square Test is an AlgoOperator that computes the statistics of 
independence of variables

Review Comment:
   How about remove `is an AlgoOperator` in this line and update the java doc 
as:
   
   ```
   Chi-square Test computes the statistics of independence of variables in a 
contingency table, e.g., p-value, and DOF(number of degrees of freedom) for 
each input feature. The contingency table is constructed from the observed 
categorical values.
   ```



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java:
##########
@@ -0,0 +1,654 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.stats.chisqtest;
+
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeHint;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.math3.distribution.ChiSquaredDistribution;
+
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.math.RoundingMode;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+/**
+ * An AlgoOperator which implements the Chi-square test algorithm.
+ *
+ * <p>Chi-square Test is an AlgoOperator that computes the statistics of 
independence of variables
+ * in a contingency table. This function computes the chi-square statistic, 
p-value, and DOF(number
+ * of degrees of freedom) for every feature in the contingency table. The 
contingency table is
+ * constructed from the observed categorical values.
+ *
+ * <p>See: http://en.wikipedia.org/wiki/Chi-squared_test.
+ */
+public class ChiSqTest implements AlgoOperator<ChiSqTest>, 
ChiSqTestParams<ChiSqTest> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public ChiSqTest() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        final String bcCategoricalMarginsKey = "bcCategoricalMarginsKey";
+        final String bcLabelMarginsKey = "bcLabelMarginsKey";
+
+        final String[] inputCols = getInputCols();
+        String labelCol = getLabelCol();
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        SingleOutputStreamOperator<Tuple3<String, Object, Object>> 
colAndFeatureAndLabel =
+                tEnv.toDataStream(inputs[0])
+                        .flatMap(new ExtractColAndFeatureAndLabel(inputCols, 
labelCol));
+
+        DataStream<Tuple4<String, Object, Object, Long>> observedFreq =
+                colAndFeatureAndLabel
+                        .keyBy(Tuple3::hashCode)
+                        .transform(
+                                "GenerateObservedFrequencies",
+                                TypeInformation.of(
+                                        new TypeHint<Tuple4<String, Object, 
Object, Long>>() {}),
+                                new GenerateObservedFrequencies());
+
+        SingleOutputStreamOperator<Tuple4<String, Object, Object, Long>> 
filledObservedFreq =
+                observedFreq
+                        .transform(
+                                "filledObservedFreq",
+                                Types.TUPLE(
+                                        Types.STRING,
+                                        Types.GENERIC(Object.class),
+                                        Types.GENERIC(Object.class),
+                                        Types.LONG),
+                                new FillZeroFunc())
+                        .setParallelism(1);
+
+        DataStream<Tuple3<String, Object, Long>> categoricalMargins =
+                observedFreq
+                        .keyBy(tuple -> new Tuple2<>(tuple.f0, 
tuple.f1).hashCode())
+                        .transform(
+                                "AggregateCategoricalMargins",
+                                TypeInformation.of(new TypeHint<Tuple3<String, 
Object, Long>>() {}),
+                                new AggregateCategoricalMargins());
+
+        DataStream<Tuple3<String, Object, Long>> labelMargins =
+                observedFreq
+                        .keyBy(tuple -> new Tuple2<>(tuple.f0, 
tuple.f2).hashCode())
+                        .transform(
+                                "AggregateLabelMargins",
+                                TypeInformation.of(new TypeHint<Tuple3<String, 
Object, Long>>() {}),
+                                new AggregateLabelMargins());
+
+        Function<List<DataStream<?>>, DataStream<Tuple3<String, Double, 
Integer>>> function =
+                dataStreams -> {
+                    DataStream stream = dataStreams.get(0);
+                    return stream.map(new ChiSqFunc(bcCategoricalMarginsKey, 
bcLabelMarginsKey));
+                };
+
+        HashMap<String, DataStream<?>> bcMap =
+                new HashMap<String, DataStream<?>>() {
+                    {
+                        put(bcCategoricalMarginsKey, categoricalMargins);
+                        put(bcLabelMarginsKey, labelMargins);
+                    }
+                };
+
+        DataStream<Tuple3<String, Double, Integer>> categoricalStatistics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(filledObservedFreq), bcMap, 
function);
+
+        SingleOutputStreamOperator<Row> chiSqTestResult =
+                categoricalStatistics
+                        .transform(
+                                "chiSqTestResult",
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            Types.STRING, Types.DOUBLE, 
Types.DOUBLE, Types.INT
+                                        },
+                                        new String[] {
+                                            "column", "pValue", "statistic", 
"degreesOfFreedom"
+                                        }),
+                                new AggregateChiSqFunc())
+                        .setParallelism(1);
+
+        return new Table[] {tEnv.fromDataStream(chiSqTestResult)};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static ChiSqTest load(StreamTableEnvironment tEnv, String path) 
throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class ExtractColAndFeatureAndLabel
+            extends RichFlatMapFunction<Row, Tuple3<String, Object, Object>> {
+        private final String[] inputCols;
+        private final String labelCol;
+
+        public ExtractColAndFeatureAndLabel(String[] inputCols, String 
labelCol) {
+            this.inputCols = inputCols;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public void flatMap(Row row, Collector<Tuple3<String, Object, Object>> 
collector) {
+
+            Object label = row.getFieldAs(labelCol);
+
+            for (String colName : inputCols) {
+                Object value = row.getField(colName);
+                collector.collect(new Tuple3<>(colName, value, label));
+            }
+        }
+    }
+
+    /**
+     * Computes a frequency table(DataStream) of the factors(categorical 
values). The returned

Review Comment:
   How about update the java doc as:
   ```
   Computes the frequency of each feature value at different columns by labels. 
An output record
   (columnA, featureValueB, labelC, countD) represents that A feature value 
{featureValueB} with
   label {labelC} at column {columnA} has appeared {countD} times in the input 
table.
   ```



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java:
##########
@@ -0,0 +1,654 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.stats.chisqtest;
+
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeHint;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.math3.distribution.ChiSquaredDistribution;
+
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.math.RoundingMode;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+/**
+ * An AlgoOperator which implements the Chi-square test algorithm.
+ *
+ * <p>Chi-square Test is an AlgoOperator that computes the statistics of 
independence of variables
+ * in a contingency table. This function computes the chi-square statistic, 
p-value, and DOF(number
+ * of degrees of freedom) for every feature in the contingency table. The 
contingency table is
+ * constructed from the observed categorical values.
+ *
+ * <p>See: http://en.wikipedia.org/wiki/Chi-squared_test.
+ */
+public class ChiSqTest implements AlgoOperator<ChiSqTest>, 
ChiSqTestParams<ChiSqTest> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public ChiSqTest() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        final String bcCategoricalMarginsKey = "bcCategoricalMarginsKey";
+        final String bcLabelMarginsKey = "bcLabelMarginsKey";
+
+        final String[] inputCols = getInputCols();
+        String labelCol = getLabelCol();
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        SingleOutputStreamOperator<Tuple3<String, Object, Object>> 
colAndFeatureAndLabel =
+                tEnv.toDataStream(inputs[0])
+                        .flatMap(new ExtractColAndFeatureAndLabel(inputCols, 
labelCol));
+
+        DataStream<Tuple4<String, Object, Object, Long>> observedFreq =
+                colAndFeatureAndLabel
+                        .keyBy(Tuple3::hashCode)
+                        .transform(
+                                "GenerateObservedFrequencies",
+                                TypeInformation.of(
+                                        new TypeHint<Tuple4<String, Object, 
Object, Long>>() {}),
+                                new GenerateObservedFrequencies());
+
+        SingleOutputStreamOperator<Tuple4<String, Object, Object, Long>> 
filledObservedFreq =
+                observedFreq
+                        .transform(
+                                "filledObservedFreq",
+                                Types.TUPLE(
+                                        Types.STRING,
+                                        Types.GENERIC(Object.class),
+                                        Types.GENERIC(Object.class),
+                                        Types.LONG),
+                                new FillZeroFunc())
+                        .setParallelism(1);
+
+        DataStream<Tuple3<String, Object, Long>> categoricalMargins =
+                observedFreq
+                        .keyBy(tuple -> new Tuple2<>(tuple.f0, 
tuple.f1).hashCode())
+                        .transform(
+                                "AggregateCategoricalMargins",
+                                TypeInformation.of(new TypeHint<Tuple3<String, 
Object, Long>>() {}),
+                                new AggregateCategoricalMargins());
+
+        DataStream<Tuple3<String, Object, Long>> labelMargins =
+                observedFreq
+                        .keyBy(tuple -> new Tuple2<>(tuple.f0, 
tuple.f2).hashCode())
+                        .transform(
+                                "AggregateLabelMargins",
+                                TypeInformation.of(new TypeHint<Tuple3<String, 
Object, Long>>() {}),
+                                new AggregateLabelMargins());
+
+        Function<List<DataStream<?>>, DataStream<Tuple3<String, Double, 
Integer>>> function =
+                dataStreams -> {
+                    DataStream stream = dataStreams.get(0);
+                    return stream.map(new ChiSqFunc(bcCategoricalMarginsKey, 
bcLabelMarginsKey));
+                };
+
+        HashMap<String, DataStream<?>> bcMap =
+                new HashMap<String, DataStream<?>>() {
+                    {
+                        put(bcCategoricalMarginsKey, categoricalMargins);
+                        put(bcLabelMarginsKey, labelMargins);
+                    }
+                };
+
+        DataStream<Tuple3<String, Double, Integer>> categoricalStatistics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(filledObservedFreq), bcMap, 
function);
+
+        SingleOutputStreamOperator<Row> chiSqTestResult =
+                categoricalStatistics
+                        .transform(
+                                "chiSqTestResult",
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            Types.STRING, Types.DOUBLE, 
Types.DOUBLE, Types.INT
+                                        },
+                                        new String[] {
+                                            "column", "pValue", "statistic", 
"degreesOfFreedom"
+                                        }),
+                                new AggregateChiSqFunc())
+                        .setParallelism(1);
+
+        return new Table[] {tEnv.fromDataStream(chiSqTestResult)};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static ChiSqTest load(StreamTableEnvironment tEnv, String path) 
throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class ExtractColAndFeatureAndLabel
+            extends RichFlatMapFunction<Row, Tuple3<String, Object, Object>> {
+        private final String[] inputCols;
+        private final String labelCol;
+
+        public ExtractColAndFeatureAndLabel(String[] inputCols, String 
labelCol) {
+            this.inputCols = inputCols;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public void flatMap(Row row, Collector<Tuple3<String, Object, Object>> 
collector) {
+
+            Object label = row.getFieldAs(labelCol);
+
+            for (String colName : inputCols) {
+                Object value = row.getField(colName);
+                collector.collect(new Tuple3<>(colName, value, label));
+            }
+        }
+    }
+
+    /**
+     * Computes a frequency table(DataStream) of the factors(categorical 
values). The returned
+     * DataStream contains the observed frequencies (i.e. number of 
occurrences) in each category.
+     */
+    private static class GenerateObservedFrequencies
+            extends AbstractStreamOperator<Tuple4<String, Object, Object, 
Long>>
+            implements OneInputStreamOperator<
+                            Tuple3<String, Object, Object>, Tuple4<String, 
Object, Object, Long>>,
+                    BoundedOneInput {
+
+        private HashMap<Tuple3<String, Object, Object>, Long> cntMap = new 
HashMap<>();
+        private ListState<HashMap<Tuple3<String, Object, Object>, Long>> 
cntMapState;
+
+        @Override
+        public void endInput() {
+            for (Tuple3<String, Object, Object> key : cntMap.keySet()) {
+                Long count = cntMap.get(key);
+                output.collect(new StreamRecord<>(new Tuple4<>(key.f0, key.f1, 
key.f2, count)));
+            }
+            cntMapState.clear();
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<String, Object, 
Object>> element) {
+
+            Tuple3<String, Object, Object> colAndCategoryAndLabel = 
element.getValue();
+            cntMap.compute(colAndCategoryAndLabel, (k, v) -> (v == null ? 1 : 
v + 1));
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            cntMapState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "cntMapState",
+                                            TypeInformation.of(
+                                                    new TypeHint<
+                                                            HashMap<
+                                                                    
Tuple3<String, Object, Object>,
+                                                                    Long>>() 
{})));
+
+            OperatorStateUtils.getUniqueElement(cntMapState, "cntMapState")
+                    .ifPresent(x -> cntMap = x);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+            cntMapState.update(Collections.singletonList(cntMap));
+        }
+    }
+
+    /** Fills the factors which frequencies are zero in frequency table. */

Review Comment:
   How about change the function name to `fillFrequencyTable` and update the 
java doc as follows: 
   ```
   Fills the frequency table by setting the frequency of missed elements (i.e., 
missed combinations of column, featureValue and labelValue) as zero.
   ```



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java:
##########
@@ -0,0 +1,654 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.stats.chisqtest;
+
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeHint;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.math3.distribution.ChiSquaredDistribution;
+
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.math.RoundingMode;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+/**
+ * An AlgoOperator which implements the Chi-square test algorithm.
+ *
+ * <p>Chi-square Test is an AlgoOperator that computes the statistics of 
independence of variables
+ * in a contingency table. This function computes the chi-square statistic, 
p-value, and DOF(number
+ * of degrees of freedom) for every feature in the contingency table. The 
contingency table is
+ * constructed from the observed categorical values.
+ *
+ * <p>See: http://en.wikipedia.org/wiki/Chi-squared_test.
+ */
+public class ChiSqTest implements AlgoOperator<ChiSqTest>, 
ChiSqTestParams<ChiSqTest> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public ChiSqTest() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        final String bcCategoricalMarginsKey = "bcCategoricalMarginsKey";
+        final String bcLabelMarginsKey = "bcLabelMarginsKey";
+
+        final String[] inputCols = getInputCols();
+        String labelCol = getLabelCol();
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        SingleOutputStreamOperator<Tuple3<String, Object, Object>> 
colAndFeatureAndLabel =
+                tEnv.toDataStream(inputs[0])
+                        .flatMap(new ExtractColAndFeatureAndLabel(inputCols, 
labelCol));
+
+        DataStream<Tuple4<String, Object, Object, Long>> observedFreq =
+                colAndFeatureAndLabel
+                        .keyBy(Tuple3::hashCode)
+                        .transform(
+                                "GenerateObservedFrequencies",
+                                TypeInformation.of(
+                                        new TypeHint<Tuple4<String, Object, 
Object, Long>>() {}),
+                                new GenerateObservedFrequencies());
+
+        SingleOutputStreamOperator<Tuple4<String, Object, Object, Long>> 
filledObservedFreq =
+                observedFreq
+                        .transform(
+                                "filledObservedFreq",
+                                Types.TUPLE(
+                                        Types.STRING,
+                                        Types.GENERIC(Object.class),
+                                        Types.GENERIC(Object.class),
+                                        Types.LONG),
+                                new FillZeroFunc())
+                        .setParallelism(1);
+
+        DataStream<Tuple3<String, Object, Long>> categoricalMargins =
+                observedFreq
+                        .keyBy(tuple -> new Tuple2<>(tuple.f0, 
tuple.f1).hashCode())
+                        .transform(
+                                "AggregateCategoricalMargins",
+                                TypeInformation.of(new TypeHint<Tuple3<String, 
Object, Long>>() {}),
+                                new AggregateCategoricalMargins());
+
+        DataStream<Tuple3<String, Object, Long>> labelMargins =
+                observedFreq
+                        .keyBy(tuple -> new Tuple2<>(tuple.f0, 
tuple.f2).hashCode())
+                        .transform(
+                                "AggregateLabelMargins",
+                                TypeInformation.of(new TypeHint<Tuple3<String, 
Object, Long>>() {}),
+                                new AggregateLabelMargins());
+
+        Function<List<DataStream<?>>, DataStream<Tuple3<String, Double, 
Integer>>> function =
+                dataStreams -> {
+                    DataStream stream = dataStreams.get(0);
+                    return stream.map(new ChiSqFunc(bcCategoricalMarginsKey, 
bcLabelMarginsKey));
+                };
+
+        HashMap<String, DataStream<?>> bcMap =
+                new HashMap<String, DataStream<?>>() {
+                    {
+                        put(bcCategoricalMarginsKey, categoricalMargins);
+                        put(bcLabelMarginsKey, labelMargins);
+                    }
+                };
+
+        DataStream<Tuple3<String, Double, Integer>> categoricalStatistics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(filledObservedFreq), bcMap, 
function);
+
+        SingleOutputStreamOperator<Row> chiSqTestResult =
+                categoricalStatistics
+                        .transform(
+                                "chiSqTestResult",
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            Types.STRING, Types.DOUBLE, 
Types.DOUBLE, Types.INT
+                                        },
+                                        new String[] {
+                                            "column", "pValue", "statistic", 
"degreesOfFreedom"
+                                        }),
+                                new AggregateChiSqFunc())
+                        .setParallelism(1);
+
+        return new Table[] {tEnv.fromDataStream(chiSqTestResult)};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static ChiSqTest load(StreamTableEnvironment tEnv, String path) 
throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class ExtractColAndFeatureAndLabel
+            extends RichFlatMapFunction<Row, Tuple3<String, Object, Object>> {
+        private final String[] inputCols;
+        private final String labelCol;
+
+        public ExtractColAndFeatureAndLabel(String[] inputCols, String 
labelCol) {
+            this.inputCols = inputCols;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public void flatMap(Row row, Collector<Tuple3<String, Object, Object>> 
collector) {
+
+            Object label = row.getFieldAs(labelCol);
+
+            for (String colName : inputCols) {
+                Object value = row.getField(colName);
+                collector.collect(new Tuple3<>(colName, value, label));
+            }
+        }
+    }
+
+    /**
+     * Computes a frequency table(DataStream) of the factors(categorical 
values). The returned
+     * DataStream contains the observed frequencies (i.e. number of 
occurrences) in each category.
+     */
+    private static class GenerateObservedFrequencies
+            extends AbstractStreamOperator<Tuple4<String, Object, Object, 
Long>>
+            implements OneInputStreamOperator<
+                            Tuple3<String, Object, Object>, Tuple4<String, 
Object, Object, Long>>,
+                    BoundedOneInput {
+
+        private HashMap<Tuple3<String, Object, Object>, Long> cntMap = new 
HashMap<>();
+        private ListState<HashMap<Tuple3<String, Object, Object>, Long>> 
cntMapState;
+
+        @Override
+        public void endInput() {
+            for (Tuple3<String, Object, Object> key : cntMap.keySet()) {
+                Long count = cntMap.get(key);
+                output.collect(new StreamRecord<>(new Tuple4<>(key.f0, key.f1, 
key.f2, count)));
+            }
+            cntMapState.clear();
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<String, Object, 
Object>> element) {
+
+            Tuple3<String, Object, Object> colAndCategoryAndLabel = 
element.getValue();
+            cntMap.compute(colAndCategoryAndLabel, (k, v) -> (v == null ? 1 : 
v + 1));
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            cntMapState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "cntMapState",
+                                            TypeInformation.of(
+                                                    new TypeHint<
+                                                            HashMap<
+                                                                    
Tuple3<String, Object, Object>,
+                                                                    Long>>() 
{})));
+
+            OperatorStateUtils.getUniqueElement(cntMapState, "cntMapState")
+                    .ifPresent(x -> cntMap = x);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+            cntMapState.update(Collections.singletonList(cntMap));
+        }
+    }
+
+    /** Fills the factors which frequencies are zero in frequency table. */
+    private static class FillZeroFunc
+            extends AbstractStreamOperator<Tuple4<String, Object, Object, 
Long>>
+            implements OneInputStreamOperator<
+                            Tuple4<String, Object, Object, Long>,
+                            Tuple4<String, Object, Object, Long>>,
+                    BoundedOneInput {
+
+        private HashMap<Tuple2<String, Object>, ArrayList<Tuple2<Object, 
Long>>> valuesMap =
+                new HashMap<>();
+        private HashSet<Object> distinctLabels = new HashSet<>();
+
+        private ListState<HashMap<Tuple2<String, Object>, 
ArrayList<Tuple2<Object, Long>>>>
+                valuesMapState;
+        private ListState<HashSet<Object>> distinctLabelsState;
+
+        @Override
+        public void endInput() {
+
+            for (Map.Entry<Tuple2<String, Object>, ArrayList<Tuple2<Object, 
Long>>> entry :
+                    valuesMap.entrySet()) {
+                ArrayList<Tuple2<Object, Long>> labelAndCountList = 
entry.getValue();
+                Tuple2<String, Object> categoricalKey = entry.getKey();
+
+                List<Object> existingLabels =
+                        labelAndCountList.stream().map(v -> 
v.f0).collect(Collectors.toList());
+
+                for (Object label : distinctLabels) {
+                    if (!existingLabels.contains(label)) {
+                        Tuple2<Object, Long> generatedLabelCount = new 
Tuple2<>(label, 0L);
+                        labelAndCountList.add(generatedLabelCount);
+                    }
+                }
+
+                for (Tuple2<Object, Long> labelAndCount : labelAndCountList) {
+                    output.collect(
+                            new StreamRecord<>(
+                                    new Tuple4<>(
+                                            categoricalKey.f0,
+                                            categoricalKey.f1,
+                                            labelAndCount.f0,
+                                            labelAndCount.f1)));
+                }
+            }
+
+            valuesMapState.clear();
+            distinctLabelsState.clear();
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple4<String, Object, Object, 
Long>> element) {
+            Tuple4<String, Object, Object, Long> 
colAndCategoryAndLabelAndCount =
+                    element.getValue();
+            Tuple2<String, Object> key =
+                    new Tuple2<>(
+                            colAndCategoryAndLabelAndCount.f0, 
colAndCategoryAndLabelAndCount.f1);
+            Tuple2<Object, Long> labelAndCount =
+                    new Tuple2<>(
+                            colAndCategoryAndLabelAndCount.f2, 
colAndCategoryAndLabelAndCount.f3);
+            ArrayList<Tuple2<Object, Long>> labelAndCountList = 
valuesMap.get(key);
+
+            if (labelAndCountList == null) {
+                ArrayList<Tuple2<Object, Long>> value = new ArrayList<>();
+                value.add(labelAndCount);
+                valuesMap.put(key, value);
+            } else {
+                labelAndCountList.add(labelAndCount);
+            }
+
+            distinctLabels.add(colAndCategoryAndLabelAndCount.f2);
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            valuesMapState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "valuesMapState",
+                                            TypeInformation.of(
+                                                    new TypeHint<
+                                                            HashMap<
+                                                                    
Tuple2<String, Object>,
+                                                                    ArrayList<
+                                                                            
Tuple2<
+                                                                               
     Object,
+                                                                               
     Long>>>>() {})));
+            distinctLabelsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "distinctLabelsState",
+                                            TypeInformation.of(
+                                                    new 
TypeHint<HashSet<Object>>() {})));
+
+            OperatorStateUtils.getUniqueElement(valuesMapState, 
"valuesMapState")
+                    .ifPresent(x -> valuesMap = x);
+
+            OperatorStateUtils.getUniqueElement(distinctLabelsState, 
"distinctLabelsState")
+                    .ifPresent(x -> distinctLabels = x);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+            valuesMapState.update(Collections.singletonList(valuesMap));
+            
distinctLabelsState.update(Collections.singletonList(distinctLabels));
+        }
+    }
+
+    /** Returns a DataStream of the marginal sums of the factors. */

Review Comment:
   How about update the java doc as:
   `Computes the marginal sums of different categories.`



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java:
##########
@@ -0,0 +1,654 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.stats.chisqtest;
+
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeHint;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.math3.distribution.ChiSquaredDistribution;
+
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.math.RoundingMode;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+/**
+ * An AlgoOperator which implements the Chi-square test algorithm.
+ *
+ * <p>Chi-square Test is an AlgoOperator that computes the statistics of 
independence of variables
+ * in a contingency table. This function computes the chi-square statistic, 
p-value, and DOF(number
+ * of degrees of freedom) for every feature in the contingency table. The 
contingency table is
+ * constructed from the observed categorical values.
+ *
+ * <p>See: http://en.wikipedia.org/wiki/Chi-squared_test.
+ */
+public class ChiSqTest implements AlgoOperator<ChiSqTest>, 
ChiSqTestParams<ChiSqTest> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public ChiSqTest() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        final String bcCategoricalMarginsKey = "bcCategoricalMarginsKey";
+        final String bcLabelMarginsKey = "bcLabelMarginsKey";
+
+        final String[] inputCols = getInputCols();
+        String labelCol = getLabelCol();
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        SingleOutputStreamOperator<Tuple3<String, Object, Object>> 
colAndFeatureAndLabel =
+                tEnv.toDataStream(inputs[0])
+                        .flatMap(new ExtractColAndFeatureAndLabel(inputCols, 
labelCol));
+
+        DataStream<Tuple4<String, Object, Object, Long>> observedFreq =
+                colAndFeatureAndLabel
+                        .keyBy(Tuple3::hashCode)
+                        .transform(
+                                "GenerateObservedFrequencies",
+                                TypeInformation.of(
+                                        new TypeHint<Tuple4<String, Object, 
Object, Long>>() {}),
+                                new GenerateObservedFrequencies());
+
+        SingleOutputStreamOperator<Tuple4<String, Object, Object, Long>> 
filledObservedFreq =
+                observedFreq
+                        .transform(
+                                "filledObservedFreq",
+                                Types.TUPLE(
+                                        Types.STRING,
+                                        Types.GENERIC(Object.class),
+                                        Types.GENERIC(Object.class),
+                                        Types.LONG),
+                                new FillZeroFunc())
+                        .setParallelism(1);
+
+        DataStream<Tuple3<String, Object, Long>> categoricalMargins =
+                observedFreq
+                        .keyBy(tuple -> new Tuple2<>(tuple.f0, 
tuple.f1).hashCode())
+                        .transform(
+                                "AggregateCategoricalMargins",
+                                TypeInformation.of(new TypeHint<Tuple3<String, 
Object, Long>>() {}),
+                                new AggregateCategoricalMargins());
+
+        DataStream<Tuple3<String, Object, Long>> labelMargins =
+                observedFreq
+                        .keyBy(tuple -> new Tuple2<>(tuple.f0, 
tuple.f2).hashCode())
+                        .transform(
+                                "AggregateLabelMargins",
+                                TypeInformation.of(new TypeHint<Tuple3<String, 
Object, Long>>() {}),
+                                new AggregateLabelMargins());
+
+        Function<List<DataStream<?>>, DataStream<Tuple3<String, Double, 
Integer>>> function =
+                dataStreams -> {
+                    DataStream stream = dataStreams.get(0);
+                    return stream.map(new ChiSqFunc(bcCategoricalMarginsKey, 
bcLabelMarginsKey));
+                };
+
+        HashMap<String, DataStream<?>> bcMap =
+                new HashMap<String, DataStream<?>>() {
+                    {
+                        put(bcCategoricalMarginsKey, categoricalMargins);
+                        put(bcLabelMarginsKey, labelMargins);
+                    }
+                };
+
+        DataStream<Tuple3<String, Double, Integer>> categoricalStatistics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(filledObservedFreq), bcMap, 
function);
+
+        SingleOutputStreamOperator<Row> chiSqTestResult =
+                categoricalStatistics
+                        .transform(
+                                "chiSqTestResult",
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            Types.STRING, Types.DOUBLE, 
Types.DOUBLE, Types.INT
+                                        },
+                                        new String[] {
+                                            "column", "pValue", "statistic", 
"degreesOfFreedom"
+                                        }),
+                                new AggregateChiSqFunc())
+                        .setParallelism(1);
+
+        return new Table[] {tEnv.fromDataStream(chiSqTestResult)};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static ChiSqTest load(StreamTableEnvironment tEnv, String path) 
throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class ExtractColAndFeatureAndLabel
+            extends RichFlatMapFunction<Row, Tuple3<String, Object, Object>> {
+        private final String[] inputCols;
+        private final String labelCol;
+
+        public ExtractColAndFeatureAndLabel(String[] inputCols, String 
labelCol) {
+            this.inputCols = inputCols;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public void flatMap(Row row, Collector<Tuple3<String, Object, Object>> 
collector) {
+
+            Object label = row.getFieldAs(labelCol);
+
+            for (String colName : inputCols) {
+                Object value = row.getField(colName);
+                collector.collect(new Tuple3<>(colName, value, label));
+            }
+        }
+    }
+
+    /**
+     * Computes a frequency table(DataStream) of the factors(categorical 
values). The returned
+     * DataStream contains the observed frequencies (i.e. number of 
occurrences) in each category.
+     */
+    private static class GenerateObservedFrequencies
+            extends AbstractStreamOperator<Tuple4<String, Object, Object, 
Long>>
+            implements OneInputStreamOperator<
+                            Tuple3<String, Object, Object>, Tuple4<String, 
Object, Object, Long>>,
+                    BoundedOneInput {
+
+        private HashMap<Tuple3<String, Object, Object>, Long> cntMap = new 
HashMap<>();
+        private ListState<HashMap<Tuple3<String, Object, Object>, Long>> 
cntMapState;
+
+        @Override
+        public void endInput() {
+            for (Tuple3<String, Object, Object> key : cntMap.keySet()) {
+                Long count = cntMap.get(key);
+                output.collect(new StreamRecord<>(new Tuple4<>(key.f0, key.f1, 
key.f2, count)));
+            }
+            cntMapState.clear();
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<String, Object, 
Object>> element) {
+
+            Tuple3<String, Object, Object> colAndCategoryAndLabel = 
element.getValue();
+            cntMap.compute(colAndCategoryAndLabel, (k, v) -> (v == null ? 1 : 
v + 1));
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            cntMapState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "cntMapState",
+                                            TypeInformation.of(
+                                                    new TypeHint<
+                                                            HashMap<
+                                                                    
Tuple3<String, Object, Object>,
+                                                                    Long>>() 
{})));
+
+            OperatorStateUtils.getUniqueElement(cntMapState, "cntMapState")
+                    .ifPresent(x -> cntMap = x);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+            cntMapState.update(Collections.singletonList(cntMap));
+        }
+    }
+
+    /** Fills the factors which frequencies are zero in frequency table. */
+    private static class FillZeroFunc
+            extends AbstractStreamOperator<Tuple4<String, Object, Object, 
Long>>
+            implements OneInputStreamOperator<
+                            Tuple4<String, Object, Object, Long>,
+                            Tuple4<String, Object, Object, Long>>,
+                    BoundedOneInput {
+
+        private HashMap<Tuple2<String, Object>, ArrayList<Tuple2<Object, 
Long>>> valuesMap =
+                new HashMap<>();
+        private HashSet<Object> distinctLabels = new HashSet<>();
+
+        private ListState<HashMap<Tuple2<String, Object>, 
ArrayList<Tuple2<Object, Long>>>>
+                valuesMapState;
+        private ListState<HashSet<Object>> distinctLabelsState;
+
+        @Override
+        public void endInput() {
+
+            for (Map.Entry<Tuple2<String, Object>, ArrayList<Tuple2<Object, 
Long>>> entry :
+                    valuesMap.entrySet()) {
+                ArrayList<Tuple2<Object, Long>> labelAndCountList = 
entry.getValue();
+                Tuple2<String, Object> categoricalKey = entry.getKey();
+
+                List<Object> existingLabels =
+                        labelAndCountList.stream().map(v -> 
v.f0).collect(Collectors.toList());
+
+                for (Object label : distinctLabels) {
+                    if (!existingLabels.contains(label)) {
+                        Tuple2<Object, Long> generatedLabelCount = new 
Tuple2<>(label, 0L);
+                        labelAndCountList.add(generatedLabelCount);
+                    }
+                }
+
+                for (Tuple2<Object, Long> labelAndCount : labelAndCountList) {
+                    output.collect(
+                            new StreamRecord<>(
+                                    new Tuple4<>(
+                                            categoricalKey.f0,
+                                            categoricalKey.f1,
+                                            labelAndCount.f0,
+                                            labelAndCount.f1)));
+                }
+            }
+
+            valuesMapState.clear();
+            distinctLabelsState.clear();
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple4<String, Object, Object, 
Long>> element) {
+            Tuple4<String, Object, Object, Long> 
colAndCategoryAndLabelAndCount =
+                    element.getValue();
+            Tuple2<String, Object> key =
+                    new Tuple2<>(
+                            colAndCategoryAndLabelAndCount.f0, 
colAndCategoryAndLabelAndCount.f1);
+            Tuple2<Object, Long> labelAndCount =
+                    new Tuple2<>(
+                            colAndCategoryAndLabelAndCount.f2, 
colAndCategoryAndLabelAndCount.f3);
+            ArrayList<Tuple2<Object, Long>> labelAndCountList = 
valuesMap.get(key);
+
+            if (labelAndCountList == null) {
+                ArrayList<Tuple2<Object, Long>> value = new ArrayList<>();
+                value.add(labelAndCount);
+                valuesMap.put(key, value);
+            } else {
+                labelAndCountList.add(labelAndCount);
+            }
+
+            distinctLabels.add(colAndCategoryAndLabelAndCount.f2);
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            valuesMapState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "valuesMapState",
+                                            TypeInformation.of(
+                                                    new TypeHint<
+                                                            HashMap<
+                                                                    
Tuple2<String, Object>,
+                                                                    ArrayList<
+                                                                            
Tuple2<
+                                                                               
     Object,
+                                                                               
     Long>>>>() {})));
+            distinctLabelsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "distinctLabelsState",
+                                            TypeInformation.of(
+                                                    new 
TypeHint<HashSet<Object>>() {})));
+
+            OperatorStateUtils.getUniqueElement(valuesMapState, 
"valuesMapState")
+                    .ifPresent(x -> valuesMap = x);
+
+            OperatorStateUtils.getUniqueElement(distinctLabelsState, 
"distinctLabelsState")
+                    .ifPresent(x -> distinctLabels = x);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+            valuesMapState.update(Collections.singletonList(valuesMap));
+            
distinctLabelsState.update(Collections.singletonList(distinctLabels));
+        }
+    }
+
+    /** Returns a DataStream of the marginal sums of the factors. */
+    private static class AggregateCategoricalMargins
+            extends AbstractStreamOperator<Tuple3<String, Object, Long>>
+            implements OneInputStreamOperator<
+                            Tuple4<String, Object, Object, Long>, 
Tuple3<String, Object, Long>>,
+                    BoundedOneInput {
+
+        private HashMap<Tuple2<String, Object>, Long> categoricalMarginsMap = 
new HashMap<>();
+
+        private ListState<HashMap<Tuple2<String, Object>, Long>> 
categoricalMarginsMapState;
+
+        @Override
+        public void endInput() {
+            for (Tuple2<String, Object> key : categoricalMarginsMap.keySet()) {
+                Long categoricalMargin = categoricalMarginsMap.get(key);
+                output.collect(new StreamRecord<>(new Tuple3<>(key.f0, key.f1, 
categoricalMargin)));
+            }
+            categoricalMarginsMap.clear();
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple4<String, Object, Object, 
Long>> element) {
+
+            Tuple4<String, Object, Object, Long> colAndCategoryAndLabelAndCnt 
= element.getValue();
+            Tuple2<String, Object> key =
+                    new Tuple2<>(colAndCategoryAndLabelAndCnt.f0, 
colAndCategoryAndLabelAndCnt.f1);
+            Long observedFreq = colAndCategoryAndLabelAndCnt.f3;
+            categoricalMarginsMap.compute(
+                    key, (k, v) -> (v == null ? observedFreq : v + 
observedFreq));
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            categoricalMarginsMapState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "categoricalMarginsMapState",
+                                            TypeInformation.of(
+                                                    new TypeHint<
+                                                            HashMap<
+                                                                    
Tuple2<String, Object>,
+                                                                    Long>>() 
{})));
+
+            OperatorStateUtils.getUniqueElement(
+                            categoricalMarginsMapState, 
"categoricalMarginsMapState")
+                    .ifPresent(x -> categoricalMarginsMap = x);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+            
categoricalMarginsMapState.update(Collections.singletonList(categoricalMarginsMap));
+        }
+    }
+
+    /** Returns a DataStream of the marginal sums of the labels. */

Review Comment:
   How about update the java doc as:
   `Computes the marginal sums of different labels.`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to