This is an automated email from the ASF dual-hosted git repository. diwu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git
The following commit(s) were added to refs/heads/master by this push: new 7a45abe [feature] support read map and struct type (#116) 7a45abe is described below commit 7a45abe3f955e938f6161e95c43a911464996b5a Author: gnehil <adamlee...@gmail.com> AuthorDate: Tue Oct 17 11:58:35 2023 +0800 [feature] support read map and struct type (#116) --- .../apache/doris/spark/serialization/RowBatch.java | 37 +++++ .../org/apache/doris/spark/sql/SchemaUtils.scala | 3 +- .../doris/spark/serialization/TestRowBatch.java | 180 +++++++++++++++++++-- .../doris/spark/sql/TestSparkConnector.scala | 1 + 4 files changed, 203 insertions(+), 18 deletions(-) diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java index 3d66db5..b43b0a2 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java @@ -37,12 +37,16 @@ import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.impl.UnionMapReader; import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.arrow.vector.types.Types; import org.apache.commons.lang3.ArrayUtils; import org.apache.spark.sql.types.Decimal; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import scala.collection.JavaConverters; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -52,7 +56,9 @@ import java.nio.charset.StandardCharsets; import java.sql.Date; import java.time.LocalDate; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.NoSuchElementException; /** @@ -338,6 +344,37 @@ public class RowBatch { addValueToRow(rowIndex, value); } break; + case "MAP": + Preconditions.checkArgument(mt.equals(Types.MinorType.MAP), + typeMismatchMessage(currentType, mt)); + MapVector mapVector = (MapVector) curFieldVector; + UnionMapReader reader = mapVector.getReader(); + for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { + if (mapVector.isNull(rowIndex)) { + addValueToRow(rowIndex, null); + continue; + } + reader.setPosition(rowIndex); + Map<String, String> value = new HashMap<>(); + while (reader.next()) { + value.put(reader.key().readObject().toString(), reader.value().readObject().toString()); + } + addValueToRow(rowIndex, JavaConverters.mapAsScalaMapConverter(value).asScala()); + } + break; + case "STRUCT": + Preconditions.checkArgument(mt.equals(Types.MinorType.STRUCT), + typeMismatchMessage(currentType, mt)); + StructVector structVector = (StructVector) curFieldVector; + for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { + if (structVector.isNull(rowIndex)) { + addValueToRow(rowIndex, null); + continue; + } + String value = structVector.getObject(rowIndex).toString(); + addValueToRow(rowIndex, value); + } + break; default: String errMsg = "Unsupported type " + schema.get(col).getType(); logger.error(errMsg); diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala index 677cc2e..44baa95 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala @@ -32,7 +32,6 @@ import org.slf4j.LoggerFactory import java.sql.Timestamp import java.time.{LocalDateTime, ZoneOffset} import scala.collection.JavaConversions._ -import scala.collection.JavaConverters._ import scala.collection.mutable private[spark] object SchemaUtils { @@ -126,6 +125,8 @@ private[spark] object SchemaUtils { case "TIME" => DataTypes.DoubleType case "STRING" => DataTypes.StringType case "ARRAY" => DataTypes.StringType + case "MAP" => MapType(DataTypes.StringType, DataTypes.StringType) + case "STRUCT" => DataTypes.StringType case "HLL" => throw new DorisException("Unsupported type " + dorisType) case _ => diff --git a/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java b/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java index ace928f..cb7e0b8 100644 --- a/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java +++ b/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java @@ -25,6 +25,8 @@ import org.apache.doris.spark.rest.RestService; import org.apache.doris.spark.rest.models.Schema; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; @@ -39,6 +41,10 @@ import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.impl.NullableStructWriter; +import org.apache.arrow.vector.complex.impl.UnionMapWriter; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.ArrowStreamWriter; import org.apache.arrow.vector.types.FloatingPointPrecision; @@ -53,11 +59,13 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import scala.collection.JavaConverters; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.charset.StandardCharsets; import java.sql.Date; import java.util.Arrays; import java.util.List; @@ -100,7 +108,7 @@ public class TestRowBatch { root.setRowCount(3); FieldVector vector = root.getVector("k0"); - BitVector bitVector = (BitVector)vector; + BitVector bitVector = (BitVector) vector; bitVector.setInitialCapacity(3); bitVector.allocateNew(3); bitVector.setSafe(0, 1); @@ -109,7 +117,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k1"); - TinyIntVector tinyIntVector = (TinyIntVector)vector; + TinyIntVector tinyIntVector = (TinyIntVector) vector; tinyIntVector.setInitialCapacity(3); tinyIntVector.allocateNew(3); tinyIntVector.setSafe(0, 1); @@ -118,7 +126,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k2"); - SmallIntVector smallIntVector = (SmallIntVector)vector; + SmallIntVector smallIntVector = (SmallIntVector) vector; smallIntVector.setInitialCapacity(3); smallIntVector.allocateNew(3); smallIntVector.setSafe(0, 1); @@ -127,7 +135,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k3"); - IntVector intVector = (IntVector)vector; + IntVector intVector = (IntVector) vector; intVector.setInitialCapacity(3); intVector.allocateNew(3); intVector.setSafe(0, 1); @@ -136,7 +144,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k4"); - BigIntVector bigIntVector = (BigIntVector)vector; + BigIntVector bigIntVector = (BigIntVector) vector; bigIntVector.setInitialCapacity(3); bigIntVector.allocateNew(3); bigIntVector.setSafe(0, 1); @@ -145,7 +153,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k5"); - VarCharVector varCharVector = (VarCharVector)vector; + VarCharVector varCharVector = (VarCharVector) vector; varCharVector.setInitialCapacity(3); varCharVector.allocateNew(); varCharVector.setIndexDefined(0); @@ -160,7 +168,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k6"); - VarCharVector charVector = (VarCharVector)vector; + VarCharVector charVector = (VarCharVector) vector; charVector.setInitialCapacity(3); charVector.allocateNew(); charVector.setIndexDefined(0); @@ -175,7 +183,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k8"); - Float8Vector float8Vector = (Float8Vector)vector; + Float8Vector float8Vector = (Float8Vector) vector; float8Vector.setInitialCapacity(3); float8Vector.allocateNew(3); float8Vector.setSafe(0, 1.1); @@ -184,7 +192,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k9"); - Float4Vector float4Vector = (Float4Vector)vector; + Float4Vector float4Vector = (Float4Vector) vector; float4Vector.setInitialCapacity(3); float4Vector.allocateNew(3); float4Vector.setSafe(0, 1.1f); @@ -193,7 +201,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k10"); - VarCharVector datecharVector = (VarCharVector)vector; + VarCharVector datecharVector = (VarCharVector) vector; datecharVector.setInitialCapacity(3); datecharVector.allocateNew(); datecharVector.setIndexDefined(0); @@ -208,7 +216,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k11"); - VarCharVector timecharVector = (VarCharVector)vector; + VarCharVector timecharVector = (VarCharVector) vector; timecharVector.setInitialCapacity(3); timecharVector.allocateNew(); timecharVector.setIndexDefined(0); @@ -364,15 +372,15 @@ public class TestRowBatch { Assert.assertTrue(rowBatch.hasNext()); List<Object> actualRow0 = rowBatch.next(); - Assert.assertArrayEquals(binaryRow0, (byte[])actualRow0.get(0)); + Assert.assertArrayEquals(binaryRow0, (byte[]) actualRow0.get(0)); Assert.assertTrue(rowBatch.hasNext()); List<Object> actualRow1 = rowBatch.next(); - Assert.assertArrayEquals(binaryRow1, (byte[])actualRow1.get(0)); + Assert.assertArrayEquals(binaryRow1, (byte[]) actualRow1.get(0)); Assert.assertTrue(rowBatch.hasNext()); List<Object> actualRow2 = rowBatch.next(); - Assert.assertArrayEquals(binaryRow2, (byte[])actualRow2.get(0)); + Assert.assertArrayEquals(binaryRow2, (byte[]) actualRow2.get(0)); Assert.assertFalse(rowBatch.hasNext()); thrown.expect(NoSuchElementException.class); @@ -428,15 +436,15 @@ public class TestRowBatch { Assert.assertTrue(rowBatch.hasNext()); List<Object> actualRow0 = rowBatch.next(); - Assert.assertEquals(Decimal.apply(12340000000L, 11, 9), (Decimal)actualRow0.get(0)); + Assert.assertEquals(Decimal.apply(12340000000L, 11, 9), (Decimal) actualRow0.get(0)); Assert.assertTrue(rowBatch.hasNext()); List<Object> actualRow1 = rowBatch.next(); - Assert.assertEquals(Decimal.apply(88880000000L, 11, 9), (Decimal)actualRow1.get(0)); + Assert.assertEquals(Decimal.apply(88880000000L, 11, 9), (Decimal) actualRow1.get(0)); Assert.assertTrue(rowBatch.hasNext()); List<Object> actualRow2 = rowBatch.next(); - Assert.assertEquals(Decimal.apply(10000000000L, 11, 9), (Decimal)actualRow2.get(0)); + Assert.assertEquals(Decimal.apply(10000000000L, 11, 9), (Decimal) actualRow2.get(0)); Assert.assertFalse(rowBatch.hasNext()); thrown.expect(NoSuchElementException.class); @@ -591,4 +599,142 @@ public class TestRowBatch { } + @Test + public void testMap() throws IOException, DorisException { + + ImmutableList<Field> mapChildren = ImmutableList.of( + new Field("child", new FieldType(false, new ArrowType.Struct(), null), + ImmutableList.of( + new Field("key", new FieldType(false, new ArrowType.Utf8(), null), null), + new Field("value", new FieldType(false, new ArrowType.Int(32, true), null), + null) + ) + )); + + ImmutableList<Field> fields = ImmutableList.of( + new Field("col_map", new FieldType(false, new ArrowType.Map(false), null), + mapChildren) + ); + + RootAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + VectorSchemaRoot root = VectorSchemaRoot.create( + new org.apache.arrow.vector.types.pojo.Schema(fields, null), allocator); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter( + root, + new DictionaryProvider.MapDictionaryProvider(), + outputStream); + + arrowStreamWriter.start(); + root.setRowCount(3); + + MapVector mapVector = (MapVector) root.getVector("col_map"); + mapVector.allocateNew(); + UnionMapWriter mapWriter = mapVector.getWriter(); + for (int i = 0; i < 3; i++) { + mapWriter.setPosition(i); + mapWriter.startMap(); + mapWriter.startEntry(); + String key = "k" + (i + 1); + byte[] bytes = key.getBytes(StandardCharsets.UTF_8); + ArrowBuf buffer = allocator.buffer(bytes.length); + buffer.setBytes(0, bytes); + mapWriter.key().varChar().writeVarChar(0, bytes.length, buffer); + buffer.close(); + mapWriter.value().integer().writeInt(i); + mapWriter.endEntry(); + mapWriter.endMap(); + } + mapWriter.setValueCount(3); + + arrowStreamWriter.writeBatch(); + + arrowStreamWriter.end(); + arrowStreamWriter.close(); + + TStatus status = new TStatus(); + status.setStatusCode(TStatusCode.OK); + TScanBatchResult scanBatchResult = new TScanBatchResult(); + scanBatchResult.setStatus(status); + scanBatchResult.setEos(false); + scanBatchResult.setRows(outputStream.toByteArray()); + + String schemaStr = "{\"properties\":[{\"type\":\"MAP\",\"name\":\"col_map\",\"comment\":\"\"}" + + "], \"status\":200}"; + + + Schema schema = RestService.parseSchema(schemaStr, logger); + + RowBatch rowBatch = new RowBatch(scanBatchResult, schema); + Assert.assertTrue(rowBatch.hasNext()); + Assert.assertEquals(JavaConverters.mapAsScalaMapConverter(ImmutableMap.of("k1", "0")).asScala(), + rowBatch.next().get(0)); + Assert.assertTrue(rowBatch.hasNext()); + Assert.assertEquals(JavaConverters.mapAsScalaMapConverter(ImmutableMap.of("k2", "1")).asScala(), + rowBatch.next().get(0)); + Assert.assertTrue(rowBatch.hasNext()); + Assert.assertEquals(JavaConverters.mapAsScalaMapConverter(ImmutableMap.of("k3", "2")).asScala(), + rowBatch.next().get(0)); + Assert.assertFalse(rowBatch.hasNext()); + + } + + @Test + public void testStruct() throws IOException, DorisException { + + ImmutableList<Field> fields = ImmutableList.of( + new Field("col_struct", new FieldType(false, new ArrowType.Struct(), null), + ImmutableList.of(new Field("a", new FieldType(false, new ArrowType.Utf8(), null), null), + new Field("b", new FieldType(false, new ArrowType.Int(32, true), null), null)) + )); + + RootAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + VectorSchemaRoot root = VectorSchemaRoot.create( + new org.apache.arrow.vector.types.pojo.Schema(fields, null), allocator); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter( + root, + new DictionaryProvider.MapDictionaryProvider(), + outputStream); + + arrowStreamWriter.start(); + root.setRowCount(3); + + StructVector structVector = (StructVector) root.getVector("col_struct"); + structVector.allocateNew(); + NullableStructWriter writer = structVector.getWriter(); + writer.setPosition(0); + writer.start(); + byte[] bytes = "a1".getBytes(StandardCharsets.UTF_8); + ArrowBuf buffer = allocator.buffer(bytes.length); + buffer.setBytes(0, bytes); + writer.varChar("a").writeVarChar(0, bytes.length, buffer); + buffer.close(); + writer.integer("b").writeInt(1); + writer.end(); + writer.setValueCount(1); + + arrowStreamWriter.writeBatch(); + + arrowStreamWriter.end(); + arrowStreamWriter.close(); + + TStatus status = new TStatus(); + status.setStatusCode(TStatusCode.OK); + TScanBatchResult scanBatchResult = new TScanBatchResult(); + scanBatchResult.setStatus(status); + scanBatchResult.setEos(false); + scanBatchResult.setRows(outputStream.toByteArray()); + + String schemaStr = "{\"properties\":[{\"type\":\"STRUCT\",\"name\":\"col_struct\",\"comment\":\"\"}" + + "], \"status\":200}"; + + Schema schema = RestService.parseSchema(schemaStr, logger); + + RowBatch rowBatch = new RowBatch(scanBatchResult, schema); + Assert.assertTrue(rowBatch.hasNext()); + Assert.assertEquals("{\"a\":\"a1\",\"b\":1}", rowBatch.next().get(0)); + + } + } diff --git a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestSparkConnector.scala b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestSparkConnector.scala index 54771df..3f05da2 100644 --- a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestSparkConnector.scala +++ b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestSparkConnector.scala @@ -115,5 +115,6 @@ class TestSparkConnector { .start().awaitTermination() spark.stop() } + } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org