andygrove commented on code in PR #4591: URL: https://github.com/apache/datafusion-comet/pull/4591#discussion_r3500090924
########## spark/src/test/scala/org/apache/comet/exec/CometInMemoryCacheSuite.scala: ########## @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.exec + +import java.{util => ju} + +import org.apache.spark.CometDriverPlugin +import org.apache.spark.SparkConf +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.catalyst.expressions.{And, Expression, GreaterThanOrEqual, LessThan, Literal} +import org.apache.spark.sql.columnar.SimpleMetricsCachedBatch +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} + +import org.apache.comet.CometConf + +class CometInMemoryCacheSuite extends CometTestBase { + override protected def sparkConf: SparkConf = { + val conf = new SparkConf() + conf.set("spark.driver.memory", "1G") + conf.set("spark.executor.memory", "1G") + conf.set("spark.executor.memoryOverhead", "2G") + conf.set("spark.plugins", "org.apache.spark.CometPlugin") + conf.set( + "spark.shuffle.manager", + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") + conf.set("spark.comet.enabled", "true") + conf.set("spark.comet.exec.enabled", "true") + conf.set("spark.comet.exec.onHeap.enabled", "true") + conf.set("spark.comet.metrics.enabled", "true") + conf.set( + "spark.sql.cache.serializer", + "org.apache.spark.sql.comet.execution.arrow.ArrowCachedBatchSerializer") + conf + } + + private def cachedBatchTypes(table: String): Array[String] = { + val ds = spark.table(table).asInstanceOf[org.apache.spark.sql.classic.Dataset[_]] Review Comment: This suite lives under `src/test/scala` (built for all Spark versions) but uses `org.apache.spark.sql.classic.Dataset`, which only exists in Spark 4.x. In this repo that package is only referenced from `spark/src/test/spark-4.x/`. The default build is now Spark 4.1 so it compiles locally, but CI also builds 3.4 and 3.5 and this suite would fail to compile there. Could the cache lookup avoid the `classic.Dataset` cast, or move the version-specific helper into a `spark-4.x` shim? Same cast appears again around line 284. ########## spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowCachedBatchSerializer.scala: ########## @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import scala.collection.JavaConverters._ + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.columnar.{CachedBatch, SimpleMetricsCachedBatch, SimpleMetricsCachedBatchSerializer} +import org.apache.spark.sql.comet.util.Utils +import org.apache.spark.sql.execution.columnar.{DefaultCachedBatch, DefaultCachedBatchSerializer} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import org.apache.spark.util.io.ChunkedByteBuffer + +import org.apache.comet.CometConf + +/** + * Cached batch format used when Comet writes Spark in-memory cache data. + * + * `bytes` contains compressed Arrow stream data produced by `Utils.serializeBatches`. The cache + * manager still owns storage and eviction; this class only changes the cached payload. + */ +private case class CometCachedBatch( + override val numRows: Int, + override val sizeInBytes: Long, + override val stats: InternalRow, + bytes: ChunkedByteBuffer) + extends SimpleMetricsCachedBatch + +/** + * Cache serializer that stores Comet-compatible Arrow batches in Spark's in-memory cache. + * + * When Comet cache support is disabled, row-based cache writes and default cache reads are + * delegated to Spark's `DefaultCachedBatchSerializer`. + */ +class ArrowCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { + + private val fallback = new DefaultCachedBatchSerializer() + + // Cache writes use Comet format only when both Comet and the in-memory cache scan are enabled. + private def enabled(conf: SQLConf): Boolean = { + CometConf.COMET_ENABLED.get(conf) && + CometConf.COMET_EXEC_IN_MEMORY_CACHE_ENABLED.get(conf) + } + + // Row-to-Arrow conversion needs a StructType, while cache APIs pass attributes. + private def toStructType(schema: Seq[Attribute]): StructType = { + StructType(schema.map { attr => + StructField(attr.name, attr.dataType, attr.nullable, attr.metadata) + }) + } + + // Build the statistics row expected by SimpleMetricsCachedBatchSerializer. + // For each cached column Spark expects five values in this order: + // lower bound, upper bound, null count, row count, and size in bytes. + private def computeStats(batch: ColumnarBatch, attrs: Seq[Attribute]): InternalRow = { + val numCols = attrs.length + val lower = new Array[Any](numCols) + val upper = new Array[Any](numCols) + val nulls = Array.fill[Int](numCols)(0) + val numRows = batch.numRows() + + var c = 0 + while (c < numCols) { + val dt = attrs(c).dataType + val col = batch.column(c) + var r = 0 + while (r < numRows) { + if (col.isNullAt(r)) { + nulls(c) += 1 + } else if (tracksBounds(dt)) { + val value = readValue(col, dt, r) + if (lower(c) == null || compare(dt, value, lower(c)) < 0) { + lower(c) = value + } + if (upper(c) == null || compare(dt, value, upper(c)) > 0) { + upper(c) = value + } + } + r += 1 + } + c += 1 + } + + val values = new Array[Any](numCols * 5) + c = 0 + while (c < numCols) { + val base = c * 5 + values(base) = lower(c) + values(base + 1) = upper(c) + values(base + 2) = nulls(c) + values(base + 3) = numRows + values(base + 4) = 0L + c += 1 + } + + new GenericInternalRow(values) + } + + // Spark can prune cache batches only for types whose bounds can be compared. + // Other types still report null count and row count but leave bounds as null. + private def tracksBounds(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + _: DecimalType | StringType | DateType | TimestampType | TimestampNTZType => + true + case _ => false + } + + // Read a non-null value from a ColumnVector using Spark's internal value type + // for the corresponding DataType. + private def readValue(col: ColumnVector, dt: DataType, rowId: Int): Any = dt match { + case BooleanType => col.getBoolean(rowId) + case ByteType => col.getByte(rowId) + case ShortType => col.getShort(rowId) + case IntegerType | DateType => col.getInt(rowId) + case LongType | TimestampType | TimestampNTZType => col.getLong(rowId) + case FloatType => col.getFloat(rowId) + case DoubleType => col.getDouble(rowId) + case d: DecimalType => col.getDecimal(rowId, d.precision, d.scale) + case StringType => col.getUTF8String(rowId).copy() + case _ => null + } + + // Compare values using the same physical representation used in the stats row. + private def compare(dt: DataType, left: Any, right: Any): Int = dt match { + case BooleanType => + java.lang.Boolean.compare(left.asInstanceOf[Boolean], right.asInstanceOf[Boolean]) + case ByteType => + java.lang.Byte.compare(left.asInstanceOf[Byte], right.asInstanceOf[Byte]) + case ShortType => + java.lang.Short.compare(left.asInstanceOf[Short], right.asInstanceOf[Short]) + case IntegerType | DateType => + java.lang.Integer.compare(left.asInstanceOf[Int], right.asInstanceOf[Int]) + case LongType | TimestampType | TimestampNTZType => + java.lang.Long.compare(left.asInstanceOf[Long], right.asInstanceOf[Long]) + case FloatType => + java.lang.Float.compare(left.asInstanceOf[Float], right.asInstanceOf[Float]) + case DoubleType => + java.lang.Double.compare(left.asInstanceOf[Double], right.asInstanceOf[Double]) + case _: DecimalType => + left.asInstanceOf[Decimal].compare(right.asInstanceOf[Decimal]) + case StringType => + ByteArray.compareBinary( + left.asInstanceOf[UTF8String].getBytes, + right.asInstanceOf[UTF8String].getBytes) + case other => + throw new IllegalStateException(s"compare called for unsupported type $other") + } + + // Compute Spark-compatible cache stats before serializing each batch to Arrow. + // The stats are stored beside the Arrow bytes so Spark's cache filter can prune + // CometCachedBatch without decoding the batch first. + private def encodeBatches( + batches: Iterator[ColumnarBatch], + attrs: Seq[Attribute]): Iterator[CachedBatch] = { + batches.flatMap { batch => + val stats = computeStats(batch, attrs) + + Utils.serializeBatches(Iterator.single(batch)).map { case (rows, buffer) => + CometCachedBatch( + numRows = rows.toInt, + sizeInBytes = buffer.size, + stats = stats, + bytes = buffer) + } + } + } + + override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = { + val activeConf = SQLConf.get + activeConf != null && enabled(activeConf) + } + override def supportsColumnarOutput(schema: StructType): Boolean = true + + // Columnar Comet output is stored as compressed Arrow stream bytes. + override def convertColumnarBatchToCachedBatch( + input: RDD[ColumnarBatch], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + + input.mapPartitions { batches => + encodeBatches(batches, schema) + } + } + + override def convertCachedBatchToColumnarBatch( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[ColumnarBatch] = { + + // Resolve requested columns by exprId, not by name, because aliases may reuse names. + val selectedIndices = + if (selectedAttributes.isEmpty) { + cacheAttributes.indices.toArray + } else { + val byExprId = cacheAttributes.zipWithIndex.map { case (attr, idx) => + attr.exprId -> idx + }.toMap + + selectedAttributes.map { attr => + byExprId.getOrElse( + attr.exprId, + throw new IllegalStateException( + s"Could not resolve selected attribute ${attr.name} from cache attributes")) + }.toArray + } + + val batchTypes = input.map(_.getClass.getName).distinct().collect() Review Comment: `collect()` is an action, so this kicks off a full distinct-and-collect job over the entire cached RDD on every scan, just to detect the batch type, before the real scan RDD is even returned. For a large cache that is an extra pass (plus a shuffle for `distinct`) per read, which undercuts the caching benefit. Since the batch type is homogeneous per relation, could this be handled lazily inside `mapPartitions` (pattern-match per batch, or inspect the first batch per partition) so no driver-side job is needed? The mixed-type guard is reasonable, but it should not cost a job. Same pattern in `convertCachedBatchToInternalRow` around line 315. ########## spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala: ########## @@ -282,6 +285,47 @@ case class CometExecRule(session: SparkSession) case op if isCometScan(op) => convertToComet(op, CometScanWrapper).getOrElse(op) + case scan: InMemoryTableScanExec => + val cachedBuffers = scan.relation.cacheBuilder.cachedColumnBuffers + val firstBatchOpt = cachedBuffers.take(1).headOption Review Comment: `cachedColumnBuffers` builds the cache RDD and `take(1)` runs a job to fetch the first batch, during plan transformation. This rule can fire more than once under AQE re-planning, so we would run a job at plan time on each fire just to read the class of the first cached batch. Is there a way to detect the Comet cache format without materializing a batch, for example checking `relation.cacheBuilder.serializer.isInstanceOf[ArrowCachedBatchSerializer]` together with the enabled flag? ########## spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowCachedBatchSerializer.scala: ########## @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import scala.collection.JavaConverters._ + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.columnar.{CachedBatch, SimpleMetricsCachedBatch, SimpleMetricsCachedBatchSerializer} +import org.apache.spark.sql.comet.util.Utils +import org.apache.spark.sql.execution.columnar.{DefaultCachedBatch, DefaultCachedBatchSerializer} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import org.apache.spark.util.io.ChunkedByteBuffer + +import org.apache.comet.CometConf + +/** + * Cached batch format used when Comet writes Spark in-memory cache data. + * + * `bytes` contains compressed Arrow stream data produced by `Utils.serializeBatches`. The cache + * manager still owns storage and eviction; this class only changes the cached payload. + */ +private case class CometCachedBatch( + override val numRows: Int, + override val sizeInBytes: Long, + override val stats: InternalRow, + bytes: ChunkedByteBuffer) + extends SimpleMetricsCachedBatch + +/** + * Cache serializer that stores Comet-compatible Arrow batches in Spark's in-memory cache. + * + * When Comet cache support is disabled, row-based cache writes and default cache reads are + * delegated to Spark's `DefaultCachedBatchSerializer`. + */ +class ArrowCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { + + private val fallback = new DefaultCachedBatchSerializer() + + // Cache writes use Comet format only when both Comet and the in-memory cache scan are enabled. + private def enabled(conf: SQLConf): Boolean = { + CometConf.COMET_ENABLED.get(conf) && + CometConf.COMET_EXEC_IN_MEMORY_CACHE_ENABLED.get(conf) + } + + // Row-to-Arrow conversion needs a StructType, while cache APIs pass attributes. + private def toStructType(schema: Seq[Attribute]): StructType = { + StructType(schema.map { attr => + StructField(attr.name, attr.dataType, attr.nullable, attr.metadata) + }) + } + + // Build the statistics row expected by SimpleMetricsCachedBatchSerializer. + // For each cached column Spark expects five values in this order: + // lower bound, upper bound, null count, row count, and size in bytes. + private def computeStats(batch: ColumnarBatch, attrs: Seq[Attribute]): InternalRow = { + val numCols = attrs.length + val lower = new Array[Any](numCols) + val upper = new Array[Any](numCols) + val nulls = Array.fill[Int](numCols)(0) + val numRows = batch.numRows() + + var c = 0 + while (c < numCols) { + val dt = attrs(c).dataType + val col = batch.column(c) + var r = 0 + while (r < numRows) { + if (col.isNullAt(r)) { + nulls(c) += 1 + } else if (tracksBounds(dt)) { + val value = readValue(col, dt, r) + if (lower(c) == null || compare(dt, value, lower(c)) < 0) { + lower(c) = value + } + if (upper(c) == null || compare(dt, value, upper(c)) > 0) { + upper(c) = value + } + } + r += 1 + } + c += 1 + } + + val values = new Array[Any](numCols * 5) + c = 0 + while (c < numCols) { + val base = c * 5 + values(base) = lower(c) + values(base + 1) = upper(c) + values(base + 2) = nulls(c) + values(base + 3) = numRows + values(base + 4) = 0L + c += 1 + } + + new GenericInternalRow(values) + } + + // Spark can prune cache batches only for types whose bounds can be compared. + // Other types still report null count and row count but leave bounds as null. + private def tracksBounds(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + _: DecimalType | StringType | DateType | TimestampType | TimestampNTZType => + true + case _ => false + } + + // Read a non-null value from a ColumnVector using Spark's internal value type + // for the corresponding DataType. + private def readValue(col: ColumnVector, dt: DataType, rowId: Int): Any = dt match { + case BooleanType => col.getBoolean(rowId) + case ByteType => col.getByte(rowId) + case ShortType => col.getShort(rowId) + case IntegerType | DateType => col.getInt(rowId) + case LongType | TimestampType | TimestampNTZType => col.getLong(rowId) + case FloatType => col.getFloat(rowId) + case DoubleType => col.getDouble(rowId) + case d: DecimalType => col.getDecimal(rowId, d.precision, d.scale) + case StringType => col.getUTF8String(rowId).copy() + case _ => null + } + + // Compare values using the same physical representation used in the stats row. + private def compare(dt: DataType, left: Any, right: Any): Int = dt match { + case BooleanType => + java.lang.Boolean.compare(left.asInstanceOf[Boolean], right.asInstanceOf[Boolean]) + case ByteType => + java.lang.Byte.compare(left.asInstanceOf[Byte], right.asInstanceOf[Byte]) + case ShortType => + java.lang.Short.compare(left.asInstanceOf[Short], right.asInstanceOf[Short]) + case IntegerType | DateType => + java.lang.Integer.compare(left.asInstanceOf[Int], right.asInstanceOf[Int]) + case LongType | TimestampType | TimestampNTZType => + java.lang.Long.compare(left.asInstanceOf[Long], right.asInstanceOf[Long]) + case FloatType => + java.lang.Float.compare(left.asInstanceOf[Float], right.asInstanceOf[Float]) + case DoubleType => + java.lang.Double.compare(left.asInstanceOf[Double], right.asInstanceOf[Double]) + case _: DecimalType => + left.asInstanceOf[Decimal].compare(right.asInstanceOf[Decimal]) + case StringType => + ByteArray.compareBinary( + left.asInstanceOf[UTF8String].getBytes, + right.asInstanceOf[UTF8String].getBytes) + case other => + throw new IllegalStateException(s"compare called for unsupported type $other") + } + + // Compute Spark-compatible cache stats before serializing each batch to Arrow. + // The stats are stored beside the Arrow bytes so Spark's cache filter can prune + // CometCachedBatch without decoding the batch first. + private def encodeBatches( + batches: Iterator[ColumnarBatch], + attrs: Seq[Attribute]): Iterator[CachedBatch] = { + batches.flatMap { batch => + val stats = computeStats(batch, attrs) + + Utils.serializeBatches(Iterator.single(batch)).map { case (rows, buffer) => + CometCachedBatch( + numRows = rows.toInt, + sizeInBytes = buffer.size, + stats = stats, + bytes = buffer) + } + } + } + + override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = { + val activeConf = SQLConf.get + activeConf != null && enabled(activeConf) + } + override def supportsColumnarOutput(schema: StructType): Boolean = true + + // Columnar Comet output is stored as compressed Arrow stream bytes. + override def convertColumnarBatchToCachedBatch( + input: RDD[ColumnarBatch], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + + input.mapPartitions { batches => + encodeBatches(batches, schema) + } + } + + override def convertCachedBatchToColumnarBatch( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[ColumnarBatch] = { + + // Resolve requested columns by exprId, not by name, because aliases may reuse names. + val selectedIndices = + if (selectedAttributes.isEmpty) { + cacheAttributes.indices.toArray + } else { + val byExprId = cacheAttributes.zipWithIndex.map { case (attr, idx) => + attr.exprId -> idx + }.toMap + + selectedAttributes.map { attr => + byExprId.getOrElse( + attr.exprId, + throw new IllegalStateException( + s"Could not resolve selected attribute ${attr.name} from cache attributes")) + }.toArray + } + + val batchTypes = input.map(_.getClass.getName).distinct().collect() + + if (batchTypes.isEmpty) { + input.sparkContext.emptyRDD[ColumnarBatch] + } else if (batchTypes.length > 1) { + throw new IllegalStateException( + s"Mixed cached batch types are not supported: ${batchTypes.mkString(", ")}") + } else if (batchTypes.head == classOf[CometCachedBatch].getName) { + input.mapPartitions { it => + it.flatMap { + case cb: CometCachedBatch => + Utils.decodeBatches(cb.bytes, "CometCache").map { batch => + if (selectedIndices.length == batch.numCols()) { Review Comment: This treats "selected count equals cached column count" as an identity projection and returns the decoded batch unchanged. For a full-width but reordered projection (cache is `[key, value]`, scan output is `[value, key]`), `selectedIndices` would be `[1, 0]`, length 2 equals `numCols` 2, and we would return the batch in the wrong column order. Spark's `DefaultCachedBatchSerializer` always projects by computed indices rather than taking a length shortcut. Could this guard be `selectedIndices.sameElements(batch.indices)` instead? If `InMemoryTableScanExec` is guaranteed to always emit in relation order with a separate `Project` on top this is moot, but the length-only check is fragile. Same shortcut in `convertCachedBatchToInternalRow` around line 330. ########## spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowCachedBatchSerializer.scala: ########## @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import scala.collection.JavaConverters._ + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.columnar.{CachedBatch, SimpleMetricsCachedBatch, SimpleMetricsCachedBatchSerializer} +import org.apache.spark.sql.comet.util.Utils +import org.apache.spark.sql.execution.columnar.{DefaultCachedBatch, DefaultCachedBatchSerializer} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import org.apache.spark.util.io.ChunkedByteBuffer + +import org.apache.comet.CometConf + +/** + * Cached batch format used when Comet writes Spark in-memory cache data. + * + * `bytes` contains compressed Arrow stream data produced by `Utils.serializeBatches`. The cache + * manager still owns storage and eviction; this class only changes the cached payload. + */ +private case class CometCachedBatch( + override val numRows: Int, + override val sizeInBytes: Long, + override val stats: InternalRow, + bytes: ChunkedByteBuffer) + extends SimpleMetricsCachedBatch + +/** + * Cache serializer that stores Comet-compatible Arrow batches in Spark's in-memory cache. + * + * When Comet cache support is disabled, row-based cache writes and default cache reads are + * delegated to Spark's `DefaultCachedBatchSerializer`. + */ +class ArrowCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { + + private val fallback = new DefaultCachedBatchSerializer() + + // Cache writes use Comet format only when both Comet and the in-memory cache scan are enabled. + private def enabled(conf: SQLConf): Boolean = { + CometConf.COMET_ENABLED.get(conf) && + CometConf.COMET_EXEC_IN_MEMORY_CACHE_ENABLED.get(conf) + } + + // Row-to-Arrow conversion needs a StructType, while cache APIs pass attributes. + private def toStructType(schema: Seq[Attribute]): StructType = { + StructType(schema.map { attr => + StructField(attr.name, attr.dataType, attr.nullable, attr.metadata) + }) + } + + // Build the statistics row expected by SimpleMetricsCachedBatchSerializer. + // For each cached column Spark expects five values in this order: + // lower bound, upper bound, null count, row count, and size in bytes. + private def computeStats(batch: ColumnarBatch, attrs: Seq[Attribute]): InternalRow = { + val numCols = attrs.length + val lower = new Array[Any](numCols) + val upper = new Array[Any](numCols) + val nulls = Array.fill[Int](numCols)(0) + val numRows = batch.numRows() + + var c = 0 + while (c < numCols) { + val dt = attrs(c).dataType + val col = batch.column(c) + var r = 0 + while (r < numRows) { + if (col.isNullAt(r)) { + nulls(c) += 1 + } else if (tracksBounds(dt)) { + val value = readValue(col, dt, r) + if (lower(c) == null || compare(dt, value, lower(c)) < 0) { + lower(c) = value + } + if (upper(c) == null || compare(dt, value, upper(c)) > 0) { + upper(c) = value + } + } + r += 1 + } + c += 1 + } + + val values = new Array[Any](numCols * 5) + c = 0 + while (c < numCols) { + val base = c * 5 + values(base) = lower(c) + values(base + 1) = upper(c) + values(base + 2) = nulls(c) + values(base + 3) = numRows + values(base + 4) = 0L + c += 1 + } + + new GenericInternalRow(values) + } + + // Spark can prune cache batches only for types whose bounds can be compared. + // Other types still report null count and row count but leave bounds as null. + private def tracksBounds(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + _: DecimalType | StringType | DateType | TimestampType | TimestampNTZType => + true + case _ => false + } + + // Read a non-null value from a ColumnVector using Spark's internal value type + // for the corresponding DataType. + private def readValue(col: ColumnVector, dt: DataType, rowId: Int): Any = dt match { + case BooleanType => col.getBoolean(rowId) + case ByteType => col.getByte(rowId) + case ShortType => col.getShort(rowId) + case IntegerType | DateType => col.getInt(rowId) + case LongType | TimestampType | TimestampNTZType => col.getLong(rowId) + case FloatType => col.getFloat(rowId) + case DoubleType => col.getDouble(rowId) + case d: DecimalType => col.getDecimal(rowId, d.precision, d.scale) + case StringType => col.getUTF8String(rowId).copy() + case _ => null + } + + // Compare values using the same physical representation used in the stats row. + private def compare(dt: DataType, left: Any, right: Any): Int = dt match { + case BooleanType => + java.lang.Boolean.compare(left.asInstanceOf[Boolean], right.asInstanceOf[Boolean]) + case ByteType => + java.lang.Byte.compare(left.asInstanceOf[Byte], right.asInstanceOf[Byte]) + case ShortType => + java.lang.Short.compare(left.asInstanceOf[Short], right.asInstanceOf[Short]) + case IntegerType | DateType => + java.lang.Integer.compare(left.asInstanceOf[Int], right.asInstanceOf[Int]) + case LongType | TimestampType | TimestampNTZType => + java.lang.Long.compare(left.asInstanceOf[Long], right.asInstanceOf[Long]) + case FloatType => + java.lang.Float.compare(left.asInstanceOf[Float], right.asInstanceOf[Float]) + case DoubleType => + java.lang.Double.compare(left.asInstanceOf[Double], right.asInstanceOf[Double]) + case _: DecimalType => + left.asInstanceOf[Decimal].compare(right.asInstanceOf[Decimal]) + case StringType => + ByteArray.compareBinary( + left.asInstanceOf[UTF8String].getBytes, + right.asInstanceOf[UTF8String].getBytes) + case other => + throw new IllegalStateException(s"compare called for unsupported type $other") + } + + // Compute Spark-compatible cache stats before serializing each batch to Arrow. + // The stats are stored beside the Arrow bytes so Spark's cache filter can prune + // CometCachedBatch without decoding the batch first. + private def encodeBatches( + batches: Iterator[ColumnarBatch], + attrs: Seq[Attribute]): Iterator[CachedBatch] = { + batches.flatMap { batch => + val stats = computeStats(batch, attrs) + + Utils.serializeBatches(Iterator.single(batch)).map { case (rows, buffer) => + CometCachedBatch( + numRows = rows.toInt, + sizeInBytes = buffer.size, + stats = stats, + bytes = buffer) + } + } + } + + override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = { + val activeConf = SQLConf.get + activeConf != null && enabled(activeConf) + } + override def supportsColumnarOutput(schema: StructType): Boolean = true + + // Columnar Comet output is stored as compressed Arrow stream bytes. + override def convertColumnarBatchToCachedBatch( + input: RDD[ColumnarBatch], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + + input.mapPartitions { batches => + encodeBatches(batches, schema) + } + } + + override def convertCachedBatchToColumnarBatch( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[ColumnarBatch] = { + + // Resolve requested columns by exprId, not by name, because aliases may reuse names. + val selectedIndices = + if (selectedAttributes.isEmpty) { + cacheAttributes.indices.toArray + } else { + val byExprId = cacheAttributes.zipWithIndex.map { case (attr, idx) => + attr.exprId -> idx + }.toMap + + selectedAttributes.map { attr => + byExprId.getOrElse( + attr.exprId, + throw new IllegalStateException( + s"Could not resolve selected attribute ${attr.name} from cache attributes")) + }.toArray + } + + val batchTypes = input.map(_.getClass.getName).distinct().collect() + + if (batchTypes.isEmpty) { + input.sparkContext.emptyRDD[ColumnarBatch] + } else if (batchTypes.length > 1) { + throw new IllegalStateException( + s"Mixed cached batch types are not supported: ${batchTypes.mkString(", ")}") + } else if (batchTypes.head == classOf[CometCachedBatch].getName) { + input.mapPartitions { it => + it.flatMap { + case cb: CometCachedBatch => + Utils.decodeBatches(cb.bytes, "CometCache").map { batch => + if (selectedIndices.length == batch.numCols()) { + batch + } else { + val cols = + selectedIndices.map(i => batch.column(i).asInstanceOf[ColumnVector]) + new ColumnarBatch(cols, batch.numRows()) Review Comment: When projecting a strict subset, the new `ColumnarBatch` holds only the selected `CometVector`s, so closing it releases only those columns. Do the dropped columns get closed anywhere? If `decodeBatches` hands ownership of the off-heap Arrow buffers to the `ColumnarBatch`, this projection would drop that ownership for the unselected columns and leak their buffers until GC. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
