The connection manager will take care of connecting/disconnecting
to the server. This will also try to reconnect to the server in
certain situations where the client has been disconnected due to
some error condition.

Signed-off-by: G S Niteesh Babu <niteesh...@gmail.com>
---
 python/qemu/aqmp/aqmp_tui.py | 127 +++++++++++++++++++++++++++++------
 1 file changed, 105 insertions(+), 22 deletions(-)

diff --git a/python/qemu/aqmp/aqmp_tui.py b/python/qemu/aqmp/aqmp_tui.py
index 03d4808acd..c47abe0a25 100644
--- a/python/qemu/aqmp/aqmp_tui.py
+++ b/python/qemu/aqmp/aqmp_tui.py
@@ -35,8 +35,9 @@
 import urwid_readline
 
 from ..qmp import QEMUMonitorProtocol, QMPBadPortError
+from .error import ProtocolError
 from .message import DeserializationError, Message, UnexpectedTypeError
-from .protocol import ConnectError
+from .protocol import ConnectError, Runstate
 from .qmp_client import ExecInterruptedError, QMPClient
 from .util import create_task, pretty_traceback
 
@@ -128,17 +129,26 @@ class App(QMPClient):
 
     Initializes the widgets and starts the urwid event loop.
     """
-    def __init__(self, address: Union[str, Tuple[str, int]]) -> None:
+    def __init__(self, address: Union[str, Tuple[str, int]], num_retries: int,
+                 retry_delay: Optional[int]) -> None:
         """
         Initializes the TUI.
 
         :param address:
             Address of the server to connect to.
+        :param num_retries:
+            The number of times to retry before stopping to reconnect.
+        :param retry_delay:
+            The delay(sec) before each retry
         """
         urwid.register_signal(type(self), UPDATE_MSG)
         self.window = Window(self)
         self.address = address
         self.aloop: Optional[asyncio.AbstractEventLoop] = None
+        self.num_retries = num_retries
+        self.retry_delay = retry_delay if retry_delay else 2
+        self.retry: bool = False
+        self.disconnecting: bool = False
         super().__init__()
 
     def add_to_history(self, msg: str, level: Optional[str] = None) -> None:
@@ -212,10 +222,10 @@ def handle_event(self, event: Message) -> None:
         """
         try:
             await self._raw(msg, assign_id='id' not in msg)
-        except ExecInterruptedError:
-            logging.info('Error server disconnected before reply')
+        except ExecInterruptedError as err:
+            logging.info('Error server disconnected before reply %s', str(err))
             self.add_to_history('Server disconnected before reply', 'ERROR')
-            self._set_status("[Server Disconnected]")
+            await self.disconnect()
         except Exception as err:
             logging.error('Exception from _send_to_server: %s', str(err))
             raise err
@@ -237,10 +247,10 @@ def cb_send_to_server(self, raw_msg: str) -> None:
             create_task(self._send_to_server(msg))
         except (ValueError, TypeError) as err:
             logging.info('Invalid message: %s', str(err))
-            self.add_to_history(f'{raw_msg}: {err}')
+            self.add_to_history(f'{raw_msg}: {err}', 'ERROR')
         except (DeserializationError, UnexpectedTypeError) as err:
             logging.info('Invalid message: %s', err.error_message)
-            self.add_to_history(f'{raw_msg}: {err.error_message}')
+            self.add_to_history(f'{raw_msg}: {err.error_message}', 'ERROR')
 
     def unhandled_input(self, key: str) -> None:
         """
@@ -266,18 +276,32 @@ def kill_app(self) -> None:
 
         :raise Exception: When an unhandled exception is caught.
         """
-        # It is ok to call disconnect even in disconnect state
+        await self.disconnect()
+        logging.debug('Disconnect finished. Exiting app')
+        raise urwid.ExitMainLoop()
+
+    async def disconnect(self) -> None:
+        """
+        Overrides the disconnect method to handle the errors locally.
+        """
+        if self.disconnecting:
+            return
         try:
-            await self.disconnect()
-            logging.debug('Disconnect finished. Exiting app')
-        except EOFError:
-            # We receive an EOF during disconnect, ignore that
-            pass
+            self.disconnecting = True
+            await super().disconnect()
+            self.retry = False
+        except EOFError as err:
+            logging.info('disconnect: %s', str(err))
+            self.retry = True
+        except ProtocolError as err:
+            logging.info('disconnect: %s', str(err))
+            self.retry = False
         except Exception as err:
