This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch master
in repository 
https://gitbox.apache.org/repos/asf/incubator-doris-spark-connector.git

commit cd46034a84ef5e6ef9e359a7d3afb8f1b2535854
Author: Youngwb <yangwenbo_mail...@163.com>
AuthorDate: Thu Mar 26 21:34:37 2020 +0800

    [Spark] Support convert  Arrow data to RowBatch asynchronously in 
Spark-Doris-Connector (#3186)
    
    Currently, in the Spark-Doris-Connector, when Spark iteratively obtains 
each row of data,
    it needs to synchronously convert the Arrow format data into the row format 
required by Spark.
    In order to speed up the conversion process, we can add an asynchronous 
thread in the Connector,
    which is responsible for obtaining the Arrow format data from BE and 
converting it into the row
    format required by Spark calculation
    
    In our test environment, Doris cluster used 1 fe and 7 be (32C+128G). When 
using Spark-Doris-Connector
    to query a table containing 67 columns, the original query returned 69 
million rows of data
    took about 2.5min, but after improvement, it reduced to about 1.6min, which 
reduced the time by about 30%
---
 README.md                                          |   2 +
 .../doris/spark/cfg/ConfigurationOptions.java      |   6 ++
 .../apache/doris/spark/serialization/RowBatch.java |  67 ++++++-------
 .../org/apache/doris/spark/util/ErrorMessages.java |   1 +
 .../apache/doris/spark/rdd/ScalaValueReader.scala  | 106 ++++++++++++++++++---
 5 files changed, 129 insertions(+), 53 deletions(-)

diff --git a/README.md b/README.md
index d32db83..3c41b93 100644
--- a/README.md
+++ b/README.md
@@ -102,6 +102,8 @@ dorisSparkRDD.collect()
 | doris.request.tablet.size        | Integer.MAX_VALUE | 一个RDD 
Partition对应的Doris Tablet个数。<br />此数值设置越小,则会生成越多的Partition。<br 
/>从而提升Spark侧的并行度,但同时会对Doris造成更大的压力。 |
 | doris.batch.size                 | 1024              | 一次从BE读取数据的最大行数。<br 
/>增大此数值可减少Spark与Doris之间建立连接的次数。<br />从而减轻网络延迟所带来的的额外时间开销。 |
 | doris.exec.mem.limit             | 2147483648        | 单个查询的内存限制。默认为 
2GB,单位为字节                      |
+| doris.deserialize.arrow.async    | false             | 
是否支持异步转换Arrow格式到spark-doris-connector迭代所需的RowBatch                 |
+| doris.deserialize.queue.size     | 64                | 
异步转换Arrow格式的内部处理队列,当doris.deserialize.arrow.async为true时生效        |
 
 ### SQL and Dataframe Only
 
diff --git a/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java 
b/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
index 742c3eb..1bb5dfc 100644
--- a/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
+++ b/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
@@ -57,4 +57,10 @@ public interface ConfigurationOptions {
     long DORIS_EXEC_MEM_LIMIT_DEFAULT = 2147483648L;
 
     String DORIS_VALUE_READER_CLASS = "doris.value.reader.class";
+
+    String DORIS_DESERIALIZE_ARROW_ASYNC = "doris.deserialize.arrow.async";
+    boolean DORIS_DESERIALIZE_ARROW_ASYNC_DEFAULT = false;
+
+    String DORIS_DESERIALIZE_QUEUE_SIZE = "doris.deserialize.queue.size";
+    int DORIS_DESERIALIZE_QUEUE_SIZE_DEFAULT = 64;
 }
diff --git a/src/main/java/org/apache/doris/spark/serialization/RowBatch.java 
b/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
index d710fbb..0781f1e 100644
--- a/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
+++ b/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
@@ -70,7 +70,8 @@ public class RowBatch {
         }
     }
 
