cloud-fan commented on code in PR #52443:
URL: https://github.com/apache/spark/pull/52443#discussion_r2378386945
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala:
##########
@@ -1196,6 +1197,253 @@ class EnsureRequirementsSuite extends
SharedSparkSession {
TransformExpression(BucketFunction, expr, Some(numBuckets))
}
+ test("ShufflePartitionIdPassThrough - avoid necessary shuffle when they are
compatible") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ val plan1 = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val plan2 = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val smjExec = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None,
plan1, plan2)
+
+ EnsureRequirements.apply(smjExec) match {
+ case SortMergeJoinExec(
+ leftKeys,
+ rightKeys,
+ _,
+ _,
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ _
+ ) =>
+ assert(leftKeys === Seq(exprA))
+ assert(rightKeys === Seq(exprA))
+ case other => fail(s"We don't expect shuffle on neither sides, but
got: $other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough incompatibility - different partitions")
{
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ // Different number of partitions - should add shuffles
+ val plan1 = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val plan2 = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 8))
+ val smjExec = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None,
plan1, plan2)
+
+ EnsureRequirements.apply(smjExec) match {
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _),
_),
+ SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _),
_), _) =>
+ // Both sides should be shuffled to default partitions
+ assert(p1.numPartitions == 10)
+ assert(p2.numPartitions == 10)
+ assert(p1.expressions == Seq(exprA))
+ assert(p2.expressions == Seq(exprB))
+ case other => fail(s"Expected shuffles on both sides, but got: $other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough incompatibility - key position
mismatch") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ // Key position mismatch - should add shuffles
+ val plan1 = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val plan2 = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprC), 5))
+ // Join on different keys than partitioning keys
+ val smjExec = SortMergeJoinExec(exprB :: Nil, exprD :: Nil, Inner, None,
plan1, plan2)
+
+ EnsureRequirements.apply(smjExec) match {
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+ SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _),
_), _) =>
+ // Both sides shuffled due to key mismatch
+ case other => fail(s"Expected shuffles on both sides, but got: $other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough vs HashPartitioning - always shuffles") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ // ShufflePartitionIdPassThrough vs HashPartitioning - always adds
shuffles
+ val plan1 = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val plan2 = DummySparkPlan(
+ outputPartitioning = HashPartitioning(exprB :: Nil, 5))
+ val smjExec = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None,
plan1, plan2)
+
+ EnsureRequirements.apply(smjExec) match {
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+ SortExec(_, _, _: DummySparkPlan, _), _) =>
+ // Left side shuffled, right side kept as-is
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, _: DummySparkPlan, _),
+ SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _),
_), _) =>
+ // Right side shuffled, left side kept as-is
+ case other => fail(s"Expected shuffle on at least one side, but got:
$other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough vs SinglePartition - shuffles added") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
+ // Even when compatible (numPartitions=1), shuffles added due to
canCreatePartitioning=false
+ val plan1 = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 1))
+ val plan2 = DummySparkPlan(outputPartitioning = SinglePartition)
+ val smjExec = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None,
plan1, plan2)
+
+ EnsureRequirements.apply(smjExec) match {
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+ SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _),
_), _) =>
+ // Both sides shuffled due to canCreatePartitioning = false
+ case other => fail(s"Expected shuffles on both sides, but got: $other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough - incompatible due to different
expressions " +
+ "with same base column") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ // Even though both use exprA as base and have same numPartitions,
+ // different Pmod operations make them incompatible
+ val plan1 = DummySparkPlan(
+ outputPartitioning = ShufflePartitionIdPassThrough(
+ DirectShufflePartitionID(Pmod(exprA, Literal(10))), 5))
+ val plan2 = DummySparkPlan(
+ outputPartitioning = ShufflePartitionIdPassThrough(
+ DirectShufflePartitionID(Pmod(exprA, Literal(5))), 5))
+ val smjExec = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None,
plan1, plan2)
+
+ EnsureRequirements.apply(smjExec) match {
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _),
_),
+ SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _),
_), _) =>
+ // Both sides should be shuffled due to expression mismatch
+ assert(p1.numPartitions == 10)
+ assert(p2.numPartitions == 10)
+ assert(p1.expressions == Seq(exprA))
+ assert(p2.expressions == Seq(exprA))
+ case other => fail(s"Expected shuffles on both sides, but got: $other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough - compatible with multiple clustering
keys") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ // Both partitioned by exprA, joined on (exprA, exprB)
+ // Should be compatible because exprA positions overlap
+ val plan1 = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val plan2 = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val smjExec = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB ::
Nil, Inner, None,
+ plan1, plan2)
+
+ EnsureRequirements.apply(smjExec) match {
+ case SortMergeJoinExec(
+ leftKeys,
+ rightKeys,
+ _,
+ _,
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ _
+ ) =>
+ assert(leftKeys === Seq(exprA, exprB))
+ assert(rightKeys === Seq(exprA, exprB))
+ case other => fail(s"We don't expect shuffle on neither sides with
multiple " +
+ s"clustering keys, but got: $other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough - incompatible when partition key not in
join keys") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ // Partitioned by exprA and exprB respectively, but joining on
completely different keys
+ // Should require shuffles because partition keys don't match join keys
+ val plan1 = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val plan2 = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5))
+ val smjExec = SortMergeJoinExec(exprC :: Nil, exprD :: Nil, Inner, None,
plan1, plan2)
+
+ EnsureRequirements.apply(smjExec) match {
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _),
_),
+ SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _),
_), _) =>
+ // Both sides should be shuffled because partition keys not in join
keys
+ assert(p1.numPartitions == 10)
+ assert(p2.numPartitions == 10)
+ assert(p1.expressions == Seq(exprC))
+ assert(p2.expressions == Seq(exprD))
+ case other => fail(s"Expected shuffles on both sides due to key
mismatch, but got: $other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough - cross position matching behavior") {
Review Comment:
This looks the same as `ShufflePartitionIdPassThrough incompatibility - key
position mismatch`?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]