-            logging.info('_kill_app: %s', str(err))
-            # Let the app crash after providing a proper stack trace
+            logging.error('disconnect: Unhandled exception %s', str(err))
+            self.retry = False
             raise err
-        raise urwid.ExitMainLoop()
+        finally:
+            self.disconnecting = False
 
     def _set_status(self, msg: str) -> None:
         """
@@ -301,18 +325,72 @@ def _get_formatted_address(self) -> str:
             addr = f'{self.address}'
         return addr
 
-    async def connect_server(self) -> None:
+    async def _initiate_connection(self) -> Optional[ConnectError]:
+        """
+        Tries connecting to a server a number of times with a delay between
+        each try. If all retries failed then return the error faced during
+        the last retry.
+
+        :return: Error faced during last retry.
+        """
+        current_retries = 0
+        err = None
+
+        # initial try
+        await self.connect_server()
+        while self.retry and current_retries < self.num_retries:
+            logging.info('Connection Failed, retrying in %d', self.retry_delay)
+            status = f'[Retry #{current_retries} ({self.retry_delay}s)]'
+            self._set_status(status)
+
+            await asyncio.sleep(self.retry_delay)
+
+            err = await self.connect_server()
+            current_retries += 1
+        # If all retries failed report the last error
+        if err:
+            logging.info('All retries failed: %s', err)
+            return err
+        return None
+
+    async def manage_connection(self) -> None:
+        """
+        Manage the connection based on the current run state.
+
+        A reconnect is issued when the current state is IDLE and the number
+        of retries is not exhausted.
+        A disconnect is issued when the current state is DISCONNECTING.
+        """
+        while True:
+            if self.runstate == Runstate.IDLE:
+                err = await self._initiate_connection()
+                # If retry is still true then, we have exhausted all our tries.
+                if err:
+                    self._set_status(f'[Error: {err.error_message}]')
+                else:
+                    addr = self._get_formatted_address()
+                    self._set_status(f'[Connected {addr}]')
+            elif self.runstate == Runstate.DISCONNECTING:
+                self._set_status('[Disconnected]')
+                await self.disconnect()
+                # check if a retry is needed
+                if self.runstate == Runstate.IDLE:
+                    continue
+            await self.runstate_changed()
+
+    async def connect_server(self) -> Optional[ConnectError]:
         """
         Initiates a connection to the server at address `self.address`
         and in case of a failure, sets the status to the respective error.
         """
         try:
             await self.connect(self.address)
-            addr = self._get_formatted_address()
-            self._set_status(f'Connected to {addr}')
+            self.retry = False
         except ConnectError as err:
             logging.info('connect_server: ConnectError %s', str(err))
-            self._set_status(f'[ConnectError: {err.error_message}]')
+            self.retry = True
+            return err
+        return None
 
     def run(self, debug: bool = False) -> None:
         """
@@ -341,7 +419,7 @@ def run(self, debug: bool = False) -> None:
                                    event_loop=event_loop)
 
         create_task(self.wait_for_events(), self.aloop)
-        create_task(self.connect_server(), self.aloop)
+        create_task(self.manage_connection(), self.aloop)
         try:
             main_loop.run()
         except Exception as err:
@@ -566,6 +644,11 @@ def main() -> None:
     parser = argparse.ArgumentParser(description='AQMP TUI')
     parser.add_argument('qmp_server', help='Address of the QMP server. '
                         'Format <UNIX socket path | TCP addr:port>')
+    parser.add_argument('--num-retries', type=int, default=10,
+                        help='Number of times to reconnect before giving up.')
+    parser.add_argument('--retry-delay', type=int,
+                        help='Time(s) to wait before next retry. '
+                        'Default action is to wait 2s between each retry.')
     parser.add_argument('--log-file', help='The Log file name')
     parser.add_argument('--log-level', default='WARNING',
                         help='Log level <CRITICAL|ERROR|WARNING|INFO|DEBUG|>')
@@ -581,7 +664,7 @@ def main() -> None:
     except QMPBadPortError as err:
         parser.error(str(err))
 
-    app = App(address)
+    app = App(address, args.num_retries, args.retry_delay)
 
     root_logger = logging.getLogger()
     root_logger.setLevel(logging.getLevelName(args.log_level))
-- 
2.17.1


Reply via email to