This is an automated email from the ASF dual-hosted git repository. diwu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git
The following commit(s) were added to refs/heads/master by this push: new 4f2ef8d [fix] complex type npe problem (#151) 4f2ef8d is described below commit 4f2ef8df05b0860ab37acd6c2997d209ed4e3b3d Author: gnehil <adamlee...@gmail.com> AuthorDate: Wed Nov 1 17:39:57 2023 +0800 [fix] complex type npe problem (#151) --- spark-doris-connector/pom.xml | 6 ++ .../apache/doris/spark/serialization/RowBatch.java | 4 +- .../org/apache/doris/spark/sql/SchemaUtils.scala | 120 ++++++++++----------- .../doris/spark/sql/TestConnectorWriteDoris.scala | 60 ++++++++++- 4 files changed, 122 insertions(+), 68 deletions(-) diff --git a/spark-doris-connector/pom.xml b/spark-doris-connector/pom.xml index 4148a66..518a3e2 100644 --- a/spark-doris-connector/pom.xml +++ b/spark-doris-connector/pom.xml @@ -185,6 +185,12 @@ <version>${fasterxml.jackson.version}</version> </dependency> + <dependency> + <groupId>com.fasterxml.jackson.module</groupId> + <artifactId>jackson-module-scala_${scala.version}</artifactId> + <version>${fasterxml.jackson.version}</version> + </dependency> + <!-- https://mvnrepository.com/artifact/com.mysql/mysql-connector-j --> <dependency> <groupId>com.mysql</groupId> diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java index b43b0a2..cb4d303 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java @@ -60,6 +60,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; +import java.util.Objects; /** * row batch data container. @@ -357,7 +358,8 @@ public class RowBatch { reader.setPosition(rowIndex); Map<String, String> value = new HashMap<>(); while (reader.next()) { - value.put(reader.key().readObject().toString(), reader.value().readObject().toString()); + value.put(Objects.toString(reader.key().readObject(), null), + Objects.toString(reader.value().readObject(), null)); } addValueToRow(rowIndex, JavaConverters.mapAsScalaMapConverter(value).asScala()); } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala index e806059..982e580 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala @@ -17,6 +17,8 @@ package org.apache.doris.spark.sql +import com.fasterxml.jackson.databind.json.JsonMapper +import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.apache.doris.sdk.thrift.TScanColumnDesc import org.apache.doris.spark.cfg.ConfigurationOptions.{DORIS_IGNORE_TYPE, DORIS_READ_FIELD} import org.apache.doris.spark.cfg.Settings @@ -32,9 +34,11 @@ import org.slf4j.LoggerFactory import java.sql.Timestamp import java.time.{LocalDateTime, ZoneOffset} import scala.collection.JavaConversions._ +import scala.collection.mutable private[spark] object SchemaUtils { private val logger = LoggerFactory.getLogger(SchemaUtils.getClass.getSimpleName.stripSuffix("$")) + private val MAPPER = JsonMapper.builder().addModule(DefaultScalaModule).build() /** * discover Doris table schema from Doris FE. @@ -147,72 +151,60 @@ private[spark] object SchemaUtils { def rowColumnValue(row: SpecializedGetters, ordinal: Int, dataType: DataType): Any = { - dataType match { - case NullType => DataUtil.NULL_VALUE - case BooleanType => row.getBoolean(ordinal) - case ByteType => row.getByte(ordinal) - case ShortType => row.getShort(ordinal) - case IntegerType => row.getInt(ordinal) - case LongType => row.getLong(ordinal) - case FloatType => row.getFloat(ordinal) - case DoubleType => row.getDouble(ordinal) - case StringType => Option(row.getUTF8String(ordinal)).map(_.toString).getOrElse(DataUtil.NULL_VALUE) - case TimestampType => - LocalDateTime.ofEpochSecond(row.getLong(ordinal) / 100000, (row.getLong(ordinal) % 1000).toInt, ZoneOffset.UTC) - new Timestamp(row.getLong(ordinal) / 1000).toString - case DateType => DateTimeUtils.toJavaDate(row.getInt(ordinal)).toString - case BinaryType => row.getBinary(ordinal) - case dt: DecimalType => row.getDecimal(ordinal, dt.precision, dt.scale) - case at: ArrayType => - val arrayData = row.getArray(ordinal) - if (arrayData == null) DataUtil.NULL_VALUE - else if(arrayData.numElements() == 0) "[]" - else { - (0 until arrayData.numElements()).map(i => { - if (arrayData.isNullAt(i)) null else rowColumnValue(arrayData, i, at.elementType) - }).mkString("[", ",", "]") - } - - case mt: MapType => - val mapData = row.getMap(ordinal) - val keys = mapData.keyArray() - val values = mapData.valueArray() - val sb = StringBuilder.newBuilder - sb.append("{") - var i = 0 - while (i < keys.numElements()) { - rowColumnValue(keys, i, mt.keyType) -> rowColumnValue(values, i, mt.valueType) - sb.append(quoteData(rowColumnValue(keys, i, mt.keyType), mt.keyType)) - .append(":").append(quoteData(rowColumnValue(values, i, mt.valueType), mt.valueType)) - .append(",") - i += 1 - } - if (i > 0) sb.dropRight(1) - sb.append("}").toString - case st: StructType => - val structData = row.getStruct(ordinal, st.length) - val sb = StringBuilder.newBuilder - sb.append("{") - var i = 0 - while (i < structData.numFields) { - val field = st.get(i) - sb.append(s""""${field.name}":""") - .append(quoteData(rowColumnValue(structData, i, field.dataType), field.dataType)) - .append(",") - i += 1 - } - if (i > 0) sb.dropRight(1) - sb.append("}").toString - case _ => throw new DorisException(s"Unsupported spark type: ${dataType.typeName}") + if (row.isNullAt(ordinal)) null + else { + dataType match { + case NullType => DataUtil.NULL_VALUE + case BooleanType => row.getBoolean(ordinal) + case ByteType => row.getByte(ordinal) + case ShortType => row.getShort(ordinal) + case IntegerType => row.getInt(ordinal) + case LongType => row.getLong(ordinal) + case FloatType => row.getFloat(ordinal) + case DoubleType => row.getDouble(ordinal) + case StringType => Option(row.getUTF8String(ordinal)).map(_.toString).getOrElse(DataUtil.NULL_VALUE) + case TimestampType => + LocalDateTime.ofEpochSecond(row.getLong(ordinal) / 100000, (row.getLong(ordinal) % 1000).toInt, ZoneOffset.UTC) + new Timestamp(row.getLong(ordinal) / 1000).toString + case DateType => DateTimeUtils.toJavaDate(row.getInt(ordinal)).toString + case BinaryType => row.getBinary(ordinal) + case dt: DecimalType => row.getDecimal(ordinal, dt.precision, dt.scale) + case at: ArrayType => + val arrayData = row.getArray(ordinal) + if (arrayData == null) DataUtil.NULL_VALUE + else { + (0 until arrayData.numElements()).map(i => { + if (arrayData.isNullAt(i)) null else rowColumnValue(arrayData, i, at.elementType) + }).mkString("[", ",", "]") + } + case mt: MapType => + val mapData = row.getMap(ordinal) + if (mapData.numElements() == 0) "{}" + else { + val keys = mapData.keyArray() + val values = mapData.valueArray() + val map = mutable.HashMap[Any, Any]() + var i = 0 + while (i < keys.numElements()) { + map += rowColumnValue(keys, i, mt.keyType) -> rowColumnValue(values, i, mt.valueType) + i += 1 + } + MAPPER.writeValueAsString(map) + } + case st: StructType => + val structData = row.getStruct(ordinal, st.length) + val map = mutable.HashMap[String, Any]() + var i = 0 + while (i < structData.numFields) { + val field = st.get(i) + map += field.name -> rowColumnValue(structData, i, field.dataType) + i += 1 + } + MAPPER.writeValueAsString(map) + case _ => throw new DorisException(s"Unsupported spark type: ${dataType.typeName}") + } } } - private def quoteData(value: Any, dataType: DataType): Any = { - dataType match { - case StringType | TimestampType | DateType => s""""$value"""" - case _ => value - } - } - } diff --git a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestConnectorWriteDoris.scala b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestConnectorWriteDoris.scala index ae3b066..fecface 100644 --- a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestConnectorWriteDoris.scala +++ b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestConnectorWriteDoris.scala @@ -17,7 +17,8 @@ package org.apache.doris.spark.sql -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.types.{ArrayType, IntegerType, MapType, StringType, StructField, StructType} import org.junit.{Ignore, Test} // This test need real connect info to run. @@ -125,8 +126,61 @@ class TestConnectorWriteDoris { .option("sink.batch.size", 2) .option("sink.max-retries", 2) .option("sink.properties.format", "json") - // .option("sink.properties.read_json_by_line", "true") - .option("sink.properties.strip_outer_array", "true") + .save() + spark.stop() + } + + + /** + * correct data in doris + * +------+--------------+-------------+------------+ + * | id | a | m | s | + * +------+--------------+-------------+------------+ + * | 1 | [1, 2, 3] | {"k1":1} | {10, "ab"} | + * | 2 | [4, 5, 6] | {"k2":3} | NULL | + * | 3 | [7, 8, 9] | NULL | {20, "cd"} | + * | 4 | NULL | {"k3":5} | {30, "ef"} | + * | 5 | [10, 11, 12] | {"k4":7} | {40, NULL} | + * | 6 | [13, 14, 15] | {"k5":NULL} | {50, NULL} | + * | 7 | [] | {} | {60, "gh"} | + * +------+--------------+-------------+------------+ + */ + @Test + def complexWriteTest(): Unit = { + val spark = SparkSession.builder().master("local[1]").getOrCreate() + + val data = Array( + Row(1, Array(1,2,3), Map("k1" -> 1), Row(10, "ab")), + Row(2, Array(4,5,6), Map("k2" -> 3), null), + Row(3, Array(7,8,9), null, Row(20, "cd")), + Row(4, null, Map("k3" -> 5), Row(30, "ef")), + Row(5, Array(10,11,12), Map("k4" -> 7), Row(40, null)), + Row(6, Array(13,14,15), Map("k5" -> null), Row(50, "{10, \"ab\"}")), + Row(7, Array(), Map(), Row(60, "gh")) + ) + + val schema = StructType( + Array( + StructField("id", IntegerType), + StructField("a", ArrayType(IntegerType)), + StructField("m", MapType(StringType, IntegerType)), + StructField("s", StructType(Seq(StructField("a", IntegerType),StructField("b", StringType)))) + ) + ) + + val rdd = spark.sparkContext.parallelize(data) + val df = spark.createDataFrame(rdd, schema) + df.printSchema() + df.show(false) + df.write + .format("doris") + .option("doris.fenodes", dorisFeNodes) + .option("doris.table.identifier", dorisTable) + .option("user", dorisUser) + .option("password", dorisPwd) + .option("sink.batch.size", 2) + .option("sink.max-retries", 0) + .option("sink.properties.format", "json") .save() spark.stop() } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org