Browse Source

Merge remote-tracking branch 'satwell/signald-reconnect'

Tulir Asokan 4 years ago
parent
commit
febfb9ac5e
3 changed files with 66 additions and 10 deletions
  1. 4 0
      mausignald/errors.py
  2. 49 8
      mausignald/rpc.py
  3. 13 2
      mausignald/signald.py

+ 4 - 0
mausignald/errors.py

@@ -27,6 +27,10 @@ class LinkingError(RPCError):
         self.number = number
         self.number = number
 
 
 
 
+class NotConnected(RPCError):
+    pass
+
+
 class LinkingTimeout(LinkingError):
 class LinkingTimeout(LinkingError):
     pass
     pass
 
 

+ 49 - 8
mausignald/rpc.py

@@ -11,10 +11,15 @@ import json
 
 
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
-from .errors import UnexpectedError, UnexpectedResponse
+from .errors import NotConnected, UnexpectedError, UnexpectedResponse
 
 
 EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
 EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
 
 
+# These are synthetic RPC events for registering callbacks on socket
+# connect and disconnect.
+CONNECT_EVENT = "_socket_connected"
+DISCONNECT_EVENT = "_socket_disconnected"
+
 
 
 class SignaldRPCClient:
 class SignaldRPCClient:
     loop: asyncio.AbstractEventLoop
     loop: asyncio.AbstractEventLoop
@@ -23,6 +28,7 @@ class SignaldRPCClient:
     socket_path: str
     socket_path: str
     _reader: Optional[asyncio.StreamReader]
     _reader: Optional[asyncio.StreamReader]
     _writer: Optional[asyncio.StreamWriter]
     _writer: Optional[asyncio.StreamWriter]
+    _communicate_task: Optional[asyncio.Task]
 
 
     _response_waiters: Dict[UUID, asyncio.Future]
     _response_waiters: Dict[UUID, asyncio.Future]
     _rpc_event_handlers: Dict[str, List[EventHandler]]
     _rpc_event_handlers: Dict[str, List[EventHandler]]
@@ -34,21 +40,47 @@ class SignaldRPCClient:
         self.loop = loop or asyncio.get_event_loop()
         self.loop = loop or asyncio.get_event_loop()
         self._reader = None
         self._reader = None
         self._writer = None
         self._writer = None
+        self._communicate_task = None
         self._response_waiters = {}
         self._response_waiters = {}
-        self._rpc_event_handlers = {}
+        self._rpc_event_handlers = {CONNECT_EVENT: [], DISCONNECT_EVENT: []}
+        self.add_rpc_handler(DISCONNECT_EVENT, self._abandon_responses)
 
 
     async def connect(self) -> None:
     async def connect(self) -> None:
         if self._writer is not None:
         if self._writer is not None:
             return
             return
 
 
-        self._reader, self._writer = await asyncio.open_unix_connection(self.socket_path)
-        self.loop.create_task(self._try_read_loop())
+        initial_connect = self.loop.create_future()
+        self._communicate_task = self.loop.create_task(self._communicate_forever(initial_connect))
+        await initial_connect
+
+    async def _communicate_forever(self, initial_connect: Optional[asyncio.Future] = None) -> None:
+        while True:
+            try:
+                self._reader, self._writer = await asyncio.open_unix_connection(self.socket_path)
+            except OSError as e:
+                self.log.error(f"Connection to {self.socket_path} failed: {e}")
+                await asyncio.sleep(5)
+                continue
+
+            read_loop = self.loop.create_task(self._try_read_loop())
+            await self._run_rpc_handler(CONNECT_EVENT, {})
+
+            if initial_connect:
+                initial_connect.set_result(True)
+                initial_connect = None
+
+            await read_loop
+            await self._run_rpc_handler(DISCONNECT_EVENT, {})
 
 
     async def disconnect(self) -> None:
     async def disconnect(self) -> None:
-        self._writer.write_eof()
-        await self._writer.drain()
-        self._writer = None
-        self._reader = None
+        if self._writer is not None:
+            self._writer.write_eof()
+            await self._writer.drain()
+            if self._communicate_task:
+                self._communicate_task.cancel()
+                self._communicate_task = None
+            self._writer = None
+            self._reader = None
 
 
     def add_rpc_handler(self, method: str, handler: EventHandler) -> None:
     def add_rpc_handler(self, method: str, handler: EventHandler) -> None:
         self._rpc_event_handlers.setdefault(method, []).append(handler)
         self._rpc_event_handlers.setdefault(method, []).append(handler)
@@ -141,7 +173,16 @@ class SignaldRPCClient:
             future = self._response_waiters[req_id] = self.loop.create_future()
             future = self._response_waiters[req_id] = self.loop.create_future()
         return future
         return future
 
 
