xianzhe-databricks commented on code in PR #52467:
URL: https://github.com/apache/spark/pull/52467#discussion_r2391967264
##########
python/pyspark/sql/tests/connect/test_connect_collection.py:
##########
@@ -291,6 +324,125 @@ def test_collect_nested_type(self):
).collect(),
)
+ def test_collect_binary_type(self):
+ """Test that df.collect() respects binary_as_bytes configuration for
server-side data"""
+ # Use SQL to create data with binary type on the server side
+ # This ensures the data goes through Arrow conversion from server to
client
+ query = """
+ SELECT * FROM VALUES
+ (CAST('hello' AS BINARY)),
+ (CAST('world' AS BINARY))
+ AS tab(b)
+ """
+
+ # Test with binary_as_bytes=True (default)
+ with self.both_conf({"spark.sql.execution.pyspark.binaryAsBytes":
"true"}):
+ connect_rows = self.connect.sql(query).collect()
+ self.assertEqual(len(connect_rows), 2)
+ for row in connect_rows:
+ self.assertIsInstance(row.b, bytes)
+
+ spark_rows = self.spark.sql(query).collect()
+ self.assertEqual(len(spark_rows), 2)
+ for row in spark_rows:
+ self.assertIsInstance(row.b, bytes)
+
+ # Test with binary_as_bytes=False
+ with self.both_conf({"spark.sql.execution.pyspark.binaryAsBytes":
"false"}):
+ connect_rows = self.connect.sql(query).collect()
+ self.assertEqual(len(connect_rows), 2)
+ for row in connect_rows:
+ self.assertIsInstance(row.b, bytearray)
+
+ spark_rows = self.spark.sql(query).collect()
+ self.assertEqual(len(spark_rows), 2)
+ for row in spark_rows:
+ self.assertIsInstance(row.b, bytearray)
+
+ def test_to_local_iterator_binary_type(self):
+ """Test that df.toLocalIterator() respects binary_as_bytes
configuration"""
+ # Use server-side query that creates binary data
+ query = """
+ SELECT * FROM VALUES
+ (CAST('data1' AS BINARY)),
+ (CAST('data2' AS BINARY))
+ AS tab(b)
+ """
+
+ # Test with binary_as_bytes=True
+ with self.both_conf({"spark.sql.execution.pyspark.binaryAsBytes":
"true"}):
+ connect_count = 0
+ for row in self.connect.sql(query).toLocalIterator():
+ self.assertIsInstance(row.b, bytes)
+ connect_count += 1
+ self.assertEqual(connect_count, 2)
+
+ spark_count = 0
+ for row in self.spark.sql(query).toLocalIterator():
+ self.assertIsInstance(row.b, bytes)
+ spark_count += 1
+ self.assertEqual(spark_count, 2)
+
+ # Test with binary_as_bytes=False
+ with self.both_conf({"spark.sql.execution.pyspark.binaryAsBytes":
"false"}):
+ connect_count = 0
+ for row in self.connect.sql(query).toLocalIterator():
+ self.assertIsInstance(row.b, bytearray)
+ connect_count += 1
+ self.assertEqual(connect_count, 2)
+
+ spark_count = 0
+ for row in self.spark.sql(query).toLocalIterator():
+ self.assertIsInstance(row.b, bytearray)
+ spark_count += 1
+ self.assertEqual(spark_count, 2)
+
+ def test_foreach_partition_binary_type(self):
+ """Test that df.foreachPartition() respects binary_as_bytes
configuration
+
+ Since foreachPartition() runs on executors and cannot return data to
the driver,
+ we test by ensuring the function doesn't throw exceptions when it
expects the correct types.
+ """
+ # Use server-side query that creates binary data
+ query = """
+ SELECT * FROM VALUES
+ (CAST('partition1' AS BINARY)),
+ (CAST('partition2' AS BINARY))
+ AS tab(b)
+ """
+
+ # Test with binary_as_bytes=True - should get bytes objects
+ with self.both_conf({"spark.sql.execution.pyspark.binaryAsBytes":
"true"}):
+
+ def assert_bytes_type(iterator):
+ count = 0
+ for row in iterator:
+ # This will raise an exception if the type is not bytes
+ assert isinstance(row.b, bytes), f"Expected bytes, got
{type(row.b).__name__}"
+ count += 1
+ # Ensure we actually processed rows
+ assert count > 0, "No rows were processed"
+
+ self.connect.sql(query).foreachPartition(assert_bytes_type)
+ self.spark.sql(query).foreachPartition(assert_bytes_type)
+
+ # Test with binary_as_bytes=False - should get bytearray objects
+ with self.both_conf({"spark.sql.execution.pyspark.binaryAsBytes":
"false"}):
+
+ def assert_bytearray_type(iterator):
+ count = 0
+ for row in iterator:
+ # This will raise an exception if the type is not bytearray
+ assert isinstance(
+ row.b, bytearray
+ ), f"Expected bytearray, got {type(row.b).__name__}"
+ count += 1
+ # Ensure we actually processed rows
+ assert count > 0, "No rows were processed"
+
+ self.connect.sql(query).foreachPartition(assert_bytearray_type)
+ # self.spark.sql(query).foreachPartition(assert_bytearray_type)
Review Comment:
need to investigate why this line fails: why is the SQL conf not propagated
in this particular case?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]