This is an automated email from the ASF dual-hosted git repository.

jiafengzheng pushed a commit to branch master
in repository 
https://gitbox.apache.org/repos/asf/incubator-doris-flink-connector.git


The following commit(s) were added to refs/heads/master by this push:
     new 19e24c7  [Bug] Fix row type decimal convert bug (#26)
19e24c7 is described below

commit 19e24c741e79a24acb480256bbedfe9351c5d4dd
Author: aiwenmo <32723967+aiwe...@users.noreply.github.com>
AuthorDate: Fri Apr 15 11:01:08 2022 +0800

    [Bug] Fix row type decimal convert bug (#26)
    
    * Fix row type decimal convert bug
---
 .../apache/doris/flink/serialization/RowBatch.java | 13 +++----
 .../doris/flink/table/DorisDynamicTableSource.java |  5 ++-
 .../doris/flink/table/DorisRowDataInputFormat.java | 45 ++++++++++++++++++----
 .../doris/flink/serialization/TestRowBatch.java    | 32 ++++++++-------
 4 files changed, 65 insertions(+), 30 deletions(-)

diff --git 
a/flink-doris-connector/src/main/java/org/apache/doris/flink/serialization/RowBatch.java
 
b/flink-doris-connector/src/main/java/org/apache/doris/flink/serialization/RowBatch.java
index 3337637..d235aa9 100644
--- 
a/flink-doris-connector/src/main/java/org/apache/doris/flink/serialization/RowBatch.java
+++ 
b/flink-doris-connector/src/main/java/org/apache/doris/flink/serialization/RowBatch.java
@@ -18,7 +18,6 @@
 package org.apache.doris.flink.serialization;
 
 import org.apache.arrow.memory.RootAllocator;
-
 import org.apache.arrow.vector.BigIntVector;
 import org.apache.arrow.vector.BitVector;
 import org.apache.arrow.vector.DecimalVector;
@@ -36,12 +35,7 @@ import org.apache.arrow.vector.types.Types;
 import org.apache.doris.flink.exception.DorisException;
 import org.apache.doris.flink.rest.models.Schema;
 import org.apache.doris.thrift.TScanBatchResult;
-
-import org.apache.flink.table.data.DecimalData;
-import org.apache.flink.table.data.StringData;
 import org.apache.flink.util.Preconditions;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 import java.io.ByteArrayInputStream;
 import java.io.IOException;
@@ -50,6 +44,9 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.NoSuchElementException;
 
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
 /**
  * row batch data container.
  */
@@ -243,7 +240,7 @@ public class RowBatch {
                                 continue;
                             }
                             BigDecimal value = 
decimalVector.getObject(rowIndex).stripTrailingZeros();
-                            addValueToRow(rowIndex, 
DecimalData.fromBigDecimal(value, value.precision(), value.scale()));
+                            addValueToRow(rowIndex, value);
                         }
                         break;
                     case "DATE":
@@ -261,7 +258,7 @@ public class RowBatch {
                                 continue;
                             }
                             String value = new 
String(varCharVector.get(rowIndex));
-                            addValueToRow(rowIndex, 
StringData.fromString(value));
+                            addValueToRow(rowIndex, value);
                         }
                         break;
                     default:
diff --git 
a/flink-doris-connector/src/main/java/org/apache/doris/flink/table/DorisDynamicTableSource.java
 
b/flink-doris-connector/src/main/java/org/apache/doris/flink/table/DorisDynamicTableSource.java
index 0262677..689aa47 100644
--- 
a/flink-doris-connector/src/main/java/org/apache/doris/flink/table/DorisDynamicTableSource.java
+++ 
b/flink-doris-connector/src/main/java/org/apache/doris/flink/table/DorisDynamicTableSource.java
@@ -33,6 +33,8 @@ import 
org.apache.flink.table.connector.source.ScanTableSource;
 import 
org.apache.flink.table.connector.source.abilities.SupportsFilterPushDown;
 import 
org.apache.flink.table.connector.source.abilities.SupportsProjectionPushDown;
 import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.types.logical.RowType;
+
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -80,7 +82,8 @@ public final class DorisDynamicTableSource implements 
ScanTableSource, LookupTab
                 .setPassword(options.getPassword())
                 .setTableIdentifier(options.getTableIdentifier())
                 .setPartitions(dorisPartitions)
-                .setReadOptions(readOptions);
+                .setReadOptions(readOptions)
+                .setRowType((RowType) 
physicalSchema.toRowDataType().getLogicalType());
         return InputFormatProvider.of(builder.build());
     }
 
