uranusjr commented on a change in pull request #19758:
URL: https://github.com/apache/airflow/pull/19758#discussion_r756359399



##########
File path: airflow/api_connexion/endpoints/dag_endpoint.py
##########
@@ -109,6 +109,53 @@ def patch_dag(session, dag_id, update_mask=None):
     return dag_schema.dump(dag)
 
 
[email protected]_access([(permissions.ACTION_CAN_READ, 
permissions.RESOURCE_DAG)])
+@format_parameters({'limit': check_limit})
+@provide_session
+def patch_dags(limit, session, offset=0, only_active=True, tags=None, 
dag_id_pattern=None, update_mask=None):
+    """Patch multiple DAGs."""
+    if only_active:
+        dags_query = session.query(DagModel).filter(~DagModel.is_subdag, 
DagModel.is_active)
+    else:
+        dags_query = session.query(DagModel).filter(~DagModel.is_subdag)
+
+    if dag_id_pattern:
+        dags_query = 
dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%'))
+
+    editable_dags = current_app.appbuilder.sm.get_editable_dag_ids(g.user)
+
+    dags_query = dags_query.filter(DagModel.dag_id.in_(editable_dags))
+    if tags:
+        cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags]
+        dags_query = dags_query.filter(or_(*cond))
+
+    total_entries = len(dags_query.all())

Review comment:
       Uh. I noticed `get_dags` also does this. It’s slow and we should use 
`func.count()` instead. Could you change this?

##########
File path: airflow/api_connexion/endpoints/dag_endpoint.py
##########
@@ -109,6 +109,53 @@ def patch_dag(session, dag_id, update_mask=None):
     return dag_schema.dump(dag)
 
 
