Browse Source

Utilize signald's new websocket connection state reporting

Sumner Evans 3 năm trước cách đây
mục cha
commit
a057376ec3
4 tập tin đã thay đổi với 97 bổ sung46 xóa
  1. 18 14
      mausignald/signald.py
  2. 15 7
      mausignald/types.py
  3. 7 5
      mautrix_signal/signal.py
  4. 57 20
      mautrix_signal/user.py

+ 18 - 14
mausignald/signald.py

@@ -11,8 +11,8 @@ from mautrix.util.logging import TraceLogger
 from .rpc import CONNECT_EVENT, DISCONNECT_EVENT, SignaldRPCClient
 from .errors import UnexpectedError, UnexpectedResponse
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, DeviceInfo, Group,
-                    Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2,
-                    Mention, LinkSession)
+                    Profile, GroupID, GetIdentitiesResponse, GroupV2, Mention, LinkSession,
+                    WebsocketConnectionState, WebsocketConnectionStateChangeEvent)
 
 T = TypeVar('T')
 EventHandler = Callable[[T], Awaitable[None]]
@@ -29,8 +29,8 @@ class SignaldClient(SignaldRPCClient):
         self._event_handlers = {}
         self._subscriptions = set()
         self.add_rpc_handler("message", self._parse_message)
-        self.add_rpc_handler("listen_started", self._parse_listen_start)
-        self.add_rpc_handler("listener_stopped", self._parse_listener_stop)
+        self.add_rpc_handler("websocket_connection_state_change",
+                             self._websocket_connection_state_change)
         self.add_rpc_handler("version", self._log_version)
         self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
         self.add_rpc_handler(DISCONNECT_EVENT, self._on_disconnect)
@@ -67,13 +67,11 @@ class SignaldClient(SignaldRPCClient):
         version = data["data"]["version"]
         self.log.info(f"Connected to {name} v{version}")
 
-    async def _parse_listen_start(self, data: Dict[str, Any]) -> None:
-        evt = ListenEvent(action=ListenAction.STARTED, username=data["data"])
-        await self._run_event_handler(evt)
-
-    async def _parse_listener_stop(self, data: Dict[str, Any]) -> None:
-        evt = ListenEvent(action=ListenAction.STOPPED, username=data["data"],
-                          exception=data.get("exception", None))
+    async def _websocket_connection_state_change(self, change_event: Dict[str, Any]) -> None:
+        evt = WebsocketConnectionStateChangeEvent(
+            state=WebsocketConnectionState.deserialize(change_event["data"]["state"]),
+            account=change_event["data"]["account"],
+        )
         await self._run_event_handler(evt)
 
     async def subscribe(self, username: str) -> bool:
@@ -83,7 +81,10 @@ class SignaldClient(SignaldRPCClient):
             return True
         except UnexpectedError as e:
             self.log.debug("Failed to subscribe to %s: %s", username, e)
-            evt = ListenEvent(action=ListenAction.STOPPED, username=username, exception=e)
+            evt = WebsocketConnectionStateChangeEvent(
+                state=WebsocketConnectionState.DISCONNECTED,
+                account=username,
+            )
             await self._run_event_handler(evt)
             return False
 
@@ -106,8 +107,11 @@ class SignaldClient(SignaldRPCClient):
         if self._subscriptions:
             self.log.debug("Notifying of disconnection from users")
             for username in self._subscriptions:
-                evt = ListenEvent(action=ListenAction.SOCKET_DISCONNECTED, username=username,
-                                  exception="Disconnected from signald")
+                evt = WebsocketConnectionStateChangeEvent(
+                    state=WebsocketConnectionState.SOCKET_DISCONNECTED,
+                    account=username,
+                    exception="Disconnected from signald"
+                )
                 await self._run_event_handler(evt)
 
     async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None

+ 15 - 7
mausignald/types.py

@@ -410,14 +410,22 @@ class Message(SerializableAttrs):
     receipt: Optional[Receipt] = None
 
 
