This is an automated email from the ASF dual-hosted git repository.

kgabryje pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 47db185e3b9 fix(mcp): include x_axis column in query context for 
series charts with group_by (#37639)
47db185e3b9 is described below

commit 47db185e3b968fb875d83f06b85b9272b1eff050
Author: Amin Ghadersohi <[email protected]>
AuthorDate: Thu Feb 5 11:59:44 2026 -0700

    fix(mcp): include x_axis column in query context for series charts with 
group_by (#37639)
---
 superset/mcp_service/chart/preview_utils.py        |  23 ++-
 superset/mcp_service/chart/schemas.py              |  10 +-
 superset/mcp_service/chart/tool/get_chart_data.py  |  15 +-
 .../mcp_service/chart/tool/get_chart_preview.py    |  25 +++-
 .../mcp_service/chart/test_chart_schemas.py        |  36 +++++
 .../mcp_service/chart/test_preview_utils.py        | 158 +++++++++++++++++++++
 .../mcp_service/chart/tool/test_get_chart_data.py  | 109 ++++++++++++++
 .../chart/tool/test_get_chart_preview.py           | 113 +++++++++++++++
 8 files changed, 483 insertions(+), 6 deletions(-)

diff --git a/superset/mcp_service/chart/preview_utils.py 
b/superset/mcp_service/chart/preview_utils.py
index 3db475c0da1..677d3034fd4 100644
--- a/superset/mcp_service/chart/preview_utils.py
+++ b/superset/mcp_service/chart/preview_utils.py
@@ -36,6 +36,23 @@ from superset.mcp_service.chart.schemas import (
 logger = logging.getLogger(__name__)
 
 
+def _build_query_columns(form_data: Dict[str, Any]) -> list[str]:
+    """Build query columns list from form_data, including both x_axis and 
groupby."""
+    x_axis_config = form_data.get("x_axis")
+    groupby_columns: list[str] = form_data.get("groupby") or []
+    raw_columns: list[str] = form_data.get("columns") or []
+
+    columns = raw_columns.copy() if "columns" in form_data else 
groupby_columns.copy()
+    if x_axis_config and isinstance(x_axis_config, str):
+        if x_axis_config not in columns:
+            columns.insert(0, x_axis_config)
+    elif x_axis_config and isinstance(x_axis_config, dict):
+        col_name = x_axis_config.get("column_name")
+        if col_name and col_name not in columns:
+            columns.insert(0, col_name)
+    return columns
+
+
 def generate_preview_from_form_data(
     form_data: Dict[str, Any], dataset_id: int, preview_format: str
 ) -> Any:
@@ -64,12 +81,16 @@ def generate_preview_from_form_data(
         # Create query context from form data using factory
         from superset.common.query_context_factory import QueryContextFactory
 
+        # Build columns list: include x_axis and groupby for XY charts,
+        # fall back to form_data "columns" for table charts
+        columns = _build_query_columns(form_data)
+
         factory = QueryContextFactory()
         query_context_obj = factory.create(
             datasource={"id": dataset_id, "type": "table"},
             queries=[
                 {
-                    "columns": form_data.get("columns", []),
+                    "columns": columns,
                     "metrics": form_data.get("metrics", []),
                     "orderby": form_data.get("orderby", []),
                     "row_limit": form_data.get("row_limit", 100),
diff --git a/superset/mcp_service/chart/schemas.py 
b/superset/mcp_service/chart/schemas.py
index b813bc4ebc5..6b5ed699672 100644
--- a/superset/mcp_service/chart/schemas.py
+++ b/superset/mcp_service/chart/schemas.py
@@ -606,6 +606,8 @@ class FilterConfig(BaseModel):
 
 # Actual chart types
 class TableChartConfig(BaseModel):
+    model_config = ConfigDict(extra="forbid")
+
     chart_type: Literal["table"] = Field(
         ..., description="Chart type (REQUIRED: must be 'table')"
     )
@@ -659,6 +661,8 @@ class TableChartConfig(BaseModel):
 
 
 class XYChartConfig(BaseModel):
+    model_config = ConfigDict(extra="forbid")
+
     chart_type: Literal["xy"] = Field(
         ...,
         description=(
@@ -692,7 +696,11 @@ class XYChartConfig(BaseModel):
         False,
         description="Stack bars/areas on top of each other instead of 
side-by-side",
     )
-    group_by: ColumnRef | None = Field(None, description="Column to group by")
+    group_by: ColumnRef | None = Field(
+        None,
+        description="Column to group by (creates series/breakdown). "
+        "Use this field for series grouping — do NOT use 'series'.",
+    )
     x_axis: AxisConfig | None = Field(None, description="X-axis configuration")
     y_axis: AxisConfig | None = Field(None, description="Y-axis configuration")
     legend: LegendConfig | None = Field(None, description="Legend 
configuration")
diff --git a/superset/mcp_service/chart/tool/get_chart_data.py 
b/superset/mcp_service/chart/tool/get_chart_data.py
index 9984aa1d9bb..91d478b9a54 100644
--- a/superset/mcp_service/chart/tool/get_chart_data.py
+++ b/superset/mcp_service/chart/tool/get_chart_data.py
@@ -176,7 +176,18 @@ async def get_chart_data(  # noqa: C901
                 else:
                     # Standard charts use "metrics" (plural) and "groupby"
                     metrics = form_data.get("metrics", [])
-                    groupby_columns = form_data.get("groupby", [])
+                    groupby_columns = form_data.get("groupby") or []
+
+                # Build query columns list: include both x_axis and groupby
+                x_axis_config = form_data.get("x_axis")
+                query_columns = groupby_columns.copy()
+                if x_axis_config and isinstance(x_axis_config, str):
+                    if x_axis_config not in query_columns:
+                        query_columns.insert(0, x_axis_config)
+                elif x_axis_config and isinstance(x_axis_config, dict):
+                    col_name = x_axis_config.get("column_name")
+                    if col_name and col_name not in query_columns:
+                        query_columns.insert(0, col_name)
 
                 query_context = factory.create(
                     datasource={
@@ -186,7 +197,7 @@ async def get_chart_data(  # noqa: C901
                     queries=[
                         {
                             "filters": form_data.get("filters", []),
-                            "columns": groupby_columns,
+                            "columns": query_columns,
                             "metrics": metrics,
                             "row_limit": row_limit,
                             "order_desc": True,
diff --git a/superset/mcp_service/chart/tool/get_chart_preview.py 
b/superset/mcp_service/chart/tool/get_chart_preview.py
index 9e001540396..fbc1a5802be 100644
--- a/superset/mcp_service/chart/tool/get_chart_preview.py
+++ b/superset/mcp_service/chart/tool/get_chart_preview.py
@@ -56,6 +56,22 @@ class ChartLike(Protocol):
     uuid: Any
 
 
+def _build_query_columns(form_data: Dict[str, Any]) -> list[str]:
+    """Build query columns list from form_data, including both x_axis and 
groupby."""
+    x_axis_config = form_data.get("x_axis")
+    groupby_columns: list[str] = form_data.get("groupby") or []
+
+    columns = groupby_columns.copy()
+    if x_axis_config and isinstance(x_axis_config, str):
+        if x_axis_config not in columns:
+            columns.insert(0, x_axis_config)
+    elif x_axis_config and isinstance(x_axis_config, dict):
+        col_name = x_axis_config.get("column_name")
+        if col_name and col_name not in columns:
+            columns.insert(0, col_name)
+    return columns
+
+
 class PreviewFormatStrategy:
     """Base class for preview format strategies."""
 
@@ -185,6 +201,8 @@ class TablePreviewStrategy(PreviewFormatStrategy):
                     error_type="InvalidChart",
                 )
 
+            columns = _build_query_columns(form_data)
+
             factory = QueryContextFactory()
             query_context = factory.create(
                 datasource={
@@ -194,7 +212,7 @@ class TablePreviewStrategy(PreviewFormatStrategy):
                 queries=[
                     {
                         "filters": form_data.get("filters", []),
-                        "columns": form_data.get("groupby", []),
+                        "columns": columns,
                         "metrics": form_data.get("metrics", []),
                         "row_limit": 20,
                         "order_desc": True,
@@ -279,6 +297,9 @@ class VegaLitePreviewStrategy(PreviewFormatStrategy):
                     utils_json.loads(self.chart.params) if self.chart.params 
else {}
                 )
 
+            # Build columns list: include both x_axis and groupby
+            columns = _build_query_columns(form_data)
+
             # Create query context for data retrieval
             factory = QueryContextFactory()
             query_context = factory.create(
@@ -289,7 +310,7 @@ class VegaLitePreviewStrategy(PreviewFormatStrategy):
                 queries=[
                     {
                         "filters": form_data.get("filters", []),
-                        "columns": form_data.get("groupby", []),
+                        "columns": columns,
                         "metrics": form_data.get("metrics", []),
                         "row_limit": 1000,  # More data for visualization
                         "order_desc": True,
diff --git a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py 
b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py
index 5bbae63912b..ae13bfc8a75 100644
--- a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py
+++ b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py
@@ -219,3 +219,39 @@ class TestXYChartConfig:
             kind="area",
         )
         assert config.kind == "area"
+
+    def test_unknown_fields_rejected(self) -> None:
+        """Test that unknown fields like 'series' are rejected."""
+        with pytest.raises(ValidationError, match="Extra inputs are not 
permitted"):
+            XYChartConfig(
+                chart_type="xy",
+                x=ColumnRef(name="territory"),
+                y=[ColumnRef(name="sales", aggregate="SUM")],
+                kind="bar",
+                series=ColumnRef(name="year"),
+            )
+
+    def test_group_by_accepted(self) -> None:
+        """Test that group_by is the correct field for series grouping."""
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="territory"),
+            y=[ColumnRef(name="sales", aggregate="SUM")],
+            kind="bar",
+            group_by=ColumnRef(name="year"),
+        )
+        assert config.group_by is not None
+        assert config.group_by.name == "year"
+
+
+class TestTableChartConfigExtraFields:
+    """Test TableChartConfig rejects unknown fields."""
+
+    def test_unknown_fields_rejected(self) -> None:
+        """Test that unknown fields are rejected."""
+        with pytest.raises(ValidationError, match="Extra inputs are not 
permitted"):
+            TableChartConfig(
+                chart_type="table",
+                columns=[ColumnRef(name="product")],
+                foo="bar",
+            )
diff --git a/tests/unit_tests/mcp_service/chart/test_preview_utils.py 
b/tests/unit_tests/mcp_service/chart/test_preview_utils.py
new file mode 100644
index 00000000000..0190203e90e
--- /dev/null
+++ b/tests/unit_tests/mcp_service/chart/test_preview_utils.py
@@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Tests for preview_utils query context column building.
+"""
+
+
+class TestPreviewUtilsColumnBuilding:
+    """Tests for x_axis + groupby column building in 
generate_preview_from_form_data.
+
+    The function must build the columns list from both x_axis and groupby for
+    XY charts, and fall back to form_data["columns"] for table charts.
+    """
+
+    def test_xy_chart_uses_x_axis_and_groupby(self):
+        """Test XY chart form_data builds columns from x_axis + groupby."""
+        form_data = {
+            "x_axis": "territory",
+            "groupby": ["year"],
+            "metrics": [{"label": "SUM(sales)"}],
+        }
+
+        x_axis_config = form_data.get("x_axis")
+        groupby_columns = form_data.get("groupby", [])
+        raw_columns = form_data.get("columns", [])
+
+        columns = (
+            raw_columns.copy() if "columns" in form_data else 
groupby_columns.copy()
+        )
+        if x_axis_config and isinstance(x_axis_config, str):
+            if x_axis_config not in columns:
+                columns.insert(0, x_axis_config)
+        elif x_axis_config and isinstance(x_axis_config, dict):
+            col_name = x_axis_config.get("column_name")
+            if col_name and col_name not in columns:
+                columns.insert(0, col_name)
+
+        assert columns == ["territory", "year"]
+
+    def test_table_chart_uses_columns_field(self):
+        """Test table chart form_data uses 'columns' field directly."""
+        form_data = {
+            "columns": ["name", "region", "sales"],
+            "metrics": [],
+        }
+
+        x_axis_config = form_data.get("x_axis")
+        groupby_columns = form_data.get("groupby", [])
+        raw_columns = form_data.get("columns", [])
+
+        columns = (
+            raw_columns.copy() if "columns" in form_data else 
groupby_columns.copy()
+        )
+        if x_axis_config and isinstance(x_axis_config, str):
+            if x_axis_config not in columns:
+                columns.insert(0, x_axis_config)
+
+        assert columns == ["name", "region", "sales"]
+
+    def test_xy_chart_x_axis_dict_format(self):
+        """Test XY chart with x_axis as dict (column_name key)."""
+        form_data = {
+            "x_axis": {"column_name": "order_date"},
+            "groupby": ["product_type"],
+            "metrics": [{"label": "SUM(revenue)"}],
+        }
+
+        x_axis_config = form_data.get("x_axis")
+        groupby_columns = form_data.get("groupby", [])
+        raw_columns = form_data.get("columns", [])
+
+        columns = (
+            raw_columns.copy() if "columns" in form_data else 
groupby_columns.copy()
+        )
+        if x_axis_config and isinstance(x_axis_config, str):
+            if x_axis_config not in columns:
+                columns.insert(0, x_axis_config)
+        elif x_axis_config and isinstance(x_axis_config, dict):
+            col_name = x_axis_config.get("column_name")
+            if col_name and col_name not in columns:
+                columns.insert(0, col_name)
+
+        assert columns == ["order_date", "product_type"]
+
+    def test_no_x_axis_no_columns_uses_groupby(self):
+        """Test fallback to groupby when no x_axis and no columns."""
+        form_data = {
+            "groupby": ["category"],
+            "metrics": [{"label": "COUNT(*)"}],
+        }
+
+        x_axis_config = form_data.get("x_axis")
+        groupby_columns = form_data.get("groupby", [])
+        raw_columns = form_data.get("columns", [])
+
+        columns = (
+            raw_columns.copy() if "columns" in form_data else 
groupby_columns.copy()
+        )
+        if x_axis_config and isinstance(x_axis_config, str):
+            if x_axis_config not in columns:
+                columns.insert(0, x_axis_config)
+
+        assert columns == ["category"]
+
+    def test_empty_form_data_returns_empty_columns(self):
+        """Test empty form_data returns empty columns list."""
+        form_data: dict = {
+            "metrics": [{"label": "COUNT(*)"}],
+        }
+
+        x_axis_config = form_data.get("x_axis")
+        groupby_columns = form_data.get("groupby", [])
+        raw_columns = form_data.get("columns", [])
+
+        columns = (
+            raw_columns.copy() if "columns" in form_data else 
groupby_columns.copy()
+        )
+        if x_axis_config and isinstance(x_axis_config, str):
+            if x_axis_config not in columns:
+                columns.insert(0, x_axis_config)
+
+        assert columns == []
+
+    def test_x_axis_not_duplicated_when_in_groupby(self):
+        """Test x_axis is not added if already present in groupby."""
+        form_data = {
+            "x_axis": "territory",
+            "groupby": ["territory", "year"],
+            "metrics": [{"label": "SUM(sales)"}],
+        }
+
+        x_axis_config = form_data.get("x_axis")
+        groupby_columns = form_data.get("groupby", [])
+        raw_columns = form_data.get("columns", [])
+
+        columns = (
+            raw_columns.copy() if "columns" in form_data else 
groupby_columns.copy()
+        )
+        if x_axis_config and isinstance(x_axis_config, str):
+            if x_axis_config not in columns:
+                columns.insert(0, x_axis_config)
+
+        assert columns == ["territory", "year"]
diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py 
b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py
index 2669366f526..d276691425f 100644
--- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py
+++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py
@@ -152,6 +152,115 @@ class TestBigNumberChartFallback:
         assert groupby_columns == []
 
 
+class TestXAxisInQueryContext:
+    """Tests for x_axis inclusion in fallback query context columns."""
+
+    def test_x_axis_string_included_in_columns(self):
+        """Test that x_axis (string format) is included alongside groupby 
columns."""
+        form_data = {
+            "x_axis": "territory",
+            "groupby": ["year"],
+            "metrics": [{"label": "SUM(sales)"}],
+            "viz_type": "echarts_timeseries_bar",
+        }
+
+        groupby_columns = form_data.get("groupby", [])
+        x_axis_config = form_data.get("x_axis")
+        columns = groupby_columns.copy()
+        if x_axis_config and isinstance(x_axis_config, str):
+            if x_axis_config not in columns:
+                columns.insert(0, x_axis_config)
+
+        assert columns == ["territory", "year"]
+
+    def test_x_axis_dict_included_in_columns(self):
+        """Test that x_axis (dict format with column_name) is included."""
+        form_data = {
+            "x_axis": {"column_name": "territory"},
+            "groupby": ["year"],
+            "metrics": [{"label": "SUM(sales)"}],
+        }
+
+        groupby_columns = form_data.get("groupby", [])
+        x_axis_config = form_data.get("x_axis")
+        columns = groupby_columns.copy()
+        if x_axis_config and isinstance(x_axis_config, str):
+            if x_axis_config not in columns:
+                columns.insert(0, x_axis_config)
+        elif x_axis_config and isinstance(x_axis_config, dict):
+            col_name = x_axis_config.get("column_name")
+            if col_name and col_name not in columns:
+                columns.insert(0, col_name)
+
+        assert columns == ["territory", "year"]
+
+    def test_no_x_axis_uses_groupby_only(self):
+        """Test that without x_axis, only groupby columns are used."""
+        form_data = {
+            "groupby": ["region", "category"],
+            "metrics": [{"label": "SUM(sales)"}],
+        }
+
+        groupby_columns = form_data.get("groupby", [])
+        x_axis_config = form_data.get("x_axis")
+        columns = groupby_columns.copy()
+        if x_axis_config and isinstance(x_axis_config, str):
+            if x_axis_config not in columns:
+                columns.insert(0, x_axis_config)
+
+        assert columns == ["region", "category"]
+
+    def test_x_axis_not_duplicated_if_in_groupby(self):
+        """Test that x_axis is not duplicated if already in groupby list."""
+        form_data = {
+            "x_axis": "territory",
+            "groupby": ["territory", "year"],
+            "metrics": [{"label": "SUM(sales)"}],
+        }
+
+        groupby_columns = form_data.get("groupby", [])
+        x_axis_config = form_data.get("x_axis")
+        columns = groupby_columns.copy()
+        if x_axis_config and isinstance(x_axis_config, str):
+            if x_axis_config not in columns:
+                columns.insert(0, x_axis_config)
+
+        assert columns == ["territory", "year"]
+
+    def test_x_axis_without_groupby(self):
+        """Test that x_axis works when there's no groupby."""
+        form_data = {
+            "x_axis": "date",
+            "metrics": [{"label": "SUM(sales)"}],
+        }
+
+        groupby_columns = form_data.get("groupby", [])
+        x_axis_config = form_data.get("x_axis")
+        columns = groupby_columns.copy()
+        if x_axis_config and isinstance(x_axis_config, str):
+            if x_axis_config not in columns:
+                columns.insert(0, x_axis_config)
+
+        assert columns == ["date"]
+
+    def test_empty_groupby_with_x_axis(self):
+        """Test x_axis with explicitly empty groupby."""
+        form_data = {
+            "x_axis": "platform",
+            "groupby": [],
+            "metrics": [{"label": "SUM(global_sales)"}],
+        }
+
+        groupby_columns = form_data.get("groupby", [])
+        x_axis_config = form_data.get("x_axis")
+        columns = groupby_columns.copy()
+        if x_axis_config and isinstance(x_axis_config, str):
+            if x_axis_config not in columns:
+                columns.insert(0, x_axis_config)
+
+        assert columns == ["platform"]
+
+
 class TestGetChartDataRequestSchema:
     """Test the GetChartDataRequest schema validation."""
 
diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py 
b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py
index cbff760778e..fdd824886d7 100644
--- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py
+++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py
@@ -29,6 +29,119 @@ from superset.mcp_service.chart.schemas import (
 )
 
 
+class TestPreviewXAxisInQueryContext:
+    """Tests for x_axis inclusion in preview query context columns.
+
+    When generating chart previews (table, vega_lite), the query context must
+    include both x_axis and groupby columns. Previously only groupby was used,
+    causing series charts with group_by to lose the x_axis dimension.
+    """
+
+    def test_table_preview_includes_x_axis_and_groupby(self):
+        """Test that table preview builds columns with both x_axis and 
groupby."""
+        form_data = {
+            "x_axis": "territory",
+            "groupby": ["year"],
+            "metrics": [{"label": "SUM(sales)"}],
+        }
+
+        x_axis_config = form_data.get("x_axis")
+        groupby_columns = form_data.get("groupby", [])
+        columns = groupby_columns.copy()
+        if x_axis_config and isinstance(x_axis_config, str):
+            if x_axis_config not in columns:
+                columns.insert(0, x_axis_config)
+
+        assert columns == ["territory", "year"]
+
+    def test_vega_lite_preview_includes_x_axis_and_groupby(self):
+        """Test that vega_lite preview builds columns with both x_axis and 
groupby."""
+        form_data = {
+            "x_axis": "platform",
+            "groupby": ["genre"],
+            "metrics": [{"label": "SUM(global_sales)"}],
+        }
+
+        x_axis_config = form_data.get("x_axis")
+        groupby_columns = form_data.get("groupby", [])
+        columns = groupby_columns.copy()
+        if x_axis_config and isinstance(x_axis_config, str):
+            if x_axis_config not in columns:
+                columns.insert(0, x_axis_config)
+
+        assert columns == ["platform", "genre"]
+
+    def test_preview_x_axis_dict_format(self):
+        """Test preview column building with x_axis as dict."""
+        form_data = {
+            "x_axis": {"column_name": "order_date"},
+            "groupby": ["region"],
+            "metrics": [{"label": "SUM(revenue)"}],
+        }
+
+        x_axis_config = form_data.get("x_axis")
+        groupby_columns = form_data.get("groupby", [])
+        columns = groupby_columns.copy()
+        if x_axis_config and isinstance(x_axis_config, str):
+            if x_axis_config not in columns:
+                columns.insert(0, x_axis_config)
+        elif x_axis_config and isinstance(x_axis_config, dict):
+            col_name = x_axis_config.get("column_name")
+            if col_name and col_name not in columns:
+                columns.insert(0, col_name)
+
+        assert columns == ["order_date", "region"]
+
+    def test_preview_no_groupby_x_axis_only(self):
+        """Test preview with x_axis but no groupby."""
+        form_data = {
+            "x_axis": "date",
+            "metrics": [{"label": "SUM(sales)"}],
+        }
+
+        x_axis_config = form_data.get("x_axis")
+        groupby_columns = form_data.get("groupby", [])
+        columns = groupby_columns.copy()
+        if x_axis_config and isinstance(x_axis_config, str):
+            if x_axis_config not in columns:
+                columns.insert(0, x_axis_config)
+
+        assert columns == ["date"]
+
+    def test_preview_no_x_axis_groupby_only(self):
+        """Test preview with groupby but no x_axis (e.g., table chart)."""
+        form_data = {
+            "groupby": ["category", "region"],
+            "metrics": [{"label": "COUNT(*)"}],
+        }
+
+        x_axis_config = form_data.get("x_axis")
+        groupby_columns = form_data.get("groupby", [])
+        columns = groupby_columns.copy()
+        if x_axis_config and isinstance(x_axis_config, str):
+            if x_axis_config not in columns:
+                columns.insert(0, x_axis_config)
+
+        assert columns == ["category", "region"]
+
+    def test_preview_x_axis_not_duplicated(self):
+        """Test x_axis isn't duplicated if already in groupby."""
+        form_data = {
+            "x_axis": "territory",
+            "groupby": ["territory", "year"],
+            "metrics": [{"label": "SUM(sales)"}],
+        }
+
+        x_axis_config = form_data.get("x_axis")
+        groupby_columns = form_data.get("groupby", [])
+        columns = groupby_columns.copy()
+        if x_axis_config and isinstance(x_axis_config, str):
+            if x_axis_config not in columns:
+                columns.insert(0, x_axis_config)
+
+        assert columns == ["territory", "year"]
+
+
 class TestGetChartPreview:
     """Tests for get_chart_preview MCP tool."""
 

Reply via email to