This is an automated email from the ASF dual-hosted git repository.

ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ab3429c318 Add STOPPED to the failure cases for Sagemaker Training 
Jobs (#42423)
ab3429c318 is described below

commit ab3429c3189ceb244eb3d78062159859dbe611ce
Author: D. Ferruzzi <ferru...@amazon.com>
AuthorDate: Tue Sep 24 15:07:40 2024 -0700

    Add STOPPED to the failure cases for Sagemaker Training Jobs (#42423)
---
 airflow/providers/amazon/aws/hooks/sagemaker.py   | 3 ++-
 airflow/providers/amazon/aws/sensors/sagemaker.py | 2 +-
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py 
b/airflow/providers/amazon/aws/hooks/sagemaker.py
index af131697a5..2c0f4fb25e 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -155,6 +155,7 @@ class SageMakerHook(AwsBaseHook):
     endpoint_non_terminal_states = {"Creating", "Updating", "SystemUpdating", 
"RollingBack", "Deleting"}
     pipeline_non_terminal_states = {"Executing", "Stopping"}
     failed_states = {"Failed"}
+    training_failed_states = {*failed_states, "Stopped"}
 
     def __init__(self, *args, **kwargs):
         super().__init__(client_type="sagemaker", *args, **kwargs)
@@ -309,7 +310,7 @@ class SageMakerHook(AwsBaseHook):
             self.check_training_status_with_log(
                 config["TrainingJobName"],
                 self.non_terminal_states,
-                self.failed_states,
+                self.training_failed_states,
                 wait_for_completion,
                 check_interval,
                 max_ingestion_time,
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker.py 
b/airflow/providers/amazon/aws/sensors/sagemaker.py
index b01e24cd5b..af07c504aa 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker.py
@@ -238,7 +238,7 @@ class SageMakerTrainingSensor(SageMakerBaseSensor):
         return SageMakerHook.non_terminal_states
 
     def failed_states(self):
-        return SageMakerHook.failed_states
+        return SageMakerHook.training_failed_states
 
     def get_sagemaker_response(self):
         if self.print_log:

Reply via email to