This is an automated email from the ASF dual-hosted git repository. ferruzzi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new ab3429c318 Add STOPPED to the failure cases for Sagemaker Training Jobs (#42423) ab3429c318 is described below commit ab3429c3189ceb244eb3d78062159859dbe611ce Author: D. Ferruzzi <ferru...@amazon.com> AuthorDate: Tue Sep 24 15:07:40 2024 -0700 Add STOPPED to the failure cases for Sagemaker Training Jobs (#42423) --- airflow/providers/amazon/aws/hooks/sagemaker.py | 3 ++- airflow/providers/amazon/aws/sensors/sagemaker.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py index af131697a5..2c0f4fb25e 100644 --- a/airflow/providers/amazon/aws/hooks/sagemaker.py +++ b/airflow/providers/amazon/aws/hooks/sagemaker.py @@ -155,6 +155,7 @@ class SageMakerHook(AwsBaseHook): endpoint_non_terminal_states = {"Creating", "Updating", "SystemUpdating", "RollingBack", "Deleting"} pipeline_non_terminal_states = {"Executing", "Stopping"} failed_states = {"Failed"} + training_failed_states = {*failed_states, "Stopped"} def __init__(self, *args, **kwargs): super().__init__(client_type="sagemaker", *args, **kwargs) @@ -309,7 +310,7 @@ class SageMakerHook(AwsBaseHook): self.check_training_status_with_log( config["TrainingJobName"], self.non_terminal_states, - self.failed_states, + self.training_failed_states, wait_for_completion, check_interval, max_ingestion_time, diff --git a/airflow/providers/amazon/aws/sensors/sagemaker.py b/airflow/providers/amazon/aws/sensors/sagemaker.py index b01e24cd5b..af07c504aa 100644 --- a/airflow/providers/amazon/aws/sensors/sagemaker.py +++ b/airflow/providers/amazon/aws/sensors/sagemaker.py @@ -238,7 +238,7 @@ class SageMakerTrainingSensor(SageMakerBaseSensor): return SageMakerHook.non_terminal_states def failed_states(self): - return SageMakerHook.failed_states + return SageMakerHook.training_failed_states def get_sagemaker_response(self): if self.print_log: