Github user zhijiangW commented on a diff in the pull request: https://github.com/apache/flink/pull/4509#discussion_r152737920 --- Diff: flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java --- @@ -301,81 +306,388 @@ public void testProducerFailedException() throws Exception { } /** - * Tests {@link RemoteInputChannel#recycle(MemorySegment)}, verifying the exclusive segment is - * recycled to available buffers directly and it triggers notify of announced credit. + * Tests to verify that the input channel requests floating buffers from buffer pool + * in order to maintain backlog + initialCredit buffers available once receiving the + * sender's backlog, and registers as listener if no floating buffers available. */ @Test - public void testRecycleExclusiveBufferBeforeReleased() throws Exception { - final SingleInputGate inputGate = mock(SingleInputGate.class); - final RemoteInputChannel inputChannel = spy(createRemoteInputChannel(inputGate)); + public void testRequestFloatingBufferOnSenderBacklog() throws Exception { + // Setup + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(12, 32, MemoryType.HEAP); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + try { + final int numFloatingBuffers = 10; + final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); + inputGate.setBufferPool(bufferPool); + + // Assign exclusive segments to the channel + final int numExclusiveBuffers = 2; + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + assertEquals("There should be " + numExclusiveBuffers + " buffers available in the channel", + numExclusiveBuffers, inputChannel.getNumberOfAvailableBuffers()); - // Recycle exclusive segment - inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel)); + // Receive the producer's backlog + inputChannel.onSenderBacklog(8); - assertEquals("There should be one buffer available after recycle.", - 1, inputChannel.getNumberOfAvailableBuffers()); - verify(inputChannel, times(1)).notifyCreditAvailable(); + // Request the number of floating buffers by the formula of backlog + initialCredit - availableBuffers + verify(bufferPool, times(8)).requestBuffer(); + verify(bufferPool, times(0)).addBufferListener(inputChannel); + assertEquals("There should be 10 buffers available in the channel", + 10, inputChannel.getNumberOfAvailableBuffers()); - inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel)); + inputChannel.onSenderBacklog(11); - assertEquals("There should be two buffers available after recycle.", - 2, inputChannel.getNumberOfAvailableBuffers()); - // It should be called only once when increased from zero. - verify(inputChannel, times(1)).notifyCreditAvailable(); + // Need extra three floating buffers, but only two buffers available in buffer pool, register as listener as a result + verify(bufferPool, times(11)).requestBuffer(); + verify(bufferPool, times(1)).addBufferListener(inputChannel); + assertEquals("There should be 12 buffers available in the channel", + 12, inputChannel.getNumberOfAvailableBuffers()); + + inputChannel.onSenderBacklog(12); + + // Already in the status of waiting for buffers and will not request any more + verify(bufferPool, times(11)).requestBuffer(); + verify(bufferPool, times(1)).addBufferListener(inputChannel); + + } finally { + // Release all the buffer resources + inputChannel.releaseAllResources(); + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + } } /** - * Tests {@link RemoteInputChannel#recycle(MemorySegment)}, verifying the exclusive segment is - * recycled to global pool via input gate when channel is released. + * Tests to verify that the buffer pool will distribute available floating buffers among + * all the channel listeners in a fair way. */ @Test - public void testRecycleExclusiveBufferAfterReleased() throws Exception { + public void testFairDistributionFloatingBuffers() throws Exception { // Setup - final SingleInputGate inputGate = mock(SingleInputGate.class); - final RemoteInputChannel inputChannel = spy(createRemoteInputChannel(inputGate)); + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(12, 32, MemoryType.HEAP); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel channel1 = spy(createRemoteInputChannel(inputGate)); + final RemoteInputChannel channel2 = spy(createRemoteInputChannel(inputGate)); + final RemoteInputChannel channel3 = spy(createRemoteInputChannel(inputGate)); + try { + final int numFloatingBuffers = 3; + final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); + inputGate.setBufferPool(bufferPool); + + // Assign exclusive segments to the channels + inputGate.setInputChannel(channel1.partitionId.getPartitionId(), channel1); + inputGate.setInputChannel(channel2.partitionId.getPartitionId(), channel2); + inputGate.setInputChannel(channel3.partitionId.getPartitionId(), channel3); + final int numExclusiveBuffers = 2; + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + // Exhaust all the floating buffers + final List<Buffer> floatingBuffers = new ArrayList<>(numFloatingBuffers); + for (int i = 0; i < numFloatingBuffers; i++) { + Buffer buffer = bufferPool.requestBuffer(); + assertNotNull(buffer); + floatingBuffers.add(buffer); + } + + // Receive the producer's backlog to trigger request floating buffers from pool + // and register as listeners as a result + channel1.onSenderBacklog(8); + channel2.onSenderBacklog(8); + channel3.onSenderBacklog(8); + + verify(bufferPool, times(1)).addBufferListener(channel1); + verify(bufferPool, times(1)).addBufferListener(channel2); + verify(bufferPool, times(1)).addBufferListener(channel3); + assertEquals("There should be " + numExclusiveBuffers + " buffers available in the channel", + numExclusiveBuffers, channel1.getNumberOfAvailableBuffers()); + assertEquals("There should be " + numExclusiveBuffers + " buffers available in the channel", + numExclusiveBuffers, channel2.getNumberOfAvailableBuffers()); + assertEquals("There should be " + numExclusiveBuffers + " buffers available in the channel", + numExclusiveBuffers, channel3.getNumberOfAvailableBuffers()); + + // Recycle three floating buffers to trigger notify buffer available + for (Buffer buffer : floatingBuffers) { + buffer.recycle(); + } + + verify(channel1, times(1)).notifyBufferAvailable(any(Buffer.class)); + verify(channel2, times(1)).notifyBufferAvailable(any(Buffer.class)); + verify(channel3, times(1)).notifyBufferAvailable(any(Buffer.class)); + assertEquals("There should be 3 buffers available in the channel", 3, channel1.getNumberOfAvailableBuffers()); + assertEquals("There should be 3 buffers available in the channel", 3, channel2.getNumberOfAvailableBuffers()); + assertEquals("There should be 3 buffers available in the channel", 3, channel3.getNumberOfAvailableBuffers()); + + } finally { + // Release all the buffer resources + channel1.releaseAllResources(); + channel2.releaseAllResources(); + channel3.releaseAllResources(); + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + } + } + + /** + * Tests to verify that there is no race condition with two things running in parallel: + * requesting floating buffers on sender backlog and some other thread releasing + * the input channel. + */ + @Test + public void testConcurrentOnSenderBacklogAndRelease() throws Exception { + // Setup + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(256, 32, MemoryType.HEAP); + final ExecutorService executor = Executors.newFixedThreadPool(2); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + try { + final BufferPool bufferPool = networkBufferPool.createBufferPool(128, 128); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, 2); + + final Callable<Void> requestBufferTask = new Callable<Void>() { + @Override + public Void call() throws Exception { + while (true) { + for (int j = 1; j <= 128; j++) { + inputChannel.onSenderBacklog(j); + } + + if (inputChannel.isReleased()) { + return null; + } + } + } + }; - inputChannel.releaseAllResources(); + final Callable<Void> releaseTask = new Callable<Void>() { + @Override + public Void call() throws Exception { + inputChannel.releaseAllResources(); + + return null; + } + }; + + // Submit tasks and wait to finish + final List<Future<Void>> results = Lists.newArrayListWithCapacity(2); + results.add(executor.submit(requestBufferTask)); + results.add(executor.submit(releaseTask)); + for (Future<Void> result : results) { + result.get(); + } - // Recycle exclusive segment after channel released - inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel)); + assertEquals("There should be no buffers available in the channel.", + 0, inputChannel.getNumberOfAvailableBuffers()); - assertEquals("Resource leak during recycling buffer after channel is released.", - 0, inputChannel.getNumberOfAvailableBuffers()); - verify(inputChannel, times(0)).notifyCreditAvailable(); - verify(inputGate, times(1)).returnExclusiveSegments(anyListOf(MemorySegment.class)); + } finally { + // Release all the buffer resources once exception + if (!inputChannel.isReleased()) { + inputChannel.releaseAllResources(); + } + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + + executor.shutdown(); + } } /** - * Tests {@link RemoteInputChannel#releaseAllResources()}, verifying the exclusive segments are - * recycled to global pool via input gate and no resource leak. + * Tests to verify that there is no race condition with two things running in parallel: + * requesting floating buffers on sender backlog and some other thread recycling + * floating or exclusive buffers. */ @Test - public void testReleaseExclusiveBuffers() throws Exception { + public void testConcurrentOnSenderBacklogAndRecycle() throws Exception { // Setup - final SingleInputGate inputGate = mock(SingleInputGate.class); - final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(256, 32, MemoryType.HEAP); + final ExecutorService executor = Executors.newFixedThreadPool(2); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + try { + final int numFloatingBuffers = 128; + final int numExclusiveSegments = 2; + final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveSegments); + + // Exhaust all the floating buffers + final List<Buffer> floatingBuffers = new ArrayList<>(numFloatingBuffers); + for (int i = 0; i < numFloatingBuffers; i++) { + Buffer buffer = bufferPool.requestBuffer(); + assertNotNull(buffer); + floatingBuffers.add(buffer); + } - // Assign exclusive segments to channel - final List<MemorySegment> exclusiveSegments = new ArrayList<>(); - final int numExclusiveBuffers = 2; - for (int i = 0; i < numExclusiveBuffers; i++) { - exclusiveSegments.add(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel)); + // Exhaust all the exclusive buffers + final List<Buffer> exclusiveBuffers = new ArrayList<>(numExclusiveSegments); + for (int i = 0; i < numExclusiveSegments; i++) { + Buffer buffer = inputChannel.requestBuffer(); + assertNotNull(buffer); + exclusiveBuffers.add(buffer); + } + + final int backlog = 128; + final Callable<Void> requestBufferTask = new Callable<Void>() { + @Override + public Void call() throws Exception { + for (int j = 1; j <= backlog; j++) { + inputChannel.onSenderBacklog(j); + } + + return null; + } + }; + + final Callable<Void> recycleBufferTask = new Callable<Void>() { + @Override + public Void call() throws Exception { + // Recycle all the exclusive buffers + for (Buffer buffer : exclusiveBuffers) { + buffer.recycle(); + } + + // Recycle all the floating buffers + for (Buffer buffer : floatingBuffers) { + buffer.recycle(); + } + + return null; + } + }; + + // Submit tasks and wait to finish + final List<Future<Void>> results = Lists.newArrayListWithCapacity(2); + results.add(executor.submit(requestBufferTask)); + results.add(executor.submit(recycleBufferTask)); + for (Future<Void> result : results) { + result.get(); + } + + final int numRequiredBuffers = backlog + numExclusiveSegments; + assertEquals("There should be " + numRequiredBuffers +" buffers available in channel.", + numRequiredBuffers, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be no buffers available in buffer pool.", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + } finally { + // Release all the buffer resources + inputChannel.releaseAllResources(); + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + + executor.shutdown(); } - inputChannel.assignExclusiveSegments(exclusiveSegments); + } - assertEquals("The number of available buffers is not equal to the assigned amount.", - numExclusiveBuffers, inputChannel.getNumberOfAvailableBuffers()); + /** + * Tests to verify that there is no race condition with two things running in parallel: + * recycling the exclusive or floating buffers and some other thread releasing the + * input channel. + */ + @Test + public void testConcurrentRecycleAndRelease() throws Exception { + // Setup + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(256, 32, MemoryType.HEAP); + final ExecutorService executor = Executors.newFixedThreadPool(2); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + try { + final int numFloatingBuffers = 128; + final int numExclusiveSegments = 2; + final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveSegments); + + // Exhaust all the floating buffers + final List<Buffer> floatingBuffers = new ArrayList<>(numFloatingBuffers); + for (int i = 0; i < numFloatingBuffers; i++) { + Buffer buffer = bufferPool.requestBuffer(); + assertNotNull(buffer); + floatingBuffers.add(buffer); + } + + // Exhaust all the exclusive buffers + final List<Buffer> exclusiveBuffers = new ArrayList<>(numExclusiveSegments); + for (int i = 0; i < numExclusiveSegments; i++) { + Buffer buffer = inputChannel.requestBuffer(); + assertNotNull(buffer); + exclusiveBuffers.add(buffer); + } + + final Callable<Void> recycleBufferTask = new Callable<Void>() { + @Override + public Void call() throws Exception { + // Recycle all the exclusive buffers + for (Buffer buffer : exclusiveBuffers) { + buffer.recycle(); + } + + // Recycle all the floating buffers + for (Buffer buffer : floatingBuffers) { + buffer.recycle(); + } + + return null; + } + }; + + final Callable<Void> releaseTask = new Callable<Void>() { + @Override + public Void call() throws Exception { + inputChannel.releaseAllResources(); + + return null; + } + }; + + // Submit tasks and wait to finish + final List<Future<Void>> results = Lists.newArrayListWithCapacity(2); + results.add(executor.submit(recycleBufferTask)); + results.add(executor.submit(releaseTask)); + for (Future<Void> result : results) { + result.get(); + } + + assertEquals("There should be no buffers available in the channel.", + 0, inputChannel.getNumberOfAvailableBuffers()); --- End diff -- When the channel is released by one thread, we are not sure whether the floating buffer is requested and how many are requested by another thread. If the number of requested floating buffers before is less than the capacity of pool, it can not verify how many available floating buffers currently in the pool after channel released, and the same as `NetworkBufferPool`.
---