HeartSaVioR commented on code in PR #50123: URL: https://github.com/apache/spark/pull/50123#discussion_r2038644963
########## sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala: ########## @@ -158,281 +158,162 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } } - Seq( - ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName), - ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName) - ).foreach { - case (providerName, providerClassName) => - test( - s"SPARK-51358: Snapshot uploads in $providerName are properly reported to the coordinator" - ) { - withCoordinatorAndSQLConf( - sc, - SQLConf.SHUFFLE_PARTITIONS.key -> "5", - SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", - SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, - RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", - SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" - ) { - case (coordRef, spark) => - import spark.implicits._ - implicit val sqlContext = spark.sqlContext - - // Start a query and run some data to force snapshot uploads - val inputData = MemoryStream[Int] - val aggregated = inputData.toDF().dropDuplicates() - val checkpointLocation = Utils.createTempDir().getAbsoluteFile - val query = aggregated.writeStream - .format("memory") - .outputMode("update") - .queryName("query") - .option("checkpointLocation", checkpointLocation.toString) - .start() - // Add, commit, and wait multiple times to force snapshot versions and time difference - (0 until 6).foreach { _ => - inputData.addData(1, 2, 3) - query.processAllAvailable() - Thread.sleep(500) - } - val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery - val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation - val latestVersion = streamingQuery.lastProgress.batchId + 1 - - // Verify all stores have uploaded a snapshot and it's logged by the coordinator - (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { - partitionId => - val storeId = StateStoreId(stateCheckpointDir, 0, partitionId) - val providerId = StateStoreProviderId(storeId, query.runId) - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) - } - // Verify that we should not have any state stores lagging behind - assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) - query.stop() - } - } - } + private val allJoinStateStoreNames: Seq[String] = + SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) - Seq( + /** Lists the state store providers used for a test, and the set of lagging partition IDs */ + private val regularStateStoreProviders = Seq( + ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName, Set.empty[Int]), + ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName, Set.empty[Int]) + ) + + /** Lists the state store providers used for a test, and the set of lagging partition IDs */ + private val faultyStateStoreProviders = Seq( ( "RocksDBSkipMaintenanceOnCertainPartitionsProvider", - classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName, + Set(0, 1) ), ( "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider", - classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName + classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName, + Set(0, 1) ) - ).foreach { - case (providerName, providerClassName) => - test( - s"SPARK-51358: Snapshot uploads in $providerName are properly reported to the coordinator" - ) { - withCoordinatorAndSQLConf( - sc, - SQLConf.SHUFFLE_PARTITIONS.key -> "5", - SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", - SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, - RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", - SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" - ) { - case (coordRef, spark) => - import spark.implicits._ - implicit val sqlContext = spark.sqlContext - - // Start a query and run some data to force snapshot uploads - val inputData = MemoryStream[Int] - val aggregated = inputData.toDF().dropDuplicates() - val checkpointLocation = Utils.createTempDir().getAbsoluteFile - val query = aggregated.writeStream - .format("memory") - .outputMode("update") - .queryName("query") - .option("checkpointLocation", checkpointLocation.toString) - .start() - // Add, commit, and wait multiple times to force snapshot versions and time difference - (0 until 6).foreach { _ => - inputData.addData(1, 2, 3) - query.processAllAvailable() - Thread.sleep(500) - } - val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery - val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation - val latestVersion = streamingQuery.lastProgress.batchId + 1 - - (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { - partitionId => - val storeId = StateStoreId(stateCheckpointDir, 0, partitionId) - val providerId = StateStoreProviderId(storeId, query.runId) - if (partitionId <= 1) { - // Verify state stores in partition 0/1 are lagging and didn't upload anything - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).getOrElse(0) == 0) - } else { - // Verify other stores uploaded a snapshot and it's logged by the coordinator - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) - } - } - // We should have two state stores (id 0 and 1) that are lagging behind at this point - val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) - assert(laggingStores.size == 2) - assert(laggingStores.forall(_.storeId.partitionId <= 1)) - query.stop() + ) + + private val allStateStoreProviders = + regularStateStoreProviders ++ faultyStateStoreProviders + + /** + * Verifies snapshot upload RPC messages from state stores are registered and verifies + * the coordinator detected the correct lagging partitions. + */ + private def verifySnapshotUploadEvents( + coordRef: StateStoreCoordinatorRef, + query: StreamingQuery, + badPartitions: Set[Int], + storeNames: Seq[String] = Seq(StateStoreId.DEFAULT_STORE_NAME)): Unit = { + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + + // Verify all stores have uploaded a snapshot and it's logged by the coordinator + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + // Verify for every store name listed + storeNames.foreach { storeName => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId, storeName) + val providerId = StateStoreProviderId(storeId, query.runId) + val latestSnapshotVersion = coordRef.getLatestSnapshotVersionForTesting(providerId) + if (badPartitions.contains(partitionId)) { + assert(latestSnapshotVersion.getOrElse(0) == 0) + } else { + assert(latestSnapshotVersion.get >= 0) + } } - } + } + // Verify that only the bad partitions are all marked as lagging. + // Join queries should have all their state stores marked as lagging, + // which would be 4 stores per partition instead of 1. + val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == badPartitions.size * storeNames.size) + assert(laggingStores.map(_.storeId.partitionId).toSet == badPartitions) } - private val allJoinStateStoreNames: Seq[String] = - SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) + /** Sets up a stateful dropDuplicate query for testing */ + private def setUpStatefulQuery( + inputData: MemoryStream[Int], queryName: String): StreamingQuery = { + // Set up a stateful drop duplicate query + val aggregated = inputData.toDF().dropDuplicates() Review Comment: It's fine, we don't need to address something so let's leave this as a nit. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org