-class ListenAction(SerializableEnum):
-    STARTED = "started"
-    STOPPED = "stopped"
-    SOCKET_DISCONNECTED = "socket-disconnected"
+class WebsocketConnectionState(SerializableEnum):
+    # States from signald itself
+    DISCONNECTED = "DISCONNECTED"
+    CONNECTING = "CONNECTING"
+    CONNECTED = "CONNECTED"
+    RECONNECTING = "RECONNECTING"
+    DISCONNECTING = "DISCONNECTING"
+    AUTHENTICATION_FAILED = "AUTHENTICATION_FAILED"
+    FAILED = "FAILED"
+
+    # Socket disconnect state
+    SOCKET_DISCONNECTED = "SOCKET_DISCONNECTED"
 
 
 @dataclass
-class ListenEvent(SerializableAttrs):
-    action: ListenAction
-    username: str
+class WebsocketConnectionStateChangeEvent(SerializableAttrs):
+    state: WebsocketConnectionState
+    account: str
     exception: Optional[str] = None

+ 7 - 5
mautrix_signal/signal.py

@@ -19,7 +19,8 @@ import logging
 
 from mausignald import SignaldClient
 from mausignald.types import (Message, MessageData, Address, TypingNotification, TypingAction,
-                              OwnReadReceipt, Receipt, ReceiptType, ListenEvent)
+                              OwnReadReceipt, Receipt, ReceiptType,
+                              WebsocketConnectionStateChangeEvent)
 from mautrix.util.logging import TraceLogger
 
 from .db import Message as DBMessage
@@ -43,7 +44,8 @@ class SignalHandler(SignaldClient):
         self.data_dir = bridge.config["signal.data_dir"]
         self.delete_unknown_accounts = bridge.config["signal.delete_unknown_accounts_on_start"]
         self.add_event_handler(Message, self.on_message)
-        self.add_event_handler(ListenEvent, self.on_listen)
+        self.add_event_handler(WebsocketConnectionStateChangeEvent,
+                               self.on_websocket_connection_state_change)
 
     async def on_message(self, evt: Message) -> None:
         sender = await pu.Puppet.get_by_address(evt.source)
@@ -73,9 +75,9 @@ class SignalHandler(SignaldClient):
                 await user.sync_groups()
 
     @staticmethod
-    async def on_listen(evt: ListenEvent) -> None:
-        user = await u.User.get_by_username(evt.username)
-        user.on_listen(evt)
+    async def on_websocket_connection_state_change(evt: WebsocketConnectionStateChangeEvent) -> None:
+        user = await u.User.get_by_username(evt.account)
+        user.on_websocket_connection_state_change(evt)
 
     async def handle_message(self, user: 'u.User', sender: 'pu.Puppet', msg: MessageData,
                              addr_override: Optional[Address] = None) -> None:

+ 57 - 20
mautrix_signal/user.py

@@ -13,11 +13,14 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
+from asyncio.tasks import sleep
+from datetime import datetime
 from typing import Union, Dict, Optional, AsyncGenerator, List, TYPE_CHECKING, cast
 from uuid import UUID
 import asyncio
 
-from mausignald.types import Account, Address, Profile, Group, GroupV2, ListenEvent, ListenAction
+from mausignald.types import (Account, Address, Profile, Group, GroupV2, WebsocketConnectionState,
+                              WebsocketConnectionStateChangeEvent)
 from mautrix.bridge import BaseUser, BridgeState, AutologinError, async_getter_lock
 from mautrix.types import UserID, RoomID
 from mautrix.util.bridge_state import BridgeStateEvent
@@ -56,6 +59,8 @@ class User(DBUser, BaseUser):
     _sync_lock: asyncio.Lock
     _notice_room_lock: asyncio.Lock
     _connected: bool
+    _websocket_connection_state: Optional[WebsocketConnectionState]
+    _latest_non_transient_disconnect_state: Optional[datetime]
 
     def __init__(self, mxid: UserID, username: Optional[str] = None, uuid: Optional[UUID] = None,
                  notice_room: Optional[RoomID] = None) -> None:
@@ -64,6 +69,7 @@ class User(DBUser, BaseUser):
         self._notice_room_lock = asyncio.Lock()
         self._sync_lock = asyncio.Lock()
         self._connected = False
+        self._websocket_connection_state = None
         perms = self.config.get_permissions(mxid)
         self.relay_whitelisted, self.is_whitelisted, self.is_admin, self.permission_level = perms
 
