This is an automated email from the ASF dual-hosted git repository. duanzhengqiang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push: new a1ea7ffac94 Support Oracle PREDICTION_COST function SQL parse (#34705) a1ea7ffac94 is described below commit a1ea7ffac943060309f676bf849550dc4d919675 Author: ZhangCheng <chengzh...@apache.org> AuthorDate: Wed Feb 19 09:10:16 2025 +0800 Support Oracle PREDICTION_COST function SQL parse (#34705) * Support Oracle PREDICTION_COST function SQL parse * Support Oracle PREDICTION_COST function SQL parse * Support Oracle PREDICTION_COST function SQL parse --- .../src/main/antlr4/imports/oracle/BaseRule.g4 | 10 +++++--- .../visitor/statement/OracleStatementVisitor.java | 13 ++++++++-- .../resources/case/dml/select-special-function.xml | 29 ++++++++++++++++++++++ .../sql/supported/dml/select-special-function.xml | 1 + 4 files changed, 48 insertions(+), 5 deletions(-) diff --git a/parser/sql/dialect/oracle/src/main/antlr4/imports/oracle/BaseRule.g4 b/parser/sql/dialect/oracle/src/main/antlr4/imports/oracle/BaseRule.g4 index 2bde63a09ee..29fb3f7fe13 100644 --- a/parser/sql/dialect/oracle/src/main/antlr4/imports/oracle/BaseRule.g4 +++ b/parser/sql/dialect/oracle/src/main/antlr4/imports/oracle/BaseRule.g4 @@ -796,7 +796,11 @@ leadLagInfo specialFunction : castFunction | charFunction | extractFunction | formatFunction | firstOrLastValueFunction | trimFunction | featureFunction - | setFunction | translateFunction | cursorFunction | toDateFunction | approxRank | wmConcatFunction + | setFunction | translateFunction | cursorFunction | toDateFunction | approxRank | wmConcatFunction | predictionCostFunction + ; + +predictionCostFunction + : PREDICTION_COST LP_ (schemaName DOT_)? modelName (COMMA_ classExpr=expr)? costMatrixClause? miningAttributeClause? RP_ ; wmConcatFunction @@ -825,7 +829,7 @@ setFunction featureFunction : featureFunctionName LP_ (schemaName DOT_)? modelName (COMMA_ featureId)? (COMMA_ numberLiterals (COMMA_ numberLiterals)?)? - (DESC | ASC | ABS)? cost_matrix_clause? miningAttributeClause (AND miningAttributeClause)? RP_ + (DESC | ASC | ABS)? costMatrixClause? miningAttributeClause (AND miningAttributeClause)? RP_ ; featureFunctionName @@ -833,7 +837,7 @@ featureFunctionName | PREDICTION_PROBABILITY | PREDICTION_SET | PREDICTION_BOUNDS | PREDICTION | PREDICTION_DETAILS ; -cost_matrix_clause +costMatrixClause : COST (MODEL (AUTO)?)? | LP_ literals RP_ (COMMA_ LP_ literals RP_)* VALUES LP_ LP_ literals (COMMA_ literals)* RP_ (COMMA_ LP_ literals (COMMA_ literals)* RP_) RP_ ; diff --git a/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/OracleStatementVisitor.java b/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/OracleStatementVisitor.java index 8bbad56b215..3d99013c1bc 100644 --- a/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/OracleStatementVisitor.java +++ b/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/OracleStatementVisitor.java @@ -69,6 +69,7 @@ import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.OwnerC import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.PackageNameContext; import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ParameterMarkerContext; import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.PredicateContext; +import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.PredictionCostFunctionContext; import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.PrivateExprOfDbContext; import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.RegularFunctionContext; import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.SchemaNameContext; @@ -1044,9 +1045,17 @@ public abstract class OracleStatementVisitor extends OracleStatementBaseVisitor< if (null != ctx.wmConcatFunction()) { return visit(ctx.wmConcatFunction()); } + if (null != ctx.predictionCostFunction()) { + return visit(ctx.predictionCostFunction()); + } throw new IllegalStateException( - "SpecialFunctionContext must have castFunction, charFunction, extractFunction, formatFunction, firstOrLastValueFunction, trimFunction, toDateFunction, approxCount" - + " or featureFunction."); + "SpecialFunctionContext must have castFunction, charFunction, extractFunction, formatFunction, firstOrLastValueFunction, " + + "trimFunction, toDateFunction, approxCount, predictionCostFunction or featureFunction."); + } + + @Override + public ASTNode visitPredictionCostFunction(final PredictionCostFunctionContext ctx) { + return new FunctionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), ctx.PREDICTION_COST().getText(), getOriginalText(ctx)); } @Override diff --git a/test/it/parser/src/main/resources/case/dml/select-special-function.xml b/test/it/parser/src/main/resources/case/dml/select-special-function.xml index 2929afa1e1a..df0a2afee96 100644 --- a/test/it/parser/src/main/resources/case/dml/select-special-function.xml +++ b/test/it/parser/src/main/resources/case/dml/select-special-function.xml @@ -4881,4 +4881,33 @@ </expression-projection> </projections> </select> + + <select sql-case-id="select_prediction_cost"> + <projections distinct-row="false" start-index="7" stop-index="13"> + <column-projection name="cust_id" start-delimiter="" end-delimiter="" start-index="7" stop-index="13"/> + </projections> + <from start-delimiter="" end-delimiter="" start-index="20" stop-index="38"> + <simple-table name="mining_data_apply_v" start-delimiter="" end-delimiter="" start-index="20" stop-index="38"/> + </from> + <where start-index="40" stop-index="67"> + <expr start-index="46" stop-index="67"> + <binary-operation-expression start-index="46" stop-index="67"> + <left start-index="46" stop-index="57"> + <column name="country_name" start-delimiter="" end-delimiter="" start-index="46" stop-index="57"/> + </left> + <operator>=</operator> + <right start-index="61" stop-index="67"> + <literal-expression value="Italy" start-index="61" stop-index="67"/> + </right> + </binary-operation-expression> + </expr> + </where> + <order-by start-index="69" stop-index="133"> + <expression-item expression="PREDICTION_COST(DT_SH_Clas_sample, 1 COST MODEL USING *)" order-direction="ASC" start-delimiter="" end-delimiter="" start-index="78" stop-index="133"> + <expr start-index="78" stop-index="133"> + <function function-name="PREDICTION_COST" text="PREDICTION_COST(DT_SH_Clas_sample, 1 COST MODEL USING *)" start-index="78" stop-index="133"/> + </expr> + </expression-item> + </order-by> + </select> </sql-parser-test-cases> diff --git a/test/it/parser/src/main/resources/sql/supported/dml/select-special-function.xml b/test/it/parser/src/main/resources/sql/supported/dml/select-special-function.xml index d8349168daa..8b31cc5dae5 100644 --- a/test/it/parser/src/main/resources/sql/supported/dml/select-special-function.xml +++ b/test/it/parser/src/main/resources/sql/supported/dml/select-special-function.xml @@ -282,4 +282,5 @@ <sql-case id="select_lower_function" value="SELECT LOWER('QUADRATICALLY')" db-types="MySQL" /> <sql-case id="select_length" value="SELECT LENGTH('TEXT')" db-types="MySQL" /> <sql-case id="select_locate" value="SELECT LOCATE('bar','foobarbar')" db-types="MySQL" /> + <sql-case id="select_prediction_cost" value="SELECT cust_id FROM mining_data_apply_v WHERE country_name = 'Italy' ORDER BY PREDICTION_COST(DT_SH_Clas_sample, 1 COST MODEL USING *)" db-types="Oracle" /> </sql-cases>