This is an automated email from the ASF dual-hosted git repository.
kgabryje pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 47db185e3b9 fix(mcp): include x_axis column in query context for
series charts with group_by (#37639)
47db185e3b9 is described below
commit 47db185e3b968fb875d83f06b85b9272b1eff050
Author: Amin Ghadersohi <[email protected]>
AuthorDate: Thu Feb 5 11:59:44 2026 -0700
fix(mcp): include x_axis column in query context for series charts with
group_by (#37639)
---
superset/mcp_service/chart/preview_utils.py | 23 ++-
superset/mcp_service/chart/schemas.py | 10 +-
superset/mcp_service/chart/tool/get_chart_data.py | 15 +-
.../mcp_service/chart/tool/get_chart_preview.py | 25 +++-
.../mcp_service/chart/test_chart_schemas.py | 36 +++++
.../mcp_service/chart/test_preview_utils.py | 158 +++++++++++++++++++++
.../mcp_service/chart/tool/test_get_chart_data.py | 109 ++++++++++++++
.../chart/tool/test_get_chart_preview.py | 113 +++++++++++++++
8 files changed, 483 insertions(+), 6 deletions(-)
diff --git a/superset/mcp_service/chart/preview_utils.py
b/superset/mcp_service/chart/preview_utils.py
index 3db475c0da1..677d3034fd4 100644
--- a/superset/mcp_service/chart/preview_utils.py
+++ b/superset/mcp_service/chart/preview_utils.py
@@ -36,6 +36,23 @@ from superset.mcp_service.chart.schemas import (
logger = logging.getLogger(__name__)
+def _build_query_columns(form_data: Dict[str, Any]) -> list[str]:
+ """Build query columns list from form_data, including both x_axis and
groupby."""
+ x_axis_config = form_data.get("x_axis")
+ groupby_columns: list[str] = form_data.get("groupby") or []
+ raw_columns: list[str] = form_data.get("columns") or []
+
+ columns = raw_columns.copy() if "columns" in form_data else
groupby_columns.copy()
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+ elif x_axis_config and isinstance(x_axis_config, dict):
+ col_name = x_axis_config.get("column_name")
+ if col_name and col_name not in columns:
+ columns.insert(0, col_name)
+ return columns
+
+
def generate_preview_from_form_data(
form_data: Dict[str, Any], dataset_id: int, preview_format: str
) -> Any:
@@ -64,12 +81,16 @@ def generate_preview_from_form_data(
# Create query context from form data using factory
from superset.common.query_context_factory import QueryContextFactory
+ # Build columns list: include x_axis and groupby for XY charts,
+ # fall back to form_data "columns" for table charts
+ columns = _build_query_columns(form_data)
+
factory = QueryContextFactory()
query_context_obj = factory.create(
datasource={"id": dataset_id, "type": "table"},
queries=[
{
- "columns": form_data.get("columns", []),
+ "columns": columns,
"metrics": form_data.get("metrics", []),
"orderby": form_data.get("orderby", []),
"row_limit": form_data.get("row_limit", 100),
diff --git a/superset/mcp_service/chart/schemas.py
b/superset/mcp_service/chart/schemas.py
index b813bc4ebc5..6b5ed699672 100644
--- a/superset/mcp_service/chart/schemas.py
+++ b/superset/mcp_service/chart/schemas.py
@@ -606,6 +606,8 @@ class FilterConfig(BaseModel):
# Actual chart types
class TableChartConfig(BaseModel):
+ model_config = ConfigDict(extra="forbid")
+
chart_type: Literal["table"] = Field(
..., description="Chart type (REQUIRED: must be 'table')"
)
@@ -659,6 +661,8 @@ class TableChartConfig(BaseModel):
class XYChartConfig(BaseModel):
+ model_config = ConfigDict(extra="forbid")
+
chart_type: Literal["xy"] = Field(
...,
description=(
@@ -692,7 +696,11 @@ class XYChartConfig(BaseModel):
False,
description="Stack bars/areas on top of each other instead of
side-by-side",
)
- group_by: ColumnRef | None = Field(None, description="Column to group by")
+ group_by: ColumnRef | None = Field(
+ None,
+ description="Column to group by (creates series/breakdown). "
+ "Use this field for series grouping — do NOT use 'series'.",
+ )
x_axis: AxisConfig | None = Field(None, description="X-axis configuration")
y_axis: AxisConfig | None = Field(None, description="Y-axis configuration")
legend: LegendConfig | None = Field(None, description="Legend
configuration")
diff --git a/superset/mcp_service/chart/tool/get_chart_data.py
b/superset/mcp_service/chart/tool/get_chart_data.py
index 9984aa1d9bb..91d478b9a54 100644
--- a/superset/mcp_service/chart/tool/get_chart_data.py
+++ b/superset/mcp_service/chart/tool/get_chart_data.py
@@ -176,7 +176,18 @@ async def get_chart_data( # noqa: C901
else:
# Standard charts use "metrics" (plural) and "groupby"
metrics = form_data.get("metrics", [])
- groupby_columns = form_data.get("groupby", [])
+ groupby_columns = form_data.get("groupby") or []
+
+ # Build query columns list: include both x_axis and groupby
+ x_axis_config = form_data.get("x_axis")
+ query_columns = groupby_columns.copy()
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in query_columns:
+ query_columns.insert(0, x_axis_config)
+ elif x_axis_config and isinstance(x_axis_config, dict):
+ col_name = x_axis_config.get("column_name")
+ if col_name and col_name not in query_columns:
+ query_columns.insert(0, col_name)
query_context = factory.create(
datasource={
@@ -186,7 +197,7 @@ async def get_chart_data( # noqa: C901
queries=[
{
"filters": form_data.get("filters", []),
- "columns": groupby_columns,
+ "columns": query_columns,
"metrics": metrics,
"row_limit": row_limit,
"order_desc": True,
diff --git a/superset/mcp_service/chart/tool/get_chart_preview.py
b/superset/mcp_service/chart/tool/get_chart_preview.py
index 9e001540396..fbc1a5802be 100644
--- a/superset/mcp_service/chart/tool/get_chart_preview.py
+++ b/superset/mcp_service/chart/tool/get_chart_preview.py
@@ -56,6 +56,22 @@ class ChartLike(Protocol):
uuid: Any
+def _build_query_columns(form_data: Dict[str, Any]) -> list[str]:
+ """Build query columns list from form_data, including both x_axis and
groupby."""
+ x_axis_config = form_data.get("x_axis")
+ groupby_columns: list[str] = form_data.get("groupby") or []
+
+ columns = groupby_columns.copy()
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+ elif x_axis_config and isinstance(x_axis_config, dict):
+ col_name = x_axis_config.get("column_name")
+ if col_name and col_name not in columns:
+ columns.insert(0, col_name)
+ return columns
+
+
class PreviewFormatStrategy:
"""Base class for preview format strategies."""
@@ -185,6 +201,8 @@ class TablePreviewStrategy(PreviewFormatStrategy):
error_type="InvalidChart",
)
+ columns = _build_query_columns(form_data)
+
factory = QueryContextFactory()
query_context = factory.create(
datasource={
@@ -194,7 +212,7 @@ class TablePreviewStrategy(PreviewFormatStrategy):
queries=[
{
"filters": form_data.get("filters", []),
- "columns": form_data.get("groupby", []),
+ "columns": columns,
"metrics": form_data.get("metrics", []),
"row_limit": 20,
"order_desc": True,
@@ -279,6 +297,9 @@ class VegaLitePreviewStrategy(PreviewFormatStrategy):
utils_json.loads(self.chart.params) if self.chart.params
else {}
)
+ # Build columns list: include both x_axis and groupby
+ columns = _build_query_columns(form_data)
+
# Create query context for data retrieval
factory = QueryContextFactory()
query_context = factory.create(
@@ -289,7 +310,7 @@ class VegaLitePreviewStrategy(PreviewFormatStrategy):
queries=[
{
"filters": form_data.get("filters", []),
- "columns": form_data.get("groupby", []),
+ "columns": columns,
"metrics": form_data.get("metrics", []),
"row_limit": 1000, # More data for visualization
"order_desc": True,
diff --git a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py
b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py
index 5bbae63912b..ae13bfc8a75 100644
--- a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py
+++ b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py
@@ -219,3 +219,39 @@ class TestXYChartConfig:
kind="area",
)
assert config.kind == "area"
+
+ def test_unknown_fields_rejected(self) -> None:
+ """Test that unknown fields like 'series' are rejected."""
+ with pytest.raises(ValidationError, match="Extra inputs are not
permitted"):
+ XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="territory"),
+ y=[ColumnRef(name="sales", aggregate="SUM")],
+ kind="bar",
+ series=ColumnRef(name="year"),
+ )
+
+ def test_group_by_accepted(self) -> None:
+ """Test that group_by is the correct field for series grouping."""
+ config = XYChartConfig(
+ chart_type="xy",
+ x=ColumnRef(name="territory"),
+ y=[ColumnRef(name="sales", aggregate="SUM")],
+ kind="bar",
+ group_by=ColumnRef(name="year"),
+ )
+ assert config.group_by is not None
+ assert config.group_by.name == "year"
+
+
+class TestTableChartConfigExtraFields:
+ """Test TableChartConfig rejects unknown fields."""
+
+ def test_unknown_fields_rejected(self) -> None:
+ """Test that unknown fields are rejected."""
+ with pytest.raises(ValidationError, match="Extra inputs are not
permitted"):
+ TableChartConfig(
+ chart_type="table",
+ columns=[ColumnRef(name="product")],
+ foo="bar",
+ )
diff --git a/tests/unit_tests/mcp_service/chart/test_preview_utils.py
b/tests/unit_tests/mcp_service/chart/test_preview_utils.py
new file mode 100644
index 00000000000..0190203e90e
--- /dev/null
+++ b/tests/unit_tests/mcp_service/chart/test_preview_utils.py
@@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Tests for preview_utils query context column building.
+"""
+
+
+class TestPreviewUtilsColumnBuilding:
+ """Tests for x_axis + groupby column building in
generate_preview_from_form_data.
+
+ The function must build the columns list from both x_axis and groupby for
+ XY charts, and fall back to form_data["columns"] for table charts.
+ """
+
+ def test_xy_chart_uses_x_axis_and_groupby(self):
+ """Test XY chart form_data builds columns from x_axis + groupby."""
+ form_data = {
+ "x_axis": "territory",
+ "groupby": ["year"],
+ "metrics": [{"label": "SUM(sales)"}],
+ }
+
+ x_axis_config = form_data.get("x_axis")
+ groupby_columns = form_data.get("groupby", [])
+ raw_columns = form_data.get("columns", [])
+
+ columns = (
+ raw_columns.copy() if "columns" in form_data else
groupby_columns.copy()
+ )
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+ elif x_axis_config and isinstance(x_axis_config, dict):
+ col_name = x_axis_config.get("column_name")
+ if col_name and col_name not in columns:
+ columns.insert(0, col_name)
+
+ assert columns == ["territory", "year"]
+
+ def test_table_chart_uses_columns_field(self):
+ """Test table chart form_data uses 'columns' field directly."""
+ form_data = {
+ "columns": ["name", "region", "sales"],
+ "metrics": [],
+ }
+
+ x_axis_config = form_data.get("x_axis")
+ groupby_columns = form_data.get("groupby", [])
+ raw_columns = form_data.get("columns", [])
+
+ columns = (
+ raw_columns.copy() if "columns" in form_data else
groupby_columns.copy()
+ )
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+
+ assert columns == ["name", "region", "sales"]
+
+ def test_xy_chart_x_axis_dict_format(self):
+ """Test XY chart with x_axis as dict (column_name key)."""
+ form_data = {
+ "x_axis": {"column_name": "order_date"},
+ "groupby": ["product_type"],
+ "metrics": [{"label": "SUM(revenue)"}],
+ }
+
+ x_axis_config = form_data.get("x_axis")
+ groupby_columns = form_data.get("groupby", [])
+ raw_columns = form_data.get("columns", [])
+
+ columns = (
+ raw_columns.copy() if "columns" in form_data else
groupby_columns.copy()
+ )
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+ elif x_axis_config and isinstance(x_axis_config, dict):
+ col_name = x_axis_config.get("column_name")
+ if col_name and col_name not in columns:
+ columns.insert(0, col_name)
+
+ assert columns == ["order_date", "product_type"]
+
+ def test_no_x_axis_no_columns_uses_groupby(self):
+ """Test fallback to groupby when no x_axis and no columns."""
+ form_data = {
+ "groupby": ["category"],
+ "metrics": [{"label": "COUNT(*)"}],
+ }
+
+ x_axis_config = form_data.get("x_axis")
+ groupby_columns = form_data.get("groupby", [])
+ raw_columns = form_data.get("columns", [])
+
+ columns = (
+ raw_columns.copy() if "columns" in form_data else
groupby_columns.copy()
+ )
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+
+ assert columns == ["category"]
+
+ def test_empty_form_data_returns_empty_columns(self):
+ """Test empty form_data returns empty columns list."""
+ form_data: dict = {
+ "metrics": [{"label": "COUNT(*)"}],
+ }
+
+ x_axis_config = form_data.get("x_axis")
+ groupby_columns = form_data.get("groupby", [])
+ raw_columns = form_data.get("columns", [])
+
+ columns = (
+ raw_columns.copy() if "columns" in form_data else
groupby_columns.copy()
+ )
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+
+ assert columns == []
+
+ def test_x_axis_not_duplicated_when_in_groupby(self):
+ """Test x_axis is not added if already present in groupby."""
+ form_data = {
+ "x_axis": "territory",
+ "groupby": ["territory", "year"],
+ "metrics": [{"label": "SUM(sales)"}],
+ }
+
+ x_axis_config = form_data.get("x_axis")
+ groupby_columns = form_data.get("groupby", [])
+ raw_columns = form_data.get("columns", [])
+
+ columns = (
+ raw_columns.copy() if "columns" in form_data else
groupby_columns.copy()
+ )
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+
+ assert columns == ["territory", "year"]
diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py
b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py
index 2669366f526..d276691425f 100644
--- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py
+++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py
@@ -152,6 +152,115 @@ class TestBigNumberChartFallback:
assert groupby_columns == []
+class TestXAxisInQueryContext:
+ """Tests for x_axis inclusion in fallback query context columns."""
+
+ def test_x_axis_string_included_in_columns(self):
+ """Test that x_axis (string format) is included alongside groupby
columns."""
+ form_data = {
+ "x_axis": "territory",
+ "groupby": ["year"],
+ "metrics": [{"label": "SUM(sales)"}],
+ "viz_type": "echarts_timeseries_bar",
+ }
+
+ groupby_columns = form_data.get("groupby", [])
+ x_axis_config = form_data.get("x_axis")
+ columns = groupby_columns.copy()
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+
+ assert columns == ["territory", "year"]
+
+ def test_x_axis_dict_included_in_columns(self):
+ """Test that x_axis (dict format with column_name) is included."""
+ form_data = {
+ "x_axis": {"column_name": "territory"},
+ "groupby": ["year"],
+ "metrics": [{"label": "SUM(sales)"}],
+ }
+
+ groupby_columns = form_data.get("groupby", [])
+ x_axis_config = form_data.get("x_axis")
+ columns = groupby_columns.copy()
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+ elif x_axis_config and isinstance(x_axis_config, dict):
+ col_name = x_axis_config.get("column_name")
+ if col_name and col_name not in columns:
+ columns.insert(0, col_name)
+
+ assert columns == ["territory", "year"]
+
+ def test_no_x_axis_uses_groupby_only(self):
+ """Test that without x_axis, only groupby columns are used."""
+ form_data = {
+ "groupby": ["region", "category"],
+ "metrics": [{"label": "SUM(sales)"}],
+ }
+
+ groupby_columns = form_data.get("groupby", [])
+ x_axis_config = form_data.get("x_axis")
+ columns = groupby_columns.copy()
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+
+ assert columns == ["region", "category"]
+
+ def test_x_axis_not_duplicated_if_in_groupby(self):
+ """Test that x_axis is not duplicated if already in groupby list."""
+ form_data = {
+ "x_axis": "territory",
+ "groupby": ["territory", "year"],
+ "metrics": [{"label": "SUM(sales)"}],
+ }
+
+ groupby_columns = form_data.get("groupby", [])
+ x_axis_config = form_data.get("x_axis")
+ columns = groupby_columns.copy()
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+
+ assert columns == ["territory", "year"]
+
+ def test_x_axis_without_groupby(self):
+ """Test that x_axis works when there's no groupby."""
+ form_data = {
+ "x_axis": "date",
+ "metrics": [{"label": "SUM(sales)"}],
+ }
+
+ groupby_columns = form_data.get("groupby", [])
+ x_axis_config = form_data.get("x_axis")
+ columns = groupby_columns.copy()
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+
+ assert columns == ["date"]
+
+ def test_empty_groupby_with_x_axis(self):
+ """Test x_axis with explicitly empty groupby."""
+ form_data = {
+ "x_axis": "platform",
+ "groupby": [],
+ "metrics": [{"label": "SUM(global_sales)"}],
+ }
+
+ groupby_columns = form_data.get("groupby", [])
+ x_axis_config = form_data.get("x_axis")
+ columns = groupby_columns.copy()
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+
+ assert columns == ["platform"]
+
+
class TestGetChartDataRequestSchema:
"""Test the GetChartDataRequest schema validation."""
diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py
b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py
index cbff760778e..fdd824886d7 100644
--- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py
+++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py
@@ -29,6 +29,119 @@ from superset.mcp_service.chart.schemas import (
)
+class TestPreviewXAxisInQueryContext:
+ """Tests for x_axis inclusion in preview query context columns.
+
+ When generating chart previews (table, vega_lite), the query context must
+ include both x_axis and groupby columns. Previously only groupby was used,
+ causing series charts with group_by to lose the x_axis dimension.
+ """
+
+ def test_table_preview_includes_x_axis_and_groupby(self):
+ """Test that table preview builds columns with both x_axis and
groupby."""
+ form_data = {
+ "x_axis": "territory",
+ "groupby": ["year"],
+ "metrics": [{"label": "SUM(sales)"}],
+ }
+
+ x_axis_config = form_data.get("x_axis")
+ groupby_columns = form_data.get("groupby", [])
+ columns = groupby_columns.copy()
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+
+ assert columns == ["territory", "year"]
+
+ def test_vega_lite_preview_includes_x_axis_and_groupby(self):
+ """Test that vega_lite preview builds columns with both x_axis and
groupby."""
+ form_data = {
+ "x_axis": "platform",
+ "groupby": ["genre"],
+ "metrics": [{"label": "SUM(global_sales)"}],
+ }
+
+ x_axis_config = form_data.get("x_axis")
+ groupby_columns = form_data.get("groupby", [])
+ columns = groupby_columns.copy()
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+
+ assert columns == ["platform", "genre"]
+
+ def test_preview_x_axis_dict_format(self):
+ """Test preview column building with x_axis as dict."""
+ form_data = {
+ "x_axis": {"column_name": "order_date"},
+ "groupby": ["region"],
+ "metrics": [{"label": "SUM(revenue)"}],
+ }
+
+ x_axis_config = form_data.get("x_axis")
+ groupby_columns = form_data.get("groupby", [])
+ columns = groupby_columns.copy()
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+ elif x_axis_config and isinstance(x_axis_config, dict):
+ col_name = x_axis_config.get("column_name")
+ if col_name and col_name not in columns:
+ columns.insert(0, col_name)
+
+ assert columns == ["order_date", "region"]
+
+ def test_preview_no_groupby_x_axis_only(self):
+ """Test preview with x_axis but no groupby."""
+ form_data = {
+ "x_axis": "date",
+ "metrics": [{"label": "SUM(sales)"}],
+ }
+
+ x_axis_config = form_data.get("x_axis")
+ groupby_columns = form_data.get("groupby", [])
+ columns = groupby_columns.copy()
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+
+ assert columns == ["date"]
+
+ def test_preview_no_x_axis_groupby_only(self):
+ """Test preview with groupby but no x_axis (e.g., table chart)."""
+ form_data = {
+ "groupby": ["category", "region"],
+ "metrics": [{"label": "COUNT(*)"}],
+ }
+
+ x_axis_config = form_data.get("x_axis")
+ groupby_columns = form_data.get("groupby", [])
+ columns = groupby_columns.copy()
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+
+ assert columns == ["category", "region"]
+
+ def test_preview_x_axis_not_duplicated(self):
+ """Test x_axis isn't duplicated if already in groupby."""
+ form_data = {
+ "x_axis": "territory",
+ "groupby": ["territory", "year"],
+ "metrics": [{"label": "SUM(sales)"}],
+ }
+
+ x_axis_config = form_data.get("x_axis")
+ groupby_columns = form_data.get("groupby", [])
+ columns = groupby_columns.copy()
+ if x_axis_config and isinstance(x_axis_config, str):
+ if x_axis_config not in columns:
+ columns.insert(0, x_axis_config)
+
+ assert columns == ["territory", "year"]
+
+
class TestGetChartPreview:
"""Tests for get_chart_preview MCP tool."""