My algorithm is roughly like this taking top-K words problem as an example (the purpose of computing local “word count” is to deal with data imbalance):
DataStream of words -> timeWindow of 1h -> converted to DataSet of words -> random partitioning by rebalance -> local “word count” using mapPartition -> global “word count” using reduceGroup -> rebalance -> local top-K using mapPartition -> global top-K using reduceGroup Here is some (probably buggy) code to demonstrate the basic idea on DataSet: import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.GroupReduceFunction; import org.apache.flink.api.common.functions.MapPartitionFunction; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.util.Collector; import java.util.Map; import java.util.SortedMap; import java.util.TreeMap; public class WordCount { public static void main(String[] args) throws Exception { // set up the execution environment final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); // get input data DataSet<String> text = env.fromElements( "14159265358979323846264338327950288419716939937510", "58209749445923078164062862089986280348253421170679", "82148086513282306647093844609550582231725359408128", "48111745028410270193852110555964462294895493038196", "44288109756659334461284756482337867831652712019091", "45648566923460348610454326648213393607260249141273", "72458700660631558817488152092096282925409171536436", "78925903600113305305488204665213841469519415116094", "33057270365759591953092186117381932611793105118548", "07446237996274956735188575272489122793818301194912", "98336733624406566430860213949463952247371907021798", "60943702770539217176293176752384674818467669405132", "00056812714526356082778577134275778960917363717872", "14684409012249534301465495853710507922796892589235", "42019956112129021960864034418159813629774771309960", "51870721134999999837297804995105973173281609631859", "50244594553469083026425223082533446850352619311881", "71010003137838752886587533208381420617177669147303", "59825349042875546873115956286388235378759375195778", "18577805321712268066130019278766111959092164201989" ); DataSet<Tuple2<String, Integer>> counts = text // split up the lines in pairs (2-tuples) containing: (word,1) .flatMap(new LineSplitter()) .rebalance() // local word count .mapPartition(new MapPartitionFunction<Tuple2<String, Integer>, Tuple2<String, Integer>>() { @Override public void mapPartition(Iterable<Tuple2<String, Integer>> words, Collector<Tuple2<String, Integer>> out) throws Exception { SortedMap<String, Integer> m = new TreeMap<String, Integer>(); for (Tuple2<String, Integer> w : words) { Integer current = m.get(w.f0); Integer updated = current == null ? w.f1 : current + w.f1; m.put(w.f0, updated); } for (Map.Entry<String, Integer> e : m.entrySet()) { out.collect(Tuple2.of(e.getKey(), e.getValue())); } } }) // global word count .reduceGroup(new GroupReduceFunction<Tuple2<String, Integer>, Tuple2<String, Integer>>() { @Override public void reduce(Iterable<Tuple2<String, Integer>> wordcounts, Collector<Tuple2<String, Integer>> out) throws Exception { SortedMap<String, Integer> m = new TreeMap<String, Integer>(); for (Tuple2<String, Integer> wc : wordcounts) { Integer current = m.get(wc.f0); Integer updated = current == null ? wc.f1 : current + wc.f1; m.put(wc.f0, updated); } for (Map.Entry<String, Integer> e : m.entrySet()) { out.collect(Tuple2.of(e.getKey(), e.getValue())); } } }); DataSet<Tuple2<String, Integer>> topK = counts .rebalance() // local top-K .mapPartition(new MapPartitionFunction<Tuple2<String, Integer>, Tuple2<String, Integer>>() { @Override public void mapPartition(Iterable<Tuple2<String, Integer>> wordcounts, Collector<Tuple2<String, Integer>> out) throws Exception { SortedMap<Integer, String> topKSoFar = new TreeMap<Integer, String>(); for (Tuple2<String, Integer> wc : wordcounts) { String w = wc.f0; Integer c = wc.f1; topKSoFar.put(c, w); if (topKSoFar.size() > 3) { topKSoFar.remove(topKSoFar.firstKey()); } } for (Map.Entry<Integer, String> cw : topKSoFar.entrySet()) { out.collect(Tuple2.of(cw.getValue(), cw.getKey())); } } }) // global top-K .reduceGroup(new GroupReduceFunction<Tuple2<String, Integer>, Tuple2<String, Integer>>() { @Override public void reduce(Iterable<Tuple2<String, Integer>> topList, Collector<Tuple2<String, Integer>> out) throws Exception { SortedMap<Integer, String> topKSoFar = new TreeMap<Integer, String>(); for (Tuple2<String, Integer> wc : topList) { String w = wc.f0; Integer c = wc.f1; topKSoFar.put(c, w); if (topKSoFar.size() > 3) { topKSoFar.remove(topKSoFar.firstKey()); } } for (Map.Entry<Integer, String> cw : topKSoFar.entrySet()) { out.collect(Tuple2.of(cw.getValue(), cw.getKey())); } } }); // execute and print result topK.print(); env.setParallelism(4); env.execute(); } public static final class LineSplitter implements FlatMapFunction<String, Tuple2<String, Integer>> { @Override public void flatMap(String value, Collector<Tuple2<String, Integer>> out) { String[] tokens = value.split(""); for (String token : tokens) { if (token.length() > 0) { out.collect(new Tuple2<String, Integer>(token, 1)); } } } } }