pnowojski closed pull request #6417: [FLINK-9913][runtime] Improve output serialization only once in RecordWriter URL: https://github.com/apache/flink/pull/6417
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/RecordSerializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/RecordSerializer.java index 25d292771d0..6eebbbe88eb 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/RecordSerializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/RecordSerializer.java @@ -66,29 +66,33 @@ public boolean isFullBuffer() { } /** - * Starts serializing and copying the given record to the target buffer - * (if available). + * Starts serializing the given record to an intermediate data buffer. * * @param record the record to serialize - * @return how much information was written to the target buffer and - * whether this buffer is full */ - SerializationResult addRecord(T record) throws IOException; + void serializeRecord(T record) throws IOException; /** - * Sets a (next) target buffer to use and continues writing remaining data - * to it until it is full. + * Copies the intermediate data serialization buffer to the given target buffer. * * @param bufferBuilder the new target buffer to use * @return how much information was written to the target buffer and * whether this buffer is full */ - SerializationResult continueWritingWithNextBufferBuilder(BufferBuilder bufferBuilder) throws IOException; + SerializationResult copyToBufferBuilder(BufferBuilder bufferBuilder); + + /** + * Checks to decrease the size of intermediate data serialization buffer after finishing the + * whole serialization process including {@link #serializeRecord(IOReadableWritable)} and + * {@link #copyToBufferBuilder(BufferBuilder)}. + */ + void prune(); /** - * Clear and release internal state. + * Supports copying an intermediate data serialization buffer to multiple target buffers + * by resetting its initial position before each copying. */ - void clear(); + void reset(); /** * @return <tt>true</tt> if has some serialized data pending copying to the result {@link BufferBuilder}. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java index c4ab53f4b3a..ba2ed0133fd 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java @@ -20,11 +20,8 @@ import org.apache.flink.core.io.IOReadableWritable; import org.apache.flink.core.memory.DataOutputSerializer; -import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferBuilder; -import javax.annotation.Nullable; - import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -32,7 +29,7 @@ /** * Record serializer which serializes the complete record to an intermediate * data serialization buffer and copies this buffer to target buffers - * one-by-one using {@link #continueWritingWithNextBufferBuilder(BufferBuilder)}. + * one-by-one using {@link #copyToBufferBuilder(BufferBuilder)}. * * @param <T> The type of the records that are serialized. */ @@ -50,10 +47,6 @@ /** Intermediate buffer for length serialization. */ private final ByteBuffer lengthBuffer; - /** Current target {@link Buffer} of the serializer. */ - @Nullable - private BufferBuilder targetBuffer; - public SpanningRecordSerializer() { serializationBuffer = new DataOutputSerializer(128); @@ -66,15 +59,12 @@ public SpanningRecordSerializer() { } /** - * Serializes the complete record to an intermediate data serialization - * buffer and starts copying it to the target buffer (if available). + * Serializes the complete record to an intermediate data serialization buffer. * * @param record the record to serialize - * @return how much information was written to the target buffer and - * whether this buffer is full */ @Override - public SerializationResult addRecord(T record) throws IOException { + public void serializeRecord(T record) throws IOException { if (CHECKED) { if (dataBuffer.hasRemaining()) { throw new IllegalStateException("Pending serialization of previous record."); @@ -91,21 +81,17 @@ public SerializationResult addRecord(T record) throws IOException { lengthBuffer.putInt(0, len); dataBuffer = serializationBuffer.wrapAsByteBuffer(); - - // Copy from intermediate buffers to current target memory segment - if (targetBuffer != null) { - targetBuffer.append(lengthBuffer); - targetBuffer.append(dataBuffer); - targetBuffer.commit(); - } - - return getSerializationResult(); } + /** + * Copies an intermediate data serialization buffer into the target BufferBuilder. + * + * @param targetBuffer the target BufferBuilder to copy to + * @return how much information was written to the target buffer and + * whether this buffer is full + */ @Override - public SerializationResult continueWritingWithNextBufferBuilder(BufferBuilder buffer) throws IOException { - targetBuffer = buffer; - + public SerializationResult copyToBufferBuilder(BufferBuilder targetBuffer) { boolean mustCommit = false; if (lengthBuffer.hasRemaining()) { targetBuffer.append(lengthBuffer); @@ -121,30 +107,28 @@ public SerializationResult continueWritingWithNextBufferBuilder(BufferBuilder bu targetBuffer.commit(); } - SerializationResult result = getSerializationResult(); - - // make sure we don't hold onto the large buffers for too long - if (result.isFullRecord()) { - serializationBuffer.clear(); - serializationBuffer.pruneBuffer(); - dataBuffer = serializationBuffer.wrapAsByteBuffer(); - } - - return result; + return getSerializationResult(targetBuffer); } - private SerializationResult getSerializationResult() { + private SerializationResult getSerializationResult(BufferBuilder targetBuffer) { if (dataBuffer.hasRemaining() || lengthBuffer.hasRemaining()) { return SerializationResult.PARTIAL_RECORD_MEMORY_SEGMENT_FULL; } return !targetBuffer.isFull() - ? SerializationResult.FULL_RECORD - : SerializationResult.FULL_RECORD_MEMORY_SEGMENT_FULL; + ? SerializationResult.FULL_RECORD + : SerializationResult.FULL_RECORD_MEMORY_SEGMENT_FULL; } @Override - public void clear() { - targetBuffer = null; + public void reset() { + dataBuffer.position(0); + lengthBuffer.position(0); + } + + @Override + public void prune() { + serializationBuffer.pruneBuffer(); + dataBuffer = serializationBuffer.wrapAsByteBuffer(); } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java index 970795c0564..84d81837ddd 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java @@ -58,10 +58,9 @@ private final int numChannels; - /** - * {@link RecordSerializer} per outgoing channel. - */ - private final RecordSerializer<T>[] serializers; + private final int[] broadcastChannels; + + private final RecordSerializer<T> serializer; private final Optional<BufferBuilder>[] bufferBuilders; @@ -89,23 +88,17 @@ public RecordWriter(ResultPartitionWriter writer, ChannelSelector<T> channelSele this.numChannels = writer.getNumberOfSubpartitions(); - /* - * The runtime exposes a channel abstraction for the produced results - * (see {@link ChannelSelector}). Every channel has an independent - * serializer. - */ - this.serializers = new SpanningRecordSerializer[numChannels]; + this.serializer = new SpanningRecordSerializer<T>(); this.bufferBuilders = new Optional[numChannels]; + this.broadcastChannels = new int[numChannels]; for (int i = 0; i < numChannels; i++) { - serializers[i] = new SpanningRecordSerializer<T>(); + broadcastChannels[i] = i; bufferBuilders[i] = Optional.empty(); } } public void emit(T record) throws IOException, InterruptedException { - for (int targetChannel : channelSelector.selectChannels(record, numChannels)) { - sendToTarget(record, targetChannel); - } + emit(record, channelSelector.selectChannels(record, numChannels)); } /** @@ -113,53 +106,78 @@ public void emit(T record) throws IOException, InterruptedException { * the {@link ChannelSelector}. */ public void broadcastEmit(T record) throws IOException, InterruptedException { - for (int targetChannel = 0; targetChannel < numChannels; targetChannel++) { - sendToTarget(record, targetChannel); - } + emit(record, broadcastChannels); } /** * This is used to send LatencyMarks to a random target channel. */ public void randomEmit(T record) throws IOException, InterruptedException { - sendToTarget(record, rng.nextInt(numChannels)); + serializer.serializeRecord(record); + + if (copyFromSerializerToTargetChannel(rng.nextInt(numChannels))) { + serializer.prune(); + } } - private void sendToTarget(T record, int targetChannel) throws IOException, InterruptedException { - RecordSerializer<T> serializer = serializers[targetChannel]; + private void emit(T record, int[] targetChannels) throws IOException, InterruptedException { + serializer.serializeRecord(record); + + boolean pruneAfterCopying = false; + for (int channel : targetChannels) { + if (copyFromSerializerToTargetChannel(channel)) { + pruneAfterCopying = true; + } + } - SerializationResult result = serializer.addRecord(record); + // Make sure we don't hold onto the large intermediate serialization buffer for too long + if (pruneAfterCopying) { + serializer.prune(); + } + } + /** + * @param targetChannel + * @return <tt>true</tt> if the intermediate serialization buffer should be pruned + */ + private boolean copyFromSerializerToTargetChannel(int targetChannel) throws IOException, InterruptedException { + // We should reset the initial position of the intermediate serialization buffer before + // copying, so the serialization results can be copied to multiple target buffers. + serializer.reset(); + + boolean pruneTriggered = false; + BufferBuilder bufferBuilder = getBufferBuilder(targetChannel); + SerializationResult result = serializer.copyToBufferBuilder(bufferBuilder); while (result.isFullBuffer()) { - if (tryFinishCurrentBufferBuilder(targetChannel, serializer)) { - // If this was a full record, we are done. Not breaking - // out of the loop at this point will lead to another - // buffer request before breaking out (that would not be - // a problem per se, but it can lead to stalls in the - // pipeline). - if (result.isFullRecord()) { - break; - } + numBytesOut.inc(bufferBuilder.finish()); + numBuffersOut.inc(); + + // If this was a full record, we are done. Not breaking out of the loop at this point + // will lead to another buffer request before breaking out (that would not be a + // problem per se, but it can lead to stalls in the pipeline). + if (result.isFullRecord()) { + pruneTriggered = true; + bufferBuilders[targetChannel] = Optional.empty(); + break; } - BufferBuilder bufferBuilder = requestNewBufferBuilder(targetChannel); - result = serializer.continueWritingWithNextBufferBuilder(bufferBuilder); + bufferBuilder = requestNewBufferBuilder(targetChannel); + result = serializer.copyToBufferBuilder(bufferBuilder); } checkState(!serializer.hasSerializedData(), "All data should be written at once"); if (flushAlways) { targetPartition.flush(targetChannel); } + return pruneTriggered; } public void broadcastEvent(AbstractEvent event) throws IOException { try (BufferConsumer eventBufferConsumer = EventSerializer.toBufferConsumer(event)) { for (int targetChannel = 0; targetChannel < numChannels; targetChannel++) { - RecordSerializer<T> serializer = serializers[targetChannel]; - - tryFinishCurrentBufferBuilder(targetChannel, serializer); + tryFinishCurrentBufferBuilder(targetChannel); - // retain the buffer so that it can be recycled by each channel of targetPartition + // Retain the buffer so that it can be recycled by each channel of targetPartition targetPartition.addBufferConsumer(eventBufferConsumer.copy(), targetChannel); } @@ -175,9 +193,7 @@ public void flushAll() { public void clearBuffers() { for (int targetChannel = 0; targetChannel < numChannels; targetChannel++) { - RecordSerializer<?> serializer = serializers[targetChannel]; closeBufferBuilder(targetChannel); - serializer.clear(); } } @@ -191,25 +207,32 @@ public void setMetricGroup(TaskIOMetricGroup metrics) { /** * Marks the current {@link BufferBuilder} as finished and clears the state for next one. - * - * @return true if some data were written */ - private boolean tryFinishCurrentBufferBuilder(int targetChannel, RecordSerializer<T> serializer) { - + private void tryFinishCurrentBufferBuilder(int targetChannel) { if (!bufferBuilders[targetChannel].isPresent()) { - return false; + return; } BufferBuilder bufferBuilder = bufferBuilders[targetChannel].get(); bufferBuilders[targetChannel] = Optional.empty(); - numBytesOut.inc(bufferBuilder.finish()); numBuffersOut.inc(); - serializer.clear(); - return true; + } + + /** + * The {@link BufferBuilder} may already exist if not filled up last time, otherwise we need + * request a new one for this target channel. + */ + private BufferBuilder getBufferBuilder(int targetChannel) throws IOException, InterruptedException { + if (bufferBuilders[targetChannel].isPresent()) { + return bufferBuilders[targetChannel].get(); + } else { + return requestNewBufferBuilder(targetChannel); + } } private BufferBuilder requestNewBufferBuilder(int targetChannel) throws IOException, InterruptedException { - checkState(!bufferBuilders[targetChannel].isPresent()); + checkState(!bufferBuilders[targetChannel].isPresent() || bufferBuilders[targetChannel].get().isFinished()); + BufferBuilder bufferBuilder = targetPartition.getBufferProvider().requestBufferBuilderBlocking(); bufferBuilders[targetChannel] = Optional.of(bufferBuilder); targetPartition.addBufferConsumer(bufferBuilder.createBufferConsumer(), targetChannel); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializationTest.java index 2e1063f5dbe..a17008a8068 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializationTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializationTest.java @@ -23,6 +23,7 @@ import org.apache.flink.runtime.io.network.buffer.BufferBuilder; import org.apache.flink.runtime.io.network.buffer.BufferConsumer; import org.apache.flink.runtime.io.network.serialization.types.LargeObjectType; +import org.apache.flink.runtime.io.network.util.DeserializationUtils; import org.apache.flink.testutils.serialization.types.IntType; import org.apache.flink.testutils.serialization.types.SerializationTestType; import org.apache.flink.testutils.serialization.types.SerializationTestTypeFactory; @@ -134,7 +135,7 @@ private static void testSerializationRoundTrip( // ------------------------------------------------------------------------------------------------------------- - BufferConsumerAndSerializerResult serializationResult = setNextBufferForSerializer(serializer, segmentSize); + BufferAndSerializerResult serializationResult = setNextBufferForSerializer(serializer, segmentSize); int numRecords = 0; for (SerializationTestType record : records) { @@ -144,27 +145,16 @@ private static void testSerializationRoundTrip( numRecords++; // serialize record - if (serializer.addRecord(record).isFullBuffer()) { + serializer.serializeRecord(record); + if (serializer.copyToBufferBuilder(serializationResult.getBufferBuilder()).isFullBuffer()) { // buffer is full => start deserializing deserializer.setNextBuffer(serializationResult.buildBuffer()); - while (!serializedRecords.isEmpty()) { - SerializationTestType expected = serializedRecords.poll(); - SerializationTestType actual = expected.getClass().newInstance(); - - if (deserializer.getNextRecord(actual).isFullRecord()) { - Assert.assertEquals(expected, actual); - numRecords--; - } else { - serializedRecords.addFirst(expected); - break; - } - } + numRecords -= DeserializationUtils.deserializeRecords(serializedRecords, deserializer); // move buffers as long as necessary (for long records) while ((serializationResult = setNextBufferForSerializer(serializer, segmentSize)).isFullBuffer()) { deserializer.setNextBuffer(serializationResult.buildBuffer()); - serializer.clear(); } } } @@ -189,7 +179,7 @@ private static void testSerializationRoundTrip( Assert.assertFalse(deserializer.hasUnfinishedData()); } - private static BufferConsumerAndSerializerResult setNextBufferForSerializer( + private static BufferAndSerializerResult setNextBufferForSerializer( RecordSerializer<SerializationTestType> serializer, int segmentSize) throws IOException { // create a bufferBuilder with some random starting offset to properly test handling buffer slices in the @@ -199,21 +189,30 @@ private static BufferConsumerAndSerializerResult setNextBufferForSerializer( BufferConsumer bufferConsumer = bufferBuilder.createBufferConsumer(); bufferConsumer.build().recycleBuffer(); - serializer.clear(); - return new BufferConsumerAndSerializerResult( + return new BufferAndSerializerResult( + bufferBuilder, bufferConsumer, - serializer.continueWritingWithNextBufferBuilder(bufferBuilder)); + serializer.copyToBufferBuilder(bufferBuilder)); } - private static class BufferConsumerAndSerializerResult { + private static class BufferAndSerializerResult { + private final BufferBuilder bufferBuilder; private final BufferConsumer bufferConsumer; private final RecordSerializer.SerializationResult serializationResult; - public BufferConsumerAndSerializerResult(BufferConsumer bufferConsumer, RecordSerializer.SerializationResult serializationResult) { + public BufferAndSerializerResult( + BufferBuilder bufferBuilder, + BufferConsumer bufferConsumer, + RecordSerializer.SerializationResult serializationResult) { + this.bufferBuilder = bufferBuilder; this.bufferConsumer = bufferConsumer; this.serializationResult = serializationResult; } + public BufferBuilder getBufferBuilder() { + return bufferBuilder; + } + public Buffer buildBuffer() { return buildSingleBuffer(bufferConsumer); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializerTest.java index c39b54af2f8..e5f5dfcd102 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializerTest.java @@ -21,6 +21,7 @@ import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.runtime.io.network.buffer.BufferBuilder; import org.apache.flink.testutils.serialization.types.SerializationTestType; import org.apache.flink.testutils.serialization.types.SerializationTestTypeFactory; import org.apache.flink.testutils.serialization.types.Util; @@ -40,25 +41,25 @@ @Test public void testHasSerializedData() throws IOException { - final int segmentSize = 16; - final SpanningRecordSerializer<SerializationTestType> serializer = new SpanningRecordSerializer<>(); final SerializationTestType randomIntRecord = Util.randomRecord(SerializationTestTypeFactory.INT); Assert.assertFalse(serializer.hasSerializedData()); - serializer.addRecord(randomIntRecord); + serializer.serializeRecord(randomIntRecord); Assert.assertTrue(serializer.hasSerializedData()); - serializer.continueWritingWithNextBufferBuilder(createBufferBuilder(segmentSize)); + final BufferBuilder bufferBuilder1 = createBufferBuilder(16); + serializer.copyToBufferBuilder(bufferBuilder1); Assert.assertFalse(serializer.hasSerializedData()); - serializer.continueWritingWithNextBufferBuilder(createBufferBuilder(8)); - - serializer.addRecord(randomIntRecord); + final BufferBuilder bufferBuilder2 = createBufferBuilder(8); + serializer.reset(); + serializer.copyToBufferBuilder(bufferBuilder2); Assert.assertFalse(serializer.hasSerializedData()); - serializer.addRecord(randomIntRecord); + serializer.reset(); + serializer.copyToBufferBuilder(bufferBuilder2); // Buffer builder full! Assert.assertTrue(serializer.hasSerializedData()); } @@ -68,15 +69,10 @@ public void testEmptyRecords() throws IOException { final int segmentSize = 11; final SpanningRecordSerializer<SerializationTestType> serializer = new SpanningRecordSerializer<>(); + final BufferBuilder bufferBuilder1 = createBufferBuilder(segmentSize); - try { - Assert.assertEquals( - RecordSerializer.SerializationResult.FULL_RECORD, - serializer.continueWritingWithNextBufferBuilder(createBufferBuilder(segmentSize))); - } catch (IOException e) { - e.printStackTrace(); - Assert.fail(e.getMessage()); - } + Assert.assertEquals(RecordSerializer.SerializationResult.FULL_RECORD, + serializer.copyToBufferBuilder(bufferBuilder1)); SerializationTestType emptyRecord = new SerializationTestType() { @Override @@ -106,17 +102,19 @@ public boolean equals(Object obj) { } }; - RecordSerializer.SerializationResult result = serializer.addRecord(emptyRecord); - Assert.assertEquals(RecordSerializer.SerializationResult.FULL_RECORD, result); + serializer.serializeRecord(emptyRecord); + Assert.assertEquals(RecordSerializer.SerializationResult.FULL_RECORD, serializer.copyToBufferBuilder(bufferBuilder1)); - result = serializer.addRecord(emptyRecord); - Assert.assertEquals(RecordSerializer.SerializationResult.FULL_RECORD, result); + serializer.reset(); + Assert.assertEquals(RecordSerializer.SerializationResult.FULL_RECORD, serializer.copyToBufferBuilder(bufferBuilder1)); - result = serializer.addRecord(emptyRecord); - Assert.assertEquals(RecordSerializer.SerializationResult.PARTIAL_RECORD_MEMORY_SEGMENT_FULL, result); + serializer.reset(); + Assert.assertEquals(RecordSerializer.SerializationResult.PARTIAL_RECORD_MEMORY_SEGMENT_FULL, + serializer.copyToBufferBuilder(bufferBuilder1)); - result = serializer.continueWritingWithNextBufferBuilder(createBufferBuilder(segmentSize)); - Assert.assertEquals(RecordSerializer.SerializationResult.FULL_RECORD, result); + final BufferBuilder bufferBuilder2 = createBufferBuilder(segmentSize); + Assert.assertEquals(RecordSerializer.SerializationResult.FULL_RECORD, + serializer.copyToBufferBuilder(bufferBuilder2)); } @Test @@ -169,26 +167,29 @@ private void test(Util.MockRecords records, int segmentSize) throws Exception { // ------------------------------------------------------------------------------------------------------------- - serializer.continueWritingWithNextBufferBuilder(createBufferBuilder(segmentSize)); - + BufferBuilder bufferBuilder = createBufferBuilder(segmentSize); int numBytes = 0; for (SerializationTestType record : records) { - RecordSerializer.SerializationResult result = serializer.addRecord(record); + serializer.serializeRecord(record); + RecordSerializer.SerializationResult result = serializer.copyToBufferBuilder(bufferBuilder); numBytes += record.length() + serializationOverhead; if (numBytes < segmentSize) { Assert.assertEquals(RecordSerializer.SerializationResult.FULL_RECORD, result); } else if (numBytes == segmentSize) { Assert.assertEquals(RecordSerializer.SerializationResult.FULL_RECORD_MEMORY_SEGMENT_FULL, result); - serializer.continueWritingWithNextBufferBuilder(createBufferBuilder(segmentSize)); + bufferBuilder = createBufferBuilder(segmentSize); numBytes = 0; } else { Assert.assertEquals(RecordSerializer.SerializationResult.PARTIAL_RECORD_MEMORY_SEGMENT_FULL, result); while (result.isFullBuffer()) { numBytes -= segmentSize; - result = serializer.continueWritingWithNextBufferBuilder(createBufferBuilder(segmentSize)); + bufferBuilder = createBufferBuilder(segmentSize); + result = serializer.copyToBufferBuilder(bufferBuilder); } + + Assert.assertTrue(result.isFullRecord()); } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java index 0b0a2366e5a..ed9f4cc3026 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java @@ -28,7 +28,9 @@ import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; +import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; import org.apache.flink.runtime.io.network.api.serialization.RecordSerializer.SerializationResult; +import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferBuilder; import org.apache.flink.runtime.io.network.buffer.BufferConsumer; @@ -36,11 +38,18 @@ import org.apache.flink.runtime.io.network.buffer.BufferRecycler; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; +import org.apache.flink.runtime.io.network.util.DeserializationUtils; import org.apache.flink.runtime.io.network.util.TestPooledBufferProvider; +import org.apache.flink.testutils.serialization.types.SerializationTestType; +import org.apache.flink.testutils.serialization.types.SerializationTestTypeFactory; +import org.apache.flink.testutils.serialization.types.Util; import org.apache.flink.types.IntValue; import org.apache.flink.util.XORShiftRandom; +import org.junit.Assert; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.TemporaryFolder; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -70,6 +79,9 @@ */ public class RecordWriterTest { + @Rule + public TemporaryFolder tempFolder = new TemporaryFolder(); + // --------------------------------------------------------------------------------------------- // Resource release tests // --------------------------------------------------------------------------------------------- @@ -377,6 +389,77 @@ public void testBroadcastEmitBufferIndependence() throws Exception { assertEquals("Buffer 2 shares the same reader index as buffer 1", 0, buffer2.getReaderIndex()); } + /** + * Tests that records are broadcast via {@link ChannelSelector} and + * {@link RecordWriter#emit(IOReadableWritable)}. + */ + @Test + public void testEmitRecordWithBroadcastPartitioner() throws Exception { + emitRecordWithBroadcastPartitionerOrBroadcastEmitRecord(false); + } + + /** + * Tests that records are broadcast via {@link RecordWriter#broadcastEmit(IOReadableWritable)}. + */ + @Test + public void testBroadcastEmitRecord() throws Exception { + emitRecordWithBroadcastPartitionerOrBroadcastEmitRecord(true); + } + + /** + * The results of emitting records via BroadcastPartitioner or broadcasting records directly are the same, + * that is all the target channels can receive the whole outputs. + * + * @param isBroadcastEmit whether using {@link RecordWriter#broadcastEmit(IOReadableWritable)} or not + */ + private void emitRecordWithBroadcastPartitionerOrBroadcastEmitRecord(boolean isBroadcastEmit) throws Exception { + final int numChannels = 4; + final int bufferSize = 32; + final int numValues = 8; + final int serializationLength = 4; + + @SuppressWarnings("unchecked") + final Queue<BufferConsumer>[] queues = new Queue[numChannels]; + for (int i = 0; i < numChannels; i++) { + queues[i] = new ArrayDeque<>(); + } + + final TestPooledBufferProvider bufferProvider = new TestPooledBufferProvider(Integer.MAX_VALUE, bufferSize); + final ResultPartitionWriter partitionWriter = new CollectingPartitionWriter(queues, bufferProvider); + final RecordWriter<SerializationTestType> writer = isBroadcastEmit ? + new RecordWriter<>(partitionWriter) : + new RecordWriter<>(partitionWriter, new Broadcast<>()); + final RecordDeserializer<SerializationTestType> deserializer = new SpillingAdaptiveSpanningRecordDeserializer<>( + new String[]{ tempFolder.getRoot().getAbsolutePath() }); + + final ArrayDeque<SerializationTestType> serializedRecords = new ArrayDeque<>(); + final Iterable<SerializationTestType> records = Util.randomRecords(numValues, SerializationTestTypeFactory.INT); + for (SerializationTestType record : records) { + serializedRecords.add(record); + + if (isBroadcastEmit) { + writer.broadcastEmit(record); + } else { + writer.emit(record); + } + } + + final int requiredBuffers = numValues / (bufferSize / (4 + serializationLength)); + for (int i = 0; i < numChannels; i++) { + assertEquals(requiredBuffers, queues[i].size()); + + final ArrayDeque<SerializationTestType> expectedRecords = serializedRecords.clone(); + int assertRecords = 0; + for (int j = 0; j < requiredBuffers; j++) { + Buffer buffer = buildSingleBuffer(queues[i].remove()); + deserializer.setNextBuffer(buffer); + + assertRecords += DeserializationUtils.deserializeRecords(expectedRecords, deserializer); + } + Assert.assertEquals(numValues, assertRecords); + } + } + // --------------------------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------------------------- @@ -524,6 +607,27 @@ public void read(DataInputView in) throws IOException { } } + /** + * Broadcast channel selector that selects all the output channels. + */ + private static class Broadcast<T extends IOReadableWritable> implements ChannelSelector<T> { + + private int[] returnChannel; + + @Override + public int[] selectChannels(final T record, final int numberOfOutputChannels) { + if (returnChannel != null && returnChannel.length == numberOfOutputChannels) { + return returnChannel; + } else { + this.returnChannel = new int[numberOfOutputChannels]; + for (int i = 0; i < numberOfOutputChannels; i++) { + returnChannel[i] = i; + } + return returnChannel; + } + } + } + private static class TrackingBufferRecycler implements BufferRecycler { private final ArrayList<MemorySegment> recycledMemorySegments = new ArrayList<>(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/IteratorWrappingTestSingleInputGate.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/IteratorWrappingTestSingleInputGate.java index a91473327d0..3ce5b7019a0 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/IteratorWrappingTestSingleInputGate.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/IteratorWrappingTestSingleInputGate.java @@ -70,10 +70,9 @@ public IteratorWrappingTestSingleInputGate(int bufferSize, Class<T> recordType, @Override public Optional<BufferAndAvailability> getBufferAvailability() throws IOException { if (hasData) { - serializer.clear(); + serializer.serializeRecord(reuse); BufferBuilder bufferBuilder = createBufferBuilder(bufferSize); - serializer.continueWritingWithNextBufferBuilder(bufferBuilder); - serializer.addRecord(reuse); + serializer.copyToBufferBuilder(bufferBuilder); hasData = inputIterator.next(reuse) != null; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/DeserializationUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/DeserializationUtils.java new file mode 100644 index 00000000000..da103237ca8 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/DeserializationUtils.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.util; + +import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; +import org.apache.flink.testutils.serialization.types.SerializationTestType; + +import org.junit.Assert; + +import java.util.ArrayDeque; + +/** + * Utility class to help deserialization for testing. + */ +public final class DeserializationUtils { + + /** + * Iterates over the provided records to deserialize, verifies the results and stats + * the number of full records. + * + * @param records records to be deserialized + * @param deserializer the record deserializer + * @return the number of full deserialized records + */ + public static int deserializeRecords( + ArrayDeque<SerializationTestType> records, + RecordDeserializer<SerializationTestType> deserializer) throws Exception { + int deserializedRecords = 0; + + while (!records.isEmpty()) { + SerializationTestType expected = records.poll(); + SerializationTestType actual = expected.getClass().newInstance(); + + if (deserializer.getNextRecord(actual).isFullRecord()) { + Assert.assertEquals(expected, actual); + deserializedRecords++; + } else { + records.addFirst(expected); + break; + } + } + + return deserializedRecords; + } +} diff --git a/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java b/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java index ea38382bb5d..dbb81ab2683 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java @@ -105,10 +105,10 @@ private void setupInputChannels() throws IOException, InterruptedException { } else if (input != null && input.isStreamRecord()) { Object inputElement = input.getStreamRecord(); - BufferBuilder bufferBuilder = createBufferBuilder(bufferSize); - recordSerializer.continueWritingWithNextBufferBuilder(bufferBuilder); delegate.setInstance(inputElement); - recordSerializer.addRecord(delegate); + recordSerializer.serializeRecord(delegate); + BufferBuilder bufferBuilder = createBufferBuilder(bufferSize); + recordSerializer.copyToBufferBuilder(bufferBuilder); bufferBuilder.finish(); // Call getCurrentBuffer to ensure size is set ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services