zhipeng93 commented on code in PR #132: URL: https://github.com/apache/flink-ml/pull/132#discussion_r934122580
########## flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java: ########## @@ -0,0 +1,654 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.stats.chisqtest; + +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeHint; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.math3.distribution.ChiSquaredDistribution; + +import java.io.IOException; +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * An AlgoOperator which implements the Chi-square test algorithm. + * + * <p>Chi-square Test is an AlgoOperator that computes the statistics of independence of variables Review Comment: How about remove `is an AlgoOperator` in this line and update the java doc as: ``` Chi-square Test computes the statistics of independence of variables in a contingency table, e.g., p-value, and DOF(number of degrees of freedom) for each input feature. The contingency table is constructed from the observed categorical values. ``` ########## flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java: ########## @@ -0,0 +1,654 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.stats.chisqtest; + +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeHint; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.math3.distribution.ChiSquaredDistribution; + +import java.io.IOException; +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * An AlgoOperator which implements the Chi-square test algorithm. + * + * <p>Chi-square Test is an AlgoOperator that computes the statistics of independence of variables + * in a contingency table. This function computes the chi-square statistic, p-value, and DOF(number + * of degrees of freedom) for every feature in the contingency table. The contingency table is + * constructed from the observed categorical values. + * + * <p>See: http://en.wikipedia.org/wiki/Chi-squared_test. + */ +public class ChiSqTest implements AlgoOperator<ChiSqTest>, ChiSqTestParams<ChiSqTest> { + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public ChiSqTest() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + final String bcCategoricalMarginsKey = "bcCategoricalMarginsKey"; + final String bcLabelMarginsKey = "bcLabelMarginsKey"; + + final String[] inputCols = getInputCols(); + String labelCol = getLabelCol(); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + SingleOutputStreamOperator<Tuple3<String, Object, Object>> colAndFeatureAndLabel = + tEnv.toDataStream(inputs[0]) + .flatMap(new ExtractColAndFeatureAndLabel(inputCols, labelCol)); + + DataStream<Tuple4<String, Object, Object, Long>> observedFreq = + colAndFeatureAndLabel + .keyBy(Tuple3::hashCode) + .transform( + "GenerateObservedFrequencies", + TypeInformation.of( + new TypeHint<Tuple4<String, Object, Object, Long>>() {}), + new GenerateObservedFrequencies()); + + SingleOutputStreamOperator<Tuple4<String, Object, Object, Long>> filledObservedFreq = + observedFreq + .transform( + "filledObservedFreq", + Types.TUPLE( + Types.STRING, + Types.GENERIC(Object.class), + Types.GENERIC(Object.class), + Types.LONG), + new FillZeroFunc()) + .setParallelism(1); + + DataStream<Tuple3<String, Object, Long>> categoricalMargins = + observedFreq + .keyBy(tuple -> new Tuple2<>(tuple.f0, tuple.f1).hashCode()) + .transform( + "AggregateCategoricalMargins", + TypeInformation.of(new TypeHint<Tuple3<String, Object, Long>>() {}), + new AggregateCategoricalMargins()); + + DataStream<Tuple3<String, Object, Long>> labelMargins = + observedFreq + .keyBy(tuple -> new Tuple2<>(tuple.f0, tuple.f2).hashCode()) + .transform( + "AggregateLabelMargins", + TypeInformation.of(new TypeHint<Tuple3<String, Object, Long>>() {}), + new AggregateLabelMargins()); + + Function<List<DataStream<?>>, DataStream<Tuple3<String, Double, Integer>>> function = + dataStreams -> { + DataStream stream = dataStreams.get(0); + return stream.map(new ChiSqFunc(bcCategoricalMarginsKey, bcLabelMarginsKey)); + }; + + HashMap<String, DataStream<?>> bcMap = + new HashMap<String, DataStream<?>>() { + { + put(bcCategoricalMarginsKey, categoricalMargins); + put(bcLabelMarginsKey, labelMargins); + } + }; + + DataStream<Tuple3<String, Double, Integer>> categoricalStatistics = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(filledObservedFreq), bcMap, function); + + SingleOutputStreamOperator<Row> chiSqTestResult = + categoricalStatistics + .transform( + "chiSqTestResult", + new RowTypeInfo( + new TypeInformation[] { + Types.STRING, Types.DOUBLE, Types.DOUBLE, Types.INT + }, + new String[] { + "column", "pValue", "statistic", "degreesOfFreedom" + }), + new AggregateChiSqFunc()) + .setParallelism(1); + + return new Table[] {tEnv.fromDataStream(chiSqTestResult)}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static ChiSqTest load(StreamTableEnvironment tEnv, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static class ExtractColAndFeatureAndLabel + extends RichFlatMapFunction<Row, Tuple3<String, Object, Object>> { + private final String[] inputCols; + private final String labelCol; + + public ExtractColAndFeatureAndLabel(String[] inputCols, String labelCol) { + this.inputCols = inputCols; + this.labelCol = labelCol; + } + + @Override + public void flatMap(Row row, Collector<Tuple3<String, Object, Object>> collector) { + + Object label = row.getFieldAs(labelCol); + + for (String colName : inputCols) { + Object value = row.getField(colName); + collector.collect(new Tuple3<>(colName, value, label)); + } + } + } + + /** + * Computes a frequency table(DataStream) of the factors(categorical values). The returned Review Comment: How about update the java doc as: ``` Computes the frequency of each feature value at different columns by labels. An output record (columnA, featureValueB, labelC, countD) represents that A feature value {featureValueB} with label {labelC} at column {columnA} has appeared {countD} times in the input table. ``` ########## flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java: ########## @@ -0,0 +1,654 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.stats.chisqtest; + +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeHint; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.math3.distribution.ChiSquaredDistribution; + +import java.io.IOException; +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * An AlgoOperator which implements the Chi-square test algorithm. + * + * <p>Chi-square Test is an AlgoOperator that computes the statistics of independence of variables + * in a contingency table. This function computes the chi-square statistic, p-value, and DOF(number + * of degrees of freedom) for every feature in the contingency table. The contingency table is + * constructed from the observed categorical values. + * + * <p>See: http://en.wikipedia.org/wiki/Chi-squared_test. + */ +public class ChiSqTest implements AlgoOperator<ChiSqTest>, ChiSqTestParams<ChiSqTest> { + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public ChiSqTest() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + final String bcCategoricalMarginsKey = "bcCategoricalMarginsKey"; + final String bcLabelMarginsKey = "bcLabelMarginsKey"; + + final String[] inputCols = getInputCols(); + String labelCol = getLabelCol(); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + SingleOutputStreamOperator<Tuple3<String, Object, Object>> colAndFeatureAndLabel = + tEnv.toDataStream(inputs[0]) + .flatMap(new ExtractColAndFeatureAndLabel(inputCols, labelCol)); + + DataStream<Tuple4<String, Object, Object, Long>> observedFreq = + colAndFeatureAndLabel + .keyBy(Tuple3::hashCode) + .transform( + "GenerateObservedFrequencies", + TypeInformation.of( + new TypeHint<Tuple4<String, Object, Object, Long>>() {}), + new GenerateObservedFrequencies()); + + SingleOutputStreamOperator<Tuple4<String, Object, Object, Long>> filledObservedFreq = + observedFreq + .transform( + "filledObservedFreq", + Types.TUPLE( + Types.STRING, + Types.GENERIC(Object.class), + Types.GENERIC(Object.class), + Types.LONG), + new FillZeroFunc()) + .setParallelism(1); + + DataStream<Tuple3<String, Object, Long>> categoricalMargins = + observedFreq + .keyBy(tuple -> new Tuple2<>(tuple.f0, tuple.f1).hashCode()) + .transform( + "AggregateCategoricalMargins", + TypeInformation.of(new TypeHint<Tuple3<String, Object, Long>>() {}), + new AggregateCategoricalMargins()); + + DataStream<Tuple3<String, Object, Long>> labelMargins = + observedFreq + .keyBy(tuple -> new Tuple2<>(tuple.f0, tuple.f2).hashCode()) + .transform( + "AggregateLabelMargins", + TypeInformation.of(new TypeHint<Tuple3<String, Object, Long>>() {}), + new AggregateLabelMargins()); + + Function<List<DataStream<?>>, DataStream<Tuple3<String, Double, Integer>>> function = + dataStreams -> { + DataStream stream = dataStreams.get(0); + return stream.map(new ChiSqFunc(bcCategoricalMarginsKey, bcLabelMarginsKey)); + }; + + HashMap<String, DataStream<?>> bcMap = + new HashMap<String, DataStream<?>>() { + { + put(bcCategoricalMarginsKey, categoricalMargins); + put(bcLabelMarginsKey, labelMargins); + } + }; + + DataStream<Tuple3<String, Double, Integer>> categoricalStatistics = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(filledObservedFreq), bcMap, function); + + SingleOutputStreamOperator<Row> chiSqTestResult = + categoricalStatistics + .transform( + "chiSqTestResult", + new RowTypeInfo( + new TypeInformation[] { + Types.STRING, Types.DOUBLE, Types.DOUBLE, Types.INT + }, + new String[] { + "column", "pValue", "statistic", "degreesOfFreedom" + }), + new AggregateChiSqFunc()) + .setParallelism(1); + + return new Table[] {tEnv.fromDataStream(chiSqTestResult)}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static ChiSqTest load(StreamTableEnvironment tEnv, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static class ExtractColAndFeatureAndLabel + extends RichFlatMapFunction<Row, Tuple3<String, Object, Object>> { + private final String[] inputCols; + private final String labelCol; + + public ExtractColAndFeatureAndLabel(String[] inputCols, String labelCol) { + this.inputCols = inputCols; + this.labelCol = labelCol; + } + + @Override + public void flatMap(Row row, Collector<Tuple3<String, Object, Object>> collector) { + + Object label = row.getFieldAs(labelCol); + + for (String colName : inputCols) { + Object value = row.getField(colName); + collector.collect(new Tuple3<>(colName, value, label)); + } + } + } + + /** + * Computes a frequency table(DataStream) of the factors(categorical values). The returned + * DataStream contains the observed frequencies (i.e. number of occurrences) in each category. + */ + private static class GenerateObservedFrequencies + extends AbstractStreamOperator<Tuple4<String, Object, Object, Long>> + implements OneInputStreamOperator< + Tuple3<String, Object, Object>, Tuple4<String, Object, Object, Long>>, + BoundedOneInput { + + private HashMap<Tuple3<String, Object, Object>, Long> cntMap = new HashMap<>(); + private ListState<HashMap<Tuple3<String, Object, Object>, Long>> cntMapState; + + @Override + public void endInput() { + for (Tuple3<String, Object, Object> key : cntMap.keySet()) { + Long count = cntMap.get(key); + output.collect(new StreamRecord<>(new Tuple4<>(key.f0, key.f1, key.f2, count))); + } + cntMapState.clear(); + } + + @Override + public void processElement(StreamRecord<Tuple3<String, Object, Object>> element) { + + Tuple3<String, Object, Object> colAndCategoryAndLabel = element.getValue(); + cntMap.compute(colAndCategoryAndLabel, (k, v) -> (v == null ? 1 : v + 1)); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + cntMapState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "cntMapState", + TypeInformation.of( + new TypeHint< + HashMap< + Tuple3<String, Object, Object>, + Long>>() {}))); + + OperatorStateUtils.getUniqueElement(cntMapState, "cntMapState") + .ifPresent(x -> cntMap = x); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + cntMapState.update(Collections.singletonList(cntMap)); + } + } + + /** Fills the factors which frequencies are zero in frequency table. */ Review Comment: How about change the function name to `fillFrequencyTable` and update the java doc as follows: ``` Fills the frequency table by setting the frequency of missed elements (i.e., missed combinations of column, featureValue and labelValue) as zero. ``` ########## flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java: ########## @@ -0,0 +1,654 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.stats.chisqtest; + +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeHint; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.math3.distribution.ChiSquaredDistribution; + +import java.io.IOException; +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * An AlgoOperator which implements the Chi-square test algorithm. + * + * <p>Chi-square Test is an AlgoOperator that computes the statistics of independence of variables + * in a contingency table. This function computes the chi-square statistic, p-value, and DOF(number + * of degrees of freedom) for every feature in the contingency table. The contingency table is + * constructed from the observed categorical values. + * + * <p>See: http://en.wikipedia.org/wiki/Chi-squared_test. + */ +public class ChiSqTest implements AlgoOperator<ChiSqTest>, ChiSqTestParams<ChiSqTest> { + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public ChiSqTest() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + final String bcCategoricalMarginsKey = "bcCategoricalMarginsKey"; + final String bcLabelMarginsKey = "bcLabelMarginsKey"; + + final String[] inputCols = getInputCols(); + String labelCol = getLabelCol(); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + SingleOutputStreamOperator<Tuple3<String, Object, Object>> colAndFeatureAndLabel = + tEnv.toDataStream(inputs[0]) + .flatMap(new ExtractColAndFeatureAndLabel(inputCols, labelCol)); + + DataStream<Tuple4<String, Object, Object, Long>> observedFreq = + colAndFeatureAndLabel + .keyBy(Tuple3::hashCode) + .transform( + "GenerateObservedFrequencies", + TypeInformation.of( + new TypeHint<Tuple4<String, Object, Object, Long>>() {}), + new GenerateObservedFrequencies()); + + SingleOutputStreamOperator<Tuple4<String, Object, Object, Long>> filledObservedFreq = + observedFreq + .transform( + "filledObservedFreq", + Types.TUPLE( + Types.STRING, + Types.GENERIC(Object.class), + Types.GENERIC(Object.class), + Types.LONG), + new FillZeroFunc()) + .setParallelism(1); + + DataStream<Tuple3<String, Object, Long>> categoricalMargins = + observedFreq + .keyBy(tuple -> new Tuple2<>(tuple.f0, tuple.f1).hashCode()) + .transform( + "AggregateCategoricalMargins", + TypeInformation.of(new TypeHint<Tuple3<String, Object, Long>>() {}), + new AggregateCategoricalMargins()); + + DataStream<Tuple3<String, Object, Long>> labelMargins = + observedFreq + .keyBy(tuple -> new Tuple2<>(tuple.f0, tuple.f2).hashCode()) + .transform( + "AggregateLabelMargins", + TypeInformation.of(new TypeHint<Tuple3<String, Object, Long>>() {}), + new AggregateLabelMargins()); + + Function<List<DataStream<?>>, DataStream<Tuple3<String, Double, Integer>>> function = + dataStreams -> { + DataStream stream = dataStreams.get(0); + return stream.map(new ChiSqFunc(bcCategoricalMarginsKey, bcLabelMarginsKey)); + }; + + HashMap<String, DataStream<?>> bcMap = + new HashMap<String, DataStream<?>>() { + { + put(bcCategoricalMarginsKey, categoricalMargins); + put(bcLabelMarginsKey, labelMargins); + } + }; + + DataStream<Tuple3<String, Double, Integer>> categoricalStatistics = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(filledObservedFreq), bcMap, function); + + SingleOutputStreamOperator<Row> chiSqTestResult = + categoricalStatistics + .transform( + "chiSqTestResult", + new RowTypeInfo( + new TypeInformation[] { + Types.STRING, Types.DOUBLE, Types.DOUBLE, Types.INT + }, + new String[] { + "column", "pValue", "statistic", "degreesOfFreedom" + }), + new AggregateChiSqFunc()) + .setParallelism(1); + + return new Table[] {tEnv.fromDataStream(chiSqTestResult)}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static ChiSqTest load(StreamTableEnvironment tEnv, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static class ExtractColAndFeatureAndLabel + extends RichFlatMapFunction<Row, Tuple3<String, Object, Object>> { + private final String[] inputCols; + private final String labelCol; + + public ExtractColAndFeatureAndLabel(String[] inputCols, String labelCol) { + this.inputCols = inputCols; + this.labelCol = labelCol; + } + + @Override + public void flatMap(Row row, Collector<Tuple3<String, Object, Object>> collector) { + + Object label = row.getFieldAs(labelCol); + + for (String colName : inputCols) { + Object value = row.getField(colName); + collector.collect(new Tuple3<>(colName, value, label)); + } + } + } + + /** + * Computes a frequency table(DataStream) of the factors(categorical values). The returned + * DataStream contains the observed frequencies (i.e. number of occurrences) in each category. + */ + private static class GenerateObservedFrequencies + extends AbstractStreamOperator<Tuple4<String, Object, Object, Long>> + implements OneInputStreamOperator< + Tuple3<String, Object, Object>, Tuple4<String, Object, Object, Long>>, + BoundedOneInput { + + private HashMap<Tuple3<String, Object, Object>, Long> cntMap = new HashMap<>(); + private ListState<HashMap<Tuple3<String, Object, Object>, Long>> cntMapState; + + @Override + public void endInput() { + for (Tuple3<String, Object, Object> key : cntMap.keySet()) { + Long count = cntMap.get(key); + output.collect(new StreamRecord<>(new Tuple4<>(key.f0, key.f1, key.f2, count))); + } + cntMapState.clear(); + } + + @Override + public void processElement(StreamRecord<Tuple3<String, Object, Object>> element) { + + Tuple3<String, Object, Object> colAndCategoryAndLabel = element.getValue(); + cntMap.compute(colAndCategoryAndLabel, (k, v) -> (v == null ? 1 : v + 1)); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + cntMapState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "cntMapState", + TypeInformation.of( + new TypeHint< + HashMap< + Tuple3<String, Object, Object>, + Long>>() {}))); + + OperatorStateUtils.getUniqueElement(cntMapState, "cntMapState") + .ifPresent(x -> cntMap = x); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + cntMapState.update(Collections.singletonList(cntMap)); + } + } + + /** Fills the factors which frequencies are zero in frequency table. */ + private static class FillZeroFunc + extends AbstractStreamOperator<Tuple4<String, Object, Object, Long>> + implements OneInputStreamOperator< + Tuple4<String, Object, Object, Long>, + Tuple4<String, Object, Object, Long>>, + BoundedOneInput { + + private HashMap<Tuple2<String, Object>, ArrayList<Tuple2<Object, Long>>> valuesMap = + new HashMap<>(); + private HashSet<Object> distinctLabels = new HashSet<>(); + + private ListState<HashMap<Tuple2<String, Object>, ArrayList<Tuple2<Object, Long>>>> + valuesMapState; + private ListState<HashSet<Object>> distinctLabelsState; + + @Override + public void endInput() { + + for (Map.Entry<Tuple2<String, Object>, ArrayList<Tuple2<Object, Long>>> entry : + valuesMap.entrySet()) { + ArrayList<Tuple2<Object, Long>> labelAndCountList = entry.getValue(); + Tuple2<String, Object> categoricalKey = entry.getKey(); + + List<Object> existingLabels = + labelAndCountList.stream().map(v -> v.f0).collect(Collectors.toList()); + + for (Object label : distinctLabels) { + if (!existingLabels.contains(label)) { + Tuple2<Object, Long> generatedLabelCount = new Tuple2<>(label, 0L); + labelAndCountList.add(generatedLabelCount); + } + } + + for (Tuple2<Object, Long> labelAndCount : labelAndCountList) { + output.collect( + new StreamRecord<>( + new Tuple4<>( + categoricalKey.f0, + categoricalKey.f1, + labelAndCount.f0, + labelAndCount.f1))); + } + } + + valuesMapState.clear(); + distinctLabelsState.clear(); + } + + @Override + public void processElement(StreamRecord<Tuple4<String, Object, Object, Long>> element) { + Tuple4<String, Object, Object, Long> colAndCategoryAndLabelAndCount = + element.getValue(); + Tuple2<String, Object> key = + new Tuple2<>( + colAndCategoryAndLabelAndCount.f0, colAndCategoryAndLabelAndCount.f1); + Tuple2<Object, Long> labelAndCount = + new Tuple2<>( + colAndCategoryAndLabelAndCount.f2, colAndCategoryAndLabelAndCount.f3); + ArrayList<Tuple2<Object, Long>> labelAndCountList = valuesMap.get(key); + + if (labelAndCountList == null) { + ArrayList<Tuple2<Object, Long>> value = new ArrayList<>(); + value.add(labelAndCount); + valuesMap.put(key, value); + } else { + labelAndCountList.add(labelAndCount); + } + + distinctLabels.add(colAndCategoryAndLabelAndCount.f2); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + valuesMapState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "valuesMapState", + TypeInformation.of( + new TypeHint< + HashMap< + Tuple2<String, Object>, + ArrayList< + Tuple2< + Object, + Long>>>>() {}))); + distinctLabelsState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "distinctLabelsState", + TypeInformation.of( + new TypeHint<HashSet<Object>>() {}))); + + OperatorStateUtils.getUniqueElement(valuesMapState, "valuesMapState") + .ifPresent(x -> valuesMap = x); + + OperatorStateUtils.getUniqueElement(distinctLabelsState, "distinctLabelsState") + .ifPresent(x -> distinctLabels = x); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + valuesMapState.update(Collections.singletonList(valuesMap)); + distinctLabelsState.update(Collections.singletonList(distinctLabels)); + } + } + + /** Returns a DataStream of the marginal sums of the factors. */ Review Comment: How about update the java doc as: `Computes the marginal sums of different categories.` ########## flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java: ########## @@ -0,0 +1,654 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.stats.chisqtest; + +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeHint; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.math3.distribution.ChiSquaredDistribution; + +import java.io.IOException; +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * An AlgoOperator which implements the Chi-square test algorithm. + * + * <p>Chi-square Test is an AlgoOperator that computes the statistics of independence of variables + * in a contingency table. This function computes the chi-square statistic, p-value, and DOF(number + * of degrees of freedom) for every feature in the contingency table. The contingency table is + * constructed from the observed categorical values. + * + * <p>See: http://en.wikipedia.org/wiki/Chi-squared_test. + */ +public class ChiSqTest implements AlgoOperator<ChiSqTest>, ChiSqTestParams<ChiSqTest> { + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public ChiSqTest() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + final String bcCategoricalMarginsKey = "bcCategoricalMarginsKey"; + final String bcLabelMarginsKey = "bcLabelMarginsKey"; + + final String[] inputCols = getInputCols(); + String labelCol = getLabelCol(); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + SingleOutputStreamOperator<Tuple3<String, Object, Object>> colAndFeatureAndLabel = + tEnv.toDataStream(inputs[0]) + .flatMap(new ExtractColAndFeatureAndLabel(inputCols, labelCol)); + + DataStream<Tuple4<String, Object, Object, Long>> observedFreq = + colAndFeatureAndLabel + .keyBy(Tuple3::hashCode) + .transform( + "GenerateObservedFrequencies", + TypeInformation.of( + new TypeHint<Tuple4<String, Object, Object, Long>>() {}), + new GenerateObservedFrequencies()); + + SingleOutputStreamOperator<Tuple4<String, Object, Object, Long>> filledObservedFreq = + observedFreq + .transform( + "filledObservedFreq", + Types.TUPLE( + Types.STRING, + Types.GENERIC(Object.class), + Types.GENERIC(Object.class), + Types.LONG), + new FillZeroFunc()) + .setParallelism(1); + + DataStream<Tuple3<String, Object, Long>> categoricalMargins = + observedFreq + .keyBy(tuple -> new Tuple2<>(tuple.f0, tuple.f1).hashCode()) + .transform( + "AggregateCategoricalMargins", + TypeInformation.of(new TypeHint<Tuple3<String, Object, Long>>() {}), + new AggregateCategoricalMargins()); + + DataStream<Tuple3<String, Object, Long>> labelMargins = + observedFreq + .keyBy(tuple -> new Tuple2<>(tuple.f0, tuple.f2).hashCode()) + .transform( + "AggregateLabelMargins", + TypeInformation.of(new TypeHint<Tuple3<String, Object, Long>>() {}), + new AggregateLabelMargins()); + + Function<List<DataStream<?>>, DataStream<Tuple3<String, Double, Integer>>> function = + dataStreams -> { + DataStream stream = dataStreams.get(0); + return stream.map(new ChiSqFunc(bcCategoricalMarginsKey, bcLabelMarginsKey)); + }; + + HashMap<String, DataStream<?>> bcMap = + new HashMap<String, DataStream<?>>() { + { + put(bcCategoricalMarginsKey, categoricalMargins); + put(bcLabelMarginsKey, labelMargins); + } + }; + + DataStream<Tuple3<String, Double, Integer>> categoricalStatistics = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(filledObservedFreq), bcMap, function); + + SingleOutputStreamOperator<Row> chiSqTestResult = + categoricalStatistics + .transform( + "chiSqTestResult", + new RowTypeInfo( + new TypeInformation[] { + Types.STRING, Types.DOUBLE, Types.DOUBLE, Types.INT + }, + new String[] { + "column", "pValue", "statistic", "degreesOfFreedom" + }), + new AggregateChiSqFunc()) + .setParallelism(1); + + return new Table[] {tEnv.fromDataStream(chiSqTestResult)}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static ChiSqTest load(StreamTableEnvironment tEnv, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static class ExtractColAndFeatureAndLabel + extends RichFlatMapFunction<Row, Tuple3<String, Object, Object>> { + private final String[] inputCols; + private final String labelCol; + + public ExtractColAndFeatureAndLabel(String[] inputCols, String labelCol) { + this.inputCols = inputCols; + this.labelCol = labelCol; + } + + @Override + public void flatMap(Row row, Collector<Tuple3<String, Object, Object>> collector) { + + Object label = row.getFieldAs(labelCol); + + for (String colName : inputCols) { + Object value = row.getField(colName); + collector.collect(new Tuple3<>(colName, value, label)); + } + } + } + + /** + * Computes a frequency table(DataStream) of the factors(categorical values). The returned + * DataStream contains the observed frequencies (i.e. number of occurrences) in each category. + */ + private static class GenerateObservedFrequencies + extends AbstractStreamOperator<Tuple4<String, Object, Object, Long>> + implements OneInputStreamOperator< + Tuple3<String, Object, Object>, Tuple4<String, Object, Object, Long>>, + BoundedOneInput { + + private HashMap<Tuple3<String, Object, Object>, Long> cntMap = new HashMap<>(); + private ListState<HashMap<Tuple3<String, Object, Object>, Long>> cntMapState; + + @Override + public void endInput() { + for (Tuple3<String, Object, Object> key : cntMap.keySet()) { + Long count = cntMap.get(key); + output.collect(new StreamRecord<>(new Tuple4<>(key.f0, key.f1, key.f2, count))); + } + cntMapState.clear(); + } + + @Override + public void processElement(StreamRecord<Tuple3<String, Object, Object>> element) { + + Tuple3<String, Object, Object> colAndCategoryAndLabel = element.getValue(); + cntMap.compute(colAndCategoryAndLabel, (k, v) -> (v == null ? 1 : v + 1)); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + cntMapState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "cntMapState", + TypeInformation.of( + new TypeHint< + HashMap< + Tuple3<String, Object, Object>, + Long>>() {}))); + + OperatorStateUtils.getUniqueElement(cntMapState, "cntMapState") + .ifPresent(x -> cntMap = x); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + cntMapState.update(Collections.singletonList(cntMap)); + } + } + + /** Fills the factors which frequencies are zero in frequency table. */ + private static class FillZeroFunc + extends AbstractStreamOperator<Tuple4<String, Object, Object, Long>> + implements OneInputStreamOperator< + Tuple4<String, Object, Object, Long>, + Tuple4<String, Object, Object, Long>>, + BoundedOneInput { + + private HashMap<Tuple2<String, Object>, ArrayList<Tuple2<Object, Long>>> valuesMap = + new HashMap<>(); + private HashSet<Object> distinctLabels = new HashSet<>(); + + private ListState<HashMap<Tuple2<String, Object>, ArrayList<Tuple2<Object, Long>>>> + valuesMapState; + private ListState<HashSet<Object>> distinctLabelsState; + + @Override + public void endInput() { + + for (Map.Entry<Tuple2<String, Object>, ArrayList<Tuple2<Object, Long>>> entry : + valuesMap.entrySet()) { + ArrayList<Tuple2<Object, Long>> labelAndCountList = entry.getValue(); + Tuple2<String, Object> categoricalKey = entry.getKey(); + + List<Object> existingLabels = + labelAndCountList.stream().map(v -> v.f0).collect(Collectors.toList()); + + for (Object label : distinctLabels) { + if (!existingLabels.contains(label)) { + Tuple2<Object, Long> generatedLabelCount = new Tuple2<>(label, 0L); + labelAndCountList.add(generatedLabelCount); + } + } + + for (Tuple2<Object, Long> labelAndCount : labelAndCountList) { + output.collect( + new StreamRecord<>( + new Tuple4<>( + categoricalKey.f0, + categoricalKey.f1, + labelAndCount.f0, + labelAndCount.f1))); + } + } + + valuesMapState.clear(); + distinctLabelsState.clear(); + } + + @Override + public void processElement(StreamRecord<Tuple4<String, Object, Object, Long>> element) { + Tuple4<String, Object, Object, Long> colAndCategoryAndLabelAndCount = + element.getValue(); + Tuple2<String, Object> key = + new Tuple2<>( + colAndCategoryAndLabelAndCount.f0, colAndCategoryAndLabelAndCount.f1); + Tuple2<Object, Long> labelAndCount = + new Tuple2<>( + colAndCategoryAndLabelAndCount.f2, colAndCategoryAndLabelAndCount.f3); + ArrayList<Tuple2<Object, Long>> labelAndCountList = valuesMap.get(key); + + if (labelAndCountList == null) { + ArrayList<Tuple2<Object, Long>> value = new ArrayList<>(); + value.add(labelAndCount); + valuesMap.put(key, value); + } else { + labelAndCountList.add(labelAndCount); + } + + distinctLabels.add(colAndCategoryAndLabelAndCount.f2); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + valuesMapState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "valuesMapState", + TypeInformation.of( + new TypeHint< + HashMap< + Tuple2<String, Object>, + ArrayList< + Tuple2< + Object, + Long>>>>() {}))); + distinctLabelsState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "distinctLabelsState", + TypeInformation.of( + new TypeHint<HashSet<Object>>() {}))); + + OperatorStateUtils.getUniqueElement(valuesMapState, "valuesMapState") + .ifPresent(x -> valuesMap = x); + + OperatorStateUtils.getUniqueElement(distinctLabelsState, "distinctLabelsState") + .ifPresent(x -> distinctLabels = x); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + valuesMapState.update(Collections.singletonList(valuesMap)); + distinctLabelsState.update(Collections.singletonList(distinctLabels)); + } + } + + /** Returns a DataStream of the marginal sums of the factors. */ + private static class AggregateCategoricalMargins + extends AbstractStreamOperator<Tuple3<String, Object, Long>> + implements OneInputStreamOperator< + Tuple4<String, Object, Object, Long>, Tuple3<String, Object, Long>>, + BoundedOneInput { + + private HashMap<Tuple2<String, Object>, Long> categoricalMarginsMap = new HashMap<>(); + + private ListState<HashMap<Tuple2<String, Object>, Long>> categoricalMarginsMapState; + + @Override + public void endInput() { + for (Tuple2<String, Object> key : categoricalMarginsMap.keySet()) { + Long categoricalMargin = categoricalMarginsMap.get(key); + output.collect(new StreamRecord<>(new Tuple3<>(key.f0, key.f1, categoricalMargin))); + } + categoricalMarginsMap.clear(); + } + + @Override + public void processElement(StreamRecord<Tuple4<String, Object, Object, Long>> element) { + + Tuple4<String, Object, Object, Long> colAndCategoryAndLabelAndCnt = element.getValue(); + Tuple2<String, Object> key = + new Tuple2<>(colAndCategoryAndLabelAndCnt.f0, colAndCategoryAndLabelAndCnt.f1); + Long observedFreq = colAndCategoryAndLabelAndCnt.f3; + categoricalMarginsMap.compute( + key, (k, v) -> (v == null ? observedFreq : v + observedFreq)); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + categoricalMarginsMapState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "categoricalMarginsMapState", + TypeInformation.of( + new TypeHint< + HashMap< + Tuple2<String, Object>, + Long>>() {}))); + + OperatorStateUtils.getUniqueElement( + categoricalMarginsMapState, "categoricalMarginsMapState") + .ifPresent(x -> categoricalMarginsMap = x); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + categoricalMarginsMapState.update(Collections.singletonList(categoricalMarginsMap)); + } + } + + /** Returns a DataStream of the marginal sums of the labels. */ Review Comment: How about update the java doc as: `Computes the marginal sums of different labels.` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org