cloud-fan commented on code in PR #52443:
URL: https://github.com/apache/spark/pull/52443#discussion_r2378374215
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala:
##########
@@ -1196,6 +1197,253 @@ class EnsureRequirementsSuite extends
SharedSparkSession {
TransformExpression(BucketFunction, expr, Some(numBuckets))
}
+ test("ShufflePartitionIdPassThrough - avoid necessary shuffle when they are
compatible") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ val plan1 = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val plan2 = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val smjExec = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None,
plan1, plan2)
+
+ EnsureRequirements.apply(smjExec) match {
+ case SortMergeJoinExec(
+ leftKeys,
+ rightKeys,
+ _,
+ _,
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ _
+ ) =>
+ assert(leftKeys === Seq(exprA))
+ assert(rightKeys === Seq(exprA))
+ case other => fail(s"We don't expect shuffle on neither sides, but
got: $other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough incompatibility - different partitions")
{
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ // Different number of partitions - should add shuffles
+ val plan1 = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val plan2 = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 8))
+ val smjExec = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None,
plan1, plan2)
+
+ EnsureRequirements.apply(smjExec) match {
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _),
_),
+ SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _),
_), _) =>
+ // Both sides should be shuffled to default partitions
+ assert(p1.numPartitions == 10)
+ assert(p2.numPartitions == 10)
+ assert(p1.expressions == Seq(exprA))
+ assert(p2.expressions == Seq(exprB))
+ case other => fail(s"Expected shuffles on both sides, but got: $other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough incompatibility - key position
mismatch") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ // Key position mismatch - should add shuffles
+ val plan1 = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val plan2 = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprC), 5))
+ // Join on different keys than partitioning keys
+ val smjExec = SortMergeJoinExec(exprB :: Nil, exprD :: Nil, Inner, None,
plan1, plan2)
Review Comment:
```suggestion
val smjExec = SortMergeJoinExec(exprA :: exprB :: Nil, exprD :: exprC
:: Nil, Inner, None, plan1, plan2)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]