ericm-db commented on code in PR #50123:
URL: https://github.com/apache/spark/pull/50123#discussion_r2006789543


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala:
##########
@@ -283,6 +290,12 @@ abstract class ProgressContext(
     progressReporter.lastNoExecutionProgressEventTime = 
triggerClock.getTimeMillis()
     progressReporter.updateProgress(newProgress)
 
+    // Ask the state store coordinator to log all lagging state stores
+    if (progressReporter.coordinatorReportSnapshotUploadLag) {
+      progressReporter.stateStoreCoordinator
+        .logLaggingStateStores(lastExecution.runId, lastEpochId + 1)

Review Comment:
   Let's make it really explicit and set `val batchId = lastEpochId + 1` and 
use that here.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -1472,6 +1474,13 @@ class RocksDB(
         log"Current lineage: ${MDC(LogKeys.LINEAGE, lineageManager)}")
       // Compare and update with the version that was just uploaded.
       lastUploadedSnapshotVersion.updateAndGet(v => Math.max(snapshot.version, 
v))
+      // Report snapshot upload event to the coordinator.
+      if (conf.stateStoreCoordinatorReportSnapshotUploadLag) {
+        // Note that we still report uploads even when changelog checkpointing 
is enabled.

Review Comment:
   nit: did you mean changelog checkpointing is disabled? 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala:
##########
@@ -168,9 +247,149 @@ private class StateStoreCoordinator(override val rpcEnv: 
RpcEnv)
         storeIdsToRemove.mkString(", "))
       context.reply(true)
 
+    case ReportSnapshotUploaded(storeId, version, timestamp) =>
+      // Ignore this upload event if the registered latest version for the 
store is more recent,
+      // since it's possible that an older version gets uploaded after a new 
executor uploads for
+      // the same state store but with a newer snapshot.
+      logDebug(s"Snapshot version $version was uploaded for state store 
$storeId")
+      if (!stateStoreLatestUploadedSnapshot.get(storeId).exists(_.version >= 
version)) {
+        stateStoreLatestUploadedSnapshot.put(storeId, 
SnapshotUploadEvent(version, timestamp))
+      }
+      context.reply(true)
+
+    case LogLaggingStateStores(queryRunId, latestVersion) =>
+      // Only log lagging instances if the snapshot report upload is enabled,
+      // otherwise all instances will be considered lagging.
+      val currentTimestamp = System.currentTimeMillis()
+      val laggingStores = findLaggingStores(queryRunId, latestVersion, 
currentTimestamp)
+      if (laggingStores.nonEmpty) {
+        logWarning(
+          log"StateStoreCoordinator Snapshot Lag Report for " +
+          log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " +
+          log"Number of state stores falling behind: " +
+          log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)}"
+        )
+        // Report all stores that are behind in snapshot uploads.
+        // Only report the full list of providers lagging behind if the last 
reported time
+        // is not recent. The lag report interval denotes the minimum time 
between these
+        // full reports.
+        val coordinatorLagReportInterval =
+          
sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL)
+        if (laggingStores.nonEmpty &&
+          currentTimestamp - lastFullSnapshotLagReportTimeMs > 
coordinatorLagReportInterval) {
+          // Mark timestamp of the full report and log the lagging instances
+          lastFullSnapshotLagReportTimeMs = currentTimestamp
+          // Only report the stores that are lagging the most behind in 
snapshot uploads.
+          laggingStores
+            .sortBy(stateStoreLatestUploadedSnapshot.getOrElse(_, 
defaultSnapshotUploadEvent))
+            
.take(sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT))
+            .foreach { storeId =>
+              val logMessage = stateStoreLatestUploadedSnapshot.get(storeId) 
match {
+                case Some(snapshotEvent) =>
+                  val versionDelta = latestVersion - 
Math.max(snapshotEvent.version, 0)
+                  val timeDelta = currentTimestamp - snapshotEvent.timestamp
+
+                  log"StateStoreCoordinator Snapshot Lag Detected for " +
+                  log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " +
+                  log"Store ID: ${MDC(LogKeys.STATE_STORE_ID, storeId)} " +
+                  log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, 
latestVersion)}, " +
+                  log"latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, 
snapshotEvent)}, " +
+                  log"version delta: " +
+                  log"${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, 
versionDelta)}, " +
+                  log"time delta: ${MDC(LogKeys.SNAPSHOT_EVENT_TIME_DELTA, 
timeDelta)}ms)"
+                case None =>
+                  log"StateStoreCoordinator Snapshot Lag Detected for " +
+                  log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " +
+                  log"Store ID: ${MDC(LogKeys.STATE_STORE_ID, storeId)} " +
+                  log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, 
latestVersion)}, " +
+                  log"latest snapshot: no upload for query run)"
+              }
+              logWarning(logMessage)
+            }
+        }
+      }
+      context.reply(true)
+
+    case GetLatestSnapshotVersionForTesting(storeId) =>
+      val version = 
stateStoreLatestUploadedSnapshot.get(storeId).map(_.version)
+      logDebug(s"Got latest snapshot version of the state store $storeId: 
$version")
+      context.reply(version)
+
+    case GetLaggingStoresForTesting(queryRunId, latestVersion) =>
+      val currentTimestamp = System.currentTimeMillis()
+      val laggingStores = findLaggingStores(queryRunId, latestVersion, 
currentTimestamp)
+      logDebug(s"Got lagging state stores: ${laggingStores.mkString(", ")}")
+      context.reply(laggingStores)
+
     case StopCoordinator =>
       stop() // Stop before replying to ensure that endpoint name has been 
deregistered
       logInfo("StateStoreCoordinator stopped")
       context.reply(true)
   }
