voonhous commented on code in PR #18065:
URL: https://github.com/apache/hudi/pull/18065#discussion_r3354381436
##########
hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/schema/TestVariantDataType.scala:
##########
@@ -189,4 +193,406 @@ class TestVariantDataType extends HoodieSparkSqlTestBase {
spark.sql(s"drop table $tableName")
}
+
+ test("Test Shredded Variant with Multiple Variant Columns") {
+ assume(HoodieSparkUtils.gteqSpark4_0, "Variant type requires Spark 4.0 or
higher")
+
+ withRecordType()(withTempDir { tmp =>
+ val tableName = generateTableName
+ spark.sql(
+ s"""
+ |create table $tableName (
+ | id int,
+ | v1 variant,
+ | v2 variant,
+ | ts long
+ |) using hudi
+ | location '${tmp.getCanonicalPath}'
+ | tblproperties (
+ | primaryKey = 'id',
+ | type = 'cow',
+ | preCombineField = 'ts'
+ | )
+ """.stripMargin)
+
+ spark.sql("set hoodie.parquet.variant.write.shredding.enabled = true")
+ spark.sql("set hoodie.parquet.variant.allow.reading.shredded = true")
+ spark.sql("set hoodie.parquet.variant.force.shredding.schema.for.test =
a int, b string")
+
+ spark.sql(
+ s"""
+ |insert into $tableName values
+ | (1, parse_json('{"a": 10, "b": "first"}'), parse_json('{"a": 20,
"b": "second"}'), 1000),
+ | (2, parse_json('{"a": 30, "b": "third"}'), parse_json('{"a": 40,
"b": "fourth"}'), 2000)
+ """.stripMargin)
+
+ checkAnswer(s"select id, cast(v1 as string), cast(v2 as string) from
$tableName order by id")(
+ Seq(1, "{\"a\":10,\"b\":\"first\"}", "{\"a\":20,\"b\":\"second\"}"),
+ Seq(2, "{\"a\":30,\"b\":\"third\"}", "{\"a\":40,\"b\":\"fourth\"}")
+ )
+ })
+ }
+
+ test("Test Variant Shredding with Update Operation") {
+ assume(HoodieSparkUtils.gteqSpark4_0, "Variant type requires Spark 4.0 or
higher")
+
+ withRecordType()(withTempDir { tmp =>
+ val tableName = generateTableName
+ spark.sql(
+ s"""
+ |create table $tableName (
+ | id int,
+ | v variant,
+ | ts long
+ |) using hudi
+ | location '${tmp.getCanonicalPath}'
+ | tblproperties (
+ | primaryKey = 'id',
+ | type = 'cow',
+ | preCombineField = 'ts'
+ | )
+ """.stripMargin)
+
+ spark.sql("set hoodie.parquet.variant.write.shredding.enabled = true")
+ spark.sql("set hoodie.parquet.variant.allow.reading.shredded = true")
+ spark.sql("set hoodie.parquet.variant.force.shredding.schema.for.test =
a int, b string")
+
+ // Initial insert
+ spark.sql(
+ s"""
+ |insert into $tableName values
+ | (1, parse_json('{"a": 1, "b": "initial"}'), 1000)
+ """.stripMargin)
+
+ // Update the variant value
+ spark.sql(
+ s"""
+ |update $tableName set v = parse_json('{"a": 999, "b":
"updated"}'), ts = 2000 where id = 1
+ """.stripMargin)
+
+ checkAnswer(s"select id, cast(v as string), ts from $tableName")(
+ Seq(1, "{\"a\":999,\"b\":\"updated\"}", 2000)
+ )
+
+ // Verify parquet schema has shredded structure with typed_value
+ val parquetFiles = listDataParquetFiles(tmp.getCanonicalPath)
+ assert(parquetFiles.nonEmpty, "Should have at least one data parquet
file")
+
+ parquetFiles.foreach { filePath =>
+ val schema = readParquetSchema(filePath)
+ val variantGroup = getFieldAsGroup(schema, "v")
+ assert(groupContainsField(variantGroup, "typed_value"),
+ s"Shredded variant should have typed_value field.
Schema:\n$variantGroup")
+ val valueField =
variantGroup.getType(variantGroup.getFieldIndex("value"))
+ assert(valueField.getRepetition == Type.Repetition.OPTIONAL,
+ "Shredded variant value field should be OPTIONAL")
+ val metadataField =
variantGroup.getType(variantGroup.getFieldIndex("metadata"))
+ assert(metadataField.getRepetition == Type.Repetition.REQUIRED,
+ "Shredded variant metadata field should be REQUIRED")
+ }
+ })
+ }
+
+ test("Test Variant Shredding with Merge Operation") {
+ assume(HoodieSparkUtils.gteqSpark4_0, "Variant type requires Spark 4.0 or
higher")
+
+ withRecordType()(withTempDir { tmp =>
+ val tableName = generateTableName
+ spark.sql(
+ s"""
+ |create table $tableName (
+ | id int,
+ | v variant,
+ | ts long
+ |) using hudi
+ | location '${tmp.getCanonicalPath}'
+ | tblproperties (
+ | primaryKey = 'id',
+ | type = 'cow',
+ | preCombineField = 'ts'
+ | )
+ """.stripMargin)
+
+ spark.sql("set hoodie.parquet.variant.write.shredding.enabled = true")
+ spark.sql("set hoodie.parquet.variant.allow.reading.shredded = true")
+ spark.sql("set hoodie.parquet.variant.force.shredding.schema.for.test =
a int, b string")
+
+ // Initial data
+ spark.sql(
+ s"""
+ |insert into $tableName values
+ | (1, parse_json('{"a": 1, "b": "first"}'), 1000)
+ """.stripMargin)
+
+ // Merge in updates and inserts
+ spark.sql(
+ s"""
+ |merge into $tableName as target
+ |using (
+ | select 1 as id, parse_json('{"a": 100, "b": "merged"}') as v,
2000L as ts
+ | union all
+ | select 2 as id, parse_json('{"a": 200, "b": "new"}') as v, 3000L
as ts
+ |) as source
+ |on target.id = source.id
+ |when matched then update set target.v = source.v, target.ts =
source.ts
+ |when not matched then insert *
+ """.stripMargin)
+
+ checkAnswer(s"select id, cast(v as string) from $tableName order by id")(
+ Seq(1, "{\"a\":100,\"b\":\"merged\"}"),
+ Seq(2, "{\"a\":200,\"b\":\"new\"}")
+ )
+ })
+ }
+
+ test("Test Variant Shredding with Null Values") {
+ assume(HoodieSparkUtils.gteqSpark4_0, "Variant type requires Spark 4.0 or
higher")
+
+ withRecordType()(withTempDir { tmp =>
+ val tableName = generateTableName
+ spark.sql(
+ s"""
+ |create table $tableName (
+ | id int,
+ | v variant,
+ | ts long
+ |) using hudi
+ | location '${tmp.getCanonicalPath}'
+ | tblproperties (
+ | primaryKey = 'id',
+ | type = 'cow',
+ | preCombineField = 'ts'
+ | )
+ """.stripMargin)
+
+ spark.sql("set hoodie.parquet.variant.write.shredding.enabled = true")
+ spark.sql("set hoodie.parquet.variant.allow.reading.shredded = true")
+ spark.sql("set hoodie.parquet.variant.force.shredding.schema.for.test =
a int, b string")
+
+ spark.sql(
+ s"""
+ |insert into $tableName values
+ | (1, parse_json('{"a": null, "b": "test"}'), 1000),
+ | (2, null, 2000)
+ """.stripMargin)
+
+ val result = spark.sql(s"select id, v from $tableName order by
id").collect()
+ assert(result.length == 2)
+ assert(!result(0).isNullAt(1), "First row should have non-null variant")
+ assert(result(1).isNullAt(1), "Second row should have null variant")
+ })
+ }
+
+ test("Test Variant with Different Numeric Types") {
+ assume(HoodieSparkUtils.gteqSpark4_0, "Variant type requires Spark 4.0 or
higher")
+
+ withRecordType()(withTempDir { tmp =>
+ val tableName = generateTableName
+ spark.sql(
+ s"""
+ |create table $tableName (
+ | id int,
+ | v variant,
+ | ts long
+ |) using hudi
+ | location '${tmp.getCanonicalPath}'
+ | tblproperties (
+ | primaryKey = 'id',
+ | type = 'cow',
+ | preCombineField = 'ts'
+ | )
+ """.stripMargin)
+
+ spark.sql("set hoodie.parquet.variant.write.shredding.enabled = true")
+ spark.sql("set hoodie.parquet.variant.allow.reading.shredded = true")
+ spark.sql("set hoodie.parquet.variant.force.shredding.schema.for.test =
price decimal(10,2), quantity long, rating double")
+
+ spark.sql(
+ s"""
+ |insert into $tableName values
+ | (1, parse_json('{"price": 99.99, "quantity": 5, "rating":
4.5}'), 1000)
+ """.stripMargin)
+
+ checkAnswer(s"select id from $tableName")(
+ Seq(1)
+ )
+ })
+ }
+
+ test("Test Variant Shredding Toggle") {
+ assume(HoodieSparkUtils.gteqSpark4_0, "Variant type requires Spark 4.0 or
higher")
+
+ withRecordType()(withTempDir { tmp =>
+ val tableName = generateTableName
+ spark.sql(
+ s"""
+ |create table $tableName (
+ | id int,
+ | v variant,
+ | ts long
+ |) using hudi
+ | location '${tmp.getCanonicalPath}'
+ | tblproperties (
+ | primaryKey = 'id',
+ | type = 'cow',
+ | preCombineField = 'ts'
+ | )
+ """.stripMargin)
+
+ // First write: shredding disabled
+ spark.sql("set hoodie.parquet.variant.write.shredding.enabled = false")
+
+ spark.sql(
+ s"""
+ |insert into $tableName values
+ | (1, parse_json('{"a": 1, "b": "hello"}'), 1000)
+ """.stripMargin)
+
+ // Second write: shredding enabled
+ spark.sql("set hoodie.parquet.variant.write.shredding.enabled = true")
+ spark.sql("set hoodie.parquet.variant.force.shredding.schema.for.test =
a int, b string")
+
+ spark.sql(
+ s"""
+ |insert into $tableName values
+ | (2, parse_json('{"a": 2, "b": "world"}'), 2000)
+ """.stripMargin)
+
+ // Both records should be readable
+ checkAnswer(s"select id, cast(v as string) from $tableName order by id")(
+ Seq(1, "{\"a\":1,\"b\":\"hello\"}"),
+ Seq(2, "{\"a\":2,\"b\":\"world\"}")
+ )
+ })
+ }
+
+ test("Test Variant Shredding with Complex Field Names") {
+ assume(HoodieSparkUtils.gteqSpark4_0, "Variant type requires Spark 4.0 or
higher")
+
+ withRecordType()(withTempDir { tmp =>
+ val tableName = generateTableName
+ spark.sql(
+ s"""
+ |create table $tableName (
+ | id int,
+ | v variant,
+ | ts long
+ |) using hudi
+ | location '${tmp.getCanonicalPath}'
+ | tblproperties (
+ | primaryKey = 'id',
+ | type = 'cow',
+ | preCombineField = 'ts'
+ | )
+ """.stripMargin)
+
+ spark.sql("set hoodie.parquet.variant.write.shredding.enabled = true")
+ spark.sql("set hoodie.parquet.variant.allow.reading.shredded = true")
+ spark.sql("set hoodie.parquet.variant.force.shredding.schema.for.test =
field1 int, field2 string, field3 boolean")
+
+ spark.sql(
+ s"""
+ |insert into $tableName values
+ | (1, parse_json('{"field1": 100, "field2": "value", "field3":
true, "extra": "ignored"}'), 1000)
+ """.stripMargin)
+
+ val result = spark.sql(s"select cast(v as string) from
$tableName").collect()
+ assert(result.length == 1)
+ // Verify the JSON contains all expected fields
+ val json = result(0).getString(0)
+ assert(json.contains("field1"))
+ assert(json.contains("field2"))
+ assert(json.contains("field3"))
+ })
+ }
+
+ test("Test Variant with Empty Object") {
+ assume(HoodieSparkUtils.gteqSpark4_0, "Variant type requires Spark 4.0 or
higher")
+
+ withRecordType()(withTempDir { tmp =>
+ val tableName = generateTableName
+ spark.sql(
+ s"""
+ |create table $tableName (
+ | id int,
+ | v variant,
+ | ts long
+ |) using hudi
+ | location '${tmp.getCanonicalPath}'
+ | tblproperties (
+ | primaryKey = 'id',
+ | type = 'cow',
+ | preCombineField = 'ts'
+ | )
+ """.stripMargin)
+
+ spark.sql("set hoodie.parquet.variant.write.shredding.enabled = true")
+ spark.sql("set hoodie.parquet.variant.allow.reading.shredded = true")
+ spark.sql("set hoodie.parquet.variant.force.shredding.schema.for.test =
a int")
+
+ spark.sql(
+ s"""
+ |insert into $tableName values
+ | (1, parse_json('{}'), 1000)
+ """.stripMargin)
+
+ checkAnswer(s"select id, cast(v as string) from $tableName")(
+ Seq(1, "{}")
+ )
+ })
+ }
+
+ /**
+ * Lists data parquet files in the table directory, excluding Hudi metadata
files.
+ */
+ private def listDataParquetFiles(tablePath: String): Seq[String] = {
+ val conf = spark.sparkContext.hadoopConfiguration
+ val fs = FileSystem.get(new HadoopPath(tablePath).toUri, conf)
+ val iter = fs.listFiles(new HadoopPath(tablePath), true)
+ val files = scala.collection.mutable.ArrayBuffer[String]()
+ while (iter.hasNext) {
+ val file = iter.next()
+ val path = file.getPath.toString
+ if (path.endsWith(".parquet") && !path.contains(".hoodie")) {
+ files += path
+ }
+ }
+ files.toSeq
+ }
+
+ /**
+ * Reads the Parquet schema (MessageType) from a parquet file.
+ */
+ private def readParquetSchema(filePath: String): MessageType = {
+ val conf = spark.sparkContext.hadoopConfiguration
+ val inputFile = HadoopInputFile.fromPath(new HadoopPath(filePath), conf)
+ val reader = ParquetFileReader.open(inputFile)
+ try {
+ reader.getFooter.getFileMetaData.getSchema
+ } finally {
+ reader.close()
+ }
+ }
+
+ /**
+ * Gets a named field from a GroupType (MessageType) and returns it as a
GroupType.
+ * Uses getFieldIndex(String) + getType(int) to avoid Scala overload
resolution issues.
+ */
+ private def getFieldAsGroup(parent: GroupType, fieldName: String): GroupType
= {
+ val idx: Int = parent.getFieldIndex(fieldName)
+ parent.getType(idx).asGroupType()
+ }
+
+ /**
+ * Checks whether a GroupType contains a field with the given name.
+ * Uses try/catch on getFieldIndex to avoid Scala-Java collection converter
dependencies.
+ */
+ private def groupContainsField(group: GroupType, fieldName: String): Boolean
= {
+ try {
+ group.getFieldIndex(fieldName)
Review Comment:
Good catch. Replaced the try/catch entirely with Parquet's
`GroupType.containsField(fieldName)`, which returns a boolean directly. No
exception handling needed, and it matches how the rest of the codebase (e.g.
`Spark40HoodieParquetReadSupport`) checks for variant fields. Verified
`containsField` is present in every Parquet version the Spark profiles resolve
(1.12.3 / 1.13.1 / 1.15.2 for Spark 4.0 / 1.16.0).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]