This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 82716ec99d [fix](Nereids) type coercion for subquery (#17661)
82716ec99d is described below
commit 82716ec99d28d4174a9bef9158b7ed7abb5ead6b
Author: zhengshiJ <[email protected]>
AuthorDate: Tue Mar 21 20:38:06 2023 +0800
[fix](Nereids) type coercion for subquery (#17661)
Complete the type coercion of the subquery in the function Binder process.
Expressions generated when subqueries are nested are uniformly converted to
implicit types in the analyze stage.
Method: Add a typeCoercionExpr field to the subquery expression to store
the generated cast information.
Fix scenario where scalarSubQuery handles arithmetic expressions when
implicitly converting types
---
.../nereids/rules/analysis/BindExpression.java | 3 +-
.../nereids/rules/analysis/FunctionBinder.java | 19 ++++++
.../nereids/rules/analysis/SubqueryToApply.java | 64 ++++++++++----------
.../rules/rewrite/logical/InApplyToJoin.java | 15 +----
.../doris/nereids/trees/expressions/Exists.java | 16 ++++-
.../nereids/trees/expressions/InSubquery.java | 28 ++++++++-
.../doris/nereids/trees/expressions/ListQuery.java | 14 +++++
.../nereids/trees/expressions/ScalarSubquery.java | 18 +++++-
.../nereids/trees/expressions/SubqueryExpr.java | 29 +++++++--
.../doris/nereids/util/TypeCoercionUtils.java | 13 ++++-
.../doris/nereids/trees/plans/MarkJoinTest.java | 68 +++++++++++++++++-----
.../nereids_syntax_p0/sub_query_correlated.out | 19 ++++++
.../nereids_syntax_p0/sub_query_correlated.groovy | 9 +++
13 files changed, 244 insertions(+), 71 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
index 701e18c9d1..6e74cc1dbf 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
@@ -77,6 +77,7 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
+import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -142,7 +143,7 @@ public class BindExpression implements AnalysisRuleFactory {
Set<Expression> boundConjuncts =
filter.getConjuncts().stream()
.map(expr -> bindSlot(expr, filter.children(),
ctx.cascadesContext))
.map(expr -> bindFunction(expr,
ctx.cascadesContext))
- .collect(Collectors.toSet());
+
.collect(Collectors.toCollection(LinkedHashSet::new));
return new LogicalFilter<>(boundConjuncts, filter.child());
})
),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FunctionBinder.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FunctionBinder.java
index 8fe09fb510..031d8c74d7 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FunctionBinder.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FunctionBinder.java
@@ -31,9 +31,12 @@ import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Divide;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
+import org.apache.doris.nereids.trees.expressions.InSubquery;
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
+import org.apache.doris.nereids.trees.expressions.ListQuery;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
@@ -212,4 +215,20 @@ public class FunctionBinder extends
DefaultExpressionRewriter<CascadesContext> {
Between newBetween = between.withChildren(rewrittenChildren);
return TypeCoercionUtils.processBetween(newBetween);
}
+
+ @Override
+ public Expression visitInSubquery(InSubquery inSubquery, CascadesContext
context) {
+ Expression newCompareExpr = inSubquery.getCompareExpr().accept(this,
context);
+ Expression newListQuery = inSubquery.getListQuery().accept(this,
context);
+ ComparisonPredicate newCpAfterUnNestingSubquery =
+ new EqualTo(newCompareExpr, ((ListQuery)
newListQuery).getQueryPlan().getOutput().get(0));
+ ComparisonPredicate afterTypeCoercion = (ComparisonPredicate)
TypeCoercionUtils.processComparisonPredicate(
+ newCpAfterUnNestingSubquery, newCompareExpr, newListQuery);
+ if (!newCompareExpr.getDataType().isBigIntType() &&
newListQuery.getDataType().isBitmapType()) {
+ newCompareExpr = new Cast(newCompareExpr, BigIntType.INSTANCE);
+ }
+ return new InSubquery(newCompareExpr, (ListQuery)
afterTypeCoercion.right(),
+ inSubquery.getCorrelateSlots(), ((ListQuery)
afterTypeCoercion.right()).getTypeCoercionExpr(),
+ inSubquery.isNot());
+ }
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java
index 7ca11befe9..a0b4bfc652 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java
@@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Exists;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InSubquery;
@@ -42,7 +43,6 @@ import
org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
-import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
@@ -290,11 +290,11 @@ public class SubqueryToApply implements
AnalysisRuleFactory {
@Override
public Expression visitScalarSubquery(ScalarSubquery scalar,
SubqueryContext context) {
- context.setSubqueryCorrespondingConject(scalar,
scalar.getQueryPlan().getOutput().get(0));
+ context.setSubqueryCorrespondingConject(scalar,
scalar.getSubqueryOutput());
// When there is only one scalarSubQuery and CorrelateSlots is
empty
// it will not be processed by MarkJoin, so it can be returned
directly
if (context.onlySingleSubquery() &&
scalar.getCorrelateSlots().isEmpty()) {
- return scalar.getQueryPlan().getOutput().get(0);
+ return scalar.getSubqueryOutput();
}
MarkJoinSlotReference markJoinSlotReference =
@@ -302,19 +302,17 @@ public class SubqueryToApply implements
AnalysisRuleFactory {
if (isMarkJoin) {
context.setSubqueryToMarkJoinSlot(scalar,
Optional.of(markJoinSlotReference));
}
- return isMarkJoin ? markJoinSlotReference :
scalar.getQueryPlan().getOutput().get(0);
+ return isMarkJoin ? markJoinSlotReference :
scalar.getSubqueryOutput();
}
@Override
public Expression visitNot(Not not, SubqueryContext context) {
// Need to re-update scalarSubQuery unequal conditions into
subqueryCorrespondingConject
if (not.child() instanceof BinaryOperator
- && (((BinaryOperator) not.child()).left() instanceof
ScalarSubquery
- || ((BinaryOperator) not.child()).right() instanceof
ScalarSubquery)) {
+ && (((BinaryOperator)
not.child()).left().containsType(ScalarSubquery.class)
+ || ((BinaryOperator)
not.child()).right().containsType(ScalarSubquery.class))) {
Expression newChild = replace(not.child(), context);
- ScalarSubquery subquery = ((BinaryOperator)
not.child()).left() instanceof ScalarSubquery
- ? (ScalarSubquery) ((BinaryOperator)
not.child()).left()
- : (ScalarSubquery) ((BinaryOperator)
not.child()).right();
+ ScalarSubquery subquery =
collectScalarSubqueryForBinaryOperator((BinaryOperator) not.child());
context.updateSubqueryCorrespondingConjunctInNot(subquery);
return
context.getSubqueryToMarkJoinSlotValue(subquery).isPresent() ? newChild : new
Not(newChild);
}
@@ -324,8 +322,9 @@ public class SubqueryToApply implements AnalysisRuleFactory
{
@Override
public Expression visitBinaryOperator(BinaryOperator binaryOperator,
SubqueryContext context) {
- boolean atLeastOneChildIsScalarSubquery =
- binaryOperator.left() instanceof ScalarSubquery ||
binaryOperator.right() instanceof ScalarSubquery;
+ boolean atLeastOneChildContainsScalarSubquery =
+ binaryOperator.left().containsType(ScalarSubquery.class)
+ ||
binaryOperator.right().containsType(ScalarSubquery.class);
boolean currentMarkJoin =
((binaryOperator.left().anyMatch(SubqueryExpr.class::isInstance)
||
binaryOperator.right().anyMatch(SubqueryExpr.class::isInstance))
&& (binaryOperator instanceof Or)) ||
isMarkJoin;
@@ -334,9 +333,10 @@ public class SubqueryToApply implements
AnalysisRuleFactory {
isMarkJoin = currentMarkJoin;
Expression right = replace(binaryOperator.right(), context);
- if (atLeastOneChildIsScalarSubquery) {
+ if (atLeastOneChildContainsScalarSubquery && !(binaryOperator
instanceof CompoundPredicate)) {
return context.replaceBinaryOperator(binaryOperator, left,
right, isProject);
}
+
return binaryOperator.withChildren(left, right);
}
}
@@ -406,36 +406,36 @@ public class SubqueryToApply implements
AnalysisRuleFactory {
* logicalFilter(predicate=k1 > scalarSub or exists)
* -->
* logicalFilter(predicate=$c$1 or $c$2)
- *
- * For non-MarkJoin scalarSubQuery, do implicit type conversion.
- * e.g.
- * logicalFilter(predicate=k1 > scalarSub(sum(k2)))
- * -->
- * logicalFilter(predicate=Cast(k1[#0] as BIGINT) = sum(k2)[#1])
*/
public Expression replaceBinaryOperator(BinaryOperator binaryOperator,
Expression left,
Expression right,
boolean isProject) {
- boolean leftIsScalar = binaryOperator.left() instanceof
ScalarSubquery;
- ScalarSubquery subquery = leftIsScalar
- ? (ScalarSubquery) binaryOperator.left() :
(ScalarSubquery) binaryOperator.right();
-
- // Perform implicit type conversion on the connection condition of
scalarSubQuery,
- // and record the result in subqueryCorrespondingConjunct
- Expression newLeft = leftIsScalar &&
subqueryToMarkJoinSlot.get(subquery).isPresent()
- ? ((ScalarSubquery)
binaryOperator.left()).getQueryPlan().getOutput().get(0) : left;
- Expression newRight = !leftIsScalar &&
subqueryToMarkJoinSlot.get(subquery).isPresent()
- ? ((ScalarSubquery)
binaryOperator.right()).getQueryPlan().getOutput().get(0) : right;
- Expression newBinary =
TypeCoercionUtils.processComparisonPredicate(
- (ComparisonPredicate) binaryOperator.withChildren(newLeft,
newRight), newLeft, newRight);
+ boolean leftContaionsScalar =
binaryOperator.left().containsType(ScalarSubquery.class);
+ ScalarSubquery subquery =
collectScalarSubqueryForBinaryOperator(binaryOperator);
+
+ // record the result in subqueryCorrespondingConjunct
+ Expression newLeft = leftContaionsScalar &&
subqueryToMarkJoinSlot.get(subquery).isPresent()
+ ? subqueryCorrespondingConjunct.get(subquery) : left;
+ Expression newRight = !leftContaionsScalar &&
subqueryToMarkJoinSlot.get(subquery).isPresent()
+ ? subqueryCorrespondingConjunct.get(subquery) : right;
+ Expression newBinary = binaryOperator.withChildren(newLeft,
newRight);
subqueryCorrespondingConjunct.put(subquery,
- (isProject ? (leftIsScalar ? newLeft : newRight) :
newBinary));
+ (isProject ? (leftContaionsScalar ? newLeft : newRight) :
newBinary));
- if (subqueryToMarkJoinSlot.get(subquery).isPresent()) {
+ if (subqueryToMarkJoinSlot.get(subquery).isPresent() &&
binaryOperator instanceof ComparisonPredicate) {
return subqueryToMarkJoinSlot.get(subquery).get();
}
return newBinary;
}
}
+
+ private static ScalarSubquery
collectScalarSubqueryForBinaryOperator(BinaryOperator binaryOperator) {
+ boolean leftContaionsScalar =
binaryOperator.left().containsType(ScalarSubquery.class);
+ return leftContaionsScalar
+ ? (ScalarSubquery) ((ImmutableSet) binaryOperator.left()
+ .collect(ScalarSubquery.class::isInstance)).asList().get(0)
+ : (ScalarSubquery) ((ImmutableSet) binaryOperator.right()
+ .collect(ScalarSubquery.class::isInstance)).asList().get(0);
+ }
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
index ce95589a7a..8325ad18fb 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
@@ -22,7 +22,6 @@ import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InSubquery;
@@ -36,9 +35,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
-import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.util.ExpressionUtils;
-import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
@@ -84,10 +81,6 @@ public class InApplyToJoin extends OneRewriteRuleFactory {
LogicalAggregate agg = new LogicalAggregate(groupExpressions,
outputExpressions, apply.right());
Expression compareExpr = ((InSubquery)
apply.getSubqueryExpr()).getCompareExpr();
- if (!compareExpr.getDataType().isBigIntType()) {
- //this rule is after type coercion, we need to add cast by
hand
- compareExpr = new Cast(compareExpr, BigIntType.INSTANCE);
- }
Expression expr = new BitmapContains(agg.getOutput().get(0),
compareExpr);
if (((InSubquery) apply.getSubqueryExpr()).isNot()) {
expr = new Not(expr);
@@ -101,14 +94,12 @@ public class InApplyToJoin extends OneRewriteRuleFactory {
//in-predicate to equal
Expression predicate;
Expression left = ((InSubquery)
apply.getSubqueryExpr()).getCompareExpr();
- Expression right = apply.right().getOutput().get(0);
+ Expression right = apply.getSubqueryExpr().getSubqueryOutput();
if (apply.isCorrelated()) {
- predicate = ExpressionUtils.and(
- TypeCoercionUtils.processComparisonPredicate(
- new EqualTo(left, right), left, right),
+ predicate = ExpressionUtils.and(new EqualTo(left, right),
apply.getCorrelationFilter().get());
} else {
- predicate = TypeCoercionUtils.processComparisonPredicate(new
EqualTo(left, right), left, right);
+ predicate = new EqualTo(left, right);
}
if (apply.getSubCorrespondingConject().isPresent()) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Exists.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Exists.java
index 7f2628e03f..28762addd7 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Exists.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Exists.java
@@ -28,6 +28,7 @@ import com.google.common.base.Preconditions;
import java.util.List;
import java.util.Objects;
+import java.util.Optional;
/**
* Exists subquery expression.
@@ -41,8 +42,16 @@ public class Exists extends SubqueryExpr implements
LeafExpression {
}
public Exists(LogicalPlan subquery, List<Slot> correlateSlots, boolean
isNot) {
+ this(Objects.requireNonNull(subquery, "subquery can not be null"),
+ Objects.requireNonNull(correlateSlots, "subquery can not be
null"),
+ Optional.empty(), isNot);
+ }
+
+ public Exists(LogicalPlan subquery, List<Slot> correlateSlots,
+ Optional<Expression> typeCoercionExpr, boolean isNot) {
super(Objects.requireNonNull(subquery, "subquery can not be null"),
- Objects.requireNonNull(correlateSlots, "subquery can not be
null"));
+ Objects.requireNonNull(correlateSlots, "subquery can not be
null"),
+ typeCoercionExpr);
this.isNot = Objects.requireNonNull(isNot, "isNot can not be null");
}
@@ -88,4 +97,9 @@ public class Exists extends SubqueryExpr implements
LeafExpression {
public int hashCode() {
return Objects.hash(this.queryPlan, this.isNot);
}
+
+ @Override
+ public Expression withTypeCoercion(DataType dataType) {
+ return this;
+ }
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InSubquery.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InSubquery.java
index fe1dc5428f..e6487e0c81 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InSubquery.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InSubquery.java
@@ -25,6 +25,7 @@ import com.google.common.base.Preconditions;
import java.util.List;
import java.util.Objects;
+import java.util.Optional;
/**
* In predicate expression.
@@ -43,8 +44,20 @@ public class InSubquery extends SubqueryExpr {
}
public InSubquery(Expression compareExpr, ListQuery listQuery, List<Slot>
correlateSlots, boolean isNot) {
+ this(compareExpr, listQuery, correlateSlots, Optional.empty(), isNot);
+ }
+
+ /**
+ * InSubquery Constructor.
+ */
+ public InSubquery(Expression compareExpr,
+ ListQuery listQuery,
+ List<Slot> correlateSlots,
+ Optional<Expression> typeCoercionExpr,
+ boolean isNot) {
super(Objects.requireNonNull(listQuery.getQueryPlan(), "subquery can
not be null"),
- Objects.requireNonNull(correlateSlots, "correlateSlots can not
be null"));
+ Objects.requireNonNull(correlateSlots, "correlateSlots can not
be null"),
+ typeCoercionExpr);
this.compareExpr = Objects.requireNonNull(compareExpr, "compareExpr
can not be null");
this.listQuery = Objects.requireNonNull(listQuery, "listQuery can not
be null");
this.isNot = Objects.requireNonNull(isNot, "isNot can not be null");
@@ -99,7 +112,9 @@ public class InSubquery extends SubqueryExpr {
return false;
}
InSubquery inSubquery = (InSubquery) o;
- return Objects.equals(this.compareExpr, inSubquery.getCompareExpr())
+ return super.equals(inSubquery)
+ && Objects.equals(this.compareExpr,
inSubquery.getCompareExpr())
+ && Objects.equals(this.listQuery, inSubquery.listQuery)
&& this.isNot == inSubquery.isNot;
}
@@ -107,4 +122,13 @@ public class InSubquery extends SubqueryExpr {
public int hashCode() {
return Objects.hash(this.compareExpr, this.listQuery, this.isNot);
}
+
+ @Override
+ public Expression withTypeCoercion(DataType dataType) {
+ return new InSubquery(compareExpr, listQuery, correlateSlots,
+ dataType == listQuery.queryPlan.getOutput().get(0).getDataType()
+ ? Optional.of(listQuery.queryPlan.getOutput().get(0))
+ : Optional.of(new Cast(listQuery.queryPlan.getOutput().get(0),
dataType)),
+ isNot);
+ }
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ListQuery.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ListQuery.java
index 961a46ff68..bccc090016 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ListQuery.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ListQuery.java
@@ -24,7 +24,9 @@ import org.apache.doris.nereids.types.DataType;
import com.google.common.base.Preconditions;
+import java.util.List;
import java.util.Objects;
+import java.util.Optional;
/**
* Encapsulate LogicalPlan as Expression.
@@ -35,6 +37,10 @@ public class ListQuery extends SubqueryExpr implements
LeafExpression {
super(Objects.requireNonNull(subquery, "subquery can not be null"));
}
+ public ListQuery(LogicalPlan subquery, List<Slot> correlateSlots,
Optional<Expression> typeCoercionExpr) {
+ super(subquery, correlateSlots, typeCoercionExpr);
+ }
+
@Override
public DataType getDataType() {
Preconditions.checkArgument(queryPlan.getOutput().size() == 1);
@@ -54,4 +60,12 @@ public class ListQuery extends SubqueryExpr implements
LeafExpression {
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitListQuery(this, context);
}
+
+ @Override
+ public Expression withTypeCoercion(DataType dataType) {
+ return new ListQuery(queryPlan, correlateSlots,
+ dataType == queryPlan.getOutput().get(0).getDataType()
+ ? Optional.of(queryPlan.getOutput().get(0))
+ : Optional.of(new Cast(queryPlan.getOutput().get(0),
dataType)));
+ }
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ScalarSubquery.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ScalarSubquery.java
index e49e514511..a17cb0701f 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ScalarSubquery.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ScalarSubquery.java
@@ -27,6 +27,7 @@ import com.google.common.base.Preconditions;
import java.util.List;
import java.util.Objects;
+import java.util.Optional;
/**
* A subquery that will return only one row and one column.
@@ -37,8 +38,15 @@ public class ScalarSubquery extends SubqueryExpr implements
LeafExpression {
}
public ScalarSubquery(LogicalPlan subquery, List<Slot> correlateSlots) {
+ this(Objects.requireNonNull(subquery, "subquery can not be null"),
+ Objects.requireNonNull(correlateSlots, "correlateSlots can not
be null"),
+ Optional.empty());
+ }
+
+ public ScalarSubquery(LogicalPlan subquery, List<Slot> correlateSlots,
Optional<Expression> typeCoercionExpr) {
super(Objects.requireNonNull(subquery, "subquery can not be null"),
- Objects.requireNonNull(correlateSlots, "correlateSlots can not
be null"));
+ Objects.requireNonNull(correlateSlots, "correlateSlots can not
be null"),
+ typeCoercionExpr);
}
@Override
@@ -60,4 +68,12 @@ public class ScalarSubquery extends SubqueryExpr implements
LeafExpression {
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitScalarSubquery(this, context);
}
+
+ @Override
+ public Expression withTypeCoercion(DataType dataType) {
+ return new ScalarSubquery(queryPlan, correlateSlots,
+ dataType == queryPlan.getOutput().get(0).getDataType()
+ ? Optional.of(queryPlan.getOutput().get(0))
+ : Optional.of(new Cast(queryPlan.getOutput().get(0),
dataType)));
+ }
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SubqueryExpr.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SubqueryExpr.java
index d98a72f06a..759634623b 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SubqueryExpr.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SubqueryExpr.java
@@ -21,11 +21,13 @@ import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
+import java.util.Optional;
/**
* Subquery Expression.
@@ -34,20 +36,32 @@ public abstract class SubqueryExpr extends Expression {
protected final LogicalPlan queryPlan;
protected final List<Slot> correlateSlots;
+ protected final Optional<Expression> typeCoercionExpr;
+
public SubqueryExpr(LogicalPlan subquery) {
this.queryPlan = Objects.requireNonNull(subquery, "subquery can not be
null");
this.correlateSlots = ImmutableList.of();
+ this.typeCoercionExpr = Optional.empty();
}
- public SubqueryExpr(LogicalPlan subquery, List<Slot> correlateSlots) {
+ public SubqueryExpr(LogicalPlan subquery, List<Slot> correlateSlots,
Optional<Expression> typeCoercionExpr) {
this.queryPlan = Objects.requireNonNull(subquery, "subquery can not be
null");
this.correlateSlots = ImmutableList.copyOf(correlateSlots);
+ this.typeCoercionExpr = typeCoercionExpr;
}
public List<Slot> getCorrelateSlots() {
return correlateSlots;
}
+ public Optional<Expression> getTypeCoercionExpr() {
+ return typeCoercionExpr;
+ }
+
+ public Expression getSubqueryOutput() {
+ return typeCoercionExpr.orElseGet(() -> queryPlan.getOutput().get(0));
+ }
+
@Override
public DataType getDataType() throws UnboundException {
throw new UnboundException("getDataType");
@@ -65,8 +79,10 @@ public abstract class SubqueryExpr extends Expression {
@Override
public String toString() {
- return "(QueryPlan: " + queryPlan
- + "), (CorrelatedSlots: " + correlateSlots + ")";
+ return Utils.toSqlString("SubqueryExpr",
+ "QueryPlan", queryPlan,
+ "CorrelatedSlots", correlateSlots,
+ "typeCoercionExpr", typeCoercionExpr.isPresent() ?
typeCoercionExpr.get() : "null");
}
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
@@ -92,15 +108,18 @@ public abstract class SubqueryExpr extends Expression {
}
SubqueryExpr other = (SubqueryExpr) o;
return Objects.equals(correlateSlots, other.correlateSlots)
- && queryPlan.deepEquals(other.queryPlan);
+ && queryPlan.deepEquals(other.queryPlan)
+ && Objects.equals(typeCoercionExpr, other.typeCoercionExpr);
}
@Override
public int hashCode() {
- return Objects.hash(queryPlan, correlateSlots);
+ return Objects.hash(queryPlan, correlateSlots, typeCoercionExpr);
}
public List<Slot> getOutput() {
return queryPlan.getOutput();
}
+
+ public abstract Expression withTypeCoercion(DataType dataType);
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java
index 720fc36904..ea443bc610 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java
@@ -227,7 +227,7 @@ public class TypeCoercionUtils {
* cast input type if input's datatype is not same with dateType.
*/
public static Expression castIfNotSameType(Expression input, DataType
targetType) {
- if (input.getDataType().equals(targetType)) {
+ if (input.getDataType().equals(targetType) ||
isSubqueryAndDataTypeIsBitmap(input)) {
return input;
} else {
checkCanCastTo(input.getDataType(), targetType);
@@ -235,6 +235,10 @@ public class TypeCoercionUtils {
}
}
+ private static boolean isSubqueryAndDataTypeIsBitmap(Expression input) {
+ return input instanceof SubqueryExpr &&
input.getDataType().isBitmapType();
+ }
+
private static boolean canCastTo(DataType input, DataType target) {
return Type.canCastTo(input.toCatalogDataType(),
target.toCatalogDataType());
}
@@ -263,6 +267,13 @@ public class TypeCoercionUtils {
}
}
}
+ return recordTypeCoercionForSubQuery(input, dataType);
+ }
+
+ private static Expression recordTypeCoercionForSubQuery(Expression input,
DataType dataType) {
+ if (input instanceof SubqueryExpr) {
+ return ((SubqueryExpr) input).withTypeCoercion(dataType);
+ }
return new Cast(input, dataType);
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/MarkJoinTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/MarkJoinTest.java
index b3d34a3f59..9251c28564 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/MarkJoinTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/MarkJoinTest.java
@@ -29,13 +29,13 @@ public class MarkJoinTest extends TestWithFeService {
useDatabase("test");
createTable("CREATE TABLE `test_sq_dj1` (\n"
- + " `c1` bigint(20) NULL,\n"
+ + " `c1` varchar(20) NULL,\n"
+ " `c2` bigint(20) NULL,\n"
- + " `c3` bigint(20) not NULL,\n"
- + " `k4` bigint(20) not NULL,\n"
- + " `k5` bigint(20) NULL\n"
+ + " `c3` int(20) not NULL,\n"
+ + " `k4` bitmap BITMAP_UNION NULL,\n"
+ + " `k5` bitmap BITMAP_UNION NULL\n"
+ ") ENGINE=OLAP\n"
- + "DUPLICATE KEY(`c1`)\n"
+ + "AGGREGATE KEY(`c1`, `c2`, `c3`)\n"
+ "COMMENT 'OLAP'\n"
+ "DISTRIBUTED BY HASH(`c2`) BUCKETS 1\n"
+ "PROPERTIES (\n"
@@ -49,10 +49,10 @@ public class MarkJoinTest extends TestWithFeService {
+ " `c1` bigint(20) NULL,\n"
+ " `c2` bigint(20) NULL,\n"
+ " `c3` bigint(20) not NULL,\n"
- + " `k4` bigint(20) not NULL,\n"
- + " `k5` bigint(20) NULL\n"
+ + " `k4` bitmap BITMAP_UNION NULL,\n"
+ + " `k5` bitmap BITMAP_UNION NULL\n"
+ ") ENGINE=OLAP\n"
- + "DUPLICATE KEY(`c1`)\n"
+ + "AGGREGATE KEY(`c1`, `c2`, `c3`)\n"
+ "COMMENT 'OLAP'\n"
+ "DISTRIBUTED BY HASH(`c2`) BUCKETS 1\n"
+ "PROPERTIES (\n"
@@ -178,17 +178,17 @@ public class MarkJoinTest extends TestWithFeService {
.checkPlannerResult("SELECT CASE\n"
+ " WHEN (\n"
+ " SELECT COUNT(*) / 2\n"
- + " FROM test_sq_dj1\n"
+ + " FROM test_sq_dj2\n"
+ " ) > c1 THEN (\n"
+ " SELECT AVG(c1)\n"
- + " FROM test_sq_dj1\n"
+ + " FROM test_sq_dj2\n"
+ " )\n"
+ " ELSE (\n"
+ " SELECT SUM(c2)\n"
- + " FROM test_sq_dj1\n"
+ + " FROM test_sq_dj2\n"
+ " )\n"
+ " END AS kk4\n"
- + " FROM test_sq_dj1 ;");
+ + " FROM test_sq_dj2 ;");
}
@Test
@@ -197,17 +197,17 @@ public class MarkJoinTest extends TestWithFeService {
.checkPlannerResult("SELECT CASE\n"
+ " WHEN exists (\n"
+ " SELECT COUNT(*) / 2\n"
- + " FROM test_sq_dj1\n"
+ + " FROM test_sq_dj2\n"
+ " ) THEN (\n"
+ " SELECT AVG(c1)\n"
- + " FROM test_sq_dj1\n"
+ + " FROM test_sq_dj2\n"
+ " )\n"
+ " ELSE (\n"
+ " SELECT SUM(c2)\n"
- + " FROM test_sq_dj1\n"
+ + " FROM test_sq_dj2\n"
+ " )\n"
+ " END AS kk4\n"
- + " FROM test_sq_dj1 ;");
+ + " FROM test_sq_dj2 ;");
}
@Test
@@ -246,4 +246,40 @@ public class MarkJoinTest extends TestWithFeService {
+ " OR c1 < (SELECT sum(c1) FROM test_sq_dj2 WHERE
test_sq_dj1.c1 = test_sq_dj2.c1)"
+ " AND exists (SELECT c1 FROM test_sq_dj2 WHERE
test_sq_dj1.c1 = test_sq_dj2.c1)");
}
+
+ @Test
+ public void test20() {
+ PlanChecker.from(connectContext)
+ .checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c1 <
(cast('1.2' as decimal(2,1)) * (SELECT sum(c1) FROM test_sq_dj2 WHERE
test_sq_dj1.c1 = test_sq_dj2.c1))");
+ }
+
+ @Test
+ public void test21() {
+ PlanChecker.from(connectContext)
+ .checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c1 <
(cast('1.2' as decimal(2,1)) * (SELECT sum(c1) FROM test_sq_dj2 WHERE
test_sq_dj1.c1 = test_sq_dj2.c1)) or c1 > 10");
+ }
+
+ @Test
+ public void test22() {
+ PlanChecker.from(connectContext)
+ .checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c1 !=
(cast('1.2' as decimal(2,1)) * (SELECT sum(c1) FROM test_sq_dj2 WHERE
test_sq_dj1.c1 = test_sq_dj2.c1)) or c1 > 10");
+ }
+
+ @Test
+ public void test23() {
+ PlanChecker.from(connectContext)
+ .checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c2 in
(select k4 from test_sq_dj2)");
+ }
+
+ @Test
+ public void test24() {
+ PlanChecker.from(connectContext)
+ .checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c3 in
(select k4 from test_sq_dj2)");
+ }
+
+ @Test
+ public void test25() {
+ PlanChecker.from(connectContext)
+ .checkPlannerResult("select * from test_sq_dj1 where c1 in
(select c1 from test_sq_dj1 where c2 in (select c2 from test_sq_dj2) and c2 >
(select sum(c1) from test_sq_dj2))");
+ }
}
diff --git a/regression-test/data/nereids_syntax_p0/sub_query_correlated.out
b/regression-test/data/nereids_syntax_p0/sub_query_correlated.out
index b7e57d2613..1073b29d1f 100644
--- a/regression-test/data/nereids_syntax_p0/sub_query_correlated.out
+++ b/regression-test/data/nereids_syntax_p0/sub_query_correlated.out
@@ -355,3 +355,22 @@
-- !multi_subquery_scalar_and_in_or_scalar_and_exists --
+-- !cast_subquery_in --
+1 2
+1 3
+2 4
+2 5
+3 3
+3 4
+
+-- !cast_subquery_in_with_disconjunct --
+1 2
+1 3
+2 4
+2 5
+3 3
+3 4
+20 2
+22 3
+24 4
+
diff --git
a/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy
b/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy
index 1d153405ea..b816f82366 100644
--- a/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy
+++ b/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy
@@ -383,4 +383,13 @@ suite ("sub_query_correlated") {
OR k1 < (SELECT sum(k1) FROM
sub_query_correlated_subquery3 WHERE sub_query_correlated_subquery1.k1 =
sub_query_correlated_subquery3.k1))
and exists (SELECT k1 FROM
sub_query_correlated_subquery3 WHERE sub_query_correlated_subquery1.k1 =
sub_query_correlated_subquery3.k1);
"""
+
+ //----------type coercion subquery-----------
+ qt_cast_subquery_in """
+ SELECT * FROM sub_query_correlated_subquery1 WHERE k1 < (cast('1.2' as
decimal(2,1)) * (SELECT sum(k1) FROM sub_query_correlated_subquery3 WHERE
sub_query_correlated_subquery1.k1 = sub_query_correlated_subquery3.k1)) order
by k1, k2;
+ """
+
+ qt_cast_subquery_in_with_disconjunct """
+ SELECT * FROM sub_query_correlated_subquery1 WHERE k1 < (cast('1.2' as
decimal(2,1)) * (SELECT sum(k1) FROM sub_query_correlated_subquery3 WHERE
sub_query_correlated_subquery1.k1 = sub_query_correlated_subquery3.k1)) or k1 >
10 order by k1, k2;
+ """
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]