This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new a4bf90acf Fix: array contains null handling (#3372)
a4bf90acf is described below
commit a4bf90acfa9274fcef9e232519bf92d1fa0ce2cc
Author: Shekhar Prasad Rajak <[email protected]>
AuthorDate: Wed Feb 4 05:37:19 2026 +0530
Fix: array contains null handling (#3372)
---
.../main/scala/org/apache/comet/serde/arrays.scala | 39 +++++++++++++++++++++-
.../sql-tests/expressions/array/array_contains.sql | 21 +++++++++++-
.../apache/comet/CometArrayExpressionSuite.scala | 32 ++++++++++++++++++
3 files changed, 90 insertions(+), 2 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala
b/spark/src/main/scala/org/apache/comet/serde/arrays.scala
index b552a071d..cdaf3d5f8 100644
--- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala
@@ -22,6 +22,7 @@ package org.apache.comet.serde
import scala.annotation.tailrec
import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains,
ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect,
ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap,
ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten,
GetArrayItem, IsNotNull, Literal, Reverse, Size}
+import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -134,7 +135,34 @@ object CometArrayContains extends
CometExpressionSerde[ArrayContains] {
val arrayContainsScalarExpr =
scalarFunctionExprToProto("array_has", arrayExprProto, keyExprProto)
- optExprWithInfo(arrayContainsScalarExpr, expr, expr.children: _*)
+
+ // Handle NULL array input - return NULL if array is NULL (matching
Spark's behavior)
+ val isNotNullExpr = createUnaryExpr(
+ expr,
+ expr.children.head,
+ inputs,
+ binding,
+ (builder, unaryExpr) => builder.setIsNotNull(unaryExpr))
+
+ val nullLiteralProto = exprToProto(Literal(null, BooleanType), Seq.empty)
+
+ if (arrayContainsScalarExpr.isDefined && isNotNullExpr.isDefined &&
+ nullLiteralProto.isDefined) {
+ val caseWhenExpr = ExprOuterClass.CaseWhen
+ .newBuilder()
+ .addWhen(isNotNullExpr.get)
+ .addThen(arrayContainsScalarExpr.get)
+ .setElseExpr(nullLiteralProto.get)
+ .build()
+ Some(
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setCaseWhen(caseWhenExpr)
+ .build())
+ } else {
+ withInfo(expr, expr.children: _*)
+ None
+ }
}
}
@@ -395,6 +423,15 @@ object CometCreateArray extends
CometExpressionSerde[CreateArray] {
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val children = expr.children
+
+ // Handle empty array: return literal directly to avoid DataFusion
coerce_types bug
+ // when make_array is called with 0 arguments (issue #3338)
+ if (children.isEmpty) {
+ val emptyArrayLiteral =
+ Literal.create(new GenericArrayData(Array.empty[Any]), expr.dataType)
+ return exprToProtoInternal(emptyArrayLiteral, inputs, binding)
+ }
+
val childExprs = children.map(exprToProtoInternal(_, inputs, binding))
if (childExprs.forall(_.isDefined)) {
diff --git
a/spark/src/test/resources/sql-tests/expressions/array/array_contains.sql
b/spark/src/test/resources/sql-tests/expressions/array/array_contains.sql
index 86ad0cc48..cdbe3e68c 100644
--- a/spark/src/test/resources/sql-tests/expressions/array/array_contains.sql
+++ b/spark/src/test/resources/sql-tests/expressions/array/array_contains.sql
@@ -35,5 +35,24 @@ query spark_answer_only
SELECT array_contains(array(1, 2, 3), val) FROM test_array_contains
-- literal + literal
-query ignore(https://github.com/apache/datafusion-comet/issues/3345)
+-- Note: array_contains(array(), 1) still has a bug (issue #3346) so we use
spark_answer_only
+-- The NULL array case (cast(NULL as array<int>)) was fixed in issue #3345
+query spark_answer_only
SELECT array_contains(array(1, 2, 3), 2), array_contains(array(1, 2, 3), 4),
array_contains(array(), 1), array_contains(cast(NULL as array<int>), 1)
+
+-- Additional NULL array tests (issue #3345 fix verification)
+-- NULL array with integer value
+query
+SELECT array_contains(cast(NULL as array<int>), 1)
+
+-- NULL array with string value
+query
+SELECT array_contains(cast(NULL as array<string>), 'test')
+
+-- NULL array with NULL value
+query
+SELECT array_contains(cast(NULL as array<int>), cast(NULL as int))
+
+-- NULL array with column value
+query
+SELECT array_contains(cast(NULL as array<int>), val) FROM test_array_contains
diff --git
a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
index cf4911736..b22d0f72d 100644
--- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
@@ -325,6 +325,38 @@ class CometArrayExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelp
}
}
+ test("array_contains - NULL array returns NULL") {
+ // Test that array_contains returns NULL when the array argument is NULL
+ // This matches Spark's SQL three-valued logic behavior
+ withTempDir { dir =>
+ withTempView("t1") {
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n =
100)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1")
+
+ // Test NULL array with non-null value
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_contains(cast(null as array<int>), 1) FROM t1"))
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_contains(cast(null as array<string>), 'test') FROM
t1"))
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_contains(cast(null as array<double>), 1.5) FROM
t1"))
+
+ // Test NULL array with NULL value
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_contains(cast(null as array<int>), cast(null as
int)) FROM t1"))
+
+ // Test NULL array with column value
+ checkSparkAnswerAndOperator(
+ sql("SELECT array_contains(cast(null as array<int>), _2) FROM t1"))
+
+ // Test non-null array with values (to ensure fix doesn't break normal
operation)
+ checkSparkAnswerAndOperator(sql("SELECT array_contains(array(1, 2, 3),
2) FROM t1"))
+ checkSparkAnswerAndOperator(sql("SELECT array_contains(array(1, 2, 3),
5) FROM t1"))
+ }
+ }
+ }
+
test("array_contains - test all types (convert from Parquet)") {
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]