+
+  case class SnapshotUploadEvent(
+      version: Long,
+      timestamp: Long
+  ) extends Ordered[SnapshotUploadEvent] {
+
+    def isLagging(latestVersion: Long, latestTimestamp: Long): Boolean = {
+      // Use version 0 for stores that have not uploaded a snapshot version 
for this run.
+      val versionDelta = latestVersion - Math.max(version, 0)
+      val timeDelta = latestTimestamp - timestamp
+
+      // Determine alert thresholds from configurations for both time and 
version differences.
+      val snapshotVersionDeltaMultiplier = sqlConf.getConf(
+        SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG)
+      val maintenanceIntervalMultiplier = sqlConf.getConf(
+        SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG)
+      val minDeltasForSnapshot = 
sqlConf.getConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)
+      val maintenanceInterval = 
sqlConf.getConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL)
+
+      // Use the configured multipliers to determine the proper alert 
thresholds
+      val minVersionDeltaForLogging = snapshotVersionDeltaMultiplier * 
minDeltasForSnapshot
+      val minTimeDeltaForLogging = maintenanceIntervalMultiplier * 
maintenanceInterval
+
+      // Mark a state store as lagging if it is behind in both version and 
time.
+      // For stores that have never uploaded a snapshot, the time requirement 
will
+      // be automatically satisfied as the initial timestamp is 0.
+      versionDelta > minVersionDeltaForLogging && timeDelta > 
minTimeDeltaForLogging
+    }
+
+    override def compare(otherEvent: SnapshotUploadEvent): Int = {
+      // Compare by version first, then by timestamp as tiebreaker
+      val versionCompare = this.version.compare(otherEvent.version)
+      if (versionCompare == 0) {
+        this.timestamp.compare(otherEvent.timestamp)
+      } else {
+        versionCompare
+      }
+    }
+
+    override def toString(): String = {
+      s"SnapshotUploadEvent(version=$version, timestamp=$timestamp)"
+    }
+  }
+
+  private def findLaggingStores(
+      queryRunId: UUID,
+      referenceVersion: Long,
+      referenceTimestamp: Long): Seq[StateStoreId] = {
+    // Do not report any instance as lagging if report snapshot upload is 
disabled.
+    if 
(!sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG)) {
+      return Seq.empty
+    }
+    // Look for state stores that are lagging behind in snapshot uploads
+    instances.keys
+      .filter { storeProviderId =>
+        // Only consider active providers that are part of this specific query 
run,
+        // but look through all state stores under this store ID, as it's 
possible that
+        // the same query re-runs with a new run ID but has already uploaded 
some snapshots.

Review Comment:
   I'm sorry, I'm still a little confused by this. queryRunId is the runId, 
correct?
   Won't the map _not_ have anything with the new runId?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala:
##########
@@ -129,10 +194,24 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: 
RpcEndpointRef) {
  * Class for coordinating instances of [[StateStore]]s loaded in executors 
across the cluster,
  * and get their locations for job scheduling.
  */
-private class StateStoreCoordinator(override val rpcEnv: RpcEnv)
-    extends ThreadSafeRpcEndpoint with Logging {
+private class StateStoreCoordinator(
+    override val rpcEnv: RpcEnv,
+    val sqlConf: SQLConf)
+  extends ThreadSafeRpcEndpoint with Logging {
   private val instances = new mutable.HashMap[StateStoreProviderId, 
ExecutorCacheTaskLocation]
 
+  // Stores the latest snapshot upload event for a specific state store
+  private val stateStoreLatestUploadedSnapshot =
+    new mutable.HashMap[StateStoreId, SnapshotUploadEvent]

Review Comment:
   I think we should evict on query end. @micheal-o what do you think?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala:
##########
@@ -129,10 +194,24 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: 
RpcEndpointRef) {
  * Class for coordinating instances of [[StateStore]]s loaded in executors 
across the cluster,
  * and get their locations for job scheduling.
  */
-private class StateStoreCoordinator(override val rpcEnv: RpcEnv)
-    extends ThreadSafeRpcEndpoint with Logging {
+private class StateStoreCoordinator(
+    override val rpcEnv: RpcEnv,
+    val sqlConf: SQLConf)
+  extends ThreadSafeRpcEndpoint with Logging {
   private val instances = new mutable.HashMap[StateStoreProviderId, 
ExecutorCacheTaskLocation]
 
+  // Stores the latest snapshot upload event for a specific state store
+  private val stateStoreLatestUploadedSnapshot =
+    new mutable.HashMap[StateStoreId, SnapshotUploadEvent]
+
+  // Default snapshot upload event to use when a provider has never uploaded a 
snapshot
+  private val defaultSnapshotUploadEvent = SnapshotUploadEvent(-1, 0)
+
+  // Stores the last timestamp in milliseconds where the coordinator did a 
full report on
+  // instances lagging behind on snapshot uploads. The initial timestamp is 
defaulted to
+  // 0 milliseconds.
+  private var lastFullSnapshotLagReportTimeMs = 0L

Review Comment:
   Do we need to set this per query? Let's say query1 has lagging stores and we 
log, and then query2 has lagging stores that we need to log 2 seconds after but 
the reporting interval is 5 minutes. How do we deal with that case?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala:
##########
@@ -129,10 +194,24 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: 
RpcEndpointRef) {
  * Class for coordinating instances of [[StateStore]]s loaded in executors 
across the cluster,
  * and get their locations for job scheduling.
  */
-private class StateStoreCoordinator(override val rpcEnv: RpcEnv)
-    extends ThreadSafeRpcEndpoint with Logging {
+private class StateStoreCoordinator(
+    override val rpcEnv: RpcEnv,
+    val sqlConf: SQLConf)
+  extends ThreadSafeRpcEndpoint with Logging {
   private val instances = new mutable.HashMap[StateStoreProviderId, 
ExecutorCacheTaskLocation]
 
+  // Stores the latest snapshot upload event for a specific state store
+  private val stateStoreLatestUploadedSnapshot =
+    new mutable.HashMap[StateStoreId, SnapshotUploadEvent]

Review Comment:
   do we ever evict from this map, or will it grow indefinitely?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala:
##########
@@ -168,9 +247,149 @@ private class StateStoreCoordinator(override val rpcEnv: 
RpcEnv)
         storeIdsToRemove.mkString(", "))
       context.reply(true)
 
+    case ReportSnapshotUploaded(storeId, version, timestamp) =>
+      // Ignore this upload event if the registered latest version for the 
store is more recent,
+      // since it's possible that an older version gets uploaded after a new 
executor uploads for
+      // the same state store but with a newer snapshot.
+      logDebug(s"Snapshot version $version was uploaded for state store 
$storeId")
+      if (!stateStoreLatestUploadedSnapshot.get(storeId).exists(_.version >= 
version)) {
+        stateStoreLatestUploadedSnapshot.put(storeId, 
SnapshotUploadEvent(version, timestamp))
+      }
+      context.reply(true)
+
+    case LogLaggingStateStores(queryRunId, latestVersion) =>
+      // Only log lagging instances if the snapshot report upload is enabled,
+      // otherwise all instances will be considered lagging.
+      val currentTimestamp = System.currentTimeMillis()
+      val laggingStores = findLaggingStores(queryRunId, latestVersion, 
currentTimestamp)
+      if (laggingStores.nonEmpty) {
+        logWarning(
+          log"StateStoreCoordinator Snapshot Lag Report for " +
+          log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " +
+          log"Number of state stores falling behind: " +
+          log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)}"
+        )
+        // Report all stores that are behind in snapshot uploads.
+        // Only report the full list of providers lagging behind if the last 
reported time
+        // is not recent. The lag report interval denotes the minimum time 
between these
+        // full reports.
+        val coordinatorLagReportInterval =
+          
sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL)
+        if (laggingStores.nonEmpty &&
+          currentTimestamp - lastFullSnapshotLagReportTimeMs > 
coordinatorLagReportInterval) {
+          // Mark timestamp of the full report and log the lagging instances
+          lastFullSnapshotLagReportTimeMs = currentTimestamp
+          // Only report the stores that are lagging the most behind in 
snapshot uploads.
+          laggingStores
+            .sortBy(stateStoreLatestUploadedSnapshot.getOrElse(_, 
defaultSnapshotUploadEvent))
+            
.take(sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT))
+            .foreach { storeId =>
+              val logMessage = stateStoreLatestUploadedSnapshot.get(storeId) 
match {
+                case Some(snapshotEvent) =>
+                  val versionDelta = latestVersion - 
Math.max(snapshotEvent.version, 0)
+                  val timeDelta = currentTimestamp - snapshotEvent.timestamp
+
+                  log"StateStoreCoordinator Snapshot Lag Detected for " +
+                  log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " +
+                  log"Store ID: ${MDC(LogKeys.STATE_STORE_ID, storeId)} " +
+                  log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, 
latestVersion)}, " +
+                  log"latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, 
snapshotEvent)}, " +
+                  log"version delta: " +
+                  log"${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, 
versionDelta)}, " +
+                  log"time delta: ${MDC(LogKeys.SNAPSHOT_EVENT_TIME_DELTA, 
timeDelta)}ms)"
+                case None =>
+                  log"StateStoreCoordinator Snapshot Lag Detected for " +
+                  log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " +
+                  log"Store ID: ${MDC(LogKeys.STATE_STORE_ID, storeId)} " +
+                  log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, 
latestVersion)}, " +
+                  log"latest snapshot: no upload for query run)"
+              }
+              logWarning(logMessage)
+            }
+        }
+      }
+      context.reply(true)
+
+    case GetLatestSnapshotVersionForTesting(storeId) =>
+      val version = 
stateStoreLatestUploadedSnapshot.get(storeId).map(_.version)
+      logDebug(s"Got latest snapshot version of the state store $storeId: 
$version")
+      context.reply(version)
+
+    case GetLaggingStoresForTesting(queryRunId, latestVersion) =>
+      val currentTimestamp = System.currentTimeMillis()
+      val laggingStores = findLaggingStores(queryRunId, latestVersion, 
currentTimestamp)
+      logDebug(s"Got lagging state stores: ${laggingStores.mkString(", ")}")
+      context.reply(laggingStores)
+
     case StopCoordinator =>
       stop() // Stop before replying to ensure that endpoint name has been 
deregistered
       logInfo("StateStoreCoordinator stopped")
       context.reply(true)
   }
