[ https://issues.apache.org/jira/browse/HIVE-23434?focusedWorklogId=433181&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-433181 ]
ASF GitHub Bot logged work on HIVE-23434: ----------------------------------------- Author: ASF GitHub Bot Created on: 14/May/20 15:17 Start Date: 14/May/20 15:17 Worklog Time Spent: 10m Work Description: jcamachor commented on a change in pull request #1017: URL: https://github.com/apache/hive/pull/1017#discussion_r424859544 ########## File path: ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRewriteToDataSketchesRule.java ########## @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.rules; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.hep.HepRelVertex; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories.ProjectFactory; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableBitSet.Builder; +import org.apache.hadoop.hive.ql.exec.DataSketchesFunctions; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; +import org.apache.hive.plugin.api.HiveUDFPlugin.UDFDescriptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +/** + * This rule could rewrite aggregate calls to be calculated using sketch based functions. + * + * <br/> + * Currently it can rewrite: + * <ul> + * <li>{@code count(distinct(x))} to distinct counting sketches</li> + * <li>{@code percentile_cont(0.2) within group (order by id)}</li> + * </ul> + * + * <p> + * The transformation here works on Aggregate nodes; the operations done are the following: + * </p> + * <ol> + * <li>Identify candidate aggregate calls</li> + * <li>A new Project is inserted below the Aggregate; to help with data pre-processing</li> + * <li>A new Aggregate is created in which the aggregation is done by the sketch function</li> + * <li>A new Project is inserted on top of the Aggregate; which unwraps the resulting + * count-distinct estimation from the sketch representation</li> + * </ol> + */ +public final class HiveRewriteToDataSketchesRule extends RelOptRule { + + protected static final Logger LOG = LoggerFactory.getLogger(HiveRewriteToDataSketchesRule.class); + private final Optional<String> countDistinctSketchType; + private final Optional<String> percentileContSketchType; + private final ProjectFactory projectFactory; + + public HiveRewriteToDataSketchesRule(Optional<String> countDistinctSketchType, + Optional<String> percentileContSketchType) { + super(operand(HiveAggregate.class, any())); + this.countDistinctSketchType = countDistinctSketchType; + this.percentileContSketchType = percentileContSketchType; + projectFactory = HiveRelFactories.HIVE_PROJECT_FACTORY; + } + + @Override + public void onMatch(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + + if (aggregate.getGroupSets().size() != 1) { + // not yet supported + return; + } + + List<AggregateCall> newAggCalls = new ArrayList<AggregateCall>(); + + VBuilder vb = new VBuilder(aggregate); + + if (aggregate.getAggCallList().equals(vb.newAggCalls)) { + // rule didn't made any changes + return; + } + + newAggCalls = vb.newAggCalls; + List<String> filedNames=new ArrayList<String>(); Review comment: typo. filedNames -> fieldNames ########## File path: ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRewriteToDataSketchesRule.java ########## @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.rules; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.hep.HepRelVertex; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories.ProjectFactory; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableBitSet.Builder; +import org.apache.hadoop.hive.ql.exec.DataSketchesFunctions; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; +import org.apache.hive.plugin.api.HiveUDFPlugin.UDFDescriptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +/** + * This rule could rewrite aggregate calls to be calculated using sketch based functions. + * + * <br/> + * Currently it can rewrite: + * <ul> + * <li>{@code count(distinct(x))} to distinct counting sketches</li> + * <li>{@code percentile_cont(0.2) within group (order by id)}</li> + * </ul> + * + * <p> + * The transformation here works on Aggregate nodes; the operations done are the following: + * </p> + * <ol> + * <li>Identify candidate aggregate calls</li> + * <li>A new Project is inserted below the Aggregate; to help with data pre-processing</li> + * <li>A new Aggregate is created in which the aggregation is done by the sketch function</li> + * <li>A new Project is inserted on top of the Aggregate; which unwraps the resulting + * count-distinct estimation from the sketch representation</li> + * </ol> + */ +public final class HiveRewriteToDataSketchesRule extends RelOptRule { + + protected static final Logger LOG = LoggerFactory.getLogger(HiveRewriteToDataSketchesRule.class); + private final Optional<String> countDistinctSketchType; + private final Optional<String> percentileContSketchType; + private final ProjectFactory projectFactory; + + public HiveRewriteToDataSketchesRule(Optional<String> countDistinctSketchType, + Optional<String> percentileContSketchType) { + super(operand(HiveAggregate.class, any())); + this.countDistinctSketchType = countDistinctSketchType; + this.percentileContSketchType = percentileContSketchType; + projectFactory = HiveRelFactories.HIVE_PROJECT_FACTORY; + } + + @Override + public void onMatch(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + + if (aggregate.getGroupSets().size() != 1) { + // not yet supported + return; + } + + List<AggregateCall> newAggCalls = new ArrayList<AggregateCall>(); + + VBuilder vb = new VBuilder(aggregate); + + if (aggregate.getAggCallList().equals(vb.newAggCalls)) { + // rule didn't made any changes Review comment: typo. didn't made -> didn't make ########## File path: ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRewriteToDataSketchesRule.java ########## @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.rules; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.hep.HepRelVertex; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories.ProjectFactory; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableBitSet.Builder; +import org.apache.hadoop.hive.ql.exec.DataSketchesFunctions; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; +import org.apache.hive.plugin.api.HiveUDFPlugin.UDFDescriptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +/** + * This rule could rewrite aggregate calls to be calculated using sketch based functions. + * + * <br/> + * Currently it can rewrite: + * <ul> + * <li>{@code count(distinct(x))} to distinct counting sketches</li> Review comment: Thanks for adding this comment Zoltan, this is useful. Could we add the simple SQL examples too, e.g., those in the design doc? The rewritings are complex and I think it will be very useful if anyone needs to understand what is supported and how they are working. ########## File path: ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRewriteToDataSketchesRule.java ########## @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.rules; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.hep.HepRelVertex; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories.ProjectFactory; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableBitSet.Builder; +import org.apache.hadoop.hive.ql.exec.DataSketchesFunctions; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; +import org.apache.hive.plugin.api.HiveUDFPlugin.UDFDescriptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +/** + * This rule could rewrite aggregate calls to be calculated using sketch based functions. + * + * <br/> + * Currently it can rewrite: + * <ul> + * <li>{@code count(distinct(x))} to distinct counting sketches</li> + * <li>{@code percentile_cont(0.2) within group (order by id)}</li> + * </ul> + * + * <p> + * The transformation here works on Aggregate nodes; the operations done are the following: + * </p> + * <ol> + * <li>Identify candidate aggregate calls</li> + * <li>A new Project is inserted below the Aggregate; to help with data pre-processing</li> + * <li>A new Aggregate is created in which the aggregation is done by the sketch function</li> + * <li>A new Project is inserted on top of the Aggregate; which unwraps the resulting + * count-distinct estimation from the sketch representation</li> + * </ol> + */ +public final class HiveRewriteToDataSketchesRule extends RelOptRule { + + protected static final Logger LOG = LoggerFactory.getLogger(HiveRewriteToDataSketchesRule.class); + private final Optional<String> countDistinctSketchType; + private final Optional<String> percentileContSketchType; + private final ProjectFactory projectFactory; + + public HiveRewriteToDataSketchesRule(Optional<String> countDistinctSketchType, + Optional<String> percentileContSketchType) { + super(operand(HiveAggregate.class, any())); + this.countDistinctSketchType = countDistinctSketchType; + this.percentileContSketchType = percentileContSketchType; + projectFactory = HiveRelFactories.HIVE_PROJECT_FACTORY; + } + + @Override + public void onMatch(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + + if (aggregate.getGroupSets().size() != 1) { + // not yet supported + return; + } + + List<AggregateCall> newAggCalls = new ArrayList<AggregateCall>(); + + VBuilder vb = new VBuilder(aggregate); + + if (aggregate.getAggCallList().equals(vb.newAggCalls)) { + // rule didn't made any changes + return; + } + + newAggCalls = vb.newAggCalls; + List<String> filedNames=new ArrayList<String>(); + for (int i=0;i<vb.newProjectsBelow.size();i++ ) { + filedNames.add("ff_"+i); + } + RelNode newProjectBelow= + projectFactory.createProject(aggregate.getInput(), vb.newProjectsBelow, filedNames); + + RelNode newAgg = aggregate.copy(aggregate.getTraitSet(), newProjectBelow, aggregate.getGroupSet(), + aggregate.getGroupSets(), newAggCalls); + + RelNode newProject = projectFactory.createProject(newAgg, vb.newProjects, aggregate.getRowType().getFieldNames()); + + call.transformTo(newProject); + return; + } + + /** + * Helper class to help in building a new Aggregate and Project. + */ + // NOTE: methods in this class are not re-entrant; drop-to-frame to constructor during debugging + private class VBuilder { + + private final RexBuilder rexBuilder; + + private Aggregate aggregate; + private List<AggregateCall> newAggCalls; + private List<RexNode> newProjects; + private List<RexNode> newProjectsBelow; + private List<RewriteProcedure> rewrites; + + public VBuilder(Aggregate aggregate) { + this.aggregate = aggregate; + newAggCalls = new ArrayList<AggregateCall>(); + newProjects = new ArrayList<RexNode>(); + newProjectsBelow = new ArrayList<RexNode>(); + rexBuilder = aggregate.getCluster().getRexBuilder(); + rewrites = new ArrayList<RewriteProcedure>(); + + // add identity projections + addProjectedFields(); + + if (countDistinctSketchType.isPresent()) { + rewrites.add(new CountDistinctRewrite(countDistinctSketchType.get())); + } + if (percentileContSketchType.isPresent()) { + rewrites.add(new PercentileContRewrite(percentileContSketchType.get())); + } + + for (AggregateCall aggCall : aggregate.getAggCallList()) { + processAggCall(aggCall); + } + } + + private void addProjectedFields() { + for (int i = 0; i < aggregate.getGroupCount(); i++) { + newProjects.add(rexBuilder.makeInputRef(aggregate.getInput(), i)); + } + Builder b = ImmutableBitSet.builder(); + b.addAll(aggregate.getGroupSet()); + for (AggregateCall aggCall: aggregate.getAggCallList()) { + b.addAll(aggCall.getArgList()); + } + ImmutableBitSet inputs = b.build(); + Integer maxIdx = Collections.max(inputs.asSet()); + for (int i = 0; i < maxIdx; i++) { Review comment: Not sure what the intention is here. Are you trying to add all existing projections below (as an identity)? If that is the case, you could rely on the number of fields in the `rowType` of the `aggregate.getInput()`. ########## File path: ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRewriteToDataSketchesRule.java ########## @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.rules; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.hep.HepRelVertex; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories.ProjectFactory; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableBitSet.Builder; +import org.apache.hadoop.hive.ql.exec.DataSketchesFunctions; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; +import org.apache.hive.plugin.api.HiveUDFPlugin.UDFDescriptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +/** + * This rule could rewrite aggregate calls to be calculated using sketch based functions. + * + * <br/> + * Currently it can rewrite: + * <ul> + * <li>{@code count(distinct(x))} to distinct counting sketches</li> + * <li>{@code percentile_cont(0.2) within group (order by id)}</li> + * </ul> + * + * <p> + * The transformation here works on Aggregate nodes; the operations done are the following: + * </p> + * <ol> + * <li>Identify candidate aggregate calls</li> + * <li>A new Project is inserted below the Aggregate; to help with data pre-processing</li> + * <li>A new Aggregate is created in which the aggregation is done by the sketch function</li> + * <li>A new Project is inserted on top of the Aggregate; which unwraps the resulting + * count-distinct estimation from the sketch representation</li> + * </ol> + */ +public final class HiveRewriteToDataSketchesRule extends RelOptRule { + + protected static final Logger LOG = LoggerFactory.getLogger(HiveRewriteToDataSketchesRule.class); + private final Optional<String> countDistinctSketchType; + private final Optional<String> percentileContSketchType; + private final ProjectFactory projectFactory; + + public HiveRewriteToDataSketchesRule(Optional<String> countDistinctSketchType, + Optional<String> percentileContSketchType) { + super(operand(HiveAggregate.class, any())); + this.countDistinctSketchType = countDistinctSketchType; + this.percentileContSketchType = percentileContSketchType; + projectFactory = HiveRelFactories.HIVE_PROJECT_FACTORY; + } + + @Override + public void onMatch(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + + if (aggregate.getGroupSets().size() != 1) { + // not yet supported + return; + } + + List<AggregateCall> newAggCalls = new ArrayList<AggregateCall>(); + + VBuilder vb = new VBuilder(aggregate); + + if (aggregate.getAggCallList().equals(vb.newAggCalls)) { + // rule didn't made any changes + return; + } + + newAggCalls = vb.newAggCalls; + List<String> filedNames=new ArrayList<String>(); + for (int i=0;i<vb.newProjectsBelow.size();i++ ) { + filedNames.add("ff_"+i); + } + RelNode newProjectBelow= + projectFactory.createProject(aggregate.getInput(), vb.newProjectsBelow, filedNames); + + RelNode newAgg = aggregate.copy(aggregate.getTraitSet(), newProjectBelow, aggregate.getGroupSet(), + aggregate.getGroupSets(), newAggCalls); + + RelNode newProject = projectFactory.createProject(newAgg, vb.newProjects, aggregate.getRowType().getFieldNames()); + + call.transformTo(newProject); + return; + } + + /** + * Helper class to help in building a new Aggregate and Project. + */ + // NOTE: methods in this class are not re-entrant; drop-to-frame to constructor during debugging + private class VBuilder { + + private final RexBuilder rexBuilder; + + private Aggregate aggregate; + private List<AggregateCall> newAggCalls; + private List<RexNode> newProjects; + private List<RexNode> newProjectsBelow; + private List<RewriteProcedure> rewrites; + + public VBuilder(Aggregate aggregate) { + this.aggregate = aggregate; + newAggCalls = new ArrayList<AggregateCall>(); + newProjects = new ArrayList<RexNode>(); + newProjectsBelow = new ArrayList<RexNode>(); + rexBuilder = aggregate.getCluster().getRexBuilder(); + rewrites = new ArrayList<RewriteProcedure>(); + + // add identity projections + addProjectedFields(); + + if (countDistinctSketchType.isPresent()) { + rewrites.add(new CountDistinctRewrite(countDistinctSketchType.get())); + } + if (percentileContSketchType.isPresent()) { + rewrites.add(new PercentileContRewrite(percentileContSketchType.get())); + } + + for (AggregateCall aggCall : aggregate.getAggCallList()) { + processAggCall(aggCall); + } + } + + private void addProjectedFields() { + for (int i = 0; i < aggregate.getGroupCount(); i++) { + newProjects.add(rexBuilder.makeInputRef(aggregate.getInput(), i)); + } + Builder b = ImmutableBitSet.builder(); + b.addAll(aggregate.getGroupSet()); + for (AggregateCall aggCall: aggregate.getAggCallList()) { + b.addAll(aggCall.getArgList()); + } + ImmutableBitSet inputs = b.build(); + Integer maxIdx = Collections.max(inputs.asSet()); + for (int i = 0; i < maxIdx; i++) { + newProjectsBelow.add(rexBuilder.makeInputRef(aggregate.getInput(), i)); + } + } + + private void processAggCall(AggregateCall aggCall) { + for (RewriteProcedure rewrite : rewrites) { + if (rewrite.isApplicable(aggCall)) { + rewrite.rewrite(aggCall); + return; + } + } + appendAggCall(aggCall); + } + + private void appendAggCall(AggregateCall aggCall) { + RexNode projRex = rexBuilder.makeInputRef(aggCall.getType(), newProjects.size()); + + newAggCalls.add(aggCall); + newProjects.add(projRex); + } + + abstract class RewriteProcedure { + + private final String sketchClass; + + public RewriteProcedure(String sketchClass) { + this.sketchClass = sketchClass; + } + + abstract boolean isApplicable(AggregateCall aggCall); + abstract void rewrite(AggregateCall aggCall); + + protected SqlOperator getSqlOperator(String fnName) { + UDFDescriptor fn = DataSketchesFunctions.INSTANCE.getSketchFunction(sketchClass, fnName); + if (!fn.getCalciteFunction().isPresent()) { + throw new RuntimeException(fn.toString() + " doesn't have a Calcite function associated with it"); + } + return fn.getCalciteFunction().get(); + } + + } + + class CountDistinctRewrite extends RewriteProcedure { + + public CountDistinctRewrite(String sketchClass) { + super(sketchClass); + } + + @Override + boolean isApplicable(AggregateCall aggCall) { + return aggCall.isDistinct() && aggCall.getArgList().size() == 1 + && aggCall.getAggregation().getName().equalsIgnoreCase("count") && !aggCall.hasFilter(); + } + + @Override + void rewrite(AggregateCall aggCall) { + RelDataType origType = aggregate.getRowType().getFieldList().get(newProjects.size()).getType(); + + Integer argIndex = aggCall.getArgList().get(0); + RexNode call = rexBuilder.makeInputRef(aggregate.getInput(), argIndex); + newProjectsBelow.add(call); + + ArrayList<Integer> newArgList = Lists.newArrayList(newProjectsBelow.size() - 1); + + SqlAggFunction aggFunction = (SqlAggFunction) getSqlOperator(DataSketchesFunctions.DATA_TO_SKETCH); + boolean distinct = false; + boolean approximate = true; + boolean ignoreNulls = true; + List<Integer> argList = newArgList; + int filterArg = aggCall.filterArg; + RelCollation collation = aggCall.getCollation(); + RelDataType type = rexBuilder.deriveReturnType(aggFunction, Collections.emptyList()); + String name = aggFunction.getName(); + + AggregateCall newAgg = AggregateCall.create(aggFunction, distinct, approximate, ignoreNulls, argList, filterArg, + collation, type, name); + + SqlOperator projectOperator = getSqlOperator(DataSketchesFunctions.SKETCH_TO_ESTIMATE); + RexNode projRex = rexBuilder.makeInputRef(newAgg.getType(), newProjects.size()); + projRex = rexBuilder.makeCall(projectOperator, ImmutableList.of(projRex)); + projRex = rexBuilder.makeCall(SqlStdOperatorTable.ROUND, ImmutableList.of(projRex)); + projRex = rexBuilder.makeCast(origType, projRex); + + newAggCalls.add(newAgg); + newProjects.add(projRex); + } + + } + + class PercentileContRewrite extends RewriteProcedure { + + public PercentileContRewrite(String sketchClass) { + super(sketchClass); + } + + @Override + boolean isApplicable(AggregateCall aggCall) { + // FIXME: also check that args are: ?,?,1,0 - other cases are not supported + return !aggCall.isDistinct() && aggCall.getArgList().size() == 4 + && aggCall.getAggregation().getName().equalsIgnoreCase("percentile_cont") && !aggCall.hasFilter(); + } + + @Override + void rewrite(AggregateCall aggCall) { + RelDataType origType = aggregate.getRowType().getFieldList().get(newProjects.size()).getType(); + + Integer argIndex = aggCall.getArgList().get(1); + RexNode call = rexBuilder.makeInputRef(aggregate.getInput(), argIndex); + + RelDataType floatType = rexBuilder.getTypeFactory().createSqlType(SqlTypeName.FLOAT); + call = rexBuilder.makeCast(floatType, call); + newProjectsBelow.add(call); + + ArrayList<Integer> newArgList = Lists.newArrayList(newProjectsBelow.size() - 1); + + SqlAggFunction aggFunction = (SqlAggFunction) getSqlOperator(DataSketchesFunctions.DATA_TO_SKETCH); + boolean distinct = false; + boolean approximate = true; + boolean ignoreNulls = true; + List<Integer> argList = newArgList; + int filterArg = aggCall.filterArg; + RelCollation collation = aggCall.getCollation(); + RelDataType type = rexBuilder.deriveReturnType(aggFunction, Collections.emptyList()); + String name = aggFunction.getName(); + + AggregateCall newAgg = AggregateCall.create(aggFunction, distinct, approximate, ignoreNulls, argList, filterArg, + collation, type, name); + + Integer origFractionIdx = aggCall.getArgList().get(0); + RexNode fraction = getProject(aggregate.getInput()).getChildExps().get(origFractionIdx); + fraction = rexBuilder.makeCast(floatType, fraction); + + SqlOperator projectOperator = getSqlOperator(DataSketchesFunctions.GET_QUANTILE); + RexNode projRex = rexBuilder.makeInputRef(newAgg.getType(), newProjects.size()); + projRex = rexBuilder.makeCall(projectOperator, ImmutableList.of(projRex, fraction)); + projRex = rexBuilder.makeCast(origType, projRex); + + newAggCalls.add(newAgg); + newProjects.add(projRex); + + } + + } + + private Project getProject(RelNode input) { + if (input instanceof Project) { + return (Project) input; + } + if (input instanceof HepRelVertex) { Review comment: Instead of doing this, just match the rule on the aggregate input and pass it as a parameter to the class. That way we reduce dependencies if case we ever use the rule with any other planner. In addition, why do we enforce a Project? Couldn't it be any other operator? ########## File path: ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRewriteToDataSketchesRule.java ########## @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.rules; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.hep.HepRelVertex; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories.ProjectFactory; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableBitSet.Builder; +import org.apache.hadoop.hive.ql.exec.DataSketchesFunctions; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; +import org.apache.hive.plugin.api.HiveUDFPlugin.UDFDescriptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +/** + * This rule could rewrite aggregate calls to be calculated using sketch based functions. + * + * <br/> + * Currently it can rewrite: + * <ul> + * <li>{@code count(distinct(x))} to distinct counting sketches</li> + * <li>{@code percentile_cont(0.2) within group (order by id)}</li> + * </ul> + * + * <p> + * The transformation here works on Aggregate nodes; the operations done are the following: + * </p> + * <ol> + * <li>Identify candidate aggregate calls</li> + * <li>A new Project is inserted below the Aggregate; to help with data pre-processing</li> + * <li>A new Aggregate is created in which the aggregation is done by the sketch function</li> + * <li>A new Project is inserted on top of the Aggregate; which unwraps the resulting + * count-distinct estimation from the sketch representation</li> + * </ol> + */ +public final class HiveRewriteToDataSketchesRule extends RelOptRule { + + protected static final Logger LOG = LoggerFactory.getLogger(HiveRewriteToDataSketchesRule.class); + private final Optional<String> countDistinctSketchType; + private final Optional<String> percentileContSketchType; + private final ProjectFactory projectFactory; + + public HiveRewriteToDataSketchesRule(Optional<String> countDistinctSketchType, + Optional<String> percentileContSketchType) { + super(operand(HiveAggregate.class, any())); + this.countDistinctSketchType = countDistinctSketchType; + this.percentileContSketchType = percentileContSketchType; + projectFactory = HiveRelFactories.HIVE_PROJECT_FACTORY; + } + + @Override + public void onMatch(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + + if (aggregate.getGroupSets().size() != 1) { + // not yet supported + return; + } + + List<AggregateCall> newAggCalls = new ArrayList<AggregateCall>(); + + VBuilder vb = new VBuilder(aggregate); + + if (aggregate.getAggCallList().equals(vb.newAggCalls)) { + // rule didn't made any changes + return; + } + + newAggCalls = vb.newAggCalls; + List<String> filedNames=new ArrayList<String>(); + for (int i=0;i<vb.newProjectsBelow.size();i++ ) { + filedNames.add("ff_"+i); + } + RelNode newProjectBelow= + projectFactory.createProject(aggregate.getInput(), vb.newProjectsBelow, filedNames); + + RelNode newAgg = aggregate.copy(aggregate.getTraitSet(), newProjectBelow, aggregate.getGroupSet(), + aggregate.getGroupSets(), newAggCalls); + + RelNode newProject = projectFactory.createProject(newAgg, vb.newProjects, aggregate.getRowType().getFieldNames()); + + call.transformTo(newProject); + return; + } + + /** + * Helper class to help in building a new Aggregate and Project. + */ + // NOTE: methods in this class are not re-entrant; drop-to-frame to constructor during debugging + private class VBuilder { + + private final RexBuilder rexBuilder; + + private Aggregate aggregate; + private List<AggregateCall> newAggCalls; + private List<RexNode> newProjects; + private List<RexNode> newProjectsBelow; + private List<RewriteProcedure> rewrites; + + public VBuilder(Aggregate aggregate) { + this.aggregate = aggregate; + newAggCalls = new ArrayList<AggregateCall>(); + newProjects = new ArrayList<RexNode>(); + newProjectsBelow = new ArrayList<RexNode>(); + rexBuilder = aggregate.getCluster().getRexBuilder(); + rewrites = new ArrayList<RewriteProcedure>(); + + // add identity projections + addProjectedFields(); + + if (countDistinctSketchType.isPresent()) { + rewrites.add(new CountDistinctRewrite(countDistinctSketchType.get())); + } + if (percentileContSketchType.isPresent()) { + rewrites.add(new PercentileContRewrite(percentileContSketchType.get())); + } + + for (AggregateCall aggCall : aggregate.getAggCallList()) { + processAggCall(aggCall); + } + } + + private void addProjectedFields() { + for (int i = 0; i < aggregate.getGroupCount(); i++) { + newProjects.add(rexBuilder.makeInputRef(aggregate.getInput(), i)); + } + Builder b = ImmutableBitSet.builder(); + b.addAll(aggregate.getGroupSet()); + for (AggregateCall aggCall: aggregate.getAggCallList()) { + b.addAll(aggCall.getArgList()); + } + ImmutableBitSet inputs = b.build(); + Integer maxIdx = Collections.max(inputs.asSet()); + for (int i = 0; i < maxIdx; i++) { + newProjectsBelow.add(rexBuilder.makeInputRef(aggregate.getInput(), i)); + } + } + + private void processAggCall(AggregateCall aggCall) { + for (RewriteProcedure rewrite : rewrites) { + if (rewrite.isApplicable(aggCall)) { + rewrite.rewrite(aggCall); + return; + } + } + appendAggCall(aggCall); + } + + private void appendAggCall(AggregateCall aggCall) { + RexNode projRex = rexBuilder.makeInputRef(aggCall.getType(), newProjects.size()); + + newAggCalls.add(aggCall); + newProjects.add(projRex); + } + + abstract class RewriteProcedure { + + private final String sketchClass; + + public RewriteProcedure(String sketchClass) { + this.sketchClass = sketchClass; + } + + abstract boolean isApplicable(AggregateCall aggCall); + abstract void rewrite(AggregateCall aggCall); + + protected SqlOperator getSqlOperator(String fnName) { + UDFDescriptor fn = DataSketchesFunctions.INSTANCE.getSketchFunction(sketchClass, fnName); + if (!fn.getCalciteFunction().isPresent()) { + throw new RuntimeException(fn.toString() + " doesn't have a Calcite function associated with it"); + } + return fn.getCalciteFunction().get(); + } + + } + + class CountDistinctRewrite extends RewriteProcedure { + + public CountDistinctRewrite(String sketchClass) { + super(sketchClass); + } + + @Override + boolean isApplicable(AggregateCall aggCall) { + return aggCall.isDistinct() && aggCall.getArgList().size() == 1 + && aggCall.getAggregation().getName().equalsIgnoreCase("count") && !aggCall.hasFilter(); + } + + @Override + void rewrite(AggregateCall aggCall) { + RelDataType origType = aggregate.getRowType().getFieldList().get(newProjects.size()).getType(); + + Integer argIndex = aggCall.getArgList().get(0); + RexNode call = rexBuilder.makeInputRef(aggregate.getInput(), argIndex); + newProjectsBelow.add(call); + + ArrayList<Integer> newArgList = Lists.newArrayList(newProjectsBelow.size() - 1); + + SqlAggFunction aggFunction = (SqlAggFunction) getSqlOperator(DataSketchesFunctions.DATA_TO_SKETCH); + boolean distinct = false; + boolean approximate = true; + boolean ignoreNulls = true; + List<Integer> argList = newArgList; + int filterArg = aggCall.filterArg; + RelCollation collation = aggCall.getCollation(); + RelDataType type = rexBuilder.deriveReturnType(aggFunction, Collections.emptyList()); + String name = aggFunction.getName(); + + AggregateCall newAgg = AggregateCall.create(aggFunction, distinct, approximate, ignoreNulls, argList, filterArg, + collation, type, name); + + SqlOperator projectOperator = getSqlOperator(DataSketchesFunctions.SKETCH_TO_ESTIMATE); + RexNode projRex = rexBuilder.makeInputRef(newAgg.getType(), newProjects.size()); + projRex = rexBuilder.makeCall(projectOperator, ImmutableList.of(projRex)); + projRex = rexBuilder.makeCall(SqlStdOperatorTable.ROUND, ImmutableList.of(projRex)); + projRex = rexBuilder.makeCast(origType, projRex); + + newAggCalls.add(newAgg); + newProjects.add(projRex); + } + + } + + class PercentileContRewrite extends RewriteProcedure { + + public PercentileContRewrite(String sketchClass) { + super(sketchClass); + } + + @Override + boolean isApplicable(AggregateCall aggCall) { + // FIXME: also check that args are: ?,?,1,0 - other cases are not supported + return !aggCall.isDistinct() && aggCall.getArgList().size() == 4 + && aggCall.getAggregation().getName().equalsIgnoreCase("percentile_cont") && !aggCall.hasFilter(); Review comment: Is there a `percentile_cont` in SqlKind? If there is not, maybe we should add it in Calcite. ########## File path: ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRewriteToDataSketchesRule.java ########## @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.rules; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.hep.HepRelVertex; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories.ProjectFactory; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableBitSet.Builder; +import org.apache.hadoop.hive.ql.exec.DataSketchesFunctions; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; +import org.apache.hive.plugin.api.HiveUDFPlugin.UDFDescriptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +/** + * This rule could rewrite aggregate calls to be calculated using sketch based functions. + * + * <br/> + * Currently it can rewrite: + * <ul> + * <li>{@code count(distinct(x))} to distinct counting sketches</li> + * <li>{@code percentile_cont(0.2) within group (order by id)}</li> + * </ul> + * + * <p> + * The transformation here works on Aggregate nodes; the operations done are the following: + * </p> + * <ol> + * <li>Identify candidate aggregate calls</li> + * <li>A new Project is inserted below the Aggregate; to help with data pre-processing</li> + * <li>A new Aggregate is created in which the aggregation is done by the sketch function</li> + * <li>A new Project is inserted on top of the Aggregate; which unwraps the resulting + * count-distinct estimation from the sketch representation</li> + * </ol> + */ +public final class HiveRewriteToDataSketchesRule extends RelOptRule { + + protected static final Logger LOG = LoggerFactory.getLogger(HiveRewriteToDataSketchesRule.class); + private final Optional<String> countDistinctSketchType; + private final Optional<String> percentileContSketchType; + private final ProjectFactory projectFactory; + + public HiveRewriteToDataSketchesRule(Optional<String> countDistinctSketchType, + Optional<String> percentileContSketchType) { + super(operand(HiveAggregate.class, any())); + this.countDistinctSketchType = countDistinctSketchType; + this.percentileContSketchType = percentileContSketchType; + projectFactory = HiveRelFactories.HIVE_PROJECT_FACTORY; + } + + @Override + public void onMatch(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + + if (aggregate.getGroupSets().size() != 1) { + // not yet supported + return; + } + + List<AggregateCall> newAggCalls = new ArrayList<AggregateCall>(); + + VBuilder vb = new VBuilder(aggregate); + + if (aggregate.getAggCallList().equals(vb.newAggCalls)) { + // rule didn't made any changes + return; + } + + newAggCalls = vb.newAggCalls; + List<String> filedNames=new ArrayList<String>(); + for (int i=0;i<vb.newProjectsBelow.size();i++ ) { + filedNames.add("ff_"+i); + } + RelNode newProjectBelow= + projectFactory.createProject(aggregate.getInput(), vb.newProjectsBelow, filedNames); + + RelNode newAgg = aggregate.copy(aggregate.getTraitSet(), newProjectBelow, aggregate.getGroupSet(), + aggregate.getGroupSets(), newAggCalls); + + RelNode newProject = projectFactory.createProject(newAgg, vb.newProjects, aggregate.getRowType().getFieldNames()); + + call.transformTo(newProject); + return; + } + + /** + * Helper class to help in building a new Aggregate and Project. + */ + // NOTE: methods in this class are not re-entrant; drop-to-frame to constructor during debugging + private class VBuilder { + + private final RexBuilder rexBuilder; + + private Aggregate aggregate; Review comment: I think aggregate can be `final`? Please consider whether you can make some of the lists below `final` too. ########## File path: ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRewriteToDataSketchesRule.java ########## @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.rules; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.hep.HepRelVertex; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories.ProjectFactory; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableBitSet.Builder; +import org.apache.hadoop.hive.ql.exec.DataSketchesFunctions; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; +import org.apache.hive.plugin.api.HiveUDFPlugin.UDFDescriptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +/** + * This rule could rewrite aggregate calls to be calculated using sketch based functions. + * + * <br/> + * Currently it can rewrite: + * <ul> + * <li>{@code count(distinct(x))} to distinct counting sketches</li> + * <li>{@code percentile_cont(0.2) within group (order by id)}</li> + * </ul> + * + * <p> + * The transformation here works on Aggregate nodes; the operations done are the following: + * </p> + * <ol> + * <li>Identify candidate aggregate calls</li> + * <li>A new Project is inserted below the Aggregate; to help with data pre-processing</li> + * <li>A new Aggregate is created in which the aggregation is done by the sketch function</li> + * <li>A new Project is inserted on top of the Aggregate; which unwraps the resulting + * count-distinct estimation from the sketch representation</li> + * </ol> + */ +public final class HiveRewriteToDataSketchesRule extends RelOptRule { + + protected static final Logger LOG = LoggerFactory.getLogger(HiveRewriteToDataSketchesRule.class); + private final Optional<String> countDistinctSketchType; + private final Optional<String> percentileContSketchType; + private final ProjectFactory projectFactory; + + public HiveRewriteToDataSketchesRule(Optional<String> countDistinctSketchType, + Optional<String> percentileContSketchType) { + super(operand(HiveAggregate.class, any())); + this.countDistinctSketchType = countDistinctSketchType; + this.percentileContSketchType = percentileContSketchType; + projectFactory = HiveRelFactories.HIVE_PROJECT_FACTORY; + } + + @Override + public void onMatch(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + + if (aggregate.getGroupSets().size() != 1) { + // not yet supported + return; + } + + List<AggregateCall> newAggCalls = new ArrayList<AggregateCall>(); + + VBuilder vb = new VBuilder(aggregate); + + if (aggregate.getAggCallList().equals(vb.newAggCalls)) { + // rule didn't made any changes + return; + } + + newAggCalls = vb.newAggCalls; + List<String> filedNames=new ArrayList<String>(); + for (int i=0;i<vb.newProjectsBelow.size();i++ ) { + filedNames.add("ff_"+i); + } + RelNode newProjectBelow= + projectFactory.createProject(aggregate.getInput(), vb.newProjectsBelow, filedNames); + + RelNode newAgg = aggregate.copy(aggregate.getTraitSet(), newProjectBelow, aggregate.getGroupSet(), + aggregate.getGroupSets(), newAggCalls); + + RelNode newProject = projectFactory.createProject(newAgg, vb.newProjects, aggregate.getRowType().getFieldNames()); + + call.transformTo(newProject); + return; + } + + /** + * Helper class to help in building a new Aggregate and Project. + */ + // NOTE: methods in this class are not re-entrant; drop-to-frame to constructor during debugging + private class VBuilder { + + private final RexBuilder rexBuilder; + + private Aggregate aggregate; + private List<AggregateCall> newAggCalls; Review comment: I think it would be useful to describe what each of these objects/lists will contain at the end of the builder construction. ########## File path: ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRewriteToDataSketchesRule.java ########## @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.rules; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.hep.HepRelVertex; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories.ProjectFactory; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableBitSet.Builder; +import org.apache.hadoop.hive.ql.exec.DataSketchesFunctions; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; +import org.apache.hive.plugin.api.HiveUDFPlugin.UDFDescriptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +/** + * This rule could rewrite aggregate calls to be calculated using sketch based functions. + * + * <br/> + * Currently it can rewrite: + * <ul> + * <li>{@code count(distinct(x))} to distinct counting sketches</li> + * <li>{@code percentile_cont(0.2) within group (order by id)}</li> + * </ul> + * + * <p> + * The transformation here works on Aggregate nodes; the operations done are the following: + * </p> + * <ol> + * <li>Identify candidate aggregate calls</li> + * <li>A new Project is inserted below the Aggregate; to help with data pre-processing</li> + * <li>A new Aggregate is created in which the aggregation is done by the sketch function</li> + * <li>A new Project is inserted on top of the Aggregate; which unwraps the resulting + * count-distinct estimation from the sketch representation</li> + * </ol> + */ +public final class HiveRewriteToDataSketchesRule extends RelOptRule { + + protected static final Logger LOG = LoggerFactory.getLogger(HiveRewriteToDataSketchesRule.class); + private final Optional<String> countDistinctSketchType; + private final Optional<String> percentileContSketchType; + private final ProjectFactory projectFactory; + + public HiveRewriteToDataSketchesRule(Optional<String> countDistinctSketchType, + Optional<String> percentileContSketchType) { + super(operand(HiveAggregate.class, any())); + this.countDistinctSketchType = countDistinctSketchType; + this.percentileContSketchType = percentileContSketchType; + projectFactory = HiveRelFactories.HIVE_PROJECT_FACTORY; + } + + @Override + public void onMatch(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + + if (aggregate.getGroupSets().size() != 1) { + // not yet supported + return; + } + + List<AggregateCall> newAggCalls = new ArrayList<AggregateCall>(); + + VBuilder vb = new VBuilder(aggregate); + + if (aggregate.getAggCallList().equals(vb.newAggCalls)) { + // rule didn't made any changes + return; + } + + newAggCalls = vb.newAggCalls; + List<String> filedNames=new ArrayList<String>(); + for (int i=0;i<vb.newProjectsBelow.size();i++ ) { + filedNames.add("ff_"+i); + } + RelNode newProjectBelow= + projectFactory.createProject(aggregate.getInput(), vb.newProjectsBelow, filedNames); + + RelNode newAgg = aggregate.copy(aggregate.getTraitSet(), newProjectBelow, aggregate.getGroupSet(), + aggregate.getGroupSets(), newAggCalls); + + RelNode newProject = projectFactory.createProject(newAgg, vb.newProjects, aggregate.getRowType().getFieldNames()); + + call.transformTo(newProject); + return; + } + + /** + * Helper class to help in building a new Aggregate and Project. + */ + // NOTE: methods in this class are not re-entrant; drop-to-frame to constructor during debugging + private class VBuilder { + + private final RexBuilder rexBuilder; + + private Aggregate aggregate; + private List<AggregateCall> newAggCalls; + private List<RexNode> newProjects; + private List<RexNode> newProjectsBelow; + private List<RewriteProcedure> rewrites; + + public VBuilder(Aggregate aggregate) { + this.aggregate = aggregate; + newAggCalls = new ArrayList<AggregateCall>(); + newProjects = new ArrayList<RexNode>(); + newProjectsBelow = new ArrayList<RexNode>(); + rexBuilder = aggregate.getCluster().getRexBuilder(); + rewrites = new ArrayList<RewriteProcedure>(); + + // add identity projections + addProjectedFields(); + + if (countDistinctSketchType.isPresent()) { + rewrites.add(new CountDistinctRewrite(countDistinctSketchType.get())); + } + if (percentileContSketchType.isPresent()) { + rewrites.add(new PercentileContRewrite(percentileContSketchType.get())); + } + + for (AggregateCall aggCall : aggregate.getAggCallList()) { + processAggCall(aggCall); + } + } + + private void addProjectedFields() { + for (int i = 0; i < aggregate.getGroupCount(); i++) { + newProjects.add(rexBuilder.makeInputRef(aggregate.getInput(), i)); + } + Builder b = ImmutableBitSet.builder(); + b.addAll(aggregate.getGroupSet()); + for (AggregateCall aggCall: aggregate.getAggCallList()) { + b.addAll(aggCall.getArgList()); + } + ImmutableBitSet inputs = b.build(); + Integer maxIdx = Collections.max(inputs.asSet()); + for (int i = 0; i < maxIdx; i++) { + newProjectsBelow.add(rexBuilder.makeInputRef(aggregate.getInput(), i)); + } + } + + private void processAggCall(AggregateCall aggCall) { + for (RewriteProcedure rewrite : rewrites) { + if (rewrite.isApplicable(aggCall)) { + rewrite.rewrite(aggCall); + return; + } + } + appendAggCall(aggCall); + } + + private void appendAggCall(AggregateCall aggCall) { + RexNode projRex = rexBuilder.makeInputRef(aggCall.getType(), newProjects.size()); + + newAggCalls.add(aggCall); + newProjects.add(projRex); + } + + abstract class RewriteProcedure { + + private final String sketchClass; + + public RewriteProcedure(String sketchClass) { + this.sketchClass = sketchClass; + } + + abstract boolean isApplicable(AggregateCall aggCall); + abstract void rewrite(AggregateCall aggCall); + + protected SqlOperator getSqlOperator(String fnName) { + UDFDescriptor fn = DataSketchesFunctions.INSTANCE.getSketchFunction(sketchClass, fnName); + if (!fn.getCalciteFunction().isPresent()) { + throw new RuntimeException(fn.toString() + " doesn't have a Calcite function associated with it"); + } + return fn.getCalciteFunction().get(); + } + + } + + class CountDistinctRewrite extends RewriteProcedure { + + public CountDistinctRewrite(String sketchClass) { + super(sketchClass); + } + + @Override + boolean isApplicable(AggregateCall aggCall) { + return aggCall.isDistinct() && aggCall.getArgList().size() == 1 + && aggCall.getAggregation().getName().equalsIgnoreCase("count") && !aggCall.hasFilter(); + } + + @Override + void rewrite(AggregateCall aggCall) { + RelDataType origType = aggregate.getRowType().getFieldList().get(newProjects.size()).getType(); + + Integer argIndex = aggCall.getArgList().get(0); + RexNode call = rexBuilder.makeInputRef(aggregate.getInput(), argIndex); + newProjectsBelow.add(call); + + ArrayList<Integer> newArgList = Lists.newArrayList(newProjectsBelow.size() - 1); Review comment: Can this be removed? Iiuc it is just assigned to argList. ########## File path: ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRewriteToDataSketchesRule.java ########## @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.rules; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.hep.HepRelVertex; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories.ProjectFactory; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableBitSet.Builder; +import org.apache.hadoop.hive.ql.exec.DataSketchesFunctions; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; +import org.apache.hive.plugin.api.HiveUDFPlugin.UDFDescriptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +/** + * This rule could rewrite aggregate calls to be calculated using sketch based functions. + * + * <br/> + * Currently it can rewrite: + * <ul> + * <li>{@code count(distinct(x))} to distinct counting sketches</li> + * <li>{@code percentile_cont(0.2) within group (order by id)}</li> + * </ul> + * + * <p> + * The transformation here works on Aggregate nodes; the operations done are the following: + * </p> + * <ol> + * <li>Identify candidate aggregate calls</li> + * <li>A new Project is inserted below the Aggregate; to help with data pre-processing</li> + * <li>A new Aggregate is created in which the aggregation is done by the sketch function</li> + * <li>A new Project is inserted on top of the Aggregate; which unwraps the resulting + * count-distinct estimation from the sketch representation</li> + * </ol> + */ +public final class HiveRewriteToDataSketchesRule extends RelOptRule { + + protected static final Logger LOG = LoggerFactory.getLogger(HiveRewriteToDataSketchesRule.class); + private final Optional<String> countDistinctSketchType; + private final Optional<String> percentileContSketchType; + private final ProjectFactory projectFactory; + + public HiveRewriteToDataSketchesRule(Optional<String> countDistinctSketchType, + Optional<String> percentileContSketchType) { + super(operand(HiveAggregate.class, any())); + this.countDistinctSketchType = countDistinctSketchType; + this.percentileContSketchType = percentileContSketchType; + projectFactory = HiveRelFactories.HIVE_PROJECT_FACTORY; + } + + @Override + public void onMatch(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + + if (aggregate.getGroupSets().size() != 1) { + // not yet supported + return; + } + + List<AggregateCall> newAggCalls = new ArrayList<AggregateCall>(); + + VBuilder vb = new VBuilder(aggregate); + + if (aggregate.getAggCallList().equals(vb.newAggCalls)) { + // rule didn't made any changes + return; + } + + newAggCalls = vb.newAggCalls; + List<String> filedNames=new ArrayList<String>(); + for (int i=0;i<vb.newProjectsBelow.size();i++ ) { + filedNames.add("ff_"+i); + } + RelNode newProjectBelow= + projectFactory.createProject(aggregate.getInput(), vb.newProjectsBelow, filedNames); + + RelNode newAgg = aggregate.copy(aggregate.getTraitSet(), newProjectBelow, aggregate.getGroupSet(), + aggregate.getGroupSets(), newAggCalls); + + RelNode newProject = projectFactory.createProject(newAgg, vb.newProjects, aggregate.getRowType().getFieldNames()); + + call.transformTo(newProject); + return; + } + + /** + * Helper class to help in building a new Aggregate and Project. + */ + // NOTE: methods in this class are not re-entrant; drop-to-frame to constructor during debugging + private class VBuilder { + + private final RexBuilder rexBuilder; + + private Aggregate aggregate; + private List<AggregateCall> newAggCalls; + private List<RexNode> newProjects; + private List<RexNode> newProjectsBelow; + private List<RewriteProcedure> rewrites; + + public VBuilder(Aggregate aggregate) { + this.aggregate = aggregate; + newAggCalls = new ArrayList<AggregateCall>(); + newProjects = new ArrayList<RexNode>(); + newProjectsBelow = new ArrayList<RexNode>(); + rexBuilder = aggregate.getCluster().getRexBuilder(); + rewrites = new ArrayList<RewriteProcedure>(); + + // add identity projections + addProjectedFields(); + + if (countDistinctSketchType.isPresent()) { + rewrites.add(new CountDistinctRewrite(countDistinctSketchType.get())); + } + if (percentileContSketchType.isPresent()) { + rewrites.add(new PercentileContRewrite(percentileContSketchType.get())); + } + + for (AggregateCall aggCall : aggregate.getAggCallList()) { + processAggCall(aggCall); + } + } + + private void addProjectedFields() { + for (int i = 0; i < aggregate.getGroupCount(); i++) { + newProjects.add(rexBuilder.makeInputRef(aggregate.getInput(), i)); Review comment: These projects seem to be used on top of the Aggregate later on. I think we should pass `aggregate` here instead of `aggregate.getInput()`. ########## File path: ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRewriteToDataSketchesRule.java ########## @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.rules; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.hep.HepRelVertex; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories.ProjectFactory; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableBitSet.Builder; +import org.apache.hadoop.hive.ql.exec.DataSketchesFunctions; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; +import org.apache.hive.plugin.api.HiveUDFPlugin.UDFDescriptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +/** + * This rule could rewrite aggregate calls to be calculated using sketch based functions. + * + * <br/> + * Currently it can rewrite: + * <ul> + * <li>{@code count(distinct(x))} to distinct counting sketches</li> + * <li>{@code percentile_cont(0.2) within group (order by id)}</li> + * </ul> + * + * <p> + * The transformation here works on Aggregate nodes; the operations done are the following: + * </p> + * <ol> + * <li>Identify candidate aggregate calls</li> + * <li>A new Project is inserted below the Aggregate; to help with data pre-processing</li> + * <li>A new Aggregate is created in which the aggregation is done by the sketch function</li> + * <li>A new Project is inserted on top of the Aggregate; which unwraps the resulting + * count-distinct estimation from the sketch representation</li> + * </ol> + */ +public final class HiveRewriteToDataSketchesRule extends RelOptRule { + + protected static final Logger LOG = LoggerFactory.getLogger(HiveRewriteToDataSketchesRule.class); + private final Optional<String> countDistinctSketchType; + private final Optional<String> percentileContSketchType; + private final ProjectFactory projectFactory; + + public HiveRewriteToDataSketchesRule(Optional<String> countDistinctSketchType, + Optional<String> percentileContSketchType) { + super(operand(HiveAggregate.class, any())); + this.countDistinctSketchType = countDistinctSketchType; + this.percentileContSketchType = percentileContSketchType; + projectFactory = HiveRelFactories.HIVE_PROJECT_FACTORY; + } + + @Override + public void onMatch(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + + if (aggregate.getGroupSets().size() != 1) { + // not yet supported + return; + } + + List<AggregateCall> newAggCalls = new ArrayList<AggregateCall>(); + + VBuilder vb = new VBuilder(aggregate); + + if (aggregate.getAggCallList().equals(vb.newAggCalls)) { + // rule didn't made any changes + return; + } + + newAggCalls = vb.newAggCalls; + List<String> filedNames=new ArrayList<String>(); + for (int i=0;i<vb.newProjectsBelow.size();i++ ) { + filedNames.add("ff_"+i); + } + RelNode newProjectBelow= + projectFactory.createProject(aggregate.getInput(), vb.newProjectsBelow, filedNames); + + RelNode newAgg = aggregate.copy(aggregate.getTraitSet(), newProjectBelow, aggregate.getGroupSet(), + aggregate.getGroupSets(), newAggCalls); + + RelNode newProject = projectFactory.createProject(newAgg, vb.newProjects, aggregate.getRowType().getFieldNames()); + + call.transformTo(newProject); + return; + } + + /** + * Helper class to help in building a new Aggregate and Project. + */ + // NOTE: methods in this class are not re-entrant; drop-to-frame to constructor during debugging + private class VBuilder { + + private final RexBuilder rexBuilder; + + private Aggregate aggregate; + private List<AggregateCall> newAggCalls; + private List<RexNode> newProjects; + private List<RexNode> newProjectsBelow; + private List<RewriteProcedure> rewrites; + + public VBuilder(Aggregate aggregate) { + this.aggregate = aggregate; + newAggCalls = new ArrayList<AggregateCall>(); + newProjects = new ArrayList<RexNode>(); + newProjectsBelow = new ArrayList<RexNode>(); + rexBuilder = aggregate.getCluster().getRexBuilder(); + rewrites = new ArrayList<RewriteProcedure>(); + + // add identity projections + addProjectedFields(); + + if (countDistinctSketchType.isPresent()) { + rewrites.add(new CountDistinctRewrite(countDistinctSketchType.get())); + } + if (percentileContSketchType.isPresent()) { + rewrites.add(new PercentileContRewrite(percentileContSketchType.get())); + } + + for (AggregateCall aggCall : aggregate.getAggCallList()) { + processAggCall(aggCall); + } + } + + private void addProjectedFields() { + for (int i = 0; i < aggregate.getGroupCount(); i++) { + newProjects.add(rexBuilder.makeInputRef(aggregate.getInput(), i)); + } + Builder b = ImmutableBitSet.builder(); + b.addAll(aggregate.getGroupSet()); + for (AggregateCall aggCall: aggregate.getAggCallList()) { + b.addAll(aggCall.getArgList()); + } + ImmutableBitSet inputs = b.build(); + Integer maxIdx = Collections.max(inputs.asSet()); + for (int i = 0; i < maxIdx; i++) { + newProjectsBelow.add(rexBuilder.makeInputRef(aggregate.getInput(), i)); + } + } + + private void processAggCall(AggregateCall aggCall) { + for (RewriteProcedure rewrite : rewrites) { + if (rewrite.isApplicable(aggCall)) { + rewrite.rewrite(aggCall); + return; + } + } + appendAggCall(aggCall); + } + + private void appendAggCall(AggregateCall aggCall) { + RexNode projRex = rexBuilder.makeInputRef(aggCall.getType(), newProjects.size()); + + newAggCalls.add(aggCall); + newProjects.add(projRex); + } + + abstract class RewriteProcedure { + + private final String sketchClass; + + public RewriteProcedure(String sketchClass) { + this.sketchClass = sketchClass; + } + + abstract boolean isApplicable(AggregateCall aggCall); + abstract void rewrite(AggregateCall aggCall); + + protected SqlOperator getSqlOperator(String fnName) { + UDFDescriptor fn = DataSketchesFunctions.INSTANCE.getSketchFunction(sketchClass, fnName); + if (!fn.getCalciteFunction().isPresent()) { + throw new RuntimeException(fn.toString() + " doesn't have a Calcite function associated with it"); + } + return fn.getCalciteFunction().get(); + } + + } + + class CountDistinctRewrite extends RewriteProcedure { + + public CountDistinctRewrite(String sketchClass) { + super(sketchClass); + } + + @Override + boolean isApplicable(AggregateCall aggCall) { + return aggCall.isDistinct() && aggCall.getArgList().size() == 1 + && aggCall.getAggregation().getName().equalsIgnoreCase("count") && !aggCall.hasFilter(); + } + + @Override + void rewrite(AggregateCall aggCall) { + RelDataType origType = aggregate.getRowType().getFieldList().get(newProjects.size()).getType(); + + Integer argIndex = aggCall.getArgList().get(0); + RexNode call = rexBuilder.makeInputRef(aggregate.getInput(), argIndex); + newProjectsBelow.add(call); + + ArrayList<Integer> newArgList = Lists.newArrayList(newProjectsBelow.size() - 1); + + SqlAggFunction aggFunction = (SqlAggFunction) getSqlOperator(DataSketchesFunctions.DATA_TO_SKETCH); + boolean distinct = false; + boolean approximate = true; + boolean ignoreNulls = true; + List<Integer> argList = newArgList; + int filterArg = aggCall.filterArg; + RelCollation collation = aggCall.getCollation(); + RelDataType type = rexBuilder.deriveReturnType(aggFunction, Collections.emptyList()); + String name = aggFunction.getName(); + + AggregateCall newAgg = AggregateCall.create(aggFunction, distinct, approximate, ignoreNulls, argList, filterArg, + collation, type, name); + + SqlOperator projectOperator = getSqlOperator(DataSketchesFunctions.SKETCH_TO_ESTIMATE); + RexNode projRex = rexBuilder.makeInputRef(newAgg.getType(), newProjects.size()); + projRex = rexBuilder.makeCall(projectOperator, ImmutableList.of(projRex)); + projRex = rexBuilder.makeCall(SqlStdOperatorTable.ROUND, ImmutableList.of(projRex)); + projRex = rexBuilder.makeCast(origType, projRex); + + newAggCalls.add(newAgg); + newProjects.add(projRex); + } + + } + + class PercentileContRewrite extends RewriteProcedure { + + public PercentileContRewrite(String sketchClass) { + super(sketchClass); + } + + @Override + boolean isApplicable(AggregateCall aggCall) { + // FIXME: also check that args are: ?,?,1,0 - other cases are not supported Review comment: Is this going to be fixed in this patch or follow-up? Can you provide a bit more of explanation on what the comment means? ########## File path: ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRewriteToDataSketchesRule.java ########## @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.rules; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.hep.HepRelVertex; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories.ProjectFactory; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableBitSet.Builder; +import org.apache.hadoop.hive.ql.exec.DataSketchesFunctions; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; +import org.apache.hive.plugin.api.HiveUDFPlugin.UDFDescriptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +/** + * This rule could rewrite aggregate calls to be calculated using sketch based functions. + * + * <br/> + * Currently it can rewrite: + * <ul> + * <li>{@code count(distinct(x))} to distinct counting sketches</li> + * <li>{@code percentile_cont(0.2) within group (order by id)}</li> + * </ul> + * + * <p> + * The transformation here works on Aggregate nodes; the operations done are the following: + * </p> + * <ol> + * <li>Identify candidate aggregate calls</li> + * <li>A new Project is inserted below the Aggregate; to help with data pre-processing</li> + * <li>A new Aggregate is created in which the aggregation is done by the sketch function</li> + * <li>A new Project is inserted on top of the Aggregate; which unwraps the resulting + * count-distinct estimation from the sketch representation</li> + * </ol> + */ +public final class HiveRewriteToDataSketchesRule extends RelOptRule { + + protected static final Logger LOG = LoggerFactory.getLogger(HiveRewriteToDataSketchesRule.class); + private final Optional<String> countDistinctSketchType; + private final Optional<String> percentileContSketchType; + private final ProjectFactory projectFactory; + + public HiveRewriteToDataSketchesRule(Optional<String> countDistinctSketchType, + Optional<String> percentileContSketchType) { + super(operand(HiveAggregate.class, any())); + this.countDistinctSketchType = countDistinctSketchType; + this.percentileContSketchType = percentileContSketchType; + projectFactory = HiveRelFactories.HIVE_PROJECT_FACTORY; + } + + @Override + public void onMatch(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + + if (aggregate.getGroupSets().size() != 1) { + // not yet supported + return; + } + + List<AggregateCall> newAggCalls = new ArrayList<AggregateCall>(); + + VBuilder vb = new VBuilder(aggregate); + + if (aggregate.getAggCallList().equals(vb.newAggCalls)) { + // rule didn't made any changes + return; + } + + newAggCalls = vb.newAggCalls; + List<String> filedNames=new ArrayList<String>(); + for (int i=0;i<vb.newProjectsBelow.size();i++ ) { + filedNames.add("ff_"+i); + } + RelNode newProjectBelow= + projectFactory.createProject(aggregate.getInput(), vb.newProjectsBelow, filedNames); + + RelNode newAgg = aggregate.copy(aggregate.getTraitSet(), newProjectBelow, aggregate.getGroupSet(), + aggregate.getGroupSets(), newAggCalls); + + RelNode newProject = projectFactory.createProject(newAgg, vb.newProjects, aggregate.getRowType().getFieldNames()); + + call.transformTo(newProject); + return; + } + + /** + * Helper class to help in building a new Aggregate and Project. + */ + // NOTE: methods in this class are not re-entrant; drop-to-frame to constructor during debugging + private class VBuilder { + + private final RexBuilder rexBuilder; + + private Aggregate aggregate; + private List<AggregateCall> newAggCalls; + private List<RexNode> newProjects; + private List<RexNode> newProjectsBelow; + private List<RewriteProcedure> rewrites; + + public VBuilder(Aggregate aggregate) { + this.aggregate = aggregate; + newAggCalls = new ArrayList<AggregateCall>(); + newProjects = new ArrayList<RexNode>(); + newProjectsBelow = new ArrayList<RexNode>(); + rexBuilder = aggregate.getCluster().getRexBuilder(); + rewrites = new ArrayList<RewriteProcedure>(); + + // add identity projections + addProjectedFields(); + + if (countDistinctSketchType.isPresent()) { + rewrites.add(new CountDistinctRewrite(countDistinctSketchType.get())); + } + if (percentileContSketchType.isPresent()) { + rewrites.add(new PercentileContRewrite(percentileContSketchType.get())); + } + + for (AggregateCall aggCall : aggregate.getAggCallList()) { + processAggCall(aggCall); + } + } + + private void addProjectedFields() { + for (int i = 0; i < aggregate.getGroupCount(); i++) { + newProjects.add(rexBuilder.makeInputRef(aggregate.getInput(), i)); + } + Builder b = ImmutableBitSet.builder(); + b.addAll(aggregate.getGroupSet()); + for (AggregateCall aggCall: aggregate.getAggCallList()) { + b.addAll(aggCall.getArgList()); + } + ImmutableBitSet inputs = b.build(); + Integer maxIdx = Collections.max(inputs.asSet()); + for (int i = 0; i < maxIdx; i++) { + newProjectsBelow.add(rexBuilder.makeInputRef(aggregate.getInput(), i)); + } + } + + private void processAggCall(AggregateCall aggCall) { + for (RewriteProcedure rewrite : rewrites) { + if (rewrite.isApplicable(aggCall)) { + rewrite.rewrite(aggCall); + return; + } + } + appendAggCall(aggCall); + } + + private void appendAggCall(AggregateCall aggCall) { + RexNode projRex = rexBuilder.makeInputRef(aggCall.getType(), newProjects.size()); + + newAggCalls.add(aggCall); + newProjects.add(projRex); + } + + abstract class RewriteProcedure { + + private final String sketchClass; + + public RewriteProcedure(String sketchClass) { + this.sketchClass = sketchClass; + } + + abstract boolean isApplicable(AggregateCall aggCall); + abstract void rewrite(AggregateCall aggCall); + + protected SqlOperator getSqlOperator(String fnName) { + UDFDescriptor fn = DataSketchesFunctions.INSTANCE.getSketchFunction(sketchClass, fnName); + if (!fn.getCalciteFunction().isPresent()) { + throw new RuntimeException(fn.toString() + " doesn't have a Calcite function associated with it"); + } + return fn.getCalciteFunction().get(); + } + + } + + class CountDistinctRewrite extends RewriteProcedure { + + public CountDistinctRewrite(String sketchClass) { + super(sketchClass); + } + + @Override + boolean isApplicable(AggregateCall aggCall) { + return aggCall.isDistinct() && aggCall.getArgList().size() == 1 + && aggCall.getAggregation().getName().equalsIgnoreCase("count") && !aggCall.hasFilter(); + } + + @Override + void rewrite(AggregateCall aggCall) { + RelDataType origType = aggregate.getRowType().getFieldList().get(newProjects.size()).getType(); + + Integer argIndex = aggCall.getArgList().get(0); + RexNode call = rexBuilder.makeInputRef(aggregate.getInput(), argIndex); + newProjectsBelow.add(call); + + ArrayList<Integer> newArgList = Lists.newArrayList(newProjectsBelow.size() - 1); + + SqlAggFunction aggFunction = (SqlAggFunction) getSqlOperator(DataSketchesFunctions.DATA_TO_SKETCH); + boolean distinct = false; + boolean approximate = true; + boolean ignoreNulls = true; + List<Integer> argList = newArgList; + int filterArg = aggCall.filterArg; + RelCollation collation = aggCall.getCollation(); + RelDataType type = rexBuilder.deriveReturnType(aggFunction, Collections.emptyList()); + String name = aggFunction.getName(); + + AggregateCall newAgg = AggregateCall.create(aggFunction, distinct, approximate, ignoreNulls, argList, filterArg, + collation, type, name); + + SqlOperator projectOperator = getSqlOperator(DataSketchesFunctions.SKETCH_TO_ESTIMATE); + RexNode projRex = rexBuilder.makeInputRef(newAgg.getType(), newProjects.size()); + projRex = rexBuilder.makeCall(projectOperator, ImmutableList.of(projRex)); + projRex = rexBuilder.makeCall(SqlStdOperatorTable.ROUND, ImmutableList.of(projRex)); + projRex = rexBuilder.makeCast(origType, projRex); + + newAggCalls.add(newAgg); + newProjects.add(projRex); + } + + } + + class PercentileContRewrite extends RewriteProcedure { + + public PercentileContRewrite(String sketchClass) { + super(sketchClass); + } + + @Override + boolean isApplicable(AggregateCall aggCall) { + // FIXME: also check that args are: ?,?,1,0 - other cases are not supported + return !aggCall.isDistinct() && aggCall.getArgList().size() == 4 + && aggCall.getAggregation().getName().equalsIgnoreCase("percentile_cont") && !aggCall.hasFilter(); + } + + @Override + void rewrite(AggregateCall aggCall) { + RelDataType origType = aggregate.getRowType().getFieldList().get(newProjects.size()).getType(); + + Integer argIndex = aggCall.getArgList().get(1); + RexNode call = rexBuilder.makeInputRef(aggregate.getInput(), argIndex); + + RelDataType floatType = rexBuilder.getTypeFactory().createSqlType(SqlTypeName.FLOAT); + call = rexBuilder.makeCast(floatType, call); + newProjectsBelow.add(call); + + ArrayList<Integer> newArgList = Lists.newArrayList(newProjectsBelow.size() - 1); Review comment: Can this be removed? Iiuc it is just assigned to argList. ########## File path: ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRewriteToDataSketchesRule.java ########## @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.rules; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.hep.HepRelVertex; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories.ProjectFactory; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableBitSet.Builder; +import org.apache.hadoop.hive.ql.exec.DataSketchesFunctions; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; +import org.apache.hive.plugin.api.HiveUDFPlugin.UDFDescriptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +/** + * This rule could rewrite aggregate calls to be calculated using sketch based functions. + * + * <br/> + * Currently it can rewrite: + * <ul> + * <li>{@code count(distinct(x))} to distinct counting sketches</li> + * <li>{@code percentile_cont(0.2) within group (order by id)}</li> + * </ul> + * + * <p> + * The transformation here works on Aggregate nodes; the operations done are the following: + * </p> + * <ol> + * <li>Identify candidate aggregate calls</li> + * <li>A new Project is inserted below the Aggregate; to help with data pre-processing</li> + * <li>A new Aggregate is created in which the aggregation is done by the sketch function</li> + * <li>A new Project is inserted on top of the Aggregate; which unwraps the resulting + * count-distinct estimation from the sketch representation</li> + * </ol> + */ +public final class HiveRewriteToDataSketchesRule extends RelOptRule { + + protected static final Logger LOG = LoggerFactory.getLogger(HiveRewriteToDataSketchesRule.class); + private final Optional<String> countDistinctSketchType; + private final Optional<String> percentileContSketchType; + private final ProjectFactory projectFactory; + + public HiveRewriteToDataSketchesRule(Optional<String> countDistinctSketchType, + Optional<String> percentileContSketchType) { + super(operand(HiveAggregate.class, any())); + this.countDistinctSketchType = countDistinctSketchType; + this.percentileContSketchType = percentileContSketchType; + projectFactory = HiveRelFactories.HIVE_PROJECT_FACTORY; + } + + @Override + public void onMatch(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + + if (aggregate.getGroupSets().size() != 1) { + // not yet supported + return; + } + + List<AggregateCall> newAggCalls = new ArrayList<AggregateCall>(); + + VBuilder vb = new VBuilder(aggregate); + + if (aggregate.getAggCallList().equals(vb.newAggCalls)) { + // rule didn't made any changes + return; + } + + newAggCalls = vb.newAggCalls; + List<String> filedNames=new ArrayList<String>(); + for (int i=0;i<vb.newProjectsBelow.size();i++ ) { + filedNames.add("ff_"+i); + } + RelNode newProjectBelow= + projectFactory.createProject(aggregate.getInput(), vb.newProjectsBelow, filedNames); + + RelNode newAgg = aggregate.copy(aggregate.getTraitSet(), newProjectBelow, aggregate.getGroupSet(), + aggregate.getGroupSets(), newAggCalls); + + RelNode newProject = projectFactory.createProject(newAgg, vb.newProjects, aggregate.getRowType().getFieldNames()); + + call.transformTo(newProject); + return; + } + + /** + * Helper class to help in building a new Aggregate and Project. + */ + // NOTE: methods in this class are not re-entrant; drop-to-frame to constructor during debugging + private class VBuilder { + + private final RexBuilder rexBuilder; + + private Aggregate aggregate; + private List<AggregateCall> newAggCalls; + private List<RexNode> newProjects; + private List<RexNode> newProjectsBelow; + private List<RewriteProcedure> rewrites; + + public VBuilder(Aggregate aggregate) { + this.aggregate = aggregate; + newAggCalls = new ArrayList<AggregateCall>(); + newProjects = new ArrayList<RexNode>(); + newProjectsBelow = new ArrayList<RexNode>(); + rexBuilder = aggregate.getCluster().getRexBuilder(); + rewrites = new ArrayList<RewriteProcedure>(); + + // add identity projections + addProjectedFields(); + + if (countDistinctSketchType.isPresent()) { + rewrites.add(new CountDistinctRewrite(countDistinctSketchType.get())); + } + if (percentileContSketchType.isPresent()) { + rewrites.add(new PercentileContRewrite(percentileContSketchType.get())); + } + + for (AggregateCall aggCall : aggregate.getAggCallList()) { + processAggCall(aggCall); + } + } + + private void addProjectedFields() { + for (int i = 0; i < aggregate.getGroupCount(); i++) { + newProjects.add(rexBuilder.makeInputRef(aggregate.getInput(), i)); + } + Builder b = ImmutableBitSet.builder(); + b.addAll(aggregate.getGroupSet()); + for (AggregateCall aggCall: aggregate.getAggCallList()) { + b.addAll(aggCall.getArgList()); + } + ImmutableBitSet inputs = b.build(); + Integer maxIdx = Collections.max(inputs.asSet()); + for (int i = 0; i < maxIdx; i++) { + newProjectsBelow.add(rexBuilder.makeInputRef(aggregate.getInput(), i)); + } + } + + private void processAggCall(AggregateCall aggCall) { + for (RewriteProcedure rewrite : rewrites) { + if (rewrite.isApplicable(aggCall)) { + rewrite.rewrite(aggCall); + return; + } + } + appendAggCall(aggCall); + } + + private void appendAggCall(AggregateCall aggCall) { + RexNode projRex = rexBuilder.makeInputRef(aggCall.getType(), newProjects.size()); + + newAggCalls.add(aggCall); + newProjects.add(projRex); + } + + abstract class RewriteProcedure { + + private final String sketchClass; + + public RewriteProcedure(String sketchClass) { + this.sketchClass = sketchClass; + } + + abstract boolean isApplicable(AggregateCall aggCall); + abstract void rewrite(AggregateCall aggCall); + + protected SqlOperator getSqlOperator(String fnName) { + UDFDescriptor fn = DataSketchesFunctions.INSTANCE.getSketchFunction(sketchClass, fnName); + if (!fn.getCalciteFunction().isPresent()) { + throw new RuntimeException(fn.toString() + " doesn't have a Calcite function associated with it"); + } + return fn.getCalciteFunction().get(); + } + + } + + class CountDistinctRewrite extends RewriteProcedure { + + public CountDistinctRewrite(String sketchClass) { + super(sketchClass); + } + + @Override + boolean isApplicable(AggregateCall aggCall) { + return aggCall.isDistinct() && aggCall.getArgList().size() == 1 + && aggCall.getAggregation().getName().equalsIgnoreCase("count") && !aggCall.hasFilter(); Review comment: nit. indentation ########## File path: ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRewriteToDataSketchesRule.java ########## @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.optimizer.calcite.rules; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.hep.HepRelVertex; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories.ProjectFactory; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableBitSet.Builder; +import org.apache.hadoop.hive.ql.exec.DataSketchesFunctions; +import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; +import org.apache.hive.plugin.api.HiveUDFPlugin.UDFDescriptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +/** + * This rule could rewrite aggregate calls to be calculated using sketch based functions. + * + * <br/> + * Currently it can rewrite: + * <ul> + * <li>{@code count(distinct(x))} to distinct counting sketches</li> + * <li>{@code percentile_cont(0.2) within group (order by id)}</li> + * </ul> + * + * <p> + * The transformation here works on Aggregate nodes; the operations done are the following: + * </p> + * <ol> + * <li>Identify candidate aggregate calls</li> + * <li>A new Project is inserted below the Aggregate; to help with data pre-processing</li> + * <li>A new Aggregate is created in which the aggregation is done by the sketch function</li> + * <li>A new Project is inserted on top of the Aggregate; which unwraps the resulting + * count-distinct estimation from the sketch representation</li> + * </ol> + */ +public final class HiveRewriteToDataSketchesRule extends RelOptRule { + + protected static final Logger LOG = LoggerFactory.getLogger(HiveRewriteToDataSketchesRule.class); + private final Optional<String> countDistinctSketchType; + private final Optional<String> percentileContSketchType; + private final ProjectFactory projectFactory; + + public HiveRewriteToDataSketchesRule(Optional<String> countDistinctSketchType, + Optional<String> percentileContSketchType) { + super(operand(HiveAggregate.class, any())); + this.countDistinctSketchType = countDistinctSketchType; + this.percentileContSketchType = percentileContSketchType; + projectFactory = HiveRelFactories.HIVE_PROJECT_FACTORY; + } + + @Override + public void onMatch(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + + if (aggregate.getGroupSets().size() != 1) { + // not yet supported + return; + } + + List<AggregateCall> newAggCalls = new ArrayList<AggregateCall>(); + + VBuilder vb = new VBuilder(aggregate); + + if (aggregate.getAggCallList().equals(vb.newAggCalls)) { + // rule didn't made any changes + return; + } + + newAggCalls = vb.newAggCalls; + List<String> filedNames=new ArrayList<String>(); + for (int i=0;i<vb.newProjectsBelow.size();i++ ) { + filedNames.add("ff_"+i); + } + RelNode newProjectBelow= + projectFactory.createProject(aggregate.getInput(), vb.newProjectsBelow, filedNames); + + RelNode newAgg = aggregate.copy(aggregate.getTraitSet(), newProjectBelow, aggregate.getGroupSet(), + aggregate.getGroupSets(), newAggCalls); + + RelNode newProject = projectFactory.createProject(newAgg, vb.newProjects, aggregate.getRowType().getFieldNames()); + + call.transformTo(newProject); + return; + } + + /** + * Helper class to help in building a new Aggregate and Project. + */ + // NOTE: methods in this class are not re-entrant; drop-to-frame to constructor during debugging + private class VBuilder { + + private final RexBuilder rexBuilder; + + private Aggregate aggregate; + private List<AggregateCall> newAggCalls; + private List<RexNode> newProjects; + private List<RexNode> newProjectsBelow; + private List<RewriteProcedure> rewrites; + + public VBuilder(Aggregate aggregate) { + this.aggregate = aggregate; + newAggCalls = new ArrayList<AggregateCall>(); + newProjects = new ArrayList<RexNode>(); + newProjectsBelow = new ArrayList<RexNode>(); + rexBuilder = aggregate.getCluster().getRexBuilder(); + rewrites = new ArrayList<RewriteProcedure>(); + + // add identity projections + addProjectedFields(); + + if (countDistinctSketchType.isPresent()) { + rewrites.add(new CountDistinctRewrite(countDistinctSketchType.get())); + } + if (percentileContSketchType.isPresent()) { + rewrites.add(new PercentileContRewrite(percentileContSketchType.get())); + } + + for (AggregateCall aggCall : aggregate.getAggCallList()) { + processAggCall(aggCall); + } + } + + private void addProjectedFields() { + for (int i = 0; i < aggregate.getGroupCount(); i++) { + newProjects.add(rexBuilder.makeInputRef(aggregate.getInput(), i)); + } + Builder b = ImmutableBitSet.builder(); + b.addAll(aggregate.getGroupSet()); + for (AggregateCall aggCall: aggregate.getAggCallList()) { + b.addAll(aggCall.getArgList()); + } + ImmutableBitSet inputs = b.build(); + Integer maxIdx = Collections.max(inputs.asSet()); + for (int i = 0; i < maxIdx; i++) { + newProjectsBelow.add(rexBuilder.makeInputRef(aggregate.getInput(), i)); + } + } + + private void processAggCall(AggregateCall aggCall) { + for (RewriteProcedure rewrite : rewrites) { + if (rewrite.isApplicable(aggCall)) { + rewrite.rewrite(aggCall); + return; + } + } + appendAggCall(aggCall); + } + + private void appendAggCall(AggregateCall aggCall) { + RexNode projRex = rexBuilder.makeInputRef(aggCall.getType(), newProjects.size()); + + newAggCalls.add(aggCall); + newProjects.add(projRex); + } + + abstract class RewriteProcedure { + + private final String sketchClass; + + public RewriteProcedure(String sketchClass) { + this.sketchClass = sketchClass; + } + + abstract boolean isApplicable(AggregateCall aggCall); + abstract void rewrite(AggregateCall aggCall); + + protected SqlOperator getSqlOperator(String fnName) { + UDFDescriptor fn = DataSketchesFunctions.INSTANCE.getSketchFunction(sketchClass, fnName); + if (!fn.getCalciteFunction().isPresent()) { + throw new RuntimeException(fn.toString() + " doesn't have a Calcite function associated with it"); + } + return fn.getCalciteFunction().get(); + } + + } + + class CountDistinctRewrite extends RewriteProcedure { + + public CountDistinctRewrite(String sketchClass) { + super(sketchClass); + } + + @Override + boolean isApplicable(AggregateCall aggCall) { + return aggCall.isDistinct() && aggCall.getArgList().size() == 1 + && aggCall.getAggregation().getName().equalsIgnoreCase("count") && !aggCall.hasFilter(); Review comment: Can we rely on `SqlKind.COUNT` instead of the name? ########## File path: ql/src/test/results/clientpositive/llap/sketches_rewrite_percentile_cont.q.out ########## @@ -0,0 +1,105 @@ +PREHOOK: query: create table sketch_input (id int, category char(1)) +STORED AS ORC +TBLPROPERTIES ('transactional'='true') +PREHOOK: type: CREATETABLE +PREHOOK: Output: database:default +PREHOOK: Output: default@sketch_input +POSTHOOK: query: create table sketch_input (id int, category char(1)) +STORED AS ORC +TBLPROPERTIES ('transactional'='true') +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: database:default +POSTHOOK: Output: default@sketch_input +PREHOOK: query: insert into table sketch_input values + (1,'a'),(1, 'a'), (2, 'a'), (3, 'a'), (4, 'a'), (5, 'a'), (6, 'a'), (7, 'a'), (8, 'a'), (9, 'a'), (10, 'a'), + (6,'b'),(6, 'b'), (7, 'b'), (8, 'b'), (9, 'b'), (10, 'b'), (11, 'b'), (12, 'b'), (13, 'b'), (14, 'b'), (15, 'b') +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +PREHOOK: Output: default@sketch_input +POSTHOOK: query: insert into table sketch_input values + (1,'a'),(1, 'a'), (2, 'a'), (3, 'a'), (4, 'a'), (5, 'a'), (6, 'a'), (7, 'a'), (8, 'a'), (9, 'a'), (10, 'a'), + (6,'b'),(6, 'b'), (7, 'b'), (8, 'b'), (9, 'b'), (10, 'b'), (11, 'b'), (12, 'b'), (13, 'b'), (14, 'b'), (15, 'b') +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +POSTHOOK: Output: default@sketch_input +POSTHOOK: Lineage: sketch_input.category SCRIPT [] +POSTHOOK: Lineage: sketch_input.id SCRIPT [] +PREHOOK: query: explain +select percentile_cont(0.2) within group(order by id) from sketch_input +PREHOOK: type: QUERY +PREHOOK: Input: default@sketch_input +#### A masked pattern was here #### +POSTHOOK: query: explain +select percentile_cont(0.2) within group(order by id) from sketch_input +POSTHOOK: type: QUERY +POSTHOOK: Input: default@sketch_input +#### A masked pattern was here #### +STAGE DEPENDENCIES: + Stage-1 is a root stage + Stage-0 depends on stages: Stage-1 + +STAGE PLANS: + Stage: Stage-1 + Tez +#### A masked pattern was here #### + Edges: + Reducer 2 <- Map 1 (CUSTOM_SIMPLE_EDGE) +#### A masked pattern was here #### + Vertices: + Map 1 + Map Operator Tree: + TableScan + alias: sketch_input + Statistics: Num rows: 22 Data size: 88 Basic stats: COMPLETE Column stats: COMPLETE + Select Operator + expressions: UDFToFloat(id) (type: float) + outputColumnNames: _col0 + Statistics: Num rows: 22 Data size: 88 Basic stats: COMPLETE Column stats: COMPLETE + Group By Operator + aggregations: ds_kll_sketch(_col0) + minReductionHashAggr: 0.95454544 + mode: hash + outputColumnNames: _col0 + Statistics: Num rows: 1 Data size: 144 Basic stats: COMPLETE Column stats: COMPLETE + Reduce Output Operator + null sort order: + sort order: + Statistics: Num rows: 1 Data size: 144 Basic stats: COMPLETE Column stats: COMPLETE + value expressions: _col0 (type: binary) + Execution mode: llap + LLAP IO: may be used (ACID table) + Reducer 2 + Execution mode: llap + Reduce Operator Tree: + Group By Operator + aggregations: ds_kll_sketch(VALUE._col0) + mode: mergepartial + outputColumnNames: _col0 + Statistics: Num rows: 1 Data size: 144 Basic stats: COMPLETE Column stats: COMPLETE + Select Operator + expressions: UDFToDouble(ds_kll_quantile(_col0, 0.2)) (type: double) + outputColumnNames: _col0 + Statistics: Num rows: 1 Data size: 8 Basic stats: COMPLETE Column stats: COMPLETE + File Output Operator + compressed: false + Statistics: Num rows: 1 Data size: 8 Basic stats: COMPLETE Column stats: COMPLETE + table: + input format: org.apache.hadoop.mapred.SequenceFileInputFormat + output format: org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat + serde: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe + + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + ListSink + +PREHOOK: query: select percentile_cont(0.2) within group(order by id) from sketch_input +PREHOOK: type: QUERY +PREHOOK: Input: default@sketch_input +#### A masked pattern was here #### +POSTHOOK: query: select percentile_cont(0.2) within group(order by id) from sketch_input +POSTHOOK: type: QUERY +POSTHOOK: Input: default@sketch_input +#### A masked pattern was here #### +4.0 Review comment: This matches percentile_dist output but not percentile_cont. Is it because of the error bound? Have you run a few tests with data and checked the results? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org Issue Time Tracking ------------------- Worklog Id: (was: 433181) Time Spent: 20m (was: 10m) > Add option to rewrite PERCENTILE_CONT to sketch functions > --------------------------------------------------------- > > Key: HIVE-23434 > URL: https://issues.apache.org/jira/browse/HIVE-23434 > Project: Hive > Issue Type: Sub-task > Reporter: Zoltan Haindrich > Assignee: Zoltan Haindrich > Priority: Major > Labels: pull-request-available > Attachments: HIVE-23434.01.patch, HIVE-23434.02.patch, > HIVE-23434.03.patch > > Time Spent: 20m > Remaining Estimate: 0h > -- This message was sent by Atlassian Jira (v8.3.4#803005)