This is an automated email from the ASF dual-hosted git repository.

diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git


The following commit(s) were added to refs/heads/master by this push:
     new 7a45abe  [feature] support read map and struct type (#116)
7a45abe is described below

commit 7a45abe3f955e938f6161e95c43a911464996b5a
Author: gnehil <adamlee...@gmail.com>
AuthorDate: Tue Oct 17 11:58:35 2023 +0800

    [feature] support read map and struct type (#116)
---
 .../apache/doris/spark/serialization/RowBatch.java |  37 +++++
 .../org/apache/doris/spark/sql/SchemaUtils.scala   |   3 +-
 .../doris/spark/serialization/TestRowBatch.java    | 180 +++++++++++++++++++--
 .../doris/spark/sql/TestSparkConnector.scala       |   1 +
 4 files changed, 203 insertions(+), 18 deletions(-)

diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
index 3d66db5..b43b0a2 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
@@ -37,12 +37,16 @@ import org.apache.arrow.vector.VarBinaryVector;
 import org.apache.arrow.vector.VarCharVector;
 import org.apache.arrow.vector.VectorSchemaRoot;
 import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.MapVector;
+import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.complex.impl.UnionMapReader;
 import org.apache.arrow.vector.ipc.ArrowStreamReader;
 import org.apache.arrow.vector.types.Types;
 import org.apache.commons.lang3.ArrayUtils;
 import org.apache.spark.sql.types.Decimal;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import scala.collection.JavaConverters;
 
 import java.io.ByteArrayInputStream;
 import java.io.IOException;
@@ -52,7 +56,9 @@ import java.nio.charset.StandardCharsets;
 import java.sql.Date;
 import java.time.LocalDate;
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.NoSuchElementException;
 
 /**
@@ -338,6 +344,37 @@ public class RowBatch {
                             addValueToRow(rowIndex, value);
                         }
                         break;
+                    case "MAP":
+                        
Preconditions.checkArgument(mt.equals(Types.MinorType.MAP),
+                                typeMismatchMessage(currentType, mt));
+                        MapVector mapVector = (MapVector) curFieldVector;
+                        UnionMapReader reader = mapVector.getReader();
+                        for (int rowIndex = 0; rowIndex < rowCountInOneBatch; 
rowIndex++) {
+                            if (mapVector.isNull(rowIndex)) {
+                                addValueToRow(rowIndex, null);
+                                continue;
+                            }
+                            reader.setPosition(rowIndex);
+                            Map<String, String> value = new HashMap<>();
+                            while (reader.next()) {
+                                
value.put(reader.key().readObject().toString(), 
reader.value().readObject().toString());
+                            }
+                            addValueToRow(rowIndex, 
JavaConverters.mapAsScalaMapConverter(value).asScala());
+                        }
+                        break;
+                    case "STRUCT":
+                        
Preconditions.checkArgument(mt.equals(Types.MinorType.STRUCT),
+                                typeMismatchMessage(currentType, mt));
+                        StructVector structVector = (StructVector) 
curFieldVector;
+                        for (int rowIndex = 0; rowIndex < rowCountInOneBatch; 
rowIndex++) {
+                            if (structVector.isNull(rowIndex)) {
+                                addValueToRow(rowIndex, null);
+                                continue;
+                            }
+                            String value = 
structVector.getObject(rowIndex).toString();
+                            addValueToRow(rowIndex, value);
+                        }
+                        break;
                     default:
                         String errMsg = "Unsupported type " + 
schema.get(col).getType();
                         logger.error(errMsg);
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
index 677cc2e..44baa95 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
@@ -32,7 +32,6 @@ import org.slf4j.LoggerFactory
 import java.sql.Timestamp
 import java.time.{LocalDateTime, ZoneOffset}
 import scala.collection.JavaConversions._
-import scala.collection.JavaConverters._
 import scala.collection.mutable
 
 private[spark] object SchemaUtils {
@@ -126,6 +125,8 @@ private[spark] object SchemaUtils {
       case "TIME"            => DataTypes.DoubleType
       case "STRING"          => DataTypes.StringType
       case "ARRAY"           => DataTypes.StringType
+      case "MAP"             => MapType(DataTypes.StringType, 
DataTypes.StringType)
+      case "STRUCT"          => DataTypes.StringType
       case "HLL"             =>
         throw new DorisException("Unsupported type " + dorisType)
       case _                 =>
diff --git 
a/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
 
b/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
index ace928f..cb7e0b8 100644
--- 
a/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
+++ 
b/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
@@ -25,6 +25,8 @@ import org.apache.doris.spark.rest.RestService;
 import org.apache.doris.spark.rest.models.Schema;
 
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import org.apache.arrow.memory.ArrowBuf;
 import org.apache.arrow.memory.RootAllocator;
 import org.apache.arrow.vector.BigIntVector;
 import org.apache.arrow.vector.BitVector;
@@ -39,6 +41,10 @@ import org.apache.arrow.vector.TinyIntVector;
 import org.apache.arrow.vector.VarBinaryVector;
 import org.apache.arrow.vector.VarCharVector;
 import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.complex.MapVector;
+import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.complex.impl.NullableStructWriter;
+import org.apache.arrow.vector.complex.impl.UnionMapWriter;
 import org.apache.arrow.vector.dictionary.DictionaryProvider;
 import org.apache.arrow.vector.ipc.ArrowStreamWriter;
 import org.apache.arrow.vector.types.FloatingPointPrecision;
@@ -53,11 +59,13 @@ import org.junit.Test;
 import org.junit.rules.ExpectedException;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import scala.collection.JavaConverters;
 
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.math.BigDecimal;
 import java.math.BigInteger;
+import java.nio.charset.StandardCharsets;
 import java.sql.Date;
 import java.util.Arrays;
 import java.util.List;
@@ -100,7 +108,7 @@ public class TestRowBatch {
         root.setRowCount(3);
 
         FieldVector vector = root.getVector("k0");
-        BitVector bitVector = (BitVector)vector;
+        BitVector bitVector = (BitVector) vector;
         bitVector.setInitialCapacity(3);
         bitVector.allocateNew(3);
         bitVector.setSafe(0, 1);
@@ -109,7 +117,7 @@ public class TestRowBatch {
         vector.setValueCount(3);
 
         vector = root.getVector("k1");
-        TinyIntVector tinyIntVector = (TinyIntVector)vector;
+        TinyIntVector tinyIntVector = (TinyIntVector) vector;
         tinyIntVector.setInitialCapacity(3);
         tinyIntVector.allocateNew(3);
         tinyIntVector.setSafe(0, 1);
@@ -118,7 +126,7 @@ public class TestRowBatch {
         vector.setValueCount(3);
 
         vector = root.getVector("k2");
-        SmallIntVector smallIntVector = (SmallIntVector)vector;
+        SmallIntVector smallIntVector = (SmallIntVector) vector;
         smallIntVector.setInitialCapacity(3);
         smallIntVector.allocateNew(3);
         smallIntVector.setSafe(0, 1);
@@ -127,7 +135,7 @@ public class TestRowBatch {
         vector.setValueCount(3);
 
         vector = root.getVector("k3");
-        IntVector intVector = (IntVector)vector;
+        IntVector intVector = (IntVector) vector;
         intVector.setInitialCapacity(3);
         intVector.allocateNew(3);
         intVector.setSafe(0, 1);
@@ -136,7 +144,7 @@ public class TestRowBatch {
         vector.setValueCount(3);
 
         vector = root.getVector("k4");
-        BigIntVector bigIntVector = (BigIntVector)vector;
+        BigIntVector bigIntVector = (BigIntVector) vector;
         bigIntVector.setInitialCapacity(3);
         bigIntVector.allocateNew(3);
         bigIntVector.setSafe(0, 1);
@@ -145,7 +153,7 @@ public class TestRowBatch {
         vector.setValueCount(3);
 
         vector = root.getVector("k5");
-        VarCharVector varCharVector = (VarCharVector)vector;
+        VarCharVector varCharVector = (VarCharVector) vector;
         varCharVector.setInitialCapacity(3);
         varCharVector.allocateNew();
         varCharVector.setIndexDefined(0);
@@ -160,7 +168,7 @@ public class TestRowBatch {
         vector.setValueCount(3);
 
         vector = root.getVector("k6");
-        VarCharVector charVector = (VarCharVector)vector;
+        VarCharVector charVector = (VarCharVector) vector;
         charVector.setInitialCapacity(3);
         charVector.allocateNew();
         charVector.setIndexDefined(0);
@@ -175,7 +183,7 @@ public class TestRowBatch {
         vector.setValueCount(3);
 
         vector = root.getVector("k8");
-        Float8Vector float8Vector = (Float8Vector)vector;
+        Float8Vector float8Vector = (Float8Vector) vector;
         float8Vector.setInitialCapacity(3);
         float8Vector.allocateNew(3);
         float8Vector.setSafe(0, 1.1);
@@ -184,7 +192,7 @@ public class TestRowBatch {
         vector.setValueCount(3);
 
         vector = root.getVector("k9");
-        Float4Vector float4Vector = (Float4Vector)vector;
+        Float4Vector float4Vector = (Float4Vector) vector;
         float4Vector.setInitialCapacity(3);
         float4Vector.allocateNew(3);
         float4Vector.setSafe(0, 1.1f);
@@ -193,7 +201,7 @@ public class TestRowBatch {
         vector.setValueCount(3);
 
         vector = root.getVector("k10");
-        VarCharVector datecharVector = (VarCharVector)vector;
+        VarCharVector datecharVector = (VarCharVector) vector;
         datecharVector.setInitialCapacity(3);
         datecharVector.allocateNew();
         datecharVector.setIndexDefined(0);
@@ -208,7 +216,7 @@ public class TestRowBatch {
         vector.setValueCount(3);
 
         vector = root.getVector("k11");
-        VarCharVector timecharVector = (VarCharVector)vector;
+        VarCharVector timecharVector = (VarCharVector) vector;
         timecharVector.setInitialCapacity(3);
         timecharVector.allocateNew();
         timecharVector.setIndexDefined(0);
@@ -364,15 +372,15 @@ public class TestRowBatch {
 
         Assert.assertTrue(rowBatch.hasNext());
         List<Object> actualRow0 = rowBatch.next();
-        Assert.assertArrayEquals(binaryRow0, (byte[])actualRow0.get(0));
+        Assert.assertArrayEquals(binaryRow0, (byte[]) actualRow0.get(0));
 
         Assert.assertTrue(rowBatch.hasNext());
         List<Object> actualRow1 = rowBatch.next();
-        Assert.assertArrayEquals(binaryRow1, (byte[])actualRow1.get(0));
+        Assert.assertArrayEquals(binaryRow1, (byte[]) actualRow1.get(0));
 
         Assert.assertTrue(rowBatch.hasNext());
         List<Object> actualRow2 = rowBatch.next();
-        Assert.assertArrayEquals(binaryRow2, (byte[])actualRow2.get(0));
+        Assert.assertArrayEquals(binaryRow2, (byte[]) actualRow2.get(0));
 
         Assert.assertFalse(rowBatch.hasNext());
         thrown.expect(NoSuchElementException.class);
@@ -428,15 +436,15 @@ public class TestRowBatch {
 
         Assert.assertTrue(rowBatch.hasNext());
         List<Object> actualRow0 = rowBatch.next();
-        Assert.assertEquals(Decimal.apply(12340000000L, 11, 9), 
(Decimal)actualRow0.get(0));
+        Assert.assertEquals(Decimal.apply(12340000000L, 11, 9), (Decimal) 
actualRow0.get(0));
 
         Assert.assertTrue(rowBatch.hasNext());
         List<Object> actualRow1 = rowBatch.next();
-        Assert.assertEquals(Decimal.apply(88880000000L, 11, 9), 
(Decimal)actualRow1.get(0));
+        Assert.assertEquals(Decimal.apply(88880000000L, 11, 9), (Decimal) 
actualRow1.get(0));
 
         Assert.assertTrue(rowBatch.hasNext());
         List<Object> actualRow2 = rowBatch.next();
-        Assert.assertEquals(Decimal.apply(10000000000L, 11, 9), 
(Decimal)actualRow2.get(0));
+        Assert.assertEquals(Decimal.apply(10000000000L, 11, 9), (Decimal) 
actualRow2.get(0));
 
         Assert.assertFalse(rowBatch.hasNext());
         thrown.expect(NoSuchElementException.class);
@@ -591,4 +599,142 @@ public class TestRowBatch {
 
     }
 
+    @Test
+    public void testMap() throws IOException, DorisException {
+
+        ImmutableList<Field> mapChildren = ImmutableList.of(
+                new Field("child", new FieldType(false, new 
ArrowType.Struct(), null),
+                        ImmutableList.of(
+                                new Field("key", new FieldType(false, new 
ArrowType.Utf8(), null), null),
+                                new Field("value", new FieldType(false, new 
ArrowType.Int(32, true), null),
+                                        null)
+                        )
+                ));
+
+        ImmutableList<Field> fields = ImmutableList.of(
+                new Field("col_map", new FieldType(false, new 
ArrowType.Map(false), null),
+                        mapChildren)
+        );
+
+        RootAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+        VectorSchemaRoot root = VectorSchemaRoot.create(
+                new org.apache.arrow.vector.types.pojo.Schema(fields, null), 
allocator);
+        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+        ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter(
+                root,
+                new DictionaryProvider.MapDictionaryProvider(),
+                outputStream);
+
+        arrowStreamWriter.start();
+        root.setRowCount(3);
+
+        MapVector mapVector = (MapVector) root.getVector("col_map");
+        mapVector.allocateNew();
+        UnionMapWriter mapWriter = mapVector.getWriter();
+        for (int i = 0; i < 3; i++) {
+            mapWriter.setPosition(i);
+            mapWriter.startMap();
+            mapWriter.startEntry();
+            String key = "k" + (i + 1);
+            byte[] bytes = key.getBytes(StandardCharsets.UTF_8);
+            ArrowBuf buffer = allocator.buffer(bytes.length);
+            buffer.setBytes(0, bytes);
+            mapWriter.key().varChar().writeVarChar(0, bytes.length, buffer);
+            buffer.close();
+            mapWriter.value().integer().writeInt(i);
+            mapWriter.endEntry();
+            mapWriter.endMap();
+        }
+        mapWriter.setValueCount(3);
+
+        arrowStreamWriter.writeBatch();
+
+        arrowStreamWriter.end();
+        arrowStreamWriter.close();
+
+        TStatus status = new TStatus();
+        status.setStatusCode(TStatusCode.OK);
+        TScanBatchResult scanBatchResult = new TScanBatchResult();
+        scanBatchResult.setStatus(status);
+        scanBatchResult.setEos(false);
+        scanBatchResult.setRows(outputStream.toByteArray());
+
+        String schemaStr = 
"{\"properties\":[{\"type\":\"MAP\",\"name\":\"col_map\",\"comment\":\"\"}" +
+                "], \"status\":200}";
+
+
+        Schema schema = RestService.parseSchema(schemaStr, logger);
+
+        RowBatch rowBatch = new RowBatch(scanBatchResult, schema);
+        Assert.assertTrue(rowBatch.hasNext());
+        
Assert.assertEquals(JavaConverters.mapAsScalaMapConverter(ImmutableMap.of("k1", 
"0")).asScala(),
+                rowBatch.next().get(0));
+        Assert.assertTrue(rowBatch.hasNext());
+        
Assert.assertEquals(JavaConverters.mapAsScalaMapConverter(ImmutableMap.of("k2", 
"1")).asScala(),
+                rowBatch.next().get(0));
+        Assert.assertTrue(rowBatch.hasNext());
+        
Assert.assertEquals(JavaConverters.mapAsScalaMapConverter(ImmutableMap.of("k3", 
"2")).asScala(),
+                rowBatch.next().get(0));
+        Assert.assertFalse(rowBatch.hasNext());
+
+    }
+
+    @Test
+    public void testStruct() throws IOException, DorisException {
+
+        ImmutableList<Field> fields = ImmutableList.of(
+                new Field("col_struct", new FieldType(false, new 
ArrowType.Struct(), null),
+                        ImmutableList.of(new Field("a", new FieldType(false, 
new ArrowType.Utf8(), null), null),
+                                new Field("b", new FieldType(false, new 
ArrowType.Int(32, true), null), null))
+                ));
+
+        RootAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+        VectorSchemaRoot root = VectorSchemaRoot.create(
+                new org.apache.arrow.vector.types.pojo.Schema(fields, null), 
allocator);
+        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+        ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter(
+                root,
+                new DictionaryProvider.MapDictionaryProvider(),
+                outputStream);
+
+        arrowStreamWriter.start();
+        root.setRowCount(3);
+
+        StructVector structVector = (StructVector) 
root.getVector("col_struct");
+        structVector.allocateNew();
+        NullableStructWriter writer = structVector.getWriter();
+        writer.setPosition(0);
+        writer.start();
+        byte[] bytes = "a1".getBytes(StandardCharsets.UTF_8);
+        ArrowBuf buffer = allocator.buffer(bytes.length);
+        buffer.setBytes(0, bytes);
+        writer.varChar("a").writeVarChar(0, bytes.length, buffer);
+        buffer.close();
+        writer.integer("b").writeInt(1);
+        writer.end();
+        writer.setValueCount(1);
+
+        arrowStreamWriter.writeBatch();
+
+        arrowStreamWriter.end();
+        arrowStreamWriter.close();
+
+        TStatus status = new TStatus();
+        status.setStatusCode(TStatusCode.OK);
+        TScanBatchResult scanBatchResult = new TScanBatchResult();
+        scanBatchResult.setStatus(status);
+        scanBatchResult.setEos(false);
+        scanBatchResult.setRows(outputStream.toByteArray());
+
+        String schemaStr = 
"{\"properties\":[{\"type\":\"STRUCT\",\"name\":\"col_struct\",\"comment\":\"\"}"
 +
+                "], \"status\":200}";
+
+        Schema schema = RestService.parseSchema(schemaStr, logger);
+
+        RowBatch rowBatch = new RowBatch(scanBatchResult, schema);
+        Assert.assertTrue(rowBatch.hasNext());
+        Assert.assertEquals("{\"a\":\"a1\",\"b\":1}", rowBatch.next().get(0));
+
+    }
+
 }
diff --git 
a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestSparkConnector.scala
 
b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestSparkConnector.scala
index 54771df..3f05da2 100644
--- 
a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestSparkConnector.scala
+++ 
b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestSparkConnector.scala
@@ -115,5 +115,6 @@ class TestSparkConnector {
       .start().awaitTermination()
     spark.stop()
   }
+
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to