ericm-db commented on code in PR #49277:
URL: https://github.com/apache/spark/pull/49277#discussion_r1918069251


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala:
##########
@@ -496,6 +498,486 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
     }
   }
 
+  test("AvroStateEncoder - add field") {
+    val keySchema = StructType(Seq(
+      StructField("k", StringType)
+    ))
+
+    val initialValueSchema = StructType(Seq(
+      StructField("value", IntegerType, false)
+    ))
+
+    val evolvedValueSchema = StructType(Seq(
+      StructField("value", IntegerType, false),
+      StructField("timestamp", LongType, true)
+    ))
+
+    // Create test state schema provider
+    val testProvider = new TestStateSchemaProvider()
+
+    // Add initial schema version
+    testProvider.captureSchema(
+      StateStore.DEFAULT_COL_FAMILY_NAME,
+      keySchema,
+      initialValueSchema,
+      valueSchemaId = 0
+    )
+
+    // Create encoder with initial schema
+    val encoder1 = new AvroStateEncoder(
+      NoPrefixKeyStateEncoderSpec(keySchema),
+      initialValueSchema,
+      Some(testProvider),
+      Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1))
+    )
+
+    // Create test data
+    val proj = UnsafeProjection.create(initialValueSchema)
+    val row1 = proj.apply(InternalRow(1))
+
+    // Encode with schema v0
+    val encoded = encoder1.encodeValue(row1)
+
+    // Add evolved schema
+    testProvider.captureSchema(
+      StateStore.DEFAULT_COL_FAMILY_NAME,
+      keySchema,
+      evolvedValueSchema,
+      valueSchemaId = 1
+    )
+
+    // Create encoder with initial schema
+    val encoder2 = new AvroStateEncoder(
+      NoPrefixKeyStateEncoderSpec(keySchema),
+      evolvedValueSchema,
+      Some(testProvider),
+      Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1))
+    )
+
+    // Decode with evolved schema
+    val decoded = encoder2.decodeValue(encoded)
+
+    // Should be able to read old format
+    assert(decoded.getInt(0) === 1)
+    assert(decoded.getLong(1) === 0L) // New field should be null
+
+    // Encode with new schema
+    val proj2 = UnsafeProjection.create(evolvedValueSchema)
+    val row2 = proj2.apply(InternalRow(2, 100L))
+    val encoded2 = encoder2.encodeValue(row2)
+
+    // Should write with new schema version
+    val decoded2 = encoder2.decodeValue(encoded2)
+    assert(decoded2.getInt(0) === 2)
+    assert(decoded2.getLong(1) === 100L)
+  }
+
+  test("AvroStateEncoder - add field with null") {
+    val keySchema = StructType(Seq(
+      StructField("k", StringType)
+    ))
+
+    val initialValueSchema = StructType(Seq(
+      StructField("value", IntegerType, true) // Made nullable
+    ))
+
+    val evolvedValueSchema = StructType(Seq(
+      StructField("value", IntegerType, true), // Made nullable
+      StructField("timestamp", LongType, true)
+    ))
+
+    // Create test state schema provider
+    val testProvider = new TestStateSchemaProvider()
+
+    // Add initial schema version
+    testProvider.captureSchema(
+      StateStore.DEFAULT_COL_FAMILY_NAME,
+      keySchema,
+      initialValueSchema,
+      valueSchemaId = 0
+    )
+
+    // Create encoder with initial schema
+    val encoder1 = new AvroStateEncoder(
+      NoPrefixKeyStateEncoderSpec(keySchema),
+      initialValueSchema,
+      Some(testProvider),
+      Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1))
+    )
+
+    // Create test data with null value
+    val proj = UnsafeProjection.create(initialValueSchema)
+    val row1 = proj.apply(InternalRow(null))
+
+    // Encode with schema v0
+    val encoded = encoder1.encodeValue(row1)
+
+    // Read back the encoded value to verify null handling
+    val decodedWithOriginalSchema = encoder1.decodeValue(encoded)
+    assert(decodedWithOriginalSchema.isNullAt(0))
+
+    // Add evolved schema
+    testProvider.captureSchema(
+      StateStore.DEFAULT_COL_FAMILY_NAME,
+      keySchema,
+      evolvedValueSchema,
+      valueSchemaId = 1
+    )
+
+    // Create encoder with evolved schema
+    val encoder2 = new AvroStateEncoder(
+      NoPrefixKeyStateEncoderSpec(keySchema),
+      evolvedValueSchema,
+      Some(testProvider),
+      Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1))
+    )
+
+    // Decode original value with evolved schema
+    val decoded = encoder2.decodeValue(encoded)
+
+    // Should be able to read old format with null value
+    assert(decoded.isNullAt(0))
+    assert(decoded.isNullAt(1)) // New field should be null
+
+    // Encode with new schema, including null value
+    val proj2 = UnsafeProjection.create(evolvedValueSchema)
+    val row2 = proj2.apply(InternalRow(null, 100L))
+    val encoded2 = encoder2.encodeValue(row2)
+
+    // Should write with new schema version
+    val decoded2 = encoder2.decodeValue(encoded2)
+    assert(decoded2.isNullAt(0))
+    assert(decoded2.getLong(1) === 100L)
+
+    // Test mix of null and non-null values
+    val row3 = proj2.apply(InternalRow(2, null))
+    val encoded3 = encoder2.encodeValue(row3)
+    val decoded3 = encoder2.decodeValue(encoded3)
+    assert(decoded3.getInt(0) === 2)
+    assert(decoded3.isNullAt(1))
+  }
+
+  test("AvroStateEncoder - remove field") {
+    val keySchema = StructType(Seq(
+      StructField("k", StringType)
+    ))
+
+    val initialValueSchema = StructType(Seq(
+      StructField("value", IntegerType, false),
+      StructField("timestamp", LongType, true)
+    ))
+
+    val evolvedValueSchema = StructType(Seq(
+      StructField("value", IntegerType, false)
+    ))
+
+    val testProvider = new TestStateSchemaProvider()
+    testProvider.captureSchema(
+      StateStore.DEFAULT_COL_FAMILY_NAME,
+      keySchema,
+      initialValueSchema,
+      valueSchemaId = 0
+    )
+
+    val encoder1 = new AvroStateEncoder(
+      NoPrefixKeyStateEncoderSpec(keySchema),
+      initialValueSchema,
+      Some(testProvider),
+      Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1))
+    )
+
+    val proj = UnsafeProjection.create(initialValueSchema)
+    val row1 = proj.apply(InternalRow(1, 100L))
+    val encoded = encoder1.encodeValue(row1)
+
+    testProvider.captureSchema(
+      StateStore.DEFAULT_COL_FAMILY_NAME,
+      keySchema,
+      evolvedValueSchema,
+      valueSchemaId = 1
+    )
+
+    val encoder2 = new AvroStateEncoder(
+      NoPrefixKeyStateEncoderSpec(keySchema),
+      evolvedValueSchema,
+      Some(testProvider),
+      Some(ColumnFamilyInfo(StateStore.DEFAULT_COL_FAMILY_NAME, 1))
+    )
+
+    val decoded = encoder2.decodeValue(encoded)
+    assert(decoded.getInt(0) === 1) // Only value field should remain

Review Comment:
   Done, using numFields



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to