beliefer commented on code in PR #49335:
URL: https://github.com/apache/spark/pull/49335#discussion_r1910325099


##########
sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala:
##########
@@ -112,6 +151,19 @@ private case class MySQLDialect() extends JdbcDialect with 
SQLConfHelper with No
       } else {
         super.visitAggregateFunction(funcName, isDistinct, inputs)
       }
+
+    override def visitCast(expr: String, exprDataType: DataType, dataType: 
DataType): String = {
+      val databaseTypeDefinition = dataType match {
+        // MySQL uses CHAR in the cast function for the type LONGTEXT
+        case StringType => "CHAR"

Review Comment:
   `SELECT CAST('123' AS LONGTEXT)` is invalid in MySQL?



##########
sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala:
##########
@@ -374,6 +374,7 @@ abstract class JdbcDialect extends Serializable with 
Logging {
     case dateValue: Date => "'" + dateValue + "'"
     case dateValue: LocalDate => s"'${DateFormatter().format(dateValue)}'"
     case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ")
+    case binaryValue: Array[Byte] => 
binaryValue.map("%02X".format(_)).mkString("X'", "", "'")

Review Comment:
   Is this good for any databases?



##########
sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala:
##########
@@ -112,6 +151,19 @@ private case class MySQLDialect() extends JdbcDialect with 
SQLConfHelper with No
       } else {
         super.visitAggregateFunction(funcName, isDistinct, inputs)
       }
+
+    override def visitCast(expr: String, exprDataType: DataType, dataType: 
DataType): String = {
+      val databaseTypeDefinition = dataType match {
+        // MySQL uses CHAR in the cast function for the type LONGTEXT
+        case StringType => "CHAR"
+        // MySQL uses SIGNED INTEGER in the cast function for the types 
SMALLINT, INTEGER and BIGINT
+        case ShortType | IntegerType | LongType => "SIGNED INTEGER"

Review Comment:
   `SELECT CAST(123 AS SMALLINT)` is invalid in MySQL?



##########
connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala:
##########
@@ -241,6 +241,323 @@ class MySQLIntegrationSuite extends 
DockerJDBCIntegrationV2Suite with V2JDBCTest
     assert(rows10(0).getString(0) === "amy")
     assert(rows10(1).getString(0) === "alex")
   }
+
+  test("SPARK-50704: Test SQL function push down with different types and 
casts in WHERE clause") {
+    withTable(s"$catalogName.test_pushdown") {
+      // Define test values for different data types
+      val boolean = true
+      val int = 1
+      val long = 0x1_ff_ff_ff_ffL
+      val str = "TeSt SpArK"
+      val float = 0.123
+      val binary = "X'123456'"
+      val decimal = "-.001234567E+2BD"
+      val tableName = "test_pushdown"
+
+      // Create a table with various data types
+      sql(s"""CREATE TABLE $catalogName.$tableName (
+        boolean_col BOOLEAN, byte_col BYTE, tinyint_col TINYINT, short_col 
SHORT,
+        smallint_col SMALLINT, int_col INT, integer_col INTEGER, long_col LONG,
+        bigint_col BIGINT, float_col FLOAT, real_col REAL, double_col DOUBLE,
+        str_col STRING, binary_col BINARY, decimal_col DECIMAL(10, 7), dec_col 
DEC(10, 7),
+        numeric_col NUMERIC(10, 7))""")
+
+      // Insert test values into the table
+      sql(s"""INSERT INTO $catalogName.$tableName VALUES ($boolean, $int, 
$int, $int,
+        $int, $int, $int, $long, $long, $float, $float, $float, '$str', 
$binary, $decimal,
+        $decimal, $decimal)""")
+
+      // Helper function to generate test cases for a given function and 
columns
+      def generateTests(
+          function: String,
+          template: String,
+          columns: Seq[(Seq[String], Any)],
+          valueTransformer: Option[Any => String] = None
+      ): Seq[(String, String)] = {
+        columns.flatMap { case (cols, value) =>
+          val valueLiteral =
+            valueTransformer
+              .map(transform => transform(value))
+              .getOrElse(value.toString)
+          cols.map(column => {
+            (
+              function,
+              template
+                .replaceAll("COLUMN", column)
+                .replaceAll("VALUE", valueLiteral)
+            )
+          })
+        }
+      }
+
+      // Helper function to convert a value to a string literal
+      def toStringLiteral(any: Any): String = any match {
+        case stringValue: String =>
+          if (stringValue == decimal) "'-0.1234567'" else stringValue
+        case _ => s"'$any'"
+      }
+
+      // Define columns and their corresponding test values
+      var booleanColumns = (Seq("boolean_col"), boolean)
+      var intColumns = (
+        Seq(
+          "byte_col",
+          "tinyint_col",
+          "short_col",
+          "smallint_col",
+          "int_col",
+          "integer_col"
+        ),
+        int
+      )
+      var longColumns = (Seq("long_col", "bigint_col"), long)
+      var floatColumns = (Seq("float_col", "real_col", "double_col"), float)
+      var strColumns = (Seq("str_col"), s"'${str}'")
+      var binaryColumns = (Seq("binary_col"), binary)
+      var decimalColumns = (Seq("decimal_col", "dec_col", "numeric_col"), 
decimal)
+
+      // Generate test cases for various functions
+      val functions = Seq(
+        generateTests(
+          "ABS",
+          "ABS(COLUMN) = ABS(VALUE)",
+          Seq(intColumns, longColumns)
+        ),
+        generateTests(
+          "ABS",
+          "ABS(ABS(COLUMN) - ABS(VALUE)) <= 0.00001",
+          Seq(floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "COALESCE",
+          "COALESCE(COLUMN, NULL, VALUE) = VALUE",
+          Seq(
+            booleanColumns,
+            intColumns,
+            longColumns,
+            strColumns,
+            binaryColumns
+          )
+        ),
+        generateTests(
+          "COALESCE",
+          "ABS(ABS(COALESCE(COLUMN, NULL, VALUE)) - ABS(VALUE)) <= 0.00001",
+          Seq(
+            floatColumns,
+            decimalColumns
+          )
+        ),
+        generateTests(
+          "GREATEST",
+          "GREATEST(COLUMN, VALUE) = VALUE",
+          Seq(
+            booleanColumns,
+            intColumns,
+            longColumns,
+            strColumns,
+            binaryColumns
+          )
+        ),
+        generateTests(
+          "GREATEST",
+          "ABS(ABS(GREATEST(COLUMN, VALUE)) - ABS(VALUE)) <= 0.00001",
+          Seq(
+            floatColumns,
+            decimalColumns
+          )
+        ),
+        generateTests(
+          "LEAST",
+          "LEAST(COLUMN, VALUE) = VALUE",
+          Seq(
+            booleanColumns,
+            intColumns,
+            longColumns,
+            floatColumns,
+            strColumns,
+            binaryColumns,
+            decimalColumns
+          )
+        ),
+        generateTests(
+          "LOG10",
+          "ABS(LOG10(ABS(COLUMN)) - LOG10(ABS(VALUE))) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "LOG2",
+          "ABS(LOG2(ABS(COLUMN)) - LOG2(ABS(VALUE))) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "LN",
+          "ABS(LN(ABS(COLUMN)) - LN(ABS(VALUE))) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "EXP",
+          "ABS(EXP(ABS(COLUMN) - ABS(COLUMN)) - EXP(ABS(VALUE) - ABS(VALUE))) 
<= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "POWER",
+          "ABS(POWER(COLUMN, 2) - POWER(VALUE, 2)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "SQRT",
+          "ABS(SQRT(ABS(COLUMN)) - SQRT(ABS(VALUE))) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "SIN",
+          "ABS(SIN(COLUMN) - SIN(VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "COS",
+          "ABS(COS(COLUMN) - COS(VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "TAN",
+          "ABS(TAN(COLUMN) - TAN(VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "COT",
+          "ABS(COT(COLUMN) - COT(VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "ASIN",
+          "ABS(ASIN(COLUMN/COLUMN) - ASIN(VALUE/VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "ACOS",
+          "ABS(ACOS(COLUMN/COLUMN) - ACOS(VALUE/VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "ATAN",
+          "ABS(ATAN(COLUMN) - ATAN(VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "ATAN2",
+          "ABS(ATAN2(COLUMN, 1) - ATAN2(VALUE, 1)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "DEGREES",
+          "ABS(DEGREES(COLUMN) - DEGREES(VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "RADIANS",
+          "ABS(RADIANS(COLUMN) - RADIANS(VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "SIGN",
+          "ABS(SIGN(COLUMN) - SIGN(VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "UPPER",
+          "UPPER(COLUMN) = VALUE",
+          Seq(
+            intColumns,
+            longColumns,
+            floatColumns,
+            strColumns,
+            decimalColumns
+          ),
+          Some(toStringLiteral)
+        ),
+        generateTests(
+          "LOWER",
+          "LOWER(COLUMN) = VALUE",
+          Seq(
+            intColumns,
+            longColumns,
+            floatColumns,
+            strColumns,
+            decimalColumns
+          ),
+          Some(toStringLiteral)
+        ),
+        generateTests(
+          "SHA1",
+          "SHA1(COLUMN) = SHA1(VALUE)",
+          Seq(strColumns, binaryColumns)
+        ),
+        generateTests(
+          "SHA2",
+          "SHA2(COLUMN, 256) = SHA2(VALUE, 256)",
+          Seq(strColumns, binaryColumns)
+        ),
+        generateTests(
+          "MD5",
+          "MD5(COLUMN) = MD5(VALUE)",
+          Seq(strColumns, binaryColumns)
+        ),
+        generateTests(
+          "CRC32",
+          "CRC32(COLUMN) = CRC32(VALUE)",
+          Seq(strColumns, binaryColumns)
+        ),
+        generateTests(
+          "BIT_LENGTH",
+          "BIT_LENGTH(COLUMN) = BIT_LENGTH(VALUE)",
+          Seq(
+            intColumns,
+            longColumns,
+            floatColumns,
+            strColumns,
+            binaryColumns,
+            decimalColumns
+          ),
+          Some(toStringLiteral)
+        ),
+        generateTests(
+          "CHAR_LENGTH",
+          "CHAR_LENGTH(COLUMN) = CHAR_LENGTH(VALUE)",
+          Seq(
+            intColumns,
+            longColumns,
+            floatColumns,
+            strColumns,
+            binaryColumns,
+            decimalColumns
+          ),
+          Some(toStringLiteral)
+        ),
+        generateTests(
+          "CONCAT",
+          "CONCAT(COLUMN) = VALUE",
+          Seq(
+            intColumns,
+            longColumns,
+            floatColumns,
+            strColumns,
+            binaryColumns,
+            decimalColumns
+          ),
+          Some(toStringLiteral)
+        )
+      ).flatten
+

Review Comment:
   Please the test cases for cast here.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to