This is an automated email from the ASF dual-hosted git repository.

diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git


The following commit(s) were added to refs/heads/master by this push:
     new 5410651  Fix data loss due to internal retries (#145)
5410651 is described below

commit 5410651e3fdcdc03ce14c09ae1b11a75f4a773ad
Author: gnehil <adamlee...@gmail.com>
AuthorDate: Sun Oct 8 18:25:03 2023 +0800

    Fix data loss due to internal retries (#145)
---
 .../apache/doris/spark/load/DorisStreamLoad.java   | 96 +++++++++++++++++++---
 .../spark/listener/DorisTransactionListener.scala  |  2 +-
 .../apache/doris/spark/writer/DorisWriter.scala    | 32 +++++---
 3 files changed, 106 insertions(+), 24 deletions(-)

diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
index 0b506b0..c524a4c 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
@@ -51,6 +51,7 @@ import org.slf4j.LoggerFactory;
 import java.io.IOException;
 import java.io.Serializable;
 import java.nio.charset.StandardCharsets;
+import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Base64;
@@ -64,13 +65,14 @@ import java.util.Properties;
 import java.util.UUID;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.LockSupport;
+import java.util.function.Consumer;
 
 
 /**
  * DorisStreamLoad
  **/
 public class DorisStreamLoad implements Serializable {
-    private static final String NULL_VALUE = "\\N";
 
     private static final Logger LOG = 
LoggerFactory.getLogger(DorisStreamLoad.class);
 
@@ -97,7 +99,9 @@ public class DorisStreamLoad implements Serializable {
     private final String LINE_DELIMITER;
     private boolean streamingPassthrough = false;
     private final Integer batchSize;
-    private boolean enable2PC;
+    private final boolean enable2PC;
+    private final Integer txnRetries;
+    private final Integer txnIntervalMs;
 
     public DorisStreamLoad(SparkSettings settings) {
         String[] dbTable = 
settings.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER).split("\\.");
@@ -128,6 +132,10 @@ public class DorisStreamLoad implements Serializable {
                 ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT);
         this.enable2PC = 
settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC,
                 ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT);
+        this.txnRetries = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_RETRIES,
+                ConfigurationOptions.DORIS_SINK_TXN_RETRIES_DEFAULT);
+        this.txnIntervalMs = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS,
+                ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS_DEFAULT);
     }
 
     public String getLoadUrlStr() {
@@ -202,7 +210,19 @@ public class DorisStreamLoad implements Serializable {
             HttpResponse httpResponse = httpClient.execute(httpPut);
             loadResponse = new LoadResponse(httpResponse);
         } catch (IOException e) {
-            throw new RuntimeException(e);
+            if (enable2PC) {
+                int retries = txnRetries;
+                while (retries > 0) {
+                    try {
+                        abortByLabel(label);
+                        retries = 0;
+                    } catch (StreamLoadException ex) {
+                        
LockSupport.parkNanos(Duration.ofMillis(txnIntervalMs).toNanos());
+                        retries--;
+                    }
+                }
+            }
+            throw new StreamLoadException("load execute failed", e);
         }
 
         if (loadResponse.status != HttpStatus.SC_OK) {
@@ -274,22 +294,68 @@ public class DorisStreamLoad implements Serializable {
 
     }
 
-    public void abort(int txnId) throws StreamLoadException {
+    /**
+     * abort transaction by id
+     *
+     * @param txnId transaction id
+     * @throws StreamLoadException
+     */
+    public void abortById(int txnId) throws StreamLoadException {
 
         LOG.info("start abort transaction {}.", txnId);
 
+        try {
+            doAbort(httpPut -> httpPut.setHeader("txn_id", 
String.valueOf(txnId)));
+        } catch (StreamLoadException e) {
+            LOG.error("abort transaction by id: {} failed.", txnId);
+            throw e;
+        }
+
+        LOG.info("abort transaction {} succeed.", txnId);
+
+    }
+
+    /**
+     * abort transaction by label
+     *
+     * @param label label
+     * @throws StreamLoadException
+     */
+    public void abortByLabel(String label) throws StreamLoadException {
+
+        LOG.info("start abort transaction by label: {}.", label);
+
+        try {
+            doAbort(httpPut -> httpPut.setHeader("label", label));
+        } catch (StreamLoadException e) {
+            LOG.error("abort transaction by label: {} failed.", label);
+            throw e;
+        }
+
+        LOG.info("abort transaction by label {} succeed.", label);
+
+    }
+
+    /**
+     * execute abort
+     *
+     * @param putConsumer http put process function
+     * @throws StreamLoadException
+     */
+    private void doAbort(Consumer<HttpPut> putConsumer) throws 
StreamLoadException {
+
         try (CloseableHttpClient client = getHttpClient()) {
             String abortUrl = String.format(abortUrlPattern, getBackend(), db, 
tbl);
             HttpPut httpPut = new HttpPut(abortUrl);
             addCommonHeader(httpPut);
             httpPut.setHeader("txn_operation", "abort");
-            httpPut.setHeader("txn_id", String.valueOf(txnId));
+            putConsumer.accept(httpPut);
 
             CloseableHttpResponse response = client.execute(httpPut);
             int statusCode = response.getStatusLine().getStatusCode();
             if (statusCode != 200 || response.getEntity() == null) {
-                LOG.warn("abort transaction response: " + 
response.getStatusLine().toString());
-                throw new StreamLoadException("Fail to abort transaction " + 
txnId + " with url " + abortUrl);
+                LOG.error("abort transaction response: " + 
response.getStatusLine().toString());
+                throw new IOException("Fail to abort transaction with url " + 
abortUrl);
             }
 
             String loadResult = EntityUtils.toString(response.getEntity());
@@ -297,17 +363,16 @@ public class DorisStreamLoad implements Serializable {
             });
             if (!"Success".equals(res.get("status"))) {
                 if (ResponseUtil.isCommitted(res.get("msg"))) {
-                    throw new StreamLoadException("try abort committed 
transaction, " + "do you recover from old savepoint?");
+                    throw new IOException("try abort committed transaction");
                 }
-                LOG.warn("Fail to abort transaction. txnId: {}, error: {}", 
txnId, res.get("msg"));
+                LOG.error("Fail to abort transaction. error: {}", 
res.get("msg"));
+                throw new IOException(String.format("Fail to abort 
transaction. error: %s", res.get("msg")));
             }
 
         } catch (IOException e) {
             throw new StreamLoadException(e);
         }
 
-        LOG.info("abort transaction {} succeed.", txnId);
-
     }
 
     public Map<String, String> getStreamLoadProp(SparkSettings sparkSettings) {
@@ -386,12 +451,21 @@ public class DorisStreamLoad implements Serializable {
         return hexData;
     }
 
+    /**
+     * add common header to http request
+     *
+     * @param httpReq http request
+     */
     private void addCommonHeader(HttpRequestBase httpReq) {
         httpReq.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + authEncoded);
         httpReq.setHeader(HttpHeaders.EXPECT, "100-continue");
         httpReq.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain; 
charset=UTF-8");
     }
 
+    /**
+     * handle stream sink data pass through
+     * if load format is json, set read_json_by_line to true and remove 
strip_outer_array parameter
+     */
     private void handleStreamPassThrough() {
 
         if ("json".equalsIgnoreCase(fileType)) {
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala
index 262ad19..e5991de 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala
@@ -67,7 +67,7 @@ class DorisTransactionListener(preCommittedTxnAcc: 
CollectionAccumulator[Int], d
         logger.info("job run failed, start aborting transactions")
         txnIds.foreach(txnId =>
           Utils.retry(sinkTxnRetries, Duration.ofMillis(sinkTnxIntervalMs), 
logger) {
-            dorisStreamLoad.abort(txnId)
+            dorisStreamLoad.abortById(txnId)
           } match {
             case Success(_) =>
             case Failure(_) => failedTxnIds += txnId
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
index 3fdfb79..a8c414e 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
@@ -21,6 +21,7 @@ import org.apache.doris.spark.cfg.{ConfigurationOptions, 
SparkSettings}
 import org.apache.doris.spark.listener.DorisTransactionListener
 import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, 
DorisStreamLoad}
 import org.apache.doris.spark.sql.Utils
+import org.apache.spark.TaskContext
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types.StructType
@@ -31,18 +32,15 @@ import java.io.IOException
 import java.time.Duration
 import java.util
 import java.util.Objects
+import java.util.concurrent.locks.LockSupport
 import scala.collection.JavaConverters._
 import scala.collection.mutable
-import scala.util.{Failure, Success}
+import scala.util.{Failure, Success, Try}
 
 class DorisWriter(settings: SparkSettings) extends Serializable {
 
   private val logger: Logger = LoggerFactory.getLogger(classOf[DorisWriter])
 
-  val batchSize: Int = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE,
-    ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT)
-  private val maxRetryTimes: Int = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_MAX_RETRIES,
-    ConfigurationOptions.SINK_MAX_RETRIES_DEFAULT)
   private val sinkTaskPartitionSize: Integer = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TASK_PARTITION_SIZE)
   private val sinkTaskUseRepartition: Boolean = 
settings.getProperty(ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION,
     
ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION_DEFAULT.toString).toBoolean
@@ -50,7 +48,7 @@ class DorisWriter(settings: SparkSettings) extends 
Serializable {
     ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS_DEFAULT)
 
   private val enable2PC: Boolean = 
settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC,
-    ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT);
+    ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT)
   private val sinkTxnIntervalMs: Int = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS,
     ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS_DEFAULT)
   private val sinkTxnRetries: Integer = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_RETRIES,
@@ -58,19 +56,28 @@ class DorisWriter(settings: SparkSettings) extends 
Serializable {
 
   private val dorisStreamLoader: DorisStreamLoad = 
CachedDorisStreamLoadClient.getOrCreate(settings)
 
+  /**
+   * write data in batch mode
+   *
+   * @param dataFrame source dataframe
+   */
   def write(dataFrame: DataFrame): Unit = {
     doWrite(dataFrame, dorisStreamLoader.load)
   }
 
+  /**
+   * write data in stream mode
+   *
+   * @param dataFrame source dataframe
+   */
   def writeStream(dataFrame: DataFrame): Unit = {
     doWrite(dataFrame, dorisStreamLoader.loadStream)
   }
 
   private def doWrite(dataFrame: DataFrame, loadFunc: 
(util.Iterator[InternalRow], StructType) => Int): Unit = {
 
-
-
     val sc = dataFrame.sqlContext.sparkContext
+    logger.info(s"applicationAttemptId: 
${sc.applicationAttemptId.getOrElse(-1)}")
     val preCommittedTxnAcc = 
sc.collectionAccumulator[Int]("preCommittedTxnAcc")
     if (enable2PC) {
       sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, 
dorisStreamLoader, sinkTxnIntervalMs, sinkTxnRetries))
@@ -82,17 +89,18 @@ class DorisWriter(settings: SparkSettings) extends 
Serializable {
       resultRdd = if (sinkTaskUseRepartition) 
resultRdd.repartition(sinkTaskPartitionSize) else 
resultRdd.coalesce(sinkTaskPartitionSize)
     }
     resultRdd.foreachPartition(iterator => {
+      val intervalNanos = Duration.ofMillis(batchInterValMs.toLong).toNanos
       while (iterator.hasNext) {
-        // do load batch with retries
-        Utils.retry[Int, Exception](maxRetryTimes, 
Duration.ofMillis(batchInterValMs.toLong), logger) {
+        Try {
           loadFunc(iterator.asJava, schema)
         } match {
           case Success(txnId) => if (enable2PC) handleLoadSuccess(txnId, 
preCommittedTxnAcc)
           case Failure(e) =>
             if (enable2PC) handleLoadFailure(preCommittedTxnAcc)
             throw new IOException(
-              s"Failed to load batch data on BE: 
${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} 
retry times.", e)
+              s"Failed to load batch data on BE: 
${dorisStreamLoader.getLoadUrlStr} node.", e)
         }
+        LockSupport.parkNanos(intervalNanos)
       }
     })
 
@@ -113,7 +121,7 @@ class DorisWriter(settings: SparkSettings) extends 
Serializable {
     val abortFailedTxnIds = mutable.Buffer[Int]()
     acc.value.asScala.foreach(txnId => {
       Utils.retry[Unit, Exception](sinkTxnRetries, 
Duration.ofMillis(sinkTxnIntervalMs), logger) {
-        dorisStreamLoader.abort(txnId)
+        dorisStreamLoader.abortById(txnId)
       } match {
         case Success(_) =>
         case Failure(_) => abortFailedTxnIds += txnId


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to