Each lock is held per node. The lock assures that multiple connections to the same node don't execute anything at the same time, removing the possibility of race conditions.
Signed-off-by: Juraj Linkeš <juraj.lin...@pantheon.tech> --- dts/framework/ssh_pexpect.py | 14 ++++-- dts/framework/utils.py | 88 ++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 3 deletions(-) diff --git a/dts/framework/ssh_pexpect.py b/dts/framework/ssh_pexpect.py index c73c1048a4..01ebd1c010 100644 --- a/dts/framework/ssh_pexpect.py +++ b/dts/framework/ssh_pexpect.py @@ -12,7 +12,7 @@ from .exception import (SSHConnectionException, SSHSessionDeadException, TimeoutException) from .logger import DTSLOG -from .utils import GREEN, RED +from .utils import GREEN, RED, parallel_lock """ Module handles ssh sessions to TG and SUT. @@ -33,6 +33,7 @@ def __init__( username: str, password: Optional[str], logger: DTSLOG, + sut_id: int, ): self.magic_prompt = "MAGIC PROMPT" self.logger = logger @@ -42,11 +43,18 @@ def __init__( self.password = password or "" self.logger.info(f"ssh {self.username}@{self.node}") - self._connect_host() + self._connect_host(sut_id=sut_id) - def _connect_host(self) -> None: + @parallel_lock(num=8) + def _connect_host(self, sut_id: int = 0) -> None: """ Create connection to assigned node. + Parameter sut_id will be used in parallel_lock thus can assure + isolated locks for each node. + Parallel ssh connections are limited to MaxStartups option in SSHD + configuration file. By default concurrent number is 10, so default + threads number is limited to 8 which less than 10. Lock number can + be modified along with MaxStartups value. """ retry_times = 10 try: diff --git a/dts/framework/utils.py b/dts/framework/utils.py index 7036843dd7..a637c4641e 100644 --- a/dts/framework/utils.py +++ b/dts/framework/utils.py @@ -1,7 +1,95 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright(c) 2010-2014 Intel Corporation +# Copyright(c) 2022 PANTHEON.tech s.r.o. +# Copyright(c) 2022 University of New Hampshire # +import threading +from functools import wraps +from typing import Any, Callable, TypeVar + +locks_info: list[dict[str, Any]] = list() + +T = TypeVar("T") + + +def parallel_lock(num: int = 1) -> Callable[[Callable[..., T]], Callable[..., T]]: + """ + Wrapper function for protect parallel threads, allow multiple threads + share one lock. Locks are created based on function name. Thread locks are + separated between SUTs according to argument 'sut_id'. + Parameter: + num: Number of parallel threads for the lock + """ + global locks_info + + def decorate(func: Callable[..., T]) -> Callable[..., T]: + # mypy does not know how to handle the types of this function, so Any is required + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> T: + if "sut_id" in kwargs: + sut_id = kwargs["sut_id"] + else: + sut_id = 0 + + # in case function arguments is not correct + if sut_id >= len(locks_info): + sut_id = 0 + + lock_info = locks_info[sut_id] + uplock = lock_info["update_lock"] + + name = func.__name__ + uplock.acquire() + + if name not in lock_info: + lock_info[name] = dict() + lock_info[name]["lock"] = threading.RLock() + lock_info[name]["current_thread"] = 1 + else: + lock_info[name]["current_thread"] += 1 + + lock = lock_info[name]["lock"] + + # make sure when owned global lock, should also own update lock + if lock_info[name]["current_thread"] >= num: + if lock._is_owned(): + print( + RED( + f"SUT{sut_id:d} {threading.current_thread().name} waiting for func lock {func.__name__}" + ) + ) + lock.acquire() + else: + uplock.release() + + try: + ret = func(*args, **kwargs) + except Exception as e: + if not uplock._is_owned(): + uplock.acquire() + + if lock._is_owned(): + lock.release() + lock_info[name]["current_thread"] = 0 + uplock.release() + raise e + + if not uplock._is_owned(): + uplock.acquire() + + if lock._is_owned(): + lock.release() + lock_info[name]["current_thread"] = 0 + + uplock.release() + + return ret + + return wrapper + + return decorate + def RED(text: str) -> str: return f"\u001B[31;1m{str(text)}\u001B[0m" -- 2.30.2