This is an automated email from the ASF dual-hosted git repository. vavila pushed a commit to branch feat/handle-oauth2-dance-table-selector in repository https://gitbox.apache.org/repos/asf/superset.git
commit 59db6536f2a6701a14429373bcb72dea31043537 Author: Vitor Avila <[email protected]> AuthorDate: Wed Jan 21 18:14:05 2026 -0300 feat: Handle OAuth2 dance with TableSelector --- .../TableSelector/TableSelector.test.tsx | 41 ++++++++ .../src/components/TableSelector/index.tsx | 28 +++-- superset-frontend/src/hooks/apiResources/tables.ts | 9 +- superset/models/core.py | 9 ++ tests/unit_tests/models/core_test.py | 115 +++++++++++++++++++++ 5 files changed, 187 insertions(+), 15 deletions(-) diff --git a/superset-frontend/src/components/TableSelector/TableSelector.test.tsx b/superset-frontend/src/components/TableSelector/TableSelector.test.tsx index a47c40731a..0e7d577ebf 100644 --- a/superset-frontend/src/components/TableSelector/TableSelector.test.tsx +++ b/superset-frontend/src/components/TableSelector/TableSelector.test.tsx @@ -290,3 +290,44 @@ test('TableOption renders correct icons for different table types', () => { ); expect(mvContainer.querySelector('.anticon')).toBeInTheDocument(); }); + +test('handles OAuth2 error by displaying ErrorMessageWithStackTrace instead of calling handleError', async () => { + const oauth2ErrorResponse = { + errors: [ + { + error_type: 'OAUTH2_REDIRECT', + message: "OAuth token needed.", + level: 'warning', + extra: { + url: 'https://oauth.example.com/authorize', + tab_id: 'test-tab-id', + redirect_uri: 'https://superset.example.com/oauth2/', + }, + }, + ], + }; + + fetchMock.get(catalogApiRoute, { result: [] }); + fetchMock.get(schemaApiRoute, { result: ['test_schema'] }); + fetchMock.get(tablesApiRoute, { + status: 500, + body: oauth2ErrorResponse, + }); + + const handleError = jest.fn(); + const props = createProps({ handleError }); + render(<TableSelector {...props} />, { useRedux: true, store }); + + // Wait for the API call to complete and error to be processed + await waitFor( + () => { + // The ErrorMessageWithStackTrace component should render when errors array is present + // handleError should NOT be called when errors array exists (OAuth2 pattern) + expect(handleError).not.toHaveBeenCalled(); + }, + { timeout: 10000 }, + ); + + // Verify the error alert component is rendered + expect(screen.getByRole('alert')).toBeInTheDocument(); +}); diff --git a/superset-frontend/src/components/TableSelector/index.tsx b/superset-frontend/src/components/TableSelector/index.tsx index 60c6171e6d..f81eccc3d6 100644 --- a/superset-frontend/src/components/TableSelector/index.tsx +++ b/superset-frontend/src/components/TableSelector/index.tsx @@ -26,10 +26,10 @@ import { import type { SelectValue } from '@superset-ui/core/components'; import { t } from '@apache-superset/core'; -import { getClientErrorMessage, getClientErrorObject } from '@superset-ui/core'; +import { SupersetError } from '@superset-ui/core'; import { styled } from '@apache-superset/core/ui'; import { CertifiedBadge, Select } from '@superset-ui/core/components'; -import { DatabaseSelector } from 'src/components'; +import { DatabaseSelector, ErrorMessageWithStackTrace } from 'src/components'; import { Icons } from '@superset-ui/core/components/Icons'; import type { DatabaseObject } from 'src/components/DatabaseSelector/types'; import { StyledFormLabel } from 'src/components/DatabaseSelector/styles'; @@ -183,6 +183,7 @@ const TableSelector: FunctionComponent<TableSelectorProps> = ({ const [tableSelectValue, setTableSelectValue] = useState< SelectValue | undefined >(undefined); + const [errorPayload, setErrorPayload] = useState<SupersetError | null>(null); const { currentData: data, isFetching: loadingTables, @@ -192,19 +193,17 @@ const TableSelector: FunctionComponent<TableSelectorProps> = ({ catalog: currentCatalog, schema: currentSchema, onSuccess: (data, isFetched) => { + setErrorPayload(null); if (isFetched) { addSuccessToast(t('List updated')); } }, - onError: err => { - getClientErrorObject(err).then(clientError => { - handleError( - getClientErrorMessage( - t('There was an error loading the tables'), - clientError, - ), - ); - }); + onError: error => { + if (error?.errors) { + setErrorPayload(error?.errors?.[0] ?? null); + } else { + handleError(error?.error || t('There was an error loading the tables')); + } }, }); @@ -345,6 +344,12 @@ const TableSelector: FunctionComponent<TableSelectorProps> = ({ ); } + function renderError() { + return errorPayload ? ( + <ErrorMessageWithStackTrace error={errorPayload} source="crud" /> + ) : null; + } + return ( <TableSelectorWrapper> <DatabaseSelector @@ -364,6 +369,7 @@ const TableSelector: FunctionComponent<TableSelectorProps> = ({ readOnly={readOnly} /> {sqlLabMode && !formMode && <div className="divider" />} + {renderError()} {renderTableSelect()} </TableSelectorWrapper> ); diff --git a/superset-frontend/src/hooks/apiResources/tables.ts b/superset-frontend/src/hooks/apiResources/tables.ts index 81792b4a52..99fc568d1f 100644 --- a/superset-frontend/src/hooks/apiResources/tables.ts +++ b/superset-frontend/src/hooks/apiResources/tables.ts @@ -17,6 +17,7 @@ * under the License. */ import { useCallback, useMemo, useEffect, useRef } from 'react'; +import { ClientErrorObject } from '@superset-ui/core'; import useEffectEvent from 'src/hooks/useEffectEvent'; import { toQueryString } from 'src/utils/urlUtils'; import { api, JsonResponse } from './queryApi'; @@ -55,7 +56,7 @@ export type FetchTablesQueryParams = { schema?: string; forceRefresh?: boolean; onSuccess?: (data: Data, isRefetched: boolean) => void; - onError?: (error: Response) => void; + onError?: (error: ClientErrorObject) => void; }; export type FetchTableMetadataQueryParams = { @@ -192,7 +193,7 @@ export function useTables(options: Params) { onSuccess?.(data, isRefetched); }); - const handleOnError = useEffectEvent((error: Response) => { + const handleOnError = useEffectEvent((error: ClientErrorObject) => { onError?.(error); }); @@ -204,7 +205,7 @@ export function useTables(options: Params) { handleOnSuccess(data, true); } if (isError) { - handleOnError(error as Response); + handleOnError(error as ClientErrorObject); } }, ); @@ -227,7 +228,7 @@ export function useTables(options: Params) { handleOnSuccess(currentData, false); } if (isError) { - handleOnError(error as Response); + handleOnError(error as ClientErrorObject); } } } else { diff --git a/superset/models/core.py b/superset/models/core.py index 8fe4911996..cb7bdf2d35 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -896,6 +896,9 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint: ) } except Exception as ex: + if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex): + self.start_oauth2_dance() + raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex @cache_util.memoized_func( @@ -930,6 +933,9 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint: ) } except Exception as ex: + if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex): + self.start_oauth2_dance() + raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex @cache_util.memoized_func( @@ -966,6 +972,9 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint: ) } except Exception as ex: + if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex): + self.start_oauth2_dance() + raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex return set() diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 2667d359d2..998a1033bb 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -410,6 +410,121 @@ def test_get_all_catalog_names_needs_oauth2(mocker: MockerFixture) -> None: assert excinfo.value.error.error_type == SupersetErrorType.OAUTH2_REDIRECT +def test_get_all_table_names_in_schema_needs_oauth2(mocker: MockerFixture) -> None: + """ + Test the `get_all_table_names_in_schema` method when OAuth2 is needed. + """ + database = Database( + database_name="db", + sqlalchemy_uri="snowflake://:@abcd1234.snowflakecomputing.com/db", + encrypted_extra=json.dumps(oauth2_client_info), + ) + + class DriverSpecificError(Exception): + """ + A custom exception that is raised by the Snowflake driver. + """ + + mocker.patch.object( + database.db_engine_spec, + "oauth2_exception", + DriverSpecificError, + ) + mocker.patch.object( + database.db_engine_spec, + "get_table_names", + side_effect=DriverSpecificError("User needs to authenticate"), + ) + mocker.patch.object(database, "get_inspector") + user = mocker.MagicMock() + user.id = 42 + mocker.patch("superset.db_engine_specs.base.g", user=user) + + with pytest.raises(OAuth2RedirectError) as excinfo: + database.get_all_table_names_in_schema(catalog=None, schema="public") + + assert excinfo.value.message == "You don't have permission to access the data." + assert excinfo.value.error.error_type == SupersetErrorType.OAUTH2_REDIRECT + + +def test_get_all_view_names_in_schema_needs_oauth2(mocker: MockerFixture) -> None: + """ + Test the `get_all_view_names_in_schema` method when OAuth2 is needed. + """ + database = Database( + database_name="db", + sqlalchemy_uri="snowflake://:@abcd1234.snowflakecomputing.com/db", + encrypted_extra=json.dumps(oauth2_client_info), + ) + + class DriverSpecificError(Exception): + """ + A custom exception that is raised by the Snowflake driver. + """ + + mocker.patch.object( + database.db_engine_spec, + "oauth2_exception", + DriverSpecificError, + ) + mocker.patch.object( + database.db_engine_spec, + "get_view_names", + side_effect=DriverSpecificError("User needs to authenticate"), + ) + mocker.patch.object(database, "get_inspector") + user = mocker.MagicMock() + user.id = 42 + mocker.patch("superset.db_engine_specs.base.g", user=user) + + with pytest.raises(OAuth2RedirectError) as excinfo: + database.get_all_view_names_in_schema(catalog=None, schema="public") + + assert excinfo.value.message == "You don't have permission to access the data." + assert excinfo.value.error.error_type == SupersetErrorType.OAUTH2_REDIRECT + + +def test_get_all_materialized_view_names_in_schema_needs_oauth2( + mocker: MockerFixture, +) -> None: + """ + Test the `get_all_materialized_view_names_in_schema` method when OAuth2 is needed. + """ + database = Database( + database_name="db", + sqlalchemy_uri="snowflake://:@abcd1234.snowflakecomputing.com/db", + encrypted_extra=json.dumps(oauth2_client_info), + ) + + class DriverSpecificError(Exception): + """ + A custom exception that is raised by the Snowflake driver. + """ + + mocker.patch.object( + database.db_engine_spec, + "oauth2_exception", + DriverSpecificError, + ) + mocker.patch.object( + database.db_engine_spec, + "get_materialized_view_names", + side_effect=DriverSpecificError("User needs to authenticate"), + ) + mocker.patch.object(database, "get_inspector") + user = mocker.MagicMock() + user.id = 42 + mocker.patch("superset.db_engine_specs.base.g", user=user) + + with pytest.raises(OAuth2RedirectError) as excinfo: + database.get_all_materialized_view_names_in_schema( + catalog=None, schema="public" + ) + + assert excinfo.value.message == "You don't have permission to access the data." + assert excinfo.value.error.error_type == SupersetErrorType.OAUTH2_REDIRECT + + def test_get_sqla_engine(mocker: MockerFixture) -> None: """ Test `_get_sqla_engine`.
