asfgit closed pull request #7048: [FLINK-10809][state] Include keyed state that is not from head operat… URL: https://github.com/apache/flink/pull/7048
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java index b0173886d57..02fc2013fb0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java @@ -71,13 +71,12 @@ public StateAssignmentOperation( this.allowNonRestoredState = allowNonRestoredState; } - public boolean assignStates() throws Exception { + public void assignStates() { Map<OperatorID, OperatorState> localOperators = new HashMap<>(operatorStates); - Map<JobVertexID, ExecutionJobVertex> localTasks = this.tasks; checkStateMappingCompleteness(allowNonRestoredState, operatorStates, tasks); - for (Map.Entry<JobVertexID, ExecutionJobVertex> task : localTasks.entrySet()) { + for (Map.Entry<JobVertexID, ExecutionJobVertex> task : this.tasks.entrySet()) { final ExecutionJobVertex executionJobVertex = task.getValue(); // find the states of all operators belonging to this task @@ -108,7 +107,6 @@ public boolean assignStates() throws Exception { assignAttemptState(task.getValue(), operatorStates); } - return true; } private void assignAttemptState(ExecutionJobVertex executionJobVertex, List<OperatorState> operatorStates) { @@ -254,10 +252,6 @@ public static OperatorSubtaskState operatorSubtaskStateFrom( new StateObjectCollection<>(subRawKeyedState.getOrDefault(instanceID, Collections.emptyList()))); } - private static boolean isHeadOperator(int opIdx, List<OperatorID> operatorIDs) { - return opIdx == operatorIDs.size() - 1; - } - public void checkParallelismPreconditions(List<OperatorState> operatorStates, ExecutionJobVertex executionJobVertex) { for (OperatorState operatorState : operatorStates) { checkParallelismPreconditions(operatorState, executionJobVertex); @@ -278,19 +272,16 @@ private void reDistributeKeyedStates( for (int operatorIndex = 0; operatorIndex < newOperatorIDs.size(); operatorIndex++) { OperatorState operatorState = oldOperatorStates.get(operatorIndex); int oldParallelism = operatorState.getParallelism(); - for (int subTaskIndex = 0; subTaskIndex < newParallelism; subTaskIndex++) { OperatorInstanceID instanceID = OperatorInstanceID.of(subTaskIndex, newOperatorIDs.get(operatorIndex)); - if (isHeadOperator(operatorIndex, newOperatorIDs)) { - Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>> subKeyedStates = reAssignSubKeyedStates( - operatorState, - newKeyGroupPartitions, - subTaskIndex, - newParallelism, - oldParallelism); - newManagedKeyedState.put(instanceID, subKeyedStates.f0); - newRawKeyedState.put(instanceID, subKeyedStates.f1); - } + Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>> subKeyedStates = reAssignSubKeyedStates( + operatorState, + newKeyGroupPartitions, + subTaskIndex, + newParallelism, + oldParallelism); + newManagedKeyedState.put(instanceID, subKeyedStates.f0); + newRawKeyedState.put(instanceID, subKeyedStates.f1); } } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/datastream/ReinterpretDataStreamAsKeyedStreamITCase.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/datastream/ReinterpretDataStreamAsKeyedStreamITCase.java index fc8e9971683..c94319ed561 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/datastream/ReinterpretDataStreamAsKeyedStreamITCase.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/datastream/ReinterpretDataStreamAsKeyedStreamITCase.java @@ -18,11 +18,18 @@ package org.apache.flink.streaming.api.datastream; import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.state.CheckpointListener; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.streaming.api.TimeCharacteristic; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction; @@ -73,6 +80,8 @@ public void testReinterpretAsKeyedStream() throws Exception { env.setStreamTimeCharacteristic(TimeCharacteristic.IngestionTime); env.setMaxParallelism(maxParallelism); env.setParallelism(parallelism); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.fixedDelayRestart(1, 0L)); final List<File> partitionFiles = new ArrayList<>(parallelism); for (int i = 0; i < parallelism; ++i) { @@ -156,15 +165,22 @@ public void invoke(Tuple2<Integer, Integer> value, Context context) throws Excep } } - private static class FromPartitionFileSource extends RichParallelSourceFunction<Tuple2<Integer, Integer>> { + private static class FromPartitionFileSource extends RichParallelSourceFunction<Tuple2<Integer, Integer>> + implements CheckpointedFunction, CheckpointListener { private static final long serialVersionUID = 1L; private List<File> allPartitions; private DataInputStream din; private volatile boolean running; - public FromPartitionFileSource(List<File> allPartitons) { - this.allPartitions = allPartitons; + private long position; + private transient ListState<Long> positionState; + private transient boolean isRestored; + + private transient volatile boolean canFail; + + public FromPartitionFileSource(List<File> allPartitions) { + this.allPartitions = allPartitions; } @Override @@ -174,6 +190,11 @@ public void open(Configuration parameters) throws Exception { din = new DataInputStream( new BufferedInputStream( new FileInputStream(allPartitions.get(subtaskIdx)))); + + long toSkip = position; + while (toSkip > 0L) { + toSkip -= din.skip(toSkip); + } } @Override @@ -187,11 +208,27 @@ public void run(SourceContext<Tuple2<Integer, Integer>> out) throws Exception { this.running = true; try { while (running) { - Integer key = din.readInt(); - Integer val = din.readInt(); - out.collect(new Tuple2<>(key, val)); + + checkFail(); + + synchronized (out.getCheckpointLock()) { + Integer key = din.readInt(); + Integer val = din.readInt(); + out.collect(new Tuple2<>(key, val)); + + position += 2 * Integer.BYTES; + } } } catch (EOFException ignore) { + while (!isRestored) { + checkFail(); + } + } + } + + private void checkFail() throws Exception { + if (canFail) { + throw new Exception("Artificial failure."); } } @@ -199,14 +236,43 @@ public void run(SourceContext<Tuple2<Integer, Integer>> out) throws Exception { public void cancel() { this.running = false; } + + @Override + public void notifyCheckpointComplete(long checkpointId) { + canFail = !isRestored; + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + positionState.add(position); + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + canFail = false; + position = 0L; + isRestored = context.isRestored(); + positionState = context.getOperatorStateStore().getListState( + new ListStateDescriptor<>("posState", Long.class)); + + if (isRestored) { + + for (Long value : positionState.get()) { + position += value; + } + } + } } - private static class ValidatingSink extends RichSinkFunction<Tuple2<Integer, Integer>> { + private static class ValidatingSink extends RichSinkFunction<Tuple2<Integer, Integer>> + implements CheckpointedFunction { private static final long serialVersionUID = 1L; private final int expectedSum; private int runningSum = 0; + private transient ListState<Integer> sumState; + private ValidatingSink(int expectedSum) { this.expectedSum = expectedSum; } @@ -227,5 +293,22 @@ public void close() throws Exception { Assert.assertEquals(expectedSum, runningSum); super.close(); } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + sumState.add(runningSum); + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + sumState = context.getOperatorStateStore().getListState( + new ListStateDescriptor<>("sumState", Integer.class)); + + if (context.isRestored()) { + for (Integer value : sumState.get()) { + runningSum += value; + } + } + } } } ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services