szehon-ho commented on code in PR #52443:
URL: https://github.com/apache/spark/pull/52443#discussion_r2450017594
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala:
##########
@@ -1196,6 +1197,191 @@ class EnsureRequirementsSuite extends
SharedSparkSession {
TransformExpression(BucketFunction, expr, Some(numBuckets))
}
+ test("ShufflePartitionIdPassThrough - avoid unnecessary shuffle when
children are compatible") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ val passThrough_a_5 =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)
+
+ val leftPlan = DummySparkPlan(outputPartitioning = passThrough_a_5)
+ val rightPlan = DummySparkPlan(outputPartitioning = passThrough_a_5)
+ val join = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None,
leftPlan, rightPlan)
+
+ EnsureRequirements.apply(join) match {
+ case SortMergeJoinExec(
+ leftKeys,
+ rightKeys,
+ _,
+ _,
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ _
+ ) =>
+ assert(leftKeys === Seq(exprA))
+ assert(rightKeys === Seq(exprA))
+ case other => fail(s"We don't expect shuffle on neither sides, but
got: $other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough incompatibility - different partitions")
{
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ // Different number of partitions - should add shuffles
+ val leftPlan = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val rightPlan = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 8))
+ val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None,
leftPlan, rightPlan)
+
+ EnsureRequirements.apply(join) match {
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _),
_),
+ SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _),
_), _) =>
+ // Both sides should be shuffled to default partitions
+ assert(p1.numPartitions == 10)
+ assert(p2.numPartitions == 10)
+ assert(p1.expressions == Seq(exprA))
+ assert(p2.expressions == Seq(exprB))
+ case other => fail(s"Expected shuffles on both sides, but got: $other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough incompatibility - key position
mismatch") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ // Key position mismatch - should add shuffles
+ val leftPlan = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val rightPlan = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprC), 5))
+ // Join on different keys than partitioning keys
+ val join = SortMergeJoinExec(exprA :: exprB :: Nil, exprD :: exprC ::
Nil, Inner, None,
+ leftPlan, rightPlan)
+
+ EnsureRequirements.apply(join) match {
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+ SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _),
_), _) =>
+ // Both sides shuffled due to key mismatch
+ case other => fail(s"Expected shuffles on both sides, but got: $other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough vs HashPartitioning - always shuffles") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ // ShufflePartitionIdPassThrough vs HashPartitioning - always adds
shuffles
+ val leftPlan = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val rightPlan = DummySparkPlan(
+ outputPartitioning = HashPartitioning(exprB :: Nil, 5))
+ val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None,
leftPlan, rightPlan)
+
+ EnsureRequirements.apply(join) match {
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+ SortExec(_, _, _: DummySparkPlan, _), _) =>
+ // Left side shuffled, right side kept as-is
+ case other => fail(s"Expected shuffle on at least one side, but got:
$other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough vs SinglePartition - shuffles added") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
+ // Even when compatible (numPartitions=1), shuffles added due to
canCreatePartitioning=false
+ val leftPlan = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 1))
+ val rightPlan = DummySparkPlan(outputPartitioning = SinglePartition)
+ val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None,
leftPlan, rightPlan)
+
+ EnsureRequirements.apply(join) match {
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+ SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _),
_), _) =>
+ // Both sides shuffled due to canCreatePartitioning = false
+ case other => fail(s"Expected shuffles on both sides, but got: $other")
+ }
+ }
+ }
+
+
+ test("ShufflePartitionIdPassThrough - compatible with multiple clustering
keys") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ val passThrough_a_5 =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)
+ val passThrough_b_5 =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5)
+
+ // Both partitioned by exprA, joined on (exprA, exprB)
+ // Should be compatible because exprA positions overlap
+ val leftPlanA = DummySparkPlan(outputPartitioning = passThrough_a_5)
+ val rightPlanA = DummySparkPlan(outputPartitioning = passThrough_a_5)
+ val joinA = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB ::
Nil, Inner, None,
+ leftPlanA, rightPlanA)
+
+ EnsureRequirements.apply(joinA) match {
+ case SortMergeJoinExec(
+ leftKeys,
+ rightKeys,
+ _,
+ _,
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ _
+ ) =>
+ assert(leftKeys === Seq(exprA, exprB))
+ assert(rightKeys === Seq(exprA, exprB))
+ case other => fail(s"We don't expect shuffle on neither sides with
multiple " +
Review Comment:
nit: same (on either side)
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala:
##########
@@ -1196,6 +1197,191 @@ class EnsureRequirementsSuite extends
SharedSparkSession {
TransformExpression(BucketFunction, expr, Some(numBuckets))
}
+ test("ShufflePartitionIdPassThrough - avoid unnecessary shuffle when
children are compatible") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ val passThrough_a_5 =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)
+
+ val leftPlan = DummySparkPlan(outputPartitioning = passThrough_a_5)
+ val rightPlan = DummySparkPlan(outputPartitioning = passThrough_a_5)
+ val join = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None,
leftPlan, rightPlan)
+
+ EnsureRequirements.apply(join) match {
+ case SortMergeJoinExec(
+ leftKeys,
+ rightKeys,
+ _,
+ _,
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ _
+ ) =>
+ assert(leftKeys === Seq(exprA))
+ assert(rightKeys === Seq(exprA))
+ case other => fail(s"We don't expect shuffle on neither sides, but
got: $other")
Review Comment:
nit: 'on either side' (its double negative)
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala:
##########
@@ -1196,6 +1197,191 @@ class EnsureRequirementsSuite extends
SharedSparkSession {
TransformExpression(BucketFunction, expr, Some(numBuckets))
}
+ test("ShufflePartitionIdPassThrough - avoid unnecessary shuffle when
children are compatible") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ val passThrough_a_5 =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)
+
+ val leftPlan = DummySparkPlan(outputPartitioning = passThrough_a_5)
+ val rightPlan = DummySparkPlan(outputPartitioning = passThrough_a_5)
+ val join = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None,
leftPlan, rightPlan)
+
+ EnsureRequirements.apply(join) match {
+ case SortMergeJoinExec(
+ leftKeys,
+ rightKeys,
+ _,
+ _,
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ SortExec(_, _, DummySparkPlan(_, _, _:
ShufflePartitionIdPassThrough, _, _), _),
+ _
+ ) =>
+ assert(leftKeys === Seq(exprA))
+ assert(rightKeys === Seq(exprA))
+ case other => fail(s"We don't expect shuffle on neither sides, but
got: $other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough incompatibility - different partitions")
{
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ // Different number of partitions - should add shuffles
+ val leftPlan = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val rightPlan = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 8))
+ val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None,
leftPlan, rightPlan)
+
+ EnsureRequirements.apply(join) match {
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _),
_),
+ SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _),
_), _) =>
+ // Both sides should be shuffled to default partitions
+ assert(p1.numPartitions == 10)
+ assert(p2.numPartitions == 10)
+ assert(p1.expressions == Seq(exprA))
+ assert(p2.expressions == Seq(exprB))
+ case other => fail(s"Expected shuffles on both sides, but got: $other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough incompatibility - key position
mismatch") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ // Key position mismatch - should add shuffles
+ val leftPlan = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val rightPlan = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprC), 5))
+ // Join on different keys than partitioning keys
+ val join = SortMergeJoinExec(exprA :: exprB :: Nil, exprD :: exprC ::
Nil, Inner, None,
+ leftPlan, rightPlan)
+
+ EnsureRequirements.apply(join) match {
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+ SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _),
_), _) =>
+ // Both sides shuffled due to key mismatch
+ case other => fail(s"Expected shuffles on both sides, but got: $other")
+ }
+ }
+ }
+
+ test("ShufflePartitionIdPassThrough vs HashPartitioning - always shuffles") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ // ShufflePartitionIdPassThrough vs HashPartitioning - always adds
shuffles
+ val leftPlan = DummySparkPlan(
+ outputPartitioning =
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+ val rightPlan = DummySparkPlan(
+ outputPartitioning = HashPartitioning(exprB :: Nil, 5))
+ val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None,
leftPlan, rightPlan)
+
+ EnsureRequirements.apply(join) match {
+ case SortMergeJoinExec(_, _, _, _,
+ SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+ SortExec(_, _, _: DummySparkPlan, _), _) =>
+ // Left side shuffled, right side kept as-is
+ case other => fail(s"Expected shuffle on at least one side, but got:
$other")
Review Comment:
nit: just say 'left side'
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]