This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 49b19dd9e81 Retry test DB cleanup functions on transient MySQL errors 
#62768 (#62823)
49b19dd9e81 is described below

commit 49b19dd9e81caed512661db3cddb8e4682fd466f
Author: Haseeb Malik <[email protected]>
AuthorDate: Fri Mar 6 04:16:10 2026 -0500

    Retry test DB cleanup functions on transient MySQL errors #62768 (#62823)
---
 devel-common/src/tests_common/test_utils/db.py | 51 ++++++++++++++++++++++++++
 1 file changed, 51 insertions(+)

diff --git a/devel-common/src/tests_common/test_utils/db.py 
b/devel-common/src/tests_common/test_utils/db.py
index acaa24df35b..4947e922adf 100644
--- a/devel-common/src/tests_common/test_utils/db.py
+++ b/devel-common/src/tests_common/test_utils/db.py
@@ -17,7 +17,9 @@
 # under the License.
 from __future__ import annotations
 
+import functools
 import json
+import logging
 import os
 from tempfile import gettempdir
 from typing import TYPE_CHECKING
@@ -50,6 +52,7 @@ from airflow.utils.db import (
     create_default_connections,
     reflect_tables,
 )
+from airflow.utils.retries import run_with_db_retries
 from airflow.utils.session import create_session
 
 from tests_common.test_utils.compat import (
@@ -62,6 +65,8 @@ from tests_common.test_utils.compat import (
 )
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, 
AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_2_PLUS
 
+log = logging.getLogger(__name__)
+
 if TYPE_CHECKING:
     from pathlib import Path
 
@@ -74,6 +79,23 @@ if AIRFLOW_V_3_1_PLUS:
     from airflow.models.dag_favorite import DagFavorite
 
 
+def _retry_db(func):
+    """
+    Retry on transient DB errors.
+
+    Handles MySQL mid-query disconnects (error 2013) in CI.
+    See https://github.com/apache/airflow/issues/62768
+    """
+
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        for attempt in run_with_db_retries(logger=log):
+            with attempt:
+                return func(*args, **kwargs)
+
+    return wrapper
+
+
 def _deactivate_unknown_dags(active_dag_ids, session):
     """
     Given a list of known DAGs, deactivate any other DAGs that are marked as 
active in the ORM.
@@ -198,6 +220,7 @@ def parse_and_sync_to_db(folder: Path | str, 
include_examples: bool = False):
     return dagbag
 
 
+@_retry_db
 def clear_db_runs():
     with create_session() as session:
         session.execute(delete(Job))
@@ -212,6 +235,7 @@ def clear_db_runs():
             pass
 
 
+@_retry_db
 def clear_db_backfills():
     from airflow.models.backfill import Backfill, BackfillDagRun
 
@@ -220,6 +244,7 @@ def clear_db_backfills():
         session.execute(delete(Backfill))
 
 
+@_retry_db
 def clear_db_assets():
     with create_session() as session:
         session.execute(delete(AssetEvent))
@@ -251,6 +276,7 @@ def clear_db_assets():
             session.execute(delete(AssetWatcherModel))
 
 
+@_retry_db
 def clear_db_triggers():
     with create_session() as session:
         if AIRFLOW_V_3_2_PLUS:
@@ -260,6 +286,7 @@ def clear_db_triggers():
         session.execute(delete(Trigger))
 
 
+@_retry_db
 def clear_db_dags():
     with create_session() as session:
         if AIRFLOW_V_3_1_PLUS:
@@ -272,6 +299,7 @@ def clear_db_dags():
         session.execute(delete(DagModel))
 
 
+@_retry_db
 def clear_db_deadline():
     with create_session() as session:
         if AIRFLOW_V_3_0_PLUS:
@@ -280,6 +308,7 @@ def clear_db_deadline():
             session.execute(delete(Deadline))
 
 
+@_retry_db
 def clear_db_deadline_alert():
     with create_session() as session:
         if AIRFLOW_V_3_2_PLUS:
@@ -288,6 +317,7 @@ def clear_db_deadline_alert():
             session.execute(delete(DeadlineAlert))
 
 
+@_retry_db
 def drop_tables_with_prefix(prefix):
     with create_session() as session:
         metadata = reflect_tables(None, session)
@@ -296,11 +326,13 @@ def drop_tables_with_prefix(prefix):
                 table.drop(session.bind)
 
 
+@_retry_db
 def clear_db_serialized_dags():
     with create_session() as session:
         session.execute(delete(SerializedDagModel))
 
 
+@_retry_db
 def clear_db_pools():
     with create_session() as session:
         session.execute(delete(Pool))
@@ -319,6 +351,7 @@ def 
clear_test_connections(add_default_connections_back=True):
         create_default_connections_for_tests()
 
 
+@_retry_db
 def clear_db_connections(add_default_connections_back=True):
     with create_session() as session:
         session.execute(delete(Connection))
@@ -326,16 +359,19 @@ def 
clear_db_connections(add_default_connections_back=True):
             create_default_connections(session)
 
 
+@_retry_db
 def clear_db_variables():
     with create_session() as session:
         session.execute(delete(Variable))
 
 
+@_retry_db
 def clear_db_dag_code():
     with create_session() as session:
         session.execute(delete(DagCode))
 
 
+@_retry_db
 def clear_db_callbacks():
     with create_session() as session:
         if AIRFLOW_V_3_2_PLUS:
@@ -347,32 +383,38 @@ def clear_db_callbacks():
             session.execute(delete(DbCallbackRequest))
 
 
+@_retry_db
 def set_default_pool_slots(slots):
     with create_session() as session:
         default_pool = Pool.get_default_pool(session)
         default_pool.slots = slots
 
 
+@_retry_db
 def clear_rendered_ti_fields():
     with create_session() as session:
         session.execute(delete(RenderedTaskInstanceFields))
 
 
+@_retry_db
 def clear_db_import_errors():
     with create_session() as session:
         session.execute(delete(ParseImportError))
 
 
+@_retry_db
 def clear_db_dag_warnings():
     with create_session() as session:
         session.execute(delete(DagWarning))
 
 
+@_retry_db
 def clear_db_xcom():
     with create_session() as session:
         session.execute(delete(XCom))
 
 
+@_retry_db
 def clear_db_pakl():
     if not AIRFLOW_V_3_2_PLUS:
         return
@@ -382,6 +424,7 @@ def clear_db_pakl():
         session.execute(delete(PartitionedAssetKeyLog))
 
 
+@_retry_db
 def clear_db_apdr():
     if not AIRFLOW_V_3_2_PLUS:
         return
@@ -391,21 +434,25 @@ def clear_db_apdr():
         session.execute(delete(AssetPartitionDagRun))
 
 
+@_retry_db
 def clear_db_logs():
     with create_session() as session:
         session.execute(delete(Log))
 
 
+@_retry_db
 def clear_db_jobs():
     with create_session() as session:
         session.execute(delete(Job))
 
 
+@_retry_db
 def clear_db_task_reschedule():
     with create_session() as session:
         session.execute(delete(TaskReschedule))
 
 
+@_retry_db
 def clear_db_dag_parsing_requests():
     with create_session() as session:
         from airflow.models.dagbag import DagPriorityParsingRequest
@@ -413,6 +460,7 @@ def clear_db_dag_parsing_requests():
         session.execute(delete(DagPriorityParsingRequest))
 
 
+@_retry_db
 def clear_db_dag_bundles():
     with create_session() as session:
         from airflow.models.dagbundle import DagBundleModel
@@ -420,6 +468,7 @@ def clear_db_dag_bundles():
         session.execute(delete(DagBundleModel))
 
 
+@_retry_db
 def clear_db_teams():
     with create_session() as session:
         from airflow.models.team import Team
@@ -427,6 +476,7 @@ def clear_db_teams():
         session.execute(delete(Team))
 
 
+@_retry_db
 def clear_db_revoked_tokens():
     with create_session() as session:
         from airflow.models.revoked_token import RevokedToken
@@ -434,6 +484,7 @@ def clear_db_revoked_tokens():
         session.execute(delete(RevokedToken))
 
 
+@_retry_db
 def clear_dag_specific_permissions():
     if "FabAuthManager" not in conf.get("core", "auth_manager"):
         return

Reply via email to