zecookiez commented on code in PR #50123:
URL: https://github.com/apache/spark/pull/50123#discussion_r2007888175


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala:
##########
@@ -155,16 +182,431 @@ class StateStoreCoordinatorSuite extends SparkFunSuite 
with SharedSparkContext {
       StateStore.stop()
     }
   }
+
+  Seq(
+    ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName),
+    ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName)
+  ).foreach {
+    case (providerName, providerClassName) =>
+      test(
+        s"SPARK-51358: Snapshot uploads in $providerName are properly reported 
to the coordinator"
+      ) {
+        withCoordinatorAndSQLConf(
+          sc,
+          SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+          SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+          SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+          SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
+          RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
+          SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> 
"true",
+          
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"2",
+          SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> 
"0"
+        ) {
+          case (coordRef, spark) =>
+            import spark.implicits._
+            implicit val sqlContext = spark.sqlContext
+
+            // Start a query and run some data to force snapshot uploads
+            val inputData = MemoryStream[Int]
+            val aggregated = inputData.toDF().dropDuplicates()
+            val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+            val query = aggregated.writeStream
+              .format("memory")
+              .outputMode("update")
+              .queryName("query")
+              .option("checkpointLocation", checkpointLocation.toString)
+              .start()
+            // Add, commit, and wait multiple times to force snapshot versions 
and time difference
+            (0 until 2).foreach { _ =>
+              inputData.addData(1, 2, 3)
+              query.processAllAvailable()
+              Thread.sleep(1000)
+            }
+            val streamingQuery = 
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
+            val stateCheckpointDir = 
streamingQuery.lastExecution.checkpointLocation
+            val latestVersion = streamingQuery.lastProgress.batchId + 1
+
+            // Verify all stores have uploaded a snapshot and it's logged by 
the coordinator
+            (0 until 
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
+              partitionId =>
+                val storeId = StateStoreId(stateCheckpointDir, 0, partitionId)
+                
assert(coordRef.getLatestSnapshotVersionForTesting(storeId).get >= 0)
+            }
+            // Verify that we should not have any state stores lagging behind
+            assert(coordRef.getLaggingStoresForTesting(query.runId, 
latestVersion).isEmpty)
+            query.stop()
+        }
+      }
+  }
+
+  Seq(
+    (
+      "RocksDBSkipMaintenanceOnCertainPartitionsProvider",
+      classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName
+    ),
+    (
+      "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider",
+      classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName
+    )
+  ).foreach {
+    case (providerName, providerClassName) =>
+      test(
+        s"SPARK-51358: Snapshot uploads in $providerName are properly reported 
to the coordinator"
+      ) {
+        withCoordinatorAndSQLConf(
+          sc,
+          SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+          SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+          SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+          SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
+          RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
+          SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> 
"true",
+          
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"2",
+          SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> 
"0"
+        ) {
+          case (coordRef, spark) =>
+            import spark.implicits._
+            implicit val sqlContext = spark.sqlContext
+
+            // Start a query and run some data to force snapshot uploads
+            val inputData = MemoryStream[Int]
+            val aggregated = inputData.toDF().dropDuplicates()
+            val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+            val query = aggregated.writeStream
+              .format("memory")
+              .outputMode("update")
+              .queryName("query")
+              .option("checkpointLocation", checkpointLocation.toString)
+              .start()
+            // Add, commit, and wait multiple times to force snapshot versions 
and time difference
+            (0 until 3).foreach { _ =>
+              inputData.addData(1, 2, 3)
+              query.processAllAvailable()
+              Thread.sleep(1000)
+            }
+            val streamingQuery = 
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
+            val stateCheckpointDir = 
streamingQuery.lastExecution.checkpointLocation
+            val latestVersion = streamingQuery.lastProgress.batchId + 1
+
+            (0 until 
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
+              partitionId =>
+                val storeId = StateStoreId(stateCheckpointDir, 0, partitionId)
+                if (partitionId <= 1) {
+                  // Verify state stores in partition 0 and 1 are lagging and 
didn't upload anything
+                  
assert(coordRef.getLatestSnapshotVersionForTesting(storeId).isEmpty)
+                } else {
+                  // Verify other stores have uploaded a snapshot and it's 
logged by the coordinator
+                  
assert(coordRef.getLatestSnapshotVersionForTesting(storeId).get >= 0)
+                }
+            }
+            // We should have two state stores (id 0 and 1) that are lagging 
behind at this point
+            val laggingStores = 
coordRef.getLaggingStoresForTesting(query.runId, latestVersion)
+            assert(laggingStores.size == 2)
+            assert(laggingStores.forall(_.partitionId <= 1))
+            query.stop()
+        }
+      }
+  }
+
+  private val allJoinStateStoreNames: Seq[String] =
+    SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
+
+  Seq(
+    ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName),
+    ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName)
+  ).foreach {
+    case (providerName, providerClassName) =>
+      test(
+        s"SPARK-51358: Snapshot uploads for join queries with $providerName 
are properly " +
+        s"reported to the coordinator"
+      ) {
+        withCoordinatorAndSQLConf(
+          sc,
+          SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+          SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+          SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+          SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
+          RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
+          SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> 
"true",
+          
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"5",
+          SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> 
"0",
+          SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> 
"5"
+        ) {
+          case (coordRef, spark) =>
+            import spark.implicits._
+            implicit val sqlContext = spark.sqlContext
+
+            // Start a join query and run some data to force snapshot uploads
+            val input1 = MemoryStream[Int]
+            val input2 = MemoryStream[Int]
+            val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 
2) as "leftValue")
+            val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 
3) as "rightValue")
+            val joined = df1.join(df2, expr("leftKey = rightKey"))
+            val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+            val query = joined.writeStream
+              .format("memory")
+              .queryName("query")
+              .option("checkpointLocation", checkpointLocation.toString)
+              .start()
+            // Add, commit, and wait multiple times to force snapshot versions 
and time difference
+            (0 until 5).foreach { _ =>
+              input1.addData(1, 5)
+              input2.addData(1, 5, 10)
+              query.processAllAvailable()
+              Thread.sleep(500)
+            }
+            val streamingQuery = 
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
+            val stateCheckpointDir = 
streamingQuery.lastExecution.checkpointLocation
+            val latestVersion = streamingQuery.lastProgress.batchId + 1
+
+            // Verify all state stores for join queries are reporting snapshot 
uploads
+            (0 until 
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
+              partitionId =>
+                allJoinStateStoreNames.foreach { storeName =>
+                  val storeId = StateStoreId(stateCheckpointDir, 0, 
partitionId, storeName)
+                  
assert(coordRef.getLatestSnapshotVersionForTesting(storeId).get >= 0)
+                }
+            }
+            // Verify that we should not have any state stores lagging behind
+            assert(coordRef.getLaggingStoresForTesting(query.runId, 
latestVersion).isEmpty)
+            query.stop()
+        }
+      }
+  }
+
+  Seq(
+    (
+      "RocksDBSkipMaintenanceOnCertainPartitionsProvider",
+      classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName
+    ),
+    (
+      "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider",
+      classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName
+    )
+  ).foreach {
+    case (providerName, providerClassName) =>
+      test(
+        s"SPARK-51358: Snapshot uploads for join queries with $providerName 
are properly " +
+        s"reported to the coordinator"
+      ) {
+        withCoordinatorAndSQLConf(
+          sc,
+          SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+          SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+          SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+          SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
+          RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
+          SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> 
"true",
+          
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"5",
+          SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> 
"0",
+          SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> 
"5"
+        ) {
+          case (coordRef, spark) =>
+            import spark.implicits._
+            implicit val sqlContext = spark.sqlContext
+
+            // Start a join query and run some data to force snapshot uploads
+            val input1 = MemoryStream[Int]
+            val input2 = MemoryStream[Int]
+            val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 
2) as "leftValue")
+            val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 
3) as "rightValue")
+            val joined = df1.join(df2, expr("leftKey = rightKey"))
+            val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+            val query = joined.writeStream
+              .format("memory")
+              .queryName("query")
+              .option("checkpointLocation", checkpointLocation.toString)
+              .start()
+            // Add, commit, and wait multiple times to force snapshot versions 
and time difference
+            (0 until 6).foreach { _ =>
+              input1.addData(1, 5)
+              input2.addData(1, 5, 10)
+              query.processAllAvailable()
+              Thread.sleep(500)
+            }
+            val streamingQuery = 
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
+            val stateCheckpointDir = 
streamingQuery.lastExecution.checkpointLocation
+            val latestVersion = streamingQuery.lastProgress.batchId + 1
+            // Verify all state stores for join queries are reporting snapshot 
uploads
+            (0 until 
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
+              partitionId =>
+                allJoinStateStoreNames.foreach { storeName =>
+                  val storeId = StateStoreId(stateCheckpointDir, 0, 
partitionId, storeName)
+                  if (partitionId <= 1) {
+                    // Verify state stores in partition 0 and 1 are lagging 
and didn't upload
+                    
assert(coordRef.getLatestSnapshotVersionForTesting(storeId).isEmpty)
+                  } else {
+                    // Verify other stores have uploaded a snapshot and it's 
properly logged
+                    
assert(coordRef.getLatestSnapshotVersionForTesting(storeId).get >= 0)
+                  }
+                }
+            }
+            // Verify that only stores from partition id 0 and 1 are lagging 
behind.
+            // Each partition has 4 stores for join queries, so there are 2 * 
4 = 8 lagging stores.
+            val laggingStores = 
coordRef.getLaggingStoresForTesting(query.runId, latestVersion)
+            assert(laggingStores.size == 2 * 4)
+            assert(laggingStores.forall(_.partitionId <= 1))
+        }
+      }
+  }
+
+  test(
+    "SPARK-51358: Snapshot uploads in RocksDB are not reported if changelog " +
+    "checkpointing is disabled"
+  ) {
+    withCoordinatorAndSQLConf(
+      sc,
+      SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+      SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+      SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+      SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+      SQLConf.STATE_STORE_PROVIDER_CLASS.key -> 
classOf[RocksDBStateStoreProvider].getName,
+      RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "false",
+      SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true",
+      SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key 
-> "1",
+      
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"2",
+      SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0"
+    ) {
+      case (coordRef, spark) =>
+        import spark.implicits._
+        implicit val sqlContext = spark.sqlContext
+
+        // Start a query and run some data to force snapshot uploads
+        val inputData = MemoryStream[Int]
+        val aggregated = inputData.toDF().dropDuplicates()
+        val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+        val query = aggregated.writeStream
+          .format("memory")
+          .outputMode("update")
+          .queryName("query")
+          .option("checkpointLocation", checkpointLocation.toString)
+          .start()
+        // Go through several rounds of input to force snapshot uploads
+        (0 until 5).foreach { _ =>
+          inputData.addData(1, 2, 3)
+          query.processAllAvailable()
+          Thread.sleep(1000)
+        }
+        val latestVersion =
+          
query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 
1
+        // Verify that no instances are marked as lagging, even when upload 
messages are sent.
+        // Since snapshot uploads are tied to commit, the lack of version 
difference should prevent
+        // the stores from being marked as lagging.
+        assert(coordRef.getLaggingStoresForTesting(query.runId, 
latestVersion).isEmpty)
+        query.stop()
+    }
+  }
+}
+
+class StateStoreCoordinatorStreamingSuite extends StreamTest {
+  import testImplicits._
+
+  test("SPARK-51358: Restarting queries do not mark state stores as lagging") {
+    withSQLConf(
+      SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+      SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+      SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+      SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+      SQLConf.STATE_STORE_PROVIDER_CLASS.key -> 
classOf[RocksDBStateStoreProvider].getName,
+      RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
+      SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true",
+      
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"5",
+      SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0"

Review Comment:
   If the last uploaded snapshot is unavailable for that store after the 
restart the default timestamp being 0 would cause the time check to pass 
anyways, but for consistency I will put this at a low amount :+1: 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to