This is an automated email from the ASF dual-hosted git repository. duanzhengqiang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push: new 88c71e3c24b Add extractAllExpressions method to replace extractAndPredicates (#35404) 88c71e3c24b is described below commit 88c71e3c24bc60d3e8521bb0dde04137cff3f32a Author: ZhangCheng <chengzh...@apache.org> AuthorDate: Wed May 14 18:40:15 2025 +0800 Add extractAllExpressions method to replace extractAndPredicates (#35404) * Add extractAllExpressions method to replace extractAndPredicates * Add extractAllExpressions method to replace extractAndPredicates * Add extractAllExpressions method to replace extractAndPredicates --- .../EncryptPredicateColumnSupportedChecker.java | 13 ++--- .../rewrite/condition/EncryptConditionEngine.java | 7 +-- .../EncryptPredicateColumnTokenGenerator.java | 23 ++++---- .../statement/core/extractor/ColumnExtractor.java | 11 +--- .../core/extractor/ExpressionExtractor.java | 29 +++++++++ .../core/extractor/ExpressionExtractorTest.java | 68 ++++++++++++++++++++++ 6 files changed, 118 insertions(+), 33 deletions(-) diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/checker/sql/predicate/EncryptPredicateColumnSupportedChecker.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/checker/sql/predicate/EncryptPredicateColumnSupportedChecker.java index dfbd41f33d9..cfab7b75ae2 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/checker/sql/predicate/EncryptPredicateColumnSupportedChecker.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/checker/sql/predicate/EncryptPredicateColumnSupportedChecker.java @@ -34,7 +34,6 @@ import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionE import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BinaryOperationExpression; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment; -import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.AndPredicate; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment; import java.util.Collection; @@ -73,18 +72,16 @@ public final class EncryptPredicateColumnSupportedChecker implements SupportedSQ private boolean includesLike(final Collection<WhereSegment> whereSegments, final ColumnSegment targetColumnSegment) { for (WhereSegment each : whereSegments) { - Collection<AndPredicate> andPredicates = ExpressionExtractor.extractAndPredicates(each.getExpr()); - for (AndPredicate andPredicate : andPredicates) { - if (isLikeColumnSegment(andPredicate, targetColumnSegment)) { - return true; - } + Collection<ExpressionSegment> expressions = ExpressionExtractor.extractAllExpressions(each.getExpr()); + if (isLikeColumnSegment(expressions, targetColumnSegment)) { + return true; } } return false; } - private boolean isLikeColumnSegment(final AndPredicate andPredicate, final ColumnSegment targetColumnSegment) { - for (ExpressionSegment each : andPredicate.getPredicates()) { + private boolean isLikeColumnSegment(final Collection<ExpressionSegment> expressions, final ColumnSegment targetColumnSegment) { + for (ExpressionSegment each : expressions) { if (each instanceof BinaryOperationExpression && "LIKE".equalsIgnoreCase(((BinaryOperationExpression) each).getOperator()) && isSameColumnSegment(((BinaryOperationExpression) each).getLeft(), targetColumnSegment)) { return true; diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/condition/EncryptConditionEngine.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/condition/EncryptConditionEngine.java index 3266573ed58..b0ae7ef1f53 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/condition/EncryptConditionEngine.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/condition/EncryptConditionEngine.java @@ -37,7 +37,6 @@ import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.List import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.LiteralExpressionSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.SimpleExpressionSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubqueryExpressionSegment; -import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.AndPredicate; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment; import java.util.Collection; @@ -84,10 +83,8 @@ public final class EncryptConditionEngine { public Collection<EncryptCondition> createEncryptConditions(final Collection<WhereSegment> whereSegments) { Collection<EncryptCondition> result = new LinkedList<>(); for (WhereSegment each : whereSegments) { - Collection<AndPredicate> andPredicates = ExpressionExtractor.extractAndPredicates(each.getExpr()); - for (AndPredicate predicate : andPredicates) { - addEncryptConditions(result, predicate.getPredicates()); - } + Collection<ExpressionSegment> expressions = ExpressionExtractor.extractAllExpressions(each.getExpr()); + addEncryptConditions(result, expressions); } return result; } diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateColumnTokenGenerator.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateColumnTokenGenerator.java index 99fc28bd12f..a40b88e5ea8 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateColumnTokenGenerator.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/predicate/EncryptPredicateColumnTokenGenerator.java @@ -44,7 +44,6 @@ import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionE import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BinaryOperationExpression; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment; -import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.AndPredicate; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment; import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue; @@ -72,16 +71,14 @@ public final class EncryptPredicateColumnTokenGenerator implements CollectionSQL public Collection<SQLToken> generateSQLTokens(final SQLStatementContext sqlStatementContext) { Collection<SelectStatementContext> allSubqueryContexts = SQLStatementContextExtractor.getAllSubqueryContexts(sqlStatementContext); Collection<WhereSegment> whereSegments = SQLStatementContextExtractor.getWhereSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts); - Collection<AndPredicate> andPredicates = getAndPredicates(whereSegments); - return generateSQLTokens(andPredicates, sqlStatementContext); + Collection<ExpressionSegment> expressions = getAllExpressions(whereSegments); + return generateSQLTokens(expressions, sqlStatementContext); } - private Collection<SQLToken> generateSQLTokens(final Collection<AndPredicate> andPredicates, final SQLStatementContext sqlStatementContext) { + private Collection<SQLToken> generateSQLTokens(final Collection<ExpressionSegment> expressions, final SQLStatementContext sqlStatementContext) { Collection<SQLToken> result = new LinkedList<>(); - for (AndPredicate each : andPredicates) { - for (ExpressionSegment expression : each.getPredicates()) { - result.addAll(generateSQLTokens(sqlStatementContext, expression)); - } + for (ExpressionSegment each : expressions) { + result.addAll(generateSQLTokens(sqlStatementContext, each)); } return result; } @@ -98,10 +95,14 @@ public final class EncryptPredicateColumnTokenGenerator implements CollectionSQL return result; } - private Collection<AndPredicate> getAndPredicates(final Collection<WhereSegment> whereSegments) { - Collection<AndPredicate> result = new LinkedList<>(); + private Collection<ExpressionSegment> getAllExpressions(final Collection<WhereSegment> whereSegments) { + if (1 == whereSegments.size()) { + return ExpressionExtractor.extractAllExpressions(whereSegments.iterator().next().getExpr()); + } + Collection<ExpressionSegment> result = new LinkedList<>(); for (WhereSegment each : whereSegments) { - result.addAll(ExpressionExtractor.extractAndPredicates(each.getExpr())); + Collection<ExpressionSegment> expressions = ExpressionExtractor.extractAllExpressions(each.getExpr()); + result.addAll(expressions); } return result; } diff --git a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/extractor/ColumnExtractor.java b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/extractor/ColumnExtractor.java index 3ac47c3c414..d0ba99974ff 100644 --- a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/extractor/ColumnExtractor.java +++ b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/extractor/ColumnExtractor.java @@ -40,7 +40,6 @@ import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.order.Ord import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.order.item.ColumnOrderByItemSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.order.item.ExpressionOrderByItemSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.order.item.OrderByItemSegment; -import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.AndPredicate; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.HavingSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.CollectionTableSegment; @@ -165,18 +164,12 @@ public final class ColumnExtractor { */ public static void extractColumnSegments(final Collection<ColumnSegment> columnSegments, final Collection<WhereSegment> whereSegments) { for (WhereSegment each : whereSegments) { - for (AndPredicate andPredicate : ExpressionExtractor.extractAndPredicates(each.getExpr())) { - extractColumnSegments(columnSegments, andPredicate); + for (ExpressionSegment expression : ExpressionExtractor.extractAllExpressions(each.getExpr())) { + columnSegments.addAll(extract(expression)); } } } - private static void extractColumnSegments(final Collection<ColumnSegment> columnSegments, final AndPredicate andPredicate) { - for (ExpressionSegment each : andPredicate.getPredicates()) { - columnSegments.addAll(extract(each)); - } - } - /** * Extract column segments. * diff --git a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/extractor/ExpressionExtractor.java b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/extractor/ExpressionExtractor.java index 56f52dd2e69..d57e9f87a99 100644 --- a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/extractor/ExpressionExtractor.java +++ b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/extractor/ExpressionExtractor.java @@ -48,6 +48,7 @@ import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.match import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.Deque; import java.util.LinkedList; import java.util.List; import java.util.Optional; @@ -105,6 +106,34 @@ public final class ExpressionExtractor { return result; } + /** + * Extract all expressions. + * + * @param expression to be extracted expression segment + * @return all expressions + */ + public static Collection<ExpressionSegment> extractAllExpressions(final ExpressionSegment expression) { + Collection<ExpressionSegment> result = new LinkedList<>(); + Deque<ExpressionSegment> stack = new LinkedList<>(); + stack.push(expression); + while (!stack.isEmpty()) { + ExpressionSegment expressionSegment = stack.pop(); + if (expressionSegment instanceof BinaryOperationExpression) { + BinaryOperationExpression binaryExpression = (BinaryOperationExpression) expressionSegment; + Optional<LogicalOperator> logicalOperator = LogicalOperator.valueFrom(binaryExpression.getOperator()); + if (logicalOperator.isPresent() && (LogicalOperator.OR == logicalOperator.get() || LogicalOperator.AND == logicalOperator.get())) { + stack.push(binaryExpression.getRight()); + stack.push(binaryExpression.getLeft()); + } else { + result.add(expressionSegment); + } + } else { + result.add(expressionSegment); + } + } + return result; + } + /** * Get parameter marker expressions. * diff --git a/parser/sql/statement/core/src/test/java/org/apache/shardingsphere/sql/parser/statement/core/extractor/ExpressionExtractorTest.java b/parser/sql/statement/core/src/test/java/org/apache/shardingsphere/sql/parser/statement/core/extractor/ExpressionExtractorTest.java index 48dd5802c2a..74b951a437c 100644 --- a/parser/sql/statement/core/src/test/java/org/apache/shardingsphere/sql/parser/statement/core/extractor/ExpressionExtractorTest.java +++ b/parser/sql/statement/core/src/test/java/org/apache/shardingsphere/sql/parser/statement/core/extractor/ExpressionExtractorTest.java @@ -110,6 +110,74 @@ class ExpressionExtractorTest { assertThat(andPredicate2.getPredicates().size(), is(2)); } + @Test + void assertExtractAllExpressionsWithAndOperation() { + BinaryOperationExpression expression1 = + new BinaryOperationExpression(0, 0, new ColumnSegment(0, 0, new IdentifierValue("order_id")), new LiteralExpressionSegment(10, 11, "1"), "=", "order_id=1"); + BinaryOperationExpression expression2 = new BinaryOperationExpression(0, 0, new ColumnSegment(0, 0, new IdentifierValue("status")), new LiteralExpressionSegment(22, 23, "2"), "=", "status=2"); + BinaryOperationExpression andExpression = new BinaryOperationExpression(0, 0, expression1, expression2, "AND", "order_id=1 AND status=2"); + Collection<ExpressionSegment> actual = ExpressionExtractor.extractAllExpressions(andExpression); + assertThat(actual.size(), is(2)); + Iterator<ExpressionSegment> iterator = actual.iterator(); + assertThat(iterator.next(), is(expression1)); + assertThat(iterator.next(), is(expression2)); + } + + @Test + void assertExtractAllExpressionsWithOrOperation() { + BinaryOperationExpression expression1 = new BinaryOperationExpression(0, 0, new ColumnSegment(0, 0, new IdentifierValue("age")), + new LiteralExpressionSegment(8, 9, "5"), ">", "age>5"); + BinaryOperationExpression expression2 = new BinaryOperationExpression(0, 0, new ColumnSegment(0, 0, new IdentifierValue("score")), + new LiteralExpressionSegment(19, 21, "10"), "<", "score<10"); + BinaryOperationExpression orExpression = new BinaryOperationExpression(0, 0, expression1, expression2, "OR", "age>5 OR score<10"); + Collection<ExpressionSegment> actual = ExpressionExtractor.extractAllExpressions(orExpression); + assertThat(actual.size(), is(2)); + Iterator<ExpressionSegment> iterator = actual.iterator(); + assertThat(iterator.next(), is(expression1)); + assertThat(iterator.next(), is(expression2)); + } + + @Test + void assertExtractAllExpressionsWithNestedOperations() { + BinaryOperationExpression expression1 = new BinaryOperationExpression(0, 0, new ColumnSegment(0, 0, new IdentifierValue("A")), + new LiteralExpressionSegment(0, 0, "1"), "=", "A=1"); + BinaryOperationExpression expression2 = new BinaryOperationExpression(0, 0, new ColumnSegment(0, 0, new IdentifierValue("B")), + new LiteralExpressionSegment(0, 0, "2"), "=", "B=2"); + BinaryOperationExpression andExpression = new BinaryOperationExpression(0, 0, expression1, expression2, "AND", "A=1 AND B=2"); + BinaryOperationExpression expression3 = new BinaryOperationExpression(0, 0, new ColumnSegment(0, 0, new IdentifierValue("C")), + new LiteralExpressionSegment(0, 0, "3"), "=", "C=3"); + BinaryOperationExpression fullExpression = new BinaryOperationExpression(0, 0, andExpression, expression3, "OR", "A=1 AND B=2 OR C=3"); + Collection<ExpressionSegment> actual = ExpressionExtractor.extractAllExpressions(fullExpression); + assertThat(actual.size(), is(3)); + Iterator<ExpressionSegment> iterator = actual.iterator(); + assertThat(iterator.next(), is(expression1)); + assertThat(iterator.next(), is(expression2)); + assertThat(iterator.next(), is(expression3)); + } + + @Test + void assertExtractAllExpressionsWithComplexNestedExpressions() { + BinaryOperationExpression expression1 = new BinaryOperationExpression(0, 0, new ColumnSegment(21, 22, new IdentifierValue("A")), + new LiteralExpressionSegment(0, 0, "1"), "=", "A=1"); + BinaryOperationExpression expression2 = new BinaryOperationExpression(0, 0, new ColumnSegment(14, 15, new IdentifierValue("B")), + new LiteralExpressionSegment(0, 0, "2"), "=", "B=2"); + BinaryOperationExpression expression3 = new BinaryOperationExpression(0, 0, new ColumnSegment(7, 8, new IdentifierValue("C")), + new LiteralExpressionSegment(0, 0, "3"), "=", "C=3"); + BinaryOperationExpression expression4 = new BinaryOperationExpression(0, 0, + new ColumnSegment(0, 0, new IdentifierValue("D")), + new LiteralExpressionSegment(0, 0, "4"), "=", "D=4"); + BinaryOperationExpression orExpression = new BinaryOperationExpression(0, 0, expression2, expression3, "OR", "B=2 OR C=3"); + BinaryOperationExpression innerAndExpression = new BinaryOperationExpression(0, 0, expression1, orExpression, "AND", "A=1 AND (B=2 OR C=3)"); + BinaryOperationExpression topLevelExpression = new BinaryOperationExpression(0, 0, innerAndExpression, expression4, "AND", "(A=1 AND (B=2 OR C=3)) AND D=4"); + Collection<ExpressionSegment> actual = ExpressionExtractor.extractAllExpressions(topLevelExpression); + assertThat(actual.size(), is(4)); + Iterator<ExpressionSegment> iterator = actual.iterator(); + assertThat(iterator.next(), is(expression1)); + assertThat(iterator.next(), is(expression2)); + assertThat(iterator.next(), is(expression3)); + assertThat(iterator.next(), is(expression4)); + } + @Test void assertExtractGetParameterMarkerExpressions() { FunctionSegment functionSegment = new FunctionSegment(0, 0, "IF", "IF(number + 1 <= ?, 1, -1)");