Reviewed-by: Nicholas Pratte <npra...@iol.unh.edu>

On Wed, May 29, 2024 at 3:49 PM <jspew...@iol.unh.edu> wrote:
>
> From: Jeremy Spewock <jspew...@iol.unh.edu>
>
> The current implementation of consuming output from interactive shells
> relies on being able to find an expected prompt somewhere within the
> output buffer after sending the command. This is useful in situations
> where the prompt does not appear in the output itself, but in some
> practical cases (such as the starting of an XML-RPC server for scapy)
> the prompt exists in one of the commands sent to the shell and this can
> cause the command to exit early and creates a race condition between the
> server starting and the first command being sent to the server.
>
> This patch addresses this problem by searching for a line that strictly
> ends with the provided prompt, rather than one that simply contains it,
> so that the detection that a command is finished is more consistent. It
> also adds a catch to detect when a command times out before finding the
> prompt or the underlying SSH session dies so that the exception can be
> wrapped into a more explicit one and be more consistent with the
> non-interactive shells.
>
> Bugzilla ID: 1359
> Fixes: 88489c0501af ("dts: add smoke tests")
>
> Signed-off-by: Jeremy Spewock <jspew...@iol.unh.edu>
> ---
>  dts/framework/exception.py                    | 66 ++++++++++++-------
>  .../remote_session/interactive_shell.py       | 51 ++++++++++----
>  2 files changed, 81 insertions(+), 36 deletions(-)
>
> diff --git a/dts/framework/exception.py b/dts/framework/exception.py
> index cce1e0231a..627190c781 100644
> --- a/dts/framework/exception.py
> +++ b/dts/framework/exception.py
> @@ -48,26 +48,6 @@ class DTSError(Exception):
>      severity: ClassVar[ErrorSeverity] = ErrorSeverity.GENERIC_ERR
>
>
> -class SSHTimeoutError(DTSError):
> -    """The SSH execution of a command timed out."""
> -
> -    #:
> -    severity: ClassVar[ErrorSeverity] = ErrorSeverity.SSH_ERR
> -    _command: str
> -
> -    def __init__(self, command: str):
> -        """Define the meaning of the first argument.
> -
> -        Args:
> -            command: The executed command.
> -        """
> -        self._command = command
> -
> -    def __str__(self) -> str:
> -        """Add some context to the string representation."""
> -        return f"{self._command} execution timed out."
> -
> -
>  class SSHConnectionError(DTSError):
>      """An unsuccessful SSH connection."""
>
> @@ -95,8 +75,42 @@ def __str__(self) -> str:
>          return message
>
>
> -class SSHSessionDeadError(DTSError):
> -    """The SSH session is no longer alive."""
> +class _SSHTimeoutError(DTSError):
> +    """The execution of a command via SSH timed out.
> +
> +    This class is private and meant to be raised as its interactive and 
> non-interactive variants.
> +    """
> +
> +    #:
> +    severity: ClassVar[ErrorSeverity] = ErrorSeverity.SSH_ERR
> +    _command: str
> +
> +    def __init__(self, command: str):
> +        """Define the meaning of the first argument.
> +
> +        Args:
> +            command: The executed command.
> +        """
> +        self._command = command
> +
> +    def __str__(self) -> str:
> +        """Add some context to the string representation."""
> +        return f"{self._command} execution timed out."
> +
> +
> +class SSHTimeoutError(_SSHTimeoutError):
> +    """The execution of a command on a non-interactive SSH session timed 
> out."""
> +
> +
> +class InteractiveSSHTimeoutError(_SSHTimeoutError):
> +    """The execution of a command on an interactive SSH session timed out."""
> +
> +
> +class _SSHSessionDeadError(DTSError):
> +    """The SSH session is no longer alive.
> +
> +    This class is private and meant to be raised as its interactive and 
> non-interactive variants.
> +    """
>
>      #:
>      severity: ClassVar[ErrorSeverity] = ErrorSeverity.SSH_ERR
> @@ -115,6 +129,14 @@ def __str__(self) -> str:
>          return f"SSH session with {self._host} has died."
>
>
> +class SSHSessionDeadError(_SSHSessionDeadError):
> +    """Non-interactive SSH session has died."""
> +
> +
> +class InteractiveSSHSessionDeadError(_SSHSessionDeadError):
> +    """Interactive SSH session as died."""
> +
> +
>  class ConfigurationError(DTSError):
>      """An invalid configuration."""
>
> diff --git a/dts/framework/remote_session/interactive_shell.py 
> b/dts/framework/remote_session/interactive_shell.py
> index 5cfe202e15..148907f645 100644
> --- a/dts/framework/remote_session/interactive_shell.py
> +++ b/dts/framework/remote_session/interactive_shell.py
> @@ -18,11 +18,17 @@
>  from pathlib import PurePath
>  from typing import Callable, ClassVar
>
> -from paramiko import Channel, SSHClient, channel  # type: ignore[import]
> +from paramiko import Channel, channel  # type: ignore[import]
>
> +from framework.exception import (
> +    InteractiveSSHSessionDeadError,
> +    InteractiveSSHTimeoutError,
> +)
>  from framework.logger import DTSLogger
>  from framework.settings import SETTINGS
>
> +from .interactive_remote_session import InteractiveRemoteSession
> +
>
>  class InteractiveShell(ABC):
>      """The base class for managing interactive shells.
> @@ -34,7 +40,7 @@ class InteractiveShell(ABC):
>      session.
>      """
>
> -    _interactive_session: SSHClient
> +    _interactive_session: InteractiveRemoteSession
>      _stdin: channel.ChannelStdinFile
>      _stdout: channel.ChannelFile
>      _ssh_channel: Channel
> @@ -48,7 +54,10 @@ class InteractiveShell(ABC):
>
>      #: Extra characters to add to the end of every command
>      #: before sending them. This is often overridden by subclasses and is
> -    #: most commonly an additional newline character.
> +    #: most commonly an additional newline character. This additional newline
> +    #: character is used to force the line that is currently awaiting input
> +    #: into the stdout buffer so that it can be consumed and checked against
> +    #: the expected prompt.
>      _command_extra_chars: ClassVar[str] = ""
>
>      #: Path to the executable to start the interactive application.
> @@ -60,7 +69,7 @@ class InteractiveShell(ABC):
>
>      def __init__(
>          self,
> -        interactive_session: SSHClient,
> +        interactive_session: InteractiveRemoteSession,
>          logger: DTSLogger,
>          get_privileged_command: Callable[[str], str] | None,
>          app_args: str = "",
> @@ -80,7 +89,7 @@ def __init__(
>                  and no output is gathered within the timeout, an exception 
> is thrown.
>          """
>          self._interactive_session = interactive_session
> -        self._ssh_channel = self._interactive_session.invoke_shell()
> +        self._ssh_channel = self._interactive_session.session.invoke_shell()
>          self._stdin = self._ssh_channel.makefile_stdin("w")
>          self._stdout = self._ssh_channel.makefile("r")
>          self._ssh_channel.settimeout(timeout)
> @@ -124,20 +133,34 @@ def send_command(self, command: str, prompt: str | None 
> = None) -> str:
>
>          Returns:
>              All output in the buffer before expected string.
> +
> +        Raises:
> +            InteractiveSSHSessionDeadError: The session died while executing 
> the command.
> +            InteractiveSSHTimeoutError: If command was sent but prompt could 
> not be found in
> +                the output before the timeout.
>          """
>          self._logger.info(f"Sending: '{command}'")
>          if prompt is None:
>              prompt = self._default_prompt
> -        self._stdin.write(f"{command}{self._command_extra_chars}\n")
> -        self._stdin.flush()
>          out: str = ""
> -        for line in self._stdout:
> -            out += line
> -            if prompt in line and not line.rstrip().endswith(
> -                command.rstrip()
> -            ):  # ignore line that sent command
> -                break
> -        self._logger.debug(f"Got output: {out}")
> +        try:
> +            self._stdin.write(f"{command}{self._command_extra_chars}\n")
> +            self._stdin.flush()
> +            for line in self._stdout:
> +                out += line
> +                if line.rstrip().endswith(prompt):
> +                    break
> +        except TimeoutError as e:
> +            self._logger.exception(e)
> +            self._logger.debug(
> +                f"Prompt ({prompt}) was not found in output from command 
> before timeout."
> +            )
> +            raise InteractiveSSHTimeoutError(command) from e
> +        except OSError as e:
> +            self._logger.exception(e)
> +            raise 
> InteractiveSSHSessionDeadError(self._interactive_session.hostname) from e
> +        finally:
> +            self._logger.debug(f"Got output: {out}")
>          return out
>
>      def close(self) -> None:
> --
> 2.45.1
>

Reply via email to