micheal-o commented on code in PR #50572: URL: https://github.com/apache/spark/pull/50572#discussion_r2047964169
########## sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreInstanceMetricSuite.scala: ########## @@ -353,6 +353,81 @@ class StateStoreInstanceMetricSuite extends StreamTest with AlsoTestWithRocksDBF } } + testWithChangelogCheckpointingEnabled( + "SPARK-51779 Verify snapshot lag metrics are updated correctly for join " + + "using virtual column families with RocksDB" + ) { + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName, + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL.key -> "10", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT.key -> "4", + SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> "3" + ) { + withTempDir { checkpointDir => + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) as "leftValue") + val df2 = input2 + .toDF() + .select($"value" as "rightKey", ($"value" * 3) as "rightValue") + val joined = df1.join(df2, expr("leftKey = rightKey")) + + testStream(joined)( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(input1, 1, 5), + ProcessAllAvailable(), + AddData(input2, 1, 5, 10), + ProcessAllAvailable(), + AddData(input1, 2, 3), + ProcessAllAvailable(), + CheckNewAnswer((1, 2, 1, 3), (5, 10, 5, 15)), + AddData(input1, 2), + ProcessAllAvailable(), + AddData(input2, 3), + ProcessAllAvailable(), + AddData(input1, 4), + ProcessAllAvailable(), + Execute { q => + eventually(timeout(10.seconds)) { + // Make sure only smallest K active metrics are published. + // There are 5 metrics in total, but only 4 are published. + val allInstanceMetrics = q.lastProgress + .stateOperators(0) + .customMetrics + .asScala + .filter(_._1.startsWith(SNAPSHOT_LAG_METRIC_PREFIX)) + val badInstanceMetrics = allInstanceMetrics.filter { + case (key, _) => + key.startsWith(snapshotLagMetricName(0, "")) || + key.startsWith(snapshotLagMetricName(1, "")) + } + // Determined by STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT + assert( + allInstanceMetrics.size == q.sparkSession.conf + .get(SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT) + ) + // Two ids are blocked, making two lagging stores + // However, creating a family column forces a snapshot regardless of maintenance + // Thus, the version will be 1 for this case. + assert(badInstanceMetrics.count(_._2 == 1) == 2) Review Comment: I don't see where you are checking that there is only one reported metric per partition. Instead of before where there can be 4 metrics per partition ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala: ########## @@ -776,6 +818,66 @@ case class StreamingSymmetricHashJoinExec( def numUpdatedStateRows: Long = updatedStateRowsCount } + /** + * Case class used to manage both left and right side's joiners, combining + * information from both sides when necessary. + */ + private case class OneSideHashJoinerManager( + leftSideJoiner: OneSideHashJoiner, rightSideJoiner: OneSideHashJoiner) { + + def removeOldState(): Iterator[KeyToValuePair] = { + leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() + } + + def metrics: StateStoreMetrics = { + if (useVirtualColumnFamiliesForJoins) { + leftSideJoiner.getMetrics Review Comment: nit: explain why ########## sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala: ########## @@ -302,8 +302,16 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter while (iter.hasNext) iter.next() } - def numRows(implicit manager: SymmetricHashJoinStateManager): Long = { - manager.metrics.numKeys + def assertNumRows(stateFormatVersion: Int, target: Long)( + implicit manager: SymmetricHashJoinStateManager): Unit = { + // This suite originally uses HDFSBackStateStoreProvider, which provides instantaneous metrics + // for numRows. + // But for version 3 with virtual column families, RocksDBStateStoreProvider updates metrics + // asynchronously. This means the number of keys obtained from the metrics are very likely + // to be outdated right after a put/remove. + if (stateFormatVersion <= 2) { + assert(manager.metrics.numKeys == target) + } Review Comment: so this means we are doing no validation for v3. Or this test doesn't run with v3? ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala: ########## @@ -196,6 +196,19 @@ case class StreamingSymmetricHashJoinExec( private val allowMultipleStatefulOperators = conf.getConf(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE) + private val useVirtualColumnFamiliesForJoins = stateFormatVersion == 3 Review Comment: nit: can be less verbose, you're already in join code, so can call this `useVirtualColumnFamilies` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org