Browse Source

inbound error messages: show errors after a delay

Sometimes, signald asks for a resend and in those cases, we don't want to show any error.
Sumner Evans 3 years ago
parent
commit
b62f4c261c
3 changed files with 70 additions and 1 deletions
  1. 7 0
      mausignald/signald.py
  2. 18 0
      mausignald/types.py
  3. 45 1
      mautrix_signal/signal.py

+ 7 - 0
mausignald/signald.py

@@ -18,6 +18,7 @@ from .types import (
     Address,
     Attachment,
     DeviceInfo,
+    ErrorMessage,
     GetIdentitiesResponse,
     Group,
     GroupID,
@@ -50,6 +51,7 @@ class SignaldClient(SignaldRPCClient):
         self._event_handlers = {}
         self._subscriptions = set()
         self.add_rpc_handler("IncomingMessage", self._parse_message)
+        self.add_rpc_handler("ProtocolInvalidMessageError", self._parse_error)
         self.add_rpc_handler("WebSocketConnectionState", self._websocket_connection_state_change)
         self.add_rpc_handler("version", self._log_version)
         self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
@@ -73,6 +75,11 @@ class SignaldClient(SignaldRPCClient):
                 except Exception:
                     self.log.exception("Exception in event handler")
 
+    async def _parse_error(self, data: dict[str, Any]) -> None:
+        if not data.get("error"):
+            return
+        await self._run_event_handler(ErrorMessage.deserialize(data))
+
     async def _parse_message(self, data: dict[str, Any]) -> None:
         event_type = data["type"]
         event_data = data["data"]

+ 18 - 0
mausignald/types.py

@@ -529,6 +529,24 @@ class IncomingMessage(SerializableAttrs):
     receipt_message: Optional[ReceiptMessage] = None
 
 
+@dataclass(kw_only=True)
+class ErrorMessageData(SerializableAttrs):
+    sender: str
+    timestamp: int
+    message: str
+    sender_device: int
+    content_hint: int
+
+
+@dataclass(kw_only=True)
+class ErrorMessage(SerializableAttrs):
+    type: str
+    version: str
+    data: ErrorMessageData
+    error: bool
+    account: str
+
+
 class WebsocketConnectionState(SerializableEnum):
     # States from signald itself
     DISCONNECTED = "DISCONNECTED"

+ 45 - 1
mautrix_signal/signal.py

@@ -22,6 +22,7 @@ import logging
 from mausignald import SignaldClient
 from mausignald.types import (
     Address,
+    ErrorMessage,
     IncomingMessage,
     MessageData,
     OfferMessageType,
@@ -32,7 +33,7 @@ from mausignald.types import (
     TypingMessage,
     WebsocketConnectionStateChangeEvent,
 )
-from mautrix.types import MessageType
+from mautrix.types import EventID, MessageType, RoomID
 from mautrix.util.logging import TraceLogger
 
 from . import portal as po, puppet as pu, user as u
@@ -50,12 +51,17 @@ class SignalHandler(SignaldClient):
     loop: asyncio.AbstractEventLoop
     data_dir: str
     delete_unknown_accounts: bool
+    error_message_lock: asyncio.Lock
+    error_message_events: dict[tuple[RoomID, Address, int], EventID | None]
 
     def __init__(self, bridge: "SignalBridge") -> None:
         super().__init__(bridge.config["signal.socket_path"], loop=bridge.loop)
         self.data_dir = bridge.config["signal.data_dir"]
         self.delete_unknown_accounts = bridge.config["signal.delete_unknown_accounts_on_start"]
+        self.error_message_lock = asyncio.Lock()
+        self.error_message_events = {}
         self.add_event_handler(IncomingMessage, self.on_message)
+        self.add_event_handler(ErrorMessage, self.on_error_message)
         self.add_event_handler(
             WebsocketConnectionStateChangeEvent, self.on_websocket_connection_state_change
         )
@@ -90,6 +96,44 @@ class SignalHandler(SignaldClient):
                 self.log.debug("Sync message includes groups meta, syncing groups...")
                 await user.sync_groups()
 
+        async with self.error_message_lock:
+            portal = await po.Portal.get_by_chat_id(sender.address, receiver=user.username)
+            if not portal or not portal.mxid:
+                return
+            error_message_event_key = (portal.mxid, sender.address, evt.timestamp)
+            if error_message_event_key in self.error_message_events:
+                event_id = self.error_message_events[error_message_event_key]
+                if event_id is not None:
+                    await sender.intent_for(portal).redact(portal.mxid, event_id)
+                del self.error_message_events[error_message_event_key]
+
+    async def on_error_message(self, err: ErrorMessage) -> None:
+        sender = await pu.Puppet.get_by_address(Address.parse(err.data.sender))
+        user = await u.User.get_by_username(err.account)
+        portal = await po.Portal.get_by_chat_id(sender.address, receiver=user.username)
+        if not portal or not portal.mxid:
+            return
+
+        # Add the error to the error_message_events dictionary, then wait for 10 seconds until
+        # sending an error. If a success for the timestamp comes in before the 10 seconds is up,
+        # don't send the error message.
+        error_message_event_key = (portal.mxid, sender.address, err.data.timestamp)
+        async with self.error_message_lock:
+            self.error_message_events[error_message_event_key] = None
+
+        await asyncio.sleep(10)
+
+        err_text = (
+            "There was an error receiving a message. Check your Signal app for missing messages. "
+            f"{err.type}: {err.data.message}"
+        )
+        async with self.error_message_lock:
+            if error_message_event_key in self.error_message_events:
+                event_id = await sender.intent_for(portal).send_text(
+                    portal.mxid, html=err_text, msgtype=MessageType.NOTICE
+                )
+                self.error_message_events[error_message_event_key] = event_id
+
     @staticmethod
     async def on_websocket_connection_state_change(
         evt: WebsocketConnectionStateChangeEvent,