o-nikolas commented on code in PR #62343:
URL: https://github.com/apache/airflow/pull/62343#discussion_r2891424498


##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -1759,6 +1760,15 @@ def _run_scheduler_loop(self) -> None:
                     action=bundle_cleanup_mgr.remove_stale_bundle_versions,
                 )
 
+        timers.call_regular_interval(
+            delay=conf.getfloat("scheduler", 
"connection_test_dispatch_interval", fallback=2.0),
+            action=self._dispatch_connection_tests,

Review Comment:
   This ultimately schedules workloads to the executors, I believe this should 
be in the scheduler loop below.



##########
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py:
##########
@@ -1140,6 +1148,275 @@ def 
test_should_test_new_connection_without_existing(self, test_client):
         assert response.json()["status"] is True
 
 
+class TestAsyncConnectionTest(TestConnectionEndpoint):
+    """Tests for the async connection test endpoints (POST + GET polling)."""
+
+    @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+    def test_post_should_respond_202(self, test_client, session):
+        """POST /connections/test-async with a saved connection returns 202 + 
token."""
+        self.create_connection()
+        response = test_client.post("/connections/test-async", 
json={"connection_id": TEST_CONN_ID})
+        assert response.status_code == 202
+        body = response.json()
+        assert "token" in body
+        assert body["connection_id"] == TEST_CONN_ID
+        assert body["state"] == "pending"
+        assert len(body["token"]) > 0
+
+    def test_should_respond_401(self, unauthenticated_test_client):
+        response = unauthenticated_test_client.post(
+            "/connections/test-async", json={"connection_id": TEST_CONN_ID}
+        )
+        assert response.status_code == 401
+
+    def test_should_respond_403(self, unauthorized_test_client):
+        response = unauthorized_test_client.post(
+            "/connections/test-async", json={"connection_id": TEST_CONN_ID}
+        )
+        assert response.status_code == 403
+
+    def test_should_respond_403_by_default(self, test_client):
+        """Connection testing is disabled by default."""
+        response = test_client.post("/connections/test-async", 
json={"connection_id": TEST_CONN_ID})
+        assert response.status_code == 403
+        assert response.json() == {
+            "detail": "Testing connections is disabled in Airflow 
configuration. "
+            "Contact your deployment admin to enable it."
+        }
+
+    @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+    def test_should_respond_404_for_nonexistent_connection(self, test_client):
+        """Connection must be saved before testing."""
+        response = test_client.post("/connections/test-async", 
json={"connection_id": "nonexistent"})
+        assert response.status_code == 404
+        assert "was not found" in response.json()["detail"]
+
+    @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+    def test_post_creates_connection_test_row(self, test_client, session):
+        """POST creates a ConnectionTest row in PENDING state."""
+        self.create_connection()
+        response = test_client.post("/connections/test-async", 
json={"connection_id": TEST_CONN_ID})
+        assert response.status_code == 202
+        token = response.json()["token"]
+
+        ct = session.scalar(select(ConnectionTest).filter_by(token=token))
+        assert ct is not None
+        assert ct.connection_id == TEST_CONN_ID
+        assert ct.state == "pending"
+
+    @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+    def test_post_passes_queue_parameter(self, test_client, session):
+        """POST /connections/test-async passes the queue parameter to the 
ConnectionTest."""
+        self.create_connection()
+        response = test_client.post(
+            "/connections/test-async",
+            json={"connection_id": TEST_CONN_ID, "queue": "gpu_workers"},
+        )
+        assert response.status_code == 202
+        token = response.json()["token"]
+
+        ct = session.scalar(select(ConnectionTest).filter_by(token=token))
+        assert ct is not None
+        assert ct.queue == "gpu_workers"
+
+    @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+    def test_get_status_returns_pending(self, test_client, session):
+        """GET /connections/test-async/{token} returns current status (pending 
before scheduler dispatch)."""
+        self.create_connection()
+        post_response = test_client.post("/connections/test-async", 
json={"connection_id": TEST_CONN_ID})
+        token = post_response.json()["token"]
+
+        response = test_client.get(f"/connections/test-async/{token}")
+        assert response.status_code == 200
+        body = response.json()
+        assert body["token"] == token
+        assert body["connection_id"] == TEST_CONN_ID
+        assert body["state"] == "pending"
+        assert body["result_message"] is None
+        assert "created_at" in body
+
+    @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+    def test_get_status_returns_completed_result(self, test_client, session):
+        """GET returns result after the worker has updated the 
ConnectionTest."""
+        self.create_connection()
+        post_response = test_client.post("/connections/test-async", 
json={"connection_id": TEST_CONN_ID})
+        token = post_response.json()["token"]
+
+        ct = session.scalar(select(ConnectionTest).filter_by(token=token))
+        ct.state = ConnectionTestState.SUCCESS
+        ct.result_message = "Connection successfully tested"
+        session.commit()
+
+        response = test_client.get(f"/connections/test-async/{token}")
+        assert response.status_code == 200
+        body = response.json()
+        assert body["state"] == "success"
+        assert body["result_message"] == "Connection successfully tested"
+
+    def test_get_status_returns_404_for_invalid_token(self, test_client):
+        """GET with an unknown token returns 404."""
+        response = test_client.get("/connections/test-async/nonexistent-token")
+        assert response.status_code == 404
+
+
+class TestSaveAndTest(TestConnectionEndpoint):
+    """Tests for the combined PATCH /{connection_id}/test endpoint."""
+
+    @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+    def test_save_and_test_returns_200_with_token(self, test_client, session):
+        """PATCH /{connection_id}/test updates the connection and returns a 
test token."""
+        self.create_connection()
+        response = test_client.patch(
+            f"/connections/{TEST_CONN_ID}/test",
+            json={
+                "connection_id": TEST_CONN_ID,
+                "conn_type": TEST_CONN_TYPE,
+                "host": "updated-host.example.com",
+            },
+        )
+        assert response.status_code == 200
+        body = response.json()
+        assert body["test_token"]
+        assert body["test_state"] == "pending"
+        assert body["connection"]["host"] == "updated-host.example.com"
+
+    @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+    def test_save_and_test_creates_snapshot(self, test_client, session):
+        """PATCH /{connection_id}/test creates a ConnectionTest with a 
connection_snapshot."""
+        self.create_connection()
+        response = test_client.patch(
+            f"/connections/{TEST_CONN_ID}/test",
+            json={
+                "connection_id": TEST_CONN_ID,
+                "conn_type": TEST_CONN_TYPE,
+                "host": "new-host.example.com",
+            },
+        )
+        assert response.status_code == 200
+        token = response.json()["test_token"]
+
+        ct = session.scalar(select(ConnectionTest).filter_by(token=token))
+        assert ct is not None
+        assert ct.connection_snapshot is not None
+        snapshot = ct.connection_snapshot
+        assert "pre" in snapshot
+        assert "post" in snapshot
+        assert snapshot["pre"]["host"] == TEST_CONN_HOST
+        assert snapshot["post"]["host"] == "new-host.example.com"
+
+    @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+    def test_save_and_test_passes_executor_parameter(self, test_client, 
session):
+        """PATCH /{connection_id}/test passes the executor parameter to the 
ConnectionTest."""
+        self.create_connection()
+        response = test_client.patch(
+            f"/connections/{TEST_CONN_ID}/test?executor=team_a",

Review Comment:
   Why is the executor being set to a team name?



##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -3243,6 +3253,100 @@ def _activate_assets_generate_warnings() -> 
Iterator[tuple[str, str]]:
             session.add(warning)
             existing_warned_dag_ids.add(warning.dag_id)
 
+    @provide_session
+    def _dispatch_connection_tests(self, *, session: Session = NEW_SESSION) -> 
None:
+        """Dispatch pending connection tests to executors that support them."""
+        max_concurrency = conf.getint("core", 
"max_connection_test_concurrency", fallback=4)
+        timeout = conf.getint("core", "connection_test_timeout", fallback=60)
+
+        active_count = session.scalar(
+            
select(func.count(ConnectionTest.id)).where(ConnectionTest.state.in_(ACTIVE_STATES))
+        )
+        budget = max_concurrency - (active_count or 0)
+        if budget <= 0:
+            return
+
+        pending_stmt = (
+            select(ConnectionTest)
+            .where(ConnectionTest.state == ConnectionTestState.PENDING)
+            .order_by(ConnectionTest.created_at)
+            .limit(budget)
+        )
+        pending_stmt = with_row_locks(pending_stmt, session, 
of=ConnectionTest, skip_locked=True)
+        pending_tests = session.scalars(pending_stmt).all()
+
+        if not pending_tests:
+            return
+
+        for ct in pending_tests:
+            executor = self._find_executor_for_connection_test(ct.executor)
+            if executor is None:
+                reason = (
+                    f"No executor matches '{ct.executor}'"
+                    if ct.executor
+                    else "No executor supports connection testing"
+                )
+                ct.state = ConnectionTestState.FAILED
+                ct.result_message = reason
+                self.log.warning("Failing connection test %s: %s", ct.id, 
reason)
+                continue
+
+            workload = workloads.TestConnection.make(
+                connection_test_id=ct.id,
+                connection_id=ct.connection_id,
+                timeout=timeout,
+                queue=ct.queue,
+                generator=executor.jwt_generator,
+            )
+            executor.queue_workload(workload, session=session)
+            ct.state = ConnectionTestState.QUEUED
+
+        session.flush()
+
+    @provide_session
+    def _reap_stale_connection_tests(self, *, session: Session = NEW_SESSION) 
-> None:
+        """Mark connection tests that have exceeded their timeout as FAILED."""
+        timeout = conf.getint("core", "connection_test_timeout", fallback=60)
+        grace_period = max(30, timeout // 2)
+        cutoff = timezone.utcnow() - timedelta(seconds=timeout + grace_period)
+
+        stale_stmt = select(ConnectionTest).where(
+            ConnectionTest.state.in_(ACTIVE_STATES),
+            ConnectionTest.updated_at < cutoff,
+        )
+        stale_stmt = with_row_locks(stale_stmt, session, of=ConnectionTest, 
skip_locked=True)
+        stale_tests = session.scalars(stale_stmt).all()
+
+        for ct in stale_tests:
+            ct.state = ConnectionTestState.FAILED
+            ct.result_message = f"Connection test timed out (exceeded 
{timeout}s + {grace_period}s grace)"
+            self.log.warning("Reaped stale connection test %s", ct.id)
+            if ct.connection_snapshot:
+                attempt_revert(ct, session=session)
+
+        session.flush()
+
+    def _find_executor_for_connection_test(self, executor_name: str | None) -> 
BaseExecutor | None:

Review Comment:
   This executor loading duplicates existing code. It already has the problem 
that happens when you duplicate code, it's drifted from the other 
implementation that all other workload use, it's missing team support and it's 
missing a case executors match on, class name.
   
   Please use `_try_to_load_executor()` that accepts a workload (which your new 
test connection workload type should be a sub type of, if not, please update 
that). This is the single canonical location for executor lookups in the 
scheduler job.



##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -3243,6 +3253,100 @@ def _activate_assets_generate_warnings() -> 
Iterator[tuple[str, str]]:
             session.add(warning)
             existing_warned_dag_ids.add(warning.dag_id)
 
+    @provide_session
+    def _dispatch_connection_tests(self, *, session: Session = NEW_SESSION) -> 
None:
+        """Dispatch pending connection tests to executors that support them."""
+        max_concurrency = conf.getint("core", 
"max_connection_test_concurrency", fallback=4)
+        timeout = conf.getint("core", "connection_test_timeout", fallback=60)
+
+        active_count = session.scalar(
+            
select(func.count(ConnectionTest.id)).where(ConnectionTest.state.in_(ACTIVE_STATES))
+        )
+        budget = max_concurrency - (active_count or 0)

Review Comment:
   This must take into account available executor slots. It is the 
responsibility of the scheduler to ensure we don't over subscribe all executors 
across all teams.
   
   Please look at the other workload types for reference 
`_enqueue_executor_callbacks` and `_critical_section_enqueue_task_instances`. 
These both contain logic to assess the number of available executor slots and 
to only submit/queue as many workloads as there are slots available.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to