This is an automated email from the ASF dual-hosted git repository. morningman pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit f9ae03ac3c766dd3ee0e831af65cb20af1ff911e Author: 924060929 <924060...@qq.com> AuthorDate: Fri Mar 22 10:58:43 2024 +0800 [feature](Nereids) support data masking policy (#32526) support data masking policy note: if a user send the query ```sql select name from tbl limit 1 ``` and the user have row policy on `tbl.name` with the filter `name = 'Beijing'`, and have data masking policy on `tbl.name` with the masking `concat(substring(name, 1, 4), '****')`, we will rewrite the query to ```sql select concat(substring(name, 1, 4), '****') as name from tbl where name = 'Beijing' -- note that this name is from tbl, not from the alias in the select list limit 1 ``` the result would be `Beij****` --- .../doris/nereids/rules/analysis/CheckPolicy.java | 24 +-- .../trees/plans/logical/LogicalCheckPolicy.java | 67 +++++-- .../java/org/apache/doris/policy/PolicyMgr.java | 6 +- .../nereids/privileges/TestCheckPrivileges.java | 200 ++++++++++++++++++++- 4 files changed, 270 insertions(+), 27 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java index aa2917ae1e1..1e7d4dbb09d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java @@ -24,14 +24,16 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy; +import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy.RelatedPolicy; import org.apache.doris.nereids.trees.plans.logical.LogicalFileScan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; -import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.List; import java.util.Set; @@ -60,7 +62,7 @@ public class CheckPolicy implements AnalysisRuleFactory { return ctx.root.child(); } LogicalRelation relation = (LogicalRelation) child; - Set<Expression> combineFilter = new HashSet<>(); + Set<Expression> combineFilter = new LinkedHashSet<>(); // replace incremental params as AND expression if (relation instanceof LogicalFileScan) { @@ -72,18 +74,20 @@ public class CheckPolicy implements AnalysisRuleFactory { } } - // row policy - checkPolicy.getFilter(relation, ctx.connectContext) - .ifPresent(expression -> combineFilter.addAll( + RelatedPolicy relatedPolicy = checkPolicy.findPolicy(relation, ctx.cascadesContext); + relatedPolicy.rowPolicyFilter.ifPresent(expression -> combineFilter.addAll( ExpressionUtils.extractConjunctionToSet(expression))); - - if (combineFilter.isEmpty()) { - return ctx.root.child(); - } + Plan result = relation; if (upperFilter != null) { combineFilter.addAll(upperFilter.getConjuncts()); } - return new LogicalFilter<>(combineFilter, relation); + if (!combineFilter.isEmpty()) { + result = new LogicalFilter<>(combineFilter, relation); + } + if (relatedPolicy.dataMaskProjects.isPresent()) { + result = new LogicalProject<>(relatedPolicy.dataMaskProjects.get(), result); + } + return result; }) ) ); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCheckPolicy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCheckPolicy.java index 54f38034761..bda3b1f49d5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCheckPolicy.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCheckPolicy.java @@ -19,13 +19,20 @@ package org.apache.doris.nereids.trees.plans.logical; import org.apache.doris.analysis.UserIdentity; import org.apache.doris.mysql.privilege.AccessControllerManager; +import org.apache.doris.mysql.privilege.DataMaskPolicy; import org.apache.doris.mysql.privilege.RowFilterPolicy; +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.analyzer.UnboundAlias; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.memo.GroupExpression; +import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.properties.LogicalProperties; +import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.PropagateFuncDeps; @@ -37,6 +44,7 @@ import org.apache.doris.qe.ConnectContext; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import org.apache.commons.collections.CollectionUtils; import java.util.ArrayList; import java.util.List; @@ -113,32 +121,58 @@ public class LogicalCheckPolicy<CHILD_TYPE extends Plan> extends LogicalUnary<CH } /** - * get wherePredicate of policy for logicalRelation. + * find related policy for logicalRelation. * * @param logicalRelation include tableName and dbName - * @param connectContext include information about user and policy + * @param cascadesContext include information about user and policy */ - public Optional<Expression> getFilter(LogicalRelation logicalRelation, ConnectContext connectContext) { + public RelatedPolicy findPolicy(LogicalRelation logicalRelation, CascadesContext cascadesContext) { if (!(logicalRelation instanceof CatalogRelation)) { - return Optional.empty(); + return RelatedPolicy.NO_POLICY; } + ConnectContext connectContext = cascadesContext.getConnectContext(); AccessControllerManager accessManager = connectContext.getEnv().getAccessManager(); UserIdentity currentUserIdentity = connectContext.getCurrentUserIdentity(); if (currentUserIdentity.isRootUser() || currentUserIdentity.isAdminUser()) { - return Optional.empty(); + return RelatedPolicy.NO_POLICY; } CatalogRelation catalogRelation = (CatalogRelation) logicalRelation; String ctlName = catalogRelation.getDatabase().getCatalog().getName(); String dbName = catalogRelation.getDatabase().getFullName(); String tableName = catalogRelation.getTable().getName(); - List<? extends RowFilterPolicy> policies = accessManager.evalRowFilterPolicies(currentUserIdentity, ctlName, - dbName, tableName); - if (policies.isEmpty()) { - return Optional.empty(); + + NereidsParser nereidsParser = new NereidsParser(); + ImmutableList.Builder<NamedExpression> dataMasks + = ImmutableList.builderWithExpectedSize(logicalRelation.getOutput().size()); + + boolean hasDataMask = false; + for (Slot slot : logicalRelation.getOutput()) { + Optional<DataMaskPolicy> dataMaskPolicy = accessManager.evalDataMaskPolicy( + currentUserIdentity, ctlName, dbName, tableName, slot.getName()); + if (dataMaskPolicy.isPresent()) { + Expression unboundExpr = nereidsParser.parseExpression(dataMaskPolicy.get().getMaskTypeDef()); + Expression childOfAlias + = unboundExpr instanceof UnboundAlias ? unboundExpr.child(0) : unboundExpr; + Alias alias = new Alias( + StatementScopeIdGenerator.newExprId(), + ImmutableList.of(childOfAlias), + slot.getName(), slot.getQualifier(), false + ); + dataMasks.add(alias); + hasDataMask = true; + } else { + dataMasks.add(slot); + } } - return Optional.ofNullable(mergeRowPolicy(policies)); + + List<? extends RowFilterPolicy> policies = accessManager.evalRowFilterPolicies( + currentUserIdentity, ctlName, dbName, tableName); + return new RelatedPolicy( + Optional.ofNullable(CollectionUtils.isEmpty(policies) ? null : mergeRowPolicy(policies)), + hasDataMask ? Optional.of(dataMasks.build()) : Optional.empty() + ); } private Expression mergeRowPolicy(List<? extends RowFilterPolicy> policies) { @@ -172,4 +206,17 @@ public class LogicalCheckPolicy<CHILD_TYPE extends Plan> extends LogicalUnary<CH return null; } } + + /** RelatedPolicy */ + public static class RelatedPolicy { + public static final RelatedPolicy NO_POLICY = new RelatedPolicy(Optional.empty(), Optional.empty()); + + public final Optional<Expression> rowPolicyFilter; + public final Optional<List<NamedExpression>> dataMaskProjects; + + public RelatedPolicy(Optional<Expression> rowPolicyFilter, Optional<List<NamedExpression>> dataMaskProjects) { + this.rowPolicyFilter = rowPolicyFilter; + this.dataMaskProjects = dataMaskProjects; + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/policy/PolicyMgr.java b/fe/fe-core/src/main/java/org/apache/doris/policy/PolicyMgr.java index 6673cccd0bf..575c10c3b74 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/policy/PolicyMgr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/policy/PolicyMgr.java @@ -292,7 +292,7 @@ public class PolicyMgr implements Writable { return; } rowPolicy.setCtlName(InternalCatalog.INTERNAL_CATALOG_NAME); - rowPolicy.setDbName(db.get().getName()); + rowPolicy.setDbName(db.get().getFullName()); rowPolicy.setTableName(table.get().getName()); } } @@ -330,7 +330,7 @@ public class PolicyMgr implements Writable { return; } log.setCtlName(InternalCatalog.INTERNAL_CATALOG_NAME); - log.setDbName(db.get().getName()); + log.setDbName(db.get().getFullName()); log.setTableName(table.get().getName()); } unprotectedDrop(log); @@ -545,7 +545,7 @@ public class PolicyMgr implements Writable { continue; } rowPolicy.setCtlName(InternalCatalog.INTERNAL_CATALOG_NAME); - rowPolicy.setDbName(db.get().getName()); + rowPolicy.setDbName(db.get().getFullName()); rowPolicy.setTableName(table.get().getName()); } compatiblePolicies.add(rowPolicy); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/privileges/TestCheckPrivileges.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/privileges/TestCheckPrivileges.java index 7b9c3ccd767..07cbb002a64 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/privileges/TestCheckPrivileges.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/privileges/TestCheckPrivileges.java @@ -32,7 +32,18 @@ import org.apache.doris.mysql.privilege.DataMaskPolicy; import org.apache.doris.mysql.privilege.PrivPredicate; import org.apache.doris.mysql.privilege.RowFilterPolicy; import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.parser.NereidsParser; +import org.apache.doris.nereids.pattern.GeneratedMemoPatterns; +import org.apache.doris.nereids.rules.RulePromise; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Concat; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.policy.FilterType; import org.apache.doris.utframe.TestWithFeService; import com.google.common.collect.ImmutableList; @@ -49,8 +60,10 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; -public class TestCheckPrivileges extends TestWithFeService { +public class TestCheckPrivileges extends TestWithFeService implements GeneratedMemoPatterns { private static final Map<String, Map<String, List<Column>>> CATALOG_META = ImmutableMap.of( "test_db", ImmutableMap.of( "test_tbl1", ImmutableList.of( @@ -64,12 +77,16 @@ public class TestCheckPrivileges extends TestWithFeService { "test_tbl3", ImmutableList.of( new Column("id", PrimitiveType.INT), new Column("name", PrimitiveType.VARCHAR) + ), + "test_tbl4", ImmutableList.of( + new Column("id", PrimitiveType.INT), + new Column("name", PrimitiveType.VARCHAR) ) ) ); @Test - public void testColumnPrivileges() throws Exception { + public void testPrivilegesAndPolicies() throws Exception { FeConstants.runningUnitTest = true; String catalogProvider = "org.apache.doris.nereids.privileges.TestCheckPrivileges$CustomCatalogProvider"; @@ -89,6 +106,7 @@ public class TestCheckPrivileges extends TestWithFeService { String table1 = "test_tbl1"; String table2 = "test_tbl2"; String table3 = "test_tbl3"; + String table4 = "test_tbl4"; String view1 = "query_tbl2_view1"; createView("create view " + internalDb + "." @@ -118,7 +136,12 @@ public class TestCheckPrivileges extends TestWithFeService { .allowSelectColumns(user, ImmutableSet.of("name")), MakePrivileges.table("internal", internalDb, view4) - .allowSelectColumns(user, ImmutableSet.of("id")) + .allowSelectColumns(user, ImmutableSet.of("id")), + + // data masking and row policy + MakePrivileges.table(catalog, db, table4).allowSelectTable(user) + .addRowPolicy(user, "id = 1") + .addDataMasking(user, "id", "concat(id, '_****_', id)") ); AccessControllerManager accessManager = Env.getCurrentEnv().getAccessManager(); @@ -181,6 +204,64 @@ public class TestCheckPrivileges extends TestWithFeService { query("select name from " + internalDb + "." + view4) ); } + + // test row policy with data masking + { + Function<NamedExpression, Boolean> checkId = (NamedExpression ne) -> { + if (!(ne instanceof Alias) || !ne.getName().equals("id")) { + return false; + } + return ne.child(0) instanceof Concat; + }; + PlanChecker.from(connectContext) + .parse("select id," + + " test_tbl4.id," + + " test_db.test_tbl4.id, " + + " custom_catalog.test_db.test_tbl4.id, " + + " * " + + "from custom_catalog.test_db.test_tbl4") + .analyze() + .rewrite() + .matches(logicalProject( + logicalFilter( + logicalTestScan() + ).when(f -> { + EqualTo predicate = (EqualTo) f.getPredicate(); + return predicate.left() instanceof Slot + && predicate.right().equals(new IntegerLiteral((byte) 1)); + }) + ).when(p -> { + List<NamedExpression> projects = p.getProjects(); + if (!checkId.apply(projects.get(0)) || !checkId.apply(projects.get(1)) + || !checkId.apply(projects.get(2)) || !checkId.apply(projects.get(3)) + || !checkId.apply(projects.get(4))) { + return false; + } + return projects.get(5) instanceof Slot && projects.get(5).getName().equals("name"); + })); + + PlanChecker.from(connectContext) + .parse("select id, t.id, *" + + "from custom_catalog.test_db.test_tbl4 t") + .analyze() + .rewrite() + .matches(logicalProject( + logicalFilter( + logicalTestScan() + ).when(f -> { + EqualTo predicate = (EqualTo) f.getPredicate(); + return predicate.left() instanceof Slot + && predicate.right().equals(new IntegerLiteral((byte) 1)); + }) + ).when(p -> { + List<NamedExpression> projects = p.getProjects(); + if (!checkId.apply(projects.get(0)) || !checkId.apply(projects.get(1)) + || !checkId.apply(projects.get(2))) { + return false; + } + return projects.get(3) instanceof Slot && projects.get(3).getName().equals("name"); + })); + } }); } @@ -194,23 +275,36 @@ public class TestCheckPrivileges extends TestWithFeService { private void withPrivileges(List<MakeTablePrivileges> privileges, Runnable task) { List<TablePrivilege> tablePrivileges = Lists.newArrayList(); List<ColumnPrivilege> columnPrivileges = Lists.newArrayList(); + List<CustomRowPolicy> rowPolicies = Lists.newArrayList(); + List<CustomDataMaskingPolicy> dataMaskingPolicies = Lists.newArrayList(); for (MakeTablePrivileges privilege : privileges) { tablePrivileges.addAll(privilege.tablePrivileges); columnPrivileges.addAll(privilege.columnPrivileges); + rowPolicies.addAll(privilege.rowPolicies); + dataMaskingPolicies.addAll(privilege.dataMaskingPolicies); } SimpleCatalogAccessController.tablePrivileges.set(tablePrivileges); SimpleCatalogAccessController.columnPrivileges.set(columnPrivileges); + SimpleCatalogAccessController.rowPolicies.set(rowPolicies); + SimpleCatalogAccessController.dataMaskings.set(dataMaskingPolicies); try { task.run(); } finally { + SimpleCatalogAccessController.rowPolicies.remove(); + SimpleCatalogAccessController.dataMaskings.remove(); SimpleCatalogAccessController.tablePrivileges.remove(); SimpleCatalogAccessController.columnPrivileges.remove(); } } + @Override + public RulePromise defaultPromise() { + return RulePromise.REWRITE; + } + public static class CustomCatalogProvider implements TestCatalogProvider { @Override @@ -229,6 +323,8 @@ public class TestCheckPrivileges extends TestWithFeService { public static class SimpleCatalogAccessController implements CatalogAccessController { private static ThreadLocal<List<TablePrivilege>> tablePrivileges = new ThreadLocal<>(); private static ThreadLocal<List<ColumnPrivilege>> columnPrivileges = new ThreadLocal<>(); + private static ThreadLocal<List<CustomRowPolicy>> rowPolicies = new ThreadLocal<>(); + private static ThreadLocal<List<CustomDataMaskingPolicy>> dataMaskings = new ThreadLocal<>(); @Override public boolean checkGlobalPriv(UserIdentity currentUser, PrivPredicate wanted) { @@ -305,13 +401,40 @@ public class TestCheckPrivileges extends TestWithFeService { @Override public Optional<DataMaskPolicy> evalDataMaskPolicy(UserIdentity currentUser, String ctl, String db, String tbl, String col) { + List<CustomDataMaskingPolicy> dataMaskingPolicies = dataMaskings.get(); + if (dataMaskingPolicies == null) { + return Optional.empty(); + } + + for (CustomDataMaskingPolicy dataMaskingPolicy : dataMaskingPolicies) { + if (dataMaskingPolicy.column.equalsIgnoreCase(col)) { + return Optional.of(dataMaskingPolicy); + } + } return Optional.empty(); } @Override public List<? extends RowFilterPolicy> evalRowFilterPolicies(UserIdentity currentUser, String ctl, String db, String tbl) { - return Lists.newArrayList(); + List<CustomRowPolicy> customRowPolicies = rowPolicies.get(); + if (customRowPolicies == null) { + return ImmutableList.of(); + } + NereidsParser nereidsParser = new NereidsParser(); + return customRowPolicies.stream() + .map(p -> new RowFilterPolicy() { + @Override + public Expression getFilterExpression() { + return nereidsParser.parseExpression(p.filter); + } + + @Override + public String getPolicyIdent() { + return "custom policy: " + p.filter; + } + }) + .collect(Collectors.toList()); } } @@ -328,6 +451,8 @@ public class TestCheckPrivileges extends TestWithFeService { private List<TablePrivilege> tablePrivileges; private List<ColumnPrivilege> columnPrivileges; + private List<CustomRowPolicy> rowPolicies; + private List<CustomDataMaskingPolicy> dataMaskingPolicies; public MakeTablePrivileges(String catalog, String db, String table) { this.catalog = catalog; @@ -335,6 +460,8 @@ public class TestCheckPrivileges extends TestWithFeService { this.table = table; this.tablePrivileges = Lists.newArrayList(); this.columnPrivileges = Lists.newArrayList(); + this.rowPolicies = Lists.newArrayList(); + this.dataMaskingPolicies = Lists.newArrayList(); } public MakeTablePrivileges allowSelectTable(String user) { @@ -346,6 +473,16 @@ public class TestCheckPrivileges extends TestWithFeService { columnPrivileges.add(new ColumnPrivilege(catalog, db, table, user, allowColumns)); return this; } + + public MakeTablePrivileges addRowPolicy(String user, String filter) { + rowPolicies.add(new CustomRowPolicy(user, filter)); + return this; + } + + public MakeTablePrivileges addDataMasking(String user, String column, String project) { + dataMaskingPolicies.add(new CustomDataMaskingPolicy(user, column, project)); + return this; + } } private static class TablePrivilege { @@ -402,4 +539,59 @@ public class TestCheckPrivileges extends TestWithFeService { && StringUtils.equals(this.table, tbl); } } + + private static class CustomRowPolicy implements RowFilterPolicy { + private final String user; + private final String filter; + + public CustomRowPolicy(String user, String filter) { + this.user = user; + this.filter = filter; + } + + public String getUser() { + return user; + } + + @Override + public Expression getFilterExpression() { + return new NereidsParser().parseExpression(filter); + } + + @Override + public String getPolicyIdent() { + return "custom policy: " + filter; + } + + @Override + public FilterType getFilterType() { + return FilterType.PERMISSIVE; + } + } + + private static class CustomDataMaskingPolicy implements DataMaskPolicy { + private final String user; + private final String column; + private final String project; + + public CustomDataMaskingPolicy(String user, String name, String project) { + this.user = user; + this.column = name; + this.project = project; + } + + public String getUser() { + return user; + } + + @Override + public String getMaskTypeDef() { + return project; + } + + @Override + public String getPolicyIdent() { + return "custom policy: " + project; + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org