ericm-db commented on code in PR #50123:
URL: https://github.com/apache/spark/pull/50123#discussion_r2011065146


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala:
##########
@@ -164,13 +256,166 @@ private class StateStoreCoordinator(override val rpcEnv: 
RpcEnv)
       val storeIdsToRemove =
         instances.keys.filter(_.queryRunId == runId).toSeq
       instances --= storeIdsToRemove
+      // Also remove these instances from snapshot upload event tracking
+      stateStoreLatestUploadedSnapshot --= storeIdsToRemove

Review Comment:
   Also need to evict from the queryRunStartingPoint map



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala:
##########
@@ -164,13 +256,166 @@ private class StateStoreCoordinator(override val rpcEnv: 
RpcEnv)
       val storeIdsToRemove =
         instances.keys.filter(_.queryRunId == runId).toSeq
       instances --= storeIdsToRemove
+      // Also remove these instances from snapshot upload event tracking
+      stateStoreLatestUploadedSnapshot --= storeIdsToRemove
       logDebug(s"Deactivating instances related to checkpoint location $runId: 
" +
         storeIdsToRemove.mkString(", "))
       context.reply(true)
 
+    case ReportSnapshotUploaded(providerId, version, timestamp) =>
+      // Ignore this upload event if the registered latest version for the 
store is more recent,
+      // since it's possible that an older version gets uploaded after a new 
executor uploads for
+      // the same state store but with a newer snapshot.
+      logDebug(s"Snapshot version $version was uploaded for state store 
$providerId")
+      if (!stateStoreLatestUploadedSnapshot.get(providerId).exists(_.version 
>= version)) {
+        stateStoreLatestUploadedSnapshot.put(providerId, 
SnapshotUploadEvent(version, timestamp))
+      }
+      context.reply(true)
+
+    case LogLaggingStateStores(queryRunId, latestVersion) =>
+      val currentTimestamp = System.currentTimeMillis()
+      // Mark the query run's starting timestamp and latest version if the 
coordinator
+      // has never seen this query run before.
+      if (!queryRunStartingPoint.contains(queryRunId)) {
+        queryRunStartingPoint.put(queryRunId, QueryStartInfo(latestVersion, 
currentTimestamp))
+      } else {
+        // Only log lagging instances if the snapshot report upload is enabled,
+        // otherwise all instances will be considered lagging.
+        val laggingStores = findLaggingStores(queryRunId, latestVersion, 
currentTimestamp)
+        if (laggingStores.nonEmpty) {
+          logWarning(
+            log"StateStoreCoordinator Snapshot Lag Report for " +
+            log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " +
+            log"Number of state stores falling behind: " +
+            log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)}"
+          )
+          // Report all stores that are behind in snapshot uploads.
+          // Only report the list of providers lagging behind if the last 
reported time
+          // is not recent for this query run. The lag report interval denotes 
the minimum
+          // time between these full reports.
+          val timeSinceLastReport =
+            currentTimestamp - 
lastFullSnapshotLagReportTimeMs.getOrElse(queryRunId, 0L)
+          if (timeSinceLastReport > coordinatorLagReportInterval) {
+            // Mark timestamp of the report and log the lagging instances
+            lastFullSnapshotLagReportTimeMs.put(queryRunId, currentTimestamp)
+            // Only report the stores that are lagging the most behind in 
snapshot uploads.
+            laggingStores
+              .sortBy(stateStoreLatestUploadedSnapshot.getOrElse(_, 
defaultSnapshotUploadEvent))
+              .take(sqlConf.stateStoreCoordinatorMaxLaggingStoresToReport)
+              .foreach { providerId =>
+                val baseLogMessage =
+                  log"StateStoreCoordinator Snapshot Lag Detected for " +
+                  log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " +
+                  log"Store ID: ${MDC(LogKeys.STATE_STORE_ID, 
providerId.storeId)} " +
+                  log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, 
latestVersion)}"
+
+                val logMessage = 
stateStoreLatestUploadedSnapshot.get(providerId) match {
+                  case Some(snapshotEvent) =>
+                    val versionDelta = latestVersion - 
Math.max(snapshotEvent.version, 0)
+                    val timeDelta = currentTimestamp - snapshotEvent.timestamp
+
+                    baseLogMessage + log", " +
+                    log"latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, 
snapshotEvent)}, " +
+                    log"version delta: " +
+                    log"${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, 
versionDelta)}, " +
+                    log"time delta: ${MDC(LogKeys.SNAPSHOT_EVENT_TIME_DELTA, 
timeDelta)}ms)"
+                  case None =>
+                    baseLogMessage + log", latest snapshot: no upload for 
query run)"
+                }
+                logWarning(logMessage)
+              }
+          }
+        }
+      }
+      context.reply(true)
+
+    case GetLatestSnapshotVersionForTesting(providerId) =>
+      val version = 
stateStoreLatestUploadedSnapshot.get(providerId).map(_.version)
+      logDebug(s"Got latest snapshot version of the state store $providerId: 
$version")
+      context.reply(version)
+
+    case GetLaggingStoresForTesting(queryRunId, latestVersion) =>
+      val currentTimestamp = System.currentTimeMillis()
+      val laggingStores = findLaggingStores(queryRunId, latestVersion, 
currentTimestamp)
+      logDebug(s"Got lagging state stores: ${laggingStores.mkString(", ")}")
+      context.reply(laggingStores)
+
     case StopCoordinator =>
       stop() // Stop before replying to ensure that endpoint name has been 
deregistered
       logInfo("StateStoreCoordinator stopped")
       context.reply(true)
   }
