瀏覽代碼

Merge remote-tracking branch 'satwell/signald-reconnect'

Tulir Asokan 4 年之前
父節點
當前提交
febfb9ac5e
共有 3 個文件被更改,包括 66 次插入10 次删除
  1. 4 0
      mausignald/errors.py
  2. 49 8
      mausignald/rpc.py
  3. 13 2
      mausignald/signald.py

+ 4 - 0
mausignald/errors.py

@@ -27,6 +27,10 @@ class LinkingError(RPCError):
         self.number = number
 
 
+class NotConnected(RPCError):
+    pass
+
+
 class LinkingTimeout(LinkingError):
     pass
 

+ 49 - 8
mausignald/rpc.py

@@ -11,10 +11,15 @@ import json
 
 from mautrix.util.logging import TraceLogger
 
-from .errors import UnexpectedError, UnexpectedResponse
+from .errors import NotConnected, UnexpectedError, UnexpectedResponse
 
 EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
 
+# These are synthetic RPC events for registering callbacks on socket
+# connect and disconnect.
+CONNECT_EVENT = "_socket_connected"
+DISCONNECT_EVENT = "_socket_disconnected"
+
 
 class SignaldRPCClient:
     loop: asyncio.AbstractEventLoop
@@ -23,6 +28,7 @@ class SignaldRPCClient:
     socket_path: str
     _reader: Optional[asyncio.StreamReader]
     _writer: Optional[asyncio.StreamWriter]
+    _communicate_task: Optional[asyncio.Task]
 
     _response_waiters: Dict[UUID, asyncio.Future]
     _rpc_event_handlers: Dict[str, List[EventHandler]]
@@ -34,21 +40,47 @@ class SignaldRPCClient:
         self.loop = loop or asyncio.get_event_loop()
         self._reader = None
         self._writer = None
+        self._communicate_task = None
         self._response_waiters = {}
-        self._rpc_event_handlers = {}
+        self._rpc_event_handlers = {CONNECT_EVENT: [], DISCONNECT_EVENT: []}
+        self.add_rpc_handler(DISCONNECT_EVENT, self._abandon_responses)
 
     async def connect(self) -> None:
         if self._writer is not None:
             return
 
-        self._reader, self._writer = await asyncio.open_unix_connection(self.socket_path)
-        self.loop.create_task(self._try_read_loop())
+        initial_connect = self.loop.create_future()
+        self._communicate_task = self.loop.create_task(self._communicate_forever(initial_connect))
+        await initial_connect
+
+    async def _communicate_forever(self, initial_connect: Optional[asyncio.Future] = None) -> None:
+        while True:
+            try:
+                self._reader, self._writer = await asyncio.open_unix_connection(self.socket_path)
+            except OSError as e:
+                self.log.error(f"Connection to {self.socket_path} failed: {e}")
+                await asyncio.sleep(5)
+                continue
+
+            read_loop = self.loop.create_task(self._try_read_loop())
+            await self._run_rpc_handler(CONNECT_EVENT, {})
+
+            if initial_connect:
+                initial_connect.set_result(True)
+                initial_connect = None
+
+            await read_loop
+            await self._run_rpc_handler(DISCONNECT_EVENT, {})
 
     async def disconnect(self) -> None:
-        self._writer.write_eof()
-        await self._writer.drain()
-        self._writer = None
-        self._reader = None
+        if self._writer is not None:
+            self._writer.write_eof()
+            await self._writer.drain()
+            if self._communicate_task:
+                self._communicate_task.cancel()
+                self._communicate_task = None
+            self._writer = None
+            self._reader = None
 
     def add_rpc_handler(self, method: str, handler: EventHandler) -> None:
         self._rpc_event_handlers.setdefault(method, []).append(handler)
@@ -141,7 +173,16 @@ class SignaldRPCClient:
             future = self._response_waiters[req_id] = self.loop.create_future()
         return future
 
+    async def _abandon_responses(self, unused_data: Dict[str, Any]) -> None:
+        for req_id, waiter in self._response_waiters.items():
+            if not waiter.done():
+                self.log.trace(f"Abandoning response for {req_id}")
+                waiter.set_exception(NotConnected("Disconnected from signald before RPC completed"))
+
     async def _send_request(self, data: Dict[str, Any]) -> None:
+        if self._writer is None:
+            raise NotConnected("Not connected to signald")
+
         self._writer.write(json.dumps(data).encode("utf-8"))
         self._writer.write(b"\n")
         await self._writer.drain()

+ 13 - 2
mausignald/signald.py

@@ -3,13 +3,13 @@
 # This Source Code Form is subject to the terms of the Mozilla Public
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
-from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, TypeVar, Type
+from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, Set, TypeVar, Type
 from uuid import uuid4
 import asyncio
 
 from mautrix.util.logging import TraceLogger
 
-from .rpc import SignaldRPCClient
+from .rpc import CONNECT_EVENT, SignaldRPCClient
 from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
                     Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2)
@@ -20,16 +20,19 @@ EventHandler = Callable[[T], Awaitable[None]]
 
 class SignaldClient(SignaldRPCClient):
     _event_handlers: Dict[Type[T], List[EventHandler]]
+    _subscriptions: Set[str]
 
     def __init__(self, socket_path: str = "/var/run/signald/signald.sock",
                  log: Optional[TraceLogger] = None,
                  loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
         super().__init__(socket_path, log, loop)
         self._event_handlers = {}
+        self._subscriptions = set()
         self.add_rpc_handler("message", self._parse_message)
         self.add_rpc_handler("listen_started", self._parse_listen_start)
         self.add_rpc_handler("listen_stopped", self._parse_listen_stop)
         self.add_rpc_handler("version", self._log_version)
+        self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
 
     def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
         self._event_handlers.setdefault(event_class, []).append(handler)
@@ -75,6 +78,7 @@ class SignaldClient(SignaldRPCClient):
     async def subscribe(self, username: str) -> bool:
         try:
             await self.request("subscribe", "subscribed", username=username)
+            self._subscriptions.add(username)
             return True
         except UnexpectedError as e:
             self.log.debug("Failed to subscribe to %s: %s", username, e)
@@ -83,11 +87,18 @@ class SignaldClient(SignaldRPCClient):
     async def unsubscribe(self, username: str) -> bool:
         try:
             await self.request("unsubscribe", "unsubscribed", username=username)
+            self._subscriptions.remove(username)
             return True
         except UnexpectedError as e:
             self.log.debug("Failed to unsubscribe from %s: %s", username, e)
             return False
 
+    async def _resubscribe(self, unused_data: Dict[str, Any]) -> None:
+        if self._subscriptions:
+            self.log.debug("Resubscribing to users")
+            for username in list(self._subscriptions):
+                await self.subscribe(username)
+
     async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
                        ) -> str:
         resp = await self.request("register", "verification_required", username=phone,