This is an automated email from the ASF dual-hosted git repository.

zhaojinchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 6edf9651468 Support bind merge statement. (#28280)
6edf9651468 is described below

commit 6edf96514686695316ebd448a73ccc1362651a7e
Author: Chuxin Chen <[email protected]>
AuthorDate: Mon Aug 28 14:39:29 2023 +0800

    Support bind merge statement. (#28280)
---
 .../infra/binder/engine/SQLBindEngine.java         |   5 +
 .../infra/binder/enums/SegmentType.java            |   2 +-
 .../binder/statement/dml/MergeStatementBinder.java | 130 +++++++++++++++++++++
 .../binder/statement/MergeStatementBinderTest.java | 100 ++++++++++++++++
 .../handler/dml/InsertStatementHandler.java        |  26 +++++
 5 files changed, 262 insertions(+), 1 deletion(-)

diff --git 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/SQLBindEngine.java
 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/SQLBindEngine.java
index aefef195a8c..ffb09f1a780 100644
--- 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/SQLBindEngine.java
+++ 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/SQLBindEngine.java
@@ -22,6 +22,7 @@ import 
org.apache.shardingsphere.infra.binder.context.statement.SQLStatementCont
 import 
org.apache.shardingsphere.infra.binder.statement.ddl.CursorStatementBinder;
 import 
org.apache.shardingsphere.infra.binder.statement.dml.DeleteStatementBinder;
 import 
org.apache.shardingsphere.infra.binder.statement.dml.InsertStatementBinder;
+import 
org.apache.shardingsphere.infra.binder.statement.dml.MergeStatementBinder;
 import 
org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementBinder;
 import 
org.apache.shardingsphere.infra.binder.statement.dml.UpdateStatementBinder;
 import org.apache.shardingsphere.infra.hint.HintValueContext;
@@ -33,6 +34,7 @@ import 
org.apache.shardingsphere.sql.parser.sql.common.statement.ddl.DDLStatemen
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.DMLStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.DeleteStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
+import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.MergeStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
 import 
org.apache.shardingsphere.sql.parser.sql.dialect.statement.opengauss.ddl.OpenGaussCursorStatement;
@@ -98,6 +100,9 @@ public final class SQLBindEngine {
         if (statement instanceof DeleteStatement) {
             return new DeleteStatementBinder().bind((DeleteStatement) 
statement, metaData, defaultDatabaseName);
         }
+        if (statement instanceof MergeStatement) {
+            return new MergeStatementBinder().bind((MergeStatement) statement, 
metaData, defaultDatabaseName);
+        }
         return statement;
     }
     
diff --git 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/enums/SegmentType.java
 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/enums/SegmentType.java
index 8aaa431dda1..b647a2b7e36 100644
--- 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/enums/SegmentType.java
+++ 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/enums/SegmentType.java
@@ -22,5 +22,5 @@ package org.apache.shardingsphere.infra.binder.enums;
  */
 public enum SegmentType {
     
-    PROJECTION, PREDICATE, JOIN_ON, JOIN_USING, ORDER_BY, GROUP_BY, LOCK
+    PROJECTION, PREDICATE, JOIN_ON, JOIN_USING, ORDER_BY, GROUP_BY, LOCK, 
SET_ASSIGNMENT, VALUES
 }
diff --git 
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/MergeStatementBinder.java
 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/MergeStatementBinder.java
new file mode 100644
index 00000000000..afe5ade07c1
--- /dev/null
+++ 
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/MergeStatementBinder.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.infra.binder.statement.dml;
+
+import lombok.SneakyThrows;
+import org.apache.commons.collections4.map.CaseInsensitiveMap;
+import org.apache.shardingsphere.infra.binder.enums.SegmentType;
+import 
org.apache.shardingsphere.infra.binder.segment.expression.ExpressionSegmentBinder;
+import 
org.apache.shardingsphere.infra.binder.segment.expression.impl.ColumnSegmentBinder;
+import org.apache.shardingsphere.infra.binder.segment.from.TableSegmentBinder;
+import 
org.apache.shardingsphere.infra.binder.segment.from.TableSegmentBinderContext;
+import org.apache.shardingsphere.infra.binder.segment.where.WhereSegmentBinder;
+import org.apache.shardingsphere.infra.binder.statement.SQLStatementBinder;
+import 
org.apache.shardingsphere.infra.binder.statement.SQLStatementBinderContext;
+import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
+import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.MergeStatement;
+import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
+import 
org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.InsertStatementHandler;
+import 
org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.UpdateStatementHandler;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * Merge statement binder.
+ */
+public final class MergeStatementBinder implements 
SQLStatementBinder<MergeStatement> {
+    
+    @SneakyThrows
+    @Override
+    public MergeStatement bind(final MergeStatement sqlStatement, final 
ShardingSphereMetaData metaData, final String defaultDatabaseName) {
+        MergeStatement result = 
sqlStatement.getClass().getDeclaredConstructor().newInstance();
+        Map<String, TableSegmentBinderContext> tableBinderContexts = new 
CaseInsensitiveMap<>();
+        SQLStatementBinderContext statementBinderContext = new 
SQLStatementBinderContext(metaData, defaultDatabaseName, 
sqlStatement.getDatabaseType());
+        TableSegment boundedTargetTableSegment = 
TableSegmentBinder.bind(sqlStatement.getTarget(), statementBinderContext, 
tableBinderContexts);
+        TableSegment boundedSourceTableSegment = 
TableSegmentBinder.bind(sqlStatement.getSource(), statementBinderContext, 
tableBinderContexts);
+        result.setTarget(boundedTargetTableSegment);
+        result.setSource(boundedSourceTableSegment);
+        result.setExpr(ExpressionSegmentBinder.bind(sqlStatement.getExpr(), 
SegmentType.JOIN_ON, statementBinderContext, tableBinderContexts, 
Collections.emptyMap()));
+        
result.setInsert(Optional.ofNullable(sqlStatement.getInsert()).map(optional -> 
bindMergeInsert(optional,
+                (SimpleTableSegment) boundedTargetTableSegment, 
statementBinderContext, tableBinderContexts)).orElse(null));
+        
result.setUpdate(Optional.ofNullable(sqlStatement.getUpdate()).map(optional -> 
bindMergeUpdate(optional,
+                (SimpleTableSegment) boundedTargetTableSegment, 
statementBinderContext, tableBinderContexts)).orElse(null));
+        
result.addParameterMarkerSegments(sqlStatement.getParameterMarkerSegments());
+        result.getCommentSegments().addAll(sqlStatement.getCommentSegments());
+        return result;
+    }
+    
+    @SneakyThrows
+    private InsertStatement bindMergeInsert(final InsertStatement 
sqlStatement, final SimpleTableSegment tableSegment, final 
SQLStatementBinderContext statementBinderContext,
+                                            final Map<String, 
TableSegmentBinderContext> tableBinderContexts) {
+        InsertStatement result = 
sqlStatement.getClass().getDeclaredConstructor().newInstance();
+        result.setTable(tableSegment);
+        sqlStatement.getInsertColumns().ifPresent(result::setInsertColumns);
+        sqlStatement.getInsertSelect().ifPresent(result::setInsertSelect);
+        Collection<InsertValuesSegment> insertValues = new LinkedList<>();
+        for (InsertValuesSegment each : sqlStatement.getValues()) {
+            List<ExpressionSegment> values = new LinkedList<>();
+            for (ExpressionSegment value : each.getValues()) {
+                values.add(ExpressionSegmentBinder.bind(value, 
SegmentType.VALUES, statementBinderContext, tableBinderContexts, 
Collections.emptyMap()));
+            }
+            insertValues.add(new InsertValuesSegment(each.getStartIndex(), 
each.getStopIndex(), values));
+        }
+        result.getValues().addAll(insertValues);
+        
InsertStatementHandler.getOnDuplicateKeyColumnsSegment(sqlStatement).ifPresent(optional
 -> InsertStatementHandler.setOnDuplicateKeyColumnsSegment(result, optional));
+        
InsertStatementHandler.getSetAssignmentSegment(sqlStatement).ifPresent(optional 
-> InsertStatementHandler.setSetAssignmentSegment(result, optional));
+        InsertStatementHandler.getWithSegment(sqlStatement).ifPresent(optional 
-> InsertStatementHandler.setWithSegment(result, optional));
+        
InsertStatementHandler.getOutputSegment(sqlStatement).ifPresent(optional -> 
InsertStatementHandler.setOutputSegment(result, optional));
+        
InsertStatementHandler.getInsertMultiTableElementSegment(sqlStatement).ifPresent(optional
 -> InsertStatementHandler.setInsertMultiTableElementSegment(result, optional));
+        
InsertStatementHandler.getReturningSegment(sqlStatement).ifPresent(optional -> 
InsertStatementHandler.setReturningSegment(result, optional));
+        
InsertStatementHandler.getWhereSegment(sqlStatement).ifPresent(optional -> 
InsertStatementHandler.setWhereSegment(result,
+                WhereSegmentBinder.bind(optional, statementBinderContext, 
tableBinderContexts, Collections.emptyMap())));
+        
result.addParameterMarkerSegments(sqlStatement.getParameterMarkerSegments());
+        result.getCommentSegments().addAll(sqlStatement.getCommentSegments());
+        return result;
+    }
+    
+    @SneakyThrows
+    private UpdateStatement bindMergeUpdate(final UpdateStatement 
sqlStatement, final SimpleTableSegment tableSegment, final 
SQLStatementBinderContext statementBinderContext,
+                                            final Map<String, 
TableSegmentBinderContext> tableBinderContexts) {
+        UpdateStatement result = 
sqlStatement.getClass().getDeclaredConstructor().newInstance();
+        result.setTable(tableSegment);
+        Collection<AssignmentSegment> assignments = new LinkedList<>();
+        for (AssignmentSegment each : 
sqlStatement.getSetAssignment().getAssignments()) {
+            List<ColumnSegment> columnSegments = new 
ArrayList<>(each.getColumns().size());
+            each.getColumns().forEach(column -> 
columnSegments.add(ColumnSegmentBinder.bind(column, SegmentType.SET_ASSIGNMENT, 
statementBinderContext, tableBinderContexts, Collections.emptyMap())));
+            ExpressionSegment value = 
ExpressionSegmentBinder.bind(each.getValue(), SegmentType.SET_ASSIGNMENT, 
statementBinderContext, tableBinderContexts, Collections.emptyMap());
+            ColumnAssignmentSegment columnAssignmentSegment = new 
ColumnAssignmentSegment(each.getStartIndex(), each.getStopIndex(), 
columnSegments, value);
+            assignments.add(columnAssignmentSegment);
+        }
+        SetAssignmentSegment setAssignmentSegment = new 
SetAssignmentSegment(sqlStatement.getSetAssignment().getStartIndex(), 
sqlStatement.getSetAssignment().getStopIndex(), assignments);
+        result.setSetAssignment(setAssignmentSegment);
+        sqlStatement.getWhere().ifPresent(optional -> 
result.setWhere(WhereSegmentBinder.bind(optional, statementBinderContext, 
tableBinderContexts, Collections.emptyMap())));
+        
UpdateStatementHandler.getOrderBySegment(sqlStatement).ifPresent(optional -> 
UpdateStatementHandler.setOrderBySegment(result, optional));
+        
UpdateStatementHandler.getLimitSegment(sqlStatement).ifPresent(optional -> 
UpdateStatementHandler.setLimitSegment(result, optional));
+        UpdateStatementHandler.getWithSegment(sqlStatement).ifPresent(optional 
-> UpdateStatementHandler.setWithSegment(result, optional));
+        
result.addParameterMarkerSegments(sqlStatement.getParameterMarkerSegments());
+        result.getCommentSegments().addAll(sqlStatement.getCommentSegments());
+        return result;
+    }
+}
diff --git 
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/MergeStatementBinderTest.java
 
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/MergeStatementBinderTest.java
new file mode 100644
index 00000000000..7cdf2382e21
--- /dev/null
+++ 
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/MergeStatementBinderTest.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.infra.binder.statement;
+
+import 
org.apache.shardingsphere.infra.binder.statement.dml.MergeStatementBinder;
+import org.apache.shardingsphere.infra.database.core.DefaultDatabase;
+import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
+import 
org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
+import 
org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.AliasSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.OwnerSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableNameSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.MergeStatement;
+import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
+import 
org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
+import 
org.apache.shardingsphere.sql.parser.sql.dialect.statement.oracle.dml.OracleMergeStatement;
+import 
org.apache.shardingsphere.sql.parser.sql.dialect.statement.oracle.dml.OracleUpdateStatement;
+import org.junit.jupiter.api.Test;
+
+import java.sql.Types;
+import java.util.Arrays;
+import java.util.Collections;
+
+import static org.hamcrest.CoreMatchers.instanceOf;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.CoreMatchers.not;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+class MergeStatementBinderTest {
+    
+    @Test
+    void assertBind() {
+        MergeStatement mergeStatement = new OracleMergeStatement();
+        SimpleTableSegment targetTable = new SimpleTableSegment(new 
TableNameSegment(0, 0, new IdentifierValue("t_order")));
+        targetTable.setAlias(new AliasSegment(0, 0, new IdentifierValue("a")));
+        mergeStatement.setTarget(targetTable);
+        SimpleTableSegment sourceTable = new SimpleTableSegment(new 
TableNameSegment(0, 0, new IdentifierValue("t_order_item")));
+        sourceTable.setAlias(new AliasSegment(0, 0, new IdentifierValue("b")));
+        mergeStatement.setSource(sourceTable);
+        mergeStatement.setExpr(new BinaryOperationExpression(0, 0, new 
ColumnSegment(0, 0, new IdentifierValue("id")),
+                new ColumnSegment(0, 0, new IdentifierValue("order_id")), "=", 
"id = order_id"));
+        UpdateStatement updateStatement = new OracleUpdateStatement();
+        updateStatement.setTable(targetTable);
+        ColumnSegment targetTableColumn = new ColumnSegment(0, 0, new 
IdentifierValue("status"));
+        targetTableColumn.setOwner(new OwnerSegment(0, 0, new 
IdentifierValue("a")));
+        ColumnSegment sourceTableColumn = new ColumnSegment(0, 0, new 
IdentifierValue("status"));
+        sourceTableColumn.setOwner(new OwnerSegment(0, 0, new 
IdentifierValue("b")));
+        SetAssignmentSegment setAssignmentSegment = new 
SetAssignmentSegment(0, 0,
+                Collections.singletonList(new ColumnAssignmentSegment(0, 0, 
Collections.singletonList(targetTableColumn), sourceTableColumn)));
+        updateStatement.setSetAssignment(setAssignmentSegment);
+        mergeStatement.setUpdate(updateStatement);
+        MergeStatement actual = new 
MergeStatementBinder().bind(mergeStatement, createMetaData(), 
DefaultDatabase.LOGIC_NAME);
+        assertThat(actual, not(mergeStatement));
+        assertThat(actual.getSource(), not(mergeStatement.getSource()));
+        assertThat(actual.getSource(), instanceOf(SimpleTableSegment.class));
+        assertThat(actual.getTarget(), not(mergeStatement.getTarget()));
+        assertThat(actual.getTarget(), instanceOf(SimpleTableSegment.class));
+        assertThat(actual.getUpdate(), not(mergeStatement.getUpdate()));
+        
assertThat(actual.getUpdate().getSetAssignment().getAssignments().iterator().next().getValue(),
 instanceOf(ColumnSegment.class));
+        assertThat(((ColumnSegment) 
actual.getUpdate().getSetAssignment().getAssignments().iterator().next().getValue()).getColumnBoundedInfo().getOriginalTable().getValue(),
 is("t_order_item"));
+    }
+    
+    private ShardingSphereMetaData createMetaData() {
+        ShardingSphereSchema schema = mock(ShardingSphereSchema.class, 
RETURNS_DEEP_STUBS);
+        
when(schema.getTable("t_order").getColumnValues()).thenReturn(Arrays.asList(
+                new ShardingSphereColumn("id", Types.INTEGER, true, false, 
false, true, false, false),
+                new ShardingSphereColumn("user_id", Types.INTEGER, false, 
false, false, true, false, false),
+                new ShardingSphereColumn("status", Types.INTEGER, false, 
false, false, true, false, false)));
+        
when(schema.getTable("t_order_item").getColumnValues()).thenReturn(Arrays.asList(
+                new ShardingSphereColumn("item_id", Types.INTEGER, true, 
false, false, true, false, false),
+                new ShardingSphereColumn("order_id", Types.INTEGER, false, 
false, false, true, false, false),
+                new ShardingSphereColumn("status", Types.INTEGER, false, 
false, false, true, false, false)));
+        ShardingSphereMetaData result = mock(ShardingSphereMetaData.class, 
RETURNS_DEEP_STUBS);
+        
when(result.getDatabase(DefaultDatabase.LOGIC_NAME).getSchema(DefaultDatabase.LOGIC_NAME)).thenReturn(schema);
+        return result;
+    }
+}
diff --git 
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/dialect/handler/dml/InsertStatementHandler.java
 
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/dialect/handler/dml/InsertStatementHandler.java
index 4773924c474..215bf7ecdd1 100644
--- 
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/dialect/handler/dml/InsertStatementHandler.java
+++ 
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/dialect/handler/dml/InsertStatementHandler.java
@@ -22,6 +22,7 @@ import lombok.NoArgsConstructor;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.ReturningSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.OnDuplicateKeyColumnsSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.InsertMultiTableElementSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.OutputSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.WithSegment;
@@ -217,4 +218,29 @@ public final class InsertStatementHandler implements 
SQLStatementHandler {
             ((OpenGaussInsertStatement) 
insertStatement).setReturningSegment(returningSegment);
         }
     }
+    
+    /**
+     * Get where segment.
+     *
+     * @param insertStatement insert statement
+     * @return where segment
+     */
+    public static Optional<WhereSegment> getWhereSegment(final InsertStatement 
insertStatement) {
+        if (insertStatement instanceof OracleInsertStatement) {
+            return ((OracleInsertStatement) insertStatement).getWhere();
+        }
+        return Optional.empty();
+    }
+    
+    /**
+     * Set where segment.
+     * 
+     * @param insertStatement insert statement
+     * @param whereSegment where segment
+     */
+    public static void setWhereSegment(final InsertStatement insertStatement, 
final WhereSegment whereSegment) {
+        if (insertStatement instanceof OracleInsertStatement) {
+            ((OracleInsertStatement) insertStatement).setWhere(whereSegment);
+        }
+    }
 }

Reply via email to