sunxiaoguang commented on code in PR #49335:
URL: https://github.com/apache/spark/pull/49335#discussion_r1910376515


##########
connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala:
##########
@@ -241,6 +241,323 @@ class MySQLIntegrationSuite extends 
DockerJDBCIntegrationV2Suite with V2JDBCTest
     assert(rows10(0).getString(0) === "amy")
     assert(rows10(1).getString(0) === "alex")
   }
+
+  test("SPARK-50704: Test SQL function push down with different types and 
casts in WHERE clause") {
+    withTable(s"$catalogName.test_pushdown") {
+      // Define test values for different data types
+      val boolean = true
+      val int = 1
+      val long = 0x1_ff_ff_ff_ffL
+      val str = "TeSt SpArK"
+      val float = 0.123
+      val binary = "X'123456'"
+      val decimal = "-.001234567E+2BD"
+      val tableName = "test_pushdown"
+
+      // Create a table with various data types
+      sql(s"""CREATE TABLE $catalogName.$tableName (
+        boolean_col BOOLEAN, byte_col BYTE, tinyint_col TINYINT, short_col 
SHORT,
+        smallint_col SMALLINT, int_col INT, integer_col INTEGER, long_col LONG,
+        bigint_col BIGINT, float_col FLOAT, real_col REAL, double_col DOUBLE,
+        str_col STRING, binary_col BINARY, decimal_col DECIMAL(10, 7), dec_col 
DEC(10, 7),
+        numeric_col NUMERIC(10, 7))""")
+
+      // Insert test values into the table
+      sql(s"""INSERT INTO $catalogName.$tableName VALUES ($boolean, $int, 
$int, $int,
+        $int, $int, $int, $long, $long, $float, $float, $float, '$str', 
$binary, $decimal,
+        $decimal, $decimal)""")
+
+      // Helper function to generate test cases for a given function and 
columns
+      def generateTests(
+          function: String,
+          template: String,
+          columns: Seq[(Seq[String], Any)],
+          valueTransformer: Option[Any => String] = None
+      ): Seq[(String, String)] = {
+        columns.flatMap { case (cols, value) =>
+          val valueLiteral =
+            valueTransformer
+              .map(transform => transform(value))
+              .getOrElse(value.toString)
+          cols.map(column => {
+            (
+              function,
+              template
+                .replaceAll("COLUMN", column)
+                .replaceAll("VALUE", valueLiteral)
+            )
+          })
+        }
+      }
+
+      // Helper function to convert a value to a string literal
+      def toStringLiteral(any: Any): String = any match {
+        case stringValue: String =>
+          if (stringValue == decimal) "'-0.1234567'" else stringValue
+        case _ => s"'$any'"
+      }
+
+      // Define columns and their corresponding test values
+      var booleanColumns = (Seq("boolean_col"), boolean)
+      var intColumns = (
+        Seq(
+          "byte_col",
+          "tinyint_col",
+          "short_col",
+          "smallint_col",
+          "int_col",
+          "integer_col"
+        ),
+        int
+      )
+      var longColumns = (Seq("long_col", "bigint_col"), long)
+      var floatColumns = (Seq("float_col", "real_col", "double_col"), float)
+      var strColumns = (Seq("str_col"), s"'${str}'")
+      var binaryColumns = (Seq("binary_col"), binary)
+      var decimalColumns = (Seq("decimal_col", "dec_col", "numeric_col"), 
decimal)
+
+      // Generate test cases for various functions
+      val functions = Seq(
+        generateTests(
+          "ABS",
+          "ABS(COLUMN) = ABS(VALUE)",
+          Seq(intColumns, longColumns)
+        ),
+        generateTests(
+          "ABS",
+          "ABS(ABS(COLUMN) - ABS(VALUE)) <= 0.00001",
+          Seq(floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "COALESCE",
+          "COALESCE(COLUMN, NULL, VALUE) = VALUE",
+          Seq(
+            booleanColumns,
+            intColumns,
+            longColumns,
+            strColumns,
+            binaryColumns
+          )
+        ),
+        generateTests(
+          "COALESCE",
+          "ABS(ABS(COALESCE(COLUMN, NULL, VALUE)) - ABS(VALUE)) <= 0.00001",
+          Seq(
+            floatColumns,
+            decimalColumns
+          )
+        ),
+        generateTests(
+          "GREATEST",
+          "GREATEST(COLUMN, VALUE) = VALUE",
+          Seq(
+            booleanColumns,
+            intColumns,
+            longColumns,
+            strColumns,
+            binaryColumns
+          )
+        ),
+        generateTests(
+          "GREATEST",
+          "ABS(ABS(GREATEST(COLUMN, VALUE)) - ABS(VALUE)) <= 0.00001",
+          Seq(
+            floatColumns,
+            decimalColumns
+          )
+        ),
+        generateTests(
+          "LEAST",
+          "LEAST(COLUMN, VALUE) = VALUE",
+          Seq(
+            booleanColumns,
+            intColumns,
+            longColumns,
+            floatColumns,
+            strColumns,
+            binaryColumns,
+            decimalColumns
+          )
+        ),
+        generateTests(
+          "LOG10",
+          "ABS(LOG10(ABS(COLUMN)) - LOG10(ABS(VALUE))) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "LOG2",
+          "ABS(LOG2(ABS(COLUMN)) - LOG2(ABS(VALUE))) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "LN",
+          "ABS(LN(ABS(COLUMN)) - LN(ABS(VALUE))) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "EXP",
+          "ABS(EXP(ABS(COLUMN) - ABS(COLUMN)) - EXP(ABS(VALUE) - ABS(VALUE))) 
<= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "POWER",
+          "ABS(POWER(COLUMN, 2) - POWER(VALUE, 2)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "SQRT",
+          "ABS(SQRT(ABS(COLUMN)) - SQRT(ABS(VALUE))) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "SIN",
+          "ABS(SIN(COLUMN) - SIN(VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "COS",
+          "ABS(COS(COLUMN) - COS(VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "TAN",
+          "ABS(TAN(COLUMN) - TAN(VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "COT",
+          "ABS(COT(COLUMN) - COT(VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "ASIN",
+          "ABS(ASIN(COLUMN/COLUMN) - ASIN(VALUE/VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "ACOS",
+          "ABS(ACOS(COLUMN/COLUMN) - ACOS(VALUE/VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "ATAN",
+          "ABS(ATAN(COLUMN) - ATAN(VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "ATAN2",
+          "ABS(ATAN2(COLUMN, 1) - ATAN2(VALUE, 1)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "DEGREES",
+          "ABS(DEGREES(COLUMN) - DEGREES(VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "RADIANS",
+          "ABS(RADIANS(COLUMN) - RADIANS(VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "SIGN",
+          "ABS(SIGN(COLUMN) - SIGN(VALUE)) <= 0.00001",
+          Seq(intColumns, longColumns, floatColumns, decimalColumns)
+        ),
+        generateTests(
+          "UPPER",
+          "UPPER(COLUMN) = VALUE",
+          Seq(
+            intColumns,
+            longColumns,
+            floatColumns,
+            strColumns,
+            decimalColumns
+          ),
+          Some(toStringLiteral)
+        ),
+        generateTests(
+          "LOWER",
+          "LOWER(COLUMN) = VALUE",
+          Seq(
+            intColumns,
+            longColumns,
+            floatColumns,
+            strColumns,
+            decimalColumns
+          ),
+          Some(toStringLiteral)
+        ),
+        generateTests(
+          "SHA1",
+          "SHA1(COLUMN) = SHA1(VALUE)",
+          Seq(strColumns, binaryColumns)
+        ),
+        generateTests(
+          "SHA2",
+          "SHA2(COLUMN, 256) = SHA2(VALUE, 256)",
+          Seq(strColumns, binaryColumns)
+        ),
+        generateTests(
+          "MD5",
+          "MD5(COLUMN) = MD5(VALUE)",
+          Seq(strColumns, binaryColumns)
+        ),
+        generateTests(
+          "CRC32",
+          "CRC32(COLUMN) = CRC32(VALUE)",
+          Seq(strColumns, binaryColumns)
+        ),
+        generateTests(
+          "BIT_LENGTH",
+          "BIT_LENGTH(COLUMN) = BIT_LENGTH(VALUE)",
+          Seq(
+            intColumns,
+            longColumns,
+            floatColumns,
+            strColumns,
+            binaryColumns,
+            decimalColumns
+          ),
+          Some(toStringLiteral)
+        ),
+        generateTests(
+          "CHAR_LENGTH",
+          "CHAR_LENGTH(COLUMN) = CHAR_LENGTH(VALUE)",
+          Seq(
+            intColumns,
+            longColumns,
+            floatColumns,
+            strColumns,
+            binaryColumns,
+            decimalColumns
+          ),
+          Some(toStringLiteral)
+        ),
+        generateTests(
+          "CONCAT",
+          "CONCAT(COLUMN) = VALUE",
+          Seq(
+            intColumns,
+            longColumns,
+            floatColumns,
+            strColumns,
+            binaryColumns,
+            decimalColumns
+          ),
+          Some(toStringLiteral)
+        )
+      ).flatten
+

Review Comment:
   But I will add other cast tests to changed types as well. This is going to 
be pretty quick.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to