Vitor-Avila commented on code in PR #33542:
URL: https://github.com/apache/superset/pull/33542#discussion_r2116065764


##########
superset/sql_lab.py:
##########
@@ -197,101 +197,131 @@ def get_sql_results(  # pylint: 
disable=too-many-arguments
                 return handle_query_error(ex, query)
 
 
-def execute_sql_statement(  # pylint: disable=too-many-statements, 
too-many-locals  # noqa: C901
-    sql_statement: str,
-    query: Query,
-    cursor: Any,
-    log_params: Optional[dict[str, Any]],
-    apply_ctas: bool = False,
-) -> SupersetResultSet:
-    """Executes a single SQL statement"""
-    database: Database = query.database
-    db_engine_spec = database.db_engine_spec
+def apply_rls(query: Query, parsed_statement: BaseSQLStatement[Any]) -> None:
+    """
+    Modify statement inplace to ensure RLS rules are applied.
+    """
+    # we need the default schema to fully qualify the table names
+    default_schema = query.database.get_default_schema_for_query(query)
+
+    # There are two ways to insert RLS: either replacing the table with a 
subquery
+    # that has the RLS, or appending the RLS to the ``WHERE`` clause. The 
former is
+    # safer, but not supported in all databases.
+    method = (
+        RLSMethod.AS_SUBQUERY
+        if query.database.db_engine_spec.allows_subqueries
+        and query.database.db_engine_spec.allows_alias_in_select
+        else RLSMethod.AS_PREDICATE
+    )
 
-    parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine)
-    if is_feature_enabled("RLS_IN_SQLLAB"):
-        # There are two ways to insert RLS: either replacing the table with a 
subquery
-        # that has the RLS, or appending the RLS to the ``WHERE`` clause. The 
former is
-        # safer, but not supported in all databases.
-        insert_rls = (
-            insert_rls_as_subquery
-            if database.db_engine_spec.allows_subqueries
-            and database.db_engine_spec.allows_alias_in_select
-            else insert_rls_in_predicate
-        )
+    # collect all RLS predicates
+    predicates: dict[Table, list[Any]] = defaultdict(list)
+    for table in parsed_statement.tables:
+        if table_predicates := get_predicates_for_table(
+            table,
+            query.database,
+            query.catalog,
+            default_schema,
+        ):
+            predicates[table].extend(
+                parsed_statement.parse_predicate(predicate)
+                for predicate in table_predicates
+            )
 
-        # Insert any applicable RLS predicates
-        parsed_query = ParsedQuery(
-            str(
-                insert_rls(
-                    parsed_query._parsed[0],  # pylint: 
disable=protected-access
-                    database.id,
-                    query.schema,
-                )
-            ),
-            engine=db_engine_spec.engine,
-        )
+    parsed_statement.apply_rls(query.catalog, default_schema, predicates, 
method)
 
-    sql = parsed_query.stripped()
 
-    # This is a test to see if the query is being
-    # limited by either the dropdown or the sql.
-    # We are testing to see if more rows exist than the limit.
-    increased_limit = None if query.limit is None else query.limit + 1
+def get_predicates_for_table(
+    table: Table,
+    database: Database,
+    catalog: str,
+    schema: str,
+) -> list[str]:
+    """
+    Get the RLS predicates for a table.
 
-    if not database.allow_dml:
-        errors = []
-        try:
-            parsed_statement = SQLStatement(
-                statement=sql_statement,
-                engine=db_engine_spec.engine,
-            )
-            disallowed = parsed_statement.is_mutating()
-        except SupersetParseError as ex:
-            # if we fail to parse the query, disallow by default
-            disallowed = True
-            errors.append(ex.error)
-
-        if disallowed:
-            errors.append(
-                SupersetError(
-                    message=__(
-                        "This database does not allow for DDL/DML, and the 
query "
-                        "could not be parsed to confirm it is a read-only 
query. Please "  # noqa: E501
-                        "contact your administrator for more assistance."
-                    ),
-                    error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR,
-                    level=ErrorLevel.ERROR,
-                )
+    This is used to inject RLS rules into SQL statements run in SQL Lab.
+    """
+    dataset = (
+        db.session.query(SqlaTable)
+        .filter(
+            and_(
+                SqlaTable.database_id == database.id,
+                SqlaTable.catalog == table.catalog or catalog,
+                SqlaTable.schema == table.schema or schema,
+                SqlaTable.table_name == table.table,
             )
-            raise SupersetErrorsException(errors)
-
-    original_sql = sql
-    if apply_ctas:
-        if not query.tmp_table_name:
-            start_dttm = datetime.fromtimestamp(query.start_time)
-            query.tmp_table_name = (
-                
f"tmp_{query.user_id}_table_{start_dttm.strftime('%Y_%m_%d_%H_%M_%S')}"
+        )
+        .one_or_none()
+    )
+    if not dataset:
+        return []
+
+    return [
+        str(
+            and_(*filters).compile(
+                dialect=database.get_dialect(),
+                compile_kwargs={"literal_binds": True},
             )
-        sql = parsed_query.as_create_table(
-            query.tmp_table_name,
-            schema_name=query.tmp_schema_name,
-            method=query.ctas_method,
         )
-        query.select_as_cta_used = True
+        for filters in dataset.get_sqla_row_level_filters()
+    ]
 
+
+S = TypeVar("S", bound=BaseSQLStatement[Any])
+
+
+def apply_ctas(query: Query, parsed_statement: S) -> S:
+    """
+    Apply CTAS/CVAS.
+    """
+    if not query.tmp_table_name:
+        start_dttm = datetime.fromtimestamp(query.start_time)
+        prefix = f"tmp_{query.user_id}_table"
+        query.tmp_table_name = 
start_dttm.strftime(f"{prefix}_%Y_%m_%d_%H_%M_%S")
+
+    catalog = (
+        query.catalog
+        if query.database.db_engine_spec.supports_cross_catalog_queries
+        else None

Review Comment:
   Shouldn't we default here to `get_default_catalog()` instead? I'm wondering 
if there's a DB out there with `supports_catalog = True` but does not support 
cross catalog queries.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to