pnowojski commented on code in PR #27783:
URL: https://github.com/apache/flink/pull/27783#discussion_r2996614859


##########
flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java:
##########
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.io.recovery;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * Context containing all information needed for filtering recovered channel 
state buffers.
+ *
+ * <p>This context encapsulates the input configurations, rescaling 
descriptor, and subtask
+ * information required by the channel-state-unspilling thread to perform 
record filtering during
+ * recovery.
+ *
+ * <p>Supports multiple inputs (e.g., TwoInputStreamTask, 
MultipleInputStreamTask) by storing a list
+ * of {@link InputFilterConfig} instances indexed by input index.
+ *
+ * <p>Use the constructor with empty inputConfigs or enabled=false when 
filtering is not needed.
+ */
+@Internal
+public class RecordFilterContext {
+
+    /** Configuration for filtering records on a specific input. */
+    public static class InputFilterConfig {
+        private final TypeSerializer<?> typeSerializer;
+        private final StreamPartitioner<?> partitioner;
+        private final int numberOfChannels;
+
+        /**
+         * Creates a new InputFilterConfig.
+         *
+         * @param typeSerializer Serializer for the record type.
+         * @param partitioner Partitioner used to determine record ownership.
+         * @param numberOfChannels The parallelism of the current operator.
+         */
+        public InputFilterConfig(
+                TypeSerializer<?> typeSerializer,
+                StreamPartitioner<?> partitioner,
+                int numberOfChannels) {
+            this.typeSerializer = checkNotNull(typeSerializer);
+            this.partitioner = checkNotNull(partitioner);
+            this.numberOfChannels = numberOfChannels;
+        }
+
+        public TypeSerializer<?> getTypeSerializer() {
+            return typeSerializer;
+        }
+
+        public StreamPartitioner<?> getPartitioner() {
+            return partitioner;
+        }
+
+        public int getNumberOfChannels() {
+            return numberOfChannels;
+        }
+    }
+
+    /**
+     * Input configurations indexed by gate index. Array elements may be null 
for non-network inputs
+     * (e.g., SourceInputConfig). The array length equals the total number of 
input gates.
+     */
+    private final InputFilterConfig[] inputConfigs;
+
+    /** Descriptor containing rescaling information. Never null. */
+    private final InflightDataRescalingDescriptor rescalingDescriptor;
+
+    /** Current subtask index. */
+    private final int subtaskIndex;
+
+    /** Maximum parallelism for configuring partitioners. */
+    private final int maxParallelism;
+
+    /** Temporary directories for spilling spanning records. Can be empty but 
never null. */
+    private final String[] tmpDirectories;
+
+    /** Whether unaligned checkpoint during recovery is enabled. */
+    private final boolean unalignedDuringRecoveryEnabled;
+
+    /**
+     * Creates a new RecordFilterContext.
+     *
+     * @param inputConfigs Input configurations indexed by gate index. Array 
elements may be null
+     *     for non-network inputs. Not null itself.
+     * @param rescalingDescriptor Descriptor containing rescaling information. 
Not null.
+     * @param subtaskIndex Current subtask index.
+     * @param maxParallelism Maximum parallelism.
+     * @param tmpDirectories Temporary directories for spilling spanning 
records. Can be null
+     *     (converted to empty array).
+     * @param unalignedDuringRecoveryEnabled Whether unaligned checkpoint 
during recovery is
+     *     enabled.
+     */
+    public RecordFilterContext(
+            InputFilterConfig[] inputConfigs,
+            InflightDataRescalingDescriptor rescalingDescriptor,
+            int subtaskIndex,
+            int maxParallelism,
+            String[] tmpDirectories,
+            boolean unalignedDuringRecoveryEnabled) {
+        this.inputConfigs = checkNotNull(inputConfigs).clone();
+        this.rescalingDescriptor = checkNotNull(rescalingDescriptor);
+        this.subtaskIndex = subtaskIndex;
+        this.maxParallelism = maxParallelism;
+        this.tmpDirectories = tmpDirectories != null ? tmpDirectories : new 
String[0];
+        this.unalignedDuringRecoveryEnabled = unalignedDuringRecoveryEnabled;
+    }
+
+    /**
+     * Gets the input configuration for a specific gate.
+     *
+     * @param gateIndex The gate index (0-based).
+     * @return The input configuration for the specified gate, or null if the 
gate is not a network
+     *     input (e.g., SourceInputConfig).
+     * @throws IllegalArgumentException if gateIndex is out of bounds.
+     */
+    public InputFilterConfig getInputConfig(int gateIndex) {
+        checkArgument(
+                gateIndex >= 0 && gateIndex < inputConfigs.length,
+                "Invalid gate index: %s, number of gates: %s",
+                gateIndex,
+                inputConfigs.length);
+        return inputConfigs[gateIndex];
+    }
+
+    /**
+     * Gets the number of input gates.
+     *
+     * @return The number of input gates.
+     */
+    public int getNumberOfGates() {
+        return inputConfigs.length;
+    }
+
+    /**
+     * Checks whether unaligned checkpoint during recovery is enabled.
+     *
+     * @return {@code true} if enabled, {@code false} otherwise.
+     */
+    public boolean isUnalignedDuringRecoveryEnabled() {
+        return unalignedDuringRecoveryEnabled;
+    }

Review Comment:
   rename to `isCheckpointingDuringRecoveryEnabled`? And adjust the java doc:
   
   > Checks whether unaligned checkpointING during recovery is enabled.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java:
##########
@@ -0,0 +1,418 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint.channel;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.RescaleMappings;
+import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
+import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
+import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult;
+import 
org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
+import org.apache.flink.runtime.plugable.DeserializationDelegate;
+import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilter;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext;
+import org.apache.flink.streaming.runtime.io.recovery.VirtualChannel;
+import 
org.apache.flink.streaming.runtime.io.recovery.VirtualChannelRecordFilterFactory;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * Filters recovered channel state buffers during the channel-state-unspilling 
phase, removing
+ * records that do not belong to the current subtask after rescaling.
+ *
+ * <p>Uses a per-gate architecture: each {@link InputGate} gets its own {@link 
GateFilterHandler}
+ * with the correct serializer, so multi-input tasks (e.g., 
TwoInputStreamTask) correctly
+ * deserialize different record types on different gates.
+ */
+@Internal
+public class ChannelStateFilteringHandler {
+
+    /**
+     * Handles record filtering for a single input gate. Each gate has its own 
serializer and set of
+     * virtual channels, allowing different gates to handle different record 
types independently.
+     */
+    static class GateFilterHandler<T> {
+
+        private final Map<SubtaskConnectionDescriptor, VirtualChannel<T>> 
virtualChannels;
+        private final StreamElementSerializer<T> serializer;
+        private final DeserializationDelegate<StreamElement> 
deserializationDelegate;
+        private final DataOutputSerializer outputSerializer;
+        private final byte[] lengthBuffer = new byte[4];
+
+        GateFilterHandler(
+                Map<SubtaskConnectionDescriptor, VirtualChannel<T>> 
virtualChannels,
+                StreamElementSerializer<T> serializer) {
+            this.virtualChannels = checkNotNull(virtualChannels);
+            this.serializer = checkNotNull(serializer);
+            this.deserializationDelegate = new 
NonReusingDeserializationDelegate<>(serializer);
+            this.outputSerializer = new DataOutputSerializer(128);
+        }
+
+        /**
+         * Deserializes records from {@code sourceBuffer}, applies the virtual 
channel's record
+         * filter, and re-serializes the surviving records into new buffers.
+         */
+        List<Buffer> filterAndRewrite(

Review Comment:
   could you re-order methods in this class? Public first. Private either below 
all publics, or below the first usage?



##########
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java:
##########
@@ -72,6 +85,18 @@ public void readInputData(InputGate[] inputGates) throws 
IOException, Interrupte
                     groupByDelegate(
                             streamSubtaskStates(),
                             
OperatorSubtaskState::getUpstreamOutputBufferState));
+
+            if (filteringHandler != null) {
+                checkState(
+                        !filteringHandler.hasPartialData(),
+                        "Not all data has been fully consumed during 
filtering");
+            }
+        } finally {
+            // Clean up filtering handler resources (e.g., temp files from
+            // SpillingAdaptiveSpanningRecordDeserializer) on both success and 
error paths
+            if (filteringHandler != null) {
+                filteringHandler.clear();

Review Comment:
   nit: make it closeable?



##########
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java:
##########
@@ -0,0 +1,418 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint.channel;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.RescaleMappings;
+import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
+import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
+import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult;
+import 
org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
+import org.apache.flink.runtime.plugable.DeserializationDelegate;
+import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilter;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext;
+import org.apache.flink.streaming.runtime.io.recovery.VirtualChannel;
+import 
org.apache.flink.streaming.runtime.io.recovery.VirtualChannelRecordFilterFactory;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * Filters recovered channel state buffers during the channel-state-unspilling 
phase, removing
+ * records that do not belong to the current subtask after rescaling.
+ *
+ * <p>Uses a per-gate architecture: each {@link InputGate} gets its own {@link 
GateFilterHandler}
+ * with the correct serializer, so multi-input tasks (e.g., 
TwoInputStreamTask) correctly
+ * deserialize different record types on different gates.
+ */
+@Internal
+public class ChannelStateFilteringHandler {
+
+    /**
+     * Handles record filtering for a single input gate. Each gate has its own 
serializer and set of
+     * virtual channels, allowing different gates to handle different record 
types independently.
+     */
+    static class GateFilterHandler<T> {
+
+        private final Map<SubtaskConnectionDescriptor, VirtualChannel<T>> 
virtualChannels;
+        private final StreamElementSerializer<T> serializer;
+        private final DeserializationDelegate<StreamElement> 
deserializationDelegate;
+        private final DataOutputSerializer outputSerializer;
+        private final byte[] lengthBuffer = new byte[4];
+
+        GateFilterHandler(
+                Map<SubtaskConnectionDescriptor, VirtualChannel<T>> 
virtualChannels,
+                StreamElementSerializer<T> serializer) {
+            this.virtualChannels = checkNotNull(virtualChannels);
+            this.serializer = checkNotNull(serializer);
+            this.deserializationDelegate = new 
NonReusingDeserializationDelegate<>(serializer);
+            this.outputSerializer = new DataOutputSerializer(128);
+        }
+
+        /**
+         * Deserializes records from {@code sourceBuffer}, applies the virtual 
channel's record
+         * filter, and re-serializes the surviving records into new buffers.
+         */
+        List<Buffer> filterAndRewrite(
+                int oldSubtaskIndex,
+                int oldChannelIndex,
+                Buffer sourceBuffer,
+                BufferSupplier bufferSupplier)
+                throws IOException, InterruptedException {
+
+            SubtaskConnectionDescriptor key =
+                    new SubtaskConnectionDescriptor(oldSubtaskIndex, 
oldChannelIndex);
+            VirtualChannel<T> vc = virtualChannels.get(key);
+            if (vc == null) {
+                throw new IllegalStateException(
+                        "No VirtualChannel found for key: "
+                                + key
+                                + "; known channels are "
+                                + virtualChannels.keySet());
+            }
+
+            vc.setNextBuffer(sourceBuffer);
+
+            List<StreamElement> filteredElements = new ArrayList<>();
+
+            while (true) {
+                DeserializationResult result = 
vc.getNextRecord(deserializationDelegate);
+                if (result.isFullRecord()) {
+                    
filteredElements.add(deserializationDelegate.getInstance());
+                }
+                if (result.isBufferConsumed()) {
+                    break;
+                }
+            }
+
+            return serializeToBuffers(filteredElements, bufferSupplier);

Review Comment:
   ditto about `List` in `List<StreamElement> filteredElements`. It would be 
safer to be iterative. Current implementation risks OOMs if deserialised 
records are using more memory than the serialised records. This is not very 
common, but could happen.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java:
##########
@@ -95,23 +105,60 @@ public void recover(
             InputChannelInfo channelInfo,
             int oldSubtaskIndex,
             BufferWithContext<Buffer> bufferWithContext)
-            throws IOException {
+            throws IOException, InterruptedException {
         Buffer buffer = bufferWithContext.context;
         try {
             if (buffer.readableBytes() > 0) {
                 RecoveredInputChannel channel = getMappedChannels(channelInfo);
-                channel.onRecoveredStateBuffer(
-                        EventSerializer.toBuffer(
-                                new SubtaskConnectionDescriptor(
-                                        oldSubtaskIndex, 
channelInfo.getInputChannelIdx()),
-                                false));
-                channel.onRecoveredStateBuffer(buffer.retainBuffer());
+
+                if (filteringHandler != null) {
+                    // Filtering mode: filter records and rewrite to new 
buffers
+                    recoverWithFiltering(channel, channelInfo, 
oldSubtaskIndex, buffer);
+                } else {
+                    // Non-filtering mode: pass through original buffer with 
descriptor
+                    channel.onRecoveredStateBuffer(
+                            EventSerializer.toBuffer(
+                                    new SubtaskConnectionDescriptor(
+                                            oldSubtaskIndex, 
channelInfo.getInputChannelIdx()),
+                                    false));
+                    channel.onRecoveredStateBuffer(buffer.retainBuffer());
+                }
             }
         } finally {
             buffer.recycleBuffer();
         }
     }
 
+    private void recoverWithFiltering(
+            RecoveredInputChannel channel,
+            InputChannelInfo channelInfo,
+            int oldSubtaskIndex,
+            Buffer buffer)
+            throws IOException, InterruptedException {
+        checkState(filteringHandler != null, "filtering handler not set.");
+        // Extra retain: filterAndRewrite consumes one ref, caller's finally 
releases another.
+        buffer.retainBuffer();
+
+        List<Buffer> filteredBuffers;
+        try {
+            filteredBuffers =
+                    filteringHandler.filterAndRewrite(
+                            channelInfo.getGateIdx(),
+                            oldSubtaskIndex,
+                            channelInfo.getInputChannelIdx(),
+                            buffer,
+                            channel::requestBufferBlocking);
+        } catch (Throwable t) {
+            // filterAndRewrite didn't consume the buffer, release the extra 
ref.
+            buffer.recycleBuffer();
+            throw t;
+        }

Review Comment:
   Hmm, that's a bit strange? It sounds like it's not clear who is owner of 
this buffer? There should be clean owner that's always responsible for cleaning 
up, no matter what.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java:
##########
@@ -0,0 +1,418 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint.channel;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.RescaleMappings;
+import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
+import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
+import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult;
+import 
org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
+import org.apache.flink.runtime.plugable.DeserializationDelegate;
+import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilter;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext;
+import org.apache.flink.streaming.runtime.io.recovery.VirtualChannel;
+import 
org.apache.flink.streaming.runtime.io.recovery.VirtualChannelRecordFilterFactory;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * Filters recovered channel state buffers during the channel-state-unspilling 
phase, removing
+ * records that do not belong to the current subtask after rescaling.
+ *
+ * <p>Uses a per-gate architecture: each {@link InputGate} gets its own {@link 
GateFilterHandler}
+ * with the correct serializer, so multi-input tasks (e.g., 
TwoInputStreamTask) correctly
+ * deserialize different record types on different gates.
+ */
+@Internal
+public class ChannelStateFilteringHandler {
+
+    /**
+     * Handles record filtering for a single input gate. Each gate has its own 
serializer and set of
+     * virtual channels, allowing different gates to handle different record 
types independently.
+     */
+    static class GateFilterHandler<T> {
+
+        private final Map<SubtaskConnectionDescriptor, VirtualChannel<T>> 
virtualChannels;
+        private final StreamElementSerializer<T> serializer;
+        private final DeserializationDelegate<StreamElement> 
deserializationDelegate;
+        private final DataOutputSerializer outputSerializer;
+        private final byte[] lengthBuffer = new byte[4];
+
+        GateFilterHandler(
+                Map<SubtaskConnectionDescriptor, VirtualChannel<T>> 
virtualChannels,
+                StreamElementSerializer<T> serializer) {
+            this.virtualChannels = checkNotNull(virtualChannels);
+            this.serializer = checkNotNull(serializer);
+            this.deserializationDelegate = new 
NonReusingDeserializationDelegate<>(serializer);
+            this.outputSerializer = new DataOutputSerializer(128);
+        }
+
+        /**
+         * Deserializes records from {@code sourceBuffer}, applies the virtual 
channel's record
+         * filter, and re-serializes the surviving records into new buffers.
+         */
+        List<Buffer> filterAndRewrite(
+                int oldSubtaskIndex,
+                int oldChannelIndex,
+                Buffer sourceBuffer,
+                BufferSupplier bufferSupplier)
+                throws IOException, InterruptedException {
+
+            SubtaskConnectionDescriptor key =
+                    new SubtaskConnectionDescriptor(oldSubtaskIndex, 
oldChannelIndex);
+            VirtualChannel<T> vc = virtualChannels.get(key);
+            if (vc == null) {
+                throw new IllegalStateException(
+                        "No VirtualChannel found for key: "
+                                + key
+                                + "; known channels are "
+                                + virtualChannels.keySet());
+            }
+
+            vc.setNextBuffer(sourceBuffer);
+
+            List<StreamElement> filteredElements = new ArrayList<>();
+
+            while (true) {
+                DeserializationResult result = 
vc.getNextRecord(deserializationDelegate);
+                if (result.isFullRecord()) {
+                    
filteredElements.add(deserializationDelegate.getInstance());
+                }
+                if (result.isBufferConsumed()) {
+                    break;
+                }
+            }
+
+            return serializeToBuffers(filteredElements, bufferSupplier);
+        }
+
+        /**
+         * Serializes stream elements into buffers using the length-prefixed 
format (4-byte
+         * big-endian length + record bytes) expected by Flink's record 
deserializers.
+         */
+        private List<Buffer> serializeToBuffers(
+                List<StreamElement> elements, BufferSupplier bufferSupplier)
+                throws IOException, InterruptedException {
+
+            List<Buffer> resultBuffers = new ArrayList<>();
+
+            if (elements.isEmpty()) {
+                return resultBuffers;
+            }
+
+            Buffer currentBuffer = bufferSupplier.requestBufferBlocking();
+
+            for (StreamElement element : elements) {
+                outputSerializer.clear();
+                serializer.serialize(element, outputSerializer);
+                int recordLength = outputSerializer.length();
+
+                writeLengthToBuffer(recordLength);
+                currentBuffer =
+                        writeDataToBuffer(
+                                lengthBuffer, 0, 4, currentBuffer, 
resultBuffers, bufferSupplier);
+
+                byte[] serializedData = outputSerializer.getSharedBuffer();
+                currentBuffer =
+                        writeDataToBuffer(
+                                serializedData,
+                                0,
+                                recordLength,
+                                currentBuffer,
+                                resultBuffers,
+                                bufferSupplier);
+            }
+
+            if (currentBuffer.readableBytes() > 0) {
+                resultBuffers.add(currentBuffer.retainBuffer());
+            }
+            currentBuffer.recycleBuffer();
+
+            return resultBuffers;
+        }
+
+        private void writeLengthToBuffer(int length) {
+            lengthBuffer[0] = (byte) (length >> 24);
+            lengthBuffer[1] = (byte) (length >> 16);
+            lengthBuffer[2] = (byte) (length >> 8);
+            lengthBuffer[3] = (byte) length;
+        }
+
+        /**
+         * Writes data to the current buffer, spilling into new buffers from 
{@code bufferSupplier}
+         * when the current one is full.
+         *
+         * @return the buffer to continue writing into (may differ from the 
input buffer).
+         */
+        private Buffer writeDataToBuffer(
+                byte[] data,
+                int dataOffset,
+                int dataLength,
+                Buffer currentBuffer,
+                List<Buffer> resultBuffers,
+                BufferSupplier bufferSupplier)
+                throws IOException, InterruptedException {
+            int offset = dataOffset;
+            int remaining = dataLength;
+
+            while (remaining > 0) {
+                int writableBytes = currentBuffer.getMaxCapacity() - 
currentBuffer.getSize();
+
+                if (writableBytes == 0) {
+                    if (currentBuffer.readableBytes() > 0) {
+                        resultBuffers.add(currentBuffer.retainBuffer());
+                    }
+                    currentBuffer.recycleBuffer();
+                    currentBuffer = bufferSupplier.requestBufferBlocking();
+                    writableBytes = currentBuffer.getMaxCapacity();
+                }
+
+                int bytesToWrite = Math.min(remaining, writableBytes);
+                currentBuffer
+                        .getMemorySegment()
+                        .put(
+                                currentBuffer.getMemorySegmentOffset() + 
currentBuffer.getSize(),
+                                data,
+                                offset,
+                                bytesToWrite);
+                currentBuffer.setSize(currentBuffer.getSize() + bytesToWrite);
+
+                offset += bytesToWrite;
+                remaining -= bytesToWrite;
+            }
+            return currentBuffer;
+        }
+
+        boolean hasPartialData() {
+            return 
virtualChannels.values().stream().anyMatch(VirtualChannel::hasPartialData);
+        }
+
+        void clear() {
+            virtualChannels.values().forEach(VirtualChannel::clear);
+        }
+    }
+
+    // Wildcard allows heterogeneous record types across gates.
+    private final GateFilterHandler<?>[] gateHandlers;
+
+    ChannelStateFilteringHandler(GateFilterHandler<?>[] gateHandlers) {
+        this.gateHandlers = checkNotNull(gateHandlers);
+    }
+
+    /**
+     * Creates a handler from the recovery context, building per-gate virtual 
channels based on
+     * rescaling descriptors. Returns {@code null} if no filtering is needed 
(e.g., source tasks or
+     * no rescaling).
+     */
+    @Nullable
+    public static ChannelStateFilteringHandler createFromContext(
+            RecordFilterContext filterContext, InputGate[] inputGates) {
+        // Source tasks have no network inputs
+        if (filterContext.getNumberOfGates() == 0) {
+            return null;
+        }
+
+        InflightDataRescalingDescriptor rescalingDescriptor =
+                filterContext.getRescalingDescriptor();
+
+        GateFilterHandler<?>[] gateHandlers = new 
GateFilterHandler<?>[inputGates.length];
+        boolean hasAnyVirtualChannels = false;
+
+        for (int gateIndex = 0; gateIndex < inputGates.length; gateIndex++) {
+            gateHandlers[gateIndex] =
+                    createGateHandler(filterContext, inputGates, 
rescalingDescriptor, gateIndex);
+            if (gateHandlers[gateIndex] != null) {
+                hasAnyVirtualChannels = true;
+            }
+        }
+
+        if (!hasAnyVirtualChannels) {
+            return null;
+        }
+
+        return new ChannelStateFilteringHandler(gateHandlers);
+    }
+
+    /**
+     * Creates a {@link GateFilterHandler} for a single gate. The method-level 
type parameter
+     * ensures type safety within each gate while allowing different gates to 
have different types.
+     */
+    @SuppressWarnings("unchecked")
+    @Nullable
+    private static <T> GateFilterHandler<T> createGateHandler(
+            RecordFilterContext filterContext,
+            InputGate[] inputGates,
+            InflightDataRescalingDescriptor rescalingDescriptor,
+            int gateIndex) {
+        RecordFilterContext.InputFilterConfig inputConfig = 
filterContext.getInputConfig(gateIndex);
+        if (inputConfig == null) {
+            throw new IllegalStateException(
+                    "No InputFilterConfig for gateIndex "
+                            + gateIndex
+                            + ". This indicates a bug in RecordFilterContext 
initialization.");
+        }
+
+        InputGate gate = inputGates[gateIndex];
+        int[] oldSubtaskIndexes = 
rescalingDescriptor.getOldSubtaskIndexes(gateIndex);
+        RescaleMappings channelMapping = 
rescalingDescriptor.getChannelMapping(gateIndex);
+
+        TypeSerializer<T> typeSerializer = (TypeSerializer<T>) 
inputConfig.getTypeSerializer();
+        StreamElementSerializer<T> elementSerializer =
+                new StreamElementSerializer<>(typeSerializer);
+
+        VirtualChannelRecordFilterFactory<T> filterFactory =
+                VirtualChannelRecordFilterFactory.fromContext(filterContext, 
gateIndex);
+
+        Map<SubtaskConnectionDescriptor, VirtualChannel<T>> 
gateVirtualChannels = new HashMap<>();
+
+        for (int oldSubtaskIndex : oldSubtaskIndexes) {
+            int numChannels = gate.getNumberOfInputChannels();
+            int[] oldChannelIndexes = getOldChannelIndexes(channelMapping, 
numChannels);
+
+            for (int oldChannelIndex : oldChannelIndexes) {
+                SubtaskConnectionDescriptor key =
+                        new SubtaskConnectionDescriptor(oldSubtaskIndex, 
oldChannelIndex);
+
+                if (gateVirtualChannels.containsKey(key)) {
+                    continue;
+                }
+
+                // Only ambiguous channels need actual filtering; 
non-ambiguous ones pass through
+                boolean isAmbiguous = 
rescalingDescriptor.isAmbiguous(gateIndex, oldSubtaskIndex);
+
+                RecordFilter<T> recordFilter =
+                        isAmbiguous
+                                ? filterFactory.createFilter()
+                                : 
VirtualChannelRecordFilterFactory.createPassThroughFilter();
+
+                RecordDeserializer<DeserializationDelegate<StreamElement>> 
deserializer =
+                        createDeserializer(filterContext.getTmpDirectories());
+
+                VirtualChannel<T> vc = new VirtualChannel<>(deserializer, 
recordFilter);
+                gateVirtualChannels.put(key, vc);
+            }
+        }
+
+        if (gateVirtualChannels.isEmpty()) {
+            return null;
+        }
+
+        return new GateFilterHandler<>(gateVirtualChannels, elementSerializer);
+    }
+
+    /**
+     * Collects all old channel indexes that are mapped from any new channel 
index in this gate.
+     * channelMapping is new-to-old, so we iterate new indexes and collect 
their old counterparts.
+     */
+    private static int[] getOldChannelIndexes(RescaleMappings channelMapping, 
int numChannels) {
+        List<Integer> oldIndexes = new ArrayList<>();
+        for (int newIndex = 0; newIndex < numChannels; newIndex++) {
+            int[] mapped = channelMapping.getMappedIndexes(newIndex);
+            for (int oldIndex : mapped) {
+                if (!oldIndexes.contains(oldIndex)) {
+                    oldIndexes.add(oldIndex);
+                }
+            }
+        }
+        return oldIndexes.stream().mapToInt(Integer::intValue).toArray();
+    }
+
+    private static RecordDeserializer<DeserializationDelegate<StreamElement>> 
createDeserializer(
+            String[] tmpDirectories) {
+        if (tmpDirectories != null && tmpDirectories.length > 0) {
+            return new 
SpillingAdaptiveSpanningRecordDeserializer<>(tmpDirectories);
+        } else {
+            String[] defaultDirs = new String[] 
{System.getProperty("java.io.tmpdir")};
+            return new 
SpillingAdaptiveSpanningRecordDeserializer<>(defaultDirs);
+        }
+    }
+
+    /**
+     * Filters a recovered buffer from the specified virtual channel, 
returning new buffers
+     * containing only the records that belong to the current subtask.
+     *
+     * @return filtered buffers, possibly empty if all records were filtered 
out.
+     */
+    public List<Buffer> filterAndRewrite(
+            int gateIndex,
+            int oldSubtaskIndex,
+            int oldChannelIndex,
+            Buffer sourceBuffer,
+            BufferSupplier bufferSupplier)

Review Comment:
   Why does it return `List` from one single `sourceBuffer`? Could you explain 
this in the java doc? And how many `Buffers` can that be? If a lot, shouldn't 
this be an `Iterator`? 



##########
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java:
##########
@@ -95,23 +105,60 @@ public void recover(
             InputChannelInfo channelInfo,
             int oldSubtaskIndex,
             BufferWithContext<Buffer> bufferWithContext)
-            throws IOException {
+            throws IOException, InterruptedException {
         Buffer buffer = bufferWithContext.context;
         try {
             if (buffer.readableBytes() > 0) {
                 RecoveredInputChannel channel = getMappedChannels(channelInfo);
-                channel.onRecoveredStateBuffer(
-                        EventSerializer.toBuffer(
-                                new SubtaskConnectionDescriptor(
-                                        oldSubtaskIndex, 
channelInfo.getInputChannelIdx()),
-                                false));
-                channel.onRecoveredStateBuffer(buffer.retainBuffer());
+
+                if (filteringHandler != null) {
+                    // Filtering mode: filter records and rewrite to new 
buffers
+                    recoverWithFiltering(channel, channelInfo, 
oldSubtaskIndex, buffer);
+                } else {
+                    // Non-filtering mode: pass through original buffer with 
descriptor
+                    channel.onRecoveredStateBuffer(
+                            EventSerializer.toBuffer(
+                                    new SubtaskConnectionDescriptor(
+                                            oldSubtaskIndex, 
channelInfo.getInputChannelIdx()),
+                                    false));
+                    channel.onRecoveredStateBuffer(buffer.retainBuffer());
+                }
             }
         } finally {
             buffer.recycleBuffer();
         }
     }
 
+    private void recoverWithFiltering(
+            RecoveredInputChannel channel,
+            InputChannelInfo channelInfo,
+            int oldSubtaskIndex,
+            Buffer buffer)
+            throws IOException, InterruptedException {
+        checkState(filteringHandler != null, "filtering handler not set.");
+        // Extra retain: filterAndRewrite consumes one ref, caller's finally 
releases another.
+        buffer.retainBuffer();

Review Comment:
   nit: I think it would be slightly cleaner to call `buffer.retainBuffer` from 
the outside, and contract would be then that this method always takes over 
ownership of this buffer. 



##########
flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java:
##########
@@ -0,0 +1,418 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint.channel;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.RescaleMappings;
+import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
+import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
+import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult;
+import 
org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
+import org.apache.flink.runtime.plugable.DeserializationDelegate;
+import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilter;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext;
+import org.apache.flink.streaming.runtime.io.recovery.VirtualChannel;
+import 
org.apache.flink.streaming.runtime.io.recovery.VirtualChannelRecordFilterFactory;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * Filters recovered channel state buffers during the channel-state-unspilling 
phase, removing
+ * records that do not belong to the current subtask after rescaling.
+ *
+ * <p>Uses a per-gate architecture: each {@link InputGate} gets its own {@link 
GateFilterHandler}
+ * with the correct serializer, so multi-input tasks (e.g., 
TwoInputStreamTask) correctly
+ * deserialize different record types on different gates.
+ */
+@Internal
+public class ChannelStateFilteringHandler {
+
+    /**
+     * Handles record filtering for a single input gate. Each gate has its own 
serializer and set of
+     * virtual channels, allowing different gates to handle different record 
types independently.
+     */
+    static class GateFilterHandler<T> {
+
+        private final Map<SubtaskConnectionDescriptor, VirtualChannel<T>> 
virtualChannels;
+        private final StreamElementSerializer<T> serializer;
+        private final DeserializationDelegate<StreamElement> 
deserializationDelegate;
+        private final DataOutputSerializer outputSerializer;
+        private final byte[] lengthBuffer = new byte[4];
+
+        GateFilterHandler(
+                Map<SubtaskConnectionDescriptor, VirtualChannel<T>> 
virtualChannels,
+                StreamElementSerializer<T> serializer) {
+            this.virtualChannels = checkNotNull(virtualChannels);
+            this.serializer = checkNotNull(serializer);
+            this.deserializationDelegate = new 
NonReusingDeserializationDelegate<>(serializer);
+            this.outputSerializer = new DataOutputSerializer(128);
+        }
+
+        /**
+         * Deserializes records from {@code sourceBuffer}, applies the virtual 
channel's record
+         * filter, and re-serializes the surviving records into new buffers.
+         */
+        List<Buffer> filterAndRewrite(
+                int oldSubtaskIndex,
+                int oldChannelIndex,
+                Buffer sourceBuffer,
+                BufferSupplier bufferSupplier)
+                throws IOException, InterruptedException {
+
+            SubtaskConnectionDescriptor key =
+                    new SubtaskConnectionDescriptor(oldSubtaskIndex, 
oldChannelIndex);
+            VirtualChannel<T> vc = virtualChannels.get(key);
+            if (vc == null) {
+                throw new IllegalStateException(
+                        "No VirtualChannel found for key: "
+                                + key
+                                + "; known channels are "
+                                + virtualChannels.keySet());
+            }
+
+            vc.setNextBuffer(sourceBuffer);
+
+            List<StreamElement> filteredElements = new ArrayList<>();
+
+            while (true) {
+                DeserializationResult result = 
vc.getNextRecord(deserializationDelegate);
+                if (result.isFullRecord()) {
+                    
filteredElements.add(deserializationDelegate.getInstance());
+                }
+                if (result.isBufferConsumed()) {
+                    break;
+                }
+            }
+
+            return serializeToBuffers(filteredElements, bufferSupplier);
+        }
+
+        /**
+         * Serializes stream elements into buffers using the length-prefixed 
format (4-byte
+         * big-endian length + record bytes) expected by Flink's record 
deserializers.
+         */
+        private List<Buffer> serializeToBuffers(
+                List<StreamElement> elements, BufferSupplier bufferSupplier)
+                throws IOException, InterruptedException {
+
+            List<Buffer> resultBuffers = new ArrayList<>();
+
+            if (elements.isEmpty()) {
+                return resultBuffers;
+            }
+
+            Buffer currentBuffer = bufferSupplier.requestBufferBlocking();
+
+            for (StreamElement element : elements) {
+                outputSerializer.clear();
+                serializer.serialize(element, outputSerializer);
+                int recordLength = outputSerializer.length();
+
+                writeLengthToBuffer(recordLength);
+                currentBuffer =
+                        writeDataToBuffer(
+                                lengthBuffer, 0, 4, currentBuffer, 
resultBuffers, bufferSupplier);
+
+                byte[] serializedData = outputSerializer.getSharedBuffer();
+                currentBuffer =
+                        writeDataToBuffer(
+                                serializedData,
+                                0,
+                                recordLength,
+                                currentBuffer,
+                                resultBuffers,
+                                bufferSupplier);
+            }
+
+            if (currentBuffer.readableBytes() > 0) {
+                resultBuffers.add(currentBuffer.retainBuffer());
+            }
+            currentBuffer.recycleBuffer();
+
+            return resultBuffers;
+        }
+
+        private void writeLengthToBuffer(int length) {
+            lengthBuffer[0] = (byte) (length >> 24);
+            lengthBuffer[1] = (byte) (length >> 16);
+            lengthBuffer[2] = (byte) (length >> 8);
+            lengthBuffer[3] = (byte) length;
+        }
+
+        /**
+         * Writes data to the current buffer, spilling into new buffers from 
{@code bufferSupplier}
+         * when the current one is full.
+         *
+         * @return the buffer to continue writing into (may differ from the 
input buffer).
+         */
+        private Buffer writeDataToBuffer(
+                byte[] data,
+                int dataOffset,
+                int dataLength,
+                Buffer currentBuffer,
+                List<Buffer> resultBuffers,
+                BufferSupplier bufferSupplier)
+                throws IOException, InterruptedException {
+            int offset = dataOffset;
+            int remaining = dataLength;
+
+            while (remaining > 0) {
+                int writableBytes = currentBuffer.getMaxCapacity() - 
currentBuffer.getSize();
+
+                if (writableBytes == 0) {
+                    if (currentBuffer.readableBytes() > 0) {
+                        resultBuffers.add(currentBuffer.retainBuffer());
+                    }
+                    currentBuffer.recycleBuffer();
+                    currentBuffer = bufferSupplier.requestBufferBlocking();

Review Comment:
   Is it safe to block here? 🤔 Can this lead to deadlocks? I think we were 
discussing this, but AFAIR this code works differently to what we were 
discussing offline (either using unpooled buffer or create two different pools, 
or filter records in-place without requesting new buffer)?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to