chirag-s-db commented on code in PR #54330:
URL: https://github.com/apache/spark/pull/54330#discussion_r2843582827


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala:
##########
@@ -416,63 +480,43 @@ case class KeyGroupedPartitioning(
     val result = KeyGroupedShuffleSpec(this, distribution)
     if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
       // If allowing join keys to be subset of clustering keys, we should 
create a new
-      // `KeyGroupedPartitioning` here that is grouped on the join keys 
instead, and use that as
+      // `KeyedPartitioning` here that is grouped on the join keys instead, 
and use that as
       // the returned shuffle spec.
       val joinKeyPositions = 
result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2)
-      val projectedPartitioning = KeyGroupedPartitioning(expressions, 
joinKeyPositions,
-          partitionValues, originalPartitionValues)
+      val (projectedExpressions, projectedDataTypes, projectedKeys, 
projectedOriginalKeys) =
+        projectKeys(joinKeyPositions)
+      val projectedComparableWrapperFactory =
+        
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(projectedDataTypes)
+      val distinctPartitionKeys = 
projectedKeys.distinctBy(projectedComparableWrapperFactory)
+      val projectedPartitioning = copy(expressions = projectedExpressions,
+        partitionKeys = distinctPartitionKeys, originalPartitionKeys = 
projectedOriginalKeys)
       result.copy(partitioning = projectedPartitioning, joinKeyPositions = 
Some(joinKeyPositions))
     } else {
       result
     }
   }
 
