Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/4873#discussion_r147142386 --- Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/logical/DecomposeGroupingSetRule.scala --- @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical + +import org.apache.calcite.plan.RelOptRule._ +import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelOptRuleCall} +import org.apache.calcite.rel.core.AggregateCall +import org.apache.calcite.rel.logical._ +import org.apache.calcite.rex.RexNode +import org.apache.calcite.sql.SqlKind +import org.apache.calcite.tools.RelBuilder +import org.apache.calcite.util.ImmutableBitSet + +import scala.collection.JavaConversions._ + +class DecomposeGroupingSetRule + extends RelOptRule( + operand(classOf[LogicalAggregate], any), + "DecomposeGroupingSetRule") { + + override def matches(call: RelOptRuleCall): Boolean = { + val agg: LogicalAggregate = call.rel(0).asInstanceOf[LogicalAggregate] + !agg.getGroupSets.isEmpty && + DecomposeGroupingSetRule.getGroupIdExprIndexes(agg.getAggCallList).nonEmpty + } + + override def onMatch(call: RelOptRuleCall): Unit = { + val agg: LogicalAggregate = call.rel(0).asInstanceOf[LogicalAggregate] + val groupIdExprs = DecomposeGroupingSetRule.getGroupIdExprIndexes(agg.getAggCallList).toSet + + val subAggs = agg.groupSets.map(set => + DecomposeGroupingSetRule.decompose(call.builder(), agg, groupIdExprs, set)) + + val union = subAggs.reduce((l, r) => new LogicalUnion( + agg.getCluster, + agg.getTraitSet, + Seq(l, r), + true + )) + call.transformTo(union) + } +} + +object DecomposeGroupingSetRule { + val INSTANCE = new DecomposeGroupingSetRule + + private def getGroupIdExprIndexes(aggCalls: Seq[AggregateCall]) = { + aggCalls.zipWithIndex.filter { case (call, _) => + call.getAggregation.getKind match { + case SqlKind.GROUP_ID | SqlKind.GROUPING | SqlKind.GROUPING_ID => + true + case _ => + false + } + }.map { case (_, idx) => idx} + } + + private def decompose( + relBuilder: RelBuilder, + agg: LogicalAggregate, + groupExprIndexes : Set[Int], + groupSet: ImmutableBitSet + ) = { + val aggsWithIndexes = agg.getAggCallList.zipWithIndex + val subAgg = new LogicalAggregate( + agg.getCluster, + agg.getTraitSet, + agg.getInput, + false, + groupSet, + Seq(), + aggsWithIndexes + .filter { case (_, idx) => !groupExprIndexes.contains(idx) } + .map { case (call, _) => call} + ) + + val rexBuilder = relBuilder.getRexBuilder + relBuilder.push(subAgg) + + val groupingFields = new Array[RexNode](agg.getGroupCount) + val groupingFieldsName = Seq.range(0, agg.getGroupCount).map( + x => agg.getRowType.getFieldNames.get(x) + ) + Seq.range(0, agg.getGroupCount).foreach(x => + groupingFields(x) = rexBuilder.makeNullLiteral( + agg.getRowType.getFieldList.get(x).getType) + ) + + groupSet.toList.zipWithIndex.foreach { case (group, idx) => + groupingFields(group) = rexBuilder.makeInputRef(relBuilder.peek(), idx) + } + + val aggFields = aggsWithIndexes.map { case (call, idx) => + if (groupExprIndexes.contains(idx)) { + lowerGroupExpr(agg.getCluster, call, groupSet) + } else { + rexBuilder.makeInputRef(subAgg, idx + subAgg.getGroupCount) --- End diff -- this will break if there is a group id expression with a smaller index than a aggregation function. For example when we have `SELECT SUM(x), GROUP_ID(), AVG(y)`, `AVG(y)` will have the index 2 because it's at pos 3 but will be a position `groupCnt + 1` in the result of the sub aggregation.
---