This is an automated email from the ASF dual-hosted git repository.

duanzhengqiang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 64b5c3a8507 Fix suquery expression projection in merge bind ex (#28392)
64b5c3a8507 is described below

commit 64b5c3a85072ba497aebb4f293119d66b7a3314c
Author: Chuxin Chen <[email protected]>
AuthorDate: Fri Sep 8 16:12:41 2023 +0800

    Fix suquery expression projection in merge bind ex (#28392)
    
    * Fix suquery expression projection in merge bind ex
    
    * Fix suquery expression projection in merge bind ex
---
 .../expression/impl/ColumnSegmentBinder.java       | 13 +++++----
 .../binder/statement/MergeStatementBinderTest.java | 34 ++++++++++++++++++++++
 2 files changed, 42 insertions(+), 5 deletions(-)

diff --git 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/expression/impl/ColumnSegmentBinder.java
 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/expression/impl/ColumnSegmentBinder.java
index dad579cf30f..4502311fa6e 100644
--- 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/expression/impl/ColumnSegmentBinder.java
+++ 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/expression/impl/ColumnSegmentBinder.java
@@ -130,8 +130,11 @@ public final class ColumnSegmentBinder {
             }
         }
         if (!isFindInputColumn) {
-            result = findInputColumnSegmentFromExternalTables(segment, 
statementBinderContext.getExternalTableBinderContexts()).orElse(null);
-            isFindInputColumn = result != null;
+            Optional<ProjectionSegment> projectionSegment = 
findInputColumnSegmentFromExternalTables(segment, 
statementBinderContext.getExternalTableBinderContexts());
+            isFindInputColumn = projectionSegment.isPresent();
+            if (projectionSegment.isPresent() && projectionSegment.get() 
instanceof ColumnProjectionSegment) {
+                result = ((ColumnProjectionSegment) 
projectionSegment.get()).getColumn();
+            }
         }
         if (!isFindInputColumn) {
             result = findInputColumnSegmentByVariables(segment, 
statementBinderContext.getVariableNames()).orElse(null);
@@ -142,11 +145,11 @@ public final class ColumnSegmentBinder {
         return Optional.ofNullable(result);
     }
     
-    private static Optional<ColumnSegment> 
findInputColumnSegmentFromExternalTables(final ColumnSegment segment, final 
Map<String, TableSegmentBinderContext> externalTableBinderContexts) {
+    private static Optional<ProjectionSegment> 
findInputColumnSegmentFromExternalTables(final ColumnSegment segment, final 
Map<String, TableSegmentBinderContext> externalTableBinderContexts) {
         for (TableSegmentBinderContext each : 
externalTableBinderContexts.values()) {
             ProjectionSegment projectionSegment = 
each.getProjectionSegmentByColumnLabel(segment.getIdentifier().getValue());
-            if (projectionSegment instanceof ColumnProjectionSegment) {
-                return Optional.of(((ColumnProjectionSegment) 
projectionSegment).getColumn());
+            if (null != projectionSegment) {
+                return Optional.of(projectionSegment);
             }
         }
         return Optional.empty();
diff --git 
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/MergeStatementBinderTest.java
 
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/MergeStatementBinderTest.java
index 0ca795d4a74..e5671c97ad2 100644
--- 
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/MergeStatementBinderTest.java
+++ 
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/MergeStatementBinderTest.java
@@ -27,15 +27,20 @@ import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.Se
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ExpressionProjectionSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionsSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.AliasSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.OwnerSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SubqueryTableSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableNameSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.MergeStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
 import 
org.apache.shardingsphere.sql.parser.sql.dialect.statement.oracle.dml.OracleMergeStatement;
+import 
org.apache.shardingsphere.sql.parser.sql.dialect.statement.oracle.dml.OracleSelectStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.dialect.statement.oracle.dml.OracleUpdateStatement;
 import org.junit.jupiter.api.Test;
 
@@ -105,4 +110,33 @@ class MergeStatementBinderTest {
         
when(result.getDatabase(DefaultDatabase.LOGIC_NAME).getSchema(DefaultDatabase.LOGIC_NAME).containsTable("t_order_item")).thenReturn(true);
         return result;
     }
+    
+    @Test
+    void assertBindWithSubQuery() {
+        MergeStatement mergeStatement = new OracleMergeStatement();
+        SimpleTableSegment targetTable = new SimpleTableSegment(new 
TableNameSegment(0, 0, new IdentifierValue("t_order")));
+        targetTable.setAlias(new AliasSegment(0, 0, new IdentifierValue("a")));
+        mergeStatement.setTarget(targetTable);
+        ProjectionsSegment projectionsSegment = new ProjectionsSegment(0, 0);
+        ExpressionProjectionSegment expressionProjectionSegment = new 
ExpressionProjectionSegment(0, 0, "status + 1", new 
BinaryOperationExpression(0, 0,
+                new ColumnSegment(0, 0, new IdentifierValue("status")), new 
LiteralExpressionSegment(0, 0, 1), "+", "status + 1"));
+        expressionProjectionSegment.setAlias(new AliasSegment(0, 0, new 
IdentifierValue("new_status")));
+        projectionsSegment.getProjections().add(expressionProjectionSegment);
+        OracleSelectStatement oracleSelectStatement = new 
OracleSelectStatement();
+        oracleSelectStatement.setProjections(projectionsSegment);
+        oracleSelectStatement.setFrom(new SimpleTableSegment(new 
TableNameSegment(0, 0, new IdentifierValue("t_order_item"))));
+        SubqueryTableSegment subqueryTableSegment = new 
SubqueryTableSegment(new SubquerySegment(0, 0, oracleSelectStatement));
+        subqueryTableSegment.setAlias(new AliasSegment(0, 0, new 
IdentifierValue("b")));
+        mergeStatement.setSource(subqueryTableSegment);
+        UpdateStatement updateStatement = new OracleUpdateStatement();
+        ColumnSegment targetTableColumn = new ColumnSegment(0, 0, new 
IdentifierValue("status"));
+        targetTableColumn.setOwner(new OwnerSegment(0, 0, new 
IdentifierValue("a")));
+        ColumnSegment sourceTableColumn = new ColumnSegment(0, 0, new 
IdentifierValue("new_status"));
+        SetAssignmentSegment setAssignmentSegment = new 
SetAssignmentSegment(0, 0,
+                Collections.singletonList(new ColumnAssignmentSegment(0, 0, 
Collections.singletonList(targetTableColumn), sourceTableColumn)));
+        updateStatement.setSetAssignment(setAssignmentSegment);
+        mergeStatement.setUpdate(updateStatement);
+        MergeStatement actual = new 
MergeStatementBinder().bind(mergeStatement, createMetaData(), 
DefaultDatabase.LOGIC_NAME);
+        assertThat(actual, not(mergeStatement));
+    }
 }

Reply via email to