+    async def _abandon_responses(self, unused_data: Dict[str, Any]) -> None:
+        for req_id, waiter in self._response_waiters.items():
+            if not waiter.done():
+                self.log.trace(f"Abandoning response for {req_id}")
+                waiter.set_exception(NotConnected("Disconnected from signald before RPC completed"))
+
     async def _send_request(self, data: Dict[str, Any]) -> None:
     async def _send_request(self, data: Dict[str, Any]) -> None:
+        if self._writer is None:
+            raise NotConnected("Not connected to signald")
+
         self._writer.write(json.dumps(data).encode("utf-8"))
         self._writer.write(json.dumps(data).encode("utf-8"))
         self._writer.write(b"\n")
         self._writer.write(b"\n")
         await self._writer.drain()
         await self._writer.drain()

+ 13 - 2
mausignald/signald.py

@@ -3,13 +3,13 @@
 # This Source Code Form is subject to the terms of the Mozilla Public
 # This Source Code Form is subject to the terms of the Mozilla Public
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
-from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, TypeVar, Type
+from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, Set, TypeVar, Type
 from uuid import uuid4
 from uuid import uuid4
 import asyncio
 import asyncio
 
 
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
-from .rpc import SignaldRPCClient
+from .rpc import CONNECT_EVENT, SignaldRPCClient
 from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
 from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
                     Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2)
                     Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2)
@@ -20,16 +20,19 @@ EventHandler = Callable[[T], Awaitable[None]]
 
 
 class SignaldClient(SignaldRPCClient):
 class SignaldClient(SignaldRPCClient):
     _event_handlers: Dict[Type[T], List[EventHandler]]
     _event_handlers: Dict[Type[T], List[EventHandler]]
+    _subscriptions: Set[str]
 
 
     def __init__(self, socket_path: str = "/var/run/signald/signald.sock",
     def __init__(self, socket_path: str = "/var/run/signald/signald.sock",
                  log: Optional[TraceLogger] = None,
                  log: Optional[TraceLogger] = None,
                  loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
                  loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
         super().__init__(socket_path, log, loop)
         super().__init__(socket_path, log, loop)
         self._event_handlers = {}
         self._event_handlers = {}
+        self._subscriptions = set()
         self.add_rpc_handler("message", self._parse_message)
         self.add_rpc_handler("message", self._parse_message)
         self.add_rpc_handler("listen_started", self._parse_listen_start)
         self.add_rpc_handler("listen_started", self._parse_listen_start)
         self.add_rpc_handler("listen_stopped", self._parse_listen_stop)
         self.add_rpc_handler("listen_stopped", self._parse_listen_stop)
         self.add_rpc_handler("version", self._log_version)
         self.add_rpc_handler("version", self._log_version)
+        self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
 
 
     def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
     def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
         self._event_handlers.setdefault(event_class, []).append(handler)
         self._event_handlers.setdefault(event_class, []).append(handler)
@@ -75,6 +78,7 @@ class SignaldClient(SignaldRPCClient):
     async def subscribe(self, username: str) -> bool:
     async def subscribe(self, username: str) -> bool:
         try:
         try:
             await self.request("subscribe", "subscribed", username=username)
             await self.request("subscribe", "subscribed", username=username)
+            self._subscriptions.add(username)
             return True
             return True
         except UnexpectedError as e:
         except UnexpectedError as e:
             self.log.debug("Failed to subscribe to %s: %s", username, e)
             self.log.debug("Failed to subscribe to %s: %s", username, e)
@@ -83,11 +87,18 @@ class SignaldClient(SignaldRPCClient):
     async def unsubscribe(self, username: str) -> bool:
     async def unsubscribe(self, username: str) -> bool:
         try:
         try:
             await self.request("unsubscribe", "unsubscribed", username=username)
             await self.request("unsubscribe", "unsubscribed", username=username)
+            self._subscriptions.remove(username)
             return True
             return True
         except UnexpectedError as e:
         except UnexpectedError as e:
             self.log.debug("Failed to unsubscribe from %s: %s", username, e)
             self.log.debug("Failed to unsubscribe from %s: %s", username, e)
             return False
             return False
 
 
+    async def _resubscribe(self, unused_data: Dict[str, Any]) -> None:
+        if self._subscriptions:
+            self.log.debug("Resubscribing to users")
+            for username in list(self._subscriptions):
+                await self.subscribe(username)
+
     async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
     async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
                        ) -> str:
                        ) -> str:
         resp = await self.request("register", "verification_required", username=phone,
         resp = await self.request("register", "verification_required", username=phone,