+
+  private def findLaggingStores(
+      queryRunId: UUID,
+      referenceVersion: Long,
+      referenceTimestamp: Long): Seq[StateStoreProviderId] = {
+    // Do not report any instance as lagging if report snapshot upload is 
disabled.
+    if (!sqlConf.stateStoreCoordinatorReportSnapshotUploadLag) {
+      return Seq.empty
+    }
+
+    // Determine alert thresholds from configurations for both time and 
version differences.
+    val snapshotVersionDeltaMultiplier =
+      sqlConf.stateStoreCoordinatorMultiplierForMinVersionDiffToLog
+    val maintenanceIntervalMultiplier = 
sqlConf.stateStoreCoordinatorMultiplierForMinTimeDiffToLog
+    val minDeltasForSnapshot = sqlConf.stateStoreMinDeltasForSnapshot
+    val maintenanceInterval = sqlConf.streamingMaintenanceInterval
+
+    // Use the configured multipliers to determine the proper alert thresholds
+    val minVersionDeltaForLogging = snapshotVersionDeltaMultiplier * 
minDeltasForSnapshot
+    val minTimeDeltaForLogging = maintenanceIntervalMultiplier * 
maintenanceInterval
+
+    // Do not report any instance as lagging if this query run started 
recently, since the
+    // coordinator may be missing some information from the state stores.
+    // A run is considered recent if the time between now and the start of the 
run does not pass
+    // the time requirement for lagging instances.
+    // Similarly, the run is also considered too recent if not enough versions 
have passed
+    // since the start of the run.
+    val queryStartInfo = queryRunStartingPoint(queryRunId)
+
+    if (referenceTimestamp - queryStartInfo.startTimestamp <= 
minTimeDeltaForLogging ||

Review Comment:
   Let's move this check to line 283. I think it makes the code more readable.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala:
##########
@@ -155,16 +182,499 @@ class StateStoreCoordinatorSuite extends SparkFunSuite 
with SharedSparkContext {
       StateStore.stop()
     }
   }
+
+  Seq(
+    ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName),
+    ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName)
+  ).foreach {
+    case (providerName, providerClassName) =>
+      test(
+        s"SPARK-51358: Snapshot uploads in $providerName are properly reported 
to the coordinator"
+      ) {
+        withCoordinatorAndSQLConf(
+          sc,
+          SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+          SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+          SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+          SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
+          RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
+          SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> 
"true",
+          
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"2",
+          SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> 
"0"
+        ) {
+          case (coordRef, spark) =>
+            import spark.implicits._
+            implicit val sqlContext = spark.sqlContext
+
+            // Start a query and run some data to force snapshot uploads
+            val inputData = MemoryStream[Int]
+            val aggregated = inputData.toDF().dropDuplicates()
+            val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+            val query = aggregated.writeStream
+              .format("memory")
+              .outputMode("update")
+              .queryName("query")
+              .option("checkpointLocation", checkpointLocation.toString)
+              .start()
+            // Add, commit, and wait multiple times to force snapshot versions 
and time difference
+            (0 until 4).foreach { _ =>
+              inputData.addData(1, 2, 3)
+              query.processAllAvailable()
+              Thread.sleep(1000)
+            }
+            val streamingQuery = 
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
+            val stateCheckpointDir = 
streamingQuery.lastExecution.checkpointLocation
+            val latestVersion = streamingQuery.lastProgress.batchId + 1
+
+            // Verify all stores have uploaded a snapshot and it's logged by 
the coordinator
+            (0 until 
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
+              partitionId =>
+                val storeId = StateStoreId(stateCheckpointDir, 0, partitionId)
+                val providerId = StateStoreProviderId(storeId, query.runId)
+                
assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0)
+            }
+            // Verify that we should not have any state stores lagging behind
+            assert(coordRef.getLaggingStoresForTesting(query.runId, 
latestVersion).isEmpty)
+            query.stop()
+        }
+      }
+  }
+
+  Seq(
+    (
+      "RocksDBSkipMaintenanceOnCertainPartitionsProvider",
+      classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName
+    ),
+    (
+      "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider",
+      classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName
+    )
+  ).foreach {
+    case (providerName, providerClassName) =>
+      test(
+        s"SPARK-51358: Snapshot uploads in $providerName are properly reported 
to the coordinator"
+      ) {
+        withCoordinatorAndSQLConf(
+          sc,
+          SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+          SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+          SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+          SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
+          RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
+          SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> 
"true",
+          
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"2",
+          SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> 
"0"
+        ) {
+          case (coordRef, spark) =>
+            import spark.implicits._
+            implicit val sqlContext = spark.sqlContext
+
+            // Start a query and run some data to force snapshot uploads
+            val inputData = MemoryStream[Int]
+            val aggregated = inputData.toDF().dropDuplicates()
+            val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+            val query = aggregated.writeStream
+              .format("memory")
+              .outputMode("update")
+              .queryName("query")
+              .option("checkpointLocation", checkpointLocation.toString)
+              .start()
+            // Add, commit, and wait multiple times to force snapshot versions 
and time difference
+            (0 until 4).foreach { _ =>
+              inputData.addData(1, 2, 3)
+              query.processAllAvailable()
+              Thread.sleep(1000)
+            }
+            val streamingQuery = 
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
+            val stateCheckpointDir = 
streamingQuery.lastExecution.checkpointLocation
+            val latestVersion = streamingQuery.lastProgress.batchId + 1
+
+            (0 until 
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
+              partitionId =>
+                val storeId = StateStoreId(stateCheckpointDir, 0, partitionId)
+                val providerId = StateStoreProviderId(storeId, query.runId)
+                if (partitionId <= 1) {
+                  // Verify state stores in partition 0 and 1 are lagging and 
didn't upload anything
+                  
assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty)
+                } else {
+                  // Verify other stores have uploaded a snapshot and it's 
logged by the coordinator
+                  
assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0)
+                }
+            }
+            // We should have two state stores (id 0 and 1) that are lagging 
behind at this point
+            val laggingStores = 
coordRef.getLaggingStoresForTesting(query.runId, latestVersion)
+            assert(laggingStores.size == 2)
+            assert(laggingStores.forall(_.storeId.partitionId <= 1))
+            query.stop()
+        }
+      }
+  }
+
+  private val allJoinStateStoreNames: Seq[String] =
+    SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
+
+  Seq(
+    ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName),
+    ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName)
+  ).foreach {
+    case (providerName, providerClassName) =>
+      test(
+        s"SPARK-51358: Snapshot uploads for join queries with $providerName 
are properly " +
+        s"reported to the coordinator"
+      ) {
+        withCoordinatorAndSQLConf(
+          sc,
+          SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+          SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+          SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+          SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
+          RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
+          SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> 
"true",
+          
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"5",
+          SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> 
"0",
+          SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> 
"5"
+        ) {
+          case (coordRef, spark) =>
+            import spark.implicits._
+            implicit val sqlContext = spark.sqlContext
+
+            // Start a join query and run some data to force snapshot uploads
+            val input1 = MemoryStream[Int]
+            val input2 = MemoryStream[Int]
+            val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 
2) as "leftValue")
+            val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 
3) as "rightValue")
+            val joined = df1.join(df2, expr("leftKey = rightKey"))
+            val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+            val query = joined.writeStream
+              .format("memory")
+              .queryName("query")
+              .option("checkpointLocation", checkpointLocation.toString)
+              .start()
+            // Add, commit, and wait multiple times to force snapshot versions 
and time difference
+            (0 until 7).foreach { _ =>
+              input1.addData(1, 5)
+              input2.addData(1, 5, 10)
+              query.processAllAvailable()
+              Thread.sleep(500)
+            }
+            val streamingQuery = 
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
+            val stateCheckpointDir = 
streamingQuery.lastExecution.checkpointLocation
+            val latestVersion = streamingQuery.lastProgress.batchId + 1
+
+            // Verify all state stores for join queries are reporting snapshot 
uploads
+            (0 until 
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
+              partitionId =>
+                allJoinStateStoreNames.foreach { storeName =>
+                  val storeId = StateStoreId(stateCheckpointDir, 0, 
partitionId, storeName)
+                  val providerId = StateStoreProviderId(storeId, query.runId)
+                  
assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0)
+                }
+            }
+            // Verify that we should not have any state stores lagging behind
+            assert(coordRef.getLaggingStoresForTesting(query.runId, 
latestVersion).isEmpty)
+            query.stop()
+        }
+      }
+  }
+
+  Seq(
+    (
+      "RocksDBSkipMaintenanceOnCertainPartitionsProvider",
+      classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName
+    ),
+    (
+      "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider",
+      classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName
+    )
+  ).foreach {
+    case (providerName, providerClassName) =>
+      test(
+        s"SPARK-51358: Snapshot uploads for join queries with $providerName 
are properly " +
+        s"reported to the coordinator"
+      ) {
+        withCoordinatorAndSQLConf(
+          sc,
+          SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+          SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+          SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
+          SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
+          RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true",
+          SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> 
"true",
+          
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> 
"5",
+          SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> 
"0",
+          SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> 
"5"
+        ) {
+          case (coordRef, spark) =>
+            import spark.implicits._
+            implicit val sqlContext = spark.sqlContext
+
+            // Start a join query and run some data to force snapshot uploads
+            val input1 = MemoryStream[Int]
+            val input2 = MemoryStream[Int]
+            val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 
2) as "leftValue")
+            val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 
3) as "rightValue")
+            val joined = df1.join(df2, expr("leftKey = rightKey"))
+            val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+            val query = joined.writeStream
+              .format("memory")
+              .queryName("query")
+              .option("checkpointLocation", checkpointLocation.toString)
+              .start()
+            // Add, commit, and wait multiple times to force snapshot versions 
and time difference
+            (0 until 7).foreach { _ =>
+              input1.addData(1, 5)
+              input2.addData(1, 5, 10)
+              query.processAllAvailable()
+              Thread.sleep(500)
+            }
+            val streamingQuery = 
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
+            val stateCheckpointDir = 
streamingQuery.lastExecution.checkpointLocation
+            val latestVersion = streamingQuery.lastProgress.batchId + 1
+            // Verify all state stores for join queries are reporting snapshot 
uploads
+            (0 until 
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
+              partitionId =>
+                allJoinStateStoreNames.foreach { storeName =>
+                  val storeId = StateStoreId(stateCheckpointDir, 0, 
partitionId, storeName)
+                  val providerId = StateStoreProviderId(storeId, query.runId)
+                  if (partitionId <= 1) {
+                    // Verify state stores in partition 0 and 1 are lagging 
and didn't upload
+                    
assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty)
+                  } else {
+                    // Verify other stores have uploaded a snapshot and it's 
properly logged
+                    
assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0)
+                  }
+                }
+            }
+            // Verify that only stores from partition id 0 and 1 are lagging 
behind.
+            // Each partition has 4 stores for join queries, so there are 2 * 
4 = 8 lagging stores.
+            val laggingStores = 
coordRef.getLaggingStoresForTesting(query.runId, latestVersion)
+            assert(laggingStores.size == 2 * 4)
+            assert(laggingStores.forall(_.storeId.partitionId <= 1))
+        }
+      }
+  }
+
+  test(

Review Comment:
   Can we add a test for two queries running at the same time?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala:
##########
@@ -164,13 +256,166 @@ private class StateStoreCoordinator(override val rpcEnv: 
RpcEnv)
       val storeIdsToRemove =
         instances.keys.filter(_.queryRunId == runId).toSeq
       instances --= storeIdsToRemove
+      // Also remove these instances from snapshot upload event tracking
+      stateStoreLatestUploadedSnapshot --= storeIdsToRemove
       logDebug(s"Deactivating instances related to checkpoint location $runId: 
" +
         storeIdsToRemove.mkString(", "))
       context.reply(true)
 
+    case ReportSnapshotUploaded(providerId, version, timestamp) =>
+      // Ignore this upload event if the registered latest version for the 
store is more recent,
+      // since it's possible that an older version gets uploaded after a new 
executor uploads for
+      // the same state store but with a newer snapshot.
+      logDebug(s"Snapshot version $version was uploaded for state store 
$providerId")
+      if (!stateStoreLatestUploadedSnapshot.get(providerId).exists(_.version 
>= version)) {
+        stateStoreLatestUploadedSnapshot.put(providerId, 
SnapshotUploadEvent(version, timestamp))
+      }
+      context.reply(true)
+
+    case LogLaggingStateStores(queryRunId, latestVersion) =>
+      val currentTimestamp = System.currentTimeMillis()
+      // Mark the query run's starting timestamp and latest version if the 
coordinator
+      // has never seen this query run before.
+      if (!queryRunStartingPoint.contains(queryRunId)) {
+        queryRunStartingPoint.put(queryRunId, QueryStartInfo(latestVersion, 
currentTimestamp))
+      } else {
+        // Only log lagging instances if the snapshot report upload is enabled,
+        // otherwise all instances will be considered lagging.
+        val laggingStores = findLaggingStores(queryRunId, latestVersion, 
currentTimestamp)
+        if (laggingStores.nonEmpty) {
+          logWarning(
+            log"StateStoreCoordinator Snapshot Lag Report for " +
+            log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " +
+            log"Number of state stores falling behind: " +
+            log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)}"
+          )
+          // Report all stores that are behind in snapshot uploads.
+          // Only report the list of providers lagging behind if the last 
reported time
+          // is not recent for this query run. The lag report interval denotes 
the minimum
+          // time between these full reports.
+          val timeSinceLastReport =
+            currentTimestamp - 
lastFullSnapshotLagReportTimeMs.getOrElse(queryRunId, 0L)
+          if (timeSinceLastReport > coordinatorLagReportInterval) {
+            // Mark timestamp of the report and log the lagging instances
+            lastFullSnapshotLagReportTimeMs.put(queryRunId, currentTimestamp)
+            // Only report the stores that are lagging the most behind in 
snapshot uploads.
+            laggingStores
+              .sortBy(stateStoreLatestUploadedSnapshot.getOrElse(_, 
defaultSnapshotUploadEvent))
+              .take(sqlConf.stateStoreCoordinatorMaxLaggingStoresToReport)
+              .foreach { providerId =>
+                val baseLogMessage =
+                  log"StateStoreCoordinator Snapshot Lag Detected for " +
+                  log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " +
+                  log"Store ID: ${MDC(LogKeys.STATE_STORE_ID, 
providerId.storeId)} " +
+                  log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, 
latestVersion)}"
+
+                val logMessage = 
stateStoreLatestUploadedSnapshot.get(providerId) match {
+                  case Some(snapshotEvent) =>
+                    val versionDelta = latestVersion - 
Math.max(snapshotEvent.version, 0)
+                    val timeDelta = currentTimestamp - snapshotEvent.timestamp
+
+                    baseLogMessage + log", " +
+                    log"latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, 
snapshotEvent)}, " +
+                    log"version delta: " +
+                    log"${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, 
versionDelta)}, " +
+                    log"time delta: ${MDC(LogKeys.SNAPSHOT_EVENT_TIME_DELTA, 
timeDelta)}ms)"
+                  case None =>
+                    baseLogMessage + log", latest snapshot: no upload for 
query run)"
+                }
+                logWarning(logMessage)
+              }
+          }
+        }
+      }
+      context.reply(true)
+
+    case GetLatestSnapshotVersionForTesting(providerId) =>
+      val version = 
stateStoreLatestUploadedSnapshot.get(providerId).map(_.version)
+      logDebug(s"Got latest snapshot version of the state store $providerId: 
$version")
+      context.reply(version)
+
+    case GetLaggingStoresForTesting(queryRunId, latestVersion) =>
+      val currentTimestamp = System.currentTimeMillis()
+      val laggingStores = findLaggingStores(queryRunId, latestVersion, 
currentTimestamp)
+      logDebug(s"Got lagging state stores: ${laggingStores.mkString(", ")}")
+      context.reply(laggingStores)
+
     case StopCoordinator =>
       stop() // Stop before replying to ensure that endpoint name has been 
deregistered
       logInfo("StateStoreCoordinator stopped")
       context.reply(true)
   }
+
+  private def findLaggingStores(
+      queryRunId: UUID,
+      referenceVersion: Long,
+      referenceTimestamp: Long): Seq[StateStoreProviderId] = {
+    // Do not report any instance as lagging if report snapshot upload is 
disabled.
+    if (!sqlConf.stateStoreCoordinatorReportSnapshotUploadLag) {
+      return Seq.empty
+    }
+
+    // Determine alert thresholds from configurations for both time and 
version differences.
+    val snapshotVersionDeltaMultiplier =
+      sqlConf.stateStoreCoordinatorMultiplierForMinVersionDiffToLog
+    val maintenanceIntervalMultiplier = 
sqlConf.stateStoreCoordinatorMultiplierForMinTimeDiffToLog
+    val minDeltasForSnapshot = sqlConf.stateStoreMinDeltasForSnapshot
+    val maintenanceInterval = sqlConf.streamingMaintenanceInterval
+
+    // Use the configured multipliers to determine the proper alert thresholds
+    val minVersionDeltaForLogging = snapshotVersionDeltaMultiplier * 
minDeltasForSnapshot
+    val minTimeDeltaForLogging = maintenanceIntervalMultiplier * 
maintenanceInterval
+
+    // Do not report any instance as lagging if this query run started 
recently, since the
+    // coordinator may be missing some information from the state stores.
+    // A run is considered recent if the time between now and the start of the 
run does not pass
+    // the time requirement for lagging instances.
+    // Similarly, the run is also considered too recent if not enough versions 
have passed
+    // since the start of the run.
+    val queryStartInfo = queryRunStartingPoint(queryRunId)
+
+    if (referenceTimestamp - queryStartInfo.startTimestamp <= 
minTimeDeltaForLogging ||
+        referenceVersion - queryStartInfo.version <= 
minVersionDeltaForLogging) {
+      return Seq.empty
+    }
+    // Look for active state store providers that are lagging behind in 
snapshot uploads
+    instances.keys.filter { storeProviderId =>
+      // Only consider providers that are part of this specific query run
+      val latestSnapshot = stateStoreLatestUploadedSnapshot.getOrElse(
+        storeProviderId,
+        defaultSnapshotUploadEvent
+      )
+      storeProviderId.queryRunId == queryRunId && (
+        // Mark a state store as lagging if it's behind in both version and 
time.
+        // A state store is considered lagging if it's behind in both version 
and time according
+        // to the configured thresholds.
+        // Stores that didn't upload a snapshot will be treated as a store 
with a snapshot of
+        // version 0.
+        referenceVersion - Math.max(latestSnapshot.version, 0) > 
minVersionDeltaForLogging &&
+        referenceTimestamp - latestSnapshot.timestamp > minTimeDeltaForLogging
+      )
+    }.toSeq
+  }
 }
+
+case class SnapshotUploadEvent(
+    version: Long,
+    timestamp: Long
+) extends Ordered[SnapshotUploadEvent] {
+
+  override def compare(otherEvent: SnapshotUploadEvent): Int = {
+    // Compare by version first, then by timestamp as tiebreaker
+    val versionCompare = this.version.compare(otherEvent.version)
+    if (versionCompare == 0) {
+      this.timestamp.compare(otherEvent.timestamp)
+    } else {
+      versionCompare
+    }
+  }
+
+  override def toString(): String = {
+    s"SnapshotUploadEvent(version=$version, timestamp=$timestamp)"
+  }
+}
+
+case class QueryStartInfo(version: Long, startTimestamp: Long)

Review Comment:
   nit: leave a new line at the end



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to