cloud-fan commented on code in PR #52443:
URL: https://github.com/apache/spark/pull/52443#discussion_r2378388541


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala:
##########
@@ -1196,6 +1197,253 @@ class EnsureRequirementsSuite extends 
SharedSparkSession {
     TransformExpression(BucketFunction, expr, Some(numBuckets))
   }
 
+  test("ShufflePartitionIdPassThrough - avoid necessary shuffle when they are 
compatible") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      val plan1 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val plan2 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val smjExec = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None, 
plan1, plan2)
+
+      EnsureRequirements.apply(smjExec) match {
+        case SortMergeJoinExec(
+            leftKeys,
+            rightKeys,
+            _,
+            _,
+            SortExec(_, _, DummySparkPlan(_, _, _: 
ShufflePartitionIdPassThrough, _, _), _),
+            SortExec(_, _, DummySparkPlan(_, _, _: 
ShufflePartitionIdPassThrough, _, _), _),
+            _
+            ) =>
+          assert(leftKeys === Seq(exprA))
+          assert(rightKeys === Seq(exprA))
+        case other => fail(s"We don't expect shuffle on neither sides, but 
got: $other")
+      }
+    }
+  }
+
+  test("ShufflePartitionIdPassThrough incompatibility - different partitions") 
{
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      // Different number of partitions - should add shuffles
+      val plan1 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val plan2 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 8))
+      val smjExec = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, 
plan1, plan2)
+
+      EnsureRequirements.apply(smjExec) match {
+        case SortMergeJoinExec(_, _, _, _,
+          SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), 
_),
+          SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), 
_), _) =>
+          // Both sides should be shuffled to default partitions
+          assert(p1.numPartitions == 10)
+          assert(p2.numPartitions == 10)
+          assert(p1.expressions == Seq(exprA))
+          assert(p2.expressions == Seq(exprB))
+        case other => fail(s"Expected shuffles on both sides, but got: $other")
+      }
+    }
+  }
+
+  test("ShufflePartitionIdPassThrough incompatibility - key position 
mismatch") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      // Key position mismatch - should add shuffles
+      val plan1 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val plan2 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprC), 5))
+      // Join on different keys than partitioning keys
+      val smjExec = SortMergeJoinExec(exprB :: Nil, exprD :: Nil, Inner, None, 
plan1, plan2)
+
+      EnsureRequirements.apply(smjExec) match {
+        case SortMergeJoinExec(_, _, _, _,
+          SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+          SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), 
_), _) =>
+          // Both sides shuffled due to key mismatch
+        case other => fail(s"Expected shuffles on both sides, but got: $other")
+      }
+    }
+  }
+
+  test("ShufflePartitionIdPassThrough vs HashPartitioning - always shuffles") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      // ShufflePartitionIdPassThrough vs HashPartitioning - always adds 
shuffles
+      val plan1 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val plan2 = DummySparkPlan(
+        outputPartitioning = HashPartitioning(exprB :: Nil, 5))
+      val smjExec = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, 
plan1, plan2)
+
+      EnsureRequirements.apply(smjExec) match {
+        case SortMergeJoinExec(_, _, _, _,
+          SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+          SortExec(_, _, _: DummySparkPlan, _), _) =>
+          // Left side shuffled, right side kept as-is
+        case SortMergeJoinExec(_, _, _, _,
+          SortExec(_, _, _: DummySparkPlan, _),
+          SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), 
_), _) =>
+          // Right side shuffled, left side kept as-is
+        case other => fail(s"Expected shuffle on at least one side, but got: 
$other")
+      }
+    }
+  }
+
+  test("ShufflePartitionIdPassThrough vs SinglePartition - shuffles added") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
+      // Even when compatible (numPartitions=1), shuffles added due to 
canCreatePartitioning=false
+      val plan1 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 1))
+      val plan2 = DummySparkPlan(outputPartitioning = SinglePartition)
+      val smjExec = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, 
plan1, plan2)
+
+      EnsureRequirements.apply(smjExec) match {
+        case SortMergeJoinExec(_, _, _, _,
+          SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+          SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), 
_), _) =>
+          // Both sides shuffled due to canCreatePartitioning = false
+        case other => fail(s"Expected shuffles on both sides, but got: $other")
+      }
+    }
+  }
+
+  test("ShufflePartitionIdPassThrough - incompatible due to different 
expressions " +
+       "with same base column") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      // Even though both use exprA as base and have same numPartitions,
+      // different Pmod operations make them incompatible
+      val plan1 = DummySparkPlan(
+        outputPartitioning = ShufflePartitionIdPassThrough(
+          DirectShufflePartitionID(Pmod(exprA, Literal(10))), 5))
+      val plan2 = DummySparkPlan(
+        outputPartitioning = ShufflePartitionIdPassThrough(
+          DirectShufflePartitionID(Pmod(exprA, Literal(5))), 5))
+      val smjExec = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None, 
plan1, plan2)
+
+      EnsureRequirements.apply(smjExec) match {
+        case SortMergeJoinExec(_, _, _, _,
+          SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), 
_),
+          SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), 
_), _) =>
+          // Both sides should be shuffled due to expression mismatch
+          assert(p1.numPartitions == 10)
+          assert(p2.numPartitions == 10)
+          assert(p1.expressions == Seq(exprA))
+          assert(p2.expressions == Seq(exprA))
+        case other => fail(s"Expected shuffles on both sides, but got: $other")
+      }
+    }
+  }
+
+  test("ShufflePartitionIdPassThrough - compatible with multiple clustering 
keys") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      // Both partitioned by exprA, joined on (exprA, exprB)
+      // Should be compatible because exprA positions overlap
+      val plan1 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val plan2 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val smjExec = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: 
Nil, Inner, None,
+        plan1, plan2)
+
+      EnsureRequirements.apply(smjExec) match {
+        case SortMergeJoinExec(
+            leftKeys,
+            rightKeys,
+            _,
+            _,
+            SortExec(_, _, DummySparkPlan(_, _, _: 
ShufflePartitionIdPassThrough, _, _), _),
+            SortExec(_, _, DummySparkPlan(_, _, _: 
ShufflePartitionIdPassThrough, _, _), _),
+            _
+            ) =>
+          assert(leftKeys === Seq(exprA, exprB))
+          assert(rightKeys === Seq(exprA, exprB))
+        case other => fail(s"We don't expect shuffle on neither sides with 
multiple " +
+          s"clustering keys, but got: $other")
+      }
+    }
+  }
+
+  test("ShufflePartitionIdPassThrough - incompatible when partition key not in 
join keys") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      // Partitioned by exprA and exprB respectively, but joining on 
completely different keys
+      // Should require shuffles because partition keys don't match join keys
+      val plan1 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val plan2 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5))
+      val smjExec = SortMergeJoinExec(exprC :: Nil, exprD :: Nil, Inner, None, 
plan1, plan2)
+
+      EnsureRequirements.apply(smjExec) match {
+        case SortMergeJoinExec(_, _, _, _,
+          SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), 
_),
+          SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), 
_), _) =>
+          // Both sides should be shuffled because partition keys not in join 
keys
+          assert(p1.numPartitions == 10)
+          assert(p2.numPartitions == 10)
+          assert(p1.expressions == Seq(exprC))
+          assert(p2.expressions == Seq(exprD))
+        case other => fail(s"Expected shuffles on both sides due to key 
mismatch, but got: $other")
+      }
+    }
+  }
+
+  test("ShufflePartitionIdPassThrough - cross position matching behavior") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      // Left partitioned by exprA, right partitioned by exprB
+      // Both sides join on (exprA, exprB)
+      // Test if cross-position matching works: left partition key exprA 
matches right join key
+      // exprA (pos 0)
+      // and right partition key exprB matches left join key exprB (pos 1)
+      val plan1 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val plan2 = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5))
+      val smjExec = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: 
Nil, Inner, None,
+        plan1, plan2)
+
+      EnsureRequirements.apply(smjExec) match {
+        case SortMergeJoinExec(_, _, _, _,
+          SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), 
_),
+          SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), 
_), _) =>
+          assert(p1.numPartitions == 10)
+          assert(p2.numPartitions == 10)
+          assert(p1.expressions == Seq(exprA, exprB))
+          assert(p2.expressions == Seq(exprA, exprB))
+        case other => fail(s"Expected either no shuffles (if compatible) or 
shuffles on " +
+          s"both sides, but got: $other")
+      }
+    }
+  }
+
+  test("ShufflePartitionIdPassThrough - compatible when partition key matches 
at any position") {

Review Comment:
   can we merge this test case into `ShufflePartitionIdPassThrough - compatible 
with multiple clustering keys`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to