1996fanrui commented on code in PR #27783:
URL: https://github.com/apache/flink/pull/27783#discussion_r3000101071


##########
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java:
##########
@@ -0,0 +1,418 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint.channel;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.RescaleMappings;
+import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
+import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
+import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult;
+import 
org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
+import org.apache.flink.runtime.plugable.DeserializationDelegate;
+import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilter;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext;
+import org.apache.flink.streaming.runtime.io.recovery.VirtualChannel;
+import 
org.apache.flink.streaming.runtime.io.recovery.VirtualChannelRecordFilterFactory;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * Filters recovered channel state buffers during the channel-state-unspilling 
phase, removing
+ * records that do not belong to the current subtask after rescaling.
+ *
+ * <p>Uses a per-gate architecture: each {@link InputGate} gets its own {@link 
GateFilterHandler}
+ * with the correct serializer, so multi-input tasks (e.g., 
TwoInputStreamTask) correctly
+ * deserialize different record types on different gates.
+ */
+@Internal
+public class ChannelStateFilteringHandler {
+
+    /**
+     * Handles record filtering for a single input gate. Each gate has its own 
serializer and set of
+     * virtual channels, allowing different gates to handle different record 
types independently.
+     */
+    static class GateFilterHandler<T> {
+
+        private final Map<SubtaskConnectionDescriptor, VirtualChannel<T>> 
virtualChannels;
+        private final StreamElementSerializer<T> serializer;
+        private final DeserializationDelegate<StreamElement> 
deserializationDelegate;
+        private final DataOutputSerializer outputSerializer;
+        private final byte[] lengthBuffer = new byte[4];
+
+        GateFilterHandler(
+                Map<SubtaskConnectionDescriptor, VirtualChannel<T>> 
virtualChannels,
+                StreamElementSerializer<T> serializer) {
+            this.virtualChannels = checkNotNull(virtualChannels);
+            this.serializer = checkNotNull(serializer);
+            this.deserializationDelegate = new 
NonReusingDeserializationDelegate<>(serializer);
+            this.outputSerializer = new DataOutputSerializer(128);
+        }
+
+        /**
+         * Deserializes records from {@code sourceBuffer}, applies the virtual 
channel's record
+         * filter, and re-serializes the surviving records into new buffers.
+         */
+        List<Buffer> filterAndRewrite(
+                int oldSubtaskIndex,
+                int oldChannelIndex,
+                Buffer sourceBuffer,
+                BufferSupplier bufferSupplier)
+                throws IOException, InterruptedException {
+
+            SubtaskConnectionDescriptor key =
+                    new SubtaskConnectionDescriptor(oldSubtaskIndex, 
oldChannelIndex);
+            VirtualChannel<T> vc = virtualChannels.get(key);
+            if (vc == null) {
+                throw new IllegalStateException(
+                        "No VirtualChannel found for key: "
+                                + key
+                                + "; known channels are "
+                                + virtualChannels.keySet());
+            }
+
+            vc.setNextBuffer(sourceBuffer);
+
+            List<StreamElement> filteredElements = new ArrayList<>();
+
+            while (true) {
+                DeserializationResult result = 
vc.getNextRecord(deserializationDelegate);
+                if (result.isFullRecord()) {
+                    
filteredElements.add(deserializationDelegate.getInstance());
+                }
+                if (result.isBufferConsumed()) {
+                    break;
+                }
+            }
+
+            return serializeToBuffers(filteredElements, bufferSupplier);
+        }
+
+        /**
+         * Serializes stream elements into buffers using the length-prefixed 
format (4-byte
+         * big-endian length + record bytes) expected by Flink's record 
deserializers.
+         */
+        private List<Buffer> serializeToBuffers(
+                List<StreamElement> elements, BufferSupplier bufferSupplier)
+                throws IOException, InterruptedException {
+
+            List<Buffer> resultBuffers = new ArrayList<>();
+
+            if (elements.isEmpty()) {
+                return resultBuffers;
+            }
+
+            Buffer currentBuffer = bufferSupplier.requestBufferBlocking();
+
+            for (StreamElement element : elements) {
+                outputSerializer.clear();
+                serializer.serialize(element, outputSerializer);
+                int recordLength = outputSerializer.length();
+
+                writeLengthToBuffer(recordLength);
+                currentBuffer =
+                        writeDataToBuffer(
+                                lengthBuffer, 0, 4, currentBuffer, 
resultBuffers, bufferSupplier);
+
+                byte[] serializedData = outputSerializer.getSharedBuffer();
+                currentBuffer =
+                        writeDataToBuffer(
+                                serializedData,
+                                0,
+                                recordLength,
+                                currentBuffer,
+                                resultBuffers,
+                                bufferSupplier);
+            }
+
+            if (currentBuffer.readableBytes() > 0) {
+                resultBuffers.add(currentBuffer.retainBuffer());
+            }
+            currentBuffer.recycleBuffer();
+
+            return resultBuffers;
+        }
+
+        private void writeLengthToBuffer(int length) {
+            lengthBuffer[0] = (byte) (length >> 24);
+            lengthBuffer[1] = (byte) (length >> 16);
+            lengthBuffer[2] = (byte) (length >> 8);
+            lengthBuffer[3] = (byte) length;
+        }
+
+        /**
+         * Writes data to the current buffer, spilling into new buffers from 
{@code bufferSupplier}
+         * when the current one is full.
+         *
+         * @return the buffer to continue writing into (may differ from the 
input buffer).
+         */
+        private Buffer writeDataToBuffer(
+                byte[] data,
+                int dataOffset,
+                int dataLength,
+                Buffer currentBuffer,
+                List<Buffer> resultBuffers,
+                BufferSupplier bufferSupplier)
+                throws IOException, InterruptedException {
+            int offset = dataOffset;
+            int remaining = dataLength;
+
+            while (remaining > 0) {
+                int writableBytes = currentBuffer.getMaxCapacity() - 
currentBuffer.getSize();
+
+                if (writableBytes == 0) {
+                    if (currentBuffer.readableBytes() > 0) {
+                        resultBuffers.add(currentBuffer.retainBuffer());
+                    }
+                    currentBuffer.recycleBuffer();
+                    currentBuffer = bufferSupplier.requestBufferBlocking();
+                    writableBytes = currentBuffer.getMaxCapacity();
+                }
+
+                int bytesToWrite = Math.min(remaining, writableBytes);
+                currentBuffer
+                        .getMemorySegment()
+                        .put(
+                                currentBuffer.getMemorySegmentOffset() + 
currentBuffer.getSize(),
+                                data,
+                                offset,
+                                bytesToWrite);
+                currentBuffer.setSize(currentBuffer.getSize() + bytesToWrite);
+
+                offset += bytesToWrite;
+                remaining -= bytesToWrite;
+            }
+            return currentBuffer;
+        }
+
+        boolean hasPartialData() {
+            return 
virtualChannels.values().stream().anyMatch(VirtualChannel::hasPartialData);
+        }
+
+        void clear() {
+            virtualChannels.values().forEach(VirtualChannel::clear);
+        }
+    }
+
+    // Wildcard allows heterogeneous record types across gates.
+    private final GateFilterHandler<?>[] gateHandlers;
+
+    ChannelStateFilteringHandler(GateFilterHandler<?>[] gateHandlers) {
+        this.gateHandlers = checkNotNull(gateHandlers);
+    }
+
+    /**
+     * Creates a handler from the recovery context, building per-gate virtual 
channels based on
+     * rescaling descriptors. Returns {@code null} if no filtering is needed 
(e.g., source tasks or
+     * no rescaling).
+     */
+    @Nullable
+    public static ChannelStateFilteringHandler createFromContext(
+            RecordFilterContext filterContext, InputGate[] inputGates) {
+        // Source tasks have no network inputs
+        if (filterContext.getNumberOfGates() == 0) {
+            return null;
+        }
+
+        InflightDataRescalingDescriptor rescalingDescriptor =
+                filterContext.getRescalingDescriptor();
+
+        GateFilterHandler<?>[] gateHandlers = new 
GateFilterHandler<?>[inputGates.length];
+        boolean hasAnyVirtualChannels = false;
+
+        for (int gateIndex = 0; gateIndex < inputGates.length; gateIndex++) {
+            gateHandlers[gateIndex] =
+                    createGateHandler(filterContext, inputGates, 
rescalingDescriptor, gateIndex);
+            if (gateHandlers[gateIndex] != null) {
+                hasAnyVirtualChannels = true;
+            }
+        }
+
+        if (!hasAnyVirtualChannels) {
+            return null;
+        }
+
+        return new ChannelStateFilteringHandler(gateHandlers);
+    }
+
+    /**
+     * Creates a {@link GateFilterHandler} for a single gate. The method-level 
type parameter
+     * ensures type safety within each gate while allowing different gates to 
have different types.
+     */
+    @SuppressWarnings("unchecked")
+    @Nullable
+    private static <T> GateFilterHandler<T> createGateHandler(
+            RecordFilterContext filterContext,
+            InputGate[] inputGates,
+            InflightDataRescalingDescriptor rescalingDescriptor,
+            int gateIndex) {
+        RecordFilterContext.InputFilterConfig inputConfig = 
filterContext.getInputConfig(gateIndex);
+        if (inputConfig == null) {
+            throw new IllegalStateException(
+                    "No InputFilterConfig for gateIndex "
+                            + gateIndex
+                            + ". This indicates a bug in RecordFilterContext 
initialization.");
+        }
+
+        InputGate gate = inputGates[gateIndex];
+        int[] oldSubtaskIndexes = 
rescalingDescriptor.getOldSubtaskIndexes(gateIndex);
+        RescaleMappings channelMapping = 
rescalingDescriptor.getChannelMapping(gateIndex);
+
+        TypeSerializer<T> typeSerializer = (TypeSerializer<T>) 
inputConfig.getTypeSerializer();
+        StreamElementSerializer<T> elementSerializer =
+                new StreamElementSerializer<>(typeSerializer);
+
+        VirtualChannelRecordFilterFactory<T> filterFactory =
+                VirtualChannelRecordFilterFactory.fromContext(filterContext, 
gateIndex);
+
+        Map<SubtaskConnectionDescriptor, VirtualChannel<T>> 
gateVirtualChannels = new HashMap<>();
+
+        for (int oldSubtaskIndex : oldSubtaskIndexes) {
+            int numChannels = gate.getNumberOfInputChannels();
+            int[] oldChannelIndexes = getOldChannelIndexes(channelMapping, 
numChannels);
+
+            for (int oldChannelIndex : oldChannelIndexes) {
+                SubtaskConnectionDescriptor key =
+                        new SubtaskConnectionDescriptor(oldSubtaskIndex, 
oldChannelIndex);
+
+                if (gateVirtualChannels.containsKey(key)) {
+                    continue;
+                }
+
+                // Only ambiguous channels need actual filtering; 
non-ambiguous ones pass through
+                boolean isAmbiguous = 
rescalingDescriptor.isAmbiguous(gateIndex, oldSubtaskIndex);
+
+                RecordFilter<T> recordFilter =
+                        isAmbiguous
+                                ? filterFactory.createFilter()
+                                : 
VirtualChannelRecordFilterFactory.createPassThroughFilter();
+
+                RecordDeserializer<DeserializationDelegate<StreamElement>> 
deserializer =
+                        createDeserializer(filterContext.getTmpDirectories());
+
+                VirtualChannel<T> vc = new VirtualChannel<>(deserializer, 
recordFilter);
+                gateVirtualChannels.put(key, vc);
+            }
+        }
+
+        if (gateVirtualChannels.isEmpty()) {
+            return null;
+        }
+
+        return new GateFilterHandler<>(gateVirtualChannels, elementSerializer);
+    }
+
+    /**
+     * Collects all old channel indexes that are mapped from any new channel 
index in this gate.
+     * channelMapping is new-to-old, so we iterate new indexes and collect 
their old counterparts.
+     */
+    private static int[] getOldChannelIndexes(RescaleMappings channelMapping, 
int numChannels) {
+        List<Integer> oldIndexes = new ArrayList<>();
+        for (int newIndex = 0; newIndex < numChannels; newIndex++) {
+            int[] mapped = channelMapping.getMappedIndexes(newIndex);
+            for (int oldIndex : mapped) {
+                if (!oldIndexes.contains(oldIndex)) {
+                    oldIndexes.add(oldIndex);
+                }
+            }
+        }
+        return oldIndexes.stream().mapToInt(Integer::intValue).toArray();
+    }
+
+    private static RecordDeserializer<DeserializationDelegate<StreamElement>> 
createDeserializer(
+            String[] tmpDirectories) {
+        if (tmpDirectories != null && tmpDirectories.length > 0) {
+            return new 
SpillingAdaptiveSpanningRecordDeserializer<>(tmpDirectories);
+        } else {
+            String[] defaultDirs = new String[] 
{System.getProperty("java.io.tmpdir")};
+            return new 
SpillingAdaptiveSpanningRecordDeserializer<>(defaultDirs);
+        }
+    }
+
+    /**
+     * Filters a recovered buffer from the specified virtual channel, 
returning new buffers
+     * containing only the records that belong to the current subtask.
+     *
+     * @return filtered buffers, possibly empty if all records were filtered 
out.
+     */
+    public List<Buffer> filterAndRewrite(
+            int gateIndex,
+            int oldSubtaskIndex,
+            int oldChannelIndex,
+            Buffer sourceBuffer,
+            BufferSupplier bufferSupplier)

Review Comment:
   The code comment is udpated.
   
   The List<Buffer> return can contain more than 1 buffer when a spanning 
record completes in this buffer — the deserializer caches partial data from 
previous buffers, so the output may include data not present in the current 
source buffer. 
   
   This is uncommon but possible with any spanning record. For this case, it 
will be covered by spilling logic if network pool is insufficient.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java:
##########
@@ -0,0 +1,418 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint.channel;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.RescaleMappings;
+import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
+import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
+import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult;
+import 
org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
+import org.apache.flink.runtime.plugable.DeserializationDelegate;
+import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilter;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext;
+import org.apache.flink.streaming.runtime.io.recovery.VirtualChannel;
+import 
org.apache.flink.streaming.runtime.io.recovery.VirtualChannelRecordFilterFactory;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * Filters recovered channel state buffers during the channel-state-unspilling 
phase, removing
+ * records that do not belong to the current subtask after rescaling.
+ *
+ * <p>Uses a per-gate architecture: each {@link InputGate} gets its own {@link 
GateFilterHandler}
+ * with the correct serializer, so multi-input tasks (e.g., 
TwoInputStreamTask) correctly
+ * deserialize different record types on different gates.
+ */
+@Internal
+public class ChannelStateFilteringHandler {
+
+    /**
+     * Handles record filtering for a single input gate. Each gate has its own 
serializer and set of
+     * virtual channels, allowing different gates to handle different record 
types independently.
+     */
+    static class GateFilterHandler<T> {
+
+        private final Map<SubtaskConnectionDescriptor, VirtualChannel<T>> 
virtualChannels;
+        private final StreamElementSerializer<T> serializer;
+        private final DeserializationDelegate<StreamElement> 
deserializationDelegate;
+        private final DataOutputSerializer outputSerializer;
+        private final byte[] lengthBuffer = new byte[4];
+
+        GateFilterHandler(
+                Map<SubtaskConnectionDescriptor, VirtualChannel<T>> 
virtualChannels,
+                StreamElementSerializer<T> serializer) {
+            this.virtualChannels = checkNotNull(virtualChannels);
+            this.serializer = checkNotNull(serializer);
+            this.deserializationDelegate = new 
NonReusingDeserializationDelegate<>(serializer);
+            this.outputSerializer = new DataOutputSerializer(128);
+        }
+
+        /**
+         * Deserializes records from {@code sourceBuffer}, applies the virtual 
channel's record
+         * filter, and re-serializes the surviving records into new buffers.
+         */
+        List<Buffer> filterAndRewrite(
+                int oldSubtaskIndex,
+                int oldChannelIndex,
+                Buffer sourceBuffer,
+                BufferSupplier bufferSupplier)
+                throws IOException, InterruptedException {
+
+            SubtaskConnectionDescriptor key =
+                    new SubtaskConnectionDescriptor(oldSubtaskIndex, 
oldChannelIndex);
+            VirtualChannel<T> vc = virtualChannels.get(key);
+            if (vc == null) {
+                throw new IllegalStateException(
+                        "No VirtualChannel found for key: "
+                                + key
+                                + "; known channels are "
+                                + virtualChannels.keySet());
+            }
+
+            vc.setNextBuffer(sourceBuffer);
+
+            List<StreamElement> filteredElements = new ArrayList<>();
+
+            while (true) {
+                DeserializationResult result = 
vc.getNextRecord(deserializationDelegate);
+                if (result.isFullRecord()) {
+                    
filteredElements.add(deserializationDelegate.getInstance());
+                }
+                if (result.isBufferConsumed()) {
+                    break;
+                }
+            }
+
+            return serializeToBuffers(filteredElements, bufferSupplier);

Review Comment:
   Addressed in 62d42c758086f9cacf5bc6d68038e97bebbb6a08



##########
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java:
##########
@@ -95,23 +105,60 @@ public void recover(
             InputChannelInfo channelInfo,
             int oldSubtaskIndex,
             BufferWithContext<Buffer> bufferWithContext)
-            throws IOException {
+            throws IOException, InterruptedException {
         Buffer buffer = bufferWithContext.context;
         try {
             if (buffer.readableBytes() > 0) {
                 RecoveredInputChannel channel = getMappedChannels(channelInfo);
-                channel.onRecoveredStateBuffer(
-                        EventSerializer.toBuffer(
-                                new SubtaskConnectionDescriptor(
-                                        oldSubtaskIndex, 
channelInfo.getInputChannelIdx()),
-                                false));
-                channel.onRecoveredStateBuffer(buffer.retainBuffer());
+
+                if (filteringHandler != null) {
+                    // Filtering mode: filter records and rewrite to new 
buffers
+                    recoverWithFiltering(channel, channelInfo, 
oldSubtaskIndex, buffer);
+                } else {
+                    // Non-filtering mode: pass through original buffer with 
descriptor
+                    channel.onRecoveredStateBuffer(
+                            EventSerializer.toBuffer(
+                                    new SubtaskConnectionDescriptor(
+                                            oldSubtaskIndex, 
channelInfo.getInputChannelIdx()),
+                                    false));
+                    channel.onRecoveredStateBuffer(buffer.retainBuffer());
+                }
             }
         } finally {
             buffer.recycleBuffer();
         }
     }
 
+    private void recoverWithFiltering(
+            RecoveredInputChannel channel,
+            InputChannelInfo channelInfo,
+            int oldSubtaskIndex,
+            Buffer buffer)
+            throws IOException, InterruptedException {
+        checkState(filteringHandler != null, "filtering handler not set.");
+        // Extra retain: filterAndRewrite consumes one ref, caller's finally 
releases another.
+        buffer.retainBuffer();

Review Comment:
   Addressed together with the ownership concern in comment 
https://github.com/apache/flink/pull/27783/changes#r2996388666. Removed 
`retainBuffer()` and the catch block entirely. The buffer now has a single 
clean owner per path: in the filtering path, the deserializer recycles the 
buffer when consumed; the `finally` uses a defensive `isRecycled()` check only 
for the edge case where an exception occurs before the deserializer takes the 
buffer (e.g., VirtualChannel lookup failure). Added a buffer lifecycle diagram 
in the javadoc covering all paths. No extra retain/recycle needed.



##########
flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java:
##########
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.io.recovery;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * Context containing all information needed for filtering recovered channel 
state buffers.
+ *
+ * <p>This context encapsulates the input configurations, rescaling 
descriptor, and subtask
+ * information required by the channel-state-unspilling thread to perform 
record filtering during
+ * recovery.
+ *
+ * <p>Supports multiple inputs (e.g., TwoInputStreamTask, 
MultipleInputStreamTask) by storing a list
+ * of {@link InputFilterConfig} instances indexed by input index.
+ *
+ * <p>Use the constructor with empty inputConfigs or enabled=false when 
filtering is not needed.
+ */
+@Internal
+public class RecordFilterContext {
+
+    /** Configuration for filtering records on a specific input. */
+    public static class InputFilterConfig {
+        private final TypeSerializer<?> typeSerializer;
+        private final StreamPartitioner<?> partitioner;
+        private final int numberOfChannels;
+
+        /**
+         * Creates a new InputFilterConfig.
+         *
+         * @param typeSerializer Serializer for the record type.
+         * @param partitioner Partitioner used to determine record ownership.
+         * @param numberOfChannels The parallelism of the current operator.
+         */
+        public InputFilterConfig(
+                TypeSerializer<?> typeSerializer,
+                StreamPartitioner<?> partitioner,
+                int numberOfChannels) {
+            this.typeSerializer = checkNotNull(typeSerializer);
+            this.partitioner = checkNotNull(partitioner);
+            this.numberOfChannels = numberOfChannels;
+        }
+
+        public TypeSerializer<?> getTypeSerializer() {
+            return typeSerializer;
+        }
+
+        public StreamPartitioner<?> getPartitioner() {
+            return partitioner;
+        }
+
+        public int getNumberOfChannels() {
+            return numberOfChannels;
+        }
+    }
+
+    /**
+     * Input configurations indexed by gate index. Array elements may be null 
for non-network inputs
+     * (e.g., SourceInputConfig). The array length equals the total number of 
input gates.
+     */
+    private final InputFilterConfig[] inputConfigs;
+
+    /** Descriptor containing rescaling information. Never null. */
+    private final InflightDataRescalingDescriptor rescalingDescriptor;
+
+    /** Current subtask index. */
+    private final int subtaskIndex;
+
+    /** Maximum parallelism for configuring partitioners. */
+    private final int maxParallelism;
+
+    /** Temporary directories for spilling spanning records. Can be empty but 
never null. */
+    private final String[] tmpDirectories;
+
+    /** Whether unaligned checkpoint during recovery is enabled. */
+    private final boolean unalignedDuringRecoveryEnabled;
+
+    /**
+     * Creates a new RecordFilterContext.
+     *
+     * @param inputConfigs Input configurations indexed by gate index. Array 
elements may be null
+     *     for non-network inputs. Not null itself.
+     * @param rescalingDescriptor Descriptor containing rescaling information. 
Not null.
+     * @param subtaskIndex Current subtask index.
+     * @param maxParallelism Maximum parallelism.
+     * @param tmpDirectories Temporary directories for spilling spanning 
records. Can be null
+     *     (converted to empty array).
+     * @param unalignedDuringRecoveryEnabled Whether unaligned checkpoint 
during recovery is
+     *     enabled.
+     */
+    public RecordFilterContext(
+            InputFilterConfig[] inputConfigs,
+            InflightDataRescalingDescriptor rescalingDescriptor,
+            int subtaskIndex,
+            int maxParallelism,
+            String[] tmpDirectories,
+            boolean unalignedDuringRecoveryEnabled) {
+        this.inputConfigs = checkNotNull(inputConfigs).clone();
+        this.rescalingDescriptor = checkNotNull(rescalingDescriptor);
+        this.subtaskIndex = subtaskIndex;
+        this.maxParallelism = maxParallelism;
+        this.tmpDirectories = tmpDirectories != null ? tmpDirectories : new 
String[0];
+        this.unalignedDuringRecoveryEnabled = unalignedDuringRecoveryEnabled;
+    }
+
+    /**
+     * Gets the input configuration for a specific gate.
+     *
+     * @param gateIndex The gate index (0-based).
+     * @return The input configuration for the specified gate, or null if the 
gate is not a network
+     *     input (e.g., SourceInputConfig).
+     * @throws IllegalArgumentException if gateIndex is out of bounds.
+     */
+    public InputFilterConfig getInputConfig(int gateIndex) {
+        checkArgument(
+                gateIndex >= 0 && gateIndex < inputConfigs.length,
+                "Invalid gate index: %s, number of gates: %s",
+                gateIndex,
+                inputConfigs.length);
+        return inputConfigs[gateIndex];
+    }
+
+    /**
+     * Gets the number of input gates.
+     *
+     * @return The number of input gates.
+     */
+    public int getNumberOfGates() {
+        return inputConfigs.length;
+    }
+
+    /**
+     * Checks whether unaligned checkpoint during recovery is enabled.
+     *
+     * @return {@code true} if enabled, {@code false} otherwise.
+     */
+    public boolean isUnalignedDuringRecoveryEnabled() {
+        return unalignedDuringRecoveryEnabled;
+    }

Review Comment:
   Done. And renamed in a separate preceding hotfix commit since the method was 
introduced in an earlier PR (FLINK-38541). All references updated.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java:
##########
@@ -0,0 +1,418 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint.channel;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.RescaleMappings;
+import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
+import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
+import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult;
+import 
org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
+import org.apache.flink.runtime.plugable.DeserializationDelegate;
+import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilter;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext;
+import org.apache.flink.streaming.runtime.io.recovery.VirtualChannel;
+import 
org.apache.flink.streaming.runtime.io.recovery.VirtualChannelRecordFilterFactory;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * Filters recovered channel state buffers during the channel-state-unspilling 
phase, removing
+ * records that do not belong to the current subtask after rescaling.
+ *
+ * <p>Uses a per-gate architecture: each {@link InputGate} gets its own {@link 
GateFilterHandler}
+ * with the correct serializer, so multi-input tasks (e.g., 
TwoInputStreamTask) correctly
+ * deserialize different record types on different gates.
+ */
+@Internal
+public class ChannelStateFilteringHandler {
+
+    /**
+     * Handles record filtering for a single input gate. Each gate has its own 
serializer and set of
+     * virtual channels, allowing different gates to handle different record 
types independently.
+     */
+    static class GateFilterHandler<T> {
+
+        private final Map<SubtaskConnectionDescriptor, VirtualChannel<T>> 
virtualChannels;
+        private final StreamElementSerializer<T> serializer;
+        private final DeserializationDelegate<StreamElement> 
deserializationDelegate;
+        private final DataOutputSerializer outputSerializer;
+        private final byte[] lengthBuffer = new byte[4];
+
+        GateFilterHandler(
+                Map<SubtaskConnectionDescriptor, VirtualChannel<T>> 
virtualChannels,
+                StreamElementSerializer<T> serializer) {
+            this.virtualChannels = checkNotNull(virtualChannels);
+            this.serializer = checkNotNull(serializer);
+            this.deserializationDelegate = new 
NonReusingDeserializationDelegate<>(serializer);
+            this.outputSerializer = new DataOutputSerializer(128);
+        }
+
+        /**
+         * Deserializes records from {@code sourceBuffer}, applies the virtual 
channel's record
+         * filter, and re-serializes the surviving records into new buffers.
+         */
+        List<Buffer> filterAndRewrite(
+                int oldSubtaskIndex,
+                int oldChannelIndex,
+                Buffer sourceBuffer,
+                BufferSupplier bufferSupplier)
+                throws IOException, InterruptedException {
+
+            SubtaskConnectionDescriptor key =
+                    new SubtaskConnectionDescriptor(oldSubtaskIndex, 
oldChannelIndex);
+            VirtualChannel<T> vc = virtualChannels.get(key);
+            if (vc == null) {
+                throw new IllegalStateException(
+                        "No VirtualChannel found for key: "
+                                + key
+                                + "; known channels are "
+                                + virtualChannels.keySet());
+            }
+
+            vc.setNextBuffer(sourceBuffer);
+
+            List<StreamElement> filteredElements = new ArrayList<>();
+
+            while (true) {
+                DeserializationResult result = 
vc.getNextRecord(deserializationDelegate);
+                if (result.isFullRecord()) {
+                    
filteredElements.add(deserializationDelegate.getInstance());
+                }
+                if (result.isBufferConsumed()) {
+                    break;
+                }
+            }
+
+            return serializeToBuffers(filteredElements, bufferSupplier);
+        }
+
+        /**
+         * Serializes stream elements into buffers using the length-prefixed 
format (4-byte
+         * big-endian length + record bytes) expected by Flink's record 
deserializers.
+         */
+        private List<Buffer> serializeToBuffers(
+                List<StreamElement> elements, BufferSupplier bufferSupplier)
+                throws IOException, InterruptedException {
+
+            List<Buffer> resultBuffers = new ArrayList<>();
+
+            if (elements.isEmpty()) {
+                return resultBuffers;
+            }
+
+            Buffer currentBuffer = bufferSupplier.requestBufferBlocking();
+
+            for (StreamElement element : elements) {
+                outputSerializer.clear();
+                serializer.serialize(element, outputSerializer);
+                int recordLength = outputSerializer.length();
+
+                writeLengthToBuffer(recordLength);
+                currentBuffer =
+                        writeDataToBuffer(
+                                lengthBuffer, 0, 4, currentBuffer, 
resultBuffers, bufferSupplier);
+
+                byte[] serializedData = outputSerializer.getSharedBuffer();
+                currentBuffer =
+                        writeDataToBuffer(
+                                serializedData,
+                                0,
+                                recordLength,
+                                currentBuffer,
+                                resultBuffers,
+                                bufferSupplier);
+            }
+
+            if (currentBuffer.readableBytes() > 0) {
+                resultBuffers.add(currentBuffer.retainBuffer());
+            }
+            currentBuffer.recycleBuffer();
+
+            return resultBuffers;
+        }
+
+        private void writeLengthToBuffer(int length) {
+            lengthBuffer[0] = (byte) (length >> 24);
+            lengthBuffer[1] = (byte) (length >> 16);
+            lengthBuffer[2] = (byte) (length >> 8);
+            lengthBuffer[3] = (byte) length;
+        }
+
+        /**
+         * Writes data to the current buffer, spilling into new buffers from 
{@code bufferSupplier}
+         * when the current one is full.
+         *
+         * @return the buffer to continue writing into (may differ from the 
input buffer).
+         */
+        private Buffer writeDataToBuffer(
+                byte[] data,
+                int dataOffset,
+                int dataLength,
+                Buffer currentBuffer,
+                List<Buffer> resultBuffers,
+                BufferSupplier bufferSupplier)
+                throws IOException, InterruptedException {
+            int offset = dataOffset;
+            int remaining = dataLength;
+
+            while (remaining > 0) {
+                int writableBytes = currentBuffer.getMaxCapacity() - 
currentBuffer.getSize();
+
+                if (writableBytes == 0) {
+                    if (currentBuffer.readableBytes() > 0) {
+                        resultBuffers.add(currentBuffer.retainBuffer());
+                    }
+                    currentBuffer.recycleBuffer();
+                    currentBuffer = bufferSupplier.requestBufferBlocking();

Review Comment:
   Addressed in b12a097760ee267f9ef6a659682b1a3f2ecb4404



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to