This is an automated email from the ASF dual-hosted git repository.

diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git


The following commit(s) were added to refs/heads/master by this push:
     new 4f2ef8d  [fix] complex type npe problem (#151)
4f2ef8d is described below

commit 4f2ef8df05b0860ab37acd6c2997d209ed4e3b3d
Author: gnehil <adamlee...@gmail.com>
AuthorDate: Wed Nov 1 17:39:57 2023 +0800

    [fix] complex type npe problem (#151)
---
 spark-doris-connector/pom.xml                      |   6 ++
 .../apache/doris/spark/serialization/RowBatch.java |   4 +-
 .../org/apache/doris/spark/sql/SchemaUtils.scala   | 120 ++++++++++-----------
 .../doris/spark/sql/TestConnectorWriteDoris.scala  |  60 ++++++++++-
 4 files changed, 122 insertions(+), 68 deletions(-)

diff --git a/spark-doris-connector/pom.xml b/spark-doris-connector/pom.xml
index 4148a66..518a3e2 100644
--- a/spark-doris-connector/pom.xml
+++ b/spark-doris-connector/pom.xml
@@ -185,6 +185,12 @@
             <version>${fasterxml.jackson.version}</version>
         </dependency>
 
+        <dependency>
+            <groupId>com.fasterxml.jackson.module</groupId>
+            <artifactId>jackson-module-scala_${scala.version}</artifactId>
+            <version>${fasterxml.jackson.version}</version>
+        </dependency>
+
         <!-- https://mvnrepository.com/artifact/com.mysql/mysql-connector-j -->
         <dependency>
             <groupId>com.mysql</groupId>
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
index b43b0a2..cb4d303 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
@@ -60,6 +60,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.NoSuchElementException;
+import java.util.Objects;
 
 /**
  * row batch data container.
@@ -357,7 +358,8 @@ public class RowBatch {
                             reader.setPosition(rowIndex);
                             Map<String, String> value = new HashMap<>();
                             while (reader.next()) {
-                                
value.put(reader.key().readObject().toString(), 
reader.value().readObject().toString());
+                                
value.put(Objects.toString(reader.key().readObject(), null),
+                                        
Objects.toString(reader.value().readObject(), null));
                             }
                             addValueToRow(rowIndex, 
JavaConverters.mapAsScalaMapConverter(value).asScala());
                         }
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
index e806059..982e580 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
@@ -17,6 +17,8 @@
 
 package org.apache.doris.spark.sql
 
+import com.fasterxml.jackson.databind.json.JsonMapper
+import com.fasterxml.jackson.module.scala.DefaultScalaModule
 import org.apache.doris.sdk.thrift.TScanColumnDesc
 import org.apache.doris.spark.cfg.ConfigurationOptions.{DORIS_IGNORE_TYPE, 
DORIS_READ_FIELD}
 import org.apache.doris.spark.cfg.Settings
@@ -32,9 +34,11 @@ import org.slf4j.LoggerFactory
 import java.sql.Timestamp
 import java.time.{LocalDateTime, ZoneOffset}
 import scala.collection.JavaConversions._
+import scala.collection.mutable
 
 private[spark] object SchemaUtils {
   private val logger = 
LoggerFactory.getLogger(SchemaUtils.getClass.getSimpleName.stripSuffix("$"))
+  private val MAPPER = 
JsonMapper.builder().addModule(DefaultScalaModule).build()
 
   /**
    * discover Doris table schema from Doris FE.
@@ -147,72 +151,60 @@ private[spark] object SchemaUtils {
 
   def rowColumnValue(row: SpecializedGetters, ordinal: Int, dataType: 
DataType): Any = {
 
-    dataType match {
-      case NullType => DataUtil.NULL_VALUE
-      case BooleanType => row.getBoolean(ordinal)
-      case ByteType => row.getByte(ordinal)
-      case ShortType => row.getShort(ordinal)
-      case IntegerType => row.getInt(ordinal)
-      case LongType => row.getLong(ordinal)
-      case FloatType => row.getFloat(ordinal)
-      case DoubleType => row.getDouble(ordinal)
-      case StringType => 
Option(row.getUTF8String(ordinal)).map(_.toString).getOrElse(DataUtil.NULL_VALUE)
-      case TimestampType =>
-        LocalDateTime.ofEpochSecond(row.getLong(ordinal) / 100000, 
(row.getLong(ordinal) % 1000).toInt, ZoneOffset.UTC)
-        new Timestamp(row.getLong(ordinal) / 1000).toString
-      case DateType => DateTimeUtils.toJavaDate(row.getInt(ordinal)).toString
-      case BinaryType => row.getBinary(ordinal)
-      case dt: DecimalType => row.getDecimal(ordinal, dt.precision, dt.scale)
-      case at: ArrayType =>
-        val arrayData = row.getArray(ordinal)
-        if (arrayData == null) DataUtil.NULL_VALUE
-        else if(arrayData.numElements() == 0) "[]"
-        else {
-          (0 until arrayData.numElements()).map(i => {
-            if (arrayData.isNullAt(i)) null else rowColumnValue(arrayData, i, 
at.elementType)
-          }).mkString("[", ",", "]")
-        }
-
-      case mt: MapType =>
-        val mapData = row.getMap(ordinal)
-        val keys = mapData.keyArray()
-        val values = mapData.valueArray()
-        val sb = StringBuilder.newBuilder
-        sb.append("{")
-        var i = 0
-        while (i < keys.numElements()) {
-          rowColumnValue(keys, i, mt.keyType) -> rowColumnValue(values, i, 
mt.valueType)
-          sb.append(quoteData(rowColumnValue(keys, i, mt.keyType), mt.keyType))
-            .append(":").append(quoteData(rowColumnValue(values, i, 
mt.valueType), mt.valueType))
-            .append(",")
-          i += 1
-        }
-        if (i > 0) sb.dropRight(1)
-        sb.append("}").toString
-      case st: StructType =>
-        val structData = row.getStruct(ordinal, st.length)
-        val sb = StringBuilder.newBuilder
-        sb.append("{")
-        var i = 0
-        while (i < structData.numFields) {
-          val field = st.get(i)
-          sb.append(s""""${field.name}":""")
-            .append(quoteData(rowColumnValue(structData, i, field.dataType), 
field.dataType))
-            .append(",")
-          i += 1
-        }
-        if (i > 0) sb.dropRight(1)
-        sb.append("}").toString
-      case _ => throw new DorisException(s"Unsupported spark type: 
${dataType.typeName}")
+    if (row.isNullAt(ordinal)) null
+    else {
+      dataType match {
+        case NullType => DataUtil.NULL_VALUE
+        case BooleanType => row.getBoolean(ordinal)
+        case ByteType => row.getByte(ordinal)
+        case ShortType => row.getShort(ordinal)
+        case IntegerType => row.getInt(ordinal)
+        case LongType => row.getLong(ordinal)
+        case FloatType => row.getFloat(ordinal)
+        case DoubleType => row.getDouble(ordinal)
+        case StringType => 
Option(row.getUTF8String(ordinal)).map(_.toString).getOrElse(DataUtil.NULL_VALUE)
+        case TimestampType =>
+          LocalDateTime.ofEpochSecond(row.getLong(ordinal) / 100000, 
(row.getLong(ordinal) % 1000).toInt, ZoneOffset.UTC)
+          new Timestamp(row.getLong(ordinal) / 1000).toString
+        case DateType => DateTimeUtils.toJavaDate(row.getInt(ordinal)).toString
+        case BinaryType => row.getBinary(ordinal)
+        case dt: DecimalType => row.getDecimal(ordinal, dt.precision, dt.scale)
+        case at: ArrayType =>
+          val arrayData = row.getArray(ordinal)
+          if (arrayData == null) DataUtil.NULL_VALUE
+          else {
+            (0 until arrayData.numElements()).map(i => {
+              if (arrayData.isNullAt(i)) null else rowColumnValue(arrayData, 
i, at.elementType)
+            }).mkString("[", ",", "]")
+          }
+        case mt: MapType =>
+          val mapData = row.getMap(ordinal)
+          if (mapData.numElements() == 0) "{}"
+          else {
+            val keys = mapData.keyArray()
+            val values = mapData.valueArray()
+            val map = mutable.HashMap[Any, Any]()
+            var i = 0
+            while (i < keys.numElements()) {
+              map += rowColumnValue(keys, i, mt.keyType) -> 
rowColumnValue(values, i, mt.valueType)
+              i += 1
+            }
+            MAPPER.writeValueAsString(map)
+          }
+        case st: StructType =>
+          val structData = row.getStruct(ordinal, st.length)
+          val map = mutable.HashMap[String, Any]()
+          var i = 0
+          while (i < structData.numFields) {
+            val field = st.get(i)
+            map += field.name -> rowColumnValue(structData, i, field.dataType)
+            i += 1
+          }
+          MAPPER.writeValueAsString(map)
+        case _ => throw new DorisException(s"Unsupported spark type: 
${dataType.typeName}")
+      }
     }
 
   }
 
-  private def quoteData(value: Any, dataType: DataType): Any = {
-    dataType match {
-      case StringType | TimestampType | DateType => s""""$value""""
-      case _ => value
-    }
-  }
-
 }
diff --git 
a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestConnectorWriteDoris.scala
 
b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestConnectorWriteDoris.scala
index ae3b066..fecface 100644
--- 
a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestConnectorWriteDoris.scala
+++ 
b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestConnectorWriteDoris.scala
@@ -17,7 +17,8 @@
 
 package org.apache.doris.spark.sql
 
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.types.{ArrayType, IntegerType, MapType, 
StringType, StructField, StructType}
 import org.junit.{Ignore, Test}
 
 // This test need real connect info to run.
@@ -125,8 +126,61 @@ class TestConnectorWriteDoris {
       .option("sink.batch.size", 2)
       .option("sink.max-retries", 2)
       .option("sink.properties.format", "json")
-      // .option("sink.properties.read_json_by_line", "true")
-      .option("sink.properties.strip_outer_array", "true")
+      .save()
+    spark.stop()
+  }
+
+
+  /**
+   * correct data in doris
+   * +------+--------------+-------------+------------+
+   * | id   | a            | m           | s          |
+   * +------+--------------+-------------+------------+
+   * |    1 | [1, 2, 3]    | {"k1":1}    | {10, "ab"} |
+   * |    2 | [4, 5, 6]    | {"k2":3}    | NULL       |
+   * |    3 | [7, 8, 9]    | NULL        | {20, "cd"} |
+   * |    4 | NULL         | {"k3":5}    | {30, "ef"} |
+   * |    5 | [10, 11, 12] | {"k4":7}    | {40, NULL} |
+   * |    6 | [13, 14, 15] | {"k5":NULL} | {50, NULL} |
+   * |    7 | []           | {}          | {60, "gh"} |
+   * +------+--------------+-------------+------------+
+   */
+  @Test
+  def complexWriteTest(): Unit = {
+    val spark = SparkSession.builder().master("local[1]").getOrCreate()
+
+    val data = Array(
+      Row(1, Array(1,2,3), Map("k1" -> 1), Row(10, "ab")),
+      Row(2, Array(4,5,6), Map("k2" -> 3), null),
+      Row(3, Array(7,8,9), null, Row(20, "cd")),
+      Row(4, null, Map("k3" -> 5), Row(30, "ef")),
+      Row(5, Array(10,11,12), Map("k4" -> 7), Row(40, null)),
+      Row(6, Array(13,14,15), Map("k5" -> null), Row(50, "{10, \"ab\"}")),
+      Row(7, Array(), Map(), Row(60, "gh"))
+    )
+
+    val schema = StructType(
+      Array(
+        StructField("id", IntegerType),
+        StructField("a", ArrayType(IntegerType)),
+        StructField("m", MapType(StringType, IntegerType)),
+        StructField("s", StructType(Seq(StructField("a", 
IntegerType),StructField("b", StringType))))
+      )
+    )
+
+    val rdd = spark.sparkContext.parallelize(data)
+    val df = spark.createDataFrame(rdd, schema)
+    df.printSchema()
+    df.show(false)
+    df.write
+      .format("doris")
+      .option("doris.fenodes", dorisFeNodes)
+      .option("doris.table.identifier", dorisTable)
+      .option("user", dorisUser)
+      .option("password", dorisPwd)
+      .option("sink.batch.size", 2)
+      .option("sink.max-retries", 0)
+      .option("sink.properties.format", "json")
       .save()
     spark.stop()
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to