Hi, I'm trying to calculate stateful counts per key with checkpoints following the example in https://ci.apache.org/projects/flink/flink-docs-release-1.1/apis/streaming/state.html#checkpointing-instance-fields. I would expect my test program to calculate the counts per key, but it seems to group the data by task rather than by key. Is this a Flink bug or have I misunderstood something?
The output of inputData.keyBy(0).flatMap(new TestCounters).print is 1> (A,count=1) 1> (F,count=2) 2> (B,count=1) 2> (C,count=2) 2> (D,count=3) 2> (E,count=4) 2> (E,count=5) 2> (E,count=6) 2> (H,count=7) 4> (G,count=1) while the output of inputData.keyBy(0).flatMapWithState(...).print is (as I would expect) 2> (B,1) 4> (G,1) 1> (A,1) 2> (C,1) 1> (F,1) 2> (D,1) 2> (E,1) 2> (E,2) 2> (E,3) 2> (H,1) I would expect both to give the same results. The full code: import org.apache.flink.api.common.functions.RichFlatMapFunction import org.apache.flink.streaming.api.scala._ import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.checkpoint.Checkpointed import org.apache.flink.util.Collector object FlinkStreamingTest { def main(args: Array[String]) { val env = StreamExecutionEnvironment.createLocalEnvironment() val checkpointIntervalMillis = 10000 env.enableCheckpointing(checkpointIntervalMillis) val inputData = env.fromElements(("A",0),("B",0),("C",0),("D",0), ("E",0),("E",0),("E",0), ("F",0),("G",0),("H",0)) inputData.keyBy(0).flatMap(new TestCounters).print /* inputData.keyBy(0).flatMapWithState((keyAndCount: (String, Int), count: Option[Int]) => count match { case None => (Iterator((keyAndCount._1, 1)), Some(1)) case Some(c) => (Iterator((keyAndCount._1, c+1)), Some(c+1)) }).print */ env.execute("Counters test") } } case class CounterClass(var count: Int) class TestCounters extends RichFlatMapFunction[(String, Int), (String, String)] with Checkpointed[CounterClass] { var counterValue: CounterClass = null override def flatMap(in: (String, Int), out: Collector[(String, String)]) = { counterValue.count = counterValue.count + 1 out.collect((in._1,"count="+counterValue.count)) } override def open(config: Configuration): Unit = { if(counterValue == null) { counterValue = new CounterClass(0) } } override def snapshotState(l: Long, l1: Long): CounterClass = { counterValue } override def restoreState(state: CounterClass): Unit = { counterValue = state } } Disclaimer: This message and any attachments thereto are intended solely for the addressed recipient(s) and may contain confidential information. If you are not the intended recipient, please notify the sender by reply e-mail and delete the e-mail (including any attachments thereto) without producing, distributing or retaining any copies thereof. Any review, dissemination or other use of, or taking of any action in reliance upon, this information by persons or entities other than the intended recipient(s) is prohibited. Thank you.