pnowojski commented on code in PR #20151: URL: https://github.com/apache/flink/pull/20151#discussion_r1048286114
########## flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java: ########## @@ -18,280 +18,265 @@ package org.apache.flink.runtime.checkpoint.channel; import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult; +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.logger.NetworkActionsLogger; -import org.apache.flink.runtime.state.AbstractChannelStateHandle; +import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.state.AbstractChannelStateHandle.StateContentMetaInfo; import org.apache.flink.runtime.state.CheckpointStateOutputStream; import org.apache.flink.runtime.state.CheckpointStreamFactory; -import org.apache.flink.runtime.state.InputChannelStateHandle; -import org.apache.flink.runtime.state.ResultSubpartitionStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.util.Preconditions; import org.apache.flink.util.function.RunnableWithException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.Nonnull; import javax.annotation.concurrent.NotThreadSafe; import java.io.DataOutputStream; import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; import java.util.HashMap; -import java.util.List; +import java.util.HashSet; import java.util.Map; -import java.util.Optional; -import java.util.concurrent.CompletableFuture; +import java.util.Objects; +import java.util.Set; -import static java.util.Collections.emptyList; -import static java.util.Collections.singletonList; -import static java.util.UUID.randomUUID; +import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHANNEL_STATE_SHARED_STREAM_EXCEPTION; import static org.apache.flink.runtime.state.CheckpointedStateScope.EXCLUSIVE; import static org.apache.flink.util.ExceptionUtils.findThrowable; import static org.apache.flink.util.ExceptionUtils.rethrow; +import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; -/** Writes channel state for a specific checkpoint-subtask-attempt triple. */ +/** Writes channel state for multiple subtasks of the same checkpoint. */ @NotThreadSafe class ChannelStateCheckpointWriter { private static final Logger LOG = LoggerFactory.getLogger(ChannelStateCheckpointWriter.class); private final DataOutputStream dataStream; private final CheckpointStateOutputStream checkpointStream; - private final ChannelStateWriteResult result; - private final Map<InputChannelInfo, StateContentMetaInfo> inputChannelOffsets = new HashMap<>(); - private final Map<ResultSubpartitionInfo, StateContentMetaInfo> resultSubpartitionOffsets = - new HashMap<>(); + + /** + * Indicates whether the current checkpoints of all subtasks have exception. If it's not null, + * the checkpoint will fail. + */ + private Throwable throwable; + private final ChannelStateSerializer serializer; private final long checkpointId; - private boolean allInputsReceived = false; - private boolean allOutputsReceived = false; private final RunnableWithException onComplete; - private final int subtaskIndex; - private final String taskName; + + // Subtasks that have not yet register writer result. + private final Set<SubtaskID> waitedSubtasks; + + private final Map<SubtaskID, ChannelStatePendingResult> pendingResults = new HashMap<>(); ChannelStateCheckpointWriter( - String taskName, - int subtaskIndex, - CheckpointStartRequest startCheckpointItem, + Set<SubtaskID> subtasks, + long checkpointId, CheckpointStreamFactory streamFactory, ChannelStateSerializer serializer, RunnableWithException onComplete) throws Exception { this( - taskName, - subtaskIndex, - startCheckpointItem.getCheckpointId(), - startCheckpointItem.getTargetResult(), + subtasks, + checkpointId, streamFactory.createCheckpointStateOutputStream(EXCLUSIVE), serializer, onComplete); } @VisibleForTesting ChannelStateCheckpointWriter( - String taskName, - int subtaskIndex, + Set<SubtaskID> subtasks, long checkpointId, - ChannelStateWriteResult result, CheckpointStateOutputStream stream, ChannelStateSerializer serializer, RunnableWithException onComplete) { - this( - taskName, - subtaskIndex, - checkpointId, - result, - serializer, - onComplete, - stream, - new DataOutputStream(stream)); + this(subtasks, checkpointId, serializer, onComplete, stream, new DataOutputStream(stream)); } @VisibleForTesting ChannelStateCheckpointWriter( - String taskName, - int subtaskIndex, + Set<SubtaskID> subtasks, long checkpointId, - ChannelStateWriteResult result, ChannelStateSerializer serializer, RunnableWithException onComplete, CheckpointStateOutputStream checkpointStateOutputStream, DataOutputStream dataStream) { - this.taskName = taskName; - this.subtaskIndex = subtaskIndex; + checkArgument(!subtasks.isEmpty(), "The subtasks cannot be empty."); + this.waitedSubtasks = new HashSet<>(subtasks); Review Comment: `waitedSubtasks` -> `subtasksToRegister`? ########## flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStatePendingResult.java: ########## @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.runtime.state.AbstractChannelStateHandle; +import org.apache.flink.runtime.state.AbstractChannelStateHandle.StateContentMetaInfo; +import org.apache.flink.runtime.state.InputChannelStateHandle; +import org.apache.flink.runtime.state.ResultSubpartitionStateHandle; +import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; +import static java.util.UUID.randomUUID; +import static org.apache.flink.util.Preconditions.checkArgument; + +/** The pending result of channel state for a specific checkpoint-subtask. */ +public class ChannelStatePendingResult { + + private static final Logger LOG = LoggerFactory.getLogger(ChannelStatePendingResult.class); + + // Subtask information + private final int subtaskIndex; + + private final long checkpointId; + + // Result related + private final ChannelStateSerializer serializer; + private final ChannelStateWriter.ChannelStateWriteResult result; + private final Map<InputChannelInfo, AbstractChannelStateHandle.StateContentMetaInfo> + inputChannelOffsets = new HashMap<>(); + private final Map<ResultSubpartitionInfo, AbstractChannelStateHandle.StateContentMetaInfo> + resultSubpartitionOffsets = new HashMap<>(); + private boolean allInputsReceived = false; + private boolean allOutputsReceived = false; + + public ChannelStatePendingResult( + int subtaskIndex, + long checkpointId, + ChannelStateWriter.ChannelStateWriteResult result, + ChannelStateSerializer serializer) { + this.subtaskIndex = subtaskIndex; + this.checkpointId = checkpointId; + this.result = result; + this.serializer = serializer; + } + + public boolean isAllInputsReceived() { + return allInputsReceived; + } + + public boolean isAllOutputsReceived() { + return allOutputsReceived; + } + + public Map<InputChannelInfo, StateContentMetaInfo> getInputChannelOffsets() { + return inputChannelOffsets; + } + + public Map<ResultSubpartitionInfo, StateContentMetaInfo> getResultSubpartitionOffsets() { + return resultSubpartitionOffsets; + } + + void completeInput() { + LOG.debug("complete input, output completed: {}", allOutputsReceived); + checkArgument(!allInputsReceived); + allInputsReceived = true; + } + + void completeOutput() { + LOG.debug("complete output, input completed: {}", allInputsReceived); + checkArgument(!allOutputsReceived); + allOutputsReceived = true; + } + + public void finishResult(StreamStateHandle stateHandle) throws IOException { + if (inputChannelOffsets.isEmpty() && resultSubpartitionOffsets.isEmpty()) { + result.inputChannelStateHandles.complete(emptyList()); + result.resultSubpartitionStateHandles.complete(emptyList()); + return; + } Review Comment: Does it make sense to optimise this code case? I think without this if check, the `complete()` calls below would still work the same way? ########## flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java: ########## @@ -315,18 +325,73 @@ public void fail(Throwable e) { } } - private interface HandleFactory<I, H extends AbstractChannelStateHandle<I>> { - H create( - int subtaskIndex, - I info, - StreamStateHandle underlying, - List<Long> offsets, - long size); + @Nonnull + private ChannelStatePendingResult getChannelStatePendingResult( + JobID jobID, JobVertexID jobVertexID, int subtaskIndex) { + SubtaskID subtaskID = SubtaskID.of(jobID, jobVertexID, subtaskIndex); + ChannelStatePendingResult pendingResult = pendingResults.get(subtaskID); + checkNotNull(pendingResult, "The subtask[%s] is not registered yet", subtaskID); + return pendingResult; + } +} + +/** A identification for subtask. */ +class SubtaskID { + + private final JobID jobID; Review Comment: Why do we even need the `JobID` here? All of the write requests should be coming from the same job. And if we really need it, why do we have to pass `JobID` for every `writeInput`/`writeOutput` call? Should it be passed through the constructor? Now it suggests that single writer has to handle writes from different jobs, which I hope is not the case 😅 ########## flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java: ########## @@ -90,33 +103,62 @@ public void dispatch(ChannelStateWriteRequest request) throws Exception { } private void dispatchInternal(ChannelStateWriteRequest request) throws Exception { + if (request instanceof SubtaskRegisterRequest) { + SubtaskRegisterRequest req = (SubtaskRegisterRequest) request; + SubtaskID subtaskID = + SubtaskID.of(req.getJobID(), req.getJobVertexID(), req.getSubtaskIndex()); + subtasks.add(subtaskID); Review Comment: `subtasks` -> `registeredSubtasks`? ########## flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImpl.java: ########## @@ -51,27 +66,40 @@ class ChannelStateWriteRequestExecutorImpl implements ChannelStateWriteRequestEx private final Thread thread; private volatile Exception thrown = null; private volatile boolean wasClosed = false; - private final String taskName; + + private final Map<SubtaskID, Queue<ChannelStateWriteRequest>> unreadyQueues = + new ConcurrentHashMap<>(); + + private final JobID jobID; + private final Set<SubtaskID> subtasks; + private final AtomicBoolean isRegistering = new AtomicBoolean(true); + private final int numberOfSubtasksShareFile; Review Comment: Haven't we discussed somewhere, that this (and the `dequeue` above) could be replaced with non thread safe versions and a single explicit lock, instead of having many different thread safe primitives used? Or was it for a different class 🤔? ########## flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java: ########## @@ -18,280 +18,265 @@ package org.apache.flink.runtime.checkpoint.channel; import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult; +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.logger.NetworkActionsLogger; -import org.apache.flink.runtime.state.AbstractChannelStateHandle; +import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.state.AbstractChannelStateHandle.StateContentMetaInfo; import org.apache.flink.runtime.state.CheckpointStateOutputStream; import org.apache.flink.runtime.state.CheckpointStreamFactory; -import org.apache.flink.runtime.state.InputChannelStateHandle; -import org.apache.flink.runtime.state.ResultSubpartitionStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.util.Preconditions; import org.apache.flink.util.function.RunnableWithException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.Nonnull; import javax.annotation.concurrent.NotThreadSafe; import java.io.DataOutputStream; import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; import java.util.HashMap; -import java.util.List; +import java.util.HashSet; import java.util.Map; -import java.util.Optional; -import java.util.concurrent.CompletableFuture; +import java.util.Objects; +import java.util.Set; -import static java.util.Collections.emptyList; -import static java.util.Collections.singletonList; -import static java.util.UUID.randomUUID; +import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHANNEL_STATE_SHARED_STREAM_EXCEPTION; import static org.apache.flink.runtime.state.CheckpointedStateScope.EXCLUSIVE; import static org.apache.flink.util.ExceptionUtils.findThrowable; import static org.apache.flink.util.ExceptionUtils.rethrow; +import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; -/** Writes channel state for a specific checkpoint-subtask-attempt triple. */ +/** Writes channel state for multiple subtasks of the same checkpoint. */ @NotThreadSafe class ChannelStateCheckpointWriter { private static final Logger LOG = LoggerFactory.getLogger(ChannelStateCheckpointWriter.class); private final DataOutputStream dataStream; private final CheckpointStateOutputStream checkpointStream; - private final ChannelStateWriteResult result; - private final Map<InputChannelInfo, StateContentMetaInfo> inputChannelOffsets = new HashMap<>(); - private final Map<ResultSubpartitionInfo, StateContentMetaInfo> resultSubpartitionOffsets = - new HashMap<>(); + + /** + * Indicates whether the current checkpoints of all subtasks have exception. If it's not null, + * the checkpoint will fail. + */ + private Throwable throwable; + private final ChannelStateSerializer serializer; private final long checkpointId; - private boolean allInputsReceived = false; - private boolean allOutputsReceived = false; private final RunnableWithException onComplete; - private final int subtaskIndex; - private final String taskName; + + // Subtasks that have not yet register writer result. + private final Set<SubtaskID> waitedSubtasks; + + private final Map<SubtaskID, ChannelStatePendingResult> pendingResults = new HashMap<>(); ChannelStateCheckpointWriter( - String taskName, - int subtaskIndex, - CheckpointStartRequest startCheckpointItem, + Set<SubtaskID> subtasks, + long checkpointId, CheckpointStreamFactory streamFactory, ChannelStateSerializer serializer, RunnableWithException onComplete) throws Exception { this( - taskName, - subtaskIndex, - startCheckpointItem.getCheckpointId(), - startCheckpointItem.getTargetResult(), + subtasks, + checkpointId, streamFactory.createCheckpointStateOutputStream(EXCLUSIVE), serializer, onComplete); } @VisibleForTesting ChannelStateCheckpointWriter( - String taskName, - int subtaskIndex, + Set<SubtaskID> subtasks, long checkpointId, - ChannelStateWriteResult result, CheckpointStateOutputStream stream, ChannelStateSerializer serializer, RunnableWithException onComplete) { - this( - taskName, - subtaskIndex, - checkpointId, - result, - serializer, - onComplete, - stream, - new DataOutputStream(stream)); + this(subtasks, checkpointId, serializer, onComplete, stream, new DataOutputStream(stream)); } @VisibleForTesting ChannelStateCheckpointWriter( - String taskName, - int subtaskIndex, + Set<SubtaskID> subtasks, long checkpointId, - ChannelStateWriteResult result, ChannelStateSerializer serializer, RunnableWithException onComplete, CheckpointStateOutputStream checkpointStateOutputStream, DataOutputStream dataStream) { - this.taskName = taskName; - this.subtaskIndex = subtaskIndex; + checkArgument(!subtasks.isEmpty(), "The subtasks cannot be empty."); + this.waitedSubtasks = new HashSet<>(subtasks); this.checkpointId = checkpointId; - this.result = checkNotNull(result); this.checkpointStream = checkNotNull(checkpointStateOutputStream); this.serializer = checkNotNull(serializer); this.dataStream = checkNotNull(dataStream); this.onComplete = checkNotNull(onComplete); runWithChecks(() -> serializer.writeHeader(dataStream)); } - void writeInput(InputChannelInfo info, Buffer buffer) { - write( - inputChannelOffsets, - info, - buffer, - !allInputsReceived, - "ChannelStateCheckpointWriter#writeInput"); + void registerSubtaskResult( + SubtaskID subtaskID, ChannelStateWriter.ChannelStateWriteResult result) { + // The writer shouldn't register any subtask after writer has exception or is done, + checkState(!isDone(), "The write is done."); + Preconditions.checkState( + !pendingResults.containsKey(subtaskID), + "The subtask %s has already been register before.", + subtaskID); + waitedSubtasks.remove(subtaskID); + + ChannelStatePendingResult pendingResult = + new ChannelStatePendingResult( + subtaskID.getSubtaskIndex(), checkpointId, result, serializer); + pendingResults.put(subtaskID, pendingResult); } - void writeOutput(ResultSubpartitionInfo info, Buffer buffer) { - write( - resultSubpartitionOffsets, - info, - buffer, - !allOutputsReceived, - "ChannelStateCheckpointWriter#writeOutput"); + void releaseSubtask(SubtaskID subtaskID) throws Exception { + if (waitedSubtasks.remove(subtaskID)) { + tryFinishResult(); Review Comment: Why do we have to call `tryFinishResult` on release? ########## flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java: ########## @@ -90,33 +103,62 @@ public void dispatch(ChannelStateWriteRequest request) throws Exception { } private void dispatchInternal(ChannelStateWriteRequest request) throws Exception { Review Comment: This method has grown a bit too much. Could you split into sth like: ``` if (isAbortedCheckpoint(...)) { handleAbortedRequest(...); } else if (request instanceof X) { handleRequestX(....); } (....) ``` ? ########## flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java: ########## @@ -41,69 +46,123 @@ import static org.apache.flink.util.CloseableIterator.ofElements; import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; interface ChannelStateWriteRequest { Logger LOG = LoggerFactory.getLogger(ChannelStateWriteRequest.class); + JobID getJobID(); + + JobVertexID getJobVertexID(); + + int getSubtaskIndex(); + long getCheckpointId(); void cancel(Throwable cause) throws Exception; - static CheckpointInProgressRequest completeInput(long checkpointId) { + CompletableFuture<?> getReadyFuture(); + + static CheckpointInProgressRequest completeInput( + JobID jobID, JobVertexID jobVertexID, int subtaskIndex, long checkpointId) { Review Comment: ditto about the need for JobID here ########## flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java: ########## @@ -172,25 +257,80 @@ static ThrowingConsumer<Throwable, Exception> recycle(Buffer[] flinkBuffers) { } } -final class CheckpointStartRequest implements ChannelStateWriteRequest { +abstract class AbstractChannelStateWriteRequest implements ChannelStateWriteRequest { + + private final JobID jobID; + + private final JobVertexID jobVertexID; + + private final int subtaskIndex; + + private final long checkpointId; + + public AbstractChannelStateWriteRequest( + JobID jobID, JobVertexID jobVertexID, int subtaskIndex, long checkpointId) { + this.jobID = jobID; + this.jobVertexID = jobVertexID; + this.subtaskIndex = subtaskIndex; + this.checkpointId = checkpointId; + } + + @Override + public final JobID getJobID() { + return jobID; + } + + @Override + public final JobVertexID getJobVertexID() { + return jobVertexID; + } + + @Override + public final int getSubtaskIndex() { + return subtaskIndex; + } + + @Override + public final long getCheckpointId() { + return checkpointId; + } + + @Override + public CompletableFuture<?> getReadyFuture() { + return AvailabilityProvider.AVAILABLE; + } + + @Override + public String toString() { Review Comment: If you are adding already this super method, why don't you add the "name" parameter here as well? This way there would be no need for the subclasses to override this toString method ########## flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java: ########## @@ -172,25 +257,80 @@ static ThrowingConsumer<Throwable, Exception> recycle(Buffer[] flinkBuffers) { } } -final class CheckpointStartRequest implements ChannelStateWriteRequest { +abstract class AbstractChannelStateWriteRequest implements ChannelStateWriteRequest { Review Comment: why do we need the interface and the abstract class? Can now we just change the `ChannelStateWriteRequest` interface into the abstract class? ########## flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java: ########## @@ -41,69 +46,123 @@ import static org.apache.flink.util.CloseableIterator.ofElements; import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; interface ChannelStateWriteRequest { Logger LOG = LoggerFactory.getLogger(ChannelStateWriteRequest.class); + JobID getJobID(); + + JobVertexID getJobVertexID(); + + int getSubtaskIndex(); + long getCheckpointId(); void cancel(Throwable cause) throws Exception; - static CheckpointInProgressRequest completeInput(long checkpointId) { + CompletableFuture<?> getReadyFuture(); Review Comment: can you add a javadoc explaining what is this method doing/is used for? ########## flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java: ########## @@ -152,14 +200,23 @@ private void failAndClearWriter(Throwable e) { writer = null; } + private void failAndClearWriter( + JobID jobID, JobVertexID jobVertexID, int subtaskIndex, Throwable throwable) { + if (writer == null) { + return; + } + writer.fail(jobID, jobVertexID, subtaskIndex, throwable); + writer = null; + } Review Comment: Why do we have two `failAndClearWriter` methods that are behaving differently? One is failing the `pendingResult` the other is not? 🤔 ########## flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorFactory.java: ########## @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.state.CheckpointStorage; + +import java.util.HashMap; +import java.util.Map; + +/** The factory of ChannelStateWriteRequestExecutor. */ +public class ChannelStateWriteRequestExecutorFactory { + + private static final Map<JobID, ChannelStateWriteRequestExecutor> EXECUTORS = new HashMap<>(); Review Comment: Instead of having a singleton, with static fields, can we inject a shared instance of `ChannelStateWriteRequestExecutorFactory`? For example initialize and hold this instance in `TaskManagerServices`, and pass it to `Task` and into the `StreamTask` through `Environment` and later down into `SubtaskCheckpointCoordinatorImpl#openChannelStateWriter` (such kind of static/global variables can cause many problems down the line, including in for example tests) ########## flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java: ########## @@ -90,33 +103,62 @@ public void dispatch(ChannelStateWriteRequest request) throws Exception { } private void dispatchInternal(ChannelStateWriteRequest request) throws Exception { + if (request instanceof SubtaskRegisterRequest) { + SubtaskRegisterRequest req = (SubtaskRegisterRequest) request; + SubtaskID subtaskID = + SubtaskID.of(req.getJobID(), req.getJobVertexID(), req.getSubtaskIndex()); + subtasks.add(subtaskID); + return; + } else if (request instanceof SubtaskReleaseRequest) { + SubtaskReleaseRequest req = (SubtaskReleaseRequest) request; + SubtaskID subtaskID = + SubtaskID.of(req.getJobID(), req.getJobVertexID(), req.getSubtaskIndex()); + subtasks.remove(subtaskID); + if (writer == null) { + return; + } + writer.releaseSubtask(subtaskID); + return; + } if (isAbortedCheckpoint(request.getCheckpointId())) { - if (request.getCheckpointId() == maxAbortedCheckpointId) { + if (request.getCheckpointId() != maxAbortedCheckpointId) { + request.cancel(new CheckpointException(CHECKPOINT_DECLINED_SUBSUMED)); + return; + } + + SubtaskID requestSubtask = + SubtaskID.of( + request.getJobID(), + request.getJobVertexID(), + request.getSubtaskIndex()); + if (requestSubtask.equals(abortedSubtaskID)) { request.cancel(abortedCause); } else { - request.cancel(new CheckpointException(CHECKPOINT_DECLINED_SUBSUMED)); + request.cancel( + new CheckpointException( + CHANNEL_STATE_SHARED_STREAM_EXCEPTION, abortedCause)); } return; } if (request instanceof CheckpointStartRequest) { checkState( - request.getCheckpointId() > ongoingCheckpointId, + request.getCheckpointId() >= ongoingCheckpointId, String.format( "Checkpoint must be incremented, ongoingCheckpointId is %s, but the request is %s.", ongoingCheckpointId, request)); - failAndClearWriter( - new IllegalStateException( - String.format( - "Task[name=%s, subtaskIndex=%s] has uncompleted channelState writer of checkpointId=%s, " - + "but it received a new checkpoint start request of checkpointId=%s, it maybe " - + "a bug due to currently not supported concurrent unaligned checkpoint.", - taskName, - subtaskIndex, - ongoingCheckpointId, - request.getCheckpointId()))); - this.writer = buildWriter((CheckpointStartRequest) request); - this.ongoingCheckpointId = request.getCheckpointId(); + if (request.getCheckpointId() > ongoingCheckpointId) { + // Clear the previous writer. + failAndClearWriter(new CheckpointException(CHECKPOINT_DECLINED_SUBSUMED)); + } + CheckpointStartRequest req = (CheckpointStartRequest) request; + if (writer == null) { Review Comment: This can not be `checkState(writer == null)` because single dispatcher will handle 5 `CheckpointStartRequests` from 5 subtasks (assuming 5 subtasks are configured to share the same file?). If so, maybe add a comment explaining this? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org