HeartSaVioR commented on code in PR #50123:
URL: https://github.com/apache/spark/pull/50123#discussion_r2038644963


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala:
##########
@@ -158,281 +158,162 @@ class StateStoreCoordinatorSuite extends SparkFunSuite 
with SharedSparkContext {
     }
   }
 
-  Seq(
-    ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName),
-    ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName)
-  ).foreach {
-    case (providerName, providerClassName) =>
-      test(
-        s"SPARK-51358: Snapshot uploads in $providerName are properly reported 
to the coordinator"
-      ) {
-        withCoordinatorAndSQLConf(
-          sc,
-          SQLConf.SHUFFLE_PARTITIONS.key -> "5",
-          SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
-          SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
-          SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
-          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
-          RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
-          SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> 
"true",
-          
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"2",
-          SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> 
"0"
-        ) {
-          case (coordRef, spark) =>
-            import spark.implicits._
-            implicit val sqlContext = spark.sqlContext
-
-            // Start a query and run some data to force snapshot uploads
-            val inputData = MemoryStream[Int]
-            val aggregated = inputData.toDF().dropDuplicates()
-            val checkpointLocation = Utils.createTempDir().getAbsoluteFile
-            val query = aggregated.writeStream
-              .format("memory")
-              .outputMode("update")
-              .queryName("query")
-              .option("checkpointLocation", checkpointLocation.toString)
-              .start()
-            // Add, commit, and wait multiple times to force snapshot versions 
and time difference
-            (0 until 6).foreach { _ =>
-              inputData.addData(1, 2, 3)
-              query.processAllAvailable()
-              Thread.sleep(500)
-            }
-            val streamingQuery = 
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
-            val stateCheckpointDir = 
streamingQuery.lastExecution.checkpointLocation
-            val latestVersion = streamingQuery.lastProgress.batchId + 1
-
-            // Verify all stores have uploaded a snapshot and it's logged by 
the coordinator
-            (0 until 
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
-              partitionId =>
-                val storeId = StateStoreId(stateCheckpointDir, 0, partitionId)
-                val providerId = StateStoreProviderId(storeId, query.runId)
-                
assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0)
-            }
-            // Verify that we should not have any state stores lagging behind
-            assert(coordRef.getLaggingStoresForTesting(query.runId, 
latestVersion).isEmpty)
-            query.stop()
-        }
-      }
-  }
+  private val allJoinStateStoreNames: Seq[String] =
+    SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
 
-  Seq(
+  /** Lists the state store providers used for a test, and the set of lagging 
partition IDs */
+  private val regularStateStoreProviders = Seq(
+    ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName, 
Set.empty[Int]),
+    ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName, 
Set.empty[Int])
+  )
+
+  /** Lists the state store providers used for a test, and the set of lagging 
partition IDs */
+  private val faultyStateStoreProviders = Seq(
     (
       "RocksDBSkipMaintenanceOnCertainPartitionsProvider",
-      classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName
+      classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName,
+      Set(0, 1)
     ),
     (
       "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider",
-      classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName
+      classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName,
+      Set(0, 1)
     )
-  ).foreach {
-    case (providerName, providerClassName) =>
-      test(
-        s"SPARK-51358: Snapshot uploads in $providerName are properly reported 
to the coordinator"
-      ) {
-        withCoordinatorAndSQLConf(
-          sc,
-          SQLConf.SHUFFLE_PARTITIONS.key -> "5",
-          SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
-          SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
-          SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
-          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
-          RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
-          SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> 
"true",
-          
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"2",
-          SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> 
"0"
-        ) {
-          case (coordRef, spark) =>
-            import spark.implicits._
-            implicit val sqlContext = spark.sqlContext
-
-            // Start a query and run some data to force snapshot uploads
-            val inputData = MemoryStream[Int]
-            val aggregated = inputData.toDF().dropDuplicates()
-            val checkpointLocation = Utils.createTempDir().getAbsoluteFile
-            val query = aggregated.writeStream
-              .format("memory")
-              .outputMode("update")
-              .queryName("query")
-              .option("checkpointLocation", checkpointLocation.toString)
-              .start()
-            // Add, commit, and wait multiple times to force snapshot versions 
and time difference
-            (0 until 6).foreach { _ =>
-              inputData.addData(1, 2, 3)
-              query.processAllAvailable()
-              Thread.sleep(500)
-            }
-            val streamingQuery = 
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
-            val stateCheckpointDir = 
streamingQuery.lastExecution.checkpointLocation
-            val latestVersion = streamingQuery.lastProgress.batchId + 1
-
-            (0 until 
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
-              partitionId =>
-                val storeId = StateStoreId(stateCheckpointDir, 0, partitionId)
-                val providerId = StateStoreProviderId(storeId, query.runId)
-                if (partitionId <= 1) {
-                  // Verify state stores in partition 0/1 are lagging and 
didn't upload anything
-                  
assert(coordRef.getLatestSnapshotVersionForTesting(providerId).getOrElse(0) == 
0)
-                } else {
-                  // Verify other stores uploaded a snapshot and it's logged 
by the coordinator
-                  
assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0)
-                }
-            }
-            // We should have two state stores (id 0 and 1) that are lagging 
behind at this point
-            val laggingStores = 
coordRef.getLaggingStoresForTesting(query.runId, latestVersion)
-            assert(laggingStores.size == 2)
-            assert(laggingStores.forall(_.storeId.partitionId <= 1))
-            query.stop()
+  )
+
+  private val allStateStoreProviders =
+    regularStateStoreProviders ++ faultyStateStoreProviders
+
+  /**
+   *  Verifies snapshot upload RPC messages from state stores are registered 
and verifies
+   *  the coordinator detected the correct lagging partitions.
+   */
+  private def verifySnapshotUploadEvents(
+      coordRef: StateStoreCoordinatorRef,
+      query: StreamingQuery,
+      badPartitions: Set[Int],
+      storeNames: Seq[String] = Seq(StateStoreId.DEFAULT_STORE_NAME)): Unit = {
+    val streamingQuery = 
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
+    val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation
+    val latestVersion = streamingQuery.lastProgress.batchId + 1
+
+    // Verify all stores have uploaded a snapshot and it's logged by the 
coordinator
+    (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
+      partitionId =>
+        // Verify for every store name listed
+        storeNames.foreach { storeName =>
+          val storeId = StateStoreId(stateCheckpointDir, 0, partitionId, 
storeName)
+          val providerId = StateStoreProviderId(storeId, query.runId)
+          val latestSnapshotVersion = 
coordRef.getLatestSnapshotVersionForTesting(providerId)
+          if (badPartitions.contains(partitionId)) {
+            assert(latestSnapshotVersion.getOrElse(0) == 0)
+          } else {
+            assert(latestSnapshotVersion.get >= 0)
+          }
         }
-      }
+    }
+    // Verify that only the bad partitions are all marked as lagging.
+    // Join queries should have all their state stores marked as lagging,
+    // which would be 4 stores per partition instead of 1.
+    val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, 
latestVersion)
+    assert(laggingStores.size == badPartitions.size * storeNames.size)
+    assert(laggingStores.map(_.storeId.partitionId).toSet == badPartitions)
   }
 
-  private val allJoinStateStoreNames: Seq[String] =
-    SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
+  /** Sets up a stateful dropDuplicate query for testing */
+  private def setUpStatefulQuery(
+      inputData: MemoryStream[Int], queryName: String): StreamingQuery = {
+    // Set up a stateful drop duplicate query
+    val aggregated = inputData.toDF().dropDuplicates()

Review Comment:
   It's fine, we don't need to address something so let's leave this as a nit.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to