This is an automated email from the ASF dual-hosted git repository. morningman pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-doris-spark-connector.git
commit cd46034a84ef5e6ef9e359a7d3afb8f1b2535854 Author: Youngwb <yangwenbo_mail...@163.com> AuthorDate: Thu Mar 26 21:34:37 2020 +0800 [Spark] Support convert Arrow data to RowBatch asynchronously in Spark-Doris-Connector (#3186) Currently, in the Spark-Doris-Connector, when Spark iteratively obtains each row of data, it needs to synchronously convert the Arrow format data into the row format required by Spark. In order to speed up the conversion process, we can add an asynchronous thread in the Connector, which is responsible for obtaining the Arrow format data from BE and converting it into the row format required by Spark calculation In our test environment, Doris cluster used 1 fe and 7 be (32C+128G). When using Spark-Doris-Connector to query a table containing 67 columns, the original query returned 69 million rows of data took about 2.5min, but after improvement, it reduced to about 1.6min, which reduced the time by about 30% --- README.md | 2 + .../doris/spark/cfg/ConfigurationOptions.java | 6 ++ .../apache/doris/spark/serialization/RowBatch.java | 67 ++++++------- .../org/apache/doris/spark/util/ErrorMessages.java | 1 + .../apache/doris/spark/rdd/ScalaValueReader.scala | 106 ++++++++++++++++++--- 5 files changed, 129 insertions(+), 53 deletions(-) diff --git a/README.md b/README.md index d32db83..3c41b93 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,8 @@ dorisSparkRDD.collect() | doris.request.tablet.size | Integer.MAX_VALUE | 一个RDD Partition对应的Doris Tablet个数。<br />此数值设置越小,则会生成越多的Partition。<br />从而提升Spark侧的并行度,但同时会对Doris造成更大的压力。 | | doris.batch.size | 1024 | 一次从BE读取数据的最大行数。<br />增大此数值可减少Spark与Doris之间建立连接的次数。<br />从而减轻网络延迟所带来的的额外时间开销。 | | doris.exec.mem.limit | 2147483648 | 单个查询的内存限制。默认为 2GB,单位为字节 | +| doris.deserialize.arrow.async | false | 是否支持异步转换Arrow格式到spark-doris-connector迭代所需的RowBatch | +| doris.deserialize.queue.size | 64 | 异步转换Arrow格式的内部处理队列,当doris.deserialize.arrow.async为true时生效 | ### SQL and Dataframe Only diff --git a/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java b/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java index 742c3eb..1bb5dfc 100644 --- a/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java +++ b/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java @@ -57,4 +57,10 @@ public interface ConfigurationOptions { long DORIS_EXEC_MEM_LIMIT_DEFAULT = 2147483648L; String DORIS_VALUE_READER_CLASS = "doris.value.reader.class"; + + String DORIS_DESERIALIZE_ARROW_ASYNC = "doris.deserialize.arrow.async"; + boolean DORIS_DESERIALIZE_ARROW_ASYNC_DEFAULT = false; + + String DORIS_DESERIALIZE_QUEUE_SIZE = "doris.deserialize.queue.size"; + int DORIS_DESERIALIZE_QUEUE_SIZE_DEFAULT = 64; } diff --git a/src/main/java/org/apache/doris/spark/serialization/RowBatch.java b/src/main/java/org/apache/doris/spark/serialization/RowBatch.java index d710fbb..0781f1e 100644 --- a/src/main/java/org/apache/doris/spark/serialization/RowBatch.java +++ b/src/main/java/org/apache/doris/spark/serialization/RowBatch.java @@ -70,7 +70,8 @@ public class RowBatch { } } - private int offsetInOneBatch = 0; + // offset for iterate the rowBatch + private int offsetInRowBatch = 0; private int rowCountInOneBatch = 0; private int readRowCount = 0; private List<Row> rowBatch = new ArrayList<>(); @@ -87,50 +88,40 @@ public class RowBatch { new ByteArrayInputStream(nextResult.getRows()), rootAllocator ); + this.offsetInRowBatch = 0; try { this.root = arrowStreamReader.getVectorSchemaRoot(); + while (arrowStreamReader.loadNextBatch()) { + fieldVectors = root.getFieldVectors(); + if (fieldVectors.size() != schema.size()) { + logger.error("Schema size '{}' is not equal to arrow field size '{}'.", + fieldVectors.size(), schema.size()); + throw new DorisException("Load Doris data failed, schema size of fetch data is wrong."); + } + if (fieldVectors.size() == 0 || root.getRowCount() == 0) { + logger.debug("One batch in arrow has no data."); + continue; + } + rowCountInOneBatch = root.getRowCount(); + // init the rowBatch + for (int i = 0; i < rowCountInOneBatch; ++i) { + rowBatch.add(new Row(fieldVectors.size())); + } + convertArrowToRowBatch(); + readRowCount += root.getRowCount(); + } } catch (Exception e) { logger.error("Read Doris Data failed because: ", e); - close(); throw new DorisException(e.getMessage()); + } finally { + close(); } } - public boolean hasNext() throws DorisException { - if (offsetInOneBatch < rowCountInOneBatch) { + public boolean hasNext() { + if (offsetInRowBatch < readRowCount) { return true; } - try { - try { - while (arrowStreamReader.loadNextBatch()) { - fieldVectors = root.getFieldVectors(); - readRowCount += root.getRowCount(); - if (fieldVectors.size() != schema.size()) { - logger.error("Schema size '{}' is not equal to arrow field size '{}'.", - fieldVectors.size(), schema.size()); - throw new DorisException("Load Doris data failed, schema size of fetch data is wrong."); - } - if (fieldVectors.size() == 0 || root.getRowCount() == 0) { - logger.debug("One batch in arrow has no data."); - continue; - } - offsetInOneBatch = 0; - rowCountInOneBatch = root.getRowCount(); - // init the rowBatch - for (int i = 0; i < rowCountInOneBatch; ++i) { - rowBatch.add(new Row(fieldVectors.size())); - } - convertArrowToRowBatch(); - return true; - } - } catch (IOException e) { - logger.error("Load arrow next batch failed.", e); - throw new DorisException("Cannot load arrow next batch fetching from Doris."); - } - } catch (Exception e) { - close(); - throw e; - } return false; } @@ -141,7 +132,7 @@ public class RowBatch { logger.error(errMsg); throw new NoSuchElementException(errMsg); } - rowBatch.get(rowIndex).put(obj); + rowBatch.get(readRowCount + rowIndex).put(obj); } public void convertArrowToRowBatch() throws DorisException { @@ -295,11 +286,11 @@ public class RowBatch { public List<Object> next() throws DorisException { if (!hasNext()) { - String errMsg = "Get row offset:" + offsetInOneBatch + " larger than row size: " + rowCountInOneBatch; + String errMsg = "Get row offset:" + offsetInRowBatch + " larger than row size: " + readRowCount; logger.error(errMsg); throw new NoSuchElementException(errMsg); } - return rowBatch.get(offsetInOneBatch++).getCols(); + return rowBatch.get(offsetInRowBatch++).getCols(); } private String typeMismatchMessage(final String sparkType, final Types.MinorType arrowType) { diff --git a/src/main/java/org/apache/doris/spark/util/ErrorMessages.java b/src/main/java/org/apache/doris/spark/util/ErrorMessages.java index 92a04e9..44ca28b 100644 --- a/src/main/java/org/apache/doris/spark/util/ErrorMessages.java +++ b/src/main/java/org/apache/doris/spark/util/ErrorMessages.java @@ -19,6 +19,7 @@ package org.apache.doris.spark.util; public abstract class ErrorMessages { public static final String PARSE_NUMBER_FAILED_MESSAGE = "Parse '{}' to number failed. Original string is '{}'."; + public static final String PARSE_BOOL_FAILED_MESSAGE = "Parse '{}' to boolean failed. Original string is '{}'."; public static final String CONNECT_FAILED_MESSAGE = "Connect to doris {} failed."; public static final String ILLEGAL_ARGUMENT_MESSAGE = "argument '{}' is illegal, value is '{}'."; public static final String SHOULD_NOT_HAPPEN_MESSAGE = "Should not come here."; diff --git a/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala b/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala index 13a955a..1d22c42 100644 --- a/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala +++ b/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala @@ -17,9 +17,11 @@ package org.apache.doris.spark.rdd +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent._ + import scala.collection.JavaConversions._ import scala.util.Try - import org.apache.doris.spark.backend.BackendClient import org.apache.doris.spark.cfg.ConfigurationOptions._ import org.apache.doris.spark.cfg.Settings @@ -31,9 +33,10 @@ import org.apache.doris.spark.sql.SchemaUtils import org.apache.doris.spark.util.ErrorMessages import org.apache.doris.spark.util.ErrorMessages.SHOULD_NOT_HAPPEN_MESSAGE import org.apache.doris.thrift.{TScanCloseParams, TScanNextBatchParams, TScanOpenParams, TScanOpenResult} - import org.apache.log4j.Logger +import scala.util.control.Breaks + /** * read data from Doris BE to array. * @param partition Doris RDD partition @@ -44,8 +47,30 @@ class ScalaValueReader(partition: PartitionDefinition, settings: Settings) { protected val client = new BackendClient(new Routing(partition.getBeAddress), settings) protected var offset = 0 - protected var eos: Boolean = false + protected var eos: AtomicBoolean = new AtomicBoolean(false) protected var rowBatch: RowBatch = _ + // flag indicate if support deserialize Arrow to RowBatch asynchronously + protected var deserializeArrowToRowBatchAsync: Boolean = Try { + settings.getProperty(DORIS_DESERIALIZE_ARROW_ASYNC, DORIS_DESERIALIZE_ARROW_ASYNC_DEFAULT.toString).toBoolean + } getOrElse { + logger.warn(ErrorMessages.PARSE_BOOL_FAILED_MESSAGE, DORIS_DESERIALIZE_ARROW_ASYNC, settings.getProperty(DORIS_DESERIALIZE_ARROW_ASYNC)) + DORIS_DESERIALIZE_ARROW_ASYNC_DEFAULT + } + + protected var rowBatchBlockingQueue: BlockingQueue[RowBatch] = { + val blockingQueueSize = Try { + settings.getProperty(DORIS_DESERIALIZE_QUEUE_SIZE, DORIS_DESERIALIZE_QUEUE_SIZE_DEFAULT.toString).toInt + } getOrElse { + logger.warn(ErrorMessages.PARSE_NUMBER_FAILED_MESSAGE, DORIS_DESERIALIZE_QUEUE_SIZE, settings.getProperty(DORIS_DESERIALIZE_QUEUE_SIZE)) + DORIS_DESERIALIZE_QUEUE_SIZE_DEFAULT + } + + var queue: BlockingQueue[RowBatch] = null + if (deserializeArrowToRowBatchAsync) { + queue = new ArrayBlockingQueue(blockingQueueSize) + } + queue + } private val openParams: TScanOpenParams = { val params = new TScanOpenParams @@ -103,6 +128,33 @@ class ScalaValueReader(partition: PartitionDefinition, settings: Settings) { protected val schema: Schema = SchemaUtils.convertToSchema(openResult.getSelected_columns) + protected val asyncThread: Thread = new Thread { + override def run { + val nextBatchParams = new TScanNextBatchParams + nextBatchParams.setContext_id(contextId) + while (!eos.get) { + nextBatchParams.setOffset(offset) + val nextResult = client.getNext(nextBatchParams) + eos.set(nextResult.isEos) + if (!eos.get) { + val rowBatch = new RowBatch(nextResult, schema) + offset += rowBatch.getReadRowCount + rowBatch.close + rowBatchBlockingQueue.put(rowBatch) + } + } + } + } + + protected val asyncThreadStarted: Boolean = { + var started = false + if (deserializeArrowToRowBatchAsync) { + asyncThread.start + started = true + } + started + } + logger.debug(s"Open scan result is, contextId: $contextId, schema: $schema.") /** @@ -110,21 +162,45 @@ class ScalaValueReader(partition: PartitionDefinition, settings: Settings) { * @return true if hax next value */ def hasNext: Boolean = { - if (!eos && (rowBatch == null || !rowBatch.hasNext)) { - if (rowBatch != null) { - offset += rowBatch.getReadRowCount - rowBatch.close + var hasNext = false + if (deserializeArrowToRowBatchAsync && asyncThreadStarted) { + // support deserialize Arrow to RowBatch asynchronously + if (rowBatch == null || !rowBatch.hasNext) { + val loop = new Breaks + loop.breakable { + while (!eos.get || !rowBatchBlockingQueue.isEmpty) { + if (!rowBatchBlockingQueue.isEmpty) { + rowBatch = rowBatchBlockingQueue.take + hasNext = true + loop.break + } else { + // wait for rowBatch put in queue or eos change + Thread.sleep(5) + } + } + } + } else { + hasNext = true } - val nextBatchParams = new TScanNextBatchParams - nextBatchParams.setContext_id(contextId) - nextBatchParams.setOffset(offset) - val nextResult = client.getNext(nextBatchParams) - eos = nextResult.isEos - if (!eos) { - rowBatch = new RowBatch(nextResult, schema) + } else { + // Arrow data was acquired synchronously during the iterative process + if (!eos.get && (rowBatch == null || !rowBatch.hasNext)) { + if (rowBatch != null) { + offset += rowBatch.getReadRowCount + rowBatch.close + } + val nextBatchParams = new TScanNextBatchParams + nextBatchParams.setContext_id(contextId) + nextBatchParams.setOffset(offset) + val nextResult = client.getNext(nextBatchParams) + eos.set(nextResult.isEos) + if (!eos.get) { + rowBatch = new RowBatch(nextResult, schema) + } } + hasNext = !eos.get } - !eos + hasNext } /** --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org