anishshri-db commented on code in PR #49304:
URL: https://github.com/apache/spark/pull/49304#discussion_r1946049732


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -656,31 +803,75 @@ class RocksDB(
    *
    * @note This update is not committed to disk until commit() is called.
    */
-  def merge(key: Array[Byte], value: Array[Byte]): Unit = {
-    if (conf.trackTotalNumberOfRows) {
-      val oldValue = db.get(readOptions, key)
-      if (oldValue == null) {
-        numKeysOnWritingVersion += 1
+  def merge(
+      key: Array[Byte],
+      value: Array[Byte],
+      cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
+    val keyWithPrefix = if (useColumnFamilies) {
+      encodeStateRowWithPrefix(key, cfName)
+    } else {
+      key
+    }
+
+    if (useColumnFamilies) {
+      if (conf.trackTotalNumberOfRows) {
+        val oldValue = db.get(readOptions, keyWithPrefix)
+        if (oldValue == null) {
+          val cfInfo = getColumnFamilyInfo(cfName)
+          if (cfInfo.isInternal) {
+            numInternalKeysOnWritingVersion += 1
+          } else {
+            numKeysOnWritingVersion += 1
+          }
+        }
+      }
+    } else {
+      if (conf.trackTotalNumberOfRows) {
+        val oldValue = db.get(readOptions, keyWithPrefix)
+        if (oldValue == null) {
+          numKeysOnWritingVersion += 1
+        }
       }
     }
-    db.merge(writeOptions, key, value)
 
-    changelogWriter.foreach(_.merge(key, value))
+    db.merge(writeOptions, keyWithPrefix, value)
+    changelogWriter.foreach(_.merge(keyWithPrefix, value))
   }
 
   /**
    * Remove the key if present.
    * @note This update is not committed to disk until commit() is called.
    */
-  def remove(key: Array[Byte]): Unit = {
-    if (conf.trackTotalNumberOfRows) {
-      val value = db.get(readOptions, key)
-      if (value != null) {
-        numKeysOnWritingVersion -= 1
+  def remove(key: Array[Byte], cfName: String = 
StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
+    val keyWithPrefix = if (useColumnFamilies) {
+      encodeStateRowWithPrefix(key, cfName)
+    } else {
+      key
+    }
+
+    if (useColumnFamilies) {
+      if (conf.trackTotalNumberOfRows) {
+        val oldValue = db.get(readOptions, keyWithPrefix)
+        if (oldValue != null) {
+          val cfInfo = getColumnFamilyInfo(cfName)
+          if (cfInfo.isInternal) {
+            numInternalKeysOnWritingVersion -= 1
+          } else {
+            numKeysOnWritingVersion -= 1
+          }
+        }
+      }
+    } else {
+      if (conf.trackTotalNumberOfRows) {
+        val value = db.get(readOptions, keyWithPrefix)
+        if (value != null) {
+          numKeysOnWritingVersion -= 1
+        }
       }
     }
-    db.delete(writeOptions, key)
-    changelogWriter.foreach(_.delete(key))
+
+    db.delete(writeOptions, keyWithPrefix)
+    changelogWriter.foreach(_.delete(keyWithPrefix))
   }
 
   /**

Review Comment:
   That is at the provider layer right ? at this layer we expect colFamilyName 
to be passed if multiple column families are being used



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to