anishshri-db commented on code in PR #49304:
URL: https://github.com/apache/spark/pull/49304#discussion_r1960626246


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -621,28 +700,106 @@ class RocksDB(
     }
   }
 
+  /**
+   * Function to encode state row with virtual col family id prefix
+   * @param data - passed byte array to be stored in state store
+   * @param cfName - name of column family
+   * @return - encoded byte array with virtual column family id prefix
+   */
+  private def encodeStateRowWithPrefix(
+      data: Array[Byte],
+      cfName: String): Array[Byte] = {
+    // Create result array big enough for all prefixes plus data
+    val result = new Array[Byte](StateStore.VIRTUAL_COL_FAMILY_PREFIX_BYTES + 
data.length)
+    val offset = Platform.BYTE_ARRAY_OFFSET + 
StateStore.VIRTUAL_COL_FAMILY_PREFIX_BYTES
+
+    val cfInfo = getColumnFamilyInfo(cfName)
+    Platform.putShort(result, Platform.BYTE_ARRAY_OFFSET, cfInfo.cfId)
+
+    // Write the actual data
+    Platform.copyMemory(
+      data, Platform.BYTE_ARRAY_OFFSET,
+      result, offset,
+      data.length
+    )
+
+    result
+  }
+
+  /**
+   * Function to decode state row with virtual col family id prefix
+   * @param data - passed byte array retrieved from state store
+   * @return - pair of decoded byte array without virtual column family id 
prefix
+   *           and name of column family
+   */
+  private def decodeStateRowWithPrefix(data: Array[Byte]): (Array[Byte], 
String) = {
+    val cfId = Platform.getShort(data, Platform.BYTE_ARRAY_OFFSET)
+    val cfName = getColumnFamilyNameForId(cfId)
+    val offset = Platform.BYTE_ARRAY_OFFSET + 
StateStore.VIRTUAL_COL_FAMILY_PREFIX_BYTES
+
+    val key = new Array[Byte](data.length - 
StateStore.VIRTUAL_COL_FAMILY_PREFIX_BYTES)
+    Platform.copyMemory(
+      data, offset,
+      key, Platform.BYTE_ARRAY_OFFSET,
+      key.length
+    )
+
+    (key, cfName)
+  }
+
   /**
    * Get the value for the given key if present, or null.
    * @note This will return the last written value even if it was uncommitted.
    */
-  def get(key: Array[Byte]): Array[Byte] = {
-    db.get(readOptions, key)
+  def get(
+      key: Array[Byte],
+      cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Array[Byte] = {
+    val keyWithPrefix = if (useColumnFamilies) {
+      encodeStateRowWithPrefix(key, cfName)
+    } else {
+      key
+    }
+
+    db.get(readOptions, keyWithPrefix)
   }
 
   /**
    * Put the given value for the given key.
    * @note This update is not committed to disk until commit() is called.
    */
-  def put(key: Array[Byte], value: Array[Byte]): Unit = {
-    if (conf.trackTotalNumberOfRows) {
-      val oldValue = db.get(readOptions, key)
-      if (oldValue == null) {
-        numKeysOnWritingVersion += 1
+  def put(
+      key: Array[Byte],
+      value: Array[Byte],
+      cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
+    val keyWithPrefix = if (useColumnFamilies) {
+      encodeStateRowWithPrefix(key, cfName)
+    } else {
+      key
+    }
+
+    if (useColumnFamilies) {

Review Comment:
   Done



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -656,31 +813,75 @@ class RocksDB(
    *
    * @note This update is not committed to disk until commit() is called.
    */
-  def merge(key: Array[Byte], value: Array[Byte]): Unit = {
-    if (conf.trackTotalNumberOfRows) {
-      val oldValue = db.get(readOptions, key)
-      if (oldValue == null) {
-        numKeysOnWritingVersion += 1
+  def merge(
+      key: Array[Byte],
+      value: Array[Byte],
+      cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
+    val keyWithPrefix = if (useColumnFamilies) {
+      encodeStateRowWithPrefix(key, cfName)
+    } else {
+      key
+    }
+
+    if (useColumnFamilies) {
+      if (conf.trackTotalNumberOfRows) {
+        val oldValue = db.get(readOptions, keyWithPrefix)
+        if (oldValue == null) {
+          val cfInfo = getColumnFamilyInfo(cfName)
+          if (cfInfo.isInternal) {
+            numInternalKeysOnWritingVersion += 1
+          } else {
+            numKeysOnWritingVersion += 1
+          }
+        }
+      }
+    } else {
+      if (conf.trackTotalNumberOfRows) {
+        val oldValue = db.get(readOptions, keyWithPrefix)
+        if (oldValue == null) {
+          numKeysOnWritingVersion += 1
+        }
       }
     }
-    db.merge(writeOptions, key, value)
 
-    changelogWriter.foreach(_.merge(key, value))
+    db.merge(writeOptions, keyWithPrefix, value)
+    changelogWriter.foreach(_.merge(keyWithPrefix, value))
   }
 
   /**
    * Remove the key if present.
    * @note This update is not committed to disk until commit() is called.
    */
-  def remove(key: Array[Byte]): Unit = {
-    if (conf.trackTotalNumberOfRows) {
-      val value = db.get(readOptions, key)
-      if (value != null) {
-        numKeysOnWritingVersion -= 1
+  def remove(key: Array[Byte], cfName: String = 
StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
+    val keyWithPrefix = if (useColumnFamilies) {
+      encodeStateRowWithPrefix(key, cfName)
+    } else {
+      key
+    }
+
+    if (useColumnFamilies) {

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to