[email protected]_access([(permissions.ACTION_CAN_READ, 
permissions.RESOURCE_DAG)])
+@format_parameters({'limit': check_limit})
+@provide_session
+def patch_dags(limit, session, offset=0, only_active=True, tags=None, 
dag_id_pattern=None, update_mask=None):
+    """Patch multiple DAGs."""
+    if only_active:
+        dags_query = session.query(DagModel).filter(~DagModel.is_subdag, 
DagModel.is_active)
+    else:
+        dags_query = session.query(DagModel).filter(~DagModel.is_subdag)
+
+    if dag_id_pattern:
+        dags_query = 
dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%'))
+
+    editable_dags = current_app.appbuilder.sm.get_editable_dag_ids(g.user)
+
+    dags_query = dags_query.filter(DagModel.dag_id.in_(editable_dags))
+    if tags:
+        cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags]
+        dags_query = dags_query.filter(or_(*cond))
+
+    total_entries = len(dags_query.all())
+
+    dags = 
dags_query.order_by(DagModel.dag_id).offset(offset).limit(limit).all()
+
+    try:
+        patch_body = dag_schema.load(request.json, session=session)
+    except ValidationError as err:
+        raise BadRequest("Invalid Dag schema", detail=str(err.messages))
+    if update_mask:
+        patch_body_ = {}
+        if len(update_mask) > 1:
+            raise BadRequest(detail="Only `is_paused` field can be updated 
through the REST API")
+        update_mask = update_mask[0]
+        if update_mask != 'is_paused':
+            raise BadRequest(detail="Only `is_paused` field can be updated 
through the REST API")
+        patch_body_[update_mask] = patch_body[update_mask]
+        patch_body = patch_body_
+    dags_to_update = {dag.dag_id for dag in dags}
+    session.query(DagModel).filter(DagModel.dag_id.in_(dags_to_update)).update(
+        {DagModel.is_paused: patch_body['is_paused']}, 
synchronize_session='fetch'
+    )
+
+    session.commit()

Review comment:
       This can be a `session.flush()` and delay the commit (which is done 
after the function finishes, by `provide_session`).

##########
File path: airflow/api_connexion/endpoints/dag_endpoint.py
##########
@@ -109,6 +109,53 @@ def patch_dag(session, dag_id, update_mask=None):
     return dag_schema.dump(dag)
 
 
[email protected]_access([(permissions.ACTION_CAN_READ, 
permissions.RESOURCE_DAG)])
+@format_parameters({'limit': check_limit})
+@provide_session
+def patch_dags(limit, session, offset=0, only_active=True, tags=None, 
dag_id_pattern=None, update_mask=None):
+    """Patch multiple DAGs."""
+    if only_active:
+        dags_query = session.query(DagModel).filter(~DagModel.is_subdag, 
DagModel.is_active)
+    else:
+        dags_query = session.query(DagModel).filter(~DagModel.is_subdag)
+
+    if dag_id_pattern:
+        dags_query = 
dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%'))
+
+    editable_dags = current_app.appbuilder.sm.get_editable_dag_ids(g.user)
+
+    dags_query = dags_query.filter(DagModel.dag_id.in_(editable_dags))
+    if tags:
+        cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags]
+        dags_query = dags_query.filter(or_(*cond))
+
+    total_entries = len(dags_query.all())
+
+    dags = 
dags_query.order_by(DagModel.dag_id).offset(offset).limit(limit).all()
+
+    try:
+        patch_body = dag_schema.load(request.json, session=session)
+    except ValidationError as err:
+        raise BadRequest("Invalid Dag schema", detail=str(err.messages))
+    if update_mask:
+        patch_body_ = {}
+        if len(update_mask) > 1:
+            raise BadRequest(detail="Only `is_paused` field can be updated 
through the REST API")
+        update_mask = update_mask[0]
+        if update_mask != 'is_paused':
+            raise BadRequest(detail="Only `is_paused` field can be updated 
through the REST API")

Review comment:
       Can this repeated check be simplied to this?
   
   ```python
   if update_mask != ["is_paused"]:
       raise ...
   ```

##########
File path: airflow/api_connexion/endpoints/dag_endpoint.py
##########
@@ -109,6 +109,53 @@ def patch_dag(session, dag_id, update_mask=None):
     return dag_schema.dump(dag)
 
 
[email protected]_access([(permissions.ACTION_CAN_READ, 
permissions.RESOURCE_DAG)])
+@format_parameters({'limit': check_limit})
+@provide_session
+def patch_dags(limit, session, offset=0, only_active=True, tags=None, 
dag_id_pattern=None, update_mask=None):
+    """Patch multiple DAGs."""
+    if only_active:
+        dags_query = session.query(DagModel).filter(~DagModel.is_subdag, 
DagModel.is_active)
+    else:
+        dags_query = session.query(DagModel).filter(~DagModel.is_subdag)
+
+    if dag_id_pattern:
+        dags_query = 
dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%'))
+
+    editable_dags = current_app.appbuilder.sm.get_editable_dag_ids(g.user)
+
+    dags_query = dags_query.filter(DagModel.dag_id.in_(editable_dags))
+    if tags:
+        cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags]
+        dags_query = dags_query.filter(or_(*cond))
+
+    total_entries = len(dags_query.all())
+
+    dags = 
dags_query.order_by(DagModel.dag_id).offset(offset).limit(limit).all()
+
+    try:
+        patch_body = dag_schema.load(request.json, session=session)
+    except ValidationError as err:
+        raise BadRequest("Invalid Dag schema", detail=str(err.messages))
+    if update_mask:
+        patch_body_ = {}
+        if len(update_mask) > 1:
+            raise BadRequest(detail="Only `is_paused` field can be updated 
through the REST API")
+        update_mask = update_mask[0]
+        if update_mask != 'is_paused':
+            raise BadRequest(detail="Only `is_paused` field can be updated 
through the REST API")
+        patch_body_[update_mask] = patch_body[update_mask]
+        patch_body = patch_body_

Review comment:
       This block should be moved to the beginning of the endpoint, so an 
invalid payload would not need to trigger database queries unnecessarily.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to