Github user NicoK commented on a diff in the pull request: https://github.com/apache/flink/pull/4509#discussion_r152860104 --- Diff: flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java --- @@ -301,81 +306,388 @@ public void testProducerFailedException() throws Exception { } /** - * Tests {@link RemoteInputChannel#recycle(MemorySegment)}, verifying the exclusive segment is - * recycled to available buffers directly and it triggers notify of announced credit. + * Tests to verify that the input channel requests floating buffers from buffer pool + * in order to maintain backlog + initialCredit buffers available once receiving the + * sender's backlog, and registers as listener if no floating buffers available. */ @Test - public void testRecycleExclusiveBufferBeforeReleased() throws Exception { - final SingleInputGate inputGate = mock(SingleInputGate.class); - final RemoteInputChannel inputChannel = spy(createRemoteInputChannel(inputGate)); + public void testRequestFloatingBufferOnSenderBacklog() throws Exception { + // Setup + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(12, 32, MemoryType.HEAP); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + try { + final int numFloatingBuffers = 10; + final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); + inputGate.setBufferPool(bufferPool); + + // Assign exclusive segments to the channel + final int numExclusiveBuffers = 2; + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + assertEquals("There should be " + numExclusiveBuffers + " buffers available in the channel", + numExclusiveBuffers, inputChannel.getNumberOfAvailableBuffers()); - // Recycle exclusive segment - inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel)); + // Receive the producer's backlog + inputChannel.onSenderBacklog(8); - assertEquals("There should be one buffer available after recycle.", - 1, inputChannel.getNumberOfAvailableBuffers()); - verify(inputChannel, times(1)).notifyCreditAvailable(); + // Request the number of floating buffers by the formula of backlog + initialCredit - availableBuffers + verify(bufferPool, times(8)).requestBuffer(); + verify(bufferPool, times(0)).addBufferListener(inputChannel); + assertEquals("There should be 10 buffers available in the channel", + 10, inputChannel.getNumberOfAvailableBuffers()); - inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel)); + inputChannel.onSenderBacklog(11); - assertEquals("There should be two buffers available after recycle.", - 2, inputChannel.getNumberOfAvailableBuffers()); - // It should be called only once when increased from zero. - verify(inputChannel, times(1)).notifyCreditAvailable(); + // Need extra three floating buffers, but only two buffers available in buffer pool, register as listener as a result + verify(bufferPool, times(11)).requestBuffer(); + verify(bufferPool, times(1)).addBufferListener(inputChannel); + assertEquals("There should be 12 buffers available in the channel", + 12, inputChannel.getNumberOfAvailableBuffers()); + + inputChannel.onSenderBacklog(12); + + // Already in the status of waiting for buffers and will not request any more + verify(bufferPool, times(11)).requestBuffer(); + verify(bufferPool, times(1)).addBufferListener(inputChannel); + + } finally { + // Release all the buffer resources + inputChannel.releaseAllResources(); + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + } } /** - * Tests {@link RemoteInputChannel#recycle(MemorySegment)}, verifying the exclusive segment is - * recycled to global pool via input gate when channel is released. + * Tests to verify that the buffer pool will distribute available floating buffers among + * all the channel listeners in a fair way. */ @Test - public void testRecycleExclusiveBufferAfterReleased() throws Exception { + public void testFairDistributionFloatingBuffers() throws Exception { // Setup - final SingleInputGate inputGate = mock(SingleInputGate.class); - final RemoteInputChannel inputChannel = spy(createRemoteInputChannel(inputGate)); + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(12, 32, MemoryType.HEAP); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel channel1 = spy(createRemoteInputChannel(inputGate)); + final RemoteInputChannel channel2 = spy(createRemoteInputChannel(inputGate)); + final RemoteInputChannel channel3 = spy(createRemoteInputChannel(inputGate)); + try { + final int numFloatingBuffers = 3; + final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); + inputGate.setBufferPool(bufferPool); + + // Assign exclusive segments to the channels + inputGate.setInputChannel(channel1.partitionId.getPartitionId(), channel1); + inputGate.setInputChannel(channel2.partitionId.getPartitionId(), channel2); + inputGate.setInputChannel(channel3.partitionId.getPartitionId(), channel3); + final int numExclusiveBuffers = 2; + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + // Exhaust all the floating buffers + final List<Buffer> floatingBuffers = new ArrayList<>(numFloatingBuffers); + for (int i = 0; i < numFloatingBuffers; i++) { + Buffer buffer = bufferPool.requestBuffer(); + assertNotNull(buffer); + floatingBuffers.add(buffer); + } + + // Receive the producer's backlog to trigger request floating buffers from pool + // and register as listeners as a result + channel1.onSenderBacklog(8); + channel2.onSenderBacklog(8); + channel3.onSenderBacklog(8); + + verify(bufferPool, times(1)).addBufferListener(channel1); + verify(bufferPool, times(1)).addBufferListener(channel2); + verify(bufferPool, times(1)).addBufferListener(channel3); + assertEquals("There should be " + numExclusiveBuffers + " buffers available in the channel", + numExclusiveBuffers, channel1.getNumberOfAvailableBuffers()); + assertEquals("There should be " + numExclusiveBuffers + " buffers available in the channel", + numExclusiveBuffers, channel2.getNumberOfAvailableBuffers()); + assertEquals("There should be " + numExclusiveBuffers + " buffers available in the channel", + numExclusiveBuffers, channel3.getNumberOfAvailableBuffers()); + + // Recycle three floating buffers to trigger notify buffer available + for (Buffer buffer : floatingBuffers) { + buffer.recycle(); + } + + verify(channel1, times(1)).notifyBufferAvailable(any(Buffer.class)); + verify(channel2, times(1)).notifyBufferAvailable(any(Buffer.class)); + verify(channel3, times(1)).notifyBufferAvailable(any(Buffer.class)); + assertEquals("There should be 3 buffers available in the channel", 3, channel1.getNumberOfAvailableBuffers()); + assertEquals("There should be 3 buffers available in the channel", 3, channel2.getNumberOfAvailableBuffers()); + assertEquals("There should be 3 buffers available in the channel", 3, channel3.getNumberOfAvailableBuffers()); + + } finally { + // Release all the buffer resources + channel1.releaseAllResources(); + channel2.releaseAllResources(); + channel3.releaseAllResources(); + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + } + } + + /** + * Tests to verify that there is no race condition with two things running in parallel: + * requesting floating buffers on sender backlog and some other thread releasing + * the input channel. + */ + @Test + public void testConcurrentOnSenderBacklogAndRelease() throws Exception { + // Setup + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(256, 32, MemoryType.HEAP); + final ExecutorService executor = Executors.newFixedThreadPool(2); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + try { + final BufferPool bufferPool = networkBufferPool.createBufferPool(128, 128); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, 2); + + final Callable<Void> requestBufferTask = new Callable<Void>() { + @Override + public Void call() throws Exception { + while (true) { + for (int j = 1; j <= 128; j++) { + inputChannel.onSenderBacklog(j); + } + + if (inputChannel.isReleased()) { + return null; + } + } + } + }; - inputChannel.releaseAllResources(); + final Callable<Void> releaseTask = new Callable<Void>() { + @Override + public Void call() throws Exception { + inputChannel.releaseAllResources(); + + return null; + } + }; + + // Submit tasks and wait to finish + final List<Future<Void>> results = Lists.newArrayListWithCapacity(2); + results.add(executor.submit(requestBufferTask)); + results.add(executor.submit(releaseTask)); + for (Future<Void> result : results) { + result.get(); + } - // Recycle exclusive segment after channel released - inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel)); + assertEquals("There should be no buffers available in the channel.", + 0, inputChannel.getNumberOfAvailableBuffers()); - assertEquals("Resource leak during recycling buffer after channel is released.", - 0, inputChannel.getNumberOfAvailableBuffers()); - verify(inputChannel, times(0)).notifyCreditAvailable(); - verify(inputGate, times(1)).returnExclusiveSegments(anyListOf(MemorySegment.class)); + } finally { + // Release all the buffer resources once exception + if (!inputChannel.isReleased()) { + inputChannel.releaseAllResources(); + } + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + + executor.shutdown(); + } } /** - * Tests {@link RemoteInputChannel#releaseAllResources()}, verifying the exclusive segments are - * recycled to global pool via input gate and no resource leak. + * Tests to verify that there is no race condition with two things running in parallel: + * requesting floating buffers on sender backlog and some other thread recycling + * floating or exclusive buffers. */ @Test - public void testReleaseExclusiveBuffers() throws Exception { + public void testConcurrentOnSenderBacklogAndRecycle() throws Exception { // Setup - final SingleInputGate inputGate = mock(SingleInputGate.class); - final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(256, 32, MemoryType.HEAP); + final ExecutorService executor = Executors.newFixedThreadPool(2); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + try { + final int numFloatingBuffers = 128; + final int numExclusiveSegments = 2; + final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveSegments); + + // Exhaust all the floating buffers + final List<Buffer> floatingBuffers = new ArrayList<>(numFloatingBuffers); + for (int i = 0; i < numFloatingBuffers; i++) { + Buffer buffer = bufferPool.requestBuffer(); + assertNotNull(buffer); + floatingBuffers.add(buffer); + } - // Assign exclusive segments to channel - final List<MemorySegment> exclusiveSegments = new ArrayList<>(); - final int numExclusiveBuffers = 2; - for (int i = 0; i < numExclusiveBuffers; i++) { - exclusiveSegments.add(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel)); + // Exhaust all the exclusive buffers + final List<Buffer> exclusiveBuffers = new ArrayList<>(numExclusiveSegments); + for (int i = 0; i < numExclusiveSegments; i++) { + Buffer buffer = inputChannel.requestBuffer(); + assertNotNull(buffer); + exclusiveBuffers.add(buffer); + } + + final int backlog = 128; + final Callable<Void> requestBufferTask = new Callable<Void>() { + @Override + public Void call() throws Exception { + for (int j = 1; j <= backlog; j++) { + inputChannel.onSenderBacklog(j); + } + + return null; + } + }; + + final Callable<Void> recycleBufferTask = new Callable<Void>() { + @Override + public Void call() throws Exception { + // Recycle all the exclusive buffers + for (Buffer buffer : exclusiveBuffers) { + buffer.recycle(); + } + + // Recycle all the floating buffers + for (Buffer buffer : floatingBuffers) { + buffer.recycle(); + } + + return null; + } + }; + + // Submit tasks and wait to finish + final List<Future<Void>> results = Lists.newArrayListWithCapacity(2); + results.add(executor.submit(requestBufferTask)); + results.add(executor.submit(recycleBufferTask)); + for (Future<Void> result : results) { + result.get(); + } + + final int numRequiredBuffers = backlog + numExclusiveSegments; + assertEquals("There should be " + numRequiredBuffers +" buffers available in channel.", + numRequiredBuffers, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be no buffers available in buffer pool.", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + } finally { + // Release all the buffer resources + inputChannel.releaseAllResources(); + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + + executor.shutdown(); } - inputChannel.assignExclusiveSegments(exclusiveSegments); + } - assertEquals("The number of available buffers is not equal to the assigned amount.", - numExclusiveBuffers, inputChannel.getNumberOfAvailableBuffers()); + /** + * Tests to verify that there is no race condition with two things running in parallel: + * recycling the exclusive or floating buffers and some other thread releasing the + * input channel. + */ + @Test + public void testConcurrentRecycleAndRelease() throws Exception { + // Setup + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(256, 32, MemoryType.HEAP); + final ExecutorService executor = Executors.newFixedThreadPool(2); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + try { + final int numFloatingBuffers = 128; + final int numExclusiveSegments = 2; + final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveSegments); + + // Exhaust all the floating buffers + final List<Buffer> floatingBuffers = new ArrayList<>(numFloatingBuffers); + for (int i = 0; i < numFloatingBuffers; i++) { + Buffer buffer = bufferPool.requestBuffer(); + assertNotNull(buffer); + floatingBuffers.add(buffer); + } + + // Exhaust all the exclusive buffers + final List<Buffer> exclusiveBuffers = new ArrayList<>(numExclusiveSegments); + for (int i = 0; i < numExclusiveSegments; i++) { + Buffer buffer = inputChannel.requestBuffer(); + assertNotNull(buffer); + exclusiveBuffers.add(buffer); + } + + final Callable<Void> recycleBufferTask = new Callable<Void>() { + @Override + public Void call() throws Exception { + // Recycle all the exclusive buffers --- End diff -- I was actually hoping we could extract more into a common test method but it's probably best as you implemented it to keep the actual tests easier to understand
---