Github user tillrohrmann commented on a diff in the pull request: https://github.com/apache/flink/pull/1883#discussion_r60906533 --- Diff: flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java --- @@ -781,43 +803,69 @@ public boolean restoreLatestCheckpointedState( long recoveryTimestamp = System.currentTimeMillis(); - if (allOrNothingState) { - Map<ExecutionJobVertex, Integer> stateCounts = new HashMap<ExecutionJobVertex, Integer>(); + for (Map.Entry<JobVertexID, StateForTaskGroup> taskGroupStateEntry: latest.getTaskGroupStates().entrySet()) { + StateForTaskGroup taskGroupState = taskGroupStateEntry.getValue(); + ExecutionJobVertex executionJobVertex = tasks.get(taskGroupStateEntry.getKey()); + + if (executionJobVertex != null) { + // check that we only restore the state if the parallelism has not been changed + if (taskGroupState.getParallelism() != executionJobVertex.getParallelism()) { + throw new RuntimeException("Cannot restore the latest checkpoint because " + + "the parallelism changed. The operator" + executionJobVertex.getJobVertexId() + + " has parallelism " + executionJobVertex.getParallelism() + " whereas the corresponding" + + "state object has a parallelism of " + taskGroupState.getParallelism()); + } - for (StateForTask state : latest.getStates()) { - ExecutionJobVertex vertex = tasks.get(state.getOperatorId()); - Execution exec = vertex.getTaskVertices()[state.getSubtask()].getCurrentExecutionAttempt(); - exec.setInitialState(state.getState(), recoveryTimestamp); + int counter = 0; - Integer count = stateCounts.get(vertex); - if (count != null) { - stateCounts.put(vertex, count+1); - } else { - stateCounts.put(vertex, 1); + List<Set<Integer>> keyGroupPartitions = createKeyGroupPartitions(numberKeyGroups, executionJobVertex.getParallelism()); + + for (int i = 0; i < executionJobVertex.getParallelism(); i++) { + StateForTask stateForTask = taskGroupState.getState(i); + SerializedValue<StateHandle<?>> state = null; + + if (stateForTask != null) { + // count the number of executions for which we set a state + counter++; + state = stateForTask.getState(); + } + + Map<Integer, SerializedValue<StateHandle<?>>> kvStateForTaskMap = taskGroupState.getUnwrappedKvStates(keyGroupPartitions.get(i)); + + Execution currentExecutionAttempt = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt(); + currentExecutionAttempt.setInitialState(state, kvStateForTaskMap, recoveryTimestamp); } - } - // validate that either all task vertices have state, or none - for (Map.Entry<ExecutionJobVertex, Integer> entry : stateCounts.entrySet()) { - ExecutionJobVertex vertex = entry.getKey(); - if (entry.getValue() != vertex.getParallelism()) { - throw new IllegalStateException( - "The checkpoint contained state only for a subset of tasks for vertex " + vertex); + if (allOrNothingState && counter > 0 && counter < executionJobVertex.getParallelism()) { + throw new IllegalStateException("The checkpoint contained state only for " + + "a subset of tasks for vertex " + executionJobVertex); } - } - } - else { - for (StateForTask state : latest.getStates()) { - ExecutionJobVertex vertex = tasks.get(state.getOperatorId()); - Execution exec = vertex.getTaskVertices()[state.getSubtask()].getCurrentExecutionAttempt(); - exec.setInitialState(state.getState(), recoveryTimestamp); + } else { + throw new IllegalStateException("There is no execution job vertex for the job" + + " vertex ID " + taskGroupStateEntry.getKey()); } } return true; } } + protected List<Set<Integer>> createKeyGroupPartitions(int numberKeyGroups, int parallelism) { --- End diff -- That is true. Will add it.
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. ---