Browse Source

bridge state: send TRANSIENT_DISCONNECT when disconnected from signald socket

Sumner Evans 3 years ago
parent
commit
385039e8e9
3 changed files with 22 additions and 3 deletions
  1. 10 1
      mausignald/signald.py
  2. 1 0
      mausignald/types.py
  3. 11 2
      mautrix_signal/user.py

+ 10 - 1
mausignald/signald.py

@@ -8,7 +8,7 @@ import asyncio
 
 
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
-from .rpc import CONNECT_EVENT, SignaldRPCClient
+from .rpc import CONNECT_EVENT, DISCONNECT_EVENT, SignaldRPCClient
 from .errors import UnexpectedError, UnexpectedResponse
 from .errors import UnexpectedError, UnexpectedResponse
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, DeviceInfo, Group,
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, DeviceInfo, Group,
                     Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2,
                     Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2,
@@ -33,6 +33,7 @@ class SignaldClient(SignaldRPCClient):
         self.add_rpc_handler("listen_stopped", self._parse_listen_stop)
         self.add_rpc_handler("listen_stopped", self._parse_listen_stop)
         self.add_rpc_handler("version", self._log_version)
         self.add_rpc_handler("version", self._log_version)
         self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
         self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
+        self.add_rpc_handler(DISCONNECT_EVENT, self._on_disconnect)
 
 
     def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
     def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
         self._event_handlers.setdefault(event_class, []).append(handler)
         self._event_handlers.setdefault(event_class, []).append(handler)
@@ -99,6 +100,14 @@ class SignaldClient(SignaldRPCClient):
             for username in list(self._subscriptions):
             for username in list(self._subscriptions):
                 await self.subscribe(username)
                 await self.subscribe(username)
 
 
+    async def _on_disconnect(self, *_) -> None:
+        if self._subscriptions:
+            self.log.debug("Notifying of disconnection from users")
+            for username in self._subscriptions:
+                evt = ListenEvent(action=ListenAction.SOCKET_DISCONNECTED, username=username,
+                                  exception="Disconnected from signald")
+                await self._run_event_handler(evt)
+
     async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
     async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
                        ) -> str:
                        ) -> str:
         resp = await self.request_v1("register", account=phone, voice=voice, captcha=captcha)
         resp = await self.request_v1("register", account=phone, voice=voice, captcha=captcha)

+ 1 - 0
mausignald/types.py

@@ -413,6 +413,7 @@ class Message(SerializableAttrs):
 class ListenAction(SerializableEnum):
 class ListenAction(SerializableEnum):
     STARTED = "started"
     STARTED = "started"
     STOPPED = "stopped"
     STOPPED = "stopped"
+    SOCKET_DISCONNECTED = "socket-disconnected"
 
 
 
 
 @dataclass
 @dataclass

+ 11 - 2
mautrix_signal/user.py

@@ -140,13 +140,22 @@ class User(DBUser, BaseUser):
             self._track_metric(METRIC_LOGGED_IN, True)
             self._track_metric(METRIC_LOGGED_IN, True)
             self._connected = True
             self._connected = True
             asyncio.create_task(self.push_bridge_state(BridgeStateEvent.CONNECTED))
             asyncio.create_task(self.push_bridge_state(BridgeStateEvent.CONNECTED))
-        elif evt.action == ListenAction.STOPPED:
+        elif evt.action in (ListenAction.SOCKET_DISCONNECTED, ListenAction.STOPPED):
             if evt.exception:
             if evt.exception:
                 self.log.warning(f"Disconnected from Signal: {evt.exception}")
                 self.log.warning(f"Disconnected from Signal: {evt.exception}")
             else:
             else:
                 self.log.info("Disconnected from Signal")
                 self.log.info("Disconnected from Signal")
             self._track_metric(METRIC_CONNECTED, False)
             self._track_metric(METRIC_CONNECTED, False)
-            asyncio.create_task(self.push_bridge_state(BridgeStateEvent.UNKNOWN_ERROR))
+            asyncio.create_task(
+                self.push_bridge_state(
+                    (
+                        BridgeStateEvent.TRANSIENT_DISCONNECT
+                        if evt.action == ListenAction.SOCKET_DISCONNECTED
+                        else BridgeStateEvent.UNKNOWN_ERROR
+                    ),
+                    error=str(evt.exception),
+                )
+            )
             self._connected = False
             self._connected = False
         else:
         else:
             self.log.warning(f"Unrecognized listen action {evt.action}")
             self.log.warning(f"Unrecognized listen action {evt.action}")