Reviewed-by: Nicholas Pratte <npra...@iol.unh.edu>
On Wed, May 29, 2024 at 3:49 PM <jspew...@iol.unh.edu> wrote: > > From: Jeremy Spewock <jspew...@iol.unh.edu> > > The current implementation of consuming output from interactive shells > relies on being able to find an expected prompt somewhere within the > output buffer after sending the command. This is useful in situations > where the prompt does not appear in the output itself, but in some > practical cases (such as the starting of an XML-RPC server for scapy) > the prompt exists in one of the commands sent to the shell and this can > cause the command to exit early and creates a race condition between the > server starting and the first command being sent to the server. > > This patch addresses this problem by searching for a line that strictly > ends with the provided prompt, rather than one that simply contains it, > so that the detection that a command is finished is more consistent. It > also adds a catch to detect when a command times out before finding the > prompt or the underlying SSH session dies so that the exception can be > wrapped into a more explicit one and be more consistent with the > non-interactive shells. > > Bugzilla ID: 1359 > Fixes: 88489c0501af ("dts: add smoke tests") > > Signed-off-by: Jeremy Spewock <jspew...@iol.unh.edu> > --- > dts/framework/exception.py | 66 ++++++++++++------- > .../remote_session/interactive_shell.py | 51 ++++++++++---- > 2 files changed, 81 insertions(+), 36 deletions(-) > > diff --git a/dts/framework/exception.py b/dts/framework/exception.py > index cce1e0231a..627190c781 100644 > --- a/dts/framework/exception.py > +++ b/dts/framework/exception.py > @@ -48,26 +48,6 @@ class DTSError(Exception): > severity: ClassVar[ErrorSeverity] = ErrorSeverity.GENERIC_ERR > > > -class SSHTimeoutError(DTSError): > - """The SSH execution of a command timed out.""" > - > - #: > - severity: ClassVar[ErrorSeverity] = ErrorSeverity.SSH_ERR > - _command: str > - > - def __init__(self, command: str): > - """Define the meaning of the first argument. > - > - Args: > - command: The executed command. > - """ > - self._command = command > - > - def __str__(self) -> str: > - """Add some context to the string representation.""" > - return f"{self._command} execution timed out." > - > - > class SSHConnectionError(DTSError): > """An unsuccessful SSH connection.""" > > @@ -95,8 +75,42 @@ def __str__(self) -> str: > return message > > > -class SSHSessionDeadError(DTSError): > - """The SSH session is no longer alive.""" > +class _SSHTimeoutError(DTSError): > + """The execution of a command via SSH timed out. > + > + This class is private and meant to be raised as its interactive and > non-interactive variants. > + """ > + > + #: > + severity: ClassVar[ErrorSeverity] = ErrorSeverity.SSH_ERR > + _command: str > + > + def __init__(self, command: str): > + """Define the meaning of the first argument. > + > + Args: > + command: The executed command. > + """ > + self._command = command > + > + def __str__(self) -> str: > + """Add some context to the string representation.""" > + return f"{self._command} execution timed out." > + > + > +class SSHTimeoutError(_SSHTimeoutError): > + """The execution of a command on a non-interactive SSH session timed > out.""" > + > + > +class InteractiveSSHTimeoutError(_SSHTimeoutError): > + """The execution of a command on an interactive SSH session timed out.""" > + > + > +class _SSHSessionDeadError(DTSError): > + """The SSH session is no longer alive. > + > + This class is private and meant to be raised as its interactive and > non-interactive variants. > + """ > > #: > severity: ClassVar[ErrorSeverity] = ErrorSeverity.SSH_ERR > @@ -115,6 +129,14 @@ def __str__(self) -> str: > return f"SSH session with {self._host} has died." > > > +class SSHSessionDeadError(_SSHSessionDeadError): > + """Non-interactive SSH session has died.""" > + > + > +class InteractiveSSHSessionDeadError(_SSHSessionDeadError): > + """Interactive SSH session as died.""" > + > + > class ConfigurationError(DTSError): > """An invalid configuration.""" > > diff --git a/dts/framework/remote_session/interactive_shell.py > b/dts/framework/remote_session/interactive_shell.py > index 5cfe202e15..148907f645 100644 > --- a/dts/framework/remote_session/interactive_shell.py > +++ b/dts/framework/remote_session/interactive_shell.py > @@ -18,11 +18,17 @@ > from pathlib import PurePath > from typing import Callable, ClassVar > > -from paramiko import Channel, SSHClient, channel # type: ignore[import] > +from paramiko import Channel, channel # type: ignore[import] > > +from framework.exception import ( > + InteractiveSSHSessionDeadError, > + InteractiveSSHTimeoutError, > +) > from framework.logger import DTSLogger > from framework.settings import SETTINGS > > +from .interactive_remote_session import InteractiveRemoteSession > + > > class InteractiveShell(ABC): > """The base class for managing interactive shells. > @@ -34,7 +40,7 @@ class InteractiveShell(ABC): > session. > """ > > - _interactive_session: SSHClient > + _interactive_session: InteractiveRemoteSession > _stdin: channel.ChannelStdinFile > _stdout: channel.ChannelFile > _ssh_channel: Channel > @@ -48,7 +54,10 @@ class InteractiveShell(ABC): > > #: Extra characters to add to the end of every command > #: before sending them. This is often overridden by subclasses and is > - #: most commonly an additional newline character. > + #: most commonly an additional newline character. This additional newline > + #: character is used to force the line that is currently awaiting input > + #: into the stdout buffer so that it can be consumed and checked against > + #: the expected prompt. > _command_extra_chars: ClassVar[str] = "" > > #: Path to the executable to start the interactive application. > @@ -60,7 +69,7 @@ class InteractiveShell(ABC): > > def __init__( > self, > - interactive_session: SSHClient, > + interactive_session: InteractiveRemoteSession, > logger: DTSLogger, > get_privileged_command: Callable[[str], str] | None, > app_args: str = "", > @@ -80,7 +89,7 @@ def __init__( > and no output is gathered within the timeout, an exception > is thrown. > """ > self._interactive_session = interactive_session > - self._ssh_channel = self._interactive_session.invoke_shell() > + self._ssh_channel = self._interactive_session.session.invoke_shell() > self._stdin = self._ssh_channel.makefile_stdin("w") > self._stdout = self._ssh_channel.makefile("r") > self._ssh_channel.settimeout(timeout) > @@ -124,20 +133,34 @@ def send_command(self, command: str, prompt: str | None > = None) -> str: > > Returns: > All output in the buffer before expected string. > + > + Raises: > + InteractiveSSHSessionDeadError: The session died while executing > the command. > + InteractiveSSHTimeoutError: If command was sent but prompt could > not be found in > + the output before the timeout. > """ > self._logger.info(f"Sending: '{command}'") > if prompt is None: > prompt = self._default_prompt > - self._stdin.write(f"{command}{self._command_extra_chars}\n") > - self._stdin.flush() > out: str = "" > - for line in self._stdout: > - out += line > - if prompt in line and not line.rstrip().endswith( > - command.rstrip() > - ): # ignore line that sent command > - break > - self._logger.debug(f"Got output: {out}") > + try: > + self._stdin.write(f"{command}{self._command_extra_chars}\n") > + self._stdin.flush() > + for line in self._stdout: > + out += line > + if line.rstrip().endswith(prompt): > + break > + except TimeoutError as e: > + self._logger.exception(e) > + self._logger.debug( > + f"Prompt ({prompt}) was not found in output from command > before timeout." > + ) > + raise InteractiveSSHTimeoutError(command) from e > + except OSError as e: > + self._logger.exception(e) > + raise > InteractiveSSHSessionDeadError(self._interactive_session.hostname) from e > + finally: > + self._logger.debug(f"Got output: {out}") > return out > > def close(self) -> None: > -- > 2.45.1 >