+
+  case class SnapshotUploadEvent(
+      version: Long,
+      timestamp: Long
+  ) extends Ordered[SnapshotUploadEvent] {
+
+    def isLagging(latestVersion: Long, latestTimestamp: Long): Boolean = {
+      // Use version 0 for stores that have not uploaded a snapshot version 
for this run.
+      val versionDelta = latestVersion - Math.max(version, 0)
+      val timeDelta = latestTimestamp - timestamp
+
+      // Determine alert thresholds from configurations for both time and 
version differences.
+      val snapshotVersionDeltaMultiplier = sqlConf.getConf(
+        SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG)
+      val maintenanceIntervalMultiplier = sqlConf.getConf(
+        SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG)
+      val minDeltasForSnapshot = 
sqlConf.getConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)
+      val maintenanceInterval = 
sqlConf.getConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL)
+
+      // Use the configured multipliers to determine the proper alert 
thresholds
+      val minVersionDeltaForLogging = snapshotVersionDeltaMultiplier * 
minDeltasForSnapshot
+      val minTimeDeltaForLogging = maintenanceIntervalMultiplier * 
maintenanceInterval
+
+      // Mark a state store as lagging if it is behind in both version and 
time.
+      // For stores that have never uploaded a snapshot, the time requirement 
will
+      // be automatically satisfied as the initial timestamp is 0.
+      versionDelta > minVersionDeltaForLogging && timeDelta > 
minTimeDeltaForLogging
+    }
+
+    override def compare(otherEvent: SnapshotUploadEvent): Int = {
+      // Compare by version first, then by timestamp as tiebreaker
+      val versionCompare = this.version.compare(otherEvent.version)
+      if (versionCompare == 0) {
+        this.timestamp.compare(otherEvent.timestamp)
+      } else {
+        versionCompare
+      }
+    }
+
+    override def toString(): String = {
+      s"SnapshotUploadEvent(version=$version, timestamp=$timestamp)"
+    }
+  }
+
+  private def findLaggingStores(
+      queryRunId: UUID,
+      referenceVersion: Long,
+      referenceTimestamp: Long): Seq[StateStoreId] = {
+    // Do not report any instance as lagging if report snapshot upload is 
disabled.
+    if 
(!sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG)) {
+      return Seq.empty
+    }
+    // Look for state stores that are lagging behind in snapshot uploads
+    instances.keys
+      .filter { storeProviderId =>
+        // Only consider active providers that are part of this specific query 
run,
+        // but look through all state stores under this store ID, as it's 
possible that
+        // the same query re-runs with a new run ID but has already uploaded 
some snapshots.

Review Comment:
   I'm honestly not sure why this still wouldn't consider everything as lagging 
on restart? Would you mind giving me a concrete example with queryId and runId?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala:
##########
@@ -155,16 +182,431 @@ class StateStoreCoordinatorSuite extends SparkFunSuite 
with SharedSparkContext {
       StateStore.stop()
     }
   }
