This is an automated email from the ASF dual-hosted git repository.
englefly pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 62020018bcb [opt](nereids) support extract join multiple tables
(#51569)
62020018bcb is described below
commit 62020018bcbc6e01ac076fbff6ca1e0828c603ca
Author: yujun <[email protected]>
AuthorDate: Wed Jun 11 22:12:55 2025 +0800
[opt](nereids) support extract join multiple tables (#51569)
### What problem does this PR solve?
For rule ExtractSingleTableExpressionFromDisjunction, it will extract
every single table's expression for LogicalFilter and LogicalJoin. But
this is not enough, for join, it should support extract multiple tables:
a expression for left tables, and a right expression for right tables.
---
...xtractSingleTableExpressionFromDisjunction.java | 45 ++++++++----
.../nereids/trees/expressions/NamedExpression.java | 4 ++
...ctSingleTableExpressionFromDisjunctionTest.java | 81 ++++++++++++++++++++++
3 files changed, 115 insertions(+), 15 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java
index c3c64c9e076..d56d60bc5e3 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java
@@ -65,7 +65,8 @@ public class ExtractSingleTableExpressionFromDisjunction
implements RewriteRuleF
public List<Rule> buildRules() {
return ImmutableList.of(
logicalFilter().then(filter -> {
- List<Expression> dependentPredicates =
extractDependentConjuncts(filter.getConjuncts());
+ List<Expression> dependentPredicates =
extractDependentConjuncts(filter.getConjuncts(),
+ Lists.newArrayList());
if (dependentPredicates.isEmpty()) {
return null;
}
@@ -78,8 +79,14 @@ public class ExtractSingleTableExpressionFromDisjunction
implements RewriteRuleF
return new LogicalFilter<>(newPredicates, filter.child());
}).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION),
logicalJoin().when(join ->
ALLOW_JOIN_TYPE.contains(join.getJoinType())).then(join -> {
+ List<Set<String>> qualifierBatches = Lists.newArrayList();
+ // for join, also extract multiple tables: left tables and
right tables
+
qualifierBatches.add(join.left().getOutputSet().stream().map(Slot::getJoinQualifier)
+ .collect(Collectors.toSet()));
+
qualifierBatches.add(join.right().getOutputSet().stream().map(Slot::getJoinQualifier)
+ .collect(Collectors.toSet()));
List<Expression> dependentOtherPredicates =
extractDependentConjuncts(
- ImmutableSet.copyOf(join.getOtherJoinConjuncts()));
+ ImmutableSet.copyOf(join.getOtherJoinConjuncts()),
qualifierBatches);
if (dependentOtherPredicates.isEmpty()) {
return null;
}
@@ -95,7 +102,7 @@ public class ExtractSingleTableExpressionFromDisjunction
implements RewriteRuleF
}).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION));
}
- private List<Expression> extractDependentConjuncts(Set<Expression>
conjuncts) {
+ private List<Expression> extractDependentConjuncts(Set<Expression>
conjuncts, List<Set<String>> qualifierBatches) {
List<Expression> dependentPredicates = Lists.newArrayList();
for (Expression conjunct : conjuncts) {
// conjunct=(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY')
@@ -112,11 +119,25 @@ public class ExtractSingleTableExpressionFromDisjunction
implements RewriteRuleF
Set<String> qualifiers = disjuncts.get(0).getInputSlots().stream()
.map(slot -> String.join(".", slot.getQualifier()))
.collect(Collectors.toCollection(Sets::newLinkedHashSet));
+ List<Set<String>> includeSingleQualifierBatches =
Lists.newArrayListWithExpectedSize(
+ qualifiers.size() + qualifierBatches.size());
+ // extract single table's expression
for (String qualifier : qualifiers) {
+ includeSingleQualifierBatches.add(ImmutableSet.of(qualifier));
+ }
+ // for join, extract left tables and right tables
+ for (Set<String> batch : qualifierBatches) {
+ Set<String> newBatch =
batch.stream().filter(qualifiers::contains).collect(Collectors.toSet());
+ // if newBatch.size == 1, then it had put into
includeSingleQualifierBatches
+ if (newBatch.size() > 1) {
+ includeSingleQualifierBatches.add(newBatch);
+ }
+ }
+ for (Set<String> batch : includeSingleQualifierBatches) {
List<Expression> extractForAll = Lists.newArrayList();
boolean success = true;
for (Expression expr : disjuncts) {
- Optional<Expression> extracted =
extractSingleTableExpression(expr, qualifier);
+ Optional<Expression> extracted =
extractMultipleTableExpression(expr, batch);
if (!extracted.isPresent()) {
// extract failed
success = false;
@@ -136,7 +157,7 @@ public class ExtractSingleTableExpressionFromDisjunction
implements RewriteRuleF
// extract some conjucts from expr, all slots of the extracted conjunct
comes from the table referred by qualifier.
// example: expr=(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY'),
qualifier="n1."
// output: n1.n_name = 'FRANCE'
- private Optional<Expression> extractSingleTableExpression(Expression expr,
String qualifier) {
+ private Optional<Expression> extractMultipleTableExpression(Expression
expr, Set<String> qualifiers) {
// suppose the qualifier is table T, then the process steps are as
follow:
// 1. split the expression into conjunctions: c1 and c2 and c3 and ...
// 2. for each conjunction ci, suppose its extract is Ei:
@@ -158,14 +179,14 @@ public class ExtractSingleTableExpressionFromDisjunction
implements RewriteRuleF
List<Expression> output = Lists.newArrayList();
List<Expression> conjuncts = ExpressionUtils.extractConjunction(expr);
for (Expression conjunct : conjuncts) {
- if (isSingleTableExpression(conjunct, qualifier)) {
+ if (isTableExpression(conjunct, qualifiers)) {
output.add(conjunct);
} else if (conjunct instanceof Or) {
List<Expression> disjuncts =
ExpressionUtils.extractDisjunction(conjunct);
List<Expression> extracted =
Lists.newArrayListWithExpectedSize(disjuncts.size());
boolean success = true;
for (Expression disjunct : disjuncts) {
- Optional<Expression> extractedDisjunct =
extractSingleTableExpression(disjunct, qualifier);
+ Optional<Expression> extractedDisjunct =
extractMultipleTableExpression(disjunct, qualifiers);
if (extractedDisjunct.isPresent()) {
extracted.addAll(ExpressionUtils.extractDisjunction(extractedDisjunct.get()));
} else {
@@ -186,14 +207,8 @@ public class ExtractSingleTableExpressionFromDisjunction
implements RewriteRuleF
}
}
- private boolean isSingleTableExpression(Expression expr, String qualifier)
{
+ private boolean isTableExpression(Expression expr, Set<String> qualifiers)
{
//TODO: cache getSlotQualifierAsString() result.
- for (Slot slot : expr.getInputSlots()) {
- String slotQualifier = String.join(".", slot.getQualifier());
- if (!slotQualifier.equals(qualifier)) {
- return false;
- }
- }
- return true;
+ return expr.getInputSlots().stream().allMatch(slot ->
qualifiers.contains(slot.getJoinQualifier()));
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NamedExpression.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NamedExpression.java
index d03669234cd..e363296fd72 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NamedExpression.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NamedExpression.java
@@ -48,6 +48,10 @@ public abstract class NamedExpression extends Expression {
throw new UnboundException("qualifier");
}
+ public String getJoinQualifier() {
+ return String.join(".", getQualifier());
+ }
+
/**
* Get qualified name of NamedExpression.
*
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java
index 27901e2db9f..d099315ed52 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java
@@ -17,6 +17,7 @@
package org.apache.doris.nereids.rules.rewrite;
+import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
@@ -24,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.plans.JoinType;
@@ -45,22 +47,29 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import java.util.Arrays;
+import java.util.Collections;
import java.util.List;
import java.util.Set;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class ExtractSingleTableExpressionFromDisjunctionTest implements
MemoPatternMatchSupported {
Plan student;
+ Plan score;
Plan course;
+ Plan salary;
SlotReference courseCid;
SlotReference courseName;
SlotReference studentAge;
SlotReference studentGender;
+ SlotReference scoreSid;
+ SlotReference salaryId;
@BeforeAll
public final void beforeAll() {
student = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
PlanConstructor.student, ImmutableList.of(""));
+ score = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
PlanConstructor.score, ImmutableList.of(""));
course = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
PlanConstructor.course, ImmutableList.of(""));
+ salary = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
PlanConstructor.salary, ImmutableList.of(""));
//select *
//from student join course
//where (course.cid=1 and student.age=10) or (student.gender = 0 and
course.name='abc')
@@ -68,6 +77,8 @@ public class ExtractSingleTableExpressionFromDisjunctionTest
implements MemoPatt
courseName = (SlotReference) course.getOutput().get(1);
studentAge = (SlotReference) student.getOutput().get(3);
studentGender = (SlotReference) student.getOutput().get(1);
+ scoreSid = (SlotReference) score.getOutput().get(0);
+ salaryId = (SlotReference) salary.getOutput().get(0);
}
/**
*(cid=1 and sage=10) or (sgender=1 and cname='abc')
@@ -265,4 +276,74 @@ public class
ExtractSingleTableExpressionFromDisjunctionTest implements MemoPatt
return conjuncts.size() == 3 && conjuncts.contains(or1) &&
conjuncts.contains(or2);
}
+
+ @Test
+ void testExtractMultipleTables() {
+ Expression expr = new Or(
+ ExpressionUtils.and(
+ new GreaterThan(studentAge, new IntegerLiteral(1)),
+ new GreaterThan(courseCid, new IntegerLiteral(1)),
+ new GreaterThan(scoreSid, new IntegerLiteral(1)),
+ new GreaterThan(salaryId, new IntegerLiteral(1)),
+ new EqualTo(new Add(studentAge, courseCid), new
BigIntLiteral(100L)),
+ new EqualTo(new Add(scoreSid, salaryId), new
BigIntLiteral(100L))
+ ),
+ ExpressionUtils.and(
+ new GreaterThan(studentAge, new IntegerLiteral(2)),
+ new GreaterThan(courseCid, new IntegerLiteral(2)),
+ new GreaterThan(scoreSid, new IntegerLiteral(2)),
+ new GreaterThan(salaryId, new IntegerLiteral(2)),
+ new EqualTo(new Add(studentAge, courseCid), new
BigIntLiteral(200L)),
+ new EqualTo(new Add(scoreSid, salaryId), new
BigIntLiteral(200L))
+ )
+ );
+ Plan left = new LogicalJoin<>(JoinType.CROSS_JOIN, student, course,
null);
+ Plan right = new LogicalJoin<>(JoinType.CROSS_JOIN, score, salary,
null);
+ Plan root = new LogicalJoin<>(JoinType.INNER_JOIN,
ExpressionUtils.EMPTY_CONDITION,
+ Collections.singletonList(expr), left, right, null);
+
+ List<Expression> expectJoinConjuncts = Arrays.asList(
+ // origin expression
+ expr,
+
+ // four single table expression
+ new Or(new GreaterThan(studentAge, new IntegerLiteral(1)),
+ new GreaterThan(studentAge, new IntegerLiteral(2))),
+ new Or(new GreaterThan(courseCid, new IntegerLiteral(1)),
+ new GreaterThan(courseCid, new IntegerLiteral(2))),
+ new Or(new GreaterThan(scoreSid, new IntegerLiteral(1)),
+ new GreaterThan(scoreSid, new IntegerLiteral(2))),
+ new Or(new GreaterThan(salaryId, new IntegerLiteral(1)),
+ new GreaterThan(salaryId, new IntegerLiteral(2))),
+
+ // left tables
+ new Or(
+ ExpressionUtils.and(
+ new GreaterThan(studentAge, new
IntegerLiteral(1)),
+ new GreaterThan(courseCid, new
IntegerLiteral(1)),
+ new EqualTo(new Add(studentAge, courseCid),
new BigIntLiteral(100L))),
+ ExpressionUtils.and(
+ new GreaterThan(studentAge, new
IntegerLiteral(2)),
+ new GreaterThan(courseCid, new
IntegerLiteral(2)),
+ new EqualTo(new Add(studentAge, courseCid),
new BigIntLiteral(200L)))),
+
+ // right tables
+ new Or(
+ ExpressionUtils.and(
+ new GreaterThan(scoreSid, new
IntegerLiteral(1)),
+ new GreaterThan(salaryId, new
IntegerLiteral(1)),
+ new EqualTo(new Add(scoreSid, salaryId), new
BigIntLiteral(100L))),
+ ExpressionUtils.and(
+ new GreaterThan(scoreSid, new
IntegerLiteral(2)),
+ new GreaterThan(salaryId, new
IntegerLiteral(2)),
+ new EqualTo(new Add(scoreSid, salaryId), new
BigIntLiteral(200L))))
+ );
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+ .applyTopDown(new
ExtractSingleTableExpressionFromDisjunction())
+ .matchesFromRoot(
+ logicalJoin()
+ .when(join ->
expectJoinConjuncts.equals(join.getOtherJoinConjuncts()))
+ );
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]