ruijiang-rjian commented on code in PR #62962: URL: https://github.com/apache/airflow/pull/62962#discussion_r2892706891
########## providers/amazon/tests/system/amazon/aws/example_sagemaker_unified_studio_explicit_params.py: ########## @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow.providers.amazon.aws.operators.sagemaker_unified_studio import ( + SageMakerNotebookOperator, +) + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import DAG, chain, task +else: + # Airflow 2 path + from airflow.decorators import task # type: ignore[attr-defined,no-redef] + from airflow.models.baseoperator import chain # type: ignore[attr-defined,no-redef] + from airflow.models.dag import DAG # type: ignore[attr-defined,no-redef,assignment] + +from system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder + +""" +Prerequisites: The account which runs this test must manually have the following: +1. An IAM IDC organization set up in the testing region with a user initialized Review Comment: sry what's IAM IDC.. do you mean IDC([IAM Identity Center](https://aws.amazon.com/iam/identity-center/))? Or IAM domain or IDC domain...? ########## providers/amazon/tests/system/amazon/aws/example_sagemaker_unified_studio_explicit_params.py: ########## @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow.providers.amazon.aws.operators.sagemaker_unified_studio import ( + SageMakerNotebookOperator, +) + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import DAG, chain, task +else: + # Airflow 2 path + from airflow.decorators import task # type: ignore[attr-defined,no-redef] + from airflow.models.baseoperator import chain # type: ignore[attr-defined,no-redef] + from airflow.models.dag import DAG # type: ignore[attr-defined,no-redef,assignment] + +from system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder + +""" +Prerequisites: The account which runs this test must manually have the following: +1. An IAM IDC organization set up in the testing region with a user initialized +2. A SageMaker Unified Studio Domain (with default VPC and roles) +3. A project within the SageMaker Unified Studio Domain +4. A notebook (test_notebook.ipynb) placed in the project's s3 path + +This test is identical to example_sagemaker_unified_studio, but passes domain_id, project_id, +and domain_region explicitly as operator parameters instead of relying on environment variables. +""" + +DAG_ID = "example_sagemaker_unified_studio_explicit_params" + +# Externally fetched variables: +DOMAIN_ID_KEY = "DOMAIN_ID" +PROJECT_ID_KEY = "PROJECT_ID" +S3_PATH_KEY = "S3_PATH" +REGION_NAME_KEY = "REGION_NAME" + +sys_test_context_task = ( + SystemTestContextBuilder() + .add_variable(DOMAIN_ID_KEY) + .add_variable(PROJECT_ID_KEY) + .add_variable(S3_PATH_KEY) + .add_variable(REGION_NAME_KEY) + .build() +) + + +def get_mwaa_environment_params(s3_path: str): + AIRFLOW_PREFIX = "AIRFLOW__WORKFLOWS__" + return {f"{AIRFLOW_PREFIX}PROJECT_S3_PATH": s3_path} Review Comment: no we shouldn't need to use this(and actually our system test should prove that we don't need any env variables to be set in the airflow env for the operator with new parameters provided as inputs to work) we're expecting to be able to infer the s3 path using the provided domain_id and project_id: https://code.amazon.com/packages/MaxDomePythonSDK/blobs/947dbe39c55874437a8b9c5e181a3d3483c85ca9/--/src/sagemaker_studio/execution/remote_execution_client.py#L887,L888 ########## providers/amazon/tests/system/amazon/aws/example_sagemaker_unified_studio_explicit_params.py: ########## @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow.providers.amazon.aws.operators.sagemaker_unified_studio import ( + SageMakerNotebookOperator, +) + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import DAG, chain, task +else: + # Airflow 2 path + from airflow.decorators import task # type: ignore[attr-defined,no-redef] + from airflow.models.baseoperator import chain # type: ignore[attr-defined,no-redef] + from airflow.models.dag import DAG # type: ignore[attr-defined,no-redef,assignment] + +from system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder + +""" +Prerequisites: The account which runs this test must manually have the following: +1. An IAM IDC organization set up in the testing region with a user initialized +2. A SageMaker Unified Studio Domain (with default VPC and roles) +3. A project within the SageMaker Unified Studio Domain +4. A notebook (test_notebook.ipynb) placed in the project's s3 path + +This test is identical to example_sagemaker_unified_studio, but passes domain_id, project_id, +and domain_region explicitly as operator parameters instead of relying on environment variables. +""" + +DAG_ID = "example_sagemaker_unified_studio_explicit_params" + +# Externally fetched variables: +DOMAIN_ID_KEY = "DOMAIN_ID" +PROJECT_ID_KEY = "PROJECT_ID" +S3_PATH_KEY = "S3_PATH" +REGION_NAME_KEY = "REGION_NAME" + +sys_test_context_task = ( + SystemTestContextBuilder() + .add_variable(DOMAIN_ID_KEY) + .add_variable(PROJECT_ID_KEY) + .add_variable(S3_PATH_KEY) + .add_variable(REGION_NAME_KEY) + .build() +) + + +def get_mwaa_environment_params(s3_path: str): + AIRFLOW_PREFIX = "AIRFLOW__WORKFLOWS__" + return {f"{AIRFLOW_PREFIX}PROJECT_S3_PATH": s3_path} + + +@task +def mock_mwaa_environment(parameters: dict): + """ + Sets several environment variables in the container to emulate an MWAA environment provisioned + within SageMaker Unified Studio. When running in the ECSExecutor, this is a no-op. + """ + import os + + for key, value in parameters.items(): + os.environ[key] = value + + +with DAG( + DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, +) as dag: + test_context = sys_test_context_task() + + test_env_id = test_context[ENV_ID_KEY] + domain_id = test_context[DOMAIN_ID_KEY] + project_id = test_context[PROJECT_ID_KEY] + s3_path = test_context[S3_PATH_KEY] + region_name = test_context[REGION_NAME_KEY] + + mock_mwaa_environment_params = get_mwaa_environment_params(s3_path) + + setup_mwaa_environment = mock_mwaa_environment(mock_mwaa_environment_params) Review Comment: we should probably remove the whole mock_mwaa_environment() ########## providers/amazon/tests/system/amazon/aws/example_sagemaker_unified_studio_explicit_params.py: ########## @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow.providers.amazon.aws.operators.sagemaker_unified_studio import ( + SageMakerNotebookOperator, +) + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import DAG, chain, task +else: + # Airflow 2 path + from airflow.decorators import task # type: ignore[attr-defined,no-redef] + from airflow.models.baseoperator import chain # type: ignore[attr-defined,no-redef] + from airflow.models.dag import DAG # type: ignore[attr-defined,no-redef,assignment] + +from system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder + +""" +Prerequisites: The account which runs this test must manually have the following: +1. An IAM IDC organization set up in the testing region with a user initialized +2. A SageMaker Unified Studio Domain (with default VPC and roles) +3. A project within the SageMaker Unified Studio Domain +4. A notebook (test_notebook.ipynb) placed in the project's s3 path + +This test is identical to example_sagemaker_unified_studio, but passes domain_id, project_id, +and domain_region explicitly as operator parameters instead of relying on environment variables. +""" + +DAG_ID = "example_sagemaker_unified_studio_explicit_params" + +# Externally fetched variables: +DOMAIN_ID_KEY = "DOMAIN_ID" +PROJECT_ID_KEY = "PROJECT_ID" +S3_PATH_KEY = "S3_PATH" +REGION_NAME_KEY = "REGION_NAME" + +sys_test_context_task = ( + SystemTestContextBuilder() + .add_variable(DOMAIN_ID_KEY) + .add_variable(PROJECT_ID_KEY) + .add_variable(S3_PATH_KEY) + .add_variable(REGION_NAME_KEY) + .build() +) + + +def get_mwaa_environment_params(s3_path: str): + AIRFLOW_PREFIX = "AIRFLOW__WORKFLOWS__" + return {f"{AIRFLOW_PREFIX}PROJECT_S3_PATH": s3_path} + + +@task +def mock_mwaa_environment(parameters: dict): + """ + Sets several environment variables in the container to emulate an MWAA environment provisioned + within SageMaker Unified Studio. When running in the ECSExecutor, this is a no-op. + """ + import os + + for key, value in parameters.items(): + os.environ[key] = value + + +with DAG( + DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, +) as dag: + test_context = sys_test_context_task() + + test_env_id = test_context[ENV_ID_KEY] + domain_id = test_context[DOMAIN_ID_KEY] + project_id = test_context[PROJECT_ID_KEY] + s3_path = test_context[S3_PATH_KEY] + region_name = test_context[REGION_NAME_KEY] + + mock_mwaa_environment_params = get_mwaa_environment_params(s3_path) + + setup_mwaa_environment = mock_mwaa_environment(mock_mwaa_environment_params) + + # [START howto_operator_sagemaker_unified_studio_notebook_explicit_params] + notebook_path = "test_notebook.ipynb" # This should be the path to your .ipynb, .sqlnb, or .vetl file in your project. + + run_notebook = SageMakerNotebookOperator( + task_id="run-notebook", + domain_id=domain_id, + project_id=project_id, + domain_region=region_name, Review Comment: if possible, we'd also need to test without providing domain_region as it's an optional field.. ########## providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py: ########## @@ -101,6 +101,11 @@ class SageMakerNotebookOperator(BaseOperator): """ operator_extra_links = (SageMakerUnifiedStudioLink(),) + # These fields are declared as template_fields so Airflow resolves XCom references + # (e.g. task_instance.xcom_pull(...)) to actual string values before execute() is called. + # Without this, the hook would be instantiated with unresolved PlainXComArg objects, + # causing a ParamValidationError when the underlying SDK tries to use them as strings. + template_fields = ("domain_id", "project_id", "domain_region") Review Comment: Good catch on this! and +1 to question above, do we want to apply to other parameters as well? ########## providers/amazon/tests/system/amazon/aws/example_sagemaker_unified_studio.py: ########## @@ -78,10 +75,6 @@ def get_mwaa_environment_params( parameters = {} parameters[f"{AIRFLOW_PREFIX}DATAZONE_DOMAIN_ID"] = domain_id parameters[f"{AIRFLOW_PREFIX}DATAZONE_PROJECT_ID"] = project_id - parameters[f"{AIRFLOW_PREFIX}DATAZONE_ENVIRONMENT_ID"] = environment_id - parameters[f"{AIRFLOW_PREFIX}DATAZONE_SCOPE_NAME"] = "dev" - parameters[f"{AIRFLOW_PREFIX}DATAZONE_STAGE"] = "prod" - parameters[f"{AIRFLOW_PREFIX}DATAZONE_ENDPOINT"] = f"https://datazone.{region_name}.api.aws" Review Comment: these are not being actually used in SDK either so we completely removed these in SMUS SDK recently ########## providers/amazon/tests/system/amazon/aws/example_sagemaker_unified_studio_explicit_params.py: ########## Review Comment: yeah I agree, we should try to keep in the same system test as it's the same operator but with a different setup, +1 a new task before the line to set the env variables in MWAA should hopefully work..? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
