liviazhu-db commented on code in PR #50742: URL: https://github.com/apache/spark/pull/50742#discussion_r2067517342
########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala: ########## @@ -468,29 +511,20 @@ private[sql] class RocksDBStateStoreProvider case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e) } } + override def getStore(version: Long, uniqueId: Option[String] = None): StateStore = { + loadStateStore(version, uniqueId, readOnly = false) + } + + override def getWriteStore( Review Comment: nit: rename to getWriteStoreFromReadStore? ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala: ########## @@ -565,6 +582,11 @@ trait StateStoreProvider { version: Long, stateStoreCkptId: Option[String] = None): StateStore + def getWriteStore( Review Comment: add docs comment? ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala: ########## @@ -82,19 +106,47 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag]( useColumnFamilies: Boolean = false, extraOptions: Map[String, String] = Map.empty) extends BaseStateStoreRDD[T, U](dataRDD, checkpointLocation, queryRunId, operatorId, - sessionState, storeCoordinator, extraOptions) { + sessionState, storeCoordinator, extraOptions) with StateStoreRDDProvider { + + // Using a ConcurrentHashMap to track state stores by partition ID + // and whether this store was used to create a write store or not. + @transient private lazy val partitionStores = + new java.util.concurrent.ConcurrentHashMap[Int, (ReadStateStore, Boolean)]() + + override def getStateStoreForPartition(partitionId: Int): Option[ReadStateStore] = { + val (readStore, _) = partitionStores.get(partitionId) + partitionStores.put(partitionId, (readStore, true)) + Option(readStore) + } override protected def getPartitions: Array[Partition] = dataRDD.partitions override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { val storeProviderId = getStateProviderId(partition) + val partitionId = partition.index val inputIter = dataRDD.iterator(partition, ctxt) val store = StateStore.getReadOnly( storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion, - stateStoreCkptIds.map(_.apply(partition.index).head), + stateStoreCkptIds.map(_.apply(partitionId).head), stateSchemaBroadcast, useColumnFamilies, storeConf, hadoopConfBroadcast.value.value) + + // Store reference for this partition + partitionStores.put(partitionId, (store, false)) + + // Register a cleanup callback to be executed when the task completes + ctxt.addTaskCompletionListener[Unit](_ => { Review Comment: Why are we adding the listeners here? Is it different from the one in mapPartitionsWithReadStateStore? ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala: ########## @@ -82,19 +106,47 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag]( useColumnFamilies: Boolean = false, extraOptions: Map[String, String] = Map.empty) extends BaseStateStoreRDD[T, U](dataRDD, checkpointLocation, queryRunId, operatorId, - sessionState, storeCoordinator, extraOptions) { + sessionState, storeCoordinator, extraOptions) with StateStoreRDDProvider { + + // Using a ConcurrentHashMap to track state stores by partition ID + // and whether this store was used to create a write store or not. + @transient private lazy val partitionStores = Review Comment: Will there ever be more than 1 store added here? ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala: ########## @@ -468,29 +511,20 @@ private[sql] class RocksDBStateStoreProvider case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e) } } + override def getStore(version: Long, uniqueId: Option[String] = None): StateStore = { + loadStateStore(version, uniqueId, readOnly = false) + } + + override def getWriteStore( + readStore: ReadStateStore, + version: Long, + uniqueId: Option[String] = None): StateStore = { + assert(version == readStore.version) Review Comment: Can you leave a comment or more informative error msg here? ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala: ########## @@ -950,6 +972,30 @@ object StateStore extends Logging { storeProvider.getReadStore(version, stateStoreCkptId) } + def getWriteStore( + readStore: ReadStateStore, + storeProviderId: StateStoreProviderId, + keySchema: StructType, + valueSchema: StructType, + keyStateEncoderSpec: KeyStateEncoderSpec, + version: Long, + stateStoreCkptId: Option[String], + stateSchemaBroadcast: Option[StateSchemaBroadcast], + useColumnFamilies: Boolean, + storeConf: StateStoreConf, + hadoopConf: Configuration, + useMultipleValuesPerKey: Boolean = false): StateStore = { + hadoopConf.set(StreamExecution.RUN_ID_KEY, storeProviderId.queryRunId.toString) + if (version < 0) { + throw QueryExecutionErrors.unexpectedStateStoreVersion(version) + } + hadoopConf.set(StreamExecution.RUN_ID_KEY, storeProviderId.queryRunId.toString) Review Comment: Why are we setting this twice? Can you add more comments about what is going on here -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org