+
+  Seq(
+    ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName),
+    ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName)
+  ).foreach {
+    case (providerName, providerClassName) =>
+      test(
+        s"SPARK-51358: Snapshot uploads in $providerName are properly reported 
to the coordinator"
+      ) {
+        withCoordinatorAndSQLConf(
+          sc,
+          SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+          SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+          SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+          SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
+          RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
+          SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> 
"true",
+          
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"2",
+          SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> 
"0"
+        ) {
+          case (coordRef, spark) =>
+            import spark.implicits._
+            implicit val sqlContext = spark.sqlContext
+
+            // Start a query and run some data to force snapshot uploads
+            val inputData = MemoryStream[Int]
+            val aggregated = inputData.toDF().dropDuplicates()
+            val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+            val query = aggregated.writeStream
+              .format("memory")
+              .outputMode("update")
+              .queryName("query")
+              .option("checkpointLocation", checkpointLocation.toString)
+              .start()
+            // Add, commit, and wait multiple times to force snapshot versions 
and time difference
+            (0 until 2).foreach { _ =>
+              inputData.addData(1, 2, 3)
+              query.processAllAvailable()
+              Thread.sleep(1000)
+            }
+            val streamingQuery = 
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
+            val stateCheckpointDir = 
streamingQuery.lastExecution.checkpointLocation
+            val latestVersion = streamingQuery.lastProgress.batchId + 1
+
+            // Verify all stores have uploaded a snapshot and it's logged by 
the coordinator
+            (0 until 
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
+              partitionId =>
+                val storeId = StateStoreId(stateCheckpointDir, 0, partitionId)
+                
assert(coordRef.getLatestSnapshotVersionForTesting(storeId).get >= 0)
+            }
+            // Verify that we should not have any state stores lagging behind
+            assert(coordRef.getLaggingStoresForTesting(query.runId, 
latestVersion).isEmpty)
+            query.stop()
+        }
+      }
+  }
+
+  Seq(
+    (
+      "RocksDBSkipMaintenanceOnCertainPartitionsProvider",
+      classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName
+    ),
+    (
+      "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider",
+      classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName
+    )
+  ).foreach {
+    case (providerName, providerClassName) =>
+      test(
+        s"SPARK-51358: Snapshot uploads in $providerName are properly reported 
to the coordinator"
+      ) {
+        withCoordinatorAndSQLConf(
+          sc,
+          SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+          SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+          SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+          SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
+          RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
+          SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> 
"true",
+          
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"2",
+          SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> 
"0"
+        ) {
+          case (coordRef, spark) =>
+            import spark.implicits._
+            implicit val sqlContext = spark.sqlContext
+
+            // Start a query and run some data to force snapshot uploads
+            val inputData = MemoryStream[Int]
+            val aggregated = inputData.toDF().dropDuplicates()
+            val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+            val query = aggregated.writeStream
+              .format("memory")
+              .outputMode("update")
+              .queryName("query")
+              .option("checkpointLocation", checkpointLocation.toString)
+              .start()
+            // Add, commit, and wait multiple times to force snapshot versions 
and time difference
+            (0 until 3).foreach { _ =>
+              inputData.addData(1, 2, 3)
+              query.processAllAvailable()
+              Thread.sleep(1000)
+            }
+            val streamingQuery = 
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
+            val stateCheckpointDir = 
streamingQuery.lastExecution.checkpointLocation
+            val latestVersion = streamingQuery.lastProgress.batchId + 1
+
+            (0 until 
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
+              partitionId =>
+                val storeId = StateStoreId(stateCheckpointDir, 0, partitionId)
+                if (partitionId <= 1) {
+                  // Verify state stores in partition 0 and 1 are lagging and 
didn't upload anything
+                  
assert(coordRef.getLatestSnapshotVersionForTesting(storeId).isEmpty)
+                } else {
+                  // Verify other stores have uploaded a snapshot and it's 
logged by the coordinator
+                  
assert(coordRef.getLatestSnapshotVersionForTesting(storeId).get >= 0)
+                }
+            }
+            // We should have two state stores (id 0 and 1) that are lagging 
behind at this point
+            val laggingStores = 
coordRef.getLaggingStoresForTesting(query.runId, latestVersion)
+            assert(laggingStores.size == 2)
+            assert(laggingStores.forall(_.partitionId <= 1))
+            query.stop()
+        }
+      }
+  }
+
+  private val allJoinStateStoreNames: Seq[String] =
+    SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
+
+  Seq(
+    ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName),
+    ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName)
+  ).foreach {
+    case (providerName, providerClassName) =>
+      test(
+        s"SPARK-51358: Snapshot uploads for join queries with $providerName 
are properly " +
+        s"reported to the coordinator"
+      ) {
+        withCoordinatorAndSQLConf(
+          sc,
+          SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+          SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+          SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+          SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
+          RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
+          SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> 
"true",
+          
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"5",
+          SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> 
"0",
+          SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> 
"5"
+        ) {
+          case (coordRef, spark) =>
+            import spark.implicits._
+            implicit val sqlContext = spark.sqlContext
+
+            // Start a join query and run some data to force snapshot uploads
+            val input1 = MemoryStream[Int]
+            val input2 = MemoryStream[Int]
+            val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 
2) as "leftValue")
+            val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 
3) as "rightValue")
+            val joined = df1.join(df2, expr("leftKey = rightKey"))
+            val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+            val query = joined.writeStream
+              .format("memory")
+              .queryName("query")
+              .option("checkpointLocation", checkpointLocation.toString)
+              .start()
+            // Add, commit, and wait multiple times to force snapshot versions 
and time difference
+            (0 until 5).foreach { _ =>
+              input1.addData(1, 5)
+              input2.addData(1, 5, 10)
+              query.processAllAvailable()
+              Thread.sleep(500)
+            }
+            val streamingQuery = 
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
+            val stateCheckpointDir = 
streamingQuery.lastExecution.checkpointLocation
+            val latestVersion = streamingQuery.lastProgress.batchId + 1
+
+            // Verify all state stores for join queries are reporting snapshot 
uploads
+            (0 until 
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
+              partitionId =>
+                allJoinStateStoreNames.foreach { storeName =>
+                  val storeId = StateStoreId(stateCheckpointDir, 0, 
partitionId, storeName)
+                  
assert(coordRef.getLatestSnapshotVersionForTesting(storeId).get >= 0)
+                }
+            }
+            // Verify that we should not have any state stores lagging behind
+            assert(coordRef.getLaggingStoresForTesting(query.runId, 
latestVersion).isEmpty)
+            query.stop()
+        }
+      }
+  }
+
+  Seq(
+    (
+      "RocksDBSkipMaintenanceOnCertainPartitionsProvider",
+      classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName
+    ),
+    (
+      "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider",
+      classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName
+    )
+  ).foreach {
+    case (providerName, providerClassName) =>
+      test(
+        s"SPARK-51358: Snapshot uploads for join queries with $providerName 
are properly " +
+        s"reported to the coordinator"
+      ) {
+        withCoordinatorAndSQLConf(
+          sc,
+          SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+          SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+          SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+          SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
+          RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
+          SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> 
"true",
+          
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"5",
+          SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> 
"0",
+          SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> 
"5"
+        ) {
+          case (coordRef, spark) =>
+            import spark.implicits._
+            implicit val sqlContext = spark.sqlContext
+
+            // Start a join query and run some data to force snapshot uploads
+            val input1 = MemoryStream[Int]
+            val input2 = MemoryStream[Int]
+            val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 
2) as "leftValue")
+            val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 
3) as "rightValue")
+            val joined = df1.join(df2, expr("leftKey = rightKey"))
+            val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+            val query = joined.writeStream
+              .format("memory")
+              .queryName("query")
+              .option("checkpointLocation", checkpointLocation.toString)
+              .start()
+            // Add, commit, and wait multiple times to force snapshot versions 
and time difference
+            (0 until 6).foreach { _ =>
+              input1.addData(1, 5)
+              input2.addData(1, 5, 10)
+              query.processAllAvailable()
+              Thread.sleep(500)
+            }
+            val streamingQuery = 
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
+            val stateCheckpointDir = 
streamingQuery.lastExecution.checkpointLocation
+            val latestVersion = streamingQuery.lastProgress.batchId + 1
+            // Verify all state stores for join queries are reporting snapshot 
uploads
+            (0 until 
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
+              partitionId =>
+                allJoinStateStoreNames.foreach { storeName =>
+                  val storeId = StateStoreId(stateCheckpointDir, 0, 
partitionId, storeName)
+                  if (partitionId <= 1) {
+                    // Verify state stores in partition 0 and 1 are lagging 
and didn't upload
+                    
assert(coordRef.getLatestSnapshotVersionForTesting(storeId).isEmpty)
+                  } else {
+                    // Verify other stores have uploaded a snapshot and it's 
properly logged
+                    
assert(coordRef.getLatestSnapshotVersionForTesting(storeId).get >= 0)
+                  }
+                }
+            }
+            // Verify that only stores from partition id 0 and 1 are lagging 
behind.
+            // Each partition has 4 stores for join queries, so there are 2 * 
4 = 8 lagging stores.
+            val laggingStores = 
coordRef.getLaggingStoresForTesting(query.runId, latestVersion)
+            assert(laggingStores.size == 2 * 4)
+            assert(laggingStores.forall(_.partitionId <= 1))
+        }
+      }
+  }
+
+  test(
+    "SPARK-51358: Snapshot uploads in RocksDB are not reported if changelog " +
+    "checkpointing is disabled"
+  ) {
+    withCoordinatorAndSQLConf(
+      sc,
+      SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+      SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+      SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+      SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+      SQLConf.STATE_STORE_PROVIDER_CLASS.key -> 
classOf[RocksDBStateStoreProvider].getName,
+      RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "false",
+      SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true",
+      SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key 
-> "1",
+      
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"2",
+      SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0"
+    ) {
+      case (coordRef, spark) =>
+        import spark.implicits._
+        implicit val sqlContext = spark.sqlContext
+
+        // Start a query and run some data to force snapshot uploads
+        val inputData = MemoryStream[Int]
+        val aggregated = inputData.toDF().dropDuplicates()
+        val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+        val query = aggregated.writeStream
+          .format("memory")
+          .outputMode("update")
+          .queryName("query")
+          .option("checkpointLocation", checkpointLocation.toString)
+          .start()
+        // Go through several rounds of input to force snapshot uploads
+        (0 until 5).foreach { _ =>
+          inputData.addData(1, 2, 3)
+          query.processAllAvailable()
+          Thread.sleep(1000)
+        }
+        val latestVersion =
+          
query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 
1
+        // Verify that no instances are marked as lagging, even when upload 
messages are sent.
+        // Since snapshot uploads are tied to commit, the lack of version 
difference should prevent
+        // the stores from being marked as lagging.
+        assert(coordRef.getLaggingStoresForTesting(query.runId, 
latestVersion).isEmpty)
+        query.stop()
+    }
+  }
+}
+
+class StateStoreCoordinatorStreamingSuite extends StreamTest {
+  import testImplicits._
+
+  test("SPARK-51358: Restarting queries do not mark state stores as lagging") {
+    withSQLConf(
+      SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+      SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+      SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+      SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+      SQLConf.STATE_STORE_PROVIDER_CLASS.key -> 
classOf[RocksDBStateStoreProvider].getName,
+      RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
+      SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true",
+      
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"5",
+      SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0"

Review Comment:
   Do we not need to set the maintenance interval multiplier as well?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org


Reply via email to