szehon-ho commented on code in PR #52443:
URL: https://github.com/apache/spark/pull/52443#discussion_r2393517609


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala:
##########
@@ -1196,6 +1197,253 @@ class EnsureRequirementsSuite extends 
SharedSparkSession {
     TransformExpression(BucketFunction, expr, Some(numBuckets))
   }
 
+  test("ShufflePartitionIdPassThrough - avoid necessary shuffle when they are 
compatible") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      val plan1 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val plan2 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val smjExec = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None, 
plan1, plan2)
+
+      EnsureRequirements.apply(smjExec) match {
+        case SortMergeJoinExec(
+            leftKeys,
+            rightKeys,
+            _,
+            _,
+            SortExec(_, _, DummySparkPlan(_, _, _: 
ShufflePartitionIdPassThrough, _, _), _),
+            SortExec(_, _, DummySparkPlan(_, _, _: 
ShufflePartitionIdPassThrough, _, _), _),
+            _
+            ) =>
+          assert(leftKeys === Seq(exprA))
+          assert(rightKeys === Seq(exprA))
+        case other => fail(s"We don't expect shuffle on neither sides, but 
got: $other")
+      }
+    }
+  }
+
+  test("ShufflePartitionIdPassThrough incompatibility - different partitions") 
{
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      // Different number of partitions - should add shuffles
+      val plan1 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val plan2 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 8))
+      val smjExec = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, 
plan1, plan2)
+
+      EnsureRequirements.apply(smjExec) match {
+        case SortMergeJoinExec(_, _, _, _,
+          SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), 
_),
+          SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), 
_), _) =>
+          // Both sides should be shuffled to default partitions
+          assert(p1.numPartitions == 10)
+          assert(p2.numPartitions == 10)
+          assert(p1.expressions == Seq(exprA))
+          assert(p2.expressions == Seq(exprB))
+        case other => fail(s"Expected shuffles on both sides, but got: $other")
+      }
+    }
+  }
+
+  test("ShufflePartitionIdPassThrough incompatibility - key position 
mismatch") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      // Key position mismatch - should add shuffles
+      val plan1 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val plan2 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprC), 5))
+      // Join on different keys than partitioning keys
+      val smjExec = SortMergeJoinExec(exprB :: Nil, exprD :: Nil, Inner, None, 
plan1, plan2)
+
+      EnsureRequirements.apply(smjExec) match {
+        case SortMergeJoinExec(_, _, _, _,
+          SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+          SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), 
_), _) =>
+          // Both sides shuffled due to key mismatch
+        case other => fail(s"Expected shuffles on both sides, but got: $other")
+      }
+    }
+  }
+
+  test("ShufflePartitionIdPassThrough vs HashPartitioning - always shuffles") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      // ShufflePartitionIdPassThrough vs HashPartitioning - always adds 
shuffles
+      val plan1 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val plan2 = DummySparkPlan(
+        outputPartitioning = HashPartitioning(exprB :: Nil, 5))
+      val smjExec = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, 
plan1, plan2)
+
+      EnsureRequirements.apply(smjExec) match {
+        case SortMergeJoinExec(_, _, _, _,
+          SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+          SortExec(_, _, _: DummySparkPlan, _), _) =>
+          // Left side shuffled, right side kept as-is
+        case SortMergeJoinExec(_, _, _, _,
+          SortExec(_, _, _: DummySparkPlan, _),
+          SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), 
_), _) =>
+          // Right side shuffled, left side kept as-is

Review Comment:
   Agree, i dont get why the return value match both



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to