This is an automated email from the ASF dual-hosted git repository.

gavinchou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 8d6659cb2e4 [feature](security) Support block specific query with AST 
names (#43533)
8d6659cb2e4 is described below

commit 8d6659cb2e46cb764918416f1d44cc26b4e4ec5f
Author: Siyang Tang <tangsiy...@selectdb.com>
AuthorDate: Tue Nov 12 23:51:57 2024 +0800

    [feature](security) Support block specific query with AST names (#43533)
    
    Support block specific query with AST names when necessary for security
    reasons, configure the name list in fe.conf, for example:
    ```
    block_sql_ast_names="CreateFileStmt, CreateFunctionStmt"
    ```
---
 .../main/java/org/apache/doris/common/Config.java  |  6 ++
 .../main/java/org/apache/doris/catalog/Env.java    |  3 +
 .../java/org/apache/doris/qe/StmtExecutor.java     | 21 +++++
 .../doris/nereids/parser/NereidsParserTest.java    | 36 +++++++++
 .../java/org/apache/doris/qe/StmtExecutorTest.java | 90 ++++++++++++++++++++++
 5 files changed, 156 insertions(+)

diff --git a/fe/fe-common/src/main/java/org/apache/doris/common/Config.java 
b/fe/fe-common/src/main/java/org/apache/doris/common/Config.java
index ccc93ed799a..993702c4dac 100644
--- a/fe/fe-common/src/main/java/org/apache/doris/common/Config.java
+++ b/fe/fe-common/src/main/java/org/apache/doris/common/Config.java
@@ -3249,4 +3249,10 @@ public class Config extends ConfigBase {
             "For testing purposes, all queries are forcibly forwarded to the 
master to verify"
                     + "the behavior of forwarding queries."})
     public static boolean force_forward_all_queries = false;
+
+    @ConfField(description = {"用于禁用某些SQL,配置项为AST的class simple 
name列表(例如CreateRepositoryStmt,"
+            + "CreatePolicyCommand),用逗号间隔开",
+            "For disabling certain SQL queries, the configuration item is a 
list of simple class names of AST"
+                    + "(for example CreateRepositoryStmt, 
CreatePolicyCommand), separated by commas."})
+    public static String block_sql_ast_names = "";
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Env.java 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/Env.java
index 9e65e5c866b..f67444b6cb1 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Env.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Env.java
@@ -246,6 +246,7 @@ import org.apache.doris.qe.GlobalVariable;
 import org.apache.doris.qe.JournalObservable;
 import org.apache.doris.qe.QueryCancelWorker;
 import org.apache.doris.qe.SessionVariable;
+import org.apache.doris.qe.StmtExecutor;
 import org.apache.doris.qe.VariableMgr;
 import org.apache.doris.resource.AdmissionControl;
 import org.apache.doris.resource.Tag;
@@ -1116,6 +1117,8 @@ public class Env {
             notifyNewFETypeTransfer(FrontendNodeType.MASTER);
         }
         queryCancelWorker.start();
+
+        StmtExecutor.initBlockSqlAstNames();
     }
 
     // wait until FE is ready.
diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java 
b/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java
index 80c7ad0912d..8b76933e9e0 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java
@@ -259,6 +259,8 @@ public class StmtExecutor {
     private static final AtomicLong STMT_ID_GENERATOR = new AtomicLong(0);
     public static final int MAX_DATA_TO_SEND_FOR_TXN = 100;
     public static final String NULL_VALUE_FOR_LOAD = "\\N";
+    private static Set<String> blockSqlAstNames = Sets.newHashSet();
+
     private Pattern beIpPattern = Pattern.compile("\\[(\\d+):");
     private ConnectContext context;
     private final StatementContext statementContext;
@@ -694,6 +696,7 @@ public class StmtExecutor {
                 "Nereids only process LogicalPlanAdapter, but parsedStmt is " 
+ parsedStmt.getClass().getName());
         context.getState().setNereids(true);
         LogicalPlan logicalPlan = ((LogicalPlanAdapter) 
parsedStmt).getLogicalPlan();
+        checkSqlBlocked(logicalPlan.getClass());
         if (context.getCommand() == MysqlCommand.COM_STMT_PREPARE) {
             if (isForwardToMaster()) {
                 throw new UserException("Forward master command is not 
supported for prepare statement");
@@ -834,6 +837,23 @@ public class StmtExecutor {
         }
     }
 
+    public static void initBlockSqlAstNames() {
+        blockSqlAstNames.clear();
+        blockSqlAstNames = Pattern.compile(",")
+                .splitAsStream(Config.block_sql_ast_names)
+                .map(String::trim)
+                .collect(Collectors.toSet());
+        if (blockSqlAstNames.isEmpty() && 
!Config.block_sql_ast_names.isEmpty()) {
+            blockSqlAstNames.add(Config.block_sql_ast_names);
+        }
+    }
+
+    public void checkSqlBlocked(Class<?> clazz) throws UserException {
+        if (blockSqlAstNames.contains(clazz.getSimpleName())) {
+            throw new UserException("SQL is blocked with AST name: " + 
clazz.getSimpleName());
+        }
+    }
+
     private void parseByNereids() {
         if (parsedStmt != null) {
             return;
@@ -981,6 +1001,7 @@ public class StmtExecutor {
         try {
             // parsedStmt maybe null here, we parse it. Or the predicate will 
not work.
             parseByLegacy();
+            checkSqlBlocked(parsedStmt.getClass());
             if (context.isTxnModel() && !(parsedStmt instanceof InsertStmt)
                     && !(parsedStmt instanceof TransactionStmt)) {
                 throw new TException("This is in a transaction, only insert, 
update, delete, "
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/NereidsParserTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/NereidsParserTest.java
index ff9e81f2bf3..9a46b810586 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/NereidsParserTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/NereidsParserTest.java
@@ -43,6 +43,8 @@ import org.apache.doris.nereids.types.DateTimeType;
 import org.apache.doris.nereids.types.DateType;
 import org.apache.doris.nereids.types.DecimalV2Type;
 import org.apache.doris.nereids.types.DecimalV3Type;
+import org.apache.doris.qe.ConnectContext;
+import org.apache.doris.qe.StmtExecutor;
 
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
@@ -704,4 +706,38 @@ public class NereidsParserTest extends ParserTestBase {
             nereidsParser.parseSingle(sql);
         }
     }
+
+    @Test
+    public void testBlockSqlAst() {
+        String sql = "plan replayer dump select `AD``D` from t1 where a = 1";
+        NereidsParser nereidsParser = new NereidsParser();
+        LogicalPlan logicalPlan = nereidsParser.parseSingle(sql);
+
+        Config.block_sql_ast_names = "ReplayCommand";
+        StmtExecutor.initBlockSqlAstNames();
+        StmtExecutor stmtExecutor = new StmtExecutor(new ConnectContext(), "");
+        try {
+            stmtExecutor.checkSqlBlocked(logicalPlan.getClass());
+            Assertions.fail();
+        } catch (Exception ignore) {
+            // do nothing
+        }
+
+        Config.block_sql_ast_names = "CreatePolicyCommand, ReplayCommand";
+        StmtExecutor.initBlockSqlAstNames();
+        try {
+            stmtExecutor.checkSqlBlocked(logicalPlan.getClass());
+            Assertions.fail();
+        } catch (Exception ignore) {
+            // do nothing
+        }
+
+        Config.block_sql_ast_names = "";
+        StmtExecutor.initBlockSqlAstNames();
+        try {
+            stmtExecutor.checkSqlBlocked(logicalPlan.getClass());
+        } catch (Exception ex) {
+            Assertions.fail(ex);
+        }
+    }
 }
diff --git a/fe/fe-core/src/test/java/org/apache/doris/qe/StmtExecutorTest.java 
b/fe/fe-core/src/test/java/org/apache/doris/qe/StmtExecutorTest.java
index 07d39e52180..8ab187315d8 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/qe/StmtExecutorTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/qe/StmtExecutorTest.java
@@ -19,6 +19,8 @@ package org.apache.doris.qe;
 
 import org.apache.doris.analysis.AccessTestUtil;
 import org.apache.doris.analysis.Analyzer;
+import org.apache.doris.analysis.CreateFileStmt;
+import org.apache.doris.analysis.CreateFunctionStmt;
 import org.apache.doris.analysis.DdlStmt;
 import org.apache.doris.analysis.Expr;
 import org.apache.doris.analysis.KillStmt;
@@ -31,6 +33,7 @@ import org.apache.doris.analysis.SqlParser;
 import org.apache.doris.analysis.StatementBase;
 import org.apache.doris.analysis.UseStmt;
 import org.apache.doris.catalog.Env;
+import org.apache.doris.common.Config;
 import org.apache.doris.common.DdlException;
 import org.apache.doris.common.jmockit.Deencapsulation;
 import org.apache.doris.common.profile.Profile;
@@ -801,4 +804,91 @@ public class StmtExecutorTest {
 
         Assert.assertEquals(QueryState.MysqlStateType.ERR, 
state.getStateType());
     }
+
+    @Test
+    public void testBlockSqlAst(@Mocked UseStmt useStmt, @Mocked 
CreateFileStmt createFileStmt,
+            @Mocked CreateFunctionStmt createFunctionStmt, @Mocked SqlParser 
parser) throws Exception {
+        new Expectations() {
+            {
+                useStmt.analyze((Analyzer) any);
+                minTimes = 0;
+
+                useStmt.getDatabase();
+                minTimes = 0;
+                result = "testDb";
+
+                useStmt.getRedirectStatus();
+                minTimes = 0;
+                result = RedirectStatus.NO_FORWARD;
+
+                useStmt.getCatalogName();
+                minTimes = 0;
+                result = InternalCatalog.INTERNAL_CATALOG_NAME;
+
+                Symbol symbol = new Symbol(0, 
Lists.newArrayList(createFileStmt));
+                parser.parse();
+                minTimes = 0;
+                result = symbol;
+            }
+        };
+
+        Config.block_sql_ast_names = "CreateFileStmt";
+        StmtExecutor.initBlockSqlAstNames();
+        StmtExecutor executor = new StmtExecutor(ctx, "");
+        try {
+            executor.execute();
+        } catch (Exception ignore) {
+            // do nothing
+        }
+        Assert.assertEquals(QueryState.MysqlStateType.ERR, 
state.getStateType());
+        Assert.assertTrue(state.getErrorMessage().contains("SQL is blocked 
with AST name: CreateFileStmt"));
+
+        Config.block_sql_ast_names = "AlterStmt, CreateFileStmt";
+        StmtExecutor.initBlockSqlAstNames();
+        executor = new StmtExecutor(ctx, "");
+        try {
+            executor.execute();
+        } catch (Exception ignore) {
+            // do nothing
+        }
+        Assert.assertEquals(QueryState.MysqlStateType.ERR, 
state.getStateType());
+        Assert.assertTrue(state.getErrorMessage().contains("SQL is blocked 
with AST name: CreateFileStmt"));
+
+        new Expectations() {
+            {
+                Symbol symbol = new Symbol(0, 
Lists.newArrayList(createFunctionStmt));
+                parser.parse();
+                minTimes = 0;
+                result = symbol;
+            }
+        };
+        Config.block_sql_ast_names = "CreateFunctionStmt, CreateFileStmt";
+        StmtExecutor.initBlockSqlAstNames();
+        executor = new StmtExecutor(ctx, "");
+        try {
+            executor.execute();
+        } catch (Exception ignore) {
+            // do nothing
+        }
+        Assert.assertEquals(QueryState.MysqlStateType.ERR, 
state.getStateType());
+        Assert.assertTrue(state.getErrorMessage().contains("SQL is blocked 
with AST name: CreateFunctionStmt"));
+
+        new Expectations() {
+            {
+                Symbol symbol = new Symbol(0, Lists.newArrayList(useStmt));
+                parser.parse();
+                minTimes = 0;
+                result = symbol;
+            }
+        };
+        executor = new StmtExecutor(ctx, "");
+        executor.execute();
+        Assert.assertEquals(QueryState.MysqlStateType.OK, 
state.getStateType());
+
+        Config.block_sql_ast_names = "";
+        StmtExecutor.initBlockSqlAstNames();
+        executor = new StmtExecutor(ctx, "");
+        executor.execute();
+        Assert.assertEquals(QueryState.MysqlStateType.OK, 
state.getStateType());
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to