Each lock is held per node. The lock assures that multiple connections
to the same node don't execute anything at the same time, removing the
possibility of race conditions.

Signed-off-by: Juraj Linkeš <juraj.lin...@pantheon.tech>
---
 dts/framework/ssh_pexpect.py | 14 ++++--
 dts/framework/utils.py       | 88 ++++++++++++++++++++++++++++++++++++
 2 files changed, 99 insertions(+), 3 deletions(-)

diff --git a/dts/framework/ssh_pexpect.py b/dts/framework/ssh_pexpect.py
index c73c1048a4..01ebd1c010 100644
--- a/dts/framework/ssh_pexpect.py
+++ b/dts/framework/ssh_pexpect.py
@@ -12,7 +12,7 @@
 from .exception import (SSHConnectionException, SSHSessionDeadException,
                         TimeoutException)
 from .logger import DTSLOG
-from .utils import GREEN, RED
+from .utils import GREEN, RED, parallel_lock
 
 """
 Module handles ssh sessions to TG and SUT.
@@ -33,6 +33,7 @@ def __init__(
         username: str,
         password: Optional[str],
         logger: DTSLOG,
+        sut_id: int,
     ):
         self.magic_prompt = "MAGIC PROMPT"
         self.logger = logger
@@ -42,11 +43,18 @@ def __init__(
         self.password = password or ""
         self.logger.info(f"ssh {self.username}@{self.node}")
 
-        self._connect_host()
+        self._connect_host(sut_id=sut_id)
 
-    def _connect_host(self) -> None:
+    @parallel_lock(num=8)
+    def _connect_host(self, sut_id: int = 0) -> None:
         """
         Create connection to assigned node.
+        Parameter sut_id will be used in parallel_lock thus can assure
+        isolated locks for each node.
+        Parallel ssh connections are limited to MaxStartups option in SSHD
+        configuration file. By default concurrent number is 10, so default
+        threads number is limited to 8 which less than 10. Lock number can
+        be modified along with MaxStartups value.
         """
         retry_times = 10
         try:
diff --git a/dts/framework/utils.py b/dts/framework/utils.py
index 7036843dd7..a637c4641e 100644
--- a/dts/framework/utils.py
+++ b/dts/framework/utils.py
@@ -1,7 +1,95 @@
 # SPDX-License-Identifier: BSD-3-Clause
 # Copyright(c) 2010-2014 Intel Corporation
+# Copyright(c) 2022 PANTHEON.tech s.r.o.
+# Copyright(c) 2022 University of New Hampshire
 #
 
+import threading
+from functools import wraps
+from typing import Any, Callable, TypeVar
+
+locks_info: list[dict[str, Any]] = list()
+
+T = TypeVar("T")
+
+
+def parallel_lock(num: int = 1) -> Callable[[Callable[..., T]], Callable[..., 
T]]:
+    """
+    Wrapper function for protect parallel threads, allow multiple threads
+    share one lock. Locks are created based on function name. Thread locks are
+    separated between SUTs according to argument 'sut_id'.
+    Parameter:
+        num: Number of parallel threads for the lock
+    """
+    global locks_info
+
+    def decorate(func: Callable[..., T]) -> Callable[..., T]:
+        # mypy does not know how to handle the types of this function, so Any 
is required
+        @wraps(func)
+        def wrapper(*args: Any, **kwargs: Any) -> T:
+            if "sut_id" in kwargs:
+                sut_id = kwargs["sut_id"]
+            else:
+                sut_id = 0
+
+            # in case function arguments is not correct
+            if sut_id >= len(locks_info):
+                sut_id = 0
+
+            lock_info = locks_info[sut_id]
+            uplock = lock_info["update_lock"]
+
+            name = func.__name__
+            uplock.acquire()
+
+            if name not in lock_info:
+                lock_info[name] = dict()
+                lock_info[name]["lock"] = threading.RLock()
+                lock_info[name]["current_thread"] = 1
+            else:
+                lock_info[name]["current_thread"] += 1
+
+            lock = lock_info[name]["lock"]
+
+            # make sure when owned global lock, should also own update lock
+            if lock_info[name]["current_thread"] >= num:
+                if lock._is_owned():
+                    print(
+                        RED(
+                            f"SUT{sut_id:d} {threading.current_thread().name} 
waiting for func lock {func.__name__}"
+                        )
+                    )
+                lock.acquire()
+            else:
+                uplock.release()
+
+            try:
+                ret = func(*args, **kwargs)
+            except Exception as e:
+                if not uplock._is_owned():
+                    uplock.acquire()
+
+                if lock._is_owned():
+                    lock.release()
+                    lock_info[name]["current_thread"] = 0
+                uplock.release()
+                raise e
+
+            if not uplock._is_owned():
+                uplock.acquire()
+
+            if lock._is_owned():
+                lock.release()
+                lock_info[name]["current_thread"] = 0
+
+            uplock.release()
+
+            return ret
+
+        return wrapper
+
+    return decorate
+
 
 def RED(text: str) -> str:
     return f"\u001B[31;1m{str(text)}\u001B[0m"
-- 
2.30.2

Reply via email to