[ https://issues.apache.org/jira/browse/FLINK-10809?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16684129#comment-16684129 ]
ASF GitHub Bot commented on FLINK-10809: ---------------------------------------- asfgit closed pull request #7048: [FLINK-10809][state] Include keyed state that is not from head operat… URL: https://github.com/apache/flink/pull/7048 This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java index b0173886d57..02fc2013fb0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java @@ -71,13 +71,12 @@ public StateAssignmentOperation( this.allowNonRestoredState = allowNonRestoredState; } - public boolean assignStates() throws Exception { + public void assignStates() { Map<OperatorID, OperatorState> localOperators = new HashMap<>(operatorStates); - Map<JobVertexID, ExecutionJobVertex> localTasks = this.tasks; checkStateMappingCompleteness(allowNonRestoredState, operatorStates, tasks); - for (Map.Entry<JobVertexID, ExecutionJobVertex> task : localTasks.entrySet()) { + for (Map.Entry<JobVertexID, ExecutionJobVertex> task : this.tasks.entrySet()) { final ExecutionJobVertex executionJobVertex = task.getValue(); // find the states of all operators belonging to this task @@ -108,7 +107,6 @@ public boolean assignStates() throws Exception { assignAttemptState(task.getValue(), operatorStates); } - return true; } private void assignAttemptState(ExecutionJobVertex executionJobVertex, List<OperatorState> operatorStates) { @@ -254,10 +252,6 @@ public static OperatorSubtaskState operatorSubtaskStateFrom( new StateObjectCollection<>(subRawKeyedState.getOrDefault(instanceID, Collections.emptyList()))); } - private static boolean isHeadOperator(int opIdx, List<OperatorID> operatorIDs) { - return opIdx == operatorIDs.size() - 1; - } - public void checkParallelismPreconditions(List<OperatorState> operatorStates, ExecutionJobVertex executionJobVertex) { for (OperatorState operatorState : operatorStates) { checkParallelismPreconditions(operatorState, executionJobVertex); @@ -278,19 +272,16 @@ private void reDistributeKeyedStates( for (int operatorIndex = 0; operatorIndex < newOperatorIDs.size(); operatorIndex++) { OperatorState operatorState = oldOperatorStates.get(operatorIndex); int oldParallelism = operatorState.getParallelism(); - for (int subTaskIndex = 0; subTaskIndex < newParallelism; subTaskIndex++) { OperatorInstanceID instanceID = OperatorInstanceID.of(subTaskIndex, newOperatorIDs.get(operatorIndex)); - if (isHeadOperator(operatorIndex, newOperatorIDs)) { - Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>> subKeyedStates = reAssignSubKeyedStates( - operatorState, - newKeyGroupPartitions, - subTaskIndex, - newParallelism, - oldParallelism); - newManagedKeyedState.put(instanceID, subKeyedStates.f0); - newRawKeyedState.put(instanceID, subKeyedStates.f1); - } + Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>> subKeyedStates = reAssignSubKeyedStates( + operatorState, + newKeyGroupPartitions, + subTaskIndex, + newParallelism, + oldParallelism); + newManagedKeyedState.put(instanceID, subKeyedStates.f0); + newRawKeyedState.put(instanceID, subKeyedStates.f1); } } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/datastream/ReinterpretDataStreamAsKeyedStreamITCase.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/datastream/ReinterpretDataStreamAsKeyedStreamITCase.java index fc8e9971683..c94319ed561 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/datastream/ReinterpretDataStreamAsKeyedStreamITCase.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/datastream/ReinterpretDataStreamAsKeyedStreamITCase.java @@ -18,11 +18,18 @@ package org.apache.flink.streaming.api.datastream; import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.state.CheckpointListener; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.streaming.api.TimeCharacteristic; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction; @@ -73,6 +80,8 @@ public void testReinterpretAsKeyedStream() throws Exception { env.setStreamTimeCharacteristic(TimeCharacteristic.IngestionTime); env.setMaxParallelism(maxParallelism); env.setParallelism(parallelism); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(1, 0L)); final List<File> partitionFiles = new ArrayList<>(parallelism); for (int i = 0; i < parallelism; ++i) { @@ -156,15 +165,22 @@ public void invoke(Tuple2<Integer, Integer> value, Context context) throws Excep } } - private static class FromPartitionFileSource extends RichParallelSourceFunction<Tuple2<Integer, Integer>> { + private static class FromPartitionFileSource extends RichParallelSourceFunction<Tuple2<Integer, Integer>> + implements CheckpointedFunction, CheckpointListener { private static final long serialVersionUID = 1L; private List<File> allPartitions; private DataInputStream din; private volatile boolean running; - public FromPartitionFileSource(List<File> allPartitons) { - this.allPartitions = allPartitons; + private long position; + private transient ListState<Long> positionState; + private transient boolean isRestored; + + private transient volatile boolean canFail; + + public FromPartitionFileSource(List<File> allPartitions) { + this.allPartitions = allPartitions; } @Override @@ -174,6 +190,11 @@ public void open(Configuration parameters) throws Exception { din = new DataInputStream( new BufferedInputStream( new FileInputStream(allPartitions.get(subtaskIdx)))); + + long toSkip = position; + while (toSkip > 0L) { + toSkip -= din.skip(toSkip); + } } @Override @@ -187,11 +208,27 @@ public void run(SourceContext<Tuple2<Integer, Integer>> out) throws Exception { this.running = true; try { while (running) { - Integer key = din.readInt(); - Integer val = din.readInt(); - out.collect(new Tuple2<>(key, val)); + + checkFail(); + + synchronized (out.getCheckpointLock()) { + Integer key = din.readInt(); + Integer val = din.readInt(); + out.collect(new Tuple2<>(key, val)); + + position += 2 * Integer.BYTES; + } } } catch (EOFException ignore) { + while (!isRestored) { + checkFail(); + } + } + } + + private void checkFail() throws Exception { + if (canFail) { + throw new Exception("Artificial failure."); } } @@ -199,14 +236,43 @@ public void run(SourceContext<Tuple2<Integer, Integer>> out) throws Exception { public void cancel() { this.running = false; } + + @Override + public void notifyCheckpointComplete(long checkpointId) { + canFail = !isRestored; + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + positionState.add(position); + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + canFail = false; + position = 0L; + isRestored = context.isRestored(); + positionState = context.getOperatorStateStore().getListState( + new ListStateDescriptor<>("posState", Long.class)); + + if (isRestored) { + + for (Long value : positionState.get()) { + position += value; + } + } + } } - private static class ValidatingSink extends RichSinkFunction<Tuple2<Integer, Integer>> { + private static class ValidatingSink extends RichSinkFunction<Tuple2<Integer, Integer>> + implements CheckpointedFunction { private static final long serialVersionUID = 1L; private final int expectedSum; private int runningSum = 0; + private transient ListState<Integer> sumState; + private ValidatingSink(int expectedSum) { this.expectedSum = expectedSum; } @@ -227,5 +293,22 @@ public void close() throws Exception { Assert.assertEquals(expectedSum, runningSum); super.close(); } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + sumState.add(runningSum); + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + sumState = context.getOperatorStateStore().getListState( + new ListStateDescriptor<>("sumState", Integer.class)); + + if (context.isRestored()) { + for (Integer value : sumState.get()) { + runningSum += value; + } + } + } } } ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org > Using DataStreamUtils.reinterpretAsKeyedStream produces corrupted keyed state > after restore > ------------------------------------------------------------------------------------------- > > Key: FLINK-10809 > URL: https://issues.apache.org/jira/browse/FLINK-10809 > Project: Flink > Issue Type: Bug > Components: DataStream API, State Backends, Checkpointing > Affects Versions: 1.7.0 > Reporter: Dawid Wysakowicz > Assignee: Stefan Richter > Priority: Major > Labels: pull-request-available > > I've tried using {{DataStreamUtils.reinterpretAsKeyedStream}} for results of > windowed aggregation: > {code} > DataStream<Tuple2<Integer, List<Event>>> eventStream4 = > eventStream2.keyBy(Event::getKey) > > .window(SlidingEventTimeWindows.of(Time.milliseconds(150 * 3), > Time.milliseconds(150))) > .apply(new WindowFunction<Event, Tuple2<Integer, > List<Event>>, Integer, TimeWindow>() { > private static final long serialVersionUID = > 3166250579972849440L; > @Override > public void apply( > Integer key, TimeWindow window, > Iterable<Event> input, > Collector<Tuple2<Integer, List<Event>>> > out) throws Exception { > out.collect(Tuple2.of(key, > StreamSupport.stream(input.spliterator(), > false).collect(Collectors.toList()))); > } > }); > DataStreamUtils.reinterpretAsKeyedStream(eventStream4, events-> > events.f0) > .flatMap(createSlidingWindowCheckMapper(pt)) > .addSink(new PrintSinkFunction<>()); > {code} > and then in the createSlidingWindowCheckMapper I verify that each event > belongs to 3 consecutive windows, for which I keep contents of last window in > ValueState. In a non-failure setup this check runs fine, but it misses few > windows after restore at the beginning. > {code} > public class SlidingWindowCheckMapper extends > RichFlatMapFunction<Tuple2<Integer, List<Event>>, String> { > private static final long serialVersionUID = -744070793650644485L; > /** This value state tracks previously seen events with the number of > windows they appeared in. */ > private transient ValueState<List<Tuple2<Event, Integer>>> > previousWindow; > private final int slideFactor; > SlidingWindowCheckMapper(int slideFactor) { > this.slideFactor = slideFactor; > } > @Override > public void open(Configuration parameters) throws Exception { > ValueStateDescriptor<List<Tuple2<Event, Integer>>> > previousWindowDescriptor = > new ValueStateDescriptor<>("previousWindow", > new ListTypeInfo<>(new > TupleTypeInfo<>(TypeInformation.of(Event.class), > BasicTypeInfo.INT_TYPE_INFO))); > previousWindow = > getRuntimeContext().getState(previousWindowDescriptor); > } > @Override > public void flatMap(Tuple2<Integer, List<Event>> value, > Collector<String> out) throws Exception { > List<Tuple2<Event, Integer>> previousWindowValues = > Optional.ofNullable(previousWindow.value()).orElseGet( > Collections::emptyList); > List<Event> newValues = value.f1; > newValues.stream().reduce(new BinaryOperator<Event>() { > @Override > public Event apply(Event event, Event event2) { > if (event2.getSequenceNumber() - 1 != > event.getSequenceNumber()) { > out.collect("Alert: events in window > out ouf order!"); > } > return event2; > } > }); > List<Tuple2<Event, Integer>> newWindow = new ArrayList<>(); > for (Tuple2<Event, Integer> windowValue : previousWindowValues) > { > if (!newValues.contains(windowValue.f0)) { > out.collect(String.format("Alert: event %s did > not belong to %d consecutive windows. Event seen so far %d times.Current > window: %s", > windowValue.f0, > slideFactor, > windowValue.f1, > value.f1)); > } else { > newValues.remove(windowValue.f0); > if (windowValue.f1 + 1 != slideFactor) { > newWindow.add(Tuple2.of(windowValue.f0, > windowValue.f1 + 1)); > } > } > } > newValues.forEach(e -> newWindow.add(Tuple2.of(e, 1))); > previousWindow.update(newWindow); > } > } > {code} -- This message was sent by Atlassian JIRA (v7.6.3#76005)