liviazhu-db commented on code in PR #50742:
URL: https://github.com/apache/spark/pull/50742#discussion_r2067517342


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala:
##########
@@ -468,29 +511,20 @@ private[sql] class RocksDBStateStoreProvider
       case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e)
     }
   }
+  override def getStore(version: Long, uniqueId: Option[String] = None): 
StateStore = {
+    loadStateStore(version, uniqueId, readOnly = false)
+  }
+
+  override def getWriteStore(

Review Comment:
   nit: rename to getWriteStoreFromReadStore?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala:
##########
@@ -565,6 +582,11 @@ trait StateStoreProvider {
       version: Long,
       stateStoreCkptId: Option[String] = None): StateStore
 
+  def getWriteStore(

Review Comment:
   add docs comment?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala:
##########
@@ -82,19 +106,47 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
     useColumnFamilies: Boolean = false,
     extraOptions: Map[String, String] = Map.empty)
   extends BaseStateStoreRDD[T, U](dataRDD, checkpointLocation, queryRunId, 
operatorId,
-    sessionState, storeCoordinator, extraOptions) {
+    sessionState, storeCoordinator, extraOptions) with StateStoreRDDProvider {
+
+  // Using a ConcurrentHashMap to track state stores by partition ID
+  // and whether this store was used to create a write store or not.
+  @transient private lazy val partitionStores =
+    new java.util.concurrent.ConcurrentHashMap[Int, (ReadStateStore, 
Boolean)]()
+
+  override def getStateStoreForPartition(partitionId: Int): 
Option[ReadStateStore] = {
+    val (readStore, _) = partitionStores.get(partitionId)
+    partitionStores.put(partitionId, (readStore, true))
+    Option(readStore)
+  }
 
   override protected def getPartitions: Array[Partition] = dataRDD.partitions
 
   override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = 
{
     val storeProviderId = getStateProviderId(partition)
+    val partitionId = partition.index
 
     val inputIter = dataRDD.iterator(partition, ctxt)
     val store = StateStore.getReadOnly(
       storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, 
storeVersion,
-      stateStoreCkptIds.map(_.apply(partition.index).head),
+      stateStoreCkptIds.map(_.apply(partitionId).head),
       stateSchemaBroadcast,
       useColumnFamilies, storeConf, hadoopConfBroadcast.value.value)
+
+    // Store reference for this partition
+    partitionStores.put(partitionId, (store, false))
+
+    // Register a cleanup callback to be executed when the task completes
+    ctxt.addTaskCompletionListener[Unit](_ => {

Review Comment:
   Why are we adding the listeners here? Is it different from the one in 
mapPartitionsWithReadStateStore?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala:
##########
@@ -82,19 +106,47 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
     useColumnFamilies: Boolean = false,
     extraOptions: Map[String, String] = Map.empty)
   extends BaseStateStoreRDD[T, U](dataRDD, checkpointLocation, queryRunId, 
operatorId,
-    sessionState, storeCoordinator, extraOptions) {
+    sessionState, storeCoordinator, extraOptions) with StateStoreRDDProvider {
+
+  // Using a ConcurrentHashMap to track state stores by partition ID
+  // and whether this store was used to create a write store or not.
+  @transient private lazy val partitionStores =

Review Comment:
   Will there ever be more than 1 store added here?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala:
##########
@@ -468,29 +511,20 @@ private[sql] class RocksDBStateStoreProvider
       case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e)
     }
   }
+  override def getStore(version: Long, uniqueId: Option[String] = None): 
StateStore = {
+    loadStateStore(version, uniqueId, readOnly = false)
+  }
+
+  override def getWriteStore(
+      readStore: ReadStateStore,
+      version: Long,
+      uniqueId: Option[String] = None): StateStore = {
+    assert(version == readStore.version)

Review Comment:
   Can you leave a comment or more informative error msg here?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala:
##########
@@ -950,6 +972,30 @@ object StateStore extends Logging {
     storeProvider.getReadStore(version, stateStoreCkptId)
   }
 
+  def getWriteStore(
+      readStore: ReadStateStore,
+      storeProviderId: StateStoreProviderId,
+      keySchema: StructType,
+      valueSchema: StructType,
+      keyStateEncoderSpec: KeyStateEncoderSpec,
+      version: Long,
+      stateStoreCkptId: Option[String],
+      stateSchemaBroadcast: Option[StateSchemaBroadcast],
+      useColumnFamilies: Boolean,
+      storeConf: StateStoreConf,
+      hadoopConf: Configuration,
+      useMultipleValuesPerKey: Boolean = false): StateStore = {
+    hadoopConf.set(StreamExecution.RUN_ID_KEY, 
storeProviderId.queryRunId.toString)
+    if (version < 0) {
+      throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
+    }
+    hadoopConf.set(StreamExecution.RUN_ID_KEY, 
storeProviderId.queryRunId.toString)

Review Comment:
   Why are we setting this twice? Can you add more comments about what is going 
on here



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to