This is an automated email from the ASF dual-hosted git repository.

diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git


The following commit(s) were added to refs/heads/master by this push:
     new 9dd57b0  [improvement] batch load retry (#148)
9dd57b0 is described below

commit 9dd57b004afd91320d493c18ff53eb91bea3125d
Author: gnehil <adamlee...@gmail.com>
AuthorDate: Wed Oct 25 15:45:20 2023 +0800

    [improvement] batch load retry (#148)
    
    Co-authored-by: gnehil <adamlee...@gamil.com>
---
 .../doris/spark/cfg/ConfigurationOptions.java      |  2 +-
 .../apache/doris/spark/load/DorisStreamLoad.java   |  4 -
 .../org/apache/doris/spark/load/RecordBatch.java   | 21 +----
 .../doris/spark/load/RecordBatchInputStream.java   | 16 ++--
 .../spark/listener/DorisTransactionListener.scala  |  8 +-
 .../scala/org/apache/doris/spark/sql/Utils.scala   | 27 ++++---
 .../apache/doris/spark/writer/DorisWriter.scala    | 91 ++++++++++++++++++----
 7 files changed, 107 insertions(+), 62 deletions(-)

diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
index a6767f0..a144fb8 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
@@ -70,7 +70,7 @@ public interface ConfigurationOptions {
     int SINK_BATCH_SIZE_DEFAULT = 100000;
 
     String DORIS_SINK_MAX_RETRIES = "doris.sink.max-retries";
-    int SINK_MAX_RETRIES_DEFAULT = 1;
+    int SINK_MAX_RETRIES_DEFAULT = 0;
 
     String DORIS_MAX_FILTER_RATIO = "doris.max.filter.ratio";
 
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
index c524a4c..338ffbe 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
@@ -98,7 +98,6 @@ public class DorisStreamLoad implements Serializable {
     private String FIELD_DELIMITER;
     private final String LINE_DELIMITER;
     private boolean streamingPassthrough = false;
-    private final Integer batchSize;
     private final boolean enable2PC;
     private final Integer txnRetries;
     private final Integer txnIntervalMs;
@@ -128,8 +127,6 @@ public class DorisStreamLoad implements Serializable {
         LINE_DELIMITER = 
escapeString(streamLoadProp.getOrDefault("line_delimiter", "\n"));
         this.streamingPassthrough = 
settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH,
                 ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH_DEFAULT);
-        this.batchSize = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE,
-                ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT);
         this.enable2PC = 
settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC,
                 ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT);
         this.txnRetries = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_RETRIES,
@@ -200,7 +197,6 @@ public class DorisStreamLoad implements Serializable {
             this.loadUrlStr = loadUrlStr;
             HttpPut httpPut = getHttpPut(label, loadUrlStr, enable2PC);
             RecordBatchInputStream recodeBatchInputStream = new 
RecordBatchInputStream(RecordBatch.newBuilder(rows)
-                    .batchSize(batchSize)
                     .format(fileType)
                     .sep(FIELD_DELIMITER)
                     .delim(LINE_DELIMITER)
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
index 4ce297f..e471d5b 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
@@ -36,11 +36,6 @@ public class RecordBatch {
      */
     private final Iterator<InternalRow> iterator;
 
-    /**
-     * batch size for single load
-     */
-    private final int batchSize;
-
     /**
      * stream load format
      */
@@ -63,10 +58,9 @@ public class RecordBatch {
 
     private final boolean addDoubleQuotes;
 
-    private RecordBatch(Iterator<InternalRow> iterator, int batchSize, String 
format, String sep, byte[] delim,
+    private RecordBatch(Iterator<InternalRow> iterator, String format, String 
sep, byte[] delim,
                         StructType schema, boolean addDoubleQuotes) {
         this.iterator = iterator;
-        this.batchSize = batchSize;
         this.format = format;
         this.sep = sep;
         this.delim = delim;
@@ -78,10 +72,6 @@ public class RecordBatch {
         return iterator;
     }
 
-    public int getBatchSize() {
-        return batchSize;
-    }
-
     public String getFormat() {
         return format;
     }
@@ -112,8 +102,6 @@ public class RecordBatch {
 
         private final Iterator<InternalRow> iterator;
 
-        private int batchSize;
-
         private String format;
 
         private String sep;
@@ -128,11 +116,6 @@ public class RecordBatch {
             this.iterator = iterator;
         }
 
-        public Builder batchSize(int batchSize) {
-            this.batchSize = batchSize;
-            return this;
-        }
-
         public Builder format(String format) {
             this.format = format;
             return this;
@@ -159,7 +142,7 @@ public class RecordBatch {
         }
 
         public RecordBatch build() {
-            return new RecordBatch(iterator, batchSize, format, sep, delim, 
schema, addDoubleQuotes);
+            return new RecordBatch(iterator, format, sep, delim, schema, 
addDoubleQuotes);
         }
 
     }
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
index a361c39..047ac3b 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
@@ -60,11 +60,6 @@ public class RecordBatchInputStream extends InputStream {
 
     private final byte[] delim;
 
-    /**
-     * record count has been read
-     */
-    private int readCount = 0;
-
     /**
      * streaming mode pass through data without process
      */
@@ -122,12 +117,12 @@ public class RecordBatchInputStream extends InputStream {
      */
     public boolean endOfBatch() throws DorisException {
         Iterator<InternalRow> iterator = recordBatch.getIterator();
-        if (readCount >= recordBatch.getBatchSize() || !iterator.hasNext()) {
-            delimBuf = null;
-            return true;
+        if (iterator.hasNext()) {
+            readNext(iterator);
+            return false;
         }
-        readNext(iterator);
-        return false;
+        delimBuf = null;
+        return true;
     }
 
     /**
@@ -149,7 +144,6 @@ public class RecordBatchInputStream extends InputStream {
             delimBuf =  ByteBuffer.wrap(delim);
             lineBuf = ByteBuffer.wrap(rowBytes);
         }
-        readCount++;
     }
 
     /**
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala
index e5991de..e670a30 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala
@@ -47,8 +47,8 @@ class DorisTransactionListener(preCommittedTxnAcc: 
CollectionAccumulator[Int], d
         txnIds.foreach(txnId =>
           Utils.retry(sinkTxnRetries, Duration.ofMillis(sinkTnxIntervalMs), 
logger) {
             dorisStreamLoad.commit(txnId)
-          } match {
-            case Success(_) =>
+          } () match {
+            case Success(_) => // do nothing
             case Failure(_) => failedTxnIds += txnId
           }
         )
@@ -68,8 +68,8 @@ class DorisTransactionListener(preCommittedTxnAcc: 
CollectionAccumulator[Int], d
         txnIds.foreach(txnId =>
           Utils.retry(sinkTxnRetries, Duration.ofMillis(sinkTnxIntervalMs), 
logger) {
             dorisStreamLoad.abortById(txnId)
-          } match {
-            case Success(_) =>
+          } () match {
+            case Success(_) => // do nothing
             case Failure(_) => failedTxnIds += txnId
           })
         if (failedTxnIds.nonEmpty) {
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala
index 54976a7..8910389 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala
@@ -34,6 +34,7 @@ import scala.util.{Failure, Success, Try}
 private[spark] object Utils {
   /**
    * quote column name
+   *
    * @param colName column name
    * @return quoted column name
    */
@@ -41,8 +42,9 @@ private[spark] object Utils {
 
   /**
    * compile a filter to Doris FE filter format.
-   * @param filter filter to be compile
-   * @param dialect jdbc dialect to translate value to sql format
+   *
+   * @param filter             filter to be compile
+   * @param dialect            jdbc dialect to translate value to sql format
    * @param inValueLengthLimit max length of in value array
    * @return if Doris FE can handle this filter, return None if Doris FE can 
not handled it.
    */
@@ -87,6 +89,7 @@ private[spark] object Utils {
 
   /**
    * Escape special characters in SQL string literals.
+   *
    * @param value The string to be escaped.
    * @return Escaped string.
    */
@@ -95,6 +98,7 @@ private[spark] object Utils {
 
   /**
    * Converts value to SQL expression.
+   *
    * @param value The value to be converted.
    * @return Converted value.
    */
@@ -108,16 +112,17 @@ private[spark] object Utils {
 
   /**
    * check parameters validation and process it.
+   *
    * @param parameters parameters from rdd and spark conf
-   * @param logger slf4j logger
+   * @param logger     slf4j logger
    * @return processed parameters
    */
   def params(parameters: Map[String, String], logger: Logger) = {
     // '.' seems to be problematic when specifying the options
     val dottedParams = parameters.map { case (k, v) =>
-      if (k.startsWith("sink.properties.") || 
k.startsWith("doris.sink.properties.")){
-        (k,v)
-      }else {
+      if (k.startsWith("sink.properties.") || 
k.startsWith("doris.sink.properties.")) {
+        (k, v)
+      } else {
         (k.replace('_', '.'), v)
       }
     }
@@ -141,7 +146,7 @@ private[spark] object Utils {
       case (k, v) =>
         if (k.startsWith("doris.")) (k, v)
         else ("doris." + k, v)
-    }.map{
+    }.map {
       case (ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD, _) =>
         logger.error(s"${ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD} 
cannot use in Doris Datasource.")
         throw new 
DorisException(s"${ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD} cannot use 
in" +
@@ -165,13 +170,14 @@ private[spark] object Utils {
 
     // validate path is available
     finalParams.getOrElse(ConfigurationOptions.DORIS_TABLE_IDENTIFIER,
-        throw new DorisException("table identifier must be specified for doris 
table identifier."))
+      throw new DorisException("table identifier must be specified for doris 
table identifier."))
 
     finalParams
   }
 
   @tailrec
-  def retry[R, T <: Throwable : ClassTag](retryTimes: Int, interval: Duration, 
logger: Logger)(f: => R): Try[R] = {
+  def retry[R, T <: Throwable : ClassTag](retryTimes: Int, interval: Duration, 
logger: Logger)
+                                         (f: => R)(h: => Unit): Try[R] = {
     assert(retryTimes >= 0)
     val result = Try(f)
     result match {
@@ -182,7 +188,8 @@ private[spark] object Utils {
         logger.warn(s"Execution failed caused by: ", exception)
         logger.warn(s"$retryTimes times retry remaining, the next attempt will 
be in ${interval.toMillis} ms")
         LockSupport.parkNanos(interval.toNanos)
-        retry(retryTimes - 1, interval, logger)(f)
+        h
+        retry(retryTimes - 1, interval, logger)(f)(h)
       case Failure(exception) => Failure(exception)
     }
   }
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
index a8c414e..6498bea 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
@@ -21,7 +21,6 @@ import org.apache.doris.spark.cfg.{ConfigurationOptions, 
SparkSettings}
 import org.apache.doris.spark.listener.DorisTransactionListener
 import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, 
DorisStreamLoad}
 import org.apache.doris.spark.sql.Utils
-import org.apache.spark.TaskContext
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types.StructType
@@ -32,10 +31,10 @@ import java.io.IOException
 import java.time.Duration
 import java.util
 import java.util.Objects
-import java.util.concurrent.locks.LockSupport
 import scala.collection.JavaConverters._
 import scala.collection.mutable
-import scala.util.{Failure, Success, Try}
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
 
 class DorisWriter(settings: SparkSettings) extends Serializable {
 
@@ -44,9 +43,18 @@ class DorisWriter(settings: SparkSettings) extends 
Serializable {
   private val sinkTaskPartitionSize: Integer = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TASK_PARTITION_SIZE)
   private val sinkTaskUseRepartition: Boolean = 
settings.getProperty(ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION,
     
ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION_DEFAULT.toString).toBoolean
+
+  private val maxRetryTimes: Int = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_MAX_RETRIES,
+    ConfigurationOptions.SINK_MAX_RETRIES_DEFAULT)
+  private val batchSize: Integer = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE,
+    ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT)
   private val batchInterValMs: Integer = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS,
     ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS_DEFAULT)
 
+  if (maxRetryTimes > 0) {
+    logger.info(s"batch retry enabled, size is $batchSize, interval is 
$batchInterValMs")
+  }
+
   private val enable2PC: Boolean = 
settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC,
     ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT)
   private val sinkTxnIntervalMs: Int = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS,
@@ -77,7 +85,6 @@ class DorisWriter(settings: SparkSettings) extends 
Serializable {
   private def doWrite(dataFrame: DataFrame, loadFunc: 
(util.Iterator[InternalRow], StructType) => Int): Unit = {
 
     val sc = dataFrame.sqlContext.sparkContext
-    logger.info(s"applicationAttemptId: 
${sc.applicationAttemptId.getOrElse(-1)}")
     val preCommittedTxnAcc = 
sc.collectionAccumulator[Int]("preCommittedTxnAcc")
     if (enable2PC) {
       sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, 
dorisStreamLoader, sinkTxnIntervalMs, sinkTxnRetries))
@@ -89,19 +96,22 @@ class DorisWriter(settings: SparkSettings) extends 
Serializable {
       resultRdd = if (sinkTaskUseRepartition) 
resultRdd.repartition(sinkTaskPartitionSize) else 
resultRdd.coalesce(sinkTaskPartitionSize)
     }
     resultRdd.foreachPartition(iterator => {
-      val intervalNanos = Duration.ofMillis(batchInterValMs.toLong).toNanos
+
       while (iterator.hasNext) {
-        Try {
-          loadFunc(iterator.asJava, schema)
-        } match {
-          case Success(txnId) => if (enable2PC) handleLoadSuccess(txnId, 
preCommittedTxnAcc)
+        val batchIterator = new BatchIterator[InternalRow](iterator, 
batchSize, maxRetryTimes > 0)
+        val retry = Utils.retry[Int, Exception](maxRetryTimes, 
Duration.ofMillis(batchInterValMs.toLong), logger) _
+        retry(loadFunc(batchIterator.asJava, schema))(batchIterator.reset()) 
match {
+          case Success(txnId) =>
+            if (enable2PC) handleLoadSuccess(txnId, preCommittedTxnAcc)
+            batchIterator.close()
           case Failure(e) =>
             if (enable2PC) handleLoadFailure(preCommittedTxnAcc)
+            batchIterator.close()
             throw new IOException(
               s"Failed to load batch data on BE: 
${dorisStreamLoader.getLoadUrlStr} node.", e)
         }
-        LockSupport.parkNanos(intervalNanos)
       }
+
     })
 
   }
@@ -120,10 +130,10 @@ class DorisWriter(settings: SparkSettings) extends 
Serializable {
     }
     val abortFailedTxnIds = mutable.Buffer[Int]()
     acc.value.asScala.foreach(txnId => {
-      Utils.retry[Unit, Exception](sinkTxnRetries, 
Duration.ofMillis(sinkTxnIntervalMs), logger) {
+      Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) {
         dorisStreamLoader.abortById(txnId)
-      } match {
-        case Success(_) =>
+      }() match {
+        case Success(_) => // do nothing
         case Failure(_) => abortFailedTxnIds += txnId
       }
     })
@@ -131,5 +141,60 @@ class DorisWriter(settings: SparkSettings) extends 
Serializable {
     acc.reset()
   }
 
+  /**
+   * iterator for batch load
+   * if retry time is greater than zero, enable batch retry and put batch data 
into buffer
+   *
+   * @param iterator         parent iterator
+   * @param batchSize        batch size
+   * @param batchRetryEnable whether enable batch retry
+   * @tparam T data type
+   */
+  private class BatchIterator[T](iterator: Iterator[T], batchSize: Int, 
batchRetryEnable: Boolean) extends Iterator[T] {
+
+    private val buffer: ArrayBuffer[T] = if (batchRetryEnable) new 
ArrayBuffer[T](batchSize) else ArrayBuffer.empty[T]
+
+    private var recordCount = 0
+
+    private var isReset = false
+
+    override def hasNext: Boolean = recordCount < batchSize && iterator.hasNext
+
+    override def next(): T = {
+      recordCount += 1
+      if (batchRetryEnable) {
+        if (isReset && buffer.nonEmpty) {
+          buffer(recordCount)
+        } else {
+          val elem = iterator.next
+          buffer += elem
+          elem
+        }
+      } else {
+        iterator.next
+      }
+    }
+
+    /**
+     * reset record count for re-read
+     */
+    def reset(): Unit = {
+      recordCount = 0
+      isReset = true
+      logger.info("batch iterator is reset")
+    }
+
+    /**
+     * clear buffer when buffer is not empty
+     */
+    def close(): Unit = {
+      if (buffer.nonEmpty) {
+        buffer.clear()
+        logger.info("buffer is cleared and batch iterator is closed")
+      }
+    }
+
+  }
+
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to