This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 73fcbb0 Refactor SSHOperator so a subclass can run many commands
(#10874) (#17378)
73fcbb0 is described below
commit 73fcbb0e4e151c9965fd69ba08de59462bbbe6dc
Author: Bjorn Olsen <[email protected]>
AuthorDate: Wed Oct 13 22:14:54 2021 +0200
Refactor SSHOperator so a subclass can run many commands (#10874) (#17378)
---
airflow/providers/ssh/operators/ssh.py | 204 ++++++++++++++++--------------
tests/providers/ssh/operators/test_ssh.py | 112 ++++++++++++++--
2 files changed, 211 insertions(+), 105 deletions(-)
diff --git a/airflow/providers/ssh/operators/ssh.py
b/airflow/providers/ssh/operators/ssh.py
index 7f97b03..300e155 100644
--- a/airflow/providers/ssh/operators/ssh.py
+++ b/airflow/providers/ssh/operators/ssh.py
@@ -19,7 +19,9 @@
import warnings
from base64 import b64encode
from select import select
-from typing import Optional, Union
+from typing import Optional, Tuple, Union
+
+from paramiko.client import SSHClient
from airflow.configuration import conf
from airflow.exceptions import AirflowException
@@ -107,103 +109,115 @@ class SSHOperator(BaseOperator):
stacklevel=1,
)
- def execute(self, context) -> Union[bytes, str, bool]:
+ def get_hook(self) -> SSHHook:
+ if self.ssh_conn_id:
+ if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
+ self.log.info("ssh_conn_id is ignored when ssh_hook is
provided.")
+ else:
+ self.log.info("ssh_hook is not provided or invalid. Trying
ssh_conn_id to create SSHHook.")
+ self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id,
conn_timeout=self.conn_timeout)
+
+ if not self.ssh_hook:
+ raise AirflowException("Cannot operate without ssh_hook or
ssh_conn_id.")
+
+ if self.remote_host is not None:
+ self.log.info(
+ "remote_host is provided explicitly. "
+ "It will replace the remote_host which was defined "
+ "in ssh_hook or predefined in connection of ssh_conn_id."
+ )
+ self.ssh_hook.remote_host = self.remote_host
+
+ return self.ssh_hook
+
+ def get_ssh_client(self) -> SSHClient:
+ # Remember to use context manager or call .close() on this when done
+ self.log.info('Creating ssh_client')
+ return self.get_hook().get_conn()
+
+ def exec_ssh_client_command(self, ssh_client: SSHClient, command: str) ->
Tuple[int, bytes, bytes]:
+ self.log.info("Running command: %s", command)
+
+ # set timeout taken as params
+ stdin, stdout, stderr = ssh_client.exec_command(
+ command=command,
+ get_pty=self.get_pty,
+ timeout=self.timeout,
+ environment=self.environment,
+ )
+ # get channels
+ channel = stdout.channel
+
+ # closing stdin
+ stdin.close()
+ channel.shutdown_write()
+
+ agg_stdout = b''
+ agg_stderr = b''
+
+ # capture any initial output in case channel is closed already
+ stdout_buffer_length = len(stdout.channel.in_buffer)
+
+ if stdout_buffer_length > 0:
+ agg_stdout += stdout.channel.recv(stdout_buffer_length)
+
+ # read from both stdout and stderr
+ while not channel.closed or channel.recv_ready() or
channel.recv_stderr_ready():
+ readq, _, _ = select([channel], [], [], self.cmd_timeout)
+ for recv in readq:
+ if recv.recv_ready():
+ line = stdout.channel.recv(len(recv.in_buffer))
+ agg_stdout += line
+ self.log.info(line.decode('utf-8', 'replace').strip('\n'))
+ if recv.recv_stderr_ready():
+ line =
stderr.channel.recv_stderr(len(recv.in_stderr_buffer))
+ agg_stderr += line
+ self.log.warning(line.decode('utf-8',
'replace').strip('\n'))
+ if (
+ stdout.channel.exit_status_ready()
+ and not stderr.channel.recv_stderr_ready()
+ and not stdout.channel.recv_ready()
+ ):
+ stdout.channel.shutdown_read()
+ try:
+ stdout.channel.close()
+ except Exception:
+ # there is a race that when shutdown_read has been called
and when
+ # you try to close the connection, the socket is already
closed
+ # We should ignore such errors (but we should log them
with warning)
+ self.log.warning("Ignoring exception on close",
exc_info=True)
+ break
+
+ stdout.close()
+ stderr.close()
+
+ exit_status = stdout.channel.recv_exit_status()
+
+ return exit_status, agg_stdout, agg_stderr
+
+ def raise_for_status(self, exit_status: int, stderr: bytes) -> None:
+ if exit_status != 0:
+ error_msg = stderr.decode('utf-8')
+ raise AirflowException(f"error running cmd: {self.command}, error:
{error_msg}")
+
+ def run_ssh_client_command(self, ssh_client: SSHClient, command: str) ->
bytes:
+ exit_status, agg_stdout, agg_stderr =
self.exec_ssh_client_command(ssh_client, command)
+ self.raise_for_status(exit_status, agg_stderr)
+ return agg_stdout
+
+ def execute(self, context=None) -> Union[bytes, str]:
+ result = None
+ if self.command is None:
+ raise AirflowException("SSH operator error: SSH command not
specified. Aborting.")
try:
- if self.ssh_conn_id:
- if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
- self.log.info("ssh_conn_id is ignored when ssh_hook is
provided.")
- else:
- self.log.info(
- "ssh_hook is not provided or invalid. Trying
ssh_conn_id to create SSHHook."
- )
- self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id,
conn_timeout=self.conn_timeout)
-
- if not self.ssh_hook:
- raise AirflowException("Cannot operate without ssh_hook or
ssh_conn_id.")
-
- if self.remote_host is not None:
- self.log.info(
- "remote_host is provided explicitly. "
- "It will replace the remote_host which was defined "
- "in ssh_hook or predefined in connection of ssh_conn_id."
- )
- self.ssh_hook.remote_host = self.remote_host
-
- if not self.command:
- raise AirflowException("SSH command not specified. Aborting.")
-
- with self.ssh_hook.get_conn() as ssh_client:
- self.log.info("Running command: %s", self.command)
-
- # set timeout taken as params
- stdin, stdout, stderr = ssh_client.exec_command(
- command=self.command,
- get_pty=self.get_pty,
- timeout=self.cmd_timeout,
- environment=self.environment,
- )
- # get channels
- channel = stdout.channel
-
- # closing stdin
- stdin.close()
- channel.shutdown_write()
-
- agg_stdout = b''
- agg_stderr = b''
-
- # capture any initial output in case channel is closed already
- stdout_buffer_length = len(stdout.channel.in_buffer)
-
- if stdout_buffer_length > 0:
- agg_stdout += stdout.channel.recv(stdout_buffer_length)
-
- # read from both stdout and stderr
- while not channel.closed or channel.recv_ready() or
channel.recv_stderr_ready():
- readq, _, _ = select([channel], [], [], self.cmd_timeout)
- for recv in readq:
- if recv.recv_ready():
- line = stdout.channel.recv(len(recv.in_buffer))
- agg_stdout += line
- self.log.info(line.decode('utf-8',
'replace').strip('\n'))
- if recv.recv_stderr_ready():
- line =
stderr.channel.recv_stderr(len(recv.in_stderr_buffer))
- agg_stderr += line
- self.log.warning(line.decode('utf-8',
'replace').strip('\n'))
- if (
- stdout.channel.exit_status_ready()
- and not stderr.channel.recv_stderr_ready()
- and not stdout.channel.recv_ready()
- ):
- stdout.channel.shutdown_read()
- try:
- stdout.channel.close()
- except Exception:
- # there is a race that when shutdown_read has been
called and when
- # you try to close the connection, the socket is
already closed
- # We should ignore such errors (but we should log
them with warning)
- self.log.warning("Ignoring exception on close",
exc_info=True)
- break
-
- stdout.close()
- stderr.close()
-
- exit_status = stdout.channel.recv_exit_status()
- if exit_status == 0:
- enable_pickling = conf.getboolean('core',
'enable_xcom_pickling')
- if enable_pickling:
- return agg_stdout
- else:
- return b64encode(agg_stdout).decode('utf-8')
-
- else:
- error_msg = agg_stderr.decode('utf-8')
- raise AirflowException(f"error running cmd:
{self.command}, error: {error_msg}")
-
+ with self.get_ssh_client() as ssh_client:
+ result = self.run_ssh_client_command(ssh_client, self.command)
except Exception as e:
raise AirflowException(f"SSH operator error: {str(e)}")
-
- return True
+ enable_pickling = conf.getboolean('core', 'enable_xcom_pickling')
+ if not enable_pickling:
+ result = b64encode(result).decode('utf-8')
+ return result
def tunnel(self) -> None:
"""Get ssh tunnel"""
diff --git a/tests/providers/ssh/operators/test_ssh.py
b/tests/providers/ssh/operators/test_ssh.py
index c477416..551b64b 100644
--- a/tests/providers/ssh/operators/test_ssh.py
+++ b/tests/providers/ssh/operators/test_ssh.py
@@ -38,18 +38,28 @@ COMMAND = "echo -n airflow"
COMMAND_WITH_SUDO = "sudo " + COMMAND
+class SSHClientSideEffect:
+ def __init__(self, hook):
+ self.hook = hook
+
+ def __call__(self):
+ self.return_value = self.hook.get_conn()
+ return self.return_value
+
+
class TestSSHOperator:
def setup_method(self):
from airflow.providers.ssh.hooks.ssh import SSHHook
hook = SSHHook(ssh_conn_id='ssh_default')
hook.no_host_key_check = True
+ self.dag = DAG('ssh_test', default_args={'start_date': DEFAULT_DATE})
self.hook = hook
def test_hook_created_correctly_with_timeout(self):
timeout = 20
ssh_id = "ssh_default"
- with DAG('unit_tests_ssh_test_op_arg_checking',
default_args={'start_date': DEFAULT_DATE}):
+ with self.dag:
task = SSHOperator(task_id="test", command=COMMAND,
timeout=timeout, ssh_conn_id="ssh_default")
task.execute(None)
assert timeout == task.ssh_hook.conn_timeout
@@ -59,7 +69,7 @@ class TestSSHOperator:
conn_timeout = 20
cmd_timeout = 45
ssh_id = 'ssh_default'
- with DAG('unit_tests_ssh_test_op_arg_checking',
default_args={'start_date': DEFAULT_DATE}):
+ with self.dag:
task = SSHOperator(
task_id="test",
command=COMMAND,
@@ -130,10 +140,8 @@ class TestSSHOperator:
@unittest.mock.patch('os.environ', {'AIRFLOW_CONN_' +
TEST_CONN_ID.upper(): "ssh://test_id@localhost"})
def test_arg_checking(self):
- dag = DAG('unit_tests_ssh_test_op_arg_checking',
default_args={'start_date': DEFAULT_DATE})
-
# Exception should be raised if neither ssh_hook nor ssh_conn_id is
provided.
- task_0 = SSHOperator(task_id="test", command=COMMAND, timeout=TIMEOUT,
dag=dag)
+ task_0 = SSHOperator(task_id="test", command=COMMAND, timeout=TIMEOUT,
dag=self.dag)
with pytest.raises(AirflowException, match="Cannot operate without
ssh_hook or ssh_conn_id."):
task_0.execute(None)
@@ -144,7 +152,7 @@ class TestSSHOperator:
ssh_conn_id=TEST_CONN_ID,
command=COMMAND,
timeout=TIMEOUT,
- dag=dag,
+ dag=self.dag,
)
try:
task_1.execute(None)
@@ -157,7 +165,7 @@ class TestSSHOperator:
ssh_conn_id=TEST_CONN_ID, # No ssh_hook provided.
command=COMMAND,
timeout=TIMEOUT,
- dag=dag,
+ dag=self.dag,
)
try:
task_2.execute(None)
@@ -172,10 +180,26 @@ class TestSSHOperator:
ssh_conn_id=TEST_CONN_ID,
command=COMMAND,
timeout=TIMEOUT,
- dag=dag,
+ dag=self.dag,
)
task_3.execute(None)
assert task_3.ssh_hook.ssh_conn_id == self.hook.ssh_conn_id
+ # If remote_host was specified, ensure it is used
+ task_4 = SSHOperator(
+ task_id="test_4",
+ ssh_hook=self.hook,
+ ssh_conn_id=TEST_CONN_ID,
+ command=COMMAND,
+ timeout=TIMEOUT,
+ dag=self.dag,
+ remote_host='operator_remote_host',
+ )
+ try:
+ task_4.execute(None)
+ except Exception:
+ pass
+ assert task_4.ssh_hook.ssh_conn_id == self.hook.ssh_conn_id
+ assert task_4.ssh_hook.remote_host == 'operator_remote_host'
@pytest.mark.parametrize(
"command, get_pty_in, get_pty_out",
@@ -188,7 +212,6 @@ class TestSSHOperator:
],
)
def test_get_pyt_set_correctly(self, command, get_pty_in, get_pty_out):
- dag = DAG('unit_tests_ssh_test_op_arg_checking',
default_args={'start_date': DEFAULT_DATE})
task = SSHOperator(
task_id="test",
ssh_hook=self.hook,
@@ -196,7 +219,7 @@ class TestSSHOperator:
conn_timeout=TIMEOUT,
cmd_timeout=TIMEOUT,
get_pty=get_pty_in,
- dag=dag,
+ dag=self.dag,
)
if command is None:
with pytest.raises(AirflowException) as ctx:
@@ -205,3 +228,72 @@ class TestSSHOperator:
else:
task.execute(None)
assert task.get_pty == get_pty_out
+
+ def test_ssh_client_managed_correctly(self):
+ # Ensure ssh_client gets created once
+ # Ensure connection gets closed once
+ task = SSHOperator(
+ task_id="test",
+ ssh_hook=self.hook,
+ command="ls",
+ dag=self.dag,
+ )
+
+ se = SSHClientSideEffect(self.hook)
+ with unittest.mock.patch.object(task, 'get_ssh_client') as mock_get,
unittest.mock.patch(
+ 'paramiko.client.SSHClient.close'
+ ) as mock_close:
+ mock_get.side_effect = se
+ task.execute()
+ mock_get.assert_called_once()
+ mock_close.assert_called_once()
+
+ def test_one_ssh_client_many_commands(self):
+ # Ensure we can run multiple commands with one client
+ many_commands = ['ls', 'date', 'pwd']
+
+ class CustomSSHOperator(SSHOperator):
+ def execute(self, context=None):
+ success = False
+ with self.get_ssh_client() as ssh_client:
+ for c in many_commands:
+ self.run_ssh_client_command(ssh_client, c)
+ success = True
+ return success
+
+ task = CustomSSHOperator(task_id="test", ssh_hook=self.hook,
dag=self.dag)
+ se = SSHClientSideEffect(self.hook)
+ with unittest.mock.patch.object(task, 'get_ssh_client') as mock_get,
unittest.mock.patch.object(
+ task, 'run_ssh_client_command'
+ ) as mock_run_cmd,
unittest.mock.patch('paramiko.client.SSHClient.close') as mock_close:
+ mock_get.side_effect = se
+ task.execute()
+ mock_get.assert_called_once()
+ mock_close.assert_called_once()
+
+ ssh_client = se.return_value
+ calls = [unittest.mock.call(ssh_client, c) for c in many_commands]
+ mock_run_cmd.assert_has_calls(calls)
+
+ def test_fail_with_no_command(self):
+ # Test that run_ssh_client_command fails on no command
+ task = SSHOperator(
+ task_id="test",
+ ssh_hook=self.hook,
+ # command="ls",
+ dag=self.dag,
+ )
+ with pytest.raises(AirflowException, match="SSH command not specified.
Aborting."):
+ task.execute(None)
+
+ def test_command_errored(self):
+ # Test that run_ssh_client_command works on invalid commands
+ command = "not_a_real_command"
+ task = SSHOperator(
+ task_id="test",
+ ssh_hook=self.hook,
+ command=command,
+ dag=self.dag,
+ )
+ with pytest.raises(AirflowException, match=f"error running cmd:
{command}, error: .*"):
+ task.execute(None)