Github user pnowojski commented on a diff in the pull request: https://github.com/apache/flink/pull/5481#discussion_r169060125 --- Diff: flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/KeyedProcessOperatorTest.java --- @@ -377,119 +375,114 @@ public T getKey(T value) throws Exception { } } - private static class QueryingFlatMapFunction extends ProcessFunction<Integer, String> { + private static class QueryingFlatMapFunction extends KeyedProcessFunction<Integer, Integer, String> { private static final long serialVersionUID = 1L; - private final TimeDomain timeDomain; + private final TimeDomain expectedTimeDomain; public QueryingFlatMapFunction(TimeDomain timeDomain) { - this.timeDomain = timeDomain; + this.expectedTimeDomain = timeDomain; } @Override public void processElement(Integer value, Context ctx, Collector<String> out) throws Exception { - if (timeDomain.equals(TimeDomain.EVENT_TIME)) { + if (expectedTimeDomain.equals(TimeDomain.EVENT_TIME)) { out.collect(value + "TIME:" + ctx.timerService().currentWatermark() + " TS:" + ctx.timestamp()); } else { out.collect(value + "TIME:" + ctx.timerService().currentProcessingTime() + " TS:" + ctx.timestamp()); } } @Override - public void onTimer( - long timestamp, - OnTimerContext ctx, - Collector<String> out) throws Exception { + public void onTimer(long timestamp, OnTimerContext<Integer> ctx, Collector<String> out) throws Exception { + // Do nothing } } - private static class TriggeringFlatMapFunction extends ProcessFunction<Integer, Integer> { + private static class TriggeringFlatMapFunction extends KeyedProcessFunction<Integer, Integer, Integer> { private static final long serialVersionUID = 1L; - private final TimeDomain timeDomain; + private final TimeDomain expectedTimeDomain; + + static final Integer EXPECTED_KEY = 17; public TriggeringFlatMapFunction(TimeDomain timeDomain) { - this.timeDomain = timeDomain; + this.expectedTimeDomain = timeDomain; } @Override public void processElement(Integer value, Context ctx, Collector<Integer> out) throws Exception { out.collect(value); - if (timeDomain.equals(TimeDomain.EVENT_TIME)) { + if (expectedTimeDomain.equals(TimeDomain.EVENT_TIME)) { ctx.timerService().registerEventTimeTimer(ctx.timerService().currentWatermark() + 5); } else { ctx.timerService().registerProcessingTimeTimer(ctx.timerService().currentProcessingTime() + 5); } } @Override - public void onTimer( - long timestamp, - OnTimerContext ctx, - Collector<Integer> out) throws Exception { - - assertEquals(this.timeDomain, ctx.timeDomain()); + public void onTimer(long timestamp, OnTimerContext<Integer> ctx, Collector<Integer> out) throws Exception { + assertEquals(EXPECTED_KEY, ctx.getCurrentKey()); + assertEquals(expectedTimeDomain, ctx.timeDomain()); out.collect(1777); } } - private static class TriggeringStatefulFlatMapFunction extends ProcessFunction<Integer, String> { + private static class TriggeringStatefulFlatMapFunction extends KeyedProcessFunction<Integer, Integer, String> { private static final long serialVersionUID = 1L; private final ValueStateDescriptor<Integer> state = new ValueStateDescriptor<>("seen-element", IntSerializer.INSTANCE); - private final TimeDomain timeDomain; + private final TimeDomain expectedTimeDomain; public TriggeringStatefulFlatMapFunction(TimeDomain timeDomain) { - this.timeDomain = timeDomain; + this.expectedTimeDomain = timeDomain; } @Override public void processElement(Integer value, Context ctx, Collector<String> out) throws Exception { out.collect("INPUT:" + value); getRuntimeContext().getState(state).update(value); - if (timeDomain.equals(TimeDomain.EVENT_TIME)) { + if (expectedTimeDomain.equals(TimeDomain.EVENT_TIME)) { ctx.timerService().registerEventTimeTimer(ctx.timerService().currentWatermark() + 5); } else { ctx.timerService().registerProcessingTimeTimer(ctx.timerService().currentProcessingTime() + 5); } } @Override - public void onTimer( - long timestamp, - OnTimerContext ctx, - Collector<String> out) throws Exception { - assertEquals(this.timeDomain, ctx.timeDomain()); + public void onTimer(long timestamp, OnTimerContext<Integer> ctx, Collector<String> out) throws Exception { + System.out.println(ctx.getCurrentKey()); + assertEquals(expectedTimeDomain, ctx.timeDomain()); out.collect("STATE:" + getRuntimeContext().getState(state).value()); } } - private static class BothTriggeringFlatMapFunction extends ProcessFunction<Integer, String> { + private static class BothTriggeringFlatMapFunction extends KeyedProcessFunction<Integer, Integer, String> { private static final long serialVersionUID = 1L; + static final Integer EXPECTED_KEY = 5; --- End diff -- ditto
---