This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new db948695197 Fix inflated total_received count in partitioned dag runs
API (#62786)
db948695197 is described below
commit db948695197ccec2ea77365e9f81251cb909986a
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Thu Mar 5 10:10:10 2026 +0800
Fix inflated total_received count in partitioned dag runs API (#62786)
Signed-off-by: Guan-Ming (Wesley) Chiu
<[email protected]>
---
.../core_api/routes/ui/partitioned_dag_runs.py | 1 +
.../routes/ui/test_partitioned_dag_runs.py | 60 ++++++++++++++++++++++
2 files changed, 61 insertions(+)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py
index fd67b0e48c6..f9090a6ece0 100644
---
a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py
+++
b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py
@@ -100,6 +100,7 @@ def get_partitioned_dag_runs(
DagScheduleAssetReference.dag_id ==
AssetPartitionDagRun.target_dag_id,
AssetModel.active.has(),
)
+ .correlate(AssetPartitionDagRun)
)
received_subq = (
select(func.count(func.distinct(PartitionedAssetKeyLog.asset_id)))
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py
index 52f0ec1e265..658d23abb21 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py
@@ -158,6 +158,66 @@ class TestGetPartitionedDagRuns:
assert pdr_resp["total_received"] == received_count
assert pdr_resp["total_required"] == num_assets
+ @pytest.mark.parametrize(
+ ("num_target_assets", "num_other_assets", "received_count"),
+ [(1, 1, 1), (1, 2, 0), (2, 1, 1)],
+ )
+ def test_received_count_excludes_other_dags_assets(
+ self, test_client, dag_maker, session, num_target_assets,
num_other_assets, received_count
+ ):
+
+ def _make_schedule(prefix, count):
+ assets = [Asset(uri=f"s3://bucket/{prefix}{i}",
name=f"{prefix}{i}") for i in range(count)]
+ schedule = assets[0]
+ for a in assets[1:]:
+ schedule = schedule & a
+ return [a.uri for a in assets], schedule
+
+ target_uris, target_schedule = _make_schedule("t", num_target_assets)
+ other_uris, other_schedule = _make_schedule("o", num_other_assets)
+
+ for dag_id, schedule in [("target", target_schedule), ("other",
other_schedule)]:
+ with dag_maker(
+ dag_id=dag_id,
schedule=PartitionedAssetTimetable(assets=schedule), serialized=True
+ ):
+ EmptyOperator(task_id="t")
+ dag_maker.create_dagrun()
+ dag_maker.sync_dagbag_to_db()
+
+ all_uris = target_uris + other_uris
+ assets = {a.uri: a for a in
session.scalars(select(AssetModel).where(AssetModel.uri.in_(all_uris)))}
+
+ # Both Dags need APDRs so an uncorrelated subquery would cross-join
and inflate counts.
+ for dag_id in ("target", "other"):
+ session.add(AssetPartitionDagRun(target_dag_id=dag_id,
partition_key="2024-06-01"))
+ session.flush()
+
+ pdr = session.scalar(
+
select(AssetPartitionDagRun).where(AssetPartitionDagRun.target_dag_id ==
"target")
+ )
+ # Log target assets (up to received_count) and all other-Dag assets on
the same APDR.
+ for uri in target_uris[:received_count] + other_uris:
+ event = AssetEvent(asset_id=assets[uri].id,
timestamp=pendulum.now())
+ session.add(event)
+ session.flush()
+ session.add(
+ PartitionedAssetKeyLog(
+ asset_id=assets[uri].id,
+ asset_event_id=event.id,
+ asset_partition_dag_run_id=pdr.id,
+ source_partition_key="2024-06-01",
+ target_dag_id="target",
+ target_partition_key="2024-06-01",
+ )
+ )
+ session.commit()
+
+ resp =
test_client.get("/partitioned_dag_runs?dag_id=target&has_created_dag_run_id=false")
+ assert resp.status_code == 200
+ pdr_resp = resp.json()["partitioned_dag_runs"][0]
+ assert pdr_resp["total_required"] == num_target_assets
+ assert pdr_resp["total_received"] == received_count
+
class TestGetPendingPartitionedDagRun:
def test_should_response_401(self, unauthenticated_test_client):