瀏覽代碼

Merge pull request #135 from mautrix/better-disconnect-detection

Better disconnect detection
Sumner Evans 4 年之前
父節點
當前提交
2fb9380b1c
共有 3 個文件被更改,包括 26 次插入5 次删除
  1. 14 3
      mausignald/signald.py
  2. 1 0
      mausignald/types.py
  3. 11 2
      mautrix_signal/user.py

+ 14 - 3
mausignald/signald.py

@@ -8,7 +8,7 @@ import asyncio
 
 
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
-from .rpc import CONNECT_EVENT, SignaldRPCClient
+from .rpc import CONNECT_EVENT, DISCONNECT_EVENT, SignaldRPCClient
 from .errors import UnexpectedError, UnexpectedResponse
 from .errors import UnexpectedError, UnexpectedResponse
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, DeviceInfo, Group,
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, DeviceInfo, Group,
                     Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2,
                     Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2,
@@ -30,9 +30,10 @@ class SignaldClient(SignaldRPCClient):
         self._subscriptions = set()
         self._subscriptions = set()
         self.add_rpc_handler("message", self._parse_message)
         self.add_rpc_handler("message", self._parse_message)
         self.add_rpc_handler("listen_started", self._parse_listen_start)
         self.add_rpc_handler("listen_started", self._parse_listen_start)
-        self.add_rpc_handler("listen_stopped", self._parse_listen_stop)
+        self.add_rpc_handler("listener_stopped", self._parse_listener_stop)
         self.add_rpc_handler("version", self._log_version)
         self.add_rpc_handler("version", self._log_version)
         self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
         self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
+        self.add_rpc_handler(DISCONNECT_EVENT, self._on_disconnect)
 
 
     def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
     def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
         self._event_handlers.setdefault(event_class, []).append(handler)
         self._event_handlers.setdefault(event_class, []).append(handler)
@@ -70,7 +71,7 @@ class SignaldClient(SignaldRPCClient):
         evt = ListenEvent(action=ListenAction.STARTED, username=data["data"])
         evt = ListenEvent(action=ListenAction.STARTED, username=data["data"])
         await self._run_event_handler(evt)
         await self._run_event_handler(evt)
 
 
-    async def _parse_listen_stop(self, data: Dict[str, Any]) -> None:
+    async def _parse_listener_stop(self, data: Dict[str, Any]) -> None:
         evt = ListenEvent(action=ListenAction.STOPPED, username=data["data"],
         evt = ListenEvent(action=ListenAction.STOPPED, username=data["data"],
                           exception=data.get("exception", None))
                           exception=data.get("exception", None))
         await self._run_event_handler(evt)
         await self._run_event_handler(evt)
@@ -82,6 +83,8 @@ class SignaldClient(SignaldRPCClient):
             return True
             return True
         except UnexpectedError as e:
         except UnexpectedError as e:
             self.log.debug("Failed to subscribe to %s: %s", username, e)
             self.log.debug("Failed to subscribe to %s: %s", username, e)
+            evt = ListenEvent(action=ListenAction.STOPPED, username=username, exception=e)
+            await self._run_event_handler(evt)
             return False
             return False
 
 
     async def unsubscribe(self, username: str) -> bool:
     async def unsubscribe(self, username: str) -> bool:
@@ -99,6 +102,14 @@ class SignaldClient(SignaldRPCClient):
             for username in list(self._subscriptions):
             for username in list(self._subscriptions):
                 await self.subscribe(username)
                 await self.subscribe(username)
 
 
+    async def _on_disconnect(self, *_) -> None:
+        if self._subscriptions:
+            self.log.debug("Notifying of disconnection from users")
+            for username in self._subscriptions:
+                evt = ListenEvent(action=ListenAction.SOCKET_DISCONNECTED, username=username,
+                                  exception="Disconnected from signald")
+                await self._run_event_handler(evt)
+
     async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
     async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
                        ) -> str:
                        ) -> str:
         resp = await self.request_v1("register", account=phone, voice=voice, captcha=captcha)
         resp = await self.request_v1("register", account=phone, voice=voice, captcha=captcha)

+ 1 - 0
mausignald/types.py

@@ -413,6 +413,7 @@ class Message(SerializableAttrs):
 class ListenAction(SerializableEnum):
 class ListenAction(SerializableEnum):
     STARTED = "started"
     STARTED = "started"
     STOPPED = "stopped"
     STOPPED = "stopped"
+    SOCKET_DISCONNECTED = "socket-disconnected"
 
 
 
 
 @dataclass
 @dataclass

+ 11 - 2
mautrix_signal/user.py

@@ -140,13 +140,22 @@ class User(DBUser, BaseUser):
             self._track_metric(METRIC_LOGGED_IN, True)
             self._track_metric(METRIC_LOGGED_IN, True)
             self._connected = True
             self._connected = True
             asyncio.create_task(self.push_bridge_state(BridgeStateEvent.CONNECTED))
             asyncio.create_task(self.push_bridge_state(BridgeStateEvent.CONNECTED))
-        elif evt.action == ListenAction.STOPPED:
+        elif evt.action in (ListenAction.SOCKET_DISCONNECTED, ListenAction.STOPPED):
             if evt.exception:
             if evt.exception:
                 self.log.warning(f"Disconnected from Signal: {evt.exception}")
                 self.log.warning(f"Disconnected from Signal: {evt.exception}")
             else:
             else:
                 self.log.info("Disconnected from Signal")
                 self.log.info("Disconnected from Signal")
             self._track_metric(METRIC_CONNECTED, False)
             self._track_metric(METRIC_CONNECTED, False)
-            asyncio.create_task(self.push_bridge_state(BridgeStateEvent.UNKNOWN_ERROR))
+            asyncio.create_task(
+                self.push_bridge_state(
+                    (
+                        BridgeStateEvent.TRANSIENT_DISCONNECT
+                        if evt.action == ListenAction.SOCKET_DISCONNECTED
+                        else BridgeStateEvent.UNKNOWN_ERROR
+                    ),
+                    error=str(evt.exception),
+                )
+            )
             self._connected = False
             self._connected = False
         else:
         else:
             self.log.warning(f"Unrecognized listen action {evt.action}")
             self.log.warning(f"Unrecognized listen action {evt.action}")