Browse Source

Automatically reconnect to signald

If the signald socket gets disconnected (e.g., signald restart), we will
now retry reconnecting to the socket forever, with a 5 second wait
between attempts.  We will also retry the initial connect forever if the
socket connection fails on startup.

This change creates a couple "synthetic" RPC events for connect and
disconnect that handlers can be attached to, as well as a NotConnected
exception class.  These should help callers deal with the fact that the
client may not always be connected.

Fixes #18
Steve Atwell 4 years ago
parent
commit
cf08898dc6
3 changed files with 66 additions and 10 deletions
  1. 4 0
      mausignald/errors.py
  2. 49 8
      mausignald/rpc.py
  3. 13 2
      mausignald/signald.py

+ 4 - 0
mausignald/errors.py

@@ -27,6 +27,10 @@ class LinkingError(RPCError):
         self.number = number
         self.number = number
 
 
 
 
+class NotConnected(RPCError):
+    pass
+
+
 class LinkingTimeout(LinkingError):
 class LinkingTimeout(LinkingError):
     pass
     pass
 
 

+ 49 - 8
mausignald/rpc.py

@@ -11,10 +11,15 @@ import json
 
 
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
-from .errors import UnexpectedError, UnexpectedResponse
+from .errors import NotConnected, UnexpectedError, UnexpectedResponse
 
 
 EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
 EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
 
 
+# These are synthetic RPC events for registering callbacks on socket
+# connect and disconnect.
+CONNECT_EVENT = "_socket_connected"
+DISCONNECT_EVENT = "_socket_disconnected"
+
 
 
 class SignaldRPCClient:
 class SignaldRPCClient:
     loop: asyncio.AbstractEventLoop
     loop: asyncio.AbstractEventLoop
@@ -23,6 +28,7 @@ class SignaldRPCClient:
     socket_path: str
     socket_path: str
     _reader: Optional[asyncio.StreamReader]
     _reader: Optional[asyncio.StreamReader]
     _writer: Optional[asyncio.StreamWriter]
     _writer: Optional[asyncio.StreamWriter]
+    _communicate_task: Optional[asyncio.Task]
 
 
     _response_waiters: Dict[UUID, asyncio.Future]
     _response_waiters: Dict[UUID, asyncio.Future]
     _rpc_event_handlers: Dict[str, List[EventHandler]]
     _rpc_event_handlers: Dict[str, List[EventHandler]]
@@ -34,21 +40,47 @@ class SignaldRPCClient:
         self.loop = loop or asyncio.get_event_loop()
         self.loop = loop or asyncio.get_event_loop()
         self._reader = None
         self._reader = None
         self._writer = None
         self._writer = None
+        self._communicate_task = None
         self._response_waiters = {}
         self._response_waiters = {}
-        self._rpc_event_handlers = {}
+        self._rpc_event_handlers = {CONNECT_EVENT: [], DISCONNECT_EVENT: []}
+        self.add_rpc_handler(DISCONNECT_EVENT, self._abandon_responses)
 
 
     async def connect(self) -> None:
     async def connect(self) -> None:
         if self._writer is not None:
         if self._writer is not None:
             return
             return
 
 
-        self._reader, self._writer = await asyncio.open_unix_connection(self.socket_path)
-        self.loop.create_task(self._try_read_loop())
+        initial_connect = self.loop.create_future()
+        self._communicate_task = self.loop.create_task(self._communicate_forever(initial_connect))
+        await initial_connect
+
+    async def _communicate_forever(self, initial_connect: Optional[asyncio.Future] = None) -> None:
+        while True:
+            try:
+                self._reader, self._writer = await asyncio.open_unix_connection(self.socket_path)
+            except OSError as e:
+                self.log.error(f"Connection to {self.socket_path} failed: {e}")
+                await asyncio.sleep(5)
+                continue
+
+            read_loop = self.loop.create_task(self._try_read_loop())
+            await self._run_rpc_handler(CONNECT_EVENT, {})
+
+            if initial_connect:
+                initial_connect.set_result(True)
+                initial_connect = None
+
+            await read_loop
+            await self._run_rpc_handler(DISCONNECT_EVENT, {})
 
 
     async def disconnect(self) -> None:
     async def disconnect(self) -> None:
-        self._writer.write_eof()
-        await self._writer.drain()
-        self._writer = None
-        self._reader = None
+        if self._writer is not None:
+            self._writer.write_eof()
+            await self._writer.drain()
+            if self._communicate_task:
+                self._communicate_task.cancel()
+                self._communicate_task = None
+            self._writer = None
+            self._reader = None
 
 
     def add_rpc_handler(self, method: str, handler: EventHandler) -> None:
     def add_rpc_handler(self, method: str, handler: EventHandler) -> None:
         self._rpc_event_handlers.setdefault(method, []).append(handler)
         self._rpc_event_handlers.setdefault(method, []).append(handler)
