TanYuxin-tyx commented on code in PR #23851:
URL: https://github.com/apache/flink/pull/23851#discussion_r1462693982


##########
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/common/TieredStorageConfiguration.java:
##########
@@ -133,10 +157,14 @@ public static Builder builder(String 
remoteStorageBasePath) {
                 .setRemoteStorageBasePath(remoteStorageBasePath);
     }
 
-    public static Builder builder(int tieredStorageBufferSize, String 
remoteStorageBasePath) {
+    public static Builder builder(
+            int tieredStorageBufferSize,
+            String remoteStorageBasePath,
+            boolean enableMemorySafeMode) {

Review Comment:
   I noticed that you have added a setter, so we don't have to modify the 
constructor method anymore.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/SortBufferAccumulator.java:
##########
@@ -94,29 +99,58 @@ public class SortBufferAccumulator implements 
BufferAccumulator {
     @Nullable
     private TriConsumer<TieredStorageSubpartitionId, Buffer, Integer> 
accumulatedBufferFlusher;
 
+    /**
+     * An executor to periodically check the size of buffer pool. If the size 
is changed, the
+     * accumulated buffers should be flushed to release the buffers.
+     */
+    private final ScheduledExecutorService periodicalAccumulatorFlusher =
+            Executors.newSingleThreadScheduledExecutor(
+                    new 
ExecutorThreadFactory("hybrid-shuffle-periodical-accumulator-flusher"));
+
+    private final long poolSizeCheckInterval;
+
+    private AtomicInteger poolSize;
+
     /** Whether the current {@link DataBuffer} is a broadcast sort buffer. */
     private boolean isBroadcastDataBuffer;
 
     public SortBufferAccumulator(
             int numSubpartitions,
-            int numBuffers,
+            int numExpectedBuffers,
             int bufferSizeBytes,
+            long poolSizeCheckInterval,
             TieredStorageMemoryManager memoryManager,
             boolean isPartialRecordAllowed) {
         this.numSubpartitions = numSubpartitions;
         this.bufferSizeBytes = bufferSizeBytes;
-        this.numBuffers = numBuffers;
+        this.numExpectedBuffers = numExpectedBuffers;
+        this.poolSizeCheckInterval = poolSizeCheckInterval;
         this.memoryManager = memoryManager;
         this.isPartialRecordAllowed = isPartialRecordAllowed;
     }
 
     @Override
     public void setup(TriConsumer<TieredStorageSubpartitionId, Buffer, 
Integer> bufferFlusher) {
         this.accumulatedBufferFlusher = bufferFlusher;
+        this.poolSize = new AtomicInteger(memoryManager.getBufferPoolSize());
+
+        if (poolSizeCheckInterval > 0) {
+            periodicalAccumulatorFlusher.scheduleAtFixedRate(

Review Comment:
   I think you want to use `scheduleWithFixedRate`, which will make sure only 
one task is running. While `scheduleAtFixedRate` can not make the promise.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/SortBufferAccumulator.java:
##########
@@ -94,29 +99,58 @@ public class SortBufferAccumulator implements 
BufferAccumulator {
     @Nullable
     private TriConsumer<TieredStorageSubpartitionId, Buffer, Integer> 
accumulatedBufferFlusher;
 
+    /**
+     * An executor to periodically check the size of buffer pool. If the size 
is changed, the
+     * accumulated buffers should be flushed to release the buffers.
+     */
+    private final ScheduledExecutorService periodicalAccumulatorFlusher =
+            Executors.newSingleThreadScheduledExecutor(

Review Comment:
   Is this necessary? 
   Why don't we use a listen-and-notify strategy to check the buffer pool size 
change?



##########
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/common/TieredStorageConfiguration.java:
##########
@@ -344,6 +401,21 @@ public Builder setRemoteStorageBasePath(String 
remoteStorageBasePath) {
             return this;
         }
 
+        public Builder setMemorySafeModeEnabled(boolean memorySafeModeEnabled) 
{
+            this.memorySafeModeEnabled = memorySafeModeEnabled;
+            return this;
+        }
+
+        public Builder setMinBuffersPerGate(int minBuffersPerGate) {
+            this.minBuffersPerGate = minBuffersPerGate;
+            return this;
+        }
+

Review Comment:
   We'd better also add the setter method for poolSizeCheckInterval in case of 
other usage.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferManager.java:
##########
@@ -132,16 +134,40 @@ private boolean shouldContinueRequest(BufferPool 
bufferPool) {
         }
     }
 
-    /** Requests exclusive buffers from the provider. */
-    void requestExclusiveBuffers(int numExclusiveBuffers) throws IOException {
-        checkArgument(numExclusiveBuffers >= 0, "Num exclusive buffers must be 
non-negative.");
-        if (numExclusiveBuffers == 0) {
-            return;
+    private void resizeBufferQueue() {
+        int currentSize = 
inputChannel.inputGate.getBufferPool().getNumBuffers();
+        if (currentSize > 1 && currentSize != bufferPoolSize) {
+            int numChannels = 
inputChannel.inputGate.getNumberOfInputChannels();
+            int targetExclusivePerChannel =
+                    Math.min(initialCredit, (currentSize - 1) / numChannels);
+            numExclusiveBuffers = targetExclusivePerChannel;
+            bufferPoolSize = currentSize;
         }
+    }
 
-        Collection<MemorySegment> segments =
-                globalPool.requestUnpooledMemorySegments(numExclusiveBuffers);
+    /** Requests exclusive buffers from the provider. */
+    void requestExclusiveBuffers() {
         synchronized (bufferQueue) {
+            checkArgument(numExclusiveBuffers >= 0, "Num exclusive buffers 
must be non-negative.");

Review Comment:
   checkArgument -> checkState



##########
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPool.java:
##########
@@ -600,73 +607,97 @@ private void redistributeBuffers() {
         }
 
         // All buffers, which are not among the required ones
-        final int numAvailableMemorySegment = totalNumberOfMemorySegments - 
numTotalRequiredBuffers;
+        int numAvailableMemorySegment = totalNumberOfMemorySegments - 
numTotalRequiredBuffers;
 
         if (numAvailableMemorySegment == 0) {
             // in this case, we need to redistribute buffers so that every 
pool gets its minimum
             for (LocalBufferPool bufferPool : resizableBufferPools) {
-                
bufferPool.setNumBuffers(bufferPool.getNumberOfRequiredMemorySegments());
+                
bufferPool.setNumBuffers(bufferPool.getMinNumberOfMemorySegments());
             }
             return;
         }
 
-        /*
-         * With buffer pools being potentially limited, let's distribute the 
available memory
-         * segments based on the capacity of each buffer pool, i.e. the 
maximum number of segments
-         * an unlimited buffer pool can take is numAvailableMemorySegment, for 
limited buffer pools
-         * it may be less. Based on this and the sum of all these values 
(totalCapacity), we build
-         * a ratio that we use to distribute the buffers.
-         */
+        Map<LocalBufferPool, Integer> cachedPoolSize =
+                resizableBufferPools.stream()
+                        .collect(
+                                Collectors.toMap(
+                                        Function.identity(),
+                                        
LocalBufferPool::getMinNumberOfMemorySegments));
+
+        while (true) {
+            int remaining = redistributeBuffers(numAvailableMemorySegment, 
cachedPoolSize);
+
+            // Stop the loop iteration when there is no remaining segments or 
all local buffer pools
+            // have reached the max number.
+            if (remaining == 0 || remaining == numAvailableMemorySegment) {
+                for (LocalBufferPool bufferPool : resizableBufferPools) {
+                    bufferPool.setNumBuffers(
+                            cachedPoolSize.getOrDefault(
+                                    bufferPool, 
bufferPool.getMinNumberOfMemorySegments()));
+                }
+                break;
+            }
+            numAvailableMemorySegment = remaining;
+        }
+    }
 
-        long totalCapacity = 0; // long to avoid int overflow
+    /**
+     * @param numBuffersToRedistribute the buffers to be redistributed.
+     * @param cachedPoolSize the map to cache the intermediate result.
+     * @return the remaining buffers that can continue to be redistributed.
+     */
+    private int redistributeBuffers(
+            int numBuffersToRedistribute, Map<LocalBufferPool, Integer> 
cachedPoolSize) {
+        long totalWeight = 0;
 
+        // Calculates the total weights of all local buffer pools that can be 
distributed
+        // buffers.
         for (LocalBufferPool bufferPool : resizableBufferPools) {
-            int excessMax =
-                    bufferPool.getMaxNumberOfMemorySegments()
-                            - bufferPool.getNumberOfRequiredMemorySegments();
-            totalCapacity += Math.min(numAvailableMemorySegment, excessMax);
-        }
-
-        // no capacity to receive additional buffers?
-        if (totalCapacity == 0) {
-            return; // necessary to avoid div by zero when nothing to 
re-distribute
+            if (cachedPoolSize.get(bufferPool) == 
bufferPool.getMaxNumberOfMemorySegments()) {
+                continue;
+            }
+            totalWeight += getWeight(bufferPool);
         }
 
-        // since one of the arguments of 'min(a,b)' is a positive int, this is 
actually
-        // guaranteed to be within the 'int' domain
-        // (we use a checked downCast to handle possible bugs more gracefully).
-        final int memorySegmentsToDistribute =
-                MathUtils.checkedDownCast(Math.min(numAvailableMemorySegment, 
totalCapacity));
+        int totalAllocated = 0;
 
-        long totalPartsUsed = 0; // of totalCapacity
-        int numDistributedMemorySegment = 0;
         for (LocalBufferPool bufferPool : resizableBufferPools) {
-            int excessMax =
-                    bufferPool.getMaxNumberOfMemorySegments()
-                            - bufferPool.getNumberOfRequiredMemorySegments();
-
-            // shortcut
-            if (excessMax == 0) {
+            if (cachedPoolSize.get(bufferPool) == 
bufferPool.getMaxNumberOfMemorySegments()) {
                 continue;
             }
+            double fraction = (double) getWeight(bufferPool) / totalWeight;

Review Comment:
   I think float is enough for the calculation.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/SortBufferAccumulator.java:
##########
@@ -94,29 +99,58 @@ public class SortBufferAccumulator implements 
BufferAccumulator {
     @Nullable
     private TriConsumer<TieredStorageSubpartitionId, Buffer, Integer> 
accumulatedBufferFlusher;
 
+    /**
+     * An executor to periodically check the size of buffer pool. If the size 
is changed, the
+     * accumulated buffers should be flushed to release the buffers.
+     */
+    private final ScheduledExecutorService periodicalAccumulatorFlusher =
+            Executors.newSingleThreadScheduledExecutor(
+                    new 
ExecutorThreadFactory("hybrid-shuffle-periodical-accumulator-flusher"));
+
+    private final long poolSizeCheckInterval;
+
+    private AtomicInteger poolSize;
+
     /** Whether the current {@link DataBuffer} is a broadcast sort buffer. */
     private boolean isBroadcastDataBuffer;
 
     public SortBufferAccumulator(
             int numSubpartitions,
-            int numBuffers,
+            int numExpectedBuffers,
             int bufferSizeBytes,
+            long poolSizeCheckInterval,
             TieredStorageMemoryManager memoryManager,
             boolean isPartialRecordAllowed) {
         this.numSubpartitions = numSubpartitions;
         this.bufferSizeBytes = bufferSizeBytes;
-        this.numBuffers = numBuffers;
+        this.numExpectedBuffers = numExpectedBuffers;
+        this.poolSizeCheckInterval = poolSizeCheckInterval;
         this.memoryManager = memoryManager;
         this.isPartialRecordAllowed = isPartialRecordAllowed;
     }
 
     @Override
     public void setup(TriConsumer<TieredStorageSubpartitionId, Buffer, 
Integer> bufferFlusher) {
         this.accumulatedBufferFlusher = bufferFlusher;
+        this.poolSize = new AtomicInteger(memoryManager.getBufferPoolSize());
+
+        if (poolSizeCheckInterval > 0) {
+            periodicalAccumulatorFlusher.scheduleAtFixedRate(
+                    () -> {
+                        int newSize = this.memoryManager.getBufferPoolSize();
+                        int oldSize = poolSize.getAndSet(newSize);
+                        if (oldSize != newSize) {

Review Comment:
   Why do we flush the buffers when the buffer pool size is updated?



##########
flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/TestingBufferPool.java:
##########
@@ -21,11 +21,12 @@
 package org.apache.flink.runtime.io.network.buffer;
 
 import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
 
 import java.util.concurrent.CompletableFuture;
 
-/** No-op implementation of {@link BufferPool}. */
-public class NoOpBufferPool implements BufferPool {
+/** Implementation of {@link BufferPool} for testing. */

Review Comment:
   If we want to create a `TestingBufferPool`, we'd better create it like 
`TestingHsDataView`, then other developers can reuse it anywhere for any test 
purposes.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/SortBufferAccumulator.java:
##########
@@ -144,6 +178,7 @@ public void receive(
 
     @Override
     public void close() {
+        periodicalAccumulatorFlusher.shutdown();

Review Comment:
   Please ref the process of executor shutdown in 
`TieredStorageMemoryManagerImpl`.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMemoryManagerImpl.java:
##########
@@ -335,19 +338,21 @@ private MemorySegment requestBufferBlockingFromPool() {
     }
 
     /** @return a memory segment from the internal buffer queue. */
-    private MemorySegment requestBufferBlockingFromQueue() {
+    private MemorySegment requestBufferFromQueue() {
         CompletableFuture<Void> requestBufferFuture = new 
CompletableFuture<>();
         scheduleCheckRequestBufferFuture(
                 requestBufferFuture, 
INITIAL_REQUEST_BUFFER_TIMEOUT_FOR_RECLAIMING_MS);
 
+        hardBackpressureTimerGauge.markStart();
         MemorySegment memorySegment = null;
         try {
-            memorySegment = bufferQueue.take();
+            memorySegment = bufferQueue.poll(100, TimeUnit.MILLISECONDS);

Review Comment:
   If the segment is not polled in 100ms and a null may be returned, 
`requestBufferBlocking` may throw a NPE because of a null segment, is it 
expected?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to