asfgit closed pull request #7048: [FLINK-10809][state] Include keyed state that 
is not from head operat…
URL: https://github.com/apache/flink/pull/7048
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
index b0173886d57..02fc2013fb0 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -71,13 +71,12 @@ public StateAssignmentOperation(
                this.allowNonRestoredState = allowNonRestoredState;
        }
 
-       public boolean assignStates() throws Exception {
+       public void assignStates() {
                Map<OperatorID, OperatorState> localOperators = new 
HashMap<>(operatorStates);
-               Map<JobVertexID, ExecutionJobVertex> localTasks = this.tasks;
 
                checkStateMappingCompleteness(allowNonRestoredState, 
operatorStates, tasks);
 
-               for (Map.Entry<JobVertexID, ExecutionJobVertex> task : 
localTasks.entrySet()) {
+               for (Map.Entry<JobVertexID, ExecutionJobVertex> task : 
this.tasks.entrySet()) {
                        final ExecutionJobVertex executionJobVertex = 
task.getValue();
 
                        // find the states of all operators belonging to this 
task
@@ -108,7 +107,6 @@ public boolean assignStates() throws Exception {
                        assignAttemptState(task.getValue(), operatorStates);
                }
 
-               return true;
        }
 
        private void assignAttemptState(ExecutionJobVertex executionJobVertex, 
List<OperatorState> operatorStates) {
@@ -254,10 +252,6 @@ public static OperatorSubtaskState 
operatorSubtaskStateFrom(
                        new 
StateObjectCollection<>(subRawKeyedState.getOrDefault(instanceID, 
Collections.emptyList())));
        }
 
-       private static boolean isHeadOperator(int opIdx, List<OperatorID> 
operatorIDs) {
-               return opIdx == operatorIDs.size() - 1;
-       }
-
        public void checkParallelismPreconditions(List<OperatorState> 
operatorStates, ExecutionJobVertex executionJobVertex) {
                for (OperatorState operatorState : operatorStates) {
                        checkParallelismPreconditions(operatorState, 
executionJobVertex);
@@ -278,19 +272,16 @@ private void reDistributeKeyedStates(
                for (int operatorIndex = 0; operatorIndex < 
newOperatorIDs.size(); operatorIndex++) {
                        OperatorState operatorState = 
oldOperatorStates.get(operatorIndex);
                        int oldParallelism = operatorState.getParallelism();
-
                        for (int subTaskIndex = 0; subTaskIndex < 
newParallelism; subTaskIndex++) {
                                OperatorInstanceID instanceID = 
OperatorInstanceID.of(subTaskIndex, newOperatorIDs.get(operatorIndex));
-                               if (isHeadOperator(operatorIndex, 
newOperatorIDs)) {
-                                       Tuple2<List<KeyedStateHandle>, 
List<KeyedStateHandle>> subKeyedStates = reAssignSubKeyedStates(
-                                               operatorState,
-                                               newKeyGroupPartitions,
-                                               subTaskIndex,
-                                               newParallelism,
-                                               oldParallelism);
-                                       newManagedKeyedState.put(instanceID, 
subKeyedStates.f0);
-                                       newRawKeyedState.put(instanceID, 
subKeyedStates.f1);
-                               }
+                               Tuple2<List<KeyedStateHandle>, 
List<KeyedStateHandle>> subKeyedStates = reAssignSubKeyedStates(
+                                       operatorState,
+                                       newKeyGroupPartitions,
+                                       subTaskIndex,
+                                       newParallelism,
+                                       oldParallelism);
+                               newManagedKeyedState.put(instanceID, 
subKeyedStates.f0);
+                               newRawKeyedState.put(instanceID, 
subKeyedStates.f1);
                        }
                }
        }
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/datastream/ReinterpretDataStreamAsKeyedStreamITCase.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/datastream/ReinterpretDataStreamAsKeyedStreamITCase.java
index fc8e9971683..c94319ed561 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/datastream/ReinterpretDataStreamAsKeyedStreamITCase.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/datastream/ReinterpretDataStreamAsKeyedStreamITCase.java
@@ -18,11 +18,18 @@
 package org.apache.flink.streaming.api.datastream;
 
 import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.CheckpointListener;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
 import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction;
@@ -73,6 +80,8 @@ public void testReinterpretAsKeyedStream() throws Exception {
                
env.setStreamTimeCharacteristic(TimeCharacteristic.IngestionTime);
                env.setMaxParallelism(maxParallelism);
                env.setParallelism(parallelism);
+               env.enableCheckpointing(100);
+               env.setRestartStrategy(RestartStrategies.fixedDelayRestart(1, 
0L));
 
                final List<File> partitionFiles = new ArrayList<>(parallelism);
                for (int i = 0; i < parallelism; ++i) {
@@ -156,15 +165,22 @@ public void invoke(Tuple2<Integer, Integer> value, 
Context context) throws Excep
                }
        }
 
-       private static class FromPartitionFileSource extends 
RichParallelSourceFunction<Tuple2<Integer, Integer>> {
+       private static class FromPartitionFileSource extends 
RichParallelSourceFunction<Tuple2<Integer, Integer>>
+       implements CheckpointedFunction, CheckpointListener {
                private static final long serialVersionUID = 1L;
 
                private List<File> allPartitions;
                private DataInputStream din;
                private volatile boolean running;
 
-               public FromPartitionFileSource(List<File> allPartitons) {
-                       this.allPartitions = allPartitons;
+               private long position;
+               private transient ListState<Long> positionState;
+               private transient boolean isRestored;
+
+               private transient volatile boolean canFail;
+
+               public FromPartitionFileSource(List<File> allPartitions) {
+                       this.allPartitions = allPartitions;
                }
 
                @Override
@@ -174,6 +190,11 @@ public void open(Configuration parameters) throws 
Exception {
                        din = new DataInputStream(
                                new BufferedInputStream(
                                        new 
FileInputStream(allPartitions.get(subtaskIdx))));
+
+                       long toSkip = position;
+                       while (toSkip > 0L) {
+                               toSkip -= din.skip(toSkip);
+                       }
                }
 
                @Override
@@ -187,11 +208,27 @@ public void run(SourceContext<Tuple2<Integer, Integer>> 
out) throws Exception {
                        this.running = true;
                        try {
                                while (running) {
-                                       Integer key = din.readInt();
-                                       Integer val = din.readInt();
-                                       out.collect(new Tuple2<>(key, val));
+
+                                       checkFail();
+
+                                       synchronized (out.getCheckpointLock()) {
+                                               Integer key = din.readInt();
+                                               Integer val = din.readInt();
+                                               out.collect(new Tuple2<>(key, 
val));
+
+                                               position += 2 * Integer.BYTES;
+                                       }
                                }
                        } catch (EOFException ignore) {
+                               while (!isRestored) {
+                                       checkFail();
+                               }
+                       }
+               }
+
+               private void checkFail() throws Exception {
+                       if (canFail) {
+                               throw new Exception("Artificial failure.");
                        }
                }
 
@@ -199,14 +236,43 @@ public void run(SourceContext<Tuple2<Integer, Integer>> 
out) throws Exception {
                public void cancel() {
                        this.running = false;
                }
+
+               @Override
+               public void notifyCheckpointComplete(long checkpointId) {
+                       canFail = !isRestored;
+               }
+
+               @Override
+               public void snapshotState(FunctionSnapshotContext context) 
throws Exception {
+                       positionState.add(position);
+               }
+
+               @Override
+               public void initializeState(FunctionInitializationContext 
context) throws Exception {
+                       canFail = false;
+                       position = 0L;
+                       isRestored = context.isRestored();
+                       positionState = 
context.getOperatorStateStore().getListState(
+                               new ListStateDescriptor<>("posState", 
Long.class));
+
+                       if (isRestored) {
+
+                               for (Long value : positionState.get()) {
+                                       position += value;
+                               }
+                       }
+               }
        }
 
-       private static class ValidatingSink extends 
RichSinkFunction<Tuple2<Integer, Integer>> {
+       private static class ValidatingSink extends 
RichSinkFunction<Tuple2<Integer, Integer>>
+               implements CheckpointedFunction {
 
                private static final long serialVersionUID = 1L;
                private final int expectedSum;
                private int runningSum = 0;
 
+               private transient ListState<Integer> sumState;
+
                private ValidatingSink(int expectedSum) {
                        this.expectedSum = expectedSum;
                }
@@ -227,5 +293,22 @@ public void close() throws Exception {
                        Assert.assertEquals(expectedSum, runningSum);
                        super.close();
                }
+
+               @Override
+               public void snapshotState(FunctionSnapshotContext context) 
throws Exception {
+                       sumState.add(runningSum);
+               }
+
+               @Override
+               public void initializeState(FunctionInitializationContext 
context) throws Exception {
+                       sumState = context.getOperatorStateStore().getListState(
+                               new ListStateDescriptor<>("sumState", 
Integer.class));
+
+                       if (context.isRestored()) {
+                               for (Integer value : sumState.get()) {
+                                       runningSum += value;
+                               }
+                       }
+               }
        }
 }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to