Enable the ShellPool class to be used in the test run, and let
InteractiveShells register themselves to the pool upon shell startup.
Moreover, to avoid the ShellPool to call InteractiveShell.close more
than once if the shell was not unregistered correctly, add a way to
prevent the method to be called if the shell has already been closed.

Signed-off-by: Luca Vizzarro <luca.vizza...@arm.com>
Reviewed-by: Paul Szczepanek <paul.szczepa...@arm.com>
---
 .../remote_session/interactive_shell.py       | 24 ++++++++++++++++++-
 dts/framework/remote_session/python_shell.py  |  3 ++-
 dts/framework/remote_session/testpmd_shell.py |  2 ++
 dts/framework/test_run.py                     |  5 ++++
 4 files changed, 32 insertions(+), 2 deletions(-)

diff --git a/dts/framework/remote_session/interactive_shell.py 
b/dts/framework/remote_session/interactive_shell.py
index d7e566e5c4..ba8489eafa 100644
--- a/dts/framework/remote_session/interactive_shell.py
+++ b/dts/framework/remote_session/interactive_shell.py
@@ -22,12 +22,14 @@
 """
 
 from abc import ABC, abstractmethod
+from collections.abc import Callable
 from pathlib import PurePath
-from typing import ClassVar
+from typing import ClassVar, Concatenate, ParamSpec, TypeAlias, TypeVar
 
 from paramiko import Channel, channel
 from typing_extensions import Self
 
+from framework.context import get_ctx
 from framework.exception import (
     InteractiveCommandExecutionError,
     InteractiveSSHSessionDeadError,
@@ -38,6 +40,23 @@
 from framework.settings import SETTINGS
 from framework.testbed_model.node import Node
 
+P = ParamSpec("P")
+T = TypeVar("T", bound="InteractiveShell")
+R = TypeVar("R")
+InteractiveShellMethod = Callable[Concatenate[T, P], R]
+InteractiveShellDecorator: TypeAlias = Callable[[InteractiveShellMethod], 
InteractiveShellMethod]
+
+
+def only_active(func: InteractiveShellMethod) -> InteractiveShellMethod:
+    """This decorator will skip running the method if the SSH channel is not 
active."""
+
+    def _wrapper(self: "InteractiveShell", *args: P.args, **kwargs: P.kwargs) 
-> R | None:
+        if self._ssh_channel.active:
+            return func(self, *args, **kwargs)
+        return None
+
+    return _wrapper
+
 
 class InteractiveShell(ABC):
     """The base class for managing interactive shells.
@@ -155,6 +174,7 @@ def start_application(self, prompt: str | None = None) -> 
None:
             self.is_alive = False  # update state on failure to start
             raise InteractiveCommandExecutionError("Failed to start 
application.")
         self._ssh_channel.settimeout(self._timeout)
+        get_ctx().shell_pool.register_shell(self)
 
     def send_command(
         self, command: str, prompt: str | None = None, skip_first_line: bool = 
False
@@ -219,6 +239,7 @@ def send_command(
             self._logger.debug(f"Got output: {out}")
         return out
 
+    @only_active
     def close(self) -> None:
         """Close the shell.
 
@@ -234,6 +255,7 @@ def close(self) -> None:
             self._logger.debug("Application failed to exit before set 
timeout.")
             raise InteractiveSSHTimeoutError("Application 'exit' command") 
from e
         self._ssh_channel.close()
+        get_ctx().shell_pool.unregister_shell(self)
 
     @property
     @abstractmethod
diff --git a/dts/framework/remote_session/python_shell.py 
b/dts/framework/remote_session/python_shell.py
index 6331d047c4..5b380a5c7a 100644
--- a/dts/framework/remote_session/python_shell.py
+++ b/dts/framework/remote_session/python_shell.py
@@ -15,7 +15,7 @@
 from pathlib import PurePath
 from typing import ClassVar
 
-from .interactive_shell import InteractiveShell
+from .interactive_shell import InteractiveShell, only_active
 
 
 class PythonShell(InteractiveShell):
@@ -32,6 +32,7 @@ def path(self) -> PurePath:
         """Path to the Python3 executable."""
         return PurePath("python3")
 
+    @only_active
     def close(self):
         """Close Python shell."""
         return super().close()
diff --git a/dts/framework/remote_session/testpmd_shell.py 
b/dts/framework/remote_session/testpmd_shell.py
index 19437b6233..b1939e4a51 100644
--- a/dts/framework/remote_session/testpmd_shell.py
+++ b/dts/framework/remote_session/testpmd_shell.py
@@ -33,6 +33,7 @@
 )
 
 from framework.context import get_ctx
+from framework.remote_session.interactive_shell import only_active
 from framework.testbed_model.topology import TopologyType
 
 if TYPE_CHECKING or environ.get("DTS_DOC_BUILD"):
@@ -2314,6 +2315,7 @@ def rx_vxlan(self, vxlan_id: int, port_id: int, enable: 
bool, verify: bool = Tru
                 self._logger.debug(f"Failed to set VXLAN:\n{vxlan_output}")
                 raise InteractiveCommandExecutionError(f"Failed to set 
VXLAN:\n{vxlan_output}")
 
+    @only_active
     def close(self) -> None:
         """Overrides :meth:`~.interactive_shell.close`."""
         self.stop()
diff --git a/dts/framework/test_run.py b/dts/framework/test_run.py
index f9cfe5e908..7708fe31bd 100644
--- a/dts/framework/test_run.py
+++ b/dts/framework/test_run.py
@@ -430,6 +430,7 @@ def description(self) -> str:
 
     def next(self) -> State | None:
         """Next state."""
+        self.test_run.ctx.shell_pool.terminate_current_pool()
         self.test_run.ctx.tg.teardown(self.test_run.ctx.topology.tg_ports)
         self.test_run.ctx.dpdk.teardown(self.test_run.ctx.topology.sut_ports)
         self.test_run.ctx.tg_node.teardown()
@@ -473,6 +474,7 @@ def description(self) -> str:
 
     def next(self) -> State | None:
         """Next state."""
+        self.test_run.ctx.shell_pool.start_new_pool()
         self.test_suite.set_up_suite()
         self.result.update_setup(Result.PASS)
         return TestSuiteExecution(self.test_run, self.test_suite, self.result)
@@ -544,6 +546,7 @@ def next(self) -> State | None:
         """Next state."""
         self.test_suite.tear_down_suite()
         self.test_run.ctx.dpdk.kill_cleanup_dpdk_apps()
+        self.test_run.ctx.shell_pool.terminate_current_pool()
         self.result.update_teardown(Result.PASS)
         return TestRunExecution(self.test_run, self.test_run.result)
 
@@ -594,6 +597,7 @@ def description(self) -> str:
 
     def next(self) -> State | None:
         """Next state."""
+        self.test_run.ctx.shell_pool.start_new_pool()
         self.test_suite.set_up_test_case()
         self.result.update_setup(Result.PASS)
         return TestCaseExecution(
@@ -670,6 +674,7 @@ def description(self) -> str:
     def next(self) -> State | None:
         """Next state."""
         self.test_suite.tear_down_test_case()
+        self.test_run.ctx.shell_pool.terminate_current_pool()
         self.result.update_teardown(Result.PASS)
         return TestSuiteExecution(self.test_run, self.test_suite, 
self.test_suite_result)
 
-- 
2.43.0

Reply via email to