@@ -141,7 +173,16 @@ class SignaldRPCClient:
             future = self._response_waiters[req_id] = self.loop.create_future()
             future = self._response_waiters[req_id] = self.loop.create_future()
         return future
         return future
 
 
+    async def _abandon_responses(self, unused_data: Dict[str, Any]) -> None:
+        for req_id, waiter in self._response_waiters.items():
+            if not waiter.done():
+                self.log.trace(f"Abandoning response for {req_id}")
+                waiter.set_exception(NotConnected("Disconnected from signald before RPC completed"))
+
     async def _send_request(self, data: Dict[str, Any]) -> None:
     async def _send_request(self, data: Dict[str, Any]) -> None:
+        if self._writer is None:
+            raise NotConnected("Not connected to signald")
+
         self._writer.write(json.dumps(data).encode("utf-8"))
         self._writer.write(json.dumps(data).encode("utf-8"))
         self._writer.write(b"\n")
         self._writer.write(b"\n")
         await self._writer.drain()
         await self._writer.drain()

+ 13 - 2
mausignald/signald.py

@@ -3,13 +3,13 @@
 # This Source Code Form is subject to the terms of the Mozilla Public
 # This Source Code Form is subject to the terms of the Mozilla Public
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
-from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, TypeVar, Type
+from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, Set, TypeVar, Type
 from uuid import uuid4
 from uuid import uuid4
 import asyncio
 import asyncio
 
 
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
-from .rpc import SignaldRPCClient
+from .rpc import CONNECT_EVENT, SignaldRPCClient
 from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
 from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
                     Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2)
                     Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2)
@@ -20,16 +20,19 @@ EventHandler = Callable[[T], Awaitable[None]]
 
 
 class SignaldClient(SignaldRPCClient):
 class SignaldClient(SignaldRPCClient):
     _event_handlers: Dict[Type[T], List[EventHandler]]
     _event_handlers: Dict[Type[T], List[EventHandler]]
+    _subscriptions: Set[str]
 
 
     def __init__(self, socket_path: str = "/var/run/signald/signald.sock",
     def __init__(self, socket_path: str = "/var/run/signald/signald.sock",
                  log: Optional[TraceLogger] = None,
                  log: Optional[TraceLogger] = None,
                  loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
                  loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
         super().__init__(socket_path, log, loop)
         super().__init__(socket_path, log, loop)
         self._event_handlers = {}
         self._event_handlers = {}
+        self._subscriptions = set()
         self.add_rpc_handler("message", self._parse_message)
         self.add_rpc_handler("message", self._parse_message)
         self.add_rpc_handler("listen_started", self._parse_listen_start)
         self.add_rpc_handler("listen_started", self._parse_listen_start)
         self.add_rpc_handler("listen_stopped", self._parse_listen_stop)
         self.add_rpc_handler("listen_stopped", self._parse_listen_stop)
         self.add_rpc_handler("version", self._log_version)
         self.add_rpc_handler("version", self._log_version)
+        self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
 
 
     def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
     def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
         self._event_handlers.setdefault(event_class, []).append(handler)
         self._event_handlers.setdefault(event_class, []).append(handler)
@@ -75,6 +78,7 @@ class SignaldClient(SignaldRPCClient):
     async def subscribe(self, username: str) -> bool:
     async def subscribe(self, username: str) -> bool:
         try:
         try:
             await self.request("subscribe", "subscribed", username=username)
             await self.request("subscribe", "subscribed", username=username)
+            self._subscriptions.add(username)
             return True
             return True
         except UnexpectedError as e:
         except UnexpectedError as e:
             self.log.debug("Failed to subscribe to %s: %s", username, e)
             self.log.debug("Failed to subscribe to %s: %s", username, e)
@@ -83,11 +87,18 @@ class SignaldClient(SignaldRPCClient):
     async def unsubscribe(self, username: str) -> bool:
     async def unsubscribe(self, username: str) -> bool:
         try:
         try:
             await self.request("unsubscribe", "unsubscribed", username=username)
             await self.request("unsubscribe", "unsubscribed", username=username)
+            self._subscriptions.remove(username)
             return True
             return True
         except UnexpectedError as e:
         except UnexpectedError as e:
             self.log.debug("Failed to unsubscribe from %s: %s", username, e)
             self.log.debug("Failed to unsubscribe from %s: %s", username, e)
             return False
             return False
 
 
+    async def _resubscribe(self, unused_data: Dict[str, Any]) -> None:
+        if self._subscriptions:
+            self.log.debug("Resubscribing to users")
+            for username in list(self._subscriptions):
+                await self.subscribe(username)
+
     async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
     async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
                        ) -> str:
                        ) -> str:
         resp = await self.request("register", "verification_required", username=phone,
         resp = await self.request("register", "verification_required", username=phone,