sunxiaoguang commented on code in PR #49453: URL: https://github.com/apache/spark/pull/49453#discussion_r1917621813
########## connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala: ########## @@ -241,6 +241,56 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest assert(rows10(0).getString(0) === "amy") assert(rows10(1).getString(0) === "alex") } + + test("SPARK-50793: MySQL JDBC Connector failed to cast some types") { + val tableName = catalogName + ".test_cast_function" + withTable(tableName) { + val stringValue = "0" + val stringLiteral = "'0'" + val longValue = 0L + val binaryValue = Array[Byte](0x30) + val binaryLiteral = "x'30'" + val doubleValue = 0.0 + val doubleLiteral = "0.0" + // CREATE table to use types defined in Spark SQL + sql(s"""CREATE TABLE $tableName ( + string_col STRING, + long_col LONG, + binary_col BINARY, + double_col DOUBLE + )""") + sql( + s"INSERT INTO $tableName VALUES($stringLiteral, $longValue, $binaryLiteral, $doubleValue)") + + def testCast(castType: String, sourceCol: String, targetCol: String, + sourceValue: Any, targetValue: Any): Unit = { + val sql = + s"""SELECT $sourceCol, CAST($sourceCol AS $castType) FROM $tableName + |WHERE CAST($sourceCol AS $castType) = $targetCol""".stripMargin + val df = spark.sql(sql) + checkFilterPushed(df) + val rows = df.collect() Review Comment: After taking a look at the checkAnswer implementation, it is using `==` to compare any other types than those types need special handling. This means the check may skip actual types, so let's double check if things like this is acceptable. ```scala val i = 1 val s = 1.toShort val l = 1L println(i == s) println(i == l) println(s ==l) The output for these lines of code is: true true true ``` FYI: This is the implementation checkAnswer finally call to compare ```scala def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { case (null, null) => true case (null, _) => false case (_, null) => false case (a: Array[_], b: Array[_]) => a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} case (a: Map[_, _], b: Map[_, _]) => a.size == b.size && a.keys.forall { aKey => b.keys.find(bKey => compare(aKey, bKey)).exists(bKey => compare(a(aKey), b(bKey))) } case (a: Iterable[_], b: Iterable[_]) => a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} case (a: Product, b: Product) => compare(a.productIterator.toSeq, b.productIterator.toSeq) case (a: Row, b: Row) => compare(a.toSeq, b.toSeq) // 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0. case (a: Double, b: Double) => java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b) case (a: Float, b: Float) => java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b) case (a, b) => **a == b** } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org