This is an automated email from the ASF dual-hosted git repository. morrysnow pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 7379cdc995 [feature](nereids) support subquery in select list (#23271) 7379cdc995 is described below commit 7379cdc9953f9ee16733ed67e5e12434cd38b263 Author: starocean999 <40539150+starocean...@users.noreply.github.com> AuthorDate: Thu Aug 31 15:51:32 2023 +0800 [feature](nereids) support subquery in select list (#23271) 1. add scalar subquery's output to LogicalApply's output 2. for in and exists subquery's, add mark join slot into LogicalApply's output 3. forbid push down alias through join if the project list have any mark join slots. 4. move normalize aggregate rule to analysis phase --- .../doris/nereids/jobs/executor/Analyzer.java | 11 +- .../doris/nereids/jobs/executor/Rewriter.java | 9 +- .../EliminateGroupByConstant.java | 3 +- .../{rewrite => analysis}/NormalizeAggregate.java | 4 +- .../nereids/rules/analysis/SubqueryToApply.java | 92 +++--- .../rules/implementation/AggregateStrategies.java | 2 +- .../rules/rewrite/PushdownAliasThroughJoin.java | 5 +- .../doris/nereids/trees/expressions/CaseWhen.java | 2 +- .../trees/expressions/literal/DoubleLiteral.java | 5 + .../trees/expressions/literal/FloatLiteral.java | 9 + .../nereids/rules/analysis/AnalyzeCTETest.java | 2 +- .../rules/analysis/AnalyzeWhereSubqueryTest.java | 70 ++-- .../rules/analysis/BindSlotReferenceTest.java | 7 +- .../EliminateGroupByConstantTest.java | 3 +- .../rules/analysis/FillUpMissingSlotsTest.java | 354 +++++++++++---------- .../NormalizeAggregateTest.java | 2 +- .../rules/rewrite/AggregateStrategiesTest.java | 1 + .../nereids/rules/rewrite/ColumnPruningTest.java | 17 +- .../ExtractAndNormalizeWindowExpressionTest.java | 1 + .../rewrite/PushdownAliasThroughJoinTest.java | 23 ++ .../PushdownExpressionsInHashConditionTest.java | 34 +- .../rules/rewrite/mv/SelectRollupIndexTest.java | 4 +- .../subquery/test_subquery_in_project.out | 50 +++ .../nereids_tpcds_shape_sf100_p0/shape/query1.out | 15 +- .../nereids_tpcds_shape_sf100_p0/shape/query30.out | 15 +- .../nereids_tpcds_shape_sf100_p0/shape/query51.out | 42 +-- .../nereids_tpcds_shape_sf100_p0/shape/query81.out | 15 +- .../nereids_tpch_shape_sf1000_p0/shape/q20.out | 13 +- .../data/nereids_tpch_shape_sf500_p0/shape/q20.out | 13 +- .../subquery/test_subquery_in_project.groovy | 120 +++++++ 30 files changed, 610 insertions(+), 333 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java index 1fb1d7eecd..ed67b44f1f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java @@ -30,7 +30,9 @@ import org.apache.doris.nereids.rules.analysis.BindSink; import org.apache.doris.nereids.rules.analysis.CheckAnalysis; import org.apache.doris.nereids.rules.analysis.CheckBound; import org.apache.doris.nereids.rules.analysis.CheckPolicy; +import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant; import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.analysis.NormalizeRepeat; import org.apache.doris.nereids.rules.analysis.ProjectToGlobalAggregate; import org.apache.doris.nereids.rules.analysis.ProjectWithDistinctToAggregate; @@ -110,9 +112,14 @@ public class Analyzer extends AbstractBatchJobExecutor { // LogicalProject for normalize. This rule depends on FillUpMissingSlots to fill up slots. new NormalizeRepeat() ), - bottomUp(new SubqueryToApply()), bottomUp(new AdjustAggregateNullableForEmptySet()), - bottomUp(new CheckAnalysis()) + // run CheckAnalysis before EliminateGroupByConstant in order to report error message correctly like bellow + // select SUM(lo_tax) FROM lineorder group by 1; + // errCode = 2, detailMessage = GROUP BY expression must not contain aggregate functions: sum(lo_tax) + bottomUp(new CheckAnalysis()), + topDown(new EliminateGroupByConstant()), + topDown(new NormalizeAggregate()), + bottomUp(new SubqueryToApply()) ); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index ec2ea06eac..986947c262 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -25,7 +25,9 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet; import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount; import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite; +import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant; import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite; import org.apache.doris.nereids.rules.expression.ExpressionNormalization; import org.apache.doris.nereids.rules.expression.ExpressionOptimization; @@ -52,7 +54,6 @@ import org.apache.doris.nereids.rules.rewrite.EliminateAggregate; import org.apache.doris.nereids.rules.rewrite.EliminateDedupJoinCondition; import org.apache.doris.nereids.rules.rewrite.EliminateEmptyRelation; import org.apache.doris.nereids.rules.rewrite.EliminateFilter; -import org.apache.doris.nereids.rules.rewrite.EliminateGroupByConstant; import org.apache.doris.nereids.rules.rewrite.EliminateLimit; import org.apache.doris.nereids.rules.rewrite.EliminateNotNull; import org.apache.doris.nereids.rules.rewrite.EliminateNullAwareLeftAntiJoin; @@ -74,12 +75,12 @@ import org.apache.doris.nereids.rules.rewrite.MergeFilters; import org.apache.doris.nereids.rules.rewrite.MergeOneRowRelationIntoUnion; import org.apache.doris.nereids.rules.rewrite.MergeProjects; import org.apache.doris.nereids.rules.rewrite.MergeSetOperations; -import org.apache.doris.nereids.rules.rewrite.NormalizeAggregate; import org.apache.doris.nereids.rules.rewrite.NormalizeSort; import org.apache.doris.nereids.rules.rewrite.PruneFileScanPartition; import org.apache.doris.nereids.rules.rewrite.PruneOlapScanPartition; import org.apache.doris.nereids.rules.rewrite.PruneOlapScanTablet; import org.apache.doris.nereids.rules.rewrite.PullUpCteAnchor; +import org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderApply; import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoEsScan; import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoJdbcScan; import org.apache.doris.nereids.rules.rewrite.PushFilterInsideJoin; @@ -139,6 +140,10 @@ public class Rewriter extends AbstractBatchJobExecutor { ), // subquery unnesting relay on ExpressionNormalization to extract common factor expression topic("Subquery unnesting", + // after doing NormalizeAggregate in analysis job + // we need run the following 2 rules to make AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION work + bottomUp(new PullUpProjectUnderApply()), + topDown(new PushdownFilterThroughProject()), costBased( custom(RuleType.AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION, AggScalarSubQueryToWindowFunction::new) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstant.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstant.java similarity index 96% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstant.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstant.java index f5a01fe530..e7fa14e5cb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstant.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstant.java @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.rules.rewrite; +package org.apache.doris.nereids.rules.analysis; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule; +import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.plans.Plan; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java similarity index 98% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregate.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index eb683e8b58..6a141dce7a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.rules.rewrite; +package org.apache.doris.nereids.rules.analysis; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot; +import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java index 6dfe95c116..c28e82f680 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java @@ -21,9 +21,7 @@ import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.BinaryOperator; -import org.apache.doris.nereids.trees.expressions.CaseWhen; import org.apache.doris.nereids.trees.expressions.Exists; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InSubquery; @@ -47,7 +45,6 @@ import com.google.common.collect.ImmutableSet; import java.util.Collection; import java.util.LinkedHashMap; -import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Objects; @@ -68,8 +65,8 @@ public class SubqueryToApply implements AnalysisRuleFactory { logicalFilter().thenApply(ctx -> { LogicalFilter<Plan> filter = ctx.root; - ImmutableList<Set> subqueryExprsList = filter.getConjuncts().stream() - .map(e -> (Set) e.collect(SubqueryExpr.class::isInstance)) + ImmutableList<Set<SubqueryExpr>> subqueryExprsList = filter.getConjuncts().stream() + .map(e -> (Set<SubqueryExpr>) e.collect(SubqueryExpr.class::isInstance)) .collect(ImmutableList.toImmutableList()); if (subqueryExprsList.stream() .flatMap(Collection::stream).noneMatch(SubqueryExpr.class::isInstance)) { @@ -104,8 +101,7 @@ public class SubqueryToApply implements AnalysisRuleFactory { tmpPlan = applyPlan; newConjuncts.add(conjunct); } - Set<Expression> conjuncts = new LinkedHashSet<>(); - conjuncts.addAll(newConjuncts.build()); + Set<Expression> conjuncts = ImmutableSet.copyOf(newConjuncts.build()); Plan newFilter = new LogicalFilter<>(conjuncts, applyPlan); if (conjuncts.stream().flatMap(c -> c.children().stream()) .anyMatch(MarkJoinSlotReference.class::isInstance)) { @@ -116,36 +112,44 @@ public class SubqueryToApply implements AnalysisRuleFactory { return new LogicalFilter<>(conjuncts, applyPlan); }) ), - RuleType.PROJECT_SUBQUERY_TO_APPLY.build( - logicalProject().thenApply(ctx -> { - LogicalProject<Plan> project = ctx.root; - Set<SubqueryExpr> subqueryExprs = new LinkedHashSet<>(); - project.getProjects().stream() - .filter(Alias.class::isInstance) - .map(Alias.class::cast) - .filter(alias -> alias.child() instanceof CaseWhen) - .forEach(alias -> alias.child().children().stream() - .forEach(e -> - subqueryExprs.addAll(e.collect(SubqueryExpr.class::isInstance)))); - if (subqueryExprs.isEmpty()) { - return project; - } - - SubqueryContext context = new SubqueryContext(subqueryExprs); - return new LogicalProject(project.getProjects().stream() - .map(p -> p.withChildren( - new ReplaceSubquery(ctx.statementContext, true) - .replace(p, context))) - .collect(ImmutableList.toImmutableList()), - subqueryToApply( - subqueryExprs.stream().collect(ImmutableList.toImmutableList()), - (LogicalPlan) project.child(), - context.getSubqueryToMarkJoinSlot(), - ctx.cascadesContext, - Optional.empty(), true - )); - }) - ) + RuleType.PROJECT_SUBQUERY_TO_APPLY.build(logicalProject().thenApply(ctx -> { + LogicalProject<Plan> project = ctx.root; + ImmutableList<Set<SubqueryExpr>> subqueryExprsList = project.getProjects().stream() + .map(e -> (Set<SubqueryExpr>) e.collect(SubqueryExpr.class::isInstance)) + .collect(ImmutableList.toImmutableList()); + if (subqueryExprsList.stream().flatMap(Collection::stream).count() == 0) { + return project; + } + List<NamedExpression> oldProjects = ImmutableList.copyOf(project.getProjects()); + ImmutableList.Builder<NamedExpression> newProjects = new ImmutableList.Builder<>(); + LogicalPlan childPlan = (LogicalPlan) project.child(); + LogicalPlan applyPlan; + for (int i = 0; i < subqueryExprsList.size(); ++i) { + Set<SubqueryExpr> subqueryExprs = subqueryExprsList.get(i); + if (subqueryExprs.isEmpty()) { + newProjects.add(oldProjects.get(i)); + continue; + } + + // first step: Replace the subquery in logcialProject's project list + // second step: Replace subquery with LogicalApply + ReplaceSubquery replaceSubquery = + new ReplaceSubquery(ctx.statementContext, true); + SubqueryContext context = new SubqueryContext(subqueryExprs); + Expression newProject = + replaceSubquery.replace(oldProjects.get(i), context); + + applyPlan = subqueryToApply( + subqueryExprs.stream().collect(ImmutableList.toImmutableList()), + childPlan, context.getSubqueryToMarkJoinSlot(), + ctx.cascadesContext, + Optional.of(newProject), true); + childPlan = applyPlan; + newProjects.add((NamedExpression) newProject); + } + + return project.withProjectsAndChild(newProjects.build(), childPlan); + })) ); } @@ -249,28 +253,30 @@ public class SubqueryToApply implements AnalysisRuleFactory { // The result set when NULL is specified in the subquery and still evaluates to TRUE by using EXISTS // When the number of rows returned is empty, agg will return null, so if there is more agg, // it will always consider the returned result to be true + boolean needCreateMarkJoinSlot = isMarkJoin || isProject; MarkJoinSlotReference markJoinSlotReference = null; - if (exists.getQueryPlan().anyMatch(Aggregate.class::isInstance) && isMarkJoin) { + if (exists.getQueryPlan().anyMatch(Aggregate.class::isInstance) && needCreateMarkJoinSlot) { markJoinSlotReference = new MarkJoinSlotReference(statementContext.generateColumnName(), true); - } else if (isMarkJoin) { + } else if (needCreateMarkJoinSlot) { markJoinSlotReference = new MarkJoinSlotReference(statementContext.generateColumnName()); } - if (isMarkJoin) { + if (needCreateMarkJoinSlot) { context.setSubqueryToMarkJoinSlot(exists, Optional.of(markJoinSlotReference)); } - return isMarkJoin ? markJoinSlotReference : BooleanLiteral.TRUE; + return needCreateMarkJoinSlot ? markJoinSlotReference : BooleanLiteral.TRUE; } @Override public Expression visitInSubquery(InSubquery in, SubqueryContext context) { MarkJoinSlotReference markJoinSlotReference = new MarkJoinSlotReference(statementContext.generateColumnName()); - if (isMarkJoin) { + boolean needCreateMarkJoinSlot = isMarkJoin || isProject; + if (needCreateMarkJoinSlot) { context.setSubqueryToMarkJoinSlot(in, Optional.of(markJoinSlotReference)); } - return isMarkJoin ? markJoinSlotReference : BooleanLiteral.TRUE; + return needCreateMarkJoinSlot ? markJoinSlotReference : BooleanLiteral.TRUE; } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java index 03962ff752..f0971c94ba 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java @@ -29,8 +29,8 @@ import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.properties.RequireProperties; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE; -import org.apache.doris.nereids.rules.rewrite.NormalizeAggregate; import org.apache.doris.nereids.trees.expressions.AggregateExpression; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Cast; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoin.java index 42649cbac1..7839fcfe95 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoin.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.Plan; @@ -45,7 +46,9 @@ public class PushdownAliasThroughJoin extends OneRewriteRuleFactory { public Rule build() { return logicalProject(logicalJoin()) .when(project -> project.getProjects().stream().allMatch(expr -> - (expr instanceof Slot) || (expr instanceof Alias && ((Alias) expr).child() instanceof Slot))) + (expr instanceof Slot && !(expr instanceof MarkJoinSlotReference)) + || (expr instanceof Alias && ((Alias) expr).child() instanceof Slot + && !(((Alias) expr).child() instanceof MarkJoinSlotReference)))) .when(project -> project.getProjects().stream().anyMatch(expr -> expr instanceof Alias)) .then(project -> { LogicalJoin<? extends Plan, ? extends Plan> join = project.child(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java index c9233d5c14..11456e8f94 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java @@ -94,7 +94,7 @@ public class CaseWhen extends Expression { StringBuilder output = new StringBuilder("CASE"); for (Expression child : children()) { if (child instanceof WhenClause) { - output.append(child); + output.append(child.toString()); } else { output.append(" ELSE ").append(child.toString()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DoubleLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DoubleLiteral.java index bdd26460c0..b155fe3075 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DoubleLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DoubleLiteral.java @@ -58,4 +58,9 @@ public class DoubleLiteral extends Literal { nf.setGroupingUsed(false); return nf.format(value); } + + @Override + public String getStringValue() { + return toString(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/FloatLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/FloatLiteral.java index 4fff7445ef..95549901dd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/FloatLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/FloatLiteral.java @@ -22,6 +22,8 @@ import org.apache.doris.catalog.Type; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.FloatType; +import java.text.NumberFormat; + /** * float type literal */ @@ -48,4 +50,11 @@ public class FloatLiteral extends Literal { public LiteralExpr toLegacyLiteral() { return new org.apache.doris.analysis.FloatLiteral((double) value, Type.FLOAT); } + + @Override + public String getStringValue() { + NumberFormat nf = NumberFormat.getInstance(); + nf.setGroupingUsed(false); + return nf.format(value); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java index 522f198e3f..ef5a32e2d3 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java @@ -140,7 +140,7 @@ public class AnalyzeCTETest extends TestWithFeService implements MemoPatternMatc logicalFilter( logicalProject( logicalJoin( - logicalAggregate(), + logicalProject(logicalAggregate()), logicalProject() ) ) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java index bf060d7e5d..73422bee70 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java @@ -156,18 +156,20 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP .matchesNotCheck( logicalApply( any(), - logicalAggregate( - logicalFilter() - ).when(FieldChecker.check("outputExpressions", ImmutableList.of( - new Alias(new ExprId(7), - (new Sum( - new SlotReference(new ExprId(4), "k3", - BigIntType.INSTANCE, true, - ImmutableList.of( - "default_cluster:test", - "t7")))).withAlwaysNullable( - true), - "sum(k3)")))) + logicalProject( + logicalAggregate( + logicalProject() + ).when(FieldChecker.check("outputExpressions", ImmutableList.of( + new Alias(new ExprId(7), + (new Sum( + new SlotReference(new ExprId(4), "k3", + BigIntType.INSTANCE, true, + ImmutableList.of( + "default_cluster:test", + "t7")))).withAlwaysNullable( + true), + "sum(k3)")))) + ) ).when(FieldChecker.check("correlationSlot", ImmutableList.of( new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true, ImmutableList.of("default_cluster:test", "t6")) @@ -383,28 +385,32 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP logicalProject( logicalApply( any(), - logicalAggregate( - logicalSubQueryAlias( + logicalProject( + logicalAggregate( logicalProject( - logicalFilter() - ).when(p -> p.getProjects().equals(ImmutableList.of( - new Alias(new ExprId(7), new SlotReference(new ExprId(5), "v1", BigIntType.INSTANCE, - true, - ImmutableList.of("default_cluster:test", "t7")), "aa") - ))) - ) - .when(a -> a.getAlias().equals("t2")) - .when(a -> a.getOutput().equals(ImmutableList.of( - new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE, - true, ImmutableList.of("t2")) + logicalSubQueryAlias( + logicalProject( + logicalFilter() + ).when(p -> p.getProjects().equals(ImmutableList.of( + new Alias(new ExprId(7), new SlotReference(new ExprId(5), "v1", BigIntType.INSTANCE, + true, + ImmutableList.of("default_cluster:test", "t7")), "aa") + ))) + ) + .when(a -> a.getAlias().equals("t2")) + .when(a -> a.getOutput().equals(ImmutableList.of( + new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE, + true, ImmutableList.of("t2")) + ))) + ) + ).when(agg -> agg.getOutputExpressions().equals(ImmutableList.of( + new Alias(new ExprId(8), + (new Max(new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE, + true, + ImmutableList.of("t2")))).withAlwaysNullable(true), "max(aa)") ))) - ).when(agg -> agg.getOutputExpressions().equals(ImmutableList.of( - new Alias(new ExprId(8), - (new Max(new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE, - true, - ImmutableList.of("t2")))).withAlwaysNullable(true), "max(aa)") - ))) - .when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of())) + .when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of())) + ) ) .when(apply -> apply.getCorrelationSlot().equals(ImmutableList.of( new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true, diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java index 0a3334b4cf..dc05ec0626 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java @@ -90,10 +90,11 @@ class BindSlotReferenceTest { join ); PlanChecker checker = PlanChecker.from(MemoTestUtils.createConnectContext()).analyze(aggregate); - LogicalAggregate plan = (LogicalAggregate) checker.getCascadesContext().getMemo().copyOut(); + LogicalAggregate plan = (LogicalAggregate) ((LogicalProject) checker.getCascadesContext() + .getMemo().copyOut()).child(); SlotReference groupByKey = (SlotReference) plan.getGroupByExpressions().get(0); - SlotReference t1id = (SlotReference) ((LogicalJoin) plan.child()).left().getOutput().get(0); - SlotReference t2id = (SlotReference) ((LogicalJoin) plan.child()).right().getOutput().get(0); + SlotReference t1id = (SlotReference) ((LogicalJoin) plan.child().child(0)).left().getOutput().get(0); + SlotReference t2id = (SlotReference) ((LogicalJoin) plan.child().child(0)).right().getOutput().get(0); Assertions.assertEquals(groupByKey.getExprId(), t1id.getExprId()); Assertions.assertNotEquals(t1id.getExprId(), t2id.getExprId()); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstantTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstantTest.java similarity index 98% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstantTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstantTest.java index 3fca54eed9..c35b983911 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByConstantTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstantTest.java @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.rules.rewrite; +package org.apache.doris.nereids.rules.analysis; import org.apache.doris.catalog.AggregateType; import org.apache.doris.catalog.Column; @@ -23,7 +23,6 @@ import org.apache.doris.catalog.KeysType; import org.apache.doris.catalog.OlapTable; import org.apache.doris.catalog.PartitionInfo; import org.apache.doris.catalog.Type; -import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Slot; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java index 03cc549bc2..8cacb46091 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java @@ -35,6 +35,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.types.TinyIntType; import org.apache.doris.nereids.util.FieldChecker; import org.apache.doris.nereids.util.MemoPatternMatchSupported; @@ -45,8 +46,6 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import org.junit.jupiter.api.Test; -import java.util.stream.Collectors; - public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements MemoPatternMatchSupported { @Override @@ -86,35 +85,35 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo ); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1))))); + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1)))))); sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING a1 > 0"; - a1 = new SlotReference( - new ExprId(1), "a1", TinyIntType.INSTANCE, true, - ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") - ); - Alias value = new Alias(new ExprId(3), a1, "value"); + SlotReference value = new SlotReference(new ExprId(3), "value", TinyIntType.INSTANCE, true, + ImmutableList.of()); PlanChecker.from(connectContext).analyze(sql) .applyBottomUp(new ExpressionRewrite(FunctionBinder.INSTANCE)) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0))))))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0))))))); sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING value > 0"; PlanChecker.from(connectContext).analyze(sql) .applyBottomUp(new ExpressionRewrite(FunctionBinder.INSTANCE)) .matches( logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value)))) ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0)))))); sql = "SELECT SUM(a2) FROM t1 GROUP BY a1 HAVING a1 > 0"; @@ -130,13 +129,14 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo PlanChecker.from(connectContext).analyze(sql) .applyBottomUp(new ExpressionRewrite(FunctionBinder.INSTANCE)) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(sumA2, a1))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(a1, new TinyIntLiteral((byte) 0))))) - ).when(FieldChecker.check("projects", Lists.newArrayList(sumA2.toSlot())))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(a1, new TinyIntLiteral((byte) 0))))) + ).when(FieldChecker.check("projects", Lists.newArrayList(sumA2.toSlot())))); } @Test @@ -153,24 +153,28 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo Alias sumA2 = new Alias(new ExprId(3), new Sum(a2), "sum(a2)"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L))))) - ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L))))) + ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 HAVING SUM(a2) > 0"; sumA2 = new Alias(new ExprId(3), new Sum(a2), "SUM(a2)"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L))))))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject( + logicalOlapScan() + ) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L))))))); sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING SUM(a2) > 0"; a1 = new SlotReference( @@ -184,20 +188,24 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo Alias value = new Alias(new ExprId(3), new Sum(a2), "value"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L))))))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject( + logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L))))))); sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING value > 0"; PlanChecker.from(connectContext).analyze(sql) .matches( logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))) + logicalProject( + logicalAggregate( + logicalProject( + logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))) ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L)))))); sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 HAVING MIN(pk) > 0"; @@ -217,49 +225,53 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo Alias minPK = new Alias(new ExprId(4), new Min(pk), "min(pk)"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(minPK.toSlot(), Literal.of((byte) 0))))) - ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot())))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(minPK.toSlot(), Literal.of((byte) 0))))) + ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot())))); sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 HAVING SUM(a1 + a2) > 0"; Alias sumA1A2 = new Alias(new ExprId(3), new Sum(new Add(a1, a2)), "SUM((a1 + a2))"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A2.toSlot(), Literal.of(0L))))))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A2.toSlot(), Literal.of(0L))))))); sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 HAVING SUM(a1 + a2 + 3) > 0"; Alias sumA1A23 = new Alias(new ExprId(4), new Sum(new Add(new Add(a1, a2), new TinyIntLiteral((byte) 3))), "sum(((a1 + a2) + 3))"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A23.toSlot(), Literal.of(0L))))) - ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot())))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A23.toSlot(), Literal.of(0L))))) + ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot())))); sql = "SELECT a1 FROM t1 GROUP BY a1 HAVING COUNT(*) > 0"; Alias countStar = new Alias(new ExprId(3), new Count(), "count(*)"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(countStar.toSlot(), Literal.of(0L))))) - ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); + logicalProject( + logicalFilter( + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar)))) + ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(countStar.toSlot(), Literal.of(0L))))) + ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); } @Test @@ -281,19 +293,21 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo Alias sumB1 = new Alias(new ExprId(7), new Sum(b1), "sum(b1)"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( + logicalProject( logicalFilter( - logicalJoin( - logicalOlapScan(), - logicalOlapScan() - ) - ) - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, sumB1))) - ).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(new Cast(a1, BigIntType.INSTANCE), - sumB1.toSlot())))) - ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot())))); + logicalProject( + logicalAggregate( + logicalProject( + logicalFilter( + logicalJoin( + logicalOlapScan(), + logicalOlapScan() + ) + )) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, sumB1))) + )).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(new Cast(a1, BigIntType.INSTANCE), + sumB1.toSlot())))) + ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot())))); } @Test @@ -331,6 +345,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo new ExprId(0), "pk", TinyIntType.INSTANCE, true, ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") ); + SlotReference pk1 = new SlotReference( + new ExprId(6), "(pk + 1)", IntegerType.INSTANCE, true, + ImmutableList.of() + ); SlotReference a1 = new SlotReference( new ExprId(1), "a1", TinyIntType.INSTANCE, true, ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") @@ -339,40 +357,42 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo new ExprId(2), "a2", TinyIntType.INSTANCE, true, ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") ); - Alias pk1 = new Alias(new ExprId(6), new Add(pk, Literal.of((byte) 1)), "(pk + 1)"); Alias pk11 = new Alias(new ExprId(7), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)"); Alias pk2 = new Alias(new ExprId(8), new Add(pk, Literal.of((byte) 2)), "(pk + 2)"); Alias sumA1 = new Alias(new ExprId(9), new Sum(a1), "SUM(a1)"); - Alias countA11 = new Alias(new ExprId(10), new Add(new Count(a1), Literal.of((byte) 1)), "(COUNT(a1) + 1)"); + Alias countA1 = new Alias(new ExprId(13), new Count(a1), "count(a1)"); + Alias countA11 = new Alias(new ExprId(10), new Add(countA1.toSlot(), Literal.of((byte) 1)), "(COUNT(a1) + 1)"); Alias sumA1A2 = new Alias(new ExprId(11), new Sum(new Add(a1, a2)), "SUM((a1 + a2))"); Alias v1 = new Alias(new ExprId(12), new Count(a2), "v1"); PlanChecker.from(connectContext).analyze(sql) .matches( - logicalProject( - logicalFilter( - logicalAggregate( + logicalProject( logicalFilter( - logicalJoin( - logicalOlapScan(), - logicalOlapScan() - ) + logicalProject( + logicalAggregate( + logicalProject( + logicalFilter( + logicalJoin( + logicalOlapScan(), + logicalOlapScan() + ) + )) + ).when(FieldChecker.check("outputExpressions", + Lists.newArrayList(pk, pk1, sumA1, countA1, sumA1A2, v1)))) + ).when(FieldChecker.check("conjuncts", + ImmutableSet.of( + new GreaterThan(pk.toSlot(), Literal.of((byte) 0)), + new GreaterThan(countA11.toSlot(), Literal.of(0L)), + new GreaterThan(new Add(sumA1A2.toSlot(), Literal.of((byte) 1)), Literal.of(0L)), + new GreaterThan(new Add(v1.toSlot(), Literal.of((byte) 1)), Literal.of(0L)), + new GreaterThan(v1.toSlot(), Literal.of(0L)) + )) ) - ).when(FieldChecker.check("outputExpressions", - Lists.newArrayList(pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1, pk))) - ).when(FieldChecker.check("conjuncts", - ImmutableSet.of( - new GreaterThan(pk.toSlot(), Literal.of((byte) 0)), - new GreaterThan(countA11.toSlot(), Literal.of(0L)), - new GreaterThan(new Add(sumA1A2.toSlot(), Literal.of((byte) 1)), Literal.of(0L)), - new GreaterThan(new Add(v1.toSlot(), Literal.of((byte) 1)), Literal.of(0L)), - new GreaterThan(v1.toSlot(), Literal.of(0L)) - )) - ) - ).when(FieldChecker.check( - "projects", Lists.newArrayList( - pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1).stream() - .map(Alias::toSlot).collect(Collectors.toList())) - )); + ).when(FieldChecker.check( + "projects", Lists.newArrayList( + pk1, pk11.toSlot(), pk2.toSlot(), sumA1.toSlot(), countA11.toSlot(), sumA1A2.toSlot(), v1.toSlot()) + ) + )); } @Test @@ -391,9 +411,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo .matches( logicalProject( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true)))) ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); @@ -402,9 +423,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo PlanChecker.from(connectContext).analyze(sql) .matches( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true))))); sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 ORDER BY SUM(a2)"; @@ -420,9 +442,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo PlanChecker.from(connectContext).analyze(sql) .matches( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true))))); sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 ORDER BY MIN(pk)"; @@ -444,9 +467,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo .matches( logicalProject( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(minPK.toSlot(), true, true)))) ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot())))); @@ -455,9 +479,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo PlanChecker.from(connectContext).analyze(sql) .matches( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA1A2.toSlot(), true, true))))); sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 ORDER BY SUM(a1 + a2 + 3)"; @@ -467,9 +492,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo .matches( logicalProject( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA1A23.toSlot(), true, true)))) ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot())))); @@ -479,9 +505,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo .matches( logicalProject( logicalSort( - logicalAggregate( - logicalOlapScan() - ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar))) + logicalProject( + logicalAggregate( + logicalProject(logicalOlapScan()) + ).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar)))) ).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(countStar.toSlot(), true, true)))) ).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot())))); } @@ -495,6 +522,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo new ExprId(0), "pk", TinyIntType.INSTANCE, true, ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") ); + SlotReference pk1 = new SlotReference( + new ExprId(6), "(pk + 1)", IntegerType.INSTANCE, true, + ImmutableList.of() + ); SlotReference a1 = new SlotReference( new ExprId(1), "a1", TinyIntType.INSTANCE, true, ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") @@ -503,40 +534,41 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo new ExprId(2), "a2", TinyIntType.INSTANCE, true, ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1") ); - Alias pk1 = new Alias(new ExprId(6), new Add(pk, Literal.of((byte) 1)), "(pk + 1)"); Alias pk11 = new Alias(new ExprId(7), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)"); Alias pk2 = new Alias(new ExprId(8), new Add(pk, Literal.of((byte) 2)), "(pk + 2)"); Alias sumA1 = new Alias(new ExprId(9), new Sum(a1), "SUM(a1)"); + Alias countA1 = new Alias(new ExprId(13), new Count(a1), "count(a1)"); Alias countA11 = new Alias(new ExprId(10), new Add(new Count(a1), Literal.of((byte) 1)), "(COUNT(a1) + 1)"); Alias sumA1A2 = new Alias(new ExprId(11), new Sum(new Add(a1, a2)), "SUM((a1 + a2))"); Alias v1 = new Alias(new ExprId(12), new Count(a2), "v1"); PlanChecker.from(connectContext).analyze(sql) - .matches( - logicalProject( - logicalSort( - logicalAggregate( - logicalFilter( - logicalJoin( - logicalOlapScan(), - logicalOlapScan() - ) - ) - ).when(FieldChecker.check("outputExpressions", - Lists.newArrayList(pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1, pk))) - ).when(FieldChecker.check("orderKeys", - ImmutableList.of( - new OrderKey(pk, true, true), - new OrderKey(countA11.toSlot(), true, true), - new OrderKey(new Add(sumA1A2.toSlot(), new TinyIntLiteral((byte) 1)), true, true), - new OrderKey(new Add(v1.toSlot(), new TinyIntLiteral((byte) 1)), true, true), - new OrderKey(v1.toSlot(), true, true) - ) - )) - ).when(FieldChecker.check( - "projects", Lists.newArrayList( - pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1).stream() - .map(Alias::toSlot).collect(Collectors.toList())) - )); + .matches(logicalProject(logicalSort(logicalProject(logicalAggregate(logicalProject( + logicalFilter(logicalJoin(logicalOlapScan(), logicalOlapScan())))).when( + FieldChecker.check("outputExpressions", Lists.newArrayList(pk, pk1, + sumA1, countA1, sumA1A2, v1))))).when(FieldChecker.check( + "orderKeys", + ImmutableList.of(new OrderKey(pk, true, true), + new OrderKey( + countA11.toSlot(), true, true), + new OrderKey( + new Add(sumA1A2.toSlot(), + new TinyIntLiteral( + (byte) 1)), + true, true), + new OrderKey( + new Add(v1.toSlot(), + new TinyIntLiteral( + (byte) 1)), + true, true), + new OrderKey(v1.toSlot(), true, true))))) + .when(FieldChecker.check("projects", + Lists.newArrayList(pk1, + pk11.toSlot(), + pk2.toSlot(), + sumA1.toSlot(), + countA11.toSlot(), + sumA1A2.toSlot(), + v1.toSlot())))); } @Test diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java similarity index 99% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregateTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java index 32f7b324f9..3808fd1842 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.rules.rewrite; +package org.apache.doris.nereids.rules.analysis; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java index 6f3bfaa7e5..34c1630918 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java @@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.implementation.AggregateStrategies; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.AggregateExpression; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ColumnPruningTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ColumnPruningTest.java index 04e84ab8e8..5c43d7274d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ColumnPruningTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ColumnPruningTest.java @@ -299,15 +299,16 @@ public class ColumnPruningTest extends TestWithFeService implements MemoPatternM .matches( logicalProject( logicalSubQueryAlias( - logicalAggregate( logicalProject( - logicalOlapScan() - ).when(p -> getOutputQualifiedNames(p).equals( - ImmutableList.of("default_cluster:test.student.id") - )) - ).when(agg -> getOutputQualifiedNames(agg.getOutputs()).equals( - ImmutableList.of("default_cluster:test.student.id") - )) + logicalAggregate( + logicalProject( + logicalOlapScan() + ).when(p -> getOutputQualifiedNames(p).equals( + ImmutableList.of("default_cluster:test.student.id") + )) + ).when(agg -> getOutputQualifiedNames(agg.getOutputs()).equals( + ImmutableList.of("default_cluster:test.student.id") + ))) ) ) ); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpressionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpressionTest.java index e676caa37a..476131e6b0 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpressionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpressionTest.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.properties.OrderKey; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoinTest.java index 5667f3f2c5..5a98b07bcf 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownAliasThroughJoinTest.java @@ -18,6 +18,10 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; @@ -30,6 +34,8 @@ import org.apache.doris.nereids.util.PlanConstructor; import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Test; +import java.util.List; + class PushdownAliasThroughJoinTest implements MemoPatternMatchSupported { private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); @@ -99,4 +105,21 @@ class PushdownAliasThroughJoinTest implements MemoPatternMatchSupported { && project.getProjects().get(1).toSql().equals("2name")) ); } + + @Test + void testNoPushdownMarkJoin() { + List<NamedExpression> projects = + ImmutableList.of(new MarkJoinSlotReference(new ExprId(101), "markSlot1", false), + new Alias(new MarkJoinSlotReference(new ExprId(102), "markSlot2", false), + "markSlot2")); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)).projectExprs(projects).build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushdownAliasThroughJoin()) + .matches(logicalProject(logicalJoin(logicalOlapScan(), logicalOlapScan())) + .when(project -> project.getProjects().get(0).toSql().equals("markSlot1") + && project.getProjects().get(1).toSql() + .equals("markSlot2 AS `markSlot2`"))); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashConditionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashConditionTest.java index 29cc509d95..dfad75d5d8 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashConditionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashConditionTest.java @@ -135,20 +135,22 @@ public class PushdownExpressionsInHashConditionTest extends TestWithFeService im .applyTopDown(new FindHashConditionForJoin()) .applyTopDown(new PushdownExpressionsInHashCondition()) .matches( - logicalProject( - logicalJoin( - logicalProject( - logicalOlapScan() - ), - logicalProject( - logicalSubQueryAlias( - logicalAggregate( - logicalOlapScan() - ) + logicalProject( + logicalJoin( + logicalProject( + logicalOlapScan() + ), + logicalProject( + logicalSubQueryAlias( + logicalProject( + logicalAggregate( + logicalProject( + logicalOlapScan() + ))) + ) + ) ) - ) ) - ) ); } @@ -168,8 +170,12 @@ public class PushdownExpressionsInHashConditionTest extends TestWithFeService im logicalProject( logicalSubQueryAlias( logicalSort( - logicalAggregate( - logicalOlapScan() + logicalProject( + logicalAggregate( + logicalProject( + logicalOlapScan() + ) + ) ) ) ) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java index a3bd46eb4f..ed5a96933d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java @@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.rewrite.mv; import org.apache.doris.common.FeConstants; import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject; import org.apache.doris.nereids.rules.rewrite.MergeProjects; +import org.apache.doris.nereids.rules.rewrite.PushdownFilterThroughProject; import org.apache.doris.nereids.trees.plans.PreAggStatus; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; @@ -188,7 +189,8 @@ class SelectRollupIndexTest extends BaseMaterializedIndexSelectTest implements M PlanChecker.from(connectContext) .analyze(sql) .applyBottomUp(new LogicalSubQueryAliasToLogicalProject()) - .applyTopDown(new MergeProjects()) + .applyTopDown(new PushdownFilterThroughProject()) + .applyBottomUp(new MergeProjects()) .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) .matches(logicalOlapScan().when(scan -> { diff --git a/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out b/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out new file mode 100644 index 0000000000..5b97935639 --- /dev/null +++ b/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out @@ -0,0 +1,50 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !sql1 -- +3 + +-- !sql2 -- +3 + +-- !sql3 -- +3 + +-- !sql4 -- +false + +-- !sql5 -- +false + +-- !sql6 -- +true + +-- !sql7 -- +2 + +-- !sql8 -- +4 +4 + +-- !sql9 -- +4 +4 + +-- !sql10 -- +false +true + +-- !sql11 -- +false +true + +-- !sql12 -- +true +true + +-- !sql13 -- +2 +2 + +-- !sql14 -- +\N 2.0 +2020-09-09 2.0 + diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query1.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query1.out index 8c934fb187..0aa36ae310 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query1.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query1.out @@ -23,7 +23,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ----------------PhysicalProject ------------------PhysicalOlapScan[customer] --------------PhysicalDistribute -----------------hashJoin[INNER_JOIN](ctr1.ctr_store_sk = ctr2.ctr_store_sk)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE)) +----------------hashJoin[INNER_JOIN](ctr1.ctr_store_sk = ctr2.ctr_store_sk)(cast(ctr_total_return as DOUBLE) > cast((avg(cast(ctr_total_return as DECIMALV3(38, 4))) * 1.2) as DOUBLE)) ------------------PhysicalProject --------------------hashJoin[INNER_JOIN](store.s_store_sk = ctr1.ctr_store_sk) ----------------------PhysicalDistribute @@ -32,11 +32,10 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------------------PhysicalProject --------------------------filter((cast(s_state as VARCHAR(*)) = 'SD')) ----------------------------PhysicalOlapScan[store] -------------------PhysicalProject ---------------------hashAgg[GLOBAL] -----------------------PhysicalDistribute -------------------------hashAgg[LOCAL] ---------------------------PhysicalDistribute -----------------------------PhysicalProject -------------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute +----------------------hashAgg[LOCAL] +------------------------PhysicalDistribute +--------------------------PhysicalProject +----------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query30.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query30.out index df28c5bee4..83982f3782 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query30.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query30.out @@ -24,7 +24,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------PhysicalDistribute --------PhysicalTopN ----------PhysicalProject -------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE)) +------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(cast(ctr_total_return as DECIMALV3(38, 4))) * 1.2) as DOUBLE)) --------------hashJoin[INNER_JOIN](ctr1.ctr_customer_sk = customer.c_customer_sk) ----------------PhysicalDistribute ------------------PhysicalCteConsumer ( cteId=CTEId#0 ) @@ -38,11 +38,10 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --------------------------filter((cast(ca_state as VARCHAR(*)) = 'IN')) ----------------------------PhysicalOlapScan[customer_address] --------------PhysicalDistribute -----------------PhysicalProject -------------------hashAgg[GLOBAL] ---------------------PhysicalDistribute -----------------------hashAgg[LOCAL] -------------------------PhysicalDistribute ---------------------------PhysicalProject -----------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) +----------------hashAgg[GLOBAL] +------------------PhysicalDistribute +--------------------hashAgg[LOCAL] +----------------------PhysicalDistribute +------------------------PhysicalProject +--------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query51.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query51.out index 8ba49dc8d6..b8d6435601 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query51.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query51.out @@ -14,30 +14,32 @@ PhysicalResultSink ----------------------PhysicalWindow ------------------------PhysicalQuickSort --------------------------PhysicalDistribute -----------------------------hashAgg[GLOBAL] -------------------------------PhysicalDistribute ---------------------------------hashAgg[LOCAL] -----------------------------------PhysicalProject -------------------------------------hashJoin[INNER_JOIN](store_sales.ss_sold_date_sk = date_dim.d_date_sk) ---------------------------------------PhysicalProject -----------------------------------------PhysicalOlapScan[store_sales] ---------------------------------------PhysicalDistribute +----------------------------PhysicalProject +------------------------------hashAgg[GLOBAL] +--------------------------------PhysicalDistribute +----------------------------------hashAgg[LOCAL] +------------------------------------PhysicalProject +--------------------------------------hashJoin[INNER_JOIN](store_sales.ss_sold_date_sk = date_dim.d_date_sk) ----------------------------------------PhysicalProject -------------------------------------------filter((date_dim.d_month_seq <= 1227)(date_dim.d_month_seq >= 1216)) ---------------------------------------------PhysicalOlapScan[date_dim] +------------------------------------------PhysicalOlapScan[store_sales] +----------------------------------------PhysicalDistribute +------------------------------------------PhysicalProject +--------------------------------------------filter((date_dim.d_month_seq <= 1227)(date_dim.d_month_seq >= 1216)) +----------------------------------------------PhysicalOlapScan[date_dim] --------------------PhysicalProject ----------------------PhysicalWindow ------------------------PhysicalQuickSort --------------------------PhysicalDistribute -----------------------------hashAgg[GLOBAL] -------------------------------PhysicalDistribute ---------------------------------hashAgg[LOCAL] -----------------------------------PhysicalProject -------------------------------------hashJoin[INNER_JOIN](web_sales.ws_sold_date_sk = date_dim.d_date_sk) ---------------------------------------PhysicalProject -----------------------------------------PhysicalOlapScan[web_sales] ---------------------------------------PhysicalDistribute +----------------------------PhysicalProject +------------------------------hashAgg[GLOBAL] +--------------------------------PhysicalDistribute +----------------------------------hashAgg[LOCAL] +------------------------------------PhysicalProject +--------------------------------------hashJoin[INNER_JOIN](web_sales.ws_sold_date_sk = date_dim.d_date_sk) ----------------------------------------PhysicalProject -------------------------------------------filter((date_dim.d_month_seq >= 1216)(date_dim.d_month_seq <= 1227)) ---------------------------------------------PhysicalOlapScan[date_dim] +------------------------------------------PhysicalOlapScan[web_sales] +----------------------------------------PhysicalDistribute +------------------------------------------PhysicalProject +--------------------------------------------filter((date_dim.d_month_seq >= 1216)(date_dim.d_month_seq <= 1227)) +----------------------------------------------PhysicalOlapScan[date_dim] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query81.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query81.out index 77c7b273ba..bfcec6ce41 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query81.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query81.out @@ -24,7 +24,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------PhysicalDistribute --------PhysicalTopN ----------PhysicalProject -------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE)) +------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(cast(ctr_total_return as DECIMALV3(38, 4))) * 1.2) as DOUBLE)) --------------hashJoin[INNER_JOIN](ctr1.ctr_customer_sk = customer.c_customer_sk) ----------------PhysicalDistribute ------------------PhysicalCteConsumer ( cteId=CTEId#0 ) @@ -38,11 +38,10 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --------------------------filter((cast(ca_state as VARCHAR(*)) = 'CA')) ----------------------------PhysicalOlapScan[customer_address] --------------PhysicalDistribute -----------------PhysicalProject -------------------hashAgg[GLOBAL] ---------------------PhysicalDistribute -----------------------hashAgg[LOCAL] -------------------------PhysicalDistribute ---------------------------PhysicalProject -----------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) +----------------hashAgg[GLOBAL] +------------------PhysicalDistribute +--------------------hashAgg[LOCAL] +----------------------PhysicalDistribute +------------------------PhysicalProject +--------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q20.out b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q20.out index 6114877bc9..9913e27f5f 100644 --- a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q20.out +++ b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q20.out @@ -9,13 +9,12 @@ PhysicalResultSink ------------PhysicalDistribute --------------PhysicalProject ----------------hashJoin[INNER_JOIN](lineitem.l_partkey = partsupp.ps_partkey)(lineitem.l_suppkey = partsupp.ps_suppkey)(cast(ps_availqty as DECIMALV3(38, 3)) > (0.5 * sum(l_quantity))) -------------------PhysicalProject ---------------------hashAgg[GLOBAL] -----------------------PhysicalDistribute -------------------------hashAgg[LOCAL] ---------------------------PhysicalProject -----------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01)) -------------------------------PhysicalOlapScan[lineitem] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute +----------------------hashAgg[LOCAL] +------------------------PhysicalProject +--------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01)) +----------------------------PhysicalOlapScan[lineitem] ------------------PhysicalDistribute --------------------hashJoin[LEFT_SEMI_JOIN](partsupp.ps_partkey = part.p_partkey) ----------------------PhysicalProject diff --git a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q20.out b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q20.out index 6114877bc9..9913e27f5f 100644 --- a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q20.out +++ b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q20.out @@ -9,13 +9,12 @@ PhysicalResultSink ------------PhysicalDistribute --------------PhysicalProject ----------------hashJoin[INNER_JOIN](lineitem.l_partkey = partsupp.ps_partkey)(lineitem.l_suppkey = partsupp.ps_suppkey)(cast(ps_availqty as DECIMALV3(38, 3)) > (0.5 * sum(l_quantity))) -------------------PhysicalProject ---------------------hashAgg[GLOBAL] -----------------------PhysicalDistribute -------------------------hashAgg[LOCAL] ---------------------------PhysicalProject -----------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01)) -------------------------------PhysicalOlapScan[lineitem] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute +----------------------hashAgg[LOCAL] +------------------------PhysicalProject +--------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01)) +----------------------------PhysicalOlapScan[lineitem] ------------------PhysicalDistribute --------------------hashJoin[LEFT_SEMI_JOIN](partsupp.ps_partkey = part.p_partkey) ----------------------PhysicalProject diff --git a/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy b/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy new file mode 100644 index 0000000000..0521334d8a --- /dev/null +++ b/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy @@ -0,0 +1,120 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_subquery_in_project") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + sql """drop table if exists test_sql;""" + sql """ + CREATE TABLE `test_sql` ( + `user_id` varchar(10) NULL, + `dt` date NULL, + `city` varchar(20) NULL, + `age` int(11) NULL + ) ENGINE=OLAP + UNIQUE KEY(`user_id`) + COMMENT 'test' + DISTRIBUTED BY HASH(`user_id`) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1", + "is_being_synced" = "false", + "storage_format" = "V2", + "light_schema_change" = "true", + "disable_auto_compaction" = "false", + "enable_single_replica_compaction" = "false" + ); + """ + + sql """ insert into test_sql values (1,'2020-09-09',2,3);""" + + qt_sql1 """ + select (select age from test_sql) col from test_sql order by col; + """ + + qt_sql2 """ + select (select sum(age) from test_sql) col from test_sql order by col; + """ + + qt_sql3 """ + select (select sum(age) from test_sql t2 where t2.dt = t1.dt ) col from test_sql t1 order by col; + """ + + qt_sql4 """ + select age in (select user_id from test_sql) col from test_sql order by col; + """ + + qt_sql5 """ + select age in (select user_id from test_sql t2 where t2.user_id = t1.age) col from test_sql t1 order by col; + """ + + qt_sql6 """ + select exists ( select user_id from test_sql ) col from test_sql order by col; + """ + + qt_sql7 """ + select case when age in (select user_id from test_sql) or age in (select user_id from test_sql t2 where t2.user_id = t1.age) or exists ( select user_id from test_sql ) or exists ( select t2.user_id from test_sql t2 where t2.age = t1.user_id) or age < (select sum(age) from test_sql t2 where t2.dt = t1.dt ) then 2 else 1 end col from test_sql t1 order by col; + """ + + sql """ insert into test_sql values (2,'2020-09-09',2,1);""" + + try { + sql """ + select (select age from test_sql) col from test_sql order by col; + """ + } catch (Exception ex) { + assertTrue(ex.getMessage().contains("Expected EQ 1 to be returned by expression")) + } + + qt_sql8 """ + select (select sum(age) from test_sql) col from test_sql order by col; + """ + + qt_sql9 """ + select (select sum(age) from test_sql t2 where t2.dt = t1.dt ) col from test_sql t1 order by col; + """ + + qt_sql10 """ + select age in (select user_id from test_sql) col from test_sql order by col; + """ + + qt_sql11 """ + select age in (select user_id from test_sql t2 where t2.user_id = t1.age) col from test_sql t1 order by col; + """ + + qt_sql12 """ + select exists ( select user_id from test_sql ) col from test_sql order by col; + """ + + qt_sql13 """ + select case when age in (select user_id from test_sql) or age in (select user_id from test_sql t2 where t2.user_id = t1.age) or exists ( select user_id from test_sql ) or exists ( select t2.user_id from test_sql t2 where t2.age = t1.user_id) or age < (select sum(age) from test_sql t2 where t2.dt = t1.dt ) then 2 else 1 end col from test_sql t1 order by col; + """ + + qt_sql14 """ + select dt,case when 'med'='med' then ( + select sum(midean) from ( + select sum(score) / count(*) as midean + from ( + select age score,row_number() over (order by age desc) as desc_math, + row_number() over (order by age asc) as asc_math from test_sql + ) as order_table + where asc_math in (desc_math, desc_math + 1, desc_math - 1)) m + ) + end 'test' from test_sql group by cube(dt) order by dt; + """ + + sql """drop table if exists test_sql;""" +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org