-    private int offsetInOneBatch = 0;
+    // offset for iterate the rowBatch
+    private int offsetInRowBatch = 0;
     private int rowCountInOneBatch = 0;
     private int readRowCount = 0;
     private List<Row> rowBatch = new ArrayList<>();
@@ -87,50 +88,40 @@ public class RowBatch {
                 new ByteArrayInputStream(nextResult.getRows()),
                 rootAllocator
                 );
+        this.offsetInRowBatch = 0;
         try {
             this.root = arrowStreamReader.getVectorSchemaRoot();
+            while (arrowStreamReader.loadNextBatch()) {
+                fieldVectors = root.getFieldVectors();
+                if (fieldVectors.size() != schema.size()) {
+                    logger.error("Schema size '{}' is not equal to arrow field 
size '{}'.",
+                            fieldVectors.size(), schema.size());
+                    throw new DorisException("Load Doris data failed, schema 
size of fetch data is wrong.");
+                }
+                if (fieldVectors.size() == 0 || root.getRowCount() == 0) {
+                    logger.debug("One batch in arrow has no data.");
+                    continue;
+                }
+                rowCountInOneBatch = root.getRowCount();
+                // init the rowBatch
+                for (int i = 0; i < rowCountInOneBatch; ++i) {
+                    rowBatch.add(new Row(fieldVectors.size()));
+                }
+                convertArrowToRowBatch();
+                readRowCount += root.getRowCount();
+            }
         } catch (Exception e) {
             logger.error("Read Doris Data failed because: ", e);
-            close();
             throw new DorisException(e.getMessage());
+        } finally {
+            close();
         }
     }
 
