sunxiaoguang commented on code in PR #49335: URL: https://github.com/apache/spark/pull/49335#discussion_r1910491403
########## connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala: ########## @@ -241,6 +241,323 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest assert(rows10(0).getString(0) === "amy") assert(rows10(1).getString(0) === "alex") } + + test("SPARK-50704: Test SQL function push down with different types and casts in WHERE clause") { + withTable(s"$catalogName.test_pushdown") { + // Define test values for different data types + val boolean = true + val int = 1 + val long = 0x1_ff_ff_ff_ffL + val str = "TeSt SpArK" + val float = 0.123 + val binary = "X'123456'" + val decimal = "-.001234567E+2BD" + val tableName = "test_pushdown" + + // Create a table with various data types + sql(s"""CREATE TABLE $catalogName.$tableName ( + boolean_col BOOLEAN, byte_col BYTE, tinyint_col TINYINT, short_col SHORT, + smallint_col SMALLINT, int_col INT, integer_col INTEGER, long_col LONG, + bigint_col BIGINT, float_col FLOAT, real_col REAL, double_col DOUBLE, + str_col STRING, binary_col BINARY, decimal_col DECIMAL(10, 7), dec_col DEC(10, 7), + numeric_col NUMERIC(10, 7))""") + + // Insert test values into the table + sql(s"""INSERT INTO $catalogName.$tableName VALUES ($boolean, $int, $int, $int, + $int, $int, $int, $long, $long, $float, $float, $float, '$str', $binary, $decimal, + $decimal, $decimal)""") + + // Helper function to generate test cases for a given function and columns + def generateTests( + function: String, + template: String, + columns: Seq[(Seq[String], Any)], + valueTransformer: Option[Any => String] = None + ): Seq[(String, String)] = { + columns.flatMap { case (cols, value) => + val valueLiteral = + valueTransformer + .map(transform => transform(value)) + .getOrElse(value.toString) + cols.map(column => { + ( + function, + template + .replaceAll("COLUMN", column) + .replaceAll("VALUE", valueLiteral) + ) + }) + } + } + + // Helper function to convert a value to a string literal + def toStringLiteral(any: Any): String = any match { + case stringValue: String => + if (stringValue == decimal) "'-0.1234567'" else stringValue + case _ => s"'$any'" + } + + // Define columns and their corresponding test values + var booleanColumns = (Seq("boolean_col"), boolean) + var intColumns = ( + Seq( + "byte_col", + "tinyint_col", + "short_col", + "smallint_col", + "int_col", + "integer_col" + ), + int + ) + var longColumns = (Seq("long_col", "bigint_col"), long) + var floatColumns = (Seq("float_col", "real_col", "double_col"), float) + var strColumns = (Seq("str_col"), s"'${str}'") + var binaryColumns = (Seq("binary_col"), binary) + var decimalColumns = (Seq("decimal_col", "dec_col", "numeric_col"), decimal) + + // Generate test cases for various functions + val functions = Seq( + generateTests( + "ABS", + "ABS(COLUMN) = ABS(VALUE)", + Seq(intColumns, longColumns) + ), + generateTests( + "ABS", + "ABS(ABS(COLUMN) - ABS(VALUE)) <= 0.00001", + Seq(floatColumns, decimalColumns) + ), + generateTests( + "COALESCE", + "COALESCE(COLUMN, NULL, VALUE) = VALUE", + Seq( + booleanColumns, + intColumns, + longColumns, + strColumns, + binaryColumns + ) + ), + generateTests( + "COALESCE", + "ABS(ABS(COALESCE(COLUMN, NULL, VALUE)) - ABS(VALUE)) <= 0.00001", + Seq( + floatColumns, + decimalColumns + ) + ), + generateTests( + "GREATEST", + "GREATEST(COLUMN, VALUE) = VALUE", + Seq( + booleanColumns, + intColumns, + longColumns, + strColumns, + binaryColumns + ) + ), + generateTests( + "GREATEST", + "ABS(ABS(GREATEST(COLUMN, VALUE)) - ABS(VALUE)) <= 0.00001", + Seq( + floatColumns, + decimalColumns + ) + ), + generateTests( + "LEAST", + "LEAST(COLUMN, VALUE) = VALUE", + Seq( + booleanColumns, + intColumns, + longColumns, + floatColumns, + strColumns, + binaryColumns, + decimalColumns + ) + ), + generateTests( + "LOG10", + "ABS(LOG10(ABS(COLUMN)) - LOG10(ABS(VALUE))) <= 0.00001", + Seq(intColumns, longColumns, floatColumns, decimalColumns) + ), + generateTests( + "LOG2", + "ABS(LOG2(ABS(COLUMN)) - LOG2(ABS(VALUE))) <= 0.00001", + Seq(intColumns, longColumns, floatColumns, decimalColumns) + ), + generateTests( + "LN", + "ABS(LN(ABS(COLUMN)) - LN(ABS(VALUE))) <= 0.00001", + Seq(intColumns, longColumns, floatColumns, decimalColumns) + ), + generateTests( + "EXP", + "ABS(EXP(ABS(COLUMN) - ABS(COLUMN)) - EXP(ABS(VALUE) - ABS(VALUE))) <= 0.00001", + Seq(intColumns, longColumns, floatColumns, decimalColumns) + ), + generateTests( + "POWER", + "ABS(POWER(COLUMN, 2) - POWER(VALUE, 2)) <= 0.00001", + Seq(intColumns, longColumns, floatColumns, decimalColumns) + ), + generateTests( + "SQRT", + "ABS(SQRT(ABS(COLUMN)) - SQRT(ABS(VALUE))) <= 0.00001", + Seq(intColumns, longColumns, floatColumns, decimalColumns) + ), + generateTests( + "SIN", + "ABS(SIN(COLUMN) - SIN(VALUE)) <= 0.00001", + Seq(intColumns, longColumns, floatColumns, decimalColumns) + ), + generateTests( + "COS", + "ABS(COS(COLUMN) - COS(VALUE)) <= 0.00001", + Seq(intColumns, longColumns, floatColumns, decimalColumns) + ), + generateTests( + "TAN", + "ABS(TAN(COLUMN) - TAN(VALUE)) <= 0.00001", + Seq(intColumns, longColumns, floatColumns, decimalColumns) + ), + generateTests( + "COT", + "ABS(COT(COLUMN) - COT(VALUE)) <= 0.00001", + Seq(intColumns, longColumns, floatColumns, decimalColumns) + ), + generateTests( + "ASIN", + "ABS(ASIN(COLUMN/COLUMN) - ASIN(VALUE/VALUE)) <= 0.00001", + Seq(intColumns, longColumns, floatColumns, decimalColumns) + ), + generateTests( + "ACOS", + "ABS(ACOS(COLUMN/COLUMN) - ACOS(VALUE/VALUE)) <= 0.00001", + Seq(intColumns, longColumns, floatColumns, decimalColumns) + ), + generateTests( + "ATAN", + "ABS(ATAN(COLUMN) - ATAN(VALUE)) <= 0.00001", + Seq(intColumns, longColumns, floatColumns, decimalColumns) + ), + generateTests( + "ATAN2", + "ABS(ATAN2(COLUMN, 1) - ATAN2(VALUE, 1)) <= 0.00001", + Seq(intColumns, longColumns, floatColumns, decimalColumns) + ), + generateTests( + "DEGREES", + "ABS(DEGREES(COLUMN) - DEGREES(VALUE)) <= 0.00001", + Seq(intColumns, longColumns, floatColumns, decimalColumns) + ), + generateTests( + "RADIANS", + "ABS(RADIANS(COLUMN) - RADIANS(VALUE)) <= 0.00001", + Seq(intColumns, longColumns, floatColumns, decimalColumns) + ), + generateTests( + "SIGN", + "ABS(SIGN(COLUMN) - SIGN(VALUE)) <= 0.00001", + Seq(intColumns, longColumns, floatColumns, decimalColumns) + ), + generateTests( + "UPPER", + "UPPER(COLUMN) = VALUE", + Seq( + intColumns, + longColumns, + floatColumns, + strColumns, + decimalColumns + ), + Some(toStringLiteral) + ), + generateTests( + "LOWER", + "LOWER(COLUMN) = VALUE", + Seq( + intColumns, + longColumns, + floatColumns, + strColumns, + decimalColumns + ), + Some(toStringLiteral) + ), + generateTests( + "SHA1", + "SHA1(COLUMN) = SHA1(VALUE)", + Seq(strColumns, binaryColumns) + ), + generateTests( + "SHA2", + "SHA2(COLUMN, 256) = SHA2(VALUE, 256)", + Seq(strColumns, binaryColumns) + ), + generateTests( + "MD5", + "MD5(COLUMN) = MD5(VALUE)", + Seq(strColumns, binaryColumns) + ), + generateTests( + "CRC32", + "CRC32(COLUMN) = CRC32(VALUE)", + Seq(strColumns, binaryColumns) + ), + generateTests( + "BIT_LENGTH", + "BIT_LENGTH(COLUMN) = BIT_LENGTH(VALUE)", + Seq( + intColumns, + longColumns, + floatColumns, + strColumns, + binaryColumns, + decimalColumns + ), + Some(toStringLiteral) + ), + generateTests( + "CHAR_LENGTH", + "CHAR_LENGTH(COLUMN) = CHAR_LENGTH(VALUE)", + Seq( + intColumns, + longColumns, + floatColumns, + strColumns, + binaryColumns, + decimalColumns + ), + Some(toStringLiteral) + ), + generateTests( + "CONCAT", + "CONCAT(COLUMN) = VALUE", + Seq( + intColumns, + longColumns, + floatColumns, + strColumns, + binaryColumns, + decimalColumns + ), + Some(toStringLiteral) + ) + ).flatten + Review Comment: Added new tests for cast, PTAL -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org