This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 098765e  Allow setting specific cwd for BashOperator (#17751)
098765e is described below

commit 098765e227d4ab7873a2f675845b50c633e356da
Author: Lionel <[email protected]>
AuthorDate: Tue Aug 31 13:12:59 2021 +0800

    Allow setting specific cwd for BashOperator (#17751)
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
---
 airflow/hooks/subprocess.py  | 20 ++++++++++++++------
 airflow/operators/bash.py    | 11 +++++++++++
 tests/operators/test_bash.py | 35 ++++++++++++++++++++++++++++++++++-
 3 files changed, 59 insertions(+), 7 deletions(-)

diff --git a/airflow/hooks/subprocess.py b/airflow/hooks/subprocess.py
index 1c6aec4..c54a818 100644
--- a/airflow/hooks/subprocess.py
+++ b/airflow/hooks/subprocess.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
+import contextlib
 import os
 import signal
 from collections import namedtuple
@@ -35,23 +35,31 @@ class SubprocessHook(BaseHook):
         super().__init__()
 
     def run_command(
-        self, command: List[str], env: Optional[Dict[str, str]] = None, 
output_encoding: str = 'utf-8'
+        self,
+        command: List[str],
+        env: Optional[Dict[str, str]] = None,
+        output_encoding: str = 'utf-8',
+        cwd: str = None,
     ) -> SubprocessResult:
         """
-        Execute the command in a temporary directory which will be cleaned 
afterwards
+        Execute the command.
 
+        If ``cwd`` is None, execute the command in a temporary directory which 
will be cleaned afterwards.
         If ``env`` is not supplied, ``os.environ`` is passed
 
         :param command: the command to run
         :param env: Optional dict containing environment variables to be made 
available to the shell
             environment in which ``command`` will be executed.  If omitted, 
``os.environ`` will be used.
         :param output_encoding: encoding to use for decoding stdout
+        :param cwd: Working directory to run the command in.
+            If None (default), the command is run in a temporary directory.
         :return: :class:`namedtuple` containing ``exit_code`` and ``output``, 
the last line from stderr
             or stdout
         """
         self.log.info('Tmp dir root location: \n %s', gettempdir())
-
-        with TemporaryDirectory(prefix='airflowtmp') as tmp_dir:
+        with contextlib.ExitStack() as stack:
+            if cwd is None:
+                cwd = 
stack.enter_context(TemporaryDirectory(prefix='airflowtmp'))
 
             def pre_exec():
                 # Restore default signal disposition and invoke setsid
@@ -66,7 +74,7 @@ class SubprocessHook(BaseHook):
                 command,
                 stdout=PIPE,
                 stderr=STDOUT,
-                cwd=tmp_dir,
+                cwd=cwd,
                 env=env if env or env == {} else os.environ,
                 preexec_fn=pre_exec,
             )
diff --git a/airflow/operators/bash.py b/airflow/operators/bash.py
index c32cd5d..a551a82 100644
--- a/airflow/operators/bash.py
+++ b/airflow/operators/bash.py
@@ -50,6 +50,9 @@ class BashOperator(BaseOperator):
         in ``skipped`` state (default: 99). If set to ``None``, any non-zero
         exit code will be treated as a failure.
     :type skip_exit_code: int
+    :param cwd: Working directory to execute the command in.
+        If None (default), the command is run in a temporary directory.
+    :type cwd: str
 
     Airflow will evaluate the exit code of the bash command. In general, a 
non-zero exit code will result in
     task failure and zero will result in task success. Exit code ``99`` (or 
another set in ``skip_exit_code``)
@@ -134,6 +137,7 @@ class BashOperator(BaseOperator):
         env: Optional[Dict[str, str]] = None,
         output_encoding: str = 'utf-8',
         skip_exit_code: int = 99,
+        cwd: str = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -141,6 +145,7 @@ class BashOperator(BaseOperator):
         self.env = env
         self.output_encoding = output_encoding
         self.skip_exit_code = skip_exit_code
+        self.cwd = cwd
         if kwargs.get('xcom_push') is not None:
             raise AirflowException("'xcom_push' was deprecated, use 
'BaseOperator.do_xcom_push' instead")
 
@@ -164,11 +169,17 @@ class BashOperator(BaseOperator):
         return env
 
     def execute(self, context):
+        if self.cwd is not None:
+            if not os.path.exists(self.cwd):
+                raise AirflowException(f"Can not find the cwd: {self.cwd}")
+            if not os.path.isdir(self.cwd):
+                raise AirflowException(f"The cwd {self.cwd} must be a 
directory")
         env = self.get_env(context)
         result = self.subprocess_hook.run_command(
             command=['bash', '-c', self.bash_command],
             env=env,
             output_encoding=self.output_encoding,
+            cwd=self.cwd,
         )
         if self.skip_exit_code is not None and result.exit_code == 
self.skip_exit_code:
             raise AirflowSkipException(f"Bash command returned exit code 
{self.skip_exit_code}. Skipping.")
diff --git a/tests/operators/test_bash.py b/tests/operators/test_bash.py
index 9810272..36e4744 100644
--- a/tests/operators/test_bash.py
+++ b/tests/operators/test_bash.py
@@ -18,7 +18,7 @@
 
 import unittest
 from datetime import datetime, timedelta
-from tempfile import NamedTemporaryFile
+from tempfile import NamedTemporaryFile, TemporaryDirectory
 from unittest import mock
 
 import pytest
@@ -123,6 +123,39 @@ class TestBashOperator(unittest.TestCase):
         ):
             BashOperator(task_id='abc', bash_command='set -e; 
something-that-isnt-on-path').execute({})
 
+    def test_unset_cwd(self):
+        val = "xxxx"
+        op = BashOperator(task_id='abc', bash_command=f'set -e; echo "{val}";')
+        line = op.execute({})
+        assert line == val
+
+    def test_cwd_does_not_exist(self):
+        test_cmd = 'set -e; echo "xxxx" |tee outputs.txt'
+        with TemporaryDirectory(prefix='test_command_with_cwd') as tmp_dir:
+            # Get a nonexistent temporary directory to do the test
+            pass
+        # There should be no exceptions when creating the operator even the 
`cwd` doesn't exist
+        bash_operator = BashOperator(task_id='abc', bash_command=test_cmd, 
cwd=tmp_dir)
+        with pytest.raises(AirflowException, match=f"Can not find the cwd: 
{tmp_dir}"):
+            bash_operator.execute({})
+
+    def test_cwd_is_file(self):
+        test_cmd = 'set -e; echo "xxxx" |tee outputs.txt'
+        with NamedTemporaryFile(suffix="var.env") as tmp_file:
+            # Test if the cwd is a file_path
+            with pytest.raises(AirflowException, match=f"The cwd 
{tmp_file.name} must be a directory"):
+                BashOperator(task_id='abc', bash_command=test_cmd, 
cwd=tmp_file.name).execute({})
+
+    def test_valid_cwd(self):
+
+        test_cmd = 'set -e; echo "xxxx" |tee outputs.txt'
+        with TemporaryDirectory(prefix='test_command_with_cwd') as 
test_cwd_folder:
+            # Test everything went alright
+            result = BashOperator(task_id='abc', bash_command=test_cmd, 
cwd=test_cwd_folder).execute({})
+            assert result == "xxxx"
+            with open(f'{test_cwd_folder}/outputs.txt') as tmp_file:
+                assert tmp_file.read().splitlines()[0] == "xxxx"
+
     @parameterized.expand(
         [
             (None, 99, AirflowSkipException),

Reply via email to