szehon-ho commented on code in PR #52443:
URL: https://github.com/apache/spark/pull/52443#discussion_r2383910676


##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala:
##########
@@ -479,4 +480,72 @@ class ShuffleSpecSuite extends SparkFunSuite with 
SQLHelper {
         "methodName" -> "createPartitioning$",
         "className" -> 
"org.apache.spark.sql.catalyst.plans.physical.ShuffleSpec"))
   }
+
+  test("compatibility: ShufflePartitionIdPassThroughSpec on both sides") {
+    val dist = ClusteredDistribution(Seq($"a", $"b"))
+    val p1 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 10)
+    val p2 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10)
+
+    // Identical specs should be compatible
+    checkCompatible(
+      p1.createShuffleSpec(dist),
+      p2.createShuffleSpec(ClusteredDistribution(Seq($"c", $"d"))),
+      expected = true
+    )
+
+    // Different number of partitions should be incompatible
+    val p3 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 5)
+    checkCompatible(
+      p1.createShuffleSpec(dist),
+      p3.createShuffleSpec(ClusteredDistribution(Seq($"c", $"d"))),
+      expected = false
+    )
+
+    // Mismatched key positions should be incompatible
+    val dist1 = ClusteredDistribution(Seq($"a", $"b"))
+    val p4 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"b"), 10) 
// Key at pos 1
+    val dist2 = ClusteredDistribution(Seq($"c", $"d"))
+    val p5 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10) 
// Key at pos 0
+    checkCompatible(
+      p4.createShuffleSpec(dist1),
+      p5.createShuffleSpec(dist2),
+      expected = false
+    )
+
+    // Mismatched clustering keys
+    val dist3 = ClusteredDistribution(Seq($"e", $"b"))
+    checkCompatible(
+      p1.createShuffleSpec(dist3),
+      p2.createShuffleSpec(dist),
+      expected = false
+    )
+  }
+
+  test("compatibility: ShufflePartitionIdPassThroughSpec vs other specs") {
+    val dist = ClusteredDistribution(Seq($"a", $"b"))
+    val p = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 10)

Review Comment:
   nit: should we just define this above where its used (second check).  I 
think its not used in the first check below?
   
   Also maybe we just need to make p.createShuffleSpec(dist) a variable and not 
have to define p



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala:
##########
@@ -966,6 +969,51 @@ object KeyGroupedShuffleSpec {
   }
 }
 
+case class ShufflePartitionIdPassThroughSpec(
+    partitioning: ShufflePartitionIdPassThrough,
+    distribution: ClusteredDistribution) extends ShuffleSpec {
+
+  /**
+   * A sequence where each element is a set of positions of the partition key 
to the cluster
+   * keys. Similar to HashShuffleSpec, this maps the partitioning expression 
to positions
+   * in the distribution clustering keys.
+   */
+  lazy val keyPositions: mutable.BitSet = {
+    val distKeyToPos = mutable.Map.empty[Expression, mutable.BitSet]
+    distribution.clustering.zipWithIndex.foreach { case (distKey, distKeyPos) 
=>
+      distKeyToPos.getOrElseUpdate(distKey.canonicalized, 
mutable.BitSet.empty).add(distKeyPos)
+    }
+    distKeyToPos.getOrElse(partitioning.expr.child.canonicalized, 
mutable.BitSet.empty)
+  }
+
+  override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
+    case SinglePartitionShuffleSpec =>
+      partitioning.numPartitions == 1
+    case otherPassThroughSpec @ ShufflePartitionIdPassThroughSpec(
+        otherPartitioning, otherDistribution) =>
+      // As ShufflePartitionIdPassThrough only allows a single expression
+      // as the partitioning expression, we check compatibility as follows:
+      // 1. Same number of clustering expressions
+      // 2. Same number of partitions
+      // 3. each pair of partitioning expression from both sides has 
overlapping positions in their

Review Comment:
   nit: if each has only one, we can delete  'each pair'?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to