This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 098765e Allow setting specific cwd for BashOperator (#17751)
098765e is described below
commit 098765e227d4ab7873a2f675845b50c633e356da
Author: Lionel <[email protected]>
AuthorDate: Tue Aug 31 13:12:59 2021 +0800
Allow setting specific cwd for BashOperator (#17751)
Co-authored-by: Tzu-ping Chung <[email protected]>
---
airflow/hooks/subprocess.py | 20 ++++++++++++++------
airflow/operators/bash.py | 11 +++++++++++
tests/operators/test_bash.py | 35 ++++++++++++++++++++++++++++++++++-
3 files changed, 59 insertions(+), 7 deletions(-)
diff --git a/airflow/hooks/subprocess.py b/airflow/hooks/subprocess.py
index 1c6aec4..c54a818 100644
--- a/airflow/hooks/subprocess.py
+++ b/airflow/hooks/subprocess.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+import contextlib
import os
import signal
from collections import namedtuple
@@ -35,23 +35,31 @@ class SubprocessHook(BaseHook):
super().__init__()
def run_command(
- self, command: List[str], env: Optional[Dict[str, str]] = None,
output_encoding: str = 'utf-8'
+ self,
+ command: List[str],
+ env: Optional[Dict[str, str]] = None,
+ output_encoding: str = 'utf-8',
+ cwd: str = None,
) -> SubprocessResult:
"""
- Execute the command in a temporary directory which will be cleaned
afterwards
+ Execute the command.
+ If ``cwd`` is None, execute the command in a temporary directory which
will be cleaned afterwards.
If ``env`` is not supplied, ``os.environ`` is passed
:param command: the command to run
:param env: Optional dict containing environment variables to be made
available to the shell
environment in which ``command`` will be executed. If omitted,
``os.environ`` will be used.
:param output_encoding: encoding to use for decoding stdout
+ :param cwd: Working directory to run the command in.
+ If None (default), the command is run in a temporary directory.
:return: :class:`namedtuple` containing ``exit_code`` and ``output``,
the last line from stderr
or stdout
"""
self.log.info('Tmp dir root location: \n %s', gettempdir())
-
- with TemporaryDirectory(prefix='airflowtmp') as tmp_dir:
+ with contextlib.ExitStack() as stack:
+ if cwd is None:
+ cwd =
stack.enter_context(TemporaryDirectory(prefix='airflowtmp'))
def pre_exec():
# Restore default signal disposition and invoke setsid
@@ -66,7 +74,7 @@ class SubprocessHook(BaseHook):
command,
stdout=PIPE,
stderr=STDOUT,
- cwd=tmp_dir,
+ cwd=cwd,
env=env if env or env == {} else os.environ,
preexec_fn=pre_exec,
)
diff --git a/airflow/operators/bash.py b/airflow/operators/bash.py
index c32cd5d..a551a82 100644
--- a/airflow/operators/bash.py
+++ b/airflow/operators/bash.py
@@ -50,6 +50,9 @@ class BashOperator(BaseOperator):
in ``skipped`` state (default: 99). If set to ``None``, any non-zero
exit code will be treated as a failure.
:type skip_exit_code: int
+ :param cwd: Working directory to execute the command in.
+ If None (default), the command is run in a temporary directory.
+ :type cwd: str
Airflow will evaluate the exit code of the bash command. In general, a
non-zero exit code will result in
task failure and zero will result in task success. Exit code ``99`` (or
another set in ``skip_exit_code``)
@@ -134,6 +137,7 @@ class BashOperator(BaseOperator):
env: Optional[Dict[str, str]] = None,
output_encoding: str = 'utf-8',
skip_exit_code: int = 99,
+ cwd: str = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -141,6 +145,7 @@ class BashOperator(BaseOperator):
self.env = env
self.output_encoding = output_encoding
self.skip_exit_code = skip_exit_code
+ self.cwd = cwd
if kwargs.get('xcom_push') is not None:
raise AirflowException("'xcom_push' was deprecated, use
'BaseOperator.do_xcom_push' instead")
@@ -164,11 +169,17 @@ class BashOperator(BaseOperator):
return env
def execute(self, context):
+ if self.cwd is not None:
+ if not os.path.exists(self.cwd):
+ raise AirflowException(f"Can not find the cwd: {self.cwd}")
+ if not os.path.isdir(self.cwd):
+ raise AirflowException(f"The cwd {self.cwd} must be a
directory")
env = self.get_env(context)
result = self.subprocess_hook.run_command(
command=['bash', '-c', self.bash_command],
env=env,
output_encoding=self.output_encoding,
+ cwd=self.cwd,
)
if self.skip_exit_code is not None and result.exit_code ==
self.skip_exit_code:
raise AirflowSkipException(f"Bash command returned exit code
{self.skip_exit_code}. Skipping.")
diff --git a/tests/operators/test_bash.py b/tests/operators/test_bash.py
index 9810272..36e4744 100644
--- a/tests/operators/test_bash.py
+++ b/tests/operators/test_bash.py
@@ -18,7 +18,7 @@
import unittest
from datetime import datetime, timedelta
-from tempfile import NamedTemporaryFile
+from tempfile import NamedTemporaryFile, TemporaryDirectory
from unittest import mock
import pytest
@@ -123,6 +123,39 @@ class TestBashOperator(unittest.TestCase):
):
BashOperator(task_id='abc', bash_command='set -e;
something-that-isnt-on-path').execute({})
+ def test_unset_cwd(self):
+ val = "xxxx"
+ op = BashOperator(task_id='abc', bash_command=f'set -e; echo "{val}";')
+ line = op.execute({})
+ assert line == val
+
+ def test_cwd_does_not_exist(self):
+ test_cmd = 'set -e; echo "xxxx" |tee outputs.txt'
+ with TemporaryDirectory(prefix='test_command_with_cwd') as tmp_dir:
+ # Get a nonexistent temporary directory to do the test
+ pass
+ # There should be no exceptions when creating the operator even the
`cwd` doesn't exist
+ bash_operator = BashOperator(task_id='abc', bash_command=test_cmd,
cwd=tmp_dir)
+ with pytest.raises(AirflowException, match=f"Can not find the cwd:
{tmp_dir}"):
+ bash_operator.execute({})
+
+ def test_cwd_is_file(self):
+ test_cmd = 'set -e; echo "xxxx" |tee outputs.txt'
+ with NamedTemporaryFile(suffix="var.env") as tmp_file:
+ # Test if the cwd is a file_path
+ with pytest.raises(AirflowException, match=f"The cwd
{tmp_file.name} must be a directory"):
+ BashOperator(task_id='abc', bash_command=test_cmd,
cwd=tmp_file.name).execute({})
+
+ def test_valid_cwd(self):
+
+ test_cmd = 'set -e; echo "xxxx" |tee outputs.txt'
+ with TemporaryDirectory(prefix='test_command_with_cwd') as
test_cwd_folder:
+ # Test everything went alright
+ result = BashOperator(task_id='abc', bash_command=test_cmd,
cwd=test_cwd_folder).execute({})
+ assert result == "xxxx"
+ with open(f'{test_cwd_folder}/outputs.txt') as tmp_file:
+ assert tmp_file.read().splitlines()[0] == "xxxx"
+
@parameterized.expand(
[
(None, 99, AirflowSkipException),