diff --git 
a/flink-doris-connector/src/main/java/org/apache/doris/flink/table/DorisRowDataInputFormat.java
 
b/flink-doris-connector/src/main/java/org/apache/doris/flink/table/DorisRowDataInputFormat.java
index c75a88f..be1e13d 100644
--- 
a/flink-doris-connector/src/main/java/org/apache/doris/flink/table/DorisRowDataInputFormat.java
+++ 
b/flink-doris-connector/src/main/java/org/apache/doris/flink/table/DorisRowDataInputFormat.java
@@ -29,16 +29,23 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.io.InputSplitAssigner;
+import org.apache.flink.table.data.DecimalData;
 import org.apache.flink.table.data.GenericRowData;
 import org.apache.flink.table.data.RowData;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import org.apache.flink.table.data.StringData;
+import org.apache.flink.table.types.logical.DecimalType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.RowType;
 
 import java.io.IOException;
+import java.math.BigDecimal;
 import java.sql.PreparedStatement;
 import java.util.ArrayList;
 import java.util.List;
 
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
 /**
  * InputFormat for {@link DorisDynamicTableSource}.
  */
@@ -56,10 +63,13 @@ public class DorisRowDataInputFormat extends 
RichInputFormat<RowData, DorisTable
     private ScalaValueReader scalaValueReader;
     private transient boolean hasNext;
 
-    public DorisRowDataInputFormat(DorisOptions options, 
List<PartitionDefinition> dorisPartitions, DorisReadOptions readOptions) {
+    private RowType rowType;
+
+    public DorisRowDataInputFormat(DorisOptions options, 
List<PartitionDefinition> dorisPartitions, DorisReadOptions readOptions, 
RowType rowType) {
         this.options = options;
         this.dorisPartitions = dorisPartitions;
         this.readOptions = readOptions;
+        this.rowType = rowType;
     }
 
     @Override
@@ -136,15 +146,30 @@ public class DorisRowDataInputFormat extends 
RichInputFormat<RowData, DorisTable
             return null;
         }
         List next = (List) scalaValueReader.next();
-        GenericRowData genericRowData = new GenericRowData(next.size());
-        for (int i = 0; i < next.size(); i++) {
-            genericRowData.setField(i, next.get(i));
+        GenericRowData genericRowData = new 
GenericRowData(rowType.getFieldCount());
+        for (int i = 0; i < next.size() && i < rowType.getFieldCount(); i++) {
+            Object value = deserialize(rowType.getTypeAt(i), next.get(i));
+            genericRowData.setField(i, value);
         }
         //update hasNext after we've read the record
         hasNext = scalaValueReader.hasNext();
         return genericRowData;
     }
 
+    private Object deserialize(LogicalType type, Object val) {
+        switch (type.getTypeRoot()) {
+            case DECIMAL:
+                final DecimalType decimalType = ((DecimalType) type);
+                final int precision = decimalType.getPrecision();
+                final int scala = decimalType.getScale();
+                return DecimalData.fromBigDecimal((BigDecimal) val, precision, 
scala);
+            case VARCHAR:
+                return StringData.fromString((String) val);
+            default:
+                return val;
+        }
+    }
+
     @Override
     public BaseStatistics getStatistics(BaseStatistics cachedStatistics) 
throws IOException {
         return cachedStatistics;
@@ -182,6 +207,7 @@ public class DorisRowDataInputFormat extends 
RichInputFormat<RowData, DorisTable
         private DorisOptions.Builder optionsBuilder;
         private List<PartitionDefinition> partitions;
         private DorisReadOptions readOptions;
+        private RowType rowType;
 
 
         public Builder() {
@@ -218,9 +244,14 @@ public class DorisRowDataInputFormat extends 
RichInputFormat<RowData, DorisTable
             return this;
         }
 
+        public Builder setRowType(RowType rowType) {
+            this.rowType = rowType;
+            return this;
+        }
+
         public DorisRowDataInputFormat build() {
             return new DorisRowDataInputFormat(
-                    optionsBuilder.build(), partitions, readOptions
+                optionsBuilder.build(), partitions, readOptions, rowType
             );
         }
     }
diff --git 
a/flink-doris-connector/src/test/java/org/apache/doris/flink/serialization/TestRowBatch.java
 
b/flink-doris-connector/src/test/java/org/apache/doris/flink/serialization/TestRowBatch.java
index 424a7be..8b66e01 100644
--- 
a/flink-doris-connector/src/test/java/org/apache/doris/flink/serialization/TestRowBatch.java
+++ 
b/flink-doris-connector/src/test/java/org/apache/doris/flink/serialization/TestRowBatch.java
@@ -44,7 +44,6 @@ import org.apache.doris.thrift.TStatusCode;
 import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableList;
 import org.apache.flink.calcite.shaded.com.google.common.collect.Lists;
 import org.apache.flink.table.data.DecimalData;
-import org.apache.flink.table.data.StringData;
 import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
@@ -248,10 +247,10 @@ public class TestRowBatch {
                 1L,
                 (float) 1.1,
                 (double) 1.1,
-                StringData.fromString("2008-08-08"),
-                StringData.fromString("2008-08-08 00:00:00"),
+                "2008-08-08",
+                "2008-08-08 00:00:00",
                 DecimalData.fromBigDecimal(new BigDecimal(12.34), 4, 2),
-                StringData.fromString("char1")
+                "char1"
         );
 
         List<Object> expectedRow2 = Arrays.asList(
@@ -262,10 +261,10 @@ public class TestRowBatch {
                 2L,
                 (float) 2.2,
                 (double) 2.2,
-                StringData.fromString("1900-08-08"),
-                StringData.fromString("1900-08-08 00:00:00"),
+                "1900-08-08",
+                "1900-08-08 00:00:00",
                 DecimalData.fromBigDecimal(new BigDecimal(88.88), 4, 2),
-                StringData.fromString("char2")
+                "char2"
         );
 
         List<Object> expectedRow3 = Arrays.asList(
@@ -276,22 +275,25 @@ public class TestRowBatch {
                 3L,
                 (float) 3.3,
                 (double) 3.3,
-                StringData.fromString("2100-08-08"),
-                StringData.fromString("2100-08-08 00:00:00"),
+                "2100-08-08",
+                "2100-08-08 00:00:00",
                 DecimalData.fromBigDecimal(new BigDecimal(10.22), 4, 2),
-                StringData.fromString("char3")
+                "char3"
         );
 
         Assert.assertTrue(rowBatch.hasNext());
         List<Object> actualRow1 = rowBatch.next();
+        actualRow1.set(9, DecimalData.fromBigDecimal((BigDecimal) 
actualRow1.get(9), 4, 2));
         Assert.assertEquals(expectedRow1, actualRow1);
 
         Assert.assertTrue(rowBatch.hasNext());
         List<Object> actualRow2 = rowBatch.next();
+        actualRow2.set(9, DecimalData.fromBigDecimal((BigDecimal) 
actualRow2.get(9), 4, 2));
         Assert.assertEquals(expectedRow2, actualRow2);
 
         Assert.assertTrue(rowBatch.hasNext());
         List<Object> actualRow3 = rowBatch.next();
+        actualRow3.set(9, DecimalData.fromBigDecimal((BigDecimal) 
actualRow3.get(9), 4, 2));
         Assert.assertEquals(expectedRow3, actualRow3);
 
         Assert.assertFalse(rowBatch.hasNext());
@@ -420,16 +422,18 @@ public class TestRowBatch {
 
         Assert.assertTrue(rowBatch.hasNext());
         List<Object> actualRow0 = rowBatch.next();
-        Assert.assertEquals(DecimalData.fromBigDecimal(new 
BigDecimal(12.340000000), 11, 9), actualRow0.get(0));
+        Assert.assertEquals(DecimalData.fromBigDecimal(new 
BigDecimal(12.340000000), 11, 9),
+            DecimalData.fromBigDecimal((BigDecimal) actualRow0.get(0), 11, 9));
 
         Assert.assertTrue(rowBatch.hasNext());
         List<Object> actualRow1 = rowBatch.next();
-
-        Assert.assertEquals(DecimalData.fromBigDecimal(new 
BigDecimal(88.880000000), 11, 9),  actualRow1.get(0));
+        Assert.assertEquals(DecimalData.fromBigDecimal(new 
BigDecimal(88.880000000), 11, 9),
+            DecimalData.fromBigDecimal((BigDecimal) actualRow1.get(0), 11, 9));
 
         Assert.assertTrue(rowBatch.hasNext());
         List<Object> actualRow2 = rowBatch.next();
-        Assert.assertEquals(DecimalData.fromBigDecimal(new 
BigDecimal(10.000000000),11, 9), actualRow2.get(0));
+        Assert.assertEquals(DecimalData.fromBigDecimal(new 
BigDecimal(10.000000000), 11, 9),
+            DecimalData.fromBigDecimal((BigDecimal) actualRow2.get(0), 11, 9));
 
         Assert.assertFalse(rowBatch.hasNext());
         thrown.expect(NoSuchElementException.class);


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to