소스 검색

Merge pull request #211 from mautrix/sumner/bri-1793

inbound error messages: show errors after a delay
Sumner Evans 3 년 전
부모
커밋
7276cedac3
3개의 변경된 파일80개의 추가작업 그리고 2개의 파일을 삭제
  1. 7 0
      mausignald/signald.py
  2. 18 0
      mausignald/types.py
  3. 55 2
      mautrix_signal/signal.py

+ 7 - 0
mausignald/signald.py

@@ -18,6 +18,7 @@ from .types import (
     Address,
     Attachment,
     DeviceInfo,
+    ErrorMessage,
     GetIdentitiesResponse,
     Group,
     GroupID,
@@ -50,6 +51,7 @@ class SignaldClient(SignaldRPCClient):
         self._event_handlers = {}
         self._subscriptions = set()
         self.add_rpc_handler("IncomingMessage", self._parse_message)
+        self.add_rpc_handler("ProtocolInvalidMessageError", self._parse_error)
         self.add_rpc_handler("WebSocketConnectionState", self._websocket_connection_state_change)
         self.add_rpc_handler("version", self._log_version)
         self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
@@ -73,6 +75,11 @@ class SignaldClient(SignaldRPCClient):
                 except Exception:
                     self.log.exception("Exception in event handler")
 
+    async def _parse_error(self, data: dict[str, Any]) -> None:
+        if not data.get("error"):
+            return
+        await self._run_event_handler(ErrorMessage.deserialize(data))
+
     async def _parse_message(self, data: dict[str, Any]) -> None:
         event_type = data["type"]
         event_data = data["data"]

+ 18 - 0
mausignald/types.py

@@ -529,6 +529,24 @@ class IncomingMessage(SerializableAttrs):
     receipt_message: Optional[ReceiptMessage] = None
 
 
+@dataclass(kw_only=True)
+class ErrorMessageData(SerializableAttrs):
+    sender: str
+    timestamp: int
+    message: str
+    sender_device: int
+    content_hint: int
+
+
+@dataclass(kw_only=True)
+class ErrorMessage(SerializableAttrs):
+    type: str
+    version: str
+    data: ErrorMessageData
+    error: bool
+    account: str
+
+
 class WebsocketConnectionState(SerializableEnum):
     # States from signald itself
     DISCONNECTED = "DISCONNECTED"

+ 55 - 2
mautrix_signal/signal.py

@@ -15,13 +15,14 @@
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Awaitable
 import asyncio
 import logging
 
 from mausignald import SignaldClient
 from mausignald.types import (
     Address,
+    ErrorMessage,
     IncomingMessage,
     MessageData,
     OfferMessageType,
@@ -32,7 +33,7 @@ from mausignald.types import (
     TypingMessage,
     WebsocketConnectionStateChangeEvent,
 )
-from mautrix.types import MessageType
+from mautrix.types import EventID, MessageType, TextMessageEventContent
 from mautrix.util.logging import TraceLogger
 
 from . import portal as po, puppet as pu, user as u
@@ -50,12 +51,15 @@ class SignalHandler(SignaldClient):
     loop: asyncio.AbstractEventLoop
     data_dir: str
     delete_unknown_accounts: bool
+    error_message_events: dict[tuple[Address, str, int], Awaitable[EventID] | None]
 
     def __init__(self, bridge: "SignalBridge") -> None:
         super().__init__(bridge.config["signal.socket_path"], loop=bridge.loop)
         self.data_dir = bridge.config["signal.data_dir"]
         self.delete_unknown_accounts = bridge.config["signal.delete_unknown_accounts_on_start"]
+        self.error_message_events = {}
         self.add_event_handler(IncomingMessage, self.on_message)
+        self.add_event_handler(ErrorMessage, self.on_error_message)
         self.add_event_handler(
             WebsocketConnectionStateChangeEvent, self.on_websocket_connection_state_change
         )
@@ -90,6 +94,55 @@ class SignalHandler(SignaldClient):
                 self.log.debug("Sync message includes groups meta, syncing groups...")
                 await user.sync_groups()
 
+        try:
+            event_id_future = self.error_message_events.pop(
+                (sender.address, user.username, evt.timestamp)
+            )
+        except KeyError:
+            pass
+        else:
+            self.log.debug(f"Got previously errored message {evt.timestamp} from {sender.address}")
+            event_id = await event_id_future if event_id_future is not None else None
+            if event_id is not None:
+                portal = await po.Portal.get_by_chat_id(sender.address, receiver=user.username)
+                if portal and portal.mxid:
+                    await sender.intent_for(portal).redact(portal.mxid, event_id)
+
+    async def on_error_message(self, err: ErrorMessage) -> None:
+        self.log.warning(
+            f"Error reading message from {err.data.sender}/{err.data.sender_device} "
+            f"(timestamp: {err.data.timestamp}, content hint: {err.data.content_hint}): "
+            f"{err.data.message}"
+        )
+
+        sender = await pu.Puppet.get_by_address(Address.parse(err.data.sender))
+        user = await u.User.get_by_username(err.account)
+        portal = await po.Portal.get_by_chat_id(sender.address, receiver=user.username)
+        if not portal or not portal.mxid:
+            return
+
+        # Add the error to the error_message_events dictionary, then wait for 10 seconds until
+        # sending an error. If a success for the timestamp comes in before the 10 seconds is up,
+        # don't send the error message.
+        error_message_event_key = (sender.address, user.username, err.data.timestamp)
+        self.error_message_events[error_message_event_key] = None
+
+        await asyncio.sleep(10)
+
+        err_text = (
+            "There was an error receiving a message. Check your Signal app for missing messages."
+        )
+        if error_message_event_key in self.error_message_events:
+            fut = self.error_message_events[error_message_event_key] = self.loop.create_future()
+            event_id = None
+            try:
+                event_id = await portal._send_message(
+                    intent=sender.intent_for(portal),
+                    content=TextMessageEventContent(body=err_text, msgtype=MessageType.NOTICE),
+                )
+            finally:
+                fut.set_result(event_id)
+
     @staticmethod
     async def on_websocket_connection_state_change(
         evt: WebsocketConnectionStateChangeEvent,