-    public boolean hasNext() throws DorisException {
-        if (offsetInOneBatch < rowCountInOneBatch) {
+    public boolean hasNext() {
+        if (offsetInRowBatch < readRowCount) {
             return true;
         }
-        try {
-            try {
-                while (arrowStreamReader.loadNextBatch()) {
-                    fieldVectors = root.getFieldVectors();
-                    readRowCount += root.getRowCount();
-                    if (fieldVectors.size() != schema.size()) {
-                        logger.error("Schema size '{}' is not equal to arrow 
field size '{}'.",
-                                fieldVectors.size(), schema.size());
-                        throw new DorisException("Load Doris data failed, 
schema size of fetch data is wrong.");
-                    }
-                    if (fieldVectors.size() == 0 || root.getRowCount() == 0) {
-                        logger.debug("One batch in arrow has no data.");
-                        continue;
-                    }
-                    offsetInOneBatch = 0;
-                    rowCountInOneBatch = root.getRowCount();
-                    // init the rowBatch
-                    for (int i = 0; i < rowCountInOneBatch; ++i) {
-                        rowBatch.add(new Row(fieldVectors.size()));
-                    }
-                    convertArrowToRowBatch();
-                    return true;
-                }
-            } catch (IOException e) {
-                logger.error("Load arrow next batch failed.", e);
-                throw new DorisException("Cannot load arrow next batch 
fetching from Doris.");
-            }
-        } catch (Exception e) {
-            close();
-            throw e;
-        }
         return false;
     }
 
@@ -141,7 +132,7 @@ public class RowBatch {
             logger.error(errMsg);
             throw new NoSuchElementException(errMsg);
         }
-        rowBatch.get(rowIndex).put(obj);
+        rowBatch.get(readRowCount + rowIndex).put(obj);
     }
 
     public void convertArrowToRowBatch() throws DorisException {
@@ -295,11 +286,11 @@ public class RowBatch {
 
     public List<Object> next() throws DorisException {
         if (!hasNext()) {
-            String errMsg = "Get row offset:" + offsetInOneBatch + " larger 
than row size: " + rowCountInOneBatch;
+            String errMsg = "Get row offset:" + offsetInRowBatch + " larger 
than row size: " + readRowCount;
             logger.error(errMsg);
             throw new NoSuchElementException(errMsg);
         }
-        return rowBatch.get(offsetInOneBatch++).getCols();
+        return rowBatch.get(offsetInRowBatch++).getCols();
     }
 
     private String typeMismatchMessage(final String sparkType, final 
Types.MinorType arrowType) {
diff --git a/src/main/java/org/apache/doris/spark/util/ErrorMessages.java 
b/src/main/java/org/apache/doris/spark/util/ErrorMessages.java
index 92a04e9..44ca28b 100644
--- a/src/main/java/org/apache/doris/spark/util/ErrorMessages.java
+++ b/src/main/java/org/apache/doris/spark/util/ErrorMessages.java
@@ -19,6 +19,7 @@ package org.apache.doris.spark.util;
 
 public abstract class ErrorMessages {
     public static final String PARSE_NUMBER_FAILED_MESSAGE = "Parse '{}' to 
number failed. Original string is '{}'.";
+    public static final String PARSE_BOOL_FAILED_MESSAGE = "Parse '{}' to 
boolean failed. Original string is '{}'.";
     public static final String CONNECT_FAILED_MESSAGE = "Connect to doris {} 
failed.";
     public static final String ILLEGAL_ARGUMENT_MESSAGE = "argument '{}' is 
illegal, value is '{}'.";
     public static final String SHOULD_NOT_HAPPEN_MESSAGE = "Should not come 
here.";
diff --git a/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala 
b/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala
index 13a955a..1d22c42 100644
--- a/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala
+++ b/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala
@@ -17,9 +17,11 @@
 
 package org.apache.doris.spark.rdd
 
+import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent._
+
 import scala.collection.JavaConversions._
 import scala.util.Try
-
 import org.apache.doris.spark.backend.BackendClient
 import org.apache.doris.spark.cfg.ConfigurationOptions._
 import org.apache.doris.spark.cfg.Settings
@@ -31,9 +33,10 @@ import org.apache.doris.spark.sql.SchemaUtils
 import org.apache.doris.spark.util.ErrorMessages
 import org.apache.doris.spark.util.ErrorMessages.SHOULD_NOT_HAPPEN_MESSAGE
 import org.apache.doris.thrift.{TScanCloseParams, TScanNextBatchParams, 
TScanOpenParams, TScanOpenResult}
-
 import org.apache.log4j.Logger
 
+import scala.util.control.Breaks
+
 /**
  * read data from Doris BE to array.
  * @param partition Doris RDD partition
@@ -44,8 +47,30 @@ class ScalaValueReader(partition: PartitionDefinition, 
settings: Settings) {
 
   protected val client = new BackendClient(new 
Routing(partition.getBeAddress), settings)
   protected var offset = 0
-  protected var eos: Boolean = false
+  protected var eos: AtomicBoolean = new AtomicBoolean(false)
   protected var rowBatch: RowBatch = _
+  // flag indicate if support deserialize Arrow to RowBatch asynchronously
+  protected var deserializeArrowToRowBatchAsync: Boolean = Try {
+    settings.getProperty(DORIS_DESERIALIZE_ARROW_ASYNC, 
DORIS_DESERIALIZE_ARROW_ASYNC_DEFAULT.toString).toBoolean
+  } getOrElse {
+    logger.warn(ErrorMessages.PARSE_BOOL_FAILED_MESSAGE, 
DORIS_DESERIALIZE_ARROW_ASYNC, 
settings.getProperty(DORIS_DESERIALIZE_ARROW_ASYNC))
+    DORIS_DESERIALIZE_ARROW_ASYNC_DEFAULT
+  }
+
+  protected var rowBatchBlockingQueue: BlockingQueue[RowBatch] = {
+    val blockingQueueSize = Try {
+      settings.getProperty(DORIS_DESERIALIZE_QUEUE_SIZE, 
DORIS_DESERIALIZE_QUEUE_SIZE_DEFAULT.toString).toInt
+    } getOrElse {
+      logger.warn(ErrorMessages.PARSE_NUMBER_FAILED_MESSAGE, 
DORIS_DESERIALIZE_QUEUE_SIZE, 
settings.getProperty(DORIS_DESERIALIZE_QUEUE_SIZE))
+      DORIS_DESERIALIZE_QUEUE_SIZE_DEFAULT
+    }
+
+    var queue: BlockingQueue[RowBatch] = null
+    if (deserializeArrowToRowBatchAsync) {
+      queue = new ArrayBlockingQueue(blockingQueueSize)
+    }
+    queue
+  }
 
   private val openParams: TScanOpenParams = {
     val params = new TScanOpenParams
@@ -103,6 +128,33 @@ class ScalaValueReader(partition: PartitionDefinition, 
settings: Settings) {
   protected val schema: Schema =
     SchemaUtils.convertToSchema(openResult.getSelected_columns)
 
+  protected val asyncThread: Thread = new Thread {
+    override def run {
+      val nextBatchParams = new TScanNextBatchParams
+      nextBatchParams.setContext_id(contextId)
+      while (!eos.get) {
+        nextBatchParams.setOffset(offset)
+        val nextResult = client.getNext(nextBatchParams)
+        eos.set(nextResult.isEos)
+        if (!eos.get) {
+          val rowBatch = new RowBatch(nextResult, schema)
+          offset += rowBatch.getReadRowCount
+          rowBatch.close
+          rowBatchBlockingQueue.put(rowBatch)
+        }
+      }
+    }
+  }
+
+  protected val asyncThreadStarted: Boolean = {
+    var started = false
+    if (deserializeArrowToRowBatchAsync) {
+      asyncThread.start
+      started = true
+    }
+    started
+  }
+
   logger.debug(s"Open scan result is, contextId: $contextId, schema: $schema.")
 
   /**
@@ -110,21 +162,45 @@ class ScalaValueReader(partition: PartitionDefinition, 
settings: Settings) {
    * @return true if hax next value
    */
   def hasNext: Boolean = {
-    if (!eos && (rowBatch == null || !rowBatch.hasNext)) {
-      if (rowBatch != null) {
-        offset += rowBatch.getReadRowCount
-        rowBatch.close
+    var hasNext = false
+    if (deserializeArrowToRowBatchAsync && asyncThreadStarted) {
+      // support deserialize Arrow to RowBatch asynchronously
+      if (rowBatch == null || !rowBatch.hasNext) {
+        val loop = new Breaks
+        loop.breakable {
+          while (!eos.get || !rowBatchBlockingQueue.isEmpty) {
+            if (!rowBatchBlockingQueue.isEmpty) {
+              rowBatch = rowBatchBlockingQueue.take
+              hasNext = true
+              loop.break
+            } else {
+              // wait for rowBatch put in queue or eos change
+              Thread.sleep(5)
+            }
+          }
+        }
+      } else {
+        hasNext = true
       }
-      val nextBatchParams = new TScanNextBatchParams
-      nextBatchParams.setContext_id(contextId)
-      nextBatchParams.setOffset(offset)
-      val nextResult = client.getNext(nextBatchParams)
-      eos = nextResult.isEos
-      if (!eos) {
-        rowBatch = new RowBatch(nextResult, schema)
+    } else {
+      // Arrow data was acquired synchronously during the iterative process
+      if (!eos.get && (rowBatch == null || !rowBatch.hasNext)) {
+        if (rowBatch != null) {
+          offset += rowBatch.getReadRowCount
+          rowBatch.close
+        }
+        val nextBatchParams = new TScanNextBatchParams
+        nextBatchParams.setContext_id(contextId)
+        nextBatchParams.setOffset(offset)
+        val nextResult = client.getNext(nextBatchParams)
+        eos.set(nextResult.isEos)
+        if (!eos.get) {
+          rowBatch = new RowBatch(nextResult, schema)
+        }
       }
+      hasNext = !eos.get
     }
-    !eos
+    hasNext
   }
 
   /**

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to