@@ -133,32 +139,63 @@ class User(DBUser, BaseUser):
         asyncio.create_task(self.sync())
         self._track_metric(METRIC_LOGGED_IN, True)
 
-    def on_listen(self, evt: ListenEvent) -> None:
-        if evt.action == ListenAction.STARTED:
+    def on_websocket_connection_state_change(self, evt: WebsocketConnectionStateChangeEvent) -> None:
+        if evt.state == WebsocketConnectionState.CONNECTED:
             self.log.info("Connected to Signal")
             self._track_metric(METRIC_CONNECTED, True)
             self._track_metric(METRIC_LOGGED_IN, True)
             self._connected = True
-            asyncio.create_task(self.push_bridge_state(BridgeStateEvent.CONNECTED))
-        elif evt.action in (ListenAction.SOCKET_DISCONNECTED, ListenAction.STOPPED):
-            if evt.exception:
-                self.log.warning(f"Disconnected from Signal: {evt.exception}")
-            else:
-                self.log.info("Disconnected from Signal")
+        else:
+            self.log.warning(
+                f"New websocket state from signald: {evt.state}. Error: {evt.exception}")
             self._track_metric(METRIC_CONNECTED, False)
-            asyncio.create_task(
-                self.push_bridge_state(
-                    (
-                        BridgeStateEvent.TRANSIENT_DISCONNECT
-                        if evt.action == ListenAction.SOCKET_DISCONNECTED
-                        else BridgeStateEvent.UNKNOWN_ERROR
-                    ),
-                    error=str(evt.exception),
-                )
-            )
             self._connected = False
+
+        bridge_state = {
+            # Signald disconnected
+            WebsocketConnectionState.SOCKET_DISCONNECTED: BridgeStateEvent.TRANSIENT_DISCONNECT,
+
+            # Websocket state reported by signald
+            WebsocketConnectionState.DISCONNECTED: (
+                None
+                if self._websocket_connection_state == BridgeStateEvent.BAD_CREDENTIALS
+                else BridgeStateEvent.TRANSIENT_DISCONNECT
+            ),
+            WebsocketConnectionState.CONNECTING: BridgeStateEvent.CONNECTING,
+            WebsocketConnectionState.CONNECTED: BridgeStateEvent.CONNECTED,
+            WebsocketConnectionState.RECONNECTING: BridgeStateEvent.TRANSIENT_DISCONNECT,
+            WebsocketConnectionState.DISCONNECTING: BridgeStateEvent.TRANSIENT_DISCONNECT,
+            WebsocketConnectionState.AUTHENTICATION_FAILED: BridgeStateEvent.BAD_CREDENTIALS,
+            WebsocketConnectionState.FAILED: BridgeStateEvent.UNKNOWN_ERROR,
+        }.get(evt.state)
+        if bridge_state:
+            asyncio.create_task(self.push_bridge_state(bridge_state))
+
+            now = datetime.now()
+            if bridge_state == BridgeStateEvent.TRANSIENT_DISCONNECT:
+                # Wait for two minutes. if the bridge stays in TRANSIENT_DISCONNECT for that long,
+                # something terrible has happened (signald failed to restart, the internet broke,
+                # etc.)
+                async def wait_report_disconnected():
+                    await sleep(120)
+                    if (
+                        self._latest_non_transient_disconnect_state
+                        and now > self._latest_non_transient_disconnect_state
+                    ):
+                        asyncio.create_task(self.push_bridge_state(BridgeStateEvent.UNKNOWN_ERROR))
+                    else:
+                        self.log.info(
+                            "New state since last TRANSIENT_DISCONNECT push. "
+                            "Not transitioning to UNKNOWN_ERROR."
+                        )
+
+                asyncio.create_task(wait_report_disconnected())
+            else:
+                self._latest_non_transient_disconnect_state = now
+
+            self._websocket_connection_state = bridge_state
         else:
-            self.log.warning(f"Unrecognized listen action {evt.action}")
+            self.log.info(f"Websocket state {evt.state} seen. Will not report new Bridge State")
 
     async def _sync_puppet(self) -> None:
         puppet = await pu.Puppet.get_by_address(self.address)