Przeglądaj źródła

Remove error_message_lock and encrypt error messages on Matrix

Tulir Asokan 3 lat temu
rodzic
commit
bf211b086f
1 zmienionych plików z 21 dodań i 22 usunięć
  1. 21 22
      mautrix_signal/signal.py

+ 21 - 22
mautrix_signal/signal.py

@@ -15,7 +15,7 @@
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Awaitable
 import asyncio
 import logging
 
@@ -33,7 +33,7 @@ from mausignald.types import (
     TypingMessage,
     WebsocketConnectionStateChangeEvent,
 )
-from mautrix.types import EventID, MessageType, RoomID
+from mautrix.types import EventID, MessageType, TextMessageEventContent
 from mautrix.util.logging import TraceLogger
 
 from . import portal as po, puppet as pu, user as u
@@ -51,14 +51,12 @@ class SignalHandler(SignaldClient):
     loop: asyncio.AbstractEventLoop
     data_dir: str
     delete_unknown_accounts: bool
-    error_message_lock: asyncio.Lock
-    error_message_events: dict[tuple[RoomID, Address, int], EventID | None]
+    error_message_events: dict[tuple[Address, str, int], Awaitable[EventID] | None]
 
     def __init__(self, bridge: "SignalBridge") -> None:
         super().__init__(bridge.config["signal.socket_path"], loop=bridge.loop)
         self.data_dir = bridge.config["signal.data_dir"]
         self.delete_unknown_accounts = bridge.config["signal.delete_unknown_accounts_on_start"]
-        self.error_message_lock = asyncio.Lock()
         self.error_message_events = {}
         self.add_event_handler(IncomingMessage, self.on_message)
         self.add_event_handler(ErrorMessage, self.on_error_message)
@@ -96,16 +94,14 @@ class SignalHandler(SignaldClient):
                 self.log.debug("Sync message includes groups meta, syncing groups...")
                 await user.sync_groups()
 
-        async with self.error_message_lock:
+        event_id_future = self.error_message_events.pop(
+            (sender.address, user.username, evt.timestamp), None
+        )
+        event_id = await event_id_future if event_id_future is not None else None
+        if event_id is not None:
             portal = await po.Portal.get_by_chat_id(sender.address, receiver=user.username)
-            if not portal or not portal.mxid:
-                return
-            error_message_event_key = (portal.mxid, sender.address, evt.timestamp)
-            if error_message_event_key in self.error_message_events:
-                event_id = self.error_message_events[error_message_event_key]
-                if event_id is not None:
-                    await sender.intent_for(portal).redact(portal.mxid, event_id)
-                del self.error_message_events[error_message_event_key]
+            if portal and portal.mxid:
+                await sender.intent_for(portal).redact(portal.mxid, event_id)
 
     async def on_error_message(self, err: ErrorMessage) -> None:
         sender = await pu.Puppet.get_by_address(Address.parse(err.data.sender))
@@ -117,9 +113,8 @@ class SignalHandler(SignaldClient):
         # Add the error to the error_message_events dictionary, then wait for 10 seconds until
         # sending an error. If a success for the timestamp comes in before the 10 seconds is up,
         # don't send the error message.
-        error_message_event_key = (portal.mxid, sender.address, err.data.timestamp)
-        async with self.error_message_lock:
-            self.error_message_events[error_message_event_key] = None
+        error_message_event_key = (sender.address, user.username, err.data.timestamp)
+        self.error_message_events[error_message_event_key] = None
 
         await asyncio.sleep(10)
 
@@ -127,12 +122,16 @@ class SignalHandler(SignaldClient):
             "There was an error receiving a message. Check your Signal app for missing messages. "
             f"{err.type}: {err.data.message}"
         )
-        async with self.error_message_lock:
-            if error_message_event_key in self.error_message_events:
-                event_id = await sender.intent_for(portal).send_text(
-                    portal.mxid, html=err_text, msgtype=MessageType.NOTICE
+        if error_message_event_key in self.error_message_events:
+            fut = self.error_message_events[error_message_event_key] = self.loop.create_future()
+            event_id = None
+            try:
+                event_id = await portal._send_message(
+                    intent=sender.intent_for(portal),
+                    content=TextMessageEventContent(body=err_text, msgtype=MessageType.NOTICE),
                 )
-                self.error_message_events[error_message_event_key] = event_id
+            finally:
+                fut.set_result(event_id)
 
     @staticmethod
     async def on_websocket_connection_state_change(