Bläddra i källkod

Merge pull request #137 from mautrix/utilize-websocket-connection-state

Utilize signald's new websocket connection state reporting
Sumner Evans 3 år sedan
förälder
incheckning
a33bf74c99
5 ändrade filer med 101 tillägg och 47 borttagningar
  1. 15 14
      mausignald/signald.py
  2. 15 7
      mausignald/types.py
  3. 7 1
      mautrix_signal/portal.py
  4. 7 5
      mautrix_signal/signal.py
  5. 57 20
      mautrix_signal/user.py

+ 15 - 14
mausignald/signald.py

@@ -11,8 +11,8 @@ from mautrix.util.logging import TraceLogger
 from .rpc import CONNECT_EVENT, DISCONNECT_EVENT, SignaldRPCClient
 from .rpc import CONNECT_EVENT, DISCONNECT_EVENT, SignaldRPCClient
 from .errors import UnexpectedError, UnexpectedResponse
 from .errors import UnexpectedError, UnexpectedResponse
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, DeviceInfo, Group,
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, DeviceInfo, Group,
-                    Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2,
-                    Mention, LinkSession)
+                    Profile, GroupID, GetIdentitiesResponse, GroupV2, Mention, LinkSession,
+                    WebsocketConnectionState, WebsocketConnectionStateChangeEvent)
 
 
 T = TypeVar('T')
 T = TypeVar('T')
 EventHandler = Callable[[T], Awaitable[None]]
 EventHandler = Callable[[T], Awaitable[None]]
@@ -29,8 +29,8 @@ class SignaldClient(SignaldRPCClient):
         self._event_handlers = {}
         self._event_handlers = {}
         self._subscriptions = set()
         self._subscriptions = set()
         self.add_rpc_handler("message", self._parse_message)
         self.add_rpc_handler("message", self._parse_message)
-        self.add_rpc_handler("listen_started", self._parse_listen_start)
-        self.add_rpc_handler("listener_stopped", self._parse_listener_stop)
+        self.add_rpc_handler("websocket_connection_state_change",
+                             self._websocket_connection_state_change)
         self.add_rpc_handler("version", self._log_version)
         self.add_rpc_handler("version", self._log_version)
         self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
         self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
         self.add_rpc_handler(DISCONNECT_EVENT, self._on_disconnect)
         self.add_rpc_handler(DISCONNECT_EVENT, self._on_disconnect)
@@ -67,13 +67,8 @@ class SignaldClient(SignaldRPCClient):
         version = data["data"]["version"]
         version = data["data"]["version"]
         self.log.info(f"Connected to {name} v{version}")
         self.log.info(f"Connected to {name} v{version}")
 
 
-    async def _parse_listen_start(self, data: Dict[str, Any]) -> None:
-        evt = ListenEvent(action=ListenAction.STARTED, username=data["data"])
-        await self._run_event_handler(evt)
-
-    async def _parse_listener_stop(self, data: Dict[str, Any]) -> None:
-        evt = ListenEvent(action=ListenAction.STOPPED, username=data["data"],
-                          exception=data.get("exception", None))
+    async def _websocket_connection_state_change(self, change_event: Dict[str, Any]) -> None:
+        evt = WebsocketConnectionStateChangeEvent.deserialize(change_event["data"])
         await self._run_event_handler(evt)
         await self._run_event_handler(evt)
 
 
     async def subscribe(self, username: str) -> bool:
     async def subscribe(self, username: str) -> bool:
@@ -83,7 +78,10 @@ class SignaldClient(SignaldRPCClient):
             return True
             return True
         except UnexpectedError as e:
         except UnexpectedError as e:
             self.log.debug("Failed to subscribe to %s: %s", username, e)
             self.log.debug("Failed to subscribe to %s: %s", username, e)
-            evt = ListenEvent(action=ListenAction.STOPPED, username=username, exception=e)
+            evt = WebsocketConnectionStateChangeEvent(
+                state=WebsocketConnectionState.DISCONNECTED,
+                account=username,
+            )
             await self._run_event_handler(evt)
             await self._run_event_handler(evt)
             return False
             return False
 
 
@@ -106,8 +104,11 @@ class SignaldClient(SignaldRPCClient):
         if self._subscriptions:
         if self._subscriptions:
             self.log.debug("Notifying of disconnection from users")
             self.log.debug("Notifying of disconnection from users")
             for username in self._subscriptions:
             for username in self._subscriptions:
-                evt = ListenEvent(action=ListenAction.SOCKET_DISCONNECTED, username=username,
-                                  exception="Disconnected from signald")
+                evt = WebsocketConnectionStateChangeEvent(
+                    state=WebsocketConnectionState.SOCKET_DISCONNECTED,
+                    account=username,
+                    exception="Disconnected from signald"
+                )
                 await self._run_event_handler(evt)
                 await self._run_event_handler(evt)
 
 
     async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
     async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None

+ 15 - 7
mausignald/types.py

@@ -410,14 +410,22 @@ class Message(SerializableAttrs):
     receipt: Optional[Receipt] = None
     receipt: Optional[Receipt] = None
 
 
 
 
-class ListenAction(SerializableEnum):
-    STARTED = "started"
-    STOPPED = "stopped"
-    SOCKET_DISCONNECTED = "socket-disconnected"
+class WebsocketConnectionState(SerializableEnum):
+    # States from signald itself
+    DISCONNECTED = "DISCONNECTED"
+    CONNECTING = "CONNECTING"
+    CONNECTED = "CONNECTED"
+    RECONNECTING = "RECONNECTING"
+    DISCONNECTING = "DISCONNECTING"
+    AUTHENTICATION_FAILED = "AUTHENTICATION_FAILED"
+    FAILED = "FAILED"
+
+    # Socket disconnect state
+    SOCKET_DISCONNECTED = "SOCKET_DISCONNECTED"
 
 
 
 
 @dataclass
 @dataclass
-class ListenEvent(SerializableAttrs):
-    action: ListenAction
-    username: str
+class WebsocketConnectionStateChangeEvent(SerializableAttrs):
+    state: WebsocketConnectionState
+    account: str
     exception: Optional[str] = None
     exception: Optional[str] = None

+ 7 - 1
mautrix_signal/portal.py

@@ -13,6 +13,7 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
+from mautrix.util.bridge_state import BridgeStateEvent
 from typing import (Dict, Tuple, Optional, List, Deque, Any, Union, AsyncGenerator, Awaitable, Set,
 from typing import (Dict, Tuple, Optional, List, Deque, Any, Union, AsyncGenerator, Awaitable, Set,
                     Callable, TYPE_CHECKING, cast)
                     Callable, TYPE_CHECKING, cast)
 from html import escape as escape_html
 from html import escape as escape_html
@@ -30,7 +31,7 @@ import os
 from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
 from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
                               Attachment, GroupID, GroupV2ID, GroupV2, Mention, Sticker,
                               Attachment, GroupID, GroupV2ID, GroupV2, Mention, Sticker,
                               GroupAccessControl, AccessControlMode, GroupMemberRole)
                               GroupAccessControl, AccessControlMode, GroupMemberRole)
-from mausignald.errors import RPCError
+from mausignald.errors import AuthorizationFailedException, RPCError, ResponseError
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.bridge import BasePortal, async_getter_lock
 from mautrix.bridge import BasePortal, async_getter_lock
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType, Format,
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType, Format,
@@ -291,6 +292,11 @@ class Portal(DBPortal, BasePortal):
                                    mentions=mentions, quote=quote, attachments=attachments,
                                    mentions=mentions, quote=quote, attachments=attachments,
                                    timestamp=request_id)
                                    timestamp=request_id)
         except Exception as e:
         except Exception as e:
+            auth_failed = (
+                "org.whispersystems.signalservice.api.push.exceptions.AuthorizationFailedException"
+            )
+            if isinstance(e, ResponseError) and auth_failed in e.data.get("exceptions"):
+                await sender.push_bridge_state(BridgeStateEvent.BAD_CREDENTIALS, error=str(e))
             await self._send_message(
             await self._send_message(
                 self.main_intent,
                 self.main_intent,
                 TextMessageEventContent(
                 TextMessageEventContent(

+ 7 - 5
mautrix_signal/signal.py

@@ -19,7 +19,8 @@ import logging
 
 
 from mausignald import SignaldClient
 from mausignald import SignaldClient
 from mausignald.types import (Message, MessageData, Address, TypingNotification, TypingAction,
 from mausignald.types import (Message, MessageData, Address, TypingNotification, TypingAction,
-                              OwnReadReceipt, Receipt, ReceiptType, ListenEvent)
+                              OwnReadReceipt, Receipt, ReceiptType,
+                              WebsocketConnectionStateChangeEvent)
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
 from .db import Message as DBMessage
 from .db import Message as DBMessage
@@ -43,7 +44,8 @@ class SignalHandler(SignaldClient):
         self.data_dir = bridge.config["signal.data_dir"]
         self.data_dir = bridge.config["signal.data_dir"]
         self.delete_unknown_accounts = bridge.config["signal.delete_unknown_accounts_on_start"]
         self.delete_unknown_accounts = bridge.config["signal.delete_unknown_accounts_on_start"]
         self.add_event_handler(Message, self.on_message)
         self.add_event_handler(Message, self.on_message)
-        self.add_event_handler(ListenEvent, self.on_listen)
+        self.add_event_handler(WebsocketConnectionStateChangeEvent,
+                               self.on_websocket_connection_state_change)
 
 
     async def on_message(self, evt: Message) -> None:
     async def on_message(self, evt: Message) -> None:
         sender = await pu.Puppet.get_by_address(evt.source)
         sender = await pu.Puppet.get_by_address(evt.source)
@@ -73,9 +75,9 @@ class SignalHandler(SignaldClient):
                 await user.sync_groups()
                 await user.sync_groups()
 
 
     @staticmethod
     @staticmethod
-    async def on_listen(evt: ListenEvent) -> None:
-        user = await u.User.get_by_username(evt.username)
-        user.on_listen(evt)
+    async def on_websocket_connection_state_change(evt: WebsocketConnectionStateChangeEvent) -> None:
+        user = await u.User.get_by_username(evt.account)
+        user.on_websocket_connection_state_change(evt)
 
 
     async def handle_message(self, user: 'u.User', sender: 'pu.Puppet', msg: MessageData,
     async def handle_message(self, user: 'u.User', sender: 'pu.Puppet', msg: MessageData,
                              addr_override: Optional[Address] = None) -> None:
                              addr_override: Optional[Address] = None) -> None:

+ 57 - 20
mautrix_signal/user.py

@@ -13,11 +13,14 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
+from asyncio.tasks import sleep
+from datetime import datetime
 from typing import Union, Dict, Optional, AsyncGenerator, List, TYPE_CHECKING, cast
 from typing import Union, Dict, Optional, AsyncGenerator, List, TYPE_CHECKING, cast
 from uuid import UUID
 from uuid import UUID
 import asyncio
 import asyncio
 
 
-from mausignald.types import Account, Address, Profile, Group, GroupV2, ListenEvent, ListenAction
+from mausignald.types import (Account, Address, Profile, Group, GroupV2, WebsocketConnectionState,
+                              WebsocketConnectionStateChangeEvent)
 from mautrix.bridge import BaseUser, BridgeState, AutologinError, async_getter_lock
 from mautrix.bridge import BaseUser, BridgeState, AutologinError, async_getter_lock
 from mautrix.types import UserID, RoomID
 from mautrix.types import UserID, RoomID
 from mautrix.util.bridge_state import BridgeStateEvent
 from mautrix.util.bridge_state import BridgeStateEvent
@@ -56,6 +59,8 @@ class User(DBUser, BaseUser):
     _sync_lock: asyncio.Lock
     _sync_lock: asyncio.Lock
     _notice_room_lock: asyncio.Lock
     _notice_room_lock: asyncio.Lock
     _connected: bool
     _connected: bool
+    _websocket_connection_state: Optional[WebsocketConnectionState]
+    _latest_non_transient_disconnect_state: Optional[datetime]
 
 
     def __init__(self, mxid: UserID, username: Optional[str] = None, uuid: Optional[UUID] = None,
     def __init__(self, mxid: UserID, username: Optional[str] = None, uuid: Optional[UUID] = None,
                  notice_room: Optional[RoomID] = None) -> None:
                  notice_room: Optional[RoomID] = None) -> None:
@@ -64,6 +69,7 @@ class User(DBUser, BaseUser):
         self._notice_room_lock = asyncio.Lock()
         self._notice_room_lock = asyncio.Lock()
         self._sync_lock = asyncio.Lock()
         self._sync_lock = asyncio.Lock()
         self._connected = False
         self._connected = False
+        self._websocket_connection_state = None
         perms = self.config.get_permissions(mxid)
         perms = self.config.get_permissions(mxid)
         self.relay_whitelisted, self.is_whitelisted, self.is_admin, self.permission_level = perms
         self.relay_whitelisted, self.is_whitelisted, self.is_admin, self.permission_level = perms
 
 
@@ -133,32 +139,63 @@ class User(DBUser, BaseUser):
         asyncio.create_task(self.sync())
         asyncio.create_task(self.sync())
         self._track_metric(METRIC_LOGGED_IN, True)
         self._track_metric(METRIC_LOGGED_IN, True)
 
 
-    def on_listen(self, evt: ListenEvent) -> None:
-        if evt.action == ListenAction.STARTED:
+    def on_websocket_connection_state_change(self, evt: WebsocketConnectionStateChangeEvent) -> None:
+        if evt.state == WebsocketConnectionState.CONNECTED:
             self.log.info("Connected to Signal")
             self.log.info("Connected to Signal")
             self._track_metric(METRIC_CONNECTED, True)
             self._track_metric(METRIC_CONNECTED, True)
             self._track_metric(METRIC_LOGGED_IN, True)
             self._track_metric(METRIC_LOGGED_IN, True)
             self._connected = True
             self._connected = True
-            asyncio.create_task(self.push_bridge_state(BridgeStateEvent.CONNECTED))
-        elif evt.action in (ListenAction.SOCKET_DISCONNECTED, ListenAction.STOPPED):
-            if evt.exception:
-                self.log.warning(f"Disconnected from Signal: {evt.exception}")
-            else:
-                self.log.info("Disconnected from Signal")
+        else:
+            self.log.warning(
+                f"New websocket state from signald: {evt.state}. Error: {evt.exception}")
             self._track_metric(METRIC_CONNECTED, False)
             self._track_metric(METRIC_CONNECTED, False)
-            asyncio.create_task(
-                self.push_bridge_state(
-                    (
-                        BridgeStateEvent.TRANSIENT_DISCONNECT
-                        if evt.action == ListenAction.SOCKET_DISCONNECTED
-                        else BridgeStateEvent.UNKNOWN_ERROR
-                    ),
-                    error=str(evt.exception),
-                )
-            )
             self._connected = False
             self._connected = False
+
+        bridge_state = {
+            # Signald disconnected
+            WebsocketConnectionState.SOCKET_DISCONNECTED: BridgeStateEvent.TRANSIENT_DISCONNECT,
+
+            # Websocket state reported by signald
+            WebsocketConnectionState.DISCONNECTED: (
+                None
+                if self._websocket_connection_state == BridgeStateEvent.BAD_CREDENTIALS
+                else BridgeStateEvent.TRANSIENT_DISCONNECT
+            ),
+            WebsocketConnectionState.CONNECTING: BridgeStateEvent.CONNECTING,
+            WebsocketConnectionState.CONNECTED: BridgeStateEvent.CONNECTED,
+            WebsocketConnectionState.RECONNECTING: BridgeStateEvent.TRANSIENT_DISCONNECT,
+            WebsocketConnectionState.DISCONNECTING: BridgeStateEvent.TRANSIENT_DISCONNECT,
+            WebsocketConnectionState.AUTHENTICATION_FAILED: BridgeStateEvent.BAD_CREDENTIALS,
+            WebsocketConnectionState.FAILED: BridgeStateEvent.UNKNOWN_ERROR,
+        }.get(evt.state)
+        if bridge_state:
+            asyncio.create_task(self.push_bridge_state(bridge_state))
+
+            now = datetime.now()
+            if bridge_state == BridgeStateEvent.TRANSIENT_DISCONNECT:
+                # Wait for two minutes. if the bridge stays in TRANSIENT_DISCONNECT for that long,
+                # something terrible has happened (signald failed to restart, the internet broke,
+                # etc.)
+                async def wait_report_disconnected():
+                    await sleep(120)
+                    if (
+                        self._latest_non_transient_disconnect_state
+                        and now > self._latest_non_transient_disconnect_state
+                    ):
+                        asyncio.create_task(self.push_bridge_state(BridgeStateEvent.UNKNOWN_ERROR))
+                    else:
+                        self.log.info(
+                            "New state since last TRANSIENT_DISCONNECT push. "
+                            "Not transitioning to UNKNOWN_ERROR."
+                        )
+
+                asyncio.create_task(wait_report_disconnected())
+            else:
+                self._latest_non_transient_disconnect_state = now
+
+            self._websocket_connection_state = bridge_state
         else:
         else:
-            self.log.warning(f"Unrecognized listen action {evt.action}")
+            self.log.info(f"Websocket state {evt.state} seen. Will not report new Bridge State")
 
 
     async def _sync_puppet(self) -> None:
     async def _sync_puppet(self) -> None:
         puppet = await pu.Puppet.get_by_address(self.address)
         puppet = await pu.Puppet.get_by_address(self.address)