-  lazy val uniquePartitionValues: Seq[InternalRow] = {
-    val internalRowComparableFactory =
-      InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
-        expressions.map(_.dataType))
-    partitionValues
-        .map(internalRowComparableFactory)
-        .distinct
-        .map(_.row)
-  }
+  override def equals(that: Any): Boolean = that match {
+    case k: KeyedPartitioning if this.expressions == k.expressions =>

Review Comment:
   Should we do a semantic comparison of the expressions to avoid cosmetic 
differences causing false-positive equality failure?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala:
##########
@@ -416,63 +480,43 @@ case class KeyGroupedPartitioning(
     val result = KeyGroupedShuffleSpec(this, distribution)
     if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
       // If allowing join keys to be subset of clustering keys, we should 
create a new
-      // `KeyGroupedPartitioning` here that is grouped on the join keys 
instead, and use that as
+      // `KeyedPartitioning` here that is grouped on the join keys instead, 
and use that as
       // the returned shuffle spec.
       val joinKeyPositions = 
result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2)
-      val projectedPartitioning = KeyGroupedPartitioning(expressions, 
joinKeyPositions,
-          partitionValues, originalPartitionValues)
+      val (projectedExpressions, projectedDataTypes, projectedKeys, 
projectedOriginalKeys) =
+        projectKeys(joinKeyPositions)
+      val projectedComparableWrapperFactory =
+        
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(projectedDataTypes)
+      val distinctPartitionKeys = 
projectedKeys.distinctBy(projectedComparableWrapperFactory)

Review Comment:
   Are we guaranteed that these projected keys will be in sorted order here? 
Suppose we had 2 partition columns and partitions w/ [0, 1], [2, 0] - if we 
projected to only the 2nd expression, yielding [1], [0], I think the ordering 
might change then?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala:
##########
@@ -78,49 +74,65 @@ case class BatchScanExec(
       val newPartitions = scan.toBatch.planInputPartitions()
 
       originalPartitioning match {
-        case p: KeyGroupedPartitioning =>
+        case p: KeyedPartitioning =>
           if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) {
             throw new SparkException("Data source must have preserved the 
original partitioning " +
                 "during runtime filtering: not all partitions implement 
HasPartitionKey after " +
                 "filtering")
           }
-          val newPartitionValues = newPartitions.map(partition =>
+          val newPartitionKeys = newPartitions.map(partition =>
               
InternalRowComparableWrapper(partition.asInstanceOf[HasPartitionKey], 
p.expressions))
             .toSet
-          val oldPartitionValues = p.partitionValues
+          val oldPartitionKeys = p.partitionKeys
             .map(partition => InternalRowComparableWrapper(partition, 
p.expressions)).toSet
           // We require the new number of partition values to be equal or less 
than the old number
           // of partition values here. In the case of less than, empty 
partitions will be added for
           // those missing values that are not present in the new input 
partitions.
-          if (oldPartitionValues.size < newPartitionValues.size) {
+          if (oldPartitionKeys.size < newPartitionKeys.size) {
             throw new SparkException("During runtime filtering, data source 
must either report " +
                 "the same number of partition values, or a subset of partition 
values from the " +
-                s"original. Before: ${oldPartitionValues.size} partition 
values. " +
-                s"After: ${newPartitionValues.size} partition values")
+                s"original. Before: ${oldPartitionKeys.size} partition values. 
" +
+                s"After: ${newPartitionKeys.size} partition values")
           }
 
-          if (!newPartitionValues.forall(oldPartitionValues.contains)) {
+          if (!newPartitionKeys.forall(oldPartitionKeys.contains)) {
             throw new SparkException("During runtime filtering, data source 
must not report new " +
                 "partition values that are not present in the original 
partitioning.")
           }
 
-          groupPartitions(newPartitions.toImmutableArraySeq)
-            .map(_.groupedParts.map(_.parts)).getOrElse(Seq.empty)
+          val dataTypes = p.expressions.map(_.dataType)

Review Comment:
   Would it be possible to put this logic in `KeyGroupedPartitionedScan` along 
with the other SPJ logic for BatchScanExec? 



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala:
##########
@@ -416,63 +480,43 @@ case class KeyGroupedPartitioning(
     val result = KeyGroupedShuffleSpec(this, distribution)
     if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
       // If allowing join keys to be subset of clustering keys, we should 
create a new
-      // `KeyGroupedPartitioning` here that is grouped on the join keys 
instead, and use that as
+      // `KeyedPartitioning` here that is grouped on the join keys instead, 
and use that as
       // the returned shuffle spec.
       val joinKeyPositions = 
result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2)
-      val projectedPartitioning = KeyGroupedPartitioning(expressions, 
joinKeyPositions,
-          partitionValues, originalPartitionValues)
+      val (projectedExpressions, projectedDataTypes, projectedKeys, 
projectedOriginalKeys) =
+        projectKeys(joinKeyPositions)
+      val projectedComparableWrapperFactory =
+        
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(projectedDataTypes)
+      val distinctPartitionKeys = 
projectedKeys.distinctBy(projectedComparableWrapperFactory)
+      val projectedPartitioning = copy(expressions = projectedExpressions,
+        partitionKeys = distinctPartitionKeys, originalPartitionKeys = 
projectedOriginalKeys)
       result.copy(partitioning = projectedPartitioning, joinKeyPositions = 
Some(joinKeyPositions))
     } else {
       result
     }
   }
 
-  lazy val uniquePartitionValues: Seq[InternalRow] = {
-    val internalRowComparableFactory =
-      InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
-        expressions.map(_.dataType))
-    partitionValues
-        .map(internalRowComparableFactory)
-        .distinct
-        .map(_.row)
-  }
+  override def equals(that: Any): Boolean = that match {
+    case k: KeyedPartitioning if this.expressions == k.expressions =>

Review Comment:
   I suppose not since the hash below doesn't take it into account (unless we 
wanted to hash on the canonicalized expression and whether they're 
deterministic?). Probably not worth the effort...



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala:
##########
@@ -78,49 +74,65 @@ case class BatchScanExec(
       val newPartitions = scan.toBatch.planInputPartitions()
 
       originalPartitioning match {
-        case p: KeyGroupedPartitioning =>
+        case p: KeyedPartitioning =>
           if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) {
             throw new SparkException("Data source must have preserved the 
original partitioning " +
                 "during runtime filtering: not all partitions implement 
HasPartitionKey after " +
                 "filtering")
           }
-          val newPartitionValues = newPartitions.map(partition =>
+          val newPartitionKeys = newPartitions.map(partition =>
               
InternalRowComparableWrapper(partition.asInstanceOf[HasPartitionKey], 
p.expressions))
             .toSet
-          val oldPartitionValues = p.partitionValues
+          val oldPartitionKeys = p.partitionKeys
             .map(partition => InternalRowComparableWrapper(partition, 
p.expressions)).toSet
           // We require the new number of partition values to be equal or less 
than the old number
           // of partition values here. In the case of less than, empty 
partitions will be added for
           // those missing values that are not present in the new input 
partitions.
-          if (oldPartitionValues.size < newPartitionValues.size) {
+          if (oldPartitionKeys.size < newPartitionKeys.size) {
             throw new SparkException("During runtime filtering, data source 
must either report " +
                 "the same number of partition values, or a subset of partition 
values from the " +
-                s"original. Before: ${oldPartitionValues.size} partition 
values. " +
-                s"After: ${newPartitionValues.size} partition values")
+                s"original. Before: ${oldPartitionKeys.size} partition values. 
" +
+                s"After: ${newPartitionKeys.size} partition values")
           }
 
-          if (!newPartitionValues.forall(oldPartitionValues.contains)) {
+          if (!newPartitionKeys.forall(oldPartitionKeys.contains)) {
             throw new SparkException("During runtime filtering, data source 
must not report new " +
                 "partition values that are not present in the original 
partitioning.")
           }
 
-          groupPartitions(newPartitions.toImmutableArraySeq)
-            .map(_.groupedParts.map(_.parts)).getOrElse(Seq.empty)
+          val dataTypes = p.expressions.map(_.dataType)

Review Comment:
   Ah, I see this is removed - would it be possible to keep the trait, but only 
have it contain this logic?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala:
##########
@@ -41,17 +39,16 @@ case class BatchScanExec(
     runtimeFilters: Seq[Expression],
     ordering: Option[Seq[SortOrder]] = None,
     @transient table: Table,
-    spjParams: StoragePartitionJoinParams = StoragePartitionJoinParams()
-  ) extends DataSourceV2ScanExecBase with 
KeyGroupedPartitionedScan[InputPartition] {
+    keyGroupedPartitioning: Option[Seq[Expression]] = None
+  ) extends DataSourceV2ScanExecBase {
 
   @transient lazy val batch: Batch = if (scan == null) null else scan.toBatch
 
   // TODO: unify the equal/hashCode implementation for all data source v2 
query plans.
   override def equals(other: Any): Boolean = other match {
     case other: BatchScanExec =>
       this.batch != null && this.batch == other.batch &&
-          this.runtimeFilters == other.runtimeFilters &&
-          this.spjParams == other.spjParams

Review Comment:
   For my understanding, do we need `keyGroupedPartitioning` in the equality 
here to distinguish scans with different partition keys (ie, if one partition 
has projected keys or not)? Or are we safe because we project at the 
GroupPartitionsExec level?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala:
##########
@@ -346,43 +348,105 @@ case class CoalescedHashPartitioning(from: 
HashPartitioning, partitions: Seq[Coa
 }
 
 /**
- * Represents a partitioning where rows are split across partitions based on 
transforms defined
- * by `expressions`. `partitionValues`, if defined, should contain value of 
partition key(s) in
- * ascending order, after evaluated by the transforms in `expressions`, for 
each input partition.
- * In addition, its length must be the same as the number of Spark partitions 
(and thus is a 1-1
- * mapping), and each row in `partitionValues` must be unique.
+ * Represents a partitioning where rows are split across partitions based on 
transforms defined by
+ * `expressions`.
+ *
+ * == Partition Keys ==
+ * This partitioning has two sets of partition keys:
+ *
+ * - `partitionKeys`: The current partition key for each partition, in 
ascending order. May contain
+ *   duplicates when first created from a data source, but becomes unique 
after grouping.
+ *
+ * - `originalPartitionKeys`: The original partition keys from the data 
source, in ascending order.
+ *   Always preserves the original values, even after grouping. Used to track 
the original
+ *   distribution for optimization purposes.
  *
- * The `originalPartitionValues`, on the other hand, are partition values from 
the original input
- * splits returned by data sources. It may contain duplicated values.
+ * == Grouping State ==
+ * A KeyedPartitioning can be in two states:
  *
- * For example, if a data source reports partition transform expressions 
`[years(ts_col)]` with 4
- * input splits whose corresponding partition values are `[0, 1, 2, 2]`, then 
the `expressions`
- * in this case is `[years(ts_col)]`, while `partitionValues` is `[0, 1, 2]`, 
which
- * represents 3 input partitions with distinct partition values. All rows in 
each partition have
- * the same value for column `ts_col` (which is of timestamp type), after 
being applied by the
- * `years` transform. This is generated after combining the two splits with 
partition value `2`
- * into a single Spark partition.
+ * - '''Ungrouped''' (when `isGrouped == false`): `partitionKeys` contains 
duplicates. Multiple
+ *   input partitions share the same key. This is the initial state when 
created from a data source.
  *
- * On the other hand, in this example `[0, 1, 2, 2]` is the value of 
`originalPartitionValues`
- * which is calculated from the original input splits.
+ * - '''Grouped''' (when `isGrouped == true`): `partitionKeys` contains only 
unique values. Each
+ *   partition has a distinct key. This state is achieved by applying 
`GroupPartitionsExec`, which
+ *   coalesces partitions with the same key.
  *
- * @param expressions partition expressions for the partitioning.
- * @param numPartitions the number of partitions
- * @param partitionValues the values for the final cluster keys (that is, 
after applying grouping
- *                        on the input splits according to `expressions`) of 
the distribution,
- *                        must be in ascending order, and must NOT contain 
duplicated values.
- * @param originalPartitionValues the original input partition values before 
any grouping has been
- *                                applied, must be in ascending order, and may 
contain duplicated
- *                                values
+ * == Example ==
+ * Consider a data source with partition transform `[years(ts_col)]` and 4 
input splits:
+ *
+ * '''Before GroupPartitionsExec''' (ungrouped):
+ * {{{
+ *   expressions:           [years(ts_col)]
+ *   partitionKeys:         [0, 1, 2, 2]    // partition 2 and 3 have the same 
key
+ *   originalPartitionKeys: [0, 1, 2, 2]
+ *   numPartitions:         4
+ *   isGrouped:             false
+ * }}}
+ *
+ * '''After GroupPartitionsExec''' (grouped):
+ * {{{
+ *   expressions:           [years(ts_col)]
+ *   partitionKeys:         [0, 1, 2]       // duplicates removed, partitions 
coalesced
+ *   originalPartitionKeys: [0, 1, 2, 2]    // unchanged, preserves original 
distribution
+ *   numPartitions:         3
+ *   isGrouped:             true
+ * }}}
+ *
+ * @param expressions Partition transform expressions (e.g., `years(col)`, 
`bucket(10, col)`).
+ * @param partitionKeys Current partition keys, one per partition, in 
ascending order.
+ *                      May contain duplicates before grouping.
+ * @param originalPartitionKeys Original partition keys from the data source, 
in ascending order.
+ *                              Preserves the initial distribution even after 
grouping.
  */
-case class KeyGroupedPartitioning(
+case class KeyedPartitioning(
     expressions: Seq[Expression],
-    numPartitions: Int,
-    partitionValues: Seq[InternalRow] = Seq.empty,
-    originalPartitionValues: Seq[InternalRow] = Seq.empty) extends 
HashPartitioningLike {
+    partitionKeys: Seq[InternalRow],
+    originalPartitionKeys: Seq[InternalRow]) extends Expression with 
Partitioning with Unevaluable {
+  override val numPartitions = partitionKeys.length
+
+  override def children: Seq[Expression] = expressions
+  override def nullable: Boolean = false
+  override def dataType: DataType = IntegerType
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): KeyedPartitioning =
+    copy(expressions = newChildren)
+
+  @transient private lazy val dataTypes: Seq[DataType] = 
expressions.map(_.dataType)
+
+  @transient private lazy val comparableWrapperFactory =
+    
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes)
+
+  @transient private lazy val rowOrdering = 
RowOrdering.createNaturalAscendingOrdering(dataTypes)
+
+  @transient lazy val isGrouped: Boolean = {
+    partitionKeys.map(comparableWrapperFactory).distinct.size == 
partitionKeys.size
+  }
+
+  def toGrouped: KeyedPartitioning = {
+    val distinctSortedPartitionKeys =
+      partitionKeys.distinctBy(comparableWrapperFactory).sorted(rowOrdering)

Review Comment:
   Since the input partitionKeys are sorted in ascending order, will the output 
of `distinctBy` also be sorted in ascending order (since Scala's `distinctBy` 
returns a subsequence)? And if so, do we still need the sort here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to