This is an automated email from the ASF dual-hosted git repository.
Gabriel39 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 6bf4327a7a5 [fix](maxcompute) Estimate write block size from Arrow
buffers, not per-row serialization (#64612)
6bf4327a7a5 is described below
commit 6bf4327a7a58bcff9d2b869869f7ae3067a0f80f
Author: daidai <[email protected]>
AuthorDate: Wed Jun 24 16:20:02 2026 +0800
[fix](maxcompute) Estimate write block size from Arrow buffers, not per-row
serialization (#64612)
The old per-row estimateSingleRowPayloadBytes ZSTD-serialized a one-row
batch for every row (CPU-heavy and ~25x oversized); sum
FieldVector.getBufferSize() over the whole batch instead, and rotate the
block lazily.
---
.../doris/maxcompute/MaxComputeJniWriter.java | 184 ++++++++++++++-------
.../doris/maxcompute/MaxComputeJniWriterTest.java | 133 +++++++++++++++
2 files changed, 258 insertions(+), 59 deletions(-)
diff --git
a/fe/be-java-extensions/max-compute-connector/src/main/java/org/apache/doris/maxcompute/MaxComputeJniWriter.java
b/fe/be-java-extensions/max-compute-connector/src/main/java/org/apache/doris/maxcompute/MaxComputeJniWriter.java
index 9788184057e..ecb01d9092f 100644
---
a/fe/be-java-extensions/max-compute-connector/src/main/java/org/apache/doris/maxcompute/MaxComputeJniWriter.java
+++
b/fe/be-java-extensions/max-compute-connector/src/main/java/org/apache/doris/maxcompute/MaxComputeJniWriter.java
@@ -25,8 +25,6 @@ import org.apache.doris.common.maxcompute.MCUtils;
import com.aliyun.odps.Odps;
import com.aliyun.odps.OdpsType;
-import com.aliyun.odps.table.arrow.ArrowWriter;
-import com.aliyun.odps.table.arrow.ArrowWriterFactory;
import com.aliyun.odps.table.configuration.ArrowOptions;
import com.aliyun.odps.table.configuration.ArrowOptions.TimestampUnit;
import com.aliyun.odps.table.configuration.CompressionCodec;
@@ -67,7 +65,6 @@ import org.apache.log4j.Logger;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
-import java.io.OutputStream;
import java.math.BigDecimal;
import java.nio.charset.StandardCharsets;
import java.time.LocalDate;
@@ -125,6 +122,10 @@ public class MaxComputeJniWriter extends JniWriter {
private List<String> columnNames;
private long currentBlockId = -1L;
private long currentBlockWrittenBytes = 0L;
+ // Per-row Arrow payload size observed from previously written ranges.
Used to bound
+ // how many rows are materialized into a single Arrow root, so a large
incoming JNI
+ // block is never copied whole before its size is known. Refined as ranges
are written.
+ private long observedBytesPerRow = 0L;
private final List<WriterCommitMessage> commitMessages = new ArrayList<>();
// Statistics
@@ -234,7 +235,7 @@ public class MaxComputeJniWriter extends JniWriter {
}
try {
- writeRowsWithRowChecks(inputTable, numRows, numCols);
+ writeBatch(inputTable, numRows, numCols);
} catch (Exception e) {
String errorMsg = "Failed to write data to MaxCompute table " +
project + "." + tableName;
LOG.error(errorMsg, e);
@@ -272,79 +273,144 @@ public class MaxComputeJniWriter extends JniWriter {
openBatchWriter(requestBlockId());
}
- private void writeRowsWithRowChecks(VectorTable inputTable, int numRows,
int numCols) throws IOException {
+ private void writeBatch(VectorTable inputTable, int numRows, int numCols)
throws IOException {
int rowStart = 0;
while (rowStart < numRows) {
- int rowEnd = rowStart;
- long batchEstimatedBytes = 0L;
- boolean rotateAfterWrite = false;
- while (rowEnd < numRows) {
- long rowEstimatedBytes =
estimateSingleRowPayloadBytes(inputTable, numCols, rowEnd);
- boolean exceedsHardLimit = currentBlockWrittenBytes +
batchEstimatedBytes
- + rowEstimatedBytes > maxBlockBytes;
- if (exceedsHardLimit) {
- if (rowEnd == rowStart) {
- if (currentBlockWrittenBytes > 0) {
- rotateCurrentBatchWriter();
- continue;
- }
- batchEstimatedBytes += rowEstimatedBytes;
- rowEnd++;
- rotateAfterWrite = true;
- }
- break;
+ // Bound the rows copied into one Arrow root using the per-row
size observed so
+ // far, so an oversized incoming block is never materialized whole
before we know
+ // whether it fits the current block.
+ int probeEnd = rowStart +
boundedProbeRowCount(observedBytesPerRow, maxBlockBytes, numRows - rowStart);
+ try (VectorSchemaRoot root = buildRowRangeRoot(inputTable,
numCols, rowStart, probeEnd)) {
+ int probeRows = probeEnd - rowStart;
+ long probeBytes = estimateBatchPayloadBytes(root);
+ observedBytesPerRow = probeBytes / probeRows;
+ if (currentBlockWrittenBytes + probeBytes <= maxBlockBytes) {
+ writeRoot(root, probeRows, probeBytes);
+ rowStart = probeEnd;
+ continue;
}
- batchEstimatedBytes += rowEstimatedBytes;
- rowEnd++;
- if (currentBlockWrittenBytes + batchEstimatedBytes >=
maxBlockBytes) {
- rotateAfterWrite = true;
- break;
+
+ // The probe overflows the current block. Split it WITHOUT
rebuilding: the binary
+ // search measures leading-row sizes from this already-built
root via
+ // getBufferSizeFor, then we slice off the prefix that fits.
The remaining rows are
+ // rebuilt on the next iteration (after rotating), so no Arrow
buffer outlives the
+ // current block writer.
+ RowRange rowRange = findPartialRowRange(rowStart, probeEnd,
currentBlockWrittenBytes,
+ maxBlockBytes, (rangeStart, rangeEnd) ->
prefixBufferBytes(root, rangeEnd - rangeStart));
+ if (rowRange.rotateBeforeWrite) {
+ rotateCurrentBatchWriter();
+ continue;
}
- }
- if (rowEnd == rowStart) {
- long rowEstimatedBytes =
estimateSingleRowPayloadBytes(inputTable, numCols, rowStart);
- batchEstimatedBytes = rowEstimatedBytes;
- rowEnd = rowStart + 1;
- rotateAfterWrite = true;
+ int headRows = rowRange.rowEnd - rowStart;
+ try (VectorSchemaRoot head = root.slice(0, headRows)) {
+ writeRoot(head, headRows, rowRange.bytes);
+ }
+ rowStart = rowRange.rowEnd;
}
-
- try (VectorSchemaRoot root = buildRowRangeRoot(inputTable,
numCols, rowStart, rowEnd)) {
- batchWriter.write(root);
+ if (rowStart < numRows && currentBlockWrittenBytes >=
maxBlockBytes) {
+ rotateCurrentBatchWriter();
}
- batchWriter.flush();
- int rowsWrittenNow = rowEnd - rowStart;
- writtenRows += rowsWrittenNow;
- currentBlockWrittenBytes += batchEstimatedBytes;
- writtenBytes += batchEstimatedBytes;
- rowStart = rowEnd;
+ }
+ }
- if (rotateAfterWrite && rowStart < numRows) {
- rotateCurrentBatchWriter();
+ // Off-heap payload bytes of the leading rowCount rows of an already-built
Arrow root,
+ // read from the existing column buffers (getBufferSizeFor) without
rebuilding any vector.
+ static long prefixBufferBytes(VectorSchemaRoot root, int rowCount) {
+ long total = 0L;
+ for (FieldVector vector : root.getFieldVectors()) {
+ total += vector.getBufferSizeFor(rowCount);
+ }
+ return total;
+ }
+
+ /**
+ * Choose how many rows to materialize into the next Arrow root, bounded
so a large
+ * incoming JNI block is never copied whole before its size is known. The
bound targets
+ * roughly one MaxCompute block worth of payload using {@code
observedBytesPerRow}; before
+ * any range has been measured it probes a single row, then sizes from
that row's measured
+ * Arrow payload. The result is at least one row and never exceeds {@code
remainingRows}.
+ */
+ static int boundedProbeRowCount(long observedBytesPerRow, long
maxBlockBytes, int remainingRows) {
+ long cap;
+ if (observedBytesPerRow <= 0L) {
+ cap = 1L;
+ } else {
+ cap = Math.max(1L, maxBlockBytes / observedBytesPerRow);
+ }
+ if (cap >= remainingRows) {
+ return remainingRows;
+ }
+ return (int) cap;
+ }
+
+ private void writeRoot(VectorSchemaRoot root, int numRows, long
batchBytes) throws IOException {
+ batchWriter.write(root);
+ batchWriter.flush();
+
+ writtenRows += numRows;
+ currentBlockWrittenBytes += batchBytes;
+ writtenBytes += batchBytes;
+ }
+
+ static RowRange findPartialRowRange(int rowStart, int numRows, long
currentBlockWrittenBytes,
+ long maxBlockBytes, RowRangeByteEstimator estimator) throws
IOException {
+ int low = rowStart + 1;
+ int high = numRows - 1;
+ int bestEnd = rowStart;
+ long bestBytes = 0L;
+ while (low <= high) {
+ int mid = low + (high - low) / 2;
+ long rangeBytes = estimator.estimate(rowStart, mid);
+ if (currentBlockWrittenBytes + rangeBytes <= maxBlockBytes) {
+ bestEnd = mid;
+ bestBytes = rangeBytes;
+ low = mid + 1;
+ } else {
+ high = mid - 1;
}
}
+
+ if (bestEnd > rowStart) {
+ return RowRange.write(bestEnd, bestBytes);
+ }
+ if (currentBlockWrittenBytes > 0) {
+ return RowRange.rotateBeforeWrite();
+ }
+ return RowRange.write(rowStart + 1, estimator.estimate(rowStart,
rowStart + 1));
+ }
+
+ interface RowRangeByteEstimator {
+ long estimate(int rowStart, int rowEnd) throws IOException;
}
- private static class CountingDiscardOutputStream extends OutputStream {
- @Override
- public void write(int b) {
- // Discard bytes while allowing WriteChannel to track payload size.
+ static class RowRange {
+ final int rowEnd;
+ final long bytes;
+ final boolean rotateBeforeWrite;
+
+ private RowRange(int rowEnd, long bytes, boolean rotateBeforeWrite) {
+ this.rowEnd = rowEnd;
+ this.bytes = bytes;
+ this.rotateBeforeWrite = rotateBeforeWrite;
+ }
+
+ static RowRange write(int rowEnd, long bytes) {
+ return new RowRange(rowEnd, bytes, false);
}
- @Override
- public void write(byte[] b, int off, int len) {
- // Discard bytes while allowing WriteChannel to track payload size.
+ static RowRange rotateBeforeWrite() {
+ return new RowRange(-1, 0L, true);
}
}
- private long estimateSingleRowPayloadBytes(VectorTable inputTable, int
numCols, int rowIndex)
- throws IOException {
- try (VectorSchemaRoot root = buildRowRangeRoot(inputTable, numCols,
rowIndex, rowIndex + 1);
- ArrowWriter estimator =
ArrowWriterFactory.getRecordBatchWriter(
- new CountingDiscardOutputStream(), writerOptions)) {
- estimator.writeBatch(root);
- return estimator.bytesWritten();
+ // Estimate an Arrow batch's payload size from its column buffer sizes
(O(columns)).
+ static long estimateBatchPayloadBytes(VectorSchemaRoot root) {
+ long total = 0L;
+ for (FieldVector vector : root.getFieldVectors()) {
+ total += vector.getBufferSize();
}
+ return total;
}
private VectorSchemaRoot buildRowRangeRoot(VectorTable inputTable, int
numCols, int rowStart, int rowEnd) {
diff --git
a/fe/be-java-extensions/max-compute-connector/src/test/java/org/apache/doris/maxcompute/MaxComputeJniWriterTest.java
b/fe/be-java-extensions/max-compute-connector/src/test/java/org/apache/doris/maxcompute/MaxComputeJniWriterTest.java
new file mode 100644
index 00000000000..84cc4a79d6e
--- /dev/null
+++
b/fe/be-java-extensions/max-compute-connector/src/test/java/org/apache/doris/maxcompute/MaxComputeJniWriterTest.java
@@ -0,0 +1,133 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.maxcompute;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Collections;
+
+public class MaxComputeJniWriterTest {
+ @Test
+ public void testPrefixBufferBytesMeasuresLeadingRowsWithoutRebuild() {
+ try (BufferAllocator allocator = new RootAllocator();
+ IntVector vec = new IntVector("c", allocator)) {
+ vec.allocateNew(8);
+ for (int i = 0; i < 8; i++) {
+ vec.set(i, i);
+ }
+ vec.setValueCount(8);
+ try (VectorSchemaRoot root = new
VectorSchemaRoot(Collections.singletonList(vec))) {
+ // The whole-root measurement must match
estimateBatchPayloadBytes...
+
Assert.assertEquals(MaxComputeJniWriter.estimateBatchPayloadBytes(root),
+ MaxComputeJniWriter.prefixBufferBytes(root,
root.getRowCount()));
+ // ...and a leading prefix must be strictly smaller, computed
from the
+ // already-built buffers (no rebuild).
+ Assert.assertTrue(MaxComputeJniWriter.prefixBufferBytes(root,
4)
+ < MaxComputeJniWriter.prefixBufferBytes(root, 8));
+ }
+ }
+ }
+
+ @Test
+ public void testFindPartialRowRangeFillsRemainingBlock() throws Exception {
+ MaxComputeJniWriter.RowRange range =
MaxComputeJniWriter.findPartialRowRange(
+ 0, 4, 60L, 100L, prefixEstimator(10L, 20L, 30L, 40L));
+
+ Assert.assertFalse(range.rotateBeforeWrite);
+ Assert.assertEquals(2, range.rowEnd);
+ Assert.assertEquals(30L, range.bytes);
+ }
+
+ @Test
+ public void testFindPartialRowRangeRotatesWhenNoRowFitsNonEmptyBlock()
throws Exception {
+ MaxComputeJniWriter.RowRange range =
MaxComputeJniWriter.findPartialRowRange(
+ 0, 3, 95L, 100L, prefixEstimator(10L, 20L, 30L));
+
+ Assert.assertTrue(range.rotateBeforeWrite);
+ }
+
+ @Test
+ public void
testFindPartialRowRangeKeepsSingleOversizeFallbackOnEmptyBlock() throws
Exception {
+ MaxComputeJniWriter.RowRange range =
MaxComputeJniWriter.findPartialRowRange(
+ 0, 3, 0L, 5L, prefixEstimator(10L, 20L, 30L));
+
+ Assert.assertFalse(range.rotateBeforeWrite);
+ Assert.assertEquals(1, range.rowEnd);
+ Assert.assertEquals(10L, range.bytes);
+ }
+
+ @Test
+ public void testFindPartialRowRangeUsesRowStartOffset() throws Exception {
+ MaxComputeJniWriter.RowRange range =
MaxComputeJniWriter.findPartialRowRange(
+ 1, 4, 50L, 100L, prefixEstimator(999L, 30L, 30L, 50L));
+
+ Assert.assertFalse(range.rotateBeforeWrite);
+ Assert.assertEquals(2, range.rowEnd);
+ Assert.assertEquals(30L, range.bytes);
+ }
+
+ @Test
+ public void testBoundedProbeRowCountBootstrapsWithSingleRow() {
+ // No per-row estimate yet: bootstrap by measuring a single row's real
Arrow payload,
+ // so an oversized input is never copied whole and we never guess a
row count.
+ int probeRows = MaxComputeJniWriter.boundedProbeRowCount(0L, 64L *
1024 * 1024, 1_000_000);
+
+ Assert.assertEquals(1, probeRows);
+ }
+
+ @Test
+ public void testBoundedProbeRowCountTargetsOneBlockAfterMeasurement() {
+ // 1 KiB/row against a 64 MiB block => ~65536 rows fill one block.
+ int probeRows = MaxComputeJniWriter.boundedProbeRowCount(1024L, 64L *
1024 * 1024, 1_000_000);
+
+ Assert.assertEquals(65536, probeRows);
+ Assert.assertTrue(probeRows < 1_000_000);
+ }
+
+ @Test
+ public void testBoundedProbeRowCountReturnsRemainingWhenItFitsCap() {
+ // A small input that comfortably fits one block is probed in one shot.
+ int probeRows = MaxComputeJniWriter.boundedProbeRowCount(1024L, 64L *
1024 * 1024, 4096);
+
+ Assert.assertEquals(4096, probeRows);
+ }
+
+ @Test
+ public void testBoundedProbeRowCountProbesSingleRowWhenRowExceedsBlock() {
+ // A single row larger than a whole block must still make progress
(never 0 rows).
+ int probeRows = MaxComputeJniWriter.boundedProbeRowCount(
+ 128L * 1024 * 1024, 64L * 1024 * 1024, 1_000_000);
+
+ Assert.assertEquals(1, probeRows);
+ }
+
+ private static MaxComputeJniWriter.RowRangeByteEstimator
prefixEstimator(long... rowBytes) {
+ return (rowStart, rowEnd) -> {
+ long bytes = 0L;
+ for (int i = rowStart; i < rowEnd; i++) {
+ bytes += rowBytes[i];
+ }
+ return bytes;
+ };
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]