Github user fhueske commented on a diff in the pull request:

    https://github.com/apache/flink/pull/4873#discussion_r147142386
  
    --- Diff: 
flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/logical/DecomposeGroupingSetRule.scala
 ---
    @@ -0,0 +1,140 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *     http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.flink.table.plan.rules.logical
    +
    +import org.apache.calcite.plan.RelOptRule._
    +import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelOptRuleCall}
    +import org.apache.calcite.rel.core.AggregateCall
    +import org.apache.calcite.rel.logical._
    +import org.apache.calcite.rex.RexNode
    +import org.apache.calcite.sql.SqlKind
    +import org.apache.calcite.tools.RelBuilder
    +import org.apache.calcite.util.ImmutableBitSet
    +
    +import scala.collection.JavaConversions._
    +
    +class DecomposeGroupingSetRule
    +  extends RelOptRule(
    +    operand(classOf[LogicalAggregate], any),
    +  "DecomposeGroupingSetRule") {
    +
    +  override def matches(call: RelOptRuleCall): Boolean = {
    +    val agg: LogicalAggregate = call.rel(0).asInstanceOf[LogicalAggregate]
    +    !agg.getGroupSets.isEmpty &&
    +      
DecomposeGroupingSetRule.getGroupIdExprIndexes(agg.getAggCallList).nonEmpty
    +  }
    +
    +  override def onMatch(call: RelOptRuleCall): Unit = {
    +    val agg: LogicalAggregate = call.rel(0).asInstanceOf[LogicalAggregate]
    +    val groupIdExprs = 
DecomposeGroupingSetRule.getGroupIdExprIndexes(agg.getAggCallList).toSet
    +
    +    val subAggs = agg.groupSets.map(set =>
    +      DecomposeGroupingSetRule.decompose(call.builder(), agg, 
groupIdExprs, set))
    +
    +    val union = subAggs.reduce((l, r) => new LogicalUnion(
    +      agg.getCluster,
    +      agg.getTraitSet,
    +      Seq(l, r),
    +      true
    +    ))
    +    call.transformTo(union)
    +  }
    +}
    +
    +object DecomposeGroupingSetRule {
    +  val INSTANCE = new DecomposeGroupingSetRule
    +
    +  private def getGroupIdExprIndexes(aggCalls: Seq[AggregateCall]) = {
    +    aggCalls.zipWithIndex.filter { case (call, _) =>
    +        call.getAggregation.getKind match {
    +          case SqlKind.GROUP_ID | SqlKind.GROUPING | SqlKind.GROUPING_ID =>
    +            true
    +          case _ =>
    +            false
    +        }
    +    }.map { case (_, idx) => idx}
    +  }
    +
    +  private def decompose(
    +     relBuilder: RelBuilder,
    +     agg: LogicalAggregate,
    +     groupExprIndexes : Set[Int],
    +     groupSet: ImmutableBitSet
    +  ) = {
    +    val aggsWithIndexes = agg.getAggCallList.zipWithIndex
    +    val subAgg = new LogicalAggregate(
    +      agg.getCluster,
    +      agg.getTraitSet,
    +      agg.getInput,
    +      false,
    +      groupSet,
    +      Seq(),
    +      aggsWithIndexes
    +        .filter { case (_, idx) => !groupExprIndexes.contains(idx) }
    +        .map { case (call, _) => call}
    +    )
    +
    +    val rexBuilder = relBuilder.getRexBuilder
    +    relBuilder.push(subAgg)
    +
    +    val groupingFields = new Array[RexNode](agg.getGroupCount)
    +    val groupingFieldsName = Seq.range(0, agg.getGroupCount).map(
    +      x => agg.getRowType.getFieldNames.get(x)
    +    )
    +    Seq.range(0, agg.getGroupCount).foreach(x =>
    +      groupingFields(x) = rexBuilder.makeNullLiteral(
    +        agg.getRowType.getFieldList.get(x).getType)
    +    )
    +
    +    groupSet.toList.zipWithIndex.foreach { case (group, idx) =>
    +      groupingFields(group) = rexBuilder.makeInputRef(relBuilder.peek(), 
idx)
    +    }
    +
    +    val aggFields = aggsWithIndexes.map { case (call, idx) =>
    +      if (groupExprIndexes.contains(idx)) {
    +        lowerGroupExpr(agg.getCluster, call, groupSet)
    +      } else {
    +        rexBuilder.makeInputRef(subAgg, idx + subAgg.getGroupCount)
    --- End diff --
    
    this will break if there is a group id expression with a smaller index than 
a aggregation function.
    For example when we have `SELECT SUM(x), GROUP_ID(), AVG(y)`, `AVG(y)` will 
have the index 2 because it's at pos 3 but will be a position `groupCnt + 1` in 
the result of the sub aggregation.


---

Reply via email to