HeartSaVioR commented on code in PR #50123: URL: https://github.com/apache/spark/pull/50123#discussion_r2034532096
########## sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala: ########## @@ -155,16 +157,766 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { StateStore.stop() } } + + Seq( + ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName), + ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads in $providerName are properly reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().dropDuplicates() + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 6).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(500) + } + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + + // Verify all stores have uploaded a snapshot and it's logged by the coordinator + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) + } + // Verify that we should not have any state stores lagging behind + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + query.stop() + } + } + } + + Seq( + ( + "RocksDBSkipMaintenanceOnCertainPartitionsProvider", + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName + ), + ( + "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider", + classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName + ) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads in $providerName are properly reported to the coordinator" + ) { + withCoordinatorAndSQLConf( Review Comment: Probably good to deduplicate with above test since the code for setting up test is same (the verification is different). ########## sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala: ########## @@ -155,16 +157,766 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { StateStore.stop() } } + + Seq( + ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName), + ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads in $providerName are properly reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().dropDuplicates() + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 6).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(500) + } + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + + // Verify all stores have uploaded a snapshot and it's logged by the coordinator + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) + } + // Verify that we should not have any state stores lagging behind + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + query.stop() + } + } + } + + Seq( + ( + "RocksDBSkipMaintenanceOnCertainPartitionsProvider", + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName + ), + ( + "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider", + classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName + ) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads in $providerName are properly reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().dropDuplicates() + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 6).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(500) + } + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + if (partitionId <= 1) { + // Verify state stores in partition 0/1 are lagging and didn't upload anything + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).getOrElse(0) == 0) + } else { + // Verify other stores uploaded a snapshot and it's logged by the coordinator + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) + } + } + // We should have two state stores (id 0 and 1) that are lagging behind at this point + val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + query.stop() + } + } + } + + private val allJoinStateStoreNames: Seq[String] = + SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) + + Seq( + ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName), + ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads for join queries with $providerName are properly " + + s"reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0", + SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> "5" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a join query and run some data to force snapshot uploads + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) as "leftValue") + val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 3) as "rightValue") + val joined = df1.join(df2, expr("leftKey = rightKey")) + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = joined.writeStream + .format("memory") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 7).foreach { _ => + input1.addData(1, 5) + input2.addData(1, 5, 10) + query.processAllAvailable() + Thread.sleep(500) + } + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + + // Verify all state stores for join queries are reporting snapshot uploads Review Comment: Looks like the verification code is also very similar for successful case (and probably for failure case). Shall we look at deduplicating? ########## sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala: ########## @@ -155,16 +157,766 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { StateStore.stop() } } + + Seq( + ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName), + ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads in $providerName are properly reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().dropDuplicates() + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 6).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(500) + } + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + + // Verify all stores have uploaded a snapshot and it's logged by the coordinator + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) + } + // Verify that we should not have any state stores lagging behind + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + query.stop() + } + } + } + + Seq( + ( + "RocksDBSkipMaintenanceOnCertainPartitionsProvider", + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName + ), + ( + "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider", + classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName + ) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads in $providerName are properly reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().dropDuplicates() + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 6).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(500) + } + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + if (partitionId <= 1) { + // Verify state stores in partition 0/1 are lagging and didn't upload anything + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).getOrElse(0) == 0) + } else { + // Verify other stores uploaded a snapshot and it's logged by the coordinator + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) + } + } + // We should have two state stores (id 0 and 1) that are lagging behind at this point + val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + query.stop() + } + } + } + + private val allJoinStateStoreNames: Seq[String] = + SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) + + Seq( + ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName), + ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads for join queries with $providerName are properly " + + s"reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0", + SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> "5" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a join query and run some data to force snapshot uploads + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) as "leftValue") + val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 3) as "rightValue") + val joined = df1.join(df2, expr("leftKey = rightKey")) + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = joined.writeStream + .format("memory") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 7).foreach { _ => + input1.addData(1, 5) + input2.addData(1, 5, 10) + query.processAllAvailable() + Thread.sleep(500) + } + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + + // Verify all state stores for join queries are reporting snapshot uploads + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + allJoinStateStoreNames.foreach { storeName => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId, storeName) + val providerId = StateStoreProviderId(storeId, query.runId) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) + } + } + // Verify that we should not have any state stores lagging behind + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + query.stop() + } + } + } + + Seq( + ( + "RocksDBSkipMaintenanceOnCertainPartitionsProvider", + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName + ), + ( + "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider", + classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName + ) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads for join queries with $providerName are properly " + + s"reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0", + SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> "5" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a join query and run some data to force snapshot uploads + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) as "leftValue") + val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 3) as "rightValue") + val joined = df1.join(df2, expr("leftKey = rightKey")) + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = joined.writeStream + .format("memory") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 7).foreach { _ => + input1.addData(1, 5) + input2.addData(1, 5, 10) + query.processAllAvailable() + Thread.sleep(500) + } + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + // Verify all state stores for join queries are reporting snapshot uploads + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + allJoinStateStoreNames.foreach { storeName => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId, storeName) + val providerId = StateStoreProviderId(storeId, query.runId) + if (partitionId <= 1) { + // Verify state stores in partition 0 and 1 are lagging and didn't upload + assert( + coordRef.getLatestSnapshotVersionForTesting(providerId).getOrElse(0) == 0 + ) + } else { + // Verify other stores have uploaded a snapshot and it's properly logged + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) + } + } + } + // Verify that only stores from partition id 0 and 1 are lagging behind. + // Each partition has 4 stores for join queries, so there are 2 * 4 = 8 lagging stores. + val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == 2 * 4) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + } + } + } + + test("SPARK-51358: Verify coordinator properly handles simultaneous query runs") { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start and run two queries together with some data to force snapshot uploads + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val dedupe1 = input1.toDF().dropDuplicates() + val dedupe2 = input2.toDF().dropDuplicates() + val checkpointLocation1 = Utils.createTempDir().getAbsoluteFile + val checkpointLocation2 = Utils.createTempDir().getAbsoluteFile + val query1 = dedupe1.writeStream + .format("memory") + .outputMode("update") + .queryName("query1") + .option("checkpointLocation", checkpointLocation1.toString) + .start() + val query2 = dedupe2.writeStream + .format("memory") + .outputMode("update") + .queryName("query2") + .option("checkpointLocation", checkpointLocation2.toString) + .start() + // Go through several rounds of input to force snapshot uploads for both queries + (0 until 2).foreach { _ => + input1.addData(1, 2, 3) + input2.addData(1, 2, 3) + query1.processAllAvailable() + query2.processAllAvailable() + // Process twice the amount of data for the first query + input1.addData(1, 2, 3) + query1.processAllAvailable() + Thread.sleep(1000) + } + // Verify that the coordinator logged the correct lagging stores for the first query + val streamingQuery1 = query1.asInstanceOf[StreamingQueryWrapper].streamingQuery + val latestVersion1 = streamingQuery1.lastProgress.batchId + 1 + val laggingStores1 = coordRef.getLaggingStoresForTesting(query1.runId, latestVersion1) + + assert(laggingStores1.size == 2) + assert(laggingStores1.forall(_.storeId.partitionId <= 1)) + assert(laggingStores1.forall(_.queryRunId == query1.runId)) + + // Verify that the second query run hasn't reported anything yet due to lack of data + val streamingQuery2 = query2.asInstanceOf[StreamingQueryWrapper].streamingQuery + var latestVersion2 = streamingQuery2.lastProgress.batchId + 1 + var laggingStores2 = coordRef.getLaggingStoresForTesting(query2.runId, latestVersion2) + assert(laggingStores2.isEmpty) + + // Process some more data for the second query to force lag reports + input2.addData(1, 2, 3) + query2.processAllAvailable() + Thread.sleep(500) + + // Verify that the coordinator logged the correct lagging stores for the second query + latestVersion2 = streamingQuery2.lastProgress.batchId + 1 + laggingStores2 = coordRef.getLaggingStoresForTesting(query2.runId, latestVersion2) + + assert(laggingStores2.size == 2) + assert(laggingStores2.forall(_.storeId.partitionId <= 1)) + assert(laggingStores2.forall(_.queryRunId == query2.runId)) + } + } + + test( + "SPARK-51358: Snapshot uploads in RocksDB are not reported if changelog " + + "checkpointing is disabled" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "false", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "1", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().dropDuplicates() + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Go through several rounds of input to force snapshot uploads Review Comment: It just needs "a" one version to commit if the query does not use changelog checkpointing. Let's reduce the test execution time if feasible. ########## sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala: ########## @@ -155,16 +157,766 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { StateStore.stop() } } + + Seq( + ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName), + ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads in $providerName are properly reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().dropDuplicates() + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 6).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(500) + } + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + + // Verify all stores have uploaded a snapshot and it's logged by the coordinator + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) + } + // Verify that we should not have any state stores lagging behind + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + query.stop() + } + } + } + + Seq( + ( + "RocksDBSkipMaintenanceOnCertainPartitionsProvider", + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName + ), + ( + "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider", + classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName + ) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads in $providerName are properly reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().dropDuplicates() + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 6).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(500) + } + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + if (partitionId <= 1) { + // Verify state stores in partition 0/1 are lagging and didn't upload anything + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).getOrElse(0) == 0) + } else { + // Verify other stores uploaded a snapshot and it's logged by the coordinator + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) + } + } + // We should have two state stores (id 0 and 1) that are lagging behind at this point + val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + query.stop() + } + } + } + + private val allJoinStateStoreNames: Seq[String] = + SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) + + Seq( + ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName), + ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads for join queries with $providerName are properly " + + s"reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0", + SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> "5" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a join query and run some data to force snapshot uploads + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) as "leftValue") + val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 3) as "rightValue") + val joined = df1.join(df2, expr("leftKey = rightKey")) + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = joined.writeStream + .format("memory") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 7).foreach { _ => + input1.addData(1, 5) + input2.addData(1, 5, 10) + query.processAllAvailable() + Thread.sleep(500) + } + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + + // Verify all state stores for join queries are reporting snapshot uploads + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + allJoinStateStoreNames.foreach { storeName => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId, storeName) + val providerId = StateStoreProviderId(storeId, query.runId) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) + } + } + // Verify that we should not have any state stores lagging behind + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + query.stop() + } + } + } + + Seq( + ( + "RocksDBSkipMaintenanceOnCertainPartitionsProvider", + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName + ), + ( + "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider", + classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName + ) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads for join queries with $providerName are properly " + + s"reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0", + SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> "5" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a join query and run some data to force snapshot uploads + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) as "leftValue") + val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 3) as "rightValue") + val joined = df1.join(df2, expr("leftKey = rightKey")) + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = joined.writeStream + .format("memory") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 7).foreach { _ => + input1.addData(1, 5) + input2.addData(1, 5, 10) + query.processAllAvailable() + Thread.sleep(500) + } + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + // Verify all state stores for join queries are reporting snapshot uploads + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { Review Comment: ditto ########## sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala: ########## @@ -155,16 +157,766 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { StateStore.stop() } } + + Seq( + ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName), + ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads in $providerName are properly reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().dropDuplicates() + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 6).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(500) + } + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + + // Verify all stores have uploaded a snapshot and it's logged by the coordinator + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) + } + // Verify that we should not have any state stores lagging behind + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + query.stop() + } + } + } + + Seq( + ( + "RocksDBSkipMaintenanceOnCertainPartitionsProvider", + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName + ), + ( + "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider", + classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName + ) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads in $providerName are properly reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().dropDuplicates() + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 6).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(500) + } + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + if (partitionId <= 1) { + // Verify state stores in partition 0/1 are lagging and didn't upload anything + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).getOrElse(0) == 0) + } else { + // Verify other stores uploaded a snapshot and it's logged by the coordinator + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) + } + } + // We should have two state stores (id 0 and 1) that are lagging behind at this point + val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + query.stop() + } + } + } + + private val allJoinStateStoreNames: Seq[String] = + SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) + + Seq( + ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName), + ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads for join queries with $providerName are properly " + + s"reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0", + SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> "5" + ) { + case (coordRef, spark) => + import spark.implicits._ Review Comment: Same, looks like the same code is used for successful case vs failure case on setting up test. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org