anishshri-db commented on code in PR #50123:
URL: https://github.com/apache/spark/pull/50123#discussion_r2022124671


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala:
##########
@@ -155,16 +157,697 @@ class StateStoreCoordinatorSuite extends SparkFunSuite 
with SharedSparkContext {
       StateStore.stop()
     }
   }
+
+  Seq(
+    ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName),
+    ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName)
+  ).foreach {
+    case (providerName, providerClassName) =>
+      test(
+        s"SPARK-51358: Snapshot uploads in $providerName are properly reported 
to the coordinator"
+      ) {
+        withCoordinatorAndSQLConf(
+          sc,
+          SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+          SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+          SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+          SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
+          RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
+          SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> 
"true",
+          
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"2",
+          SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> 
"0"
+        ) {
+          case (coordRef, spark) =>
+            import spark.implicits._
+            implicit val sqlContext = spark.sqlContext
+
+            // Start a query and run some data to force snapshot uploads
+            val inputData = MemoryStream[Int]
+            val aggregated = inputData.toDF().dropDuplicates()
+            val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+            val query = aggregated.writeStream
+              .format("memory")
+              .outputMode("update")
+              .queryName("query")
+              .option("checkpointLocation", checkpointLocation.toString)
+              .start()
+            // Add, commit, and wait multiple times to force snapshot versions 
and time difference
+            (0 until 6).foreach { _ =>
+              inputData.addData(1, 2, 3)
+              query.processAllAvailable()
+              Thread.sleep(500)
+            }
+            val streamingQuery = 
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
+            val stateCheckpointDir = 
streamingQuery.lastExecution.checkpointLocation
+            val latestVersion = streamingQuery.lastProgress.batchId + 1
+
+            // Verify all stores have uploaded a snapshot and it's logged by 
the coordinator
+            (0 until 
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
+              partitionId =>
+                val storeId = StateStoreId(stateCheckpointDir, 0, partitionId)
+                val providerId = StateStoreProviderId(storeId, query.runId)
+                
assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0)
+            }
+            // Verify that we should not have any state stores lagging behind
+            assert(coordRef.getLaggingStoresForTesting(query.runId, 
latestVersion).isEmpty)
+            query.stop()
+        }
+      }
+  }
+
+  Seq(
+    (
+      "RocksDBSkipMaintenanceOnCertainPartitionsProvider",
+      classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName
+    ),
+    (
+      "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider",
+      classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName
+    )
+  ).foreach {
+    case (providerName, providerClassName) =>
+      test(
+        s"SPARK-51358: Snapshot uploads in $providerName are properly reported 
to the coordinator"
+      ) {
+        withCoordinatorAndSQLConf(
+          sc,
+          SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+          SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+          SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+          SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
+          RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
+          SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> 
"true",
+          
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"2",
+          SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> 
"0"
+        ) {
+          case (coordRef, spark) =>
+            import spark.implicits._
+            implicit val sqlContext = spark.sqlContext
+
+            // Start a query and run some data to force snapshot uploads
+            val inputData = MemoryStream[Int]
+            val aggregated = inputData.toDF().dropDuplicates()
+            val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+            val query = aggregated.writeStream
+              .format("memory")
+              .outputMode("update")
+              .queryName("query")
+              .option("checkpointLocation", checkpointLocation.toString)
+              .start()
+            // Add, commit, and wait multiple times to force snapshot versions 
and time difference
+            (0 until 6).foreach { _ =>
+              inputData.addData(1, 2, 3)
+              query.processAllAvailable()
+              Thread.sleep(500)
+            }
+            val streamingQuery = 
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
+            val stateCheckpointDir = 
streamingQuery.lastExecution.checkpointLocation
+            val latestVersion = streamingQuery.lastProgress.batchId + 1
+
+            (0 until 
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
+              partitionId =>
+                val storeId = StateStoreId(stateCheckpointDir, 0, partitionId)
+                val providerId = StateStoreProviderId(storeId, query.runId)
+                if (partitionId <= 1) {
+                  // Verify state stores in partition 0 and 1 are lagging and 
didn't upload anything
+                  
assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty)
+                } else {
+                  // Verify other stores have uploaded a snapshot and it's 
logged by the coordinator
+                  
assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0)
+                }
+            }
+            // We should have two state stores (id 0 and 1) that are lagging 
behind at this point
+            val laggingStores = 
coordRef.getLaggingStoresForTesting(query.runId, latestVersion)
+            assert(laggingStores.size == 2)
+            assert(laggingStores.forall(_.storeId.partitionId <= 1))
+            query.stop()
+        }
+      }
+  }
+
+  private val allJoinStateStoreNames: Seq[String] =
+    SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
+
+  Seq(
+    ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName),
+    ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName)
+  ).foreach {
+    case (providerName, providerClassName) =>
+      test(
+        s"SPARK-51358: Snapshot uploads for join queries with $providerName 
are properly " +
+        s"reported to the coordinator"
+      ) {
+        withCoordinatorAndSQLConf(
+          sc,
+          SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+          SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+          SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+          SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
+          RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
+          SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> 
"true",
+          
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"5",
+          SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> 
"0",
+          SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> 
"5"
+        ) {
+          case (coordRef, spark) =>
+            import spark.implicits._
+            implicit val sqlContext = spark.sqlContext
+
+            // Start a join query and run some data to force snapshot uploads
+            val input1 = MemoryStream[Int]
+            val input2 = MemoryStream[Int]
+            val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 
2) as "leftValue")
+            val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 
3) as "rightValue")
+            val joined = df1.join(df2, expr("leftKey = rightKey"))
+            val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+            val query = joined.writeStream
+              .format("memory")
+              .queryName("query")
+              .option("checkpointLocation", checkpointLocation.toString)
+              .start()
+            // Add, commit, and wait multiple times to force snapshot versions 
and time difference
+            (0 until 7).foreach { _ =>
+              input1.addData(1, 5)
+              input2.addData(1, 5, 10)
+              query.processAllAvailable()
+              Thread.sleep(500)
+            }
+            val streamingQuery = 
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
+            val stateCheckpointDir = 
streamingQuery.lastExecution.checkpointLocation
+            val latestVersion = streamingQuery.lastProgress.batchId + 1
+
+            // Verify all state stores for join queries are reporting snapshot 
uploads
+            (0 until 
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
+              partitionId =>
+                allJoinStateStoreNames.foreach { storeName =>
+                  val storeId = StateStoreId(stateCheckpointDir, 0, 
partitionId, storeName)
+                  val providerId = StateStoreProviderId(storeId, query.runId)
+                  
assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0)
+                }
+            }
+            // Verify that we should not have any state stores lagging behind
+            assert(coordRef.getLaggingStoresForTesting(query.runId, 
latestVersion).isEmpty)
+            query.stop()
+        }
+      }
+  }
+
+  Seq(
+    (
+      "RocksDBSkipMaintenanceOnCertainPartitionsProvider",
+      classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName
+    ),
+    (
+      "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider",
+      classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName
+    )
+  ).foreach {
+    case (providerName, providerClassName) =>
+      test(
+        s"SPARK-51358: Snapshot uploads for join queries with $providerName 
are properly " +
+        s"reported to the coordinator"
+      ) {
+        withCoordinatorAndSQLConf(
+          sc,
+          SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+          SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+          SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+          SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
+          RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
+          SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> 
"true",
+          
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"5",
+          SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> 
"0",
+          SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> 
"5"
+        ) {
+          case (coordRef, spark) =>
+            import spark.implicits._
+            implicit val sqlContext = spark.sqlContext
+
+            // Start a join query and run some data to force snapshot uploads
+            val input1 = MemoryStream[Int]
+            val input2 = MemoryStream[Int]
+            val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 
2) as "leftValue")
+            val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 
3) as "rightValue")
+            val joined = df1.join(df2, expr("leftKey = rightKey"))
+            val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+            val query = joined.writeStream
+              .format("memory")
+              .queryName("query")
+              .option("checkpointLocation", checkpointLocation.toString)
+              .start()
+            // Add, commit, and wait multiple times to force snapshot versions 
and time difference
+            (0 until 7).foreach { _ =>
+              input1.addData(1, 5)
+              input2.addData(1, 5, 10)
+              query.processAllAvailable()
+              Thread.sleep(500)
+            }
+            val streamingQuery = 
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
+            val stateCheckpointDir = 
streamingQuery.lastExecution.checkpointLocation
+            val latestVersion = streamingQuery.lastProgress.batchId + 1
+            // Verify all state stores for join queries are reporting snapshot 
uploads
+            (0 until 
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
+              partitionId =>
+                allJoinStateStoreNames.foreach { storeName =>
+                  val storeId = StateStoreId(stateCheckpointDir, 0, 
partitionId, storeName)
+                  val providerId = StateStoreProviderId(storeId, query.runId)
+                  if (partitionId <= 1) {
+                    // Verify state stores in partition 0 and 1 are lagging 
and didn't upload
+                    
assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty)
+                  } else {
+                    // Verify other stores have uploaded a snapshot and it's 
properly logged
+                    
assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0)
+                  }
+                }
+            }
+            // Verify that only stores from partition id 0 and 1 are lagging 
behind.
+            // Each partition has 4 stores for join queries, so there are 2 * 
4 = 8 lagging stores.
+            val laggingStores = 
coordRef.getLaggingStoresForTesting(query.runId, latestVersion)
+            assert(laggingStores.size == 2 * 4)
+            assert(laggingStores.forall(_.storeId.partitionId <= 1))
+        }
+      }
+  }
+
+  test("SPARK-51358: Verify coordinator properly handles simultaneous query 
runs") {
+    withCoordinatorAndSQLConf(
+      sc,
+      SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+      SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+      SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+      SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+      SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+        classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName,
+      RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
+      SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true",
+      
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"2",
+      SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0"
+    ) {
+      case (coordRef, spark) =>
+        import spark.implicits._
+        implicit val sqlContext = spark.sqlContext
+
+        // Start and run two queries together with some data to force snapshot 
uploads
+        val input1 = MemoryStream[Int]
+        val input2 = MemoryStream[Int]
+        val dedupe1 = input1.toDF().dropDuplicates()
+        val dedupe2 = input2.toDF().dropDuplicates()
+        val checkpointLocation1 = Utils.createTempDir().getAbsoluteFile
+        val checkpointLocation2 = Utils.createTempDir().getAbsoluteFile
+        val query1 = dedupe1.writeStream
+          .format("memory")
+          .outputMode("update")
+          .queryName("query1")
+          .option("checkpointLocation", checkpointLocation1.toString)
+          .start()
+        val query2 = dedupe2.writeStream
+          .format("memory")
+          .outputMode("update")
+          .queryName("query2")
+          .option("checkpointLocation", checkpointLocation2.toString)
+          .start()
+        // Go through several rounds of input to force snapshot uploads for 
both queries
+        (0 until 3).foreach { _ =>
+          input1.addData(1, 2, 3)
+          input2.addData(1, 2, 3)
+          query1.processAllAvailable()
+          query2.processAllAvailable()
+          // Process twice the amount of data for the first query
+          input1.addData(1, 2, 3)
+          query1.processAllAvailable()
+          Thread.sleep(1000)
+        }
+        // Verify that the coordinator logged the correct lagging stores for 
the first query
+        val streamingQuery1 = 
query1.asInstanceOf[StreamingQueryWrapper].streamingQuery
+        val latestVersion1 = streamingQuery1.lastProgress.batchId + 1
+        val laggingStores1 = coordRef.getLaggingStoresForTesting(query1.runId, 
latestVersion1)
+
+        assert(laggingStores1.size == 2)
+        assert(laggingStores1.forall(_.storeId.partitionId <= 1))
+        assert(laggingStores1.forall(_.queryRunId == query1.runId))
+
+        // Verify that the second query run hasn't reported anything yet due 
to lack of data
+        val streamingQuery2 = 
query2.asInstanceOf[StreamingQueryWrapper].streamingQuery
+        var latestVersion2 = streamingQuery2.lastProgress.batchId + 1
+        var laggingStores2 = coordRef.getLaggingStoresForTesting(query2.runId, 
latestVersion2)
+        assert(laggingStores2.isEmpty)
+
+        // Process some more data for the second query to force lag reports
+        input2.addData(1, 2, 3)
+        query2.processAllAvailable()
+        Thread.sleep(500)
+
+        // Verify that the coordinator logged the correct lagging stores for 
the second query
+        latestVersion2 = streamingQuery2.lastProgress.batchId + 1
+        laggingStores2 = coordRef.getLaggingStoresForTesting(query2.runId, 
latestVersion2)
+
+        assert(laggingStores2.size == 2)
+        assert(laggingStores2.forall(_.storeId.partitionId <= 1))
+        assert(laggingStores2.forall(_.queryRunId == query2.runId))
+    }
+  }
+
+  test(
+    "SPARK-51358: Snapshot uploads in RocksDB are not reported if changelog " +
+    "checkpointing is disabled"
+  ) {
+    withCoordinatorAndSQLConf(
+      sc,
+      SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+      SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+      SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+      SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+      SQLConf.STATE_STORE_PROVIDER_CLASS.key -> 
classOf[RocksDBStateStoreProvider].getName,
+      RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "false",
+      SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true",
+      SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key 
-> "1",
+      
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"2",
+      SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0"
+    ) {
+      case (coordRef, spark) =>
+        import spark.implicits._
+        implicit val sqlContext = spark.sqlContext
+
+        // Start a query and run some data to force snapshot uploads
+        val inputData = MemoryStream[Int]
+        val aggregated = inputData.toDF().dropDuplicates()
+        val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+        val query = aggregated.writeStream
+          .format("memory")
+          .outputMode("update")
+          .queryName("query")
+          .option("checkpointLocation", checkpointLocation.toString)
+          .start()
+        // Go through several rounds of input to force snapshot uploads
+        (0 until 5).foreach { _ =>
+          inputData.addData(1, 2, 3)
+          query.processAllAvailable()
+          Thread.sleep(500)
+        }
+        val latestVersion =
+          
query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 
1
+        // Verify that no instances are marked as lagging, even when upload 
messages are sent.
+        // Since snapshot uploads are tied to commit, the lack of version 
difference should prevent
+        // the stores from being marked as lagging.
+        assert(coordRef.getLaggingStoresForTesting(query.runId, 
latestVersion).isEmpty)
+        query.stop()
+    }
+  }
+
+  test("SPARK-51358: Snapshot lag reports properly detects when all state 
stores are lagging") {

Review Comment:
   Do we have a test with the `AvailNow` trigger ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to