This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch hackathon-12-2025
in repository https://gitbox.apache.org/repos/asf/superset.git

commit bf342d66db6fe15f47e3acf60ca4883f7a3bbf32
Author: Beto Dealmeida <[email protected]>
AuthorDate: Fri Dec 19 13:22:58 2025 -0500

    Fix case
---
 superset/sql/parse.py | 96 +++++++++++++++++++++++++++++++++++++--------------
 1 file changed, 71 insertions(+), 25 deletions(-)

diff --git a/superset/sql/parse.py b/superset/sql/parse.py
index 709085d728..4ec056e6fd 100644
--- a/superset/sql/parse.py
+++ b/superset/sql/parse.py
@@ -35,6 +35,7 @@ from sqlglot.dialects.dialect import (
 )
 from sqlglot.dialects.singlestore import SingleStore
 from sqlglot.errors import ParseError
+from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
 from sqlglot.optimizer.pushdown_predicates import (
     pushdown_predicates,
 )
@@ -292,24 +293,40 @@ class RLSTransformer:
         catalog: str | None,
         schema: str | None,
         rules: dict[Table, list[exp.Expression]],
+        dialect: Dialects | type[Dialect] | None = None,
     ) -> None:
         self.catalog = catalog
         self.schema = schema
-        # Normalize table keys to lowercase for case-insensitive matching
-        # This is needed because apply_cls calls qualify() which may change
-        # identifier case (e.g., Snowflake uppercases identifiers)
+        self.dialect = dialect
+        # Normalize table keys using dialect-aware normalization
+        # This ensures matching works correctly regardless of how the dialect
+        # handles identifier case (e.g., Snowflake uppercases, Postgres 
lowercases)
         self.rules = {
             self._normalize_table(table): predicates
             for table, predicates in rules.items()
         }
 
-    @staticmethod
-    def _normalize_table(table: Table) -> Table:
-        """Normalize table to lowercase for case-insensitive matching."""
+    def _normalize_table(self, table: Table) -> Table:
+        """
+        Normalize table identifiers using dialect-aware normalization.
+
+        This uses sqlglot's normalize_identifiers to match how the dialect
+        handles identifier case:
+        - Snowflake: uppercases unquoted identifiers
+        - PostgreSQL: lowercases unquoted identifiers
+        - Quoted identifiers preserve their case
+        """
+        # Create a temporary exp.Table node for normalization
+        table_exp = exp.Table(
+            this=exp.Identifier(this=table.table) if table.table else None,
+            db=exp.Identifier(this=table.schema) if table.schema else None,
+            catalog=exp.Identifier(this=table.catalog) if table.catalog else 
None,
+        )
+        normalized = normalize_identifiers(table_exp, dialect=self.dialect)
         return Table(
-            table=table.table.lower() if table.table else table.table,
-            schema=table.schema.lower() if table.schema else table.schema,
-            catalog=table.catalog.lower() if table.catalog else table.catalog,
+            table=normalized.name if normalized.name else table.table,
+            schema=normalized.db if normalized.db else table.schema,
+            catalog=normalized.catalog if normalized.catalog else 
table.catalog,
         )
 
     def get_predicate(self, table_node: exp.Table) -> exp.Expression | None:
@@ -473,20 +490,47 @@ class CLSTransformer:
         rules: CLSRules,
         dialect: Dialects | type[Dialect] | None,
     ) -> None:
-        self.rules = self._normalize_rules(rules)
         self.dialect = dialect
+        self.rules = self._normalize_rules(rules)
         self.hash_pattern = CLS_HASH_FUNCTIONS.get(dialect, 
CLS_HASH_FUNCTIONS[None])
 
+    def _normalize_identifier(self, name: str) -> str:
+        """Normalize an identifier using dialect-aware normalization."""
+        ident = exp.Identifier(this=name)
+        normalized = normalize_identifiers(ident, dialect=self.dialect)
+        return normalized.name
+
+    def _normalize_table(self, table: Table) -> Table:
+        """
+        Normalize table identifiers using dialect-aware normalization.
+
+        This uses sqlglot's normalize_identifiers to match how the dialect
+        handles identifier case:
+        - Snowflake: uppercases unquoted identifiers
+        - PostgreSQL: lowercases unquoted identifiers
+        - Quoted identifiers preserve their case
+        """
+        table_exp = exp.Table(
+            this=exp.Identifier(this=table.table) if table.table else None,
+            db=exp.Identifier(this=table.schema) if table.schema else None,
+            catalog=exp.Identifier(this=table.catalog) if table.catalog else 
None,
+        )
+        normalized = normalize_identifiers(table_exp, dialect=self.dialect)
+        return Table(
+            table=normalized.name if normalized.name else table.table,
+            schema=normalized.db if normalized.db else table.schema,
+            catalog=normalized.catalog if normalized.catalog else 
table.catalog,
+        )
+
     def _normalize_rules(self, rules: CLSRules) -> dict[Table, dict[str, 
CLSAction]]:
         """
-        Normalize table and column names to lowercase for case-insensitive 
matching.
+        Normalize table and column names using dialect-aware normalization.
         """
         return {
-            Table(
-                table=table.table.lower(),
-                schema=table.schema.lower() if table.schema else None,
-                catalog=table.catalog.lower() if table.catalog else None,
-            ): {col.lower(): action for col, action in cols.items()}
+            self._normalize_table(table): {
+                self._normalize_identifier(col): action
+                for col, action in cols.items()
+            }
             for table, cols in rules.items()
         }
 
@@ -509,22 +553,24 @@ class CLSTransformer:
             return None
 
         # Create a normalized Table for lookup
-        lookup_table = Table(
-            table=table_name.lower(),
-            schema=schema.lower() if schema else None,
-            catalog=catalog.lower() if catalog else None,
+        lookup_table = self._normalize_table(
+            Table(
+                table=table_name,
+                schema=schema,
+                catalog=catalog,
+            )
         )
+        normalized_column = self._normalize_identifier(column_name)
 
         # First try exact match with schema/catalog
-        table_rules = self.rules.get(lookup_table)
-        if table_rules:
-            return table_rules.get(column_name.lower())
+        if (table_rules := self.rules.get(lookup_table)):
+            return table_rules.get(normalized_column)
 
         # Fallback: match by table name only
         # This handles cases where the rule has schema/catalog but the query 
doesn't
         for rule_table, cols in self.rules.items():
             if rule_table.table == lookup_table.table:
-                action = cols.get(column_name.lower())
+                action = cols.get(normalized_column)
                 if action:
                     return action
 
@@ -1535,7 +1581,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
         if method not in transformers:
             raise ValueError(f"Invalid RLS method: {method}")
 
-        transformer = transformers[method](catalog, schema, predicates)
+        transformer = transformers[method](catalog, schema, predicates, 
self._dialect)
         self._parsed = self._parsed.transform(transformer)
 
     def apply_cls(

Reply via email to