Pārlūkot izejas kodu

Merge pull request #135 from mautrix/better-disconnect-detection

Better disconnect detection
Sumner Evans 3 gadi atpakaļ
vecāks
revīzija
2fb9380b1c
3 mainītis faili ar 26 papildinājumiem un 5 dzēšanām
  1. 14 3
      mausignald/signald.py
  2. 1 0
      mausignald/types.py
  3. 11 2
      mautrix_signal/user.py

+ 14 - 3
mausignald/signald.py

@@ -8,7 +8,7 @@ import asyncio
 
 
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
-from .rpc import CONNECT_EVENT, SignaldRPCClient
+from .rpc import CONNECT_EVENT, DISCONNECT_EVENT, SignaldRPCClient
 from .errors import UnexpectedError, UnexpectedResponse
 from .errors import UnexpectedError, UnexpectedResponse
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, DeviceInfo, Group,
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, DeviceInfo, Group,
                     Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2,
                     Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2,
@@ -30,9 +30,10 @@ class SignaldClient(SignaldRPCClient):
         self._subscriptions = set()
         self._subscriptions = set()
         self.add_rpc_handler("message", self._parse_message)
         self.add_rpc_handler("message", self._parse_message)
         self.add_rpc_handler("listen_started", self._parse_listen_start)
         self.add_rpc_handler("listen_started", self._parse_listen_start)
-        self.add_rpc_handler("listen_stopped", self._parse_listen_stop)
+        self.add_rpc_handler("listener_stopped", self._parse_listener_stop)
         self.add_rpc_handler("version", self._log_version)
         self.add_rpc_handler("version", self._log_version)
         self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
         self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
+        self.add_rpc_handler(DISCONNECT_EVENT, self._on_disconnect)
 
 
     def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
     def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
         self._event_handlers.setdefault(event_class, []).append(handler)
         self._event_handlers.setdefault(event_class, []).append(handler)
@@ -70,7 +71,7 @@ class SignaldClient(SignaldRPCClient):
         evt = ListenEvent(action=ListenAction.STARTED, username=data["data"])
         evt = ListenEvent(action=ListenAction.STARTED, username=data["data"])
         await self._run_event_handler(evt)
         await self._run_event_handler(evt)
 
 
-    async def _parse_listen_stop(self, data: Dict[str, Any]) -> None:
+    async def _parse_listener_stop(self, data: Dict[str, Any]) -> None:
         evt = ListenEvent(action=ListenAction.STOPPED, username=data["data"],
         evt = ListenEvent(action=ListenAction.STOPPED, username=data["data"],
                           exception=data.get("exception", None))
                           exception=data.get("exception", None))
         await self._run_event_handler(evt)
         await self._run_event_handler(evt)
@@ -82,6 +83,8 @@ class SignaldClient(SignaldRPCClient):
             return True
             return True
         except UnexpectedError as e:
         except UnexpectedError as e:
             self.log.debug("Failed to subscribe to %s: %s", username, e)
             self.log.debug("Failed to subscribe to %s: %s", username, e)
+            evt = ListenEvent(action=ListenAction.STOPPED, username=username, exception=e)
+            await self._run_event_handler(evt)
             return False
             return False
 
 
     async def unsubscribe(self, username: str) -> bool:
     async def unsubscribe(self, username: str) -> bool:
@@ -99,6 +102,14 @@ class SignaldClient(SignaldRPCClient):
             for username in list(self._subscriptions):
             for username in list(self._subscriptions):
                 await self.subscribe(username)
                 await self.subscribe(username)
 
 
+    async def _on_disconnect(self, *_) -> None:
+        if self._subscriptions:
+            self.log.debug("Notifying of disconnection from users")
+            for username in self._subscriptions:
+                evt = ListenEvent(action=ListenAction.SOCKET_DISCONNECTED, username=username,
+                                  exception="Disconnected from signald")
+                await self._run_event_handler(evt)
+
     async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
     async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
                        ) -> str:
                        ) -> str:
         resp = await self.request_v1("register", account=phone, voice=voice, captcha=captcha)
         resp = await self.request_v1("register", account=phone, voice=voice, captcha=captcha)

+ 1 - 0
mausignald/types.py

@@ -413,6 +413,7 @@ class Message(SerializableAttrs):
 class ListenAction(SerializableEnum):
 class ListenAction(SerializableEnum):
     STARTED = "started"
     STARTED = "started"
     STOPPED = "stopped"
     STOPPED = "stopped"
+    SOCKET_DISCONNECTED = "socket-disconnected"
 
 
 
 
 @dataclass
 @dataclass

+ 11 - 2
mautrix_signal/user.py

@@ -140,13 +140,22 @@ class User(DBUser, BaseUser):
             self._track_metric(METRIC_LOGGED_IN, True)
             self._track_metric(METRIC_LOGGED_IN, True)
             self._connected = True
             self._connected = True
             asyncio.create_task(self.push_bridge_state(BridgeStateEvent.CONNECTED))
             asyncio.create_task(self.push_bridge_state(BridgeStateEvent.CONNECTED))
-        elif evt.action == ListenAction.STOPPED:
+        elif evt.action in (ListenAction.SOCKET_DISCONNECTED, ListenAction.STOPPED):
             if evt.exception:
             if evt.exception:
                 self.log.warning(f"Disconnected from Signal: {evt.exception}")
                 self.log.warning(f"Disconnected from Signal: {evt.exception}")
             else:
             else:
                 self.log.info("Disconnected from Signal")
                 self.log.info("Disconnected from Signal")
             self._track_metric(METRIC_CONNECTED, False)
             self._track_metric(METRIC_CONNECTED, False)
-            asyncio.create_task(self.push_bridge_state(BridgeStateEvent.UNKNOWN_ERROR))
+            asyncio.create_task(
+                self.push_bridge_state(
+                    (
+                        BridgeStateEvent.TRANSIENT_DISCONNECT
+                        if evt.action == ListenAction.SOCKET_DISCONNECTED
+                        else BridgeStateEvent.UNKNOWN_ERROR
+                    ),
+                    error=str(evt.exception),
+                )
+            )
             self._connected = False
             self._connected = False
         else:
         else:
             self.log.warning(f"Unrecognized listen action {evt.action}")
             self.log.warning(f"Unrecognized listen action {evt.action}")