This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e58985598f Ensure @contextmanager decorates generator func (#23103)
e58985598f is described below
commit e58985598f202395098e15b686aec33645a906ff
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Mon May 30 03:24:08 2022 -0400
Ensure @contextmanager decorates generator func (#23103)
---
airflow/cli/commands/task_command.py | 4 ++--
airflow/models/taskinstance.py | 3 +--
airflow/providers/google/cloud/hooks/gcs.py | 19 ++++++++++++++++---
.../google/cloud/utils/credentials_provider.py | 9 ++++++---
airflow/providers/google/common/hooks/base_google.py | 10 +++++-----
airflow/providers/microsoft/psrp/hooks/psrp.py | 4 ++--
airflow/utils/db.py | 11 ++++++++---
airflow/utils/process_utils.py | 4 ++--
airflow/utils/session.py | 4 ++--
dev/breeze/src/airflow_breeze/utils/run_utils.py | 4 ++--
dev/provider_packages/prepare_provider_packages.py | 4 ++--
11 files changed, 48 insertions(+), 28 deletions(-)
diff --git a/airflow/cli/commands/task_command.py
b/airflow/cli/commands/task_command.py
index e054e4575c..f8caf08487 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -23,7 +23,7 @@ import logging
import os
import textwrap
from contextlib import contextmanager, redirect_stderr, redirect_stdout
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Dict, Generator, List, Optional, Tuple, Union
from pendulum.parsing.exceptions import ParserError
from sqlalchemy.orm.exc import NoResultFound
@@ -269,7 +269,7 @@ def _extract_external_executor_id(args) -> Optional[str]:
@contextmanager
-def _capture_task_logs(ti):
+def _capture_task_logs(ti: TaskInstance) -> Generator[None, None, None]:
"""Manage logging context for a task run
- Replace the root logger configuration with the airflow.task configuration
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index e160153af7..0f5d49b819 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -37,7 +37,6 @@ from typing import (
Dict,
Generator,
Iterable,
- Iterator,
List,
NamedTuple,
Optional,
@@ -139,7 +138,7 @@ if TYPE_CHECKING:
@contextlib.contextmanager
-def set_current_context(context: Context) -> Iterator[Context]:
+def set_current_context(context: Context) -> Generator[Context, None, None]:
"""
Sets the current execution context to the provided context object.
This method should be called once per Task execution, before calling
operator.execute.
diff --git a/airflow/providers/google/cloud/hooks/gcs.py
b/airflow/providers/google/cloud/hooks/gcs.py
index a336385c81..36e693b608 100644
--- a/airflow/providers/google/cloud/hooks/gcs.py
+++ b/airflow/providers/google/cloud/hooks/gcs.py
@@ -28,7 +28,20 @@ from functools import partial
from io import BytesIO
from os import path
from tempfile import NamedTemporaryFile
-from typing import Callable, List, Optional, Sequence, Set, Tuple, TypeVar,
Union, cast, overload
+from typing import (
+ IO,
+ Callable,
+ Generator,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ TypeVar,
+ Union,
+ cast,
+ overload,
+)
from urllib.parse import urlparse
from google.api_core.exceptions import NotFound
@@ -373,7 +386,7 @@ class GCSHook(GoogleBaseHook):
object_name: Optional[str] = None,
object_url: Optional[str] = None,
dir: Optional[str] = None,
- ):
+ ) -> Generator[IO[bytes], None, None]:
"""
Downloads the file to a temporary directory and returns a file handle
@@ -401,7 +414,7 @@ class GCSHook(GoogleBaseHook):
bucket_name: str = PROVIDE_BUCKET,
object_name: Optional[str] = None,
object_url: Optional[str] = None,
- ):
+ ) -> Generator[IO[bytes], None, None]:
"""
Creates temporary file, returns a file handle and uploads the files
content
on close.
diff --git a/airflow/providers/google/cloud/utils/credentials_provider.py
b/airflow/providers/google/cloud/utils/credentials_provider.py
index 0a8143ceae..1cf33ea70b 100644
--- a/airflow/providers/google/cloud/utils/credentials_provider.py
+++ b/airflow/providers/google/cloud/utils/credentials_provider.py
@@ -74,7 +74,10 @@ def build_gcp_conn(
@contextmanager
-def provide_gcp_credentials(key_file_path: Optional[str] = None,
key_file_dict: Optional[Dict] = None):
+def provide_gcp_credentials(
+ key_file_path: Optional[str] = None,
+ key_file_dict: Optional[Dict] = None,
+) -> Generator[None, None, None]:
"""
Context manager that provides a Google Cloud credentials for application
supporting
`Application Default Credentials (ADC) strategy`__.
@@ -111,7 +114,7 @@ def provide_gcp_connection(
key_file_path: Optional[str] = None,
scopes: Optional[Sequence] = None,
project_id: Optional[str] = None,
-) -> Generator:
+) -> Generator[None, None, None]:
"""
Context manager that provides a temporary value of
:envvar:`AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT`
connection. It build a new connection that includes path to provided
service json,
@@ -135,7 +138,7 @@ def provide_gcp_conn_and_credentials(
key_file_path: Optional[str] = None,
scopes: Optional[Sequence] = None,
project_id: Optional[str] = None,
-) -> Generator:
+) -> Generator[None, None, None]:
"""
Context manager that provides both:
diff --git a/airflow/providers/google/common/hooks/base_google.py
b/airflow/providers/google/common/hooks/base_google.py
index f2c0d5157a..d9fe5daba5 100644
--- a/airflow/providers/google/common/hooks/base_google.py
+++ b/airflow/providers/google/common/hooks/base_google.py
@@ -25,7 +25,7 @@ import tempfile
import warnings
from contextlib import ExitStack, contextmanager
from subprocess import check_output
-from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TypeVar,
Union, cast
+from typing import Any, Callable, Dict, Generator, Optional, Sequence, Tuple,
TypeVar, Union, cast
import google.auth
import google.auth.credentials
@@ -459,7 +459,7 @@ class GoogleBaseHook(BaseHook):
return cast(T, wrapper)
@contextmanager
- def provide_gcp_credential_file_as_context(self):
+ def provide_gcp_credential_file_as_context(self) ->
Generator[Optional[str], None, None]:
"""
Context manager that provides a Google Cloud credentials for
application supporting `Application
Default Credentials (ADC) strategy
<https://cloud.google.com/docs/authentication/production>`__.
@@ -467,8 +467,8 @@ class GoogleBaseHook(BaseHook):
It can be used to provide credentials for external programs (e.g.
gcloud) that expect authorization
file in ``GOOGLE_APPLICATION_CREDENTIALS`` environment variable.
"""
- key_path = self._get_field('key_path', None) # type: Optional[str]
#
- keyfile_dict = self._get_field('keyfile_dict', None) # type:
Optional[Dict]
+ key_path: Optional[str] = self._get_field('key_path', None)
+ keyfile_dict: Optional[str] = self._get_field('keyfile_dict', None)
if key_path and keyfile_dict:
raise AirflowException(
"The `keyfile_dict` and `key_path` fields are mutually
exclusive. "
@@ -490,7 +490,7 @@ class GoogleBaseHook(BaseHook):
yield None
@contextmanager
- def provide_authorized_gcloud(self):
+ def provide_authorized_gcloud(self) -> Generator[None, None, None]:
"""
Provides a separate gcloud configuration with current credentials.
diff --git a/airflow/providers/microsoft/psrp/hooks/psrp.py
b/airflow/providers/microsoft/psrp/hooks/psrp.py
index 005f1e215d..0aebe63d03 100644
--- a/airflow/providers/microsoft/psrp/hooks/psrp.py
+++ b/airflow/providers/microsoft/psrp/hooks/psrp.py
@@ -19,7 +19,7 @@
from contextlib import contextmanager
from copy import copy
from logging import DEBUG, ERROR, INFO, WARNING
-from typing import Any, Callable, Dict, Iterator, Optional
+from typing import Any, Callable, Dict, Generator, Optional
from weakref import WeakKeyDictionary
from pypsrp.host import PSHost
@@ -155,7 +155,7 @@ class PsrpHook(BaseHook):
return pool
@contextmanager
- def invoke(self) -> Iterator[PowerShell]:
+ def invoke(self) -> Generator[PowerShell, None, None]:
"""
Context manager that yields a PowerShell object to which commands can
be
added. Upon exit, the commands will be invoked.
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 1576cec7db..1e2adbadbb 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -24,7 +24,7 @@ import time
import warnings
from dataclasses import dataclass
from tempfile import gettempdir
-from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple,
Union
+from typing import TYPE_CHECKING, Callable, Generator, Iterable, List,
Optional, Tuple, Union
from sqlalchemy import Table, and_, column, exc, func, inspect, or_, select,
table, text, tuple_
from sqlalchemy.orm.session import Session
@@ -68,6 +68,7 @@ from airflow.utils.session import NEW_SESSION,
create_session, provide_session
from airflow.version import version
if TYPE_CHECKING:
+ from alembic.runtime.environment import EnvironmentContext
from alembic.script import ScriptDirectory
from sqlalchemy.orm import Query
@@ -708,7 +709,7 @@ def check_migrations(timeout):
@contextlib.contextmanager
-def _configured_alembic_environment():
+def _configured_alembic_environment() -> Generator["EnvironmentContext", None,
None]:
from alembic.runtime.environment import EnvironmentContext
config = _get_alembic_config()
@@ -1605,7 +1606,11 @@ class DBLocks(enum.IntEnum):
@contextlib.contextmanager
-def create_global_lock(session: Session, lock: DBLocks, lock_timeout=1800):
+def create_global_lock(
+ session: Session,
+ lock: DBLocks,
+ lock_timeout: int = 1800,
+) -> Generator[None, None, None]:
"""Contextmanager that will create and teardown a global db lock."""
conn = session.get_bind().connect()
dialect = conn.dialect
diff --git a/airflow/utils/process_utils.py b/airflow/utils/process_utils.py
index 1cbb8e8b6c..d547f2c0de 100644
--- a/airflow/utils/process_utils.py
+++ b/airflow/utils/process_utils.py
@@ -34,7 +34,7 @@ if not IS_WINDOWS:
import pty
from contextlib import contextmanager
-from typing import Dict, List, Optional
+from typing import Dict, Generator, List, Optional
import psutil
from lockfile.pidlockfile import PIDLockFile
@@ -268,7 +268,7 @@ def kill_child_processes_by_pids(pids_to_kill: List[int],
timeout: int = 5) -> N
@contextmanager
-def patch_environ(new_env_variables: Dict[str, str]):
+def patch_environ(new_env_variables: Dict[str, str]) -> Generator[None, None,
None]:
"""
Sets environment variables in context. After leaving the context, it
restores its original state.
diff --git a/airflow/utils/session.py b/airflow/utils/session.py
index 3565e216a2..377ff55cbf 100644
--- a/airflow/utils/session.py
+++ b/airflow/utils/session.py
@@ -17,13 +17,13 @@
import contextlib
from functools import wraps
from inspect import signature
-from typing import Callable, Iterator, TypeVar, cast
+from typing import Callable, Generator, TypeVar, cast
from airflow import settings
@contextlib.contextmanager
-def create_session() -> Iterator[settings.SASession]:
+def create_session() -> Generator[settings.SASession, None, None]:
"""Contextmanager that will create and teardown a session."""
if not settings.Session:
raise RuntimeError("Session must be set before!")
diff --git a/dev/breeze/src/airflow_breeze/utils/run_utils.py
b/dev/breeze/src/airflow_breeze/utils/run_utils.py
index f297ecca46..54bb288795 100644
--- a/dev/breeze/src/airflow_breeze/utils/run_utils.py
+++ b/dev/breeze/src/airflow_breeze/utils/run_utils.py
@@ -25,7 +25,7 @@ from distutils.version import StrictVersion
from functools import lru_cache
from pathlib import Path
from re import match
-from typing import Dict, List, Mapping, Optional, Union
+from typing import Dict, Generator, List, Mapping, Optional, Union
from airflow_breeze.branch_defaults import AIRFLOW_BRANCH
from airflow_breeze.params._common_build_params import _CommonBuildParams
@@ -213,7 +213,7 @@ def instruct_build_image(python: str):
@contextlib.contextmanager
-def working_directory(source_path: Path):
+def working_directory(source_path: Path) -> Generator[None, None, None]:
"""
# Equivalent of pushd and popd in bash script.
# https://stackoverflow.com/a/42441759/3101838
diff --git a/dev/provider_packages/prepare_provider_packages.py
b/dev/provider_packages/prepare_provider_packages.py
index 1f7315447c..44498a9aed 100755
--- a/dev/provider_packages/prepare_provider_packages.py
+++ b/dev/provider_packages/prepare_provider_packages.py
@@ -39,7 +39,7 @@ from os.path import dirname, relpath
from pathlib import Path
from random import choice
from shutil import copyfile
-from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Set,
Tuple, Union
+from typing import Any, Dict, Generator, Iterable, List, NamedTuple, Optional,
Set, Tuple, Union
import jsonschema
import rich_click as click
@@ -196,7 +196,7 @@ argument_package_ids = click.argument('package_ids',
nargs=-1)
@contextmanager
-def with_group(title):
+def with_group(title: str) -> Generator[None, None, None]:
"""
If used in GitHub Action, creates an expandable group in the GitHub Action
log.
Otherwise, display simple text groups.