This is an automated email from the ASF dual-hosted git repository.
xushiyan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git
The following commit(s) were added to refs/heads/master by this push:
new 3042b2c7c5 [HUDI-4525] Fixing Spark 3.3 `AvroSerializer`
implementation (#6279)
3042b2c7c5 is described below
commit 3042b2c7c54c91578dba69a1a814563fb00718d5
Author: Alexey Kudinkin <[email protected]>
AuthorDate: Wed Aug 3 14:27:21 2022 -0700
[HUDI-4525] Fixing Spark 3.3 `AvroSerializer` implementation (#6279)
---
.github/workflows/bot.yml | 2 +-
.../org/apache/hudi/io/HoodieAppendHandle.java | 6 +-
.../scala/org/apache/hudi/HoodieSparkUtils.scala | 1 +
.../TestConvertFilterToCatalystExpression.scala | 4 +-
.../org/apache/hudi/TestHoodieSparkSqlWriter.scala | 34 +++++-
.../org/apache/spark/sql/avro/AvroSerializer.scala | 27 ++++-
.../apache/spark/sql/avro/AvroDeserializer.scala | 35 +++---
.../org/apache/spark/sql/avro/AvroSerializer.scala | 121 ++++++++++++++++-----
8 files changed, 172 insertions(+), 58 deletions(-)
diff --git a/.github/workflows/bot.yml b/.github/workflows/bot.yml
index 26c07b96bf..3aa9bdbcc6 100644
--- a/.github/workflows/bot.yml
+++ b/.github/workflows/bot.yml
@@ -69,4 +69,4 @@ jobs:
FLINK_PROFILE: ${{ matrix.flinkProfile }}
if: ${{ !endsWith(env.SPARK_PROFILE, '2.4') }} # skip test spark 2.4
as it's covered by Azure CI
run:
- mvn test -Punit-tests -D"$SCALA_PROFILE" -D"$SPARK_PROFILE"
-D"$FLINK_PROFILE" '-Dtest=org.apache.spark.sql.hudi.Test*' -pl
hudi-spark-datasource/hudi-spark
+ mvn test -Punit-tests -D"$SCALA_PROFILE" -D"$SPARK_PROFILE"
-D"$FLINK_PROFILE" '-Dtest=Test*' -pl hudi-spark-datasource/hudi-spark
diff --git
a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/io/HoodieAppendHandle.java
b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/io/HoodieAppendHandle.java
index 426e20f83b..e0d40642a6 100644
---
a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/io/HoodieAppendHandle.java
+++
b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/io/HoodieAppendHandle.java
@@ -471,10 +471,12 @@ public class HoodieAppendHandle<T extends
HoodieRecordPayload, I, K, O> extends
return HoodieLogFormat.newWriterBuilder()
.onParentPath(FSUtils.getPartitionPath(hoodieTable.getMetaClient().getBasePath(),
partitionPath))
- .withFileId(fileId).overBaseCommit(baseCommitTime)
+ .withFileId(fileId)
+ .overBaseCommit(baseCommitTime)
.withLogVersion(latestLogFile.map(HoodieLogFile::getLogVersion).orElse(HoodieLogFile.LOGFILE_BASE_VERSION))
.withFileSize(latestLogFile.map(HoodieLogFile::getFileSize).orElse(0L))
- .withSizeThreshold(config.getLogFileMaxSize()).withFs(fs)
+ .withSizeThreshold(config.getLogFileMaxSize())
+ .withFs(fs)
.withRolloverLogWriteToken(writeToken)
.withLogWriteToken(latestLogFile.map(x ->
FSUtils.getWriteTokenFromLogPath(x.getPath())).orElse(writeToken))
.withFileExtension(HoodieLogFile.DELTA_EXTENSION).build();
diff --git
a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala
b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala
index 97bbe3e79b..a2f5d1ce97 100644
---
a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala
+++
b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala
@@ -54,6 +54,7 @@ private[hudi] trait SparkVersionsSupport {
def isSpark3_2: Boolean = getSparkVersion.startsWith("3.2")
def isSpark3_3: Boolean = getSparkVersion.startsWith("3.3")
+ def gteqSpark3_0: Boolean = getSparkVersion >= "3.0"
def gteqSpark3_1: Boolean = getSparkVersion >= "3.1"
def gteqSpark3_1_3: Boolean = getSparkVersion >= "3.1.3"
def gteqSpark3_2: Boolean = getSparkVersion >= "3.2"
diff --git
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestConvertFilterToCatalystExpression.scala
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestConvertFilterToCatalystExpression.scala
index 8aa47ffc2f..2d4498ac28 100644
---
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestConvertFilterToCatalystExpression.scala
+++
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestConvertFilterToCatalystExpression.scala
@@ -69,7 +69,7 @@ class TestConvertFilterToCatalystExpression {
private def checkConvertFilter(filter: Filter, expectExpression: String):
Unit = {
// [SPARK-25769][SPARK-34636][SPARK-34626][SQL] sql method in
UnresolvedAttribute,
// AttributeReference and Alias don't quote qualified names properly
- val removeQuotesIfNeed = if (expectExpression != null &&
HoodieSparkUtils.isSpark3_2) {
+ val removeQuotesIfNeed = if (expectExpression != null &&
HoodieSparkUtils.gteqSpark3_2) {
expectExpression.replace("`", "")
} else {
expectExpression
@@ -86,7 +86,7 @@ class TestConvertFilterToCatalystExpression {
private def checkConvertFilters(filters: Array[Filter], expectExpression:
String): Unit = {
// [SPARK-25769][SPARK-34636][SPARK-34626][SQL] sql method in
UnresolvedAttribute,
// AttributeReference and Alias don't quote qualified names properly
- val removeQuotesIfNeed = if (expectExpression != null &&
HoodieSparkUtils.isSpark3_2) {
+ val removeQuotesIfNeed = if (expectExpression != null &&
HoodieSparkUtils.gteqSpark3_2) {
expectExpression.replace("`", "")
} else {
expectExpression
diff --git
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieSparkSqlWriter.scala
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieSparkSqlWriter.scala
index 4829c44932..93469f2796 100644
---
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieSparkSqlWriter.scala
+++
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieSparkSqlWriter.scala
@@ -22,6 +22,7 @@ import java.time.Instant
import java.util.{Collections, Date, UUID}
import org.apache.commons.io.FileUtils
import org.apache.hudi.DataSourceWriteOptions._
+import org.apache.hudi.HoodieSparkUtils.gteqSpark3_0
import org.apache.hudi.client.SparkRDDWriteClient
import org.apache.hudi.common.model._
import org.apache.hudi.common.table.{HoodieTableConfig, HoodieTableMetaClient,
TableSchemaResolver}
@@ -41,7 +42,8 @@ import org.apache.spark.{SparkConf, SparkContext}
import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse,
assertTrue, fail}
import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
import org.junit.jupiter.params.ParameterizedTest
-import org.junit.jupiter.params.provider.{CsvSource, EnumSource, ValueSource}
+import org.junit.jupiter.params.provider.Arguments.arguments
+import org.junit.jupiter.params.provider.{Arguments, CsvSource, EnumSource,
MethodSource, ValueSource}
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.{spy, times, verify}
import org.scalatest.Assertions.assertThrows
@@ -485,11 +487,8 @@ class TestHoodieSparkSqlWriter {
* @param populateMetaFields Flag for populating meta fields
*/
@ParameterizedTest
- @CsvSource(
- Array("COPY_ON_WRITE,parquet,true", "COPY_ON_WRITE,parquet,false",
"MERGE_ON_READ,parquet,true", "MERGE_ON_READ,parquet,false",
- "COPY_ON_WRITE,orc,true", "COPY_ON_WRITE,orc,false",
"MERGE_ON_READ,orc,true", "MERGE_ON_READ,orc,false"
- ))
- def testDatasourceInsertForTableTypeBaseFileMetaFields(tableType: String,
baseFileFormat: String, populateMetaFields: Boolean): Unit = {
+ @MethodSource(Array("testDatasourceInsert"))
+ def testDatasourceInsertForTableTypeBaseFileMetaFields(tableType: String,
populateMetaFields: Boolean, baseFileFormat: String): Unit = {
val hoodieFooTableName = "hoodie_foo_tbl"
val fooTableModifier = Map("path" -> tempBasePath,
HoodieWriteConfig.TBL_NAME.key -> hoodieFooTableName,
@@ -1069,3 +1068,26 @@ class TestHoodieSparkSqlWriter {
assertTrue(kg2 == classOf[SimpleKeyGenerator].getName)
}
}
+
+object TestHoodieSparkSqlWriter {
+ def testDatasourceInsert: java.util.stream.Stream[Arguments] = {
+ val scenarios = Array(
+ Seq("COPY_ON_WRITE", true),
+ Seq("COPY_ON_WRITE", false),
+ Seq("MERGE_ON_READ", true),
+ Seq("MERGE_ON_READ", false)
+ )
+
+ val parquetScenarios = scenarios.map { _ :+ "parquet" }
+ val orcScenarios = scenarios.map { _ :+ "orc" }
+
+ // TODO(HUDI-4496) Fix Orc support in Spark 3.x
+ val targetScenarios = if (gteqSpark3_0) {
+ parquetScenarios
+ } else {
+ parquetScenarios ++ orcScenarios
+ }
+
+ java.util.Arrays.stream(targetScenarios.map(as =>
arguments(as.map(_.asInstanceOf[AnyRef]):_*)))
+ }
+}
diff --git
a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
index 73267f4147..ba9812b026 100644
---
a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
+++
b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
@@ -45,8 +45,13 @@ import java.util.TimeZone
* A serializer to serialize data in catalyst format to data in avro format.
*
* NOTE: This code is borrowed from Spark 3.2.1
- * This code is borrowed, so that we can better control compatibility w/in
Spark minor
- * branches (3.2.x, 3.1.x, etc)
+ * This code is borrowed, so that we can better control compatibility
w/in Spark minor
+ * branches (3.2.x, 3.1.x, etc)
+ *
+ * NOTE: THIS IMPLEMENTATION HAS BEEN MODIFIED FROM ITS ORIGINAL VERSION WITH
THE MODIFICATION
+ * BEING EXPLICITLY ANNOTATED INLINE. PLEASE MAKE SURE TO UNDERSTAND
PROPERLY ALL THE
+ * MODIFICATIONS.
+ *
*
* PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
*/
@@ -211,11 +216,20 @@ private[sql] class AvroSerializer(rootCatalystType:
DataType,
val numFields = st.length
(getter, ordinal) => structConverter(getter.getStruct(ordinal,
numFields))
+
////////////////////////////////////////////////////////////////////////////////////////////
+ // Following section is amended to the original (Spark's) implementation
+ // >>> BEGINS
+
////////////////////////////////////////////////////////////////////////////////////////////
+
case (st: StructType, UNION) =>
val unionConverter = newUnionConverter(st, avroType, catalystPath,
avroPath)
val numFields = st.length
(getter, ordinal) => unionConverter(getter.getStruct(ordinal,
numFields))
+
////////////////////////////////////////////////////////////////////////////////////////////
+ // <<< ENDS
+
////////////////////////////////////////////////////////////////////////////////////////////
+
case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType =>
val valueConverter = newConverter(
vt, resolveNullableType(avroType.getValueType, valueContainsNull),
@@ -293,6 +307,11 @@ private[sql] class AvroSerializer(rootCatalystType:
DataType,
result
}
+
////////////////////////////////////////////////////////////////////////////////////////////
+ // Following section is amended to the original (Spark's) implementation
+ // >>> BEGINS
+
////////////////////////////////////////////////////////////////////////////////////////////
+
private def newUnionConverter(catalystStruct: StructType,
avroUnion: Schema,
catalystPath: Seq[String],
@@ -337,6 +356,10 @@ private[sql] class AvroSerializer(rootCatalystType:
DataType,
avroStruct.getTypes.size() - 1 == catalystStruct.length) ||
avroStruct.getTypes.size() == catalystStruct.length
}
+
////////////////////////////////////////////////////////////////////////////////////////////
+ // <<< ENDS
+
////////////////////////////////////////////////////////////////////////////////////////////
+
/**
* Resolve a possibly nullable Avro Type.
*
diff --git
a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
index fbefb36ddc..5e7bab3e51 100644
---
a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
+++
b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
@@ -48,17 +48,15 @@ import java.util.TimeZone
*
* PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
*/
-private[sql] class AvroDeserializer(
- rootAvroType: Schema,
- rootCatalystType: DataType,
- positionalFieldMatch: Boolean,
- datetimeRebaseSpec: RebaseSpec,
- filters: StructFilters) {
-
- def this(
- rootAvroType: Schema,
- rootCatalystType: DataType,
- datetimeRebaseMode: String) = {
+private[sql] class AvroDeserializer(rootAvroType: Schema,
+ rootCatalystType: DataType,
+ positionalFieldMatch: Boolean,
+ datetimeRebaseSpec: RebaseSpec,
+ filters: StructFilters) {
+
+ def this(rootAvroType: Schema,
+ rootCatalystType: DataType,
+ datetimeRebaseMode: String) = {
this(
rootAvroType,
rootCatalystType,
@@ -69,11 +67,9 @@ private[sql] class AvroDeserializer(
private lazy val decimalConversions = new DecimalConversion()
- private val dateRebaseFunc = createDateRebaseFuncInRead(
- datetimeRebaseSpec.mode, "Avro")
+ private val dateRebaseFunc =
createDateRebaseFuncInRead(datetimeRebaseSpec.mode, "Avro")
- private val timestampRebaseFunc = createTimestampRebaseFuncInRead(
- datetimeRebaseSpec, "Avro")
+ private val timestampRebaseFunc =
createTimestampRebaseFuncInRead(datetimeRebaseSpec, "Avro")
private val converter: Any => Option[Any] = try {
rootCatalystType match {
@@ -112,11 +108,10 @@ private[sql] class AvroDeserializer(
* Creates a writer to write avro values to Catalyst values at the given
ordinal with the given
* updater.
*/
- private def newWriter(
- avroType: Schema,
- catalystType: DataType,
- avroPath: Seq[String],
- catalystPath: Seq[String]): (CatalystDataUpdater,
Int, Any) => Unit = {
+ private def newWriter(avroType: Schema,
+ catalystType: DataType,
+ avroPath: Seq[String],
+ catalystPath: Seq[String]): (CatalystDataUpdater, Int,
Any) => Unit = {
val errorPrefix = s"Cannot convert Avro ${toFieldStr(avroPath)} to " +
s"SQL ${toFieldStr(catalystPath)} because "
val incompatibleMsg = errorPrefix +
diff --git
a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
index 73d245d42d..450d9d7346 100644
---
a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
+++
b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
@@ -29,6 +29,7 @@ import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
import org.apache.avro.generic.GenericData.Record
import org.apache.avro.util.Utf8
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.avro.AvroSerializer.{createDateRebaseFuncInWrite,
createTimestampRebaseFuncInWrite}
import org.apache.spark.sql.avro.AvroUtils.{AvroMatchedField, toFieldStr}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters,
SpecificInternalRow}
@@ -44,17 +45,20 @@ import java.util.TimeZone
* A serializer to serialize data in catalyst format to data in avro format.
*
* NOTE: This code is borrowed from Spark 3.3.0
- * This code is borrowed, so that we can better control compatibility w/in
Spark minor
- * branches (3.2.x, 3.1.x, etc)
+ * This code is borrowed, so that we can better control compatibility
w/in Spark minor
+ * branches (3.2.x, 3.1.x, etc)
+ *
+ * NOTE: THIS IMPLEMENTATION HAS BEEN MODIFIED FROM ITS ORIGINAL VERSION WITH
THE MODIFICATION
+ * BEING EXPLICITLY ANNOTATED INLINE. PLEASE MAKE SURE TO UNDERSTAND
PROPERLY ALL THE
+ * MODIFICATIONS.
*
* PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
*/
-private[sql] class AvroSerializer(
- rootCatalystType: DataType,
- rootAvroType: Schema,
- nullable: Boolean,
- positionalFieldMatch: Boolean,
- datetimeRebaseMode:
LegacyBehaviorPolicy.Value) extends Logging {
+private[sql] class AvroSerializer(rootCatalystType: DataType,
+ rootAvroType: Schema,
+ nullable: Boolean,
+ positionalFieldMatch: Boolean,
+ datetimeRebaseMode:
LegacyBehaviorPolicy.Value) extends Logging {
def this(rootCatalystType: DataType, rootAvroType: Schema, nullable:
Boolean) = {
this(rootCatalystType, rootAvroType, nullable, positionalFieldMatch =
false,
@@ -65,10 +69,10 @@ private[sql] class AvroSerializer(
converter.apply(catalystData)
}
- private val dateRebaseFunc = DataSourceUtils.createDateRebaseFuncInWrite(
+ private val dateRebaseFunc = createDateRebaseFuncInWrite(
datetimeRebaseMode, "Avro")
- private val timestampRebaseFunc =
DataSourceUtils.createTimestampRebaseFuncInWrite(
+ private val timestampRebaseFunc = createTimestampRebaseFuncInWrite(
datetimeRebaseMode, "Avro")
private val converter: Any => Any = {
@@ -104,11 +108,10 @@ private[sql] class AvroSerializer(
private lazy val decimalConversions = new DecimalConversion()
- private def newConverter(
- catalystType: DataType,
- avroType: Schema,
- catalystPath: Seq[String],
- avroPath: Seq[String]): Converter = {
+ private def newConverter(catalystType: DataType,
+ avroType: Schema,
+ catalystPath: Seq[String],
+ avroPath: Seq[String]): Converter = {
val errorPrefix = s"Cannot convert SQL ${toFieldStr(catalystPath)} " +
s"to Avro ${toFieldStr(avroPath)} because "
(catalystType, avroType.getType) match {
@@ -162,6 +165,7 @@ private[sql] class AvroSerializer(
val data: Array[Byte] = getter.getBinary(ordinal)
if (data.length != size) {
def len2str(len: Int): String = s"$len ${if (len > 1) "bytes" else
"byte"}"
+
throw new IncompatibleSchemaException(errorPrefix +
len2str(data.length) +
" of binary data cannot be written into FIXED type with size of
" + len2str(size))
}
@@ -223,6 +227,20 @@ private[sql] class AvroSerializer(
val numFields = st.length
(getter, ordinal) => structConverter(getter.getStruct(ordinal,
numFields))
+
////////////////////////////////////////////////////////////////////////////////////////////
+ // Following section is amended to the original (Spark's) implementation
+ // >>> BEGINS
+
////////////////////////////////////////////////////////////////////////////////////////////
+
+ case (st: StructType, UNION) =>
+ val unionConverter = newUnionConverter(st, avroType, catalystPath,
avroPath)
+ val numFields = st.length
+ (getter, ordinal) => unionConverter(getter.getStruct(ordinal,
numFields))
+
+
////////////////////////////////////////////////////////////////////////////////////////////
+ // <<< ENDS
+
////////////////////////////////////////////////////////////////////////////////////////////
+
case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType =>
val valueConverter = newConverter(
vt, resolveNullableType(avroType.getValueType, valueContainsNull),
@@ -257,11 +275,10 @@ private[sql] class AvroSerializer(
}
}
- private def newStructConverter(
- catalystStruct: StructType,
- avroStruct: Schema,
- catalystPath: Seq[String],
- avroPath: Seq[String]): InternalRow =>
Record = {
+ private def newStructConverter(catalystStruct: StructType,
+ avroStruct: Schema,
+ catalystPath: Seq[String],
+ avroPath: Seq[String]): InternalRow => Record
= {
val avroSchemaHelper = new AvroUtils.AvroSchemaHelper(
avroStruct, catalystStruct, avroPath, catalystPath, positionalFieldMatch)
@@ -292,6 +309,60 @@ private[sql] class AvroSerializer(
result
}
+
////////////////////////////////////////////////////////////////////////////////////////////
+ // Following section is amended to the original (Spark's) implementation
+ // >>> BEGINS
+
////////////////////////////////////////////////////////////////////////////////////////////
+
+ private def newUnionConverter(catalystStruct: StructType,
+ avroUnion: Schema,
+ catalystPath: Seq[String],
+ avroPath: Seq[String]): InternalRow => Any = {
+ if (avroUnion.getType != UNION || !canMapUnion(catalystStruct, avroUnion))
{
+ throw new IncompatibleSchemaException(s"Cannot convert Catalyst type
$catalystStruct to " +
+ s"Avro type $avroUnion.")
+ }
+ val nullable = avroUnion.getTypes.size() > 0 &&
avroUnion.getTypes.get(0).getType == Type.NULL
+ val avroInnerTypes = if (nullable) {
+ avroUnion.getTypes.asScala.tail
+ } else {
+ avroUnion.getTypes.asScala
+ }
+ val fieldConverters = catalystStruct.zip(avroInnerTypes).map {
+ case (f1, f2) => newConverter(f1.dataType, f2, catalystPath, avroPath)
+ }
+ val numFields = catalystStruct.length
+ (row: InternalRow) =>
+ var i = 0
+ var result: Any = null
+ while (i < numFields) {
+ if (!row.isNullAt(i)) {
+ if (result != null) {
+ throw new IncompatibleSchemaException(s"Cannot convert Catalyst
record $catalystStruct to " +
+ s"Avro union $avroUnion. Record has more than one optional
values set")
+ }
+ result = fieldConverters(i).apply(row, i)
+ }
+ i += 1
+ }
+ if (!nullable && result == null) {
+ throw new IncompatibleSchemaException(s"Cannot convert Catalyst record
$catalystStruct to " +
+ s"Avro union $avroUnion. Record has no values set, while should have
exactly one")
+ }
+ result
+ }
+
+ private def canMapUnion(catalystStruct: StructType, avroStruct: Schema):
Boolean = {
+ (avroStruct.getTypes.size() > 0 &&
+ avroStruct.getTypes.get(0).getType == Type.NULL &&
+ avroStruct.getTypes.size() - 1 == catalystStruct.length) ||
avroStruct.getTypes.size() == catalystStruct.length
+ }
+
+
////////////////////////////////////////////////////////////////////////////////////////////
+ // <<< ENDS
+
////////////////////////////////////////////////////////////////////////////////////////////
+
+
/**
* Resolve a possibly nullable Avro Type.
*
@@ -319,12 +390,12 @@ private[sql] class AvroSerializer(
if (avroType.getType == Type.UNION) {
val fields = avroType.getTypes.asScala
val actualType = fields.filter(_.getType != Type.NULL)
- if (fields.length != 2 || actualType.length != 1) {
- throw new UnsupportedAvroTypeException(
- s"Unsupported Avro UNION type $avroType: Only UNION of a null type
and a non-null " +
- "type is supported")
+ if (fields.length == 2 && actualType.length == 1) {
+ (true, actualType.head)
+ } else {
+ // This is just a normal union, not used to designate nullability
+ (false, avroType)
}
- (true, actualType.head)
} else {
(false, avroType)
}