pnowojski commented on code in PR #20151: URL: https://github.com/apache/flink/pull/20151#discussion_r1052295272
########## flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java: ########## @@ -18,280 +18,265 @@ package org.apache.flink.runtime.checkpoint.channel; import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult; +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.logger.NetworkActionsLogger; -import org.apache.flink.runtime.state.AbstractChannelStateHandle; +import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.state.AbstractChannelStateHandle.StateContentMetaInfo; import org.apache.flink.runtime.state.CheckpointStateOutputStream; import org.apache.flink.runtime.state.CheckpointStreamFactory; -import org.apache.flink.runtime.state.InputChannelStateHandle; -import org.apache.flink.runtime.state.ResultSubpartitionStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.util.Preconditions; import org.apache.flink.util.function.RunnableWithException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.Nonnull; import javax.annotation.concurrent.NotThreadSafe; import java.io.DataOutputStream; import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; import java.util.HashMap; -import java.util.List; +import java.util.HashSet; import java.util.Map; -import java.util.Optional; -import java.util.concurrent.CompletableFuture; +import java.util.Objects; +import java.util.Set; -import static java.util.Collections.emptyList; -import static java.util.Collections.singletonList; -import static java.util.UUID.randomUUID; +import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHANNEL_STATE_SHARED_STREAM_EXCEPTION; import static org.apache.flink.runtime.state.CheckpointedStateScope.EXCLUSIVE; import static org.apache.flink.util.ExceptionUtils.findThrowable; import static org.apache.flink.util.ExceptionUtils.rethrow; +import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; -/** Writes channel state for a specific checkpoint-subtask-attempt triple. */ +/** Writes channel state for multiple subtasks of the same checkpoint. */ @NotThreadSafe class ChannelStateCheckpointWriter { private static final Logger LOG = LoggerFactory.getLogger(ChannelStateCheckpointWriter.class); private final DataOutputStream dataStream; private final CheckpointStateOutputStream checkpointStream; - private final ChannelStateWriteResult result; - private final Map<InputChannelInfo, StateContentMetaInfo> inputChannelOffsets = new HashMap<>(); - private final Map<ResultSubpartitionInfo, StateContentMetaInfo> resultSubpartitionOffsets = - new HashMap<>(); + + /** + * Indicates whether the current checkpoints of all subtasks have exception. If it's not null, + * the checkpoint will fail. + */ + private Throwable throwable; + private final ChannelStateSerializer serializer; private final long checkpointId; - private boolean allInputsReceived = false; - private boolean allOutputsReceived = false; private final RunnableWithException onComplete; - private final int subtaskIndex; - private final String taskName; + + // Subtasks that have not yet register writer result. + private final Set<SubtaskID> waitedSubtasks; + + private final Map<SubtaskID, ChannelStatePendingResult> pendingResults = new HashMap<>(); ChannelStateCheckpointWriter( - String taskName, - int subtaskIndex, - CheckpointStartRequest startCheckpointItem, + Set<SubtaskID> subtasks, + long checkpointId, CheckpointStreamFactory streamFactory, ChannelStateSerializer serializer, RunnableWithException onComplete) throws Exception { this( - taskName, - subtaskIndex, - startCheckpointItem.getCheckpointId(), - startCheckpointItem.getTargetResult(), + subtasks, + checkpointId, streamFactory.createCheckpointStateOutputStream(EXCLUSIVE), serializer, onComplete); } @VisibleForTesting ChannelStateCheckpointWriter( - String taskName, - int subtaskIndex, + Set<SubtaskID> subtasks, long checkpointId, - ChannelStateWriteResult result, CheckpointStateOutputStream stream, ChannelStateSerializer serializer, RunnableWithException onComplete) { - this( - taskName, - subtaskIndex, - checkpointId, - result, - serializer, - onComplete, - stream, - new DataOutputStream(stream)); + this(subtasks, checkpointId, serializer, onComplete, stream, new DataOutputStream(stream)); } @VisibleForTesting ChannelStateCheckpointWriter( - String taskName, - int subtaskIndex, + Set<SubtaskID> subtasks, long checkpointId, - ChannelStateWriteResult result, ChannelStateSerializer serializer, RunnableWithException onComplete, CheckpointStateOutputStream checkpointStateOutputStream, DataOutputStream dataStream) { - this.taskName = taskName; - this.subtaskIndex = subtaskIndex; + checkArgument(!subtasks.isEmpty(), "The subtasks cannot be empty."); + this.waitedSubtasks = new HashSet<>(subtasks); this.checkpointId = checkpointId; - this.result = checkNotNull(result); this.checkpointStream = checkNotNull(checkpointStateOutputStream); this.serializer = checkNotNull(serializer); this.dataStream = checkNotNull(dataStream); this.onComplete = checkNotNull(onComplete); runWithChecks(() -> serializer.writeHeader(dataStream)); } - void writeInput(InputChannelInfo info, Buffer buffer) { - write( - inputChannelOffsets, - info, - buffer, - !allInputsReceived, - "ChannelStateCheckpointWriter#writeInput"); + void registerSubtaskResult( + SubtaskID subtaskID, ChannelStateWriter.ChannelStateWriteResult result) { + // The writer shouldn't register any subtask after writer has exception or is done, + checkState(!isDone(), "The write is done."); + Preconditions.checkState( + !pendingResults.containsKey(subtaskID), + "The subtask %s has already been register before.", + subtaskID); + waitedSubtasks.remove(subtaskID); + + ChannelStatePendingResult pendingResult = + new ChannelStatePendingResult( + subtaskID.getSubtaskIndex(), checkpointId, result, serializer); + pendingResults.put(subtaskID, pendingResult); } - void writeOutput(ResultSubpartitionInfo info, Buffer buffer) { - write( - resultSubpartitionOffsets, - info, - buffer, - !allOutputsReceived, - "ChannelStateCheckpointWriter#writeOutput"); + void releaseSubtask(SubtaskID subtaskID) throws Exception { + if (waitedSubtasks.remove(subtaskID)) { + tryFinishResult(); Review Comment: 👍 Could you maybe explain this in a java doc? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org