Bläddra i källkod

Move most disappearing message code to mautrix-python

Tulir Asokan 3 år sedan
förälder
incheckning
8d47341753

+ 1 - 1
mautrix_signal/__main__.py

@@ -68,12 +68,12 @@ class SignalBridge(Bridge):
         User.init_cls(self)
         self.add_startup_actions(Puppet.init_cls(self))
         Portal.init_cls(self)
+        self.add_startup_actions(Portal.restart_scheduled_disappearing())
         if self.config["bridge.resend_bridge_info"]:
             self.add_startup_actions(self.resend_bridge_info())
         self.add_startup_actions(self.signal.start())
         await super().start()
         self.periodic_sync_task = asyncio.create_task(self._periodic_sync_loop())
-        asyncio.create_task(Portal.start_disappearing_message_expirations())
 
     @staticmethod
     async def _actual_periodic_sync_loop(log: logging.Logger, interval: int) -> None:

+ 8 - 16
mautrix_signal/db/disappearing_message.py

@@ -17,46 +17,38 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING, ClassVar
 
-from attr import dataclass
 import asyncpg
 
+from mautrix.bridge import AbstractDisappearingMessage
 from mautrix.types import EventID, RoomID
 from mautrix.util.async_db import Database
 
 fake_db = Database.create("") if TYPE_CHECKING else None
 
 
-@dataclass
-class DisappearingMessage:
+class DisappearingMessage(AbstractDisappearingMessage):
     db: ClassVar[Database] = fake_db
 
-    room_id: RoomID
-    mxid: EventID
-    expiration_seconds: int
-    expiration_ts: int | None = None
-
     async def insert(self) -> None:
         q = """
             INSERT INTO disappearing_message (room_id, mxid, expiration_seconds, expiration_ts)
             VALUES ($1, $2, $3, $4)
         """
         await self.db.execute(
-            q, self.room_id, self.mxid, self.expiration_seconds, self.expiration_ts
+            q, self.room_id, self.event_id, self.expiration_seconds, self.expiration_ts
         )
 
-    async def set_expiration_ts(self, ts: int) -> None:
+    async def update(self) -> None:
         q = "UPDATE disappearing_message SET expiration_ts=$3 WHERE room_id=$1 AND mxid=$2"
-        self.expiration_ts = ts
-        await self.db.execute(q, self.room_id, self.mxid, self.expiration_ts)
+        await self.db.execute(q, self.room_id, self.event_id, self.expiration_ts)
 
-    @classmethod
-    async def delete(cls, room_id: RoomID, event_id: EventID) -> None:
+    async def delete(self) -> None:
         q = "DELETE from disappearing_message WHERE room_id=$1 AND mxid=$2"
-        await cls.db.execute(q, room_id, event_id)
+        await self.db.execute(q, self.room_id, self.event_id)
 
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> DisappearingMessage:
-        return cls(**row)
+        return cls(row["room_id"], row["mxid"], row["expiration_seconds"], row["expiration_ts"])
 
     @classmethod
     async def get(cls, room_id: RoomID, event_id: EventID) -> DisappearingMessage | None:

+ 0 - 2
mautrix_signal/matrix.py

@@ -128,8 +128,6 @@ class MatrixHandler(BaseMatrixHandler):
         event_id: EventID,
         data: SingleReceiptEventContent,
     ) -> None:
-        await portal.handle_read_receipt(event_id, data)
-
         message = await DBMessage.get_by_mxid(event_id, portal.mxid)
         if not message:
             return

+ 23 - 116
mautrix_signal/portal.py

@@ -118,6 +118,7 @@ class Portal(DBPortal, BasePortal):
     by_mxid: dict[RoomID, Portal] = {}
     by_chat_id: dict[tuple[str, str], Portal] = {}
     _sticker_meta_cache: dict[str, StickerPack] = {}
+    disappearing_msg_class = DisappearingMessage
     config: Config
     matrix: m.MatrixHandler
     signal: s.SignalHandler
@@ -149,19 +150,20 @@ class Portal(DBPortal, BasePortal):
         expiration_time: int | None = None,
     ) -> None:
         super().__init__(
-            chat_id,
-            receiver,
-            mxid,
-            name,
-            avatar_hash,
-            avatar_url,
-            name_set,
-            avatar_set,
-            revision,
-            encrypted,
-            relay_user_id,
-            expiration_time,
+            chat_id=chat_id,
+            receiver=receiver,
+            mxid=mxid,
+            name=name,
+            avatar_hash=avatar_hash,
+            avatar_url=avatar_url,
+            name_set=name_set,
+            avatar_set=avatar_set,
+            revision=revision,
+            encrypted=encrypted,
+            relay_user_id=relay_user_id,
+            expiration_time=expiration_time,
         )
+        BasePortal.__init__(self)
         self._create_room_lock = asyncio.Lock()
         self.log = self.log.getChild(self.chat_id_str)
         self._main_intent = None
@@ -183,6 +185,10 @@ class Portal(DBPortal, BasePortal):
     def is_direct(self) -> bool:
         return isinstance(self.chat_id, Address)
 
+    @property
+    def disappearing_enabled(self) -> bool:
+        return self.is_direct or self.config["signal.enable_disappearing_messages_in_groups"]
+
     def handle_uuid_receive(self, uuid: UUID) -> None:
         if not self.is_direct or self.chat_id.uuid:
             raise ValueError(
@@ -203,16 +209,6 @@ class Portal(DBPortal, BasePortal):
         BasePortal.bridge = bridge
         cls.private_chat_portal_meta = cls.config["bridge.private_chat_portal_meta"]
 
-    @classmethod
-    async def start_disappearing_message_expirations(cls):
-        await asyncio.gather(
-            *(
-                cls._expire_event(dm.room_id, dm.mxid, restart=True)
-                for dm in await DisappearingMessage.get_all_scheduled()
-                if dm.expiration_ts
-            )
-        )
-
     # region Misc
 
     async def _send_delivery_receipt(self, event_id: EventID) -> None:
@@ -418,12 +414,11 @@ class Portal(DBPortal, BasePortal):
                     pass
 
             # Handle disappearing messages
-            if self.expiration_time and (
-                self.is_direct or self.config["signal.enable_disappearing_messages_in_groups"]
-            ):
+            if self.expiration_time and self.disappearing_enabled:
                 dm = DisappearingMessage(self.mxid, event_id, self.expiration_time)
+                dm.start_timer()
                 await dm.insert()
-                await Portal._expire_event(dm.room_id, dm.mxid)
+                await self._disappear_event(dm)
 
     async def _signal_send_with_retries(
         self,
@@ -726,89 +721,6 @@ class Portal(DBPortal, BasePortal):
             except FileNotFoundError:
                 pass
 
-    @classmethod
-    async def _expire_event(cls, room_id: RoomID, event_id: EventID, restart: bool = False):
-        """
-        Schedule a task to expire a an event. This should only be called once the message has been
-        read, as the timer for redaction will start immediately, and there is no (supported)
-        mechanism to stop the countdown, even after bridge restart.
-
-        If there is already an expiration event for the given ``room_id`` and ``event_id``, it will
-        not schedule a new task.
-        """
-        portal = await cls.get_by_mxid(room_id)
-        if not portal:
-            raise AttributeError(f"No portal found for {room_id}")
-
-        # Need a lock around this critical section to make sure that we know if a task has been
-        # created for this particular (room_id, event_id) combination.
-        async with portal._expiration_lock:
-            if (
-                not portal.is_direct
-                and not cls.config["signal.enable_disappearing_messages_in_groups"]
-            ):
-                portal.log.debug(
-                    "Not expiring event in group message since "
-                    "signal.enable_disappearing_messages_in_groups is not enabled."
-                )
-                await DisappearingMessage.delete(room_id, event_id)
-                return
-
-            disappearing_message = await DisappearingMessage.get(room_id, event_id)
-            if disappearing_message is None:
-                return
-            wait = disappearing_message.expiration_seconds
-            now = time.time()
-            # If there is an expiration_ts, then there's already a task going, or it's a restart.
-            # If it's a restart, then restart the countdown. This is fairly likely to occur if the
-            # disappearance timeout is weeks.
-            if disappearing_message.expiration_ts:
-                if not restart:
-                    portal.log.debug(f"Expiration task already exists for {event_id} in {room_id}")
-                    return
-                portal.log.debug(f"Resuming expiration for {event_id} in {room_id}")
-                wait = (disappearing_message.expiration_ts / 1000) - now
-            if wait < 0:
-                wait = 0
-
-            # Spawn the actual expiration task.
-            asyncio.create_task(cls._expire_event_task(portal, event_id, wait))
-
-            # Set the expiration_ts only after we have actually created the expiration task.
-            if not disappearing_message.expiration_ts:
-                await disappearing_message.set_expiration_ts(int((now + wait) * 1000))
-
-    @classmethod
-    async def _expire_event_task(cls, portal: Portal, event_id: EventID, wait: float):
-        portal.log.debug(f"Redacting {event_id} in {wait} seconds")
-        await asyncio.sleep(wait)
-
-        async with portal._expiration_lock:
-            if not await DisappearingMessage.get(portal.mxid, event_id):
-                portal.log.debug(
-                    f"{event_id} no longer in disappearing messages list, not redacting"
-                )
-                return
-
-            portal.log.debug(f"Redacting {event_id} because it was expired")
-            try:
-                await portal.main_intent.redact(portal.mxid, event_id)
-                portal.log.debug(f"Redacted {event_id} successfully")
-            except Exception as e:
-                portal.log.warning(f"Redacting expired event {event_id} failed", e)
-            finally:
-                await DisappearingMessage.delete(portal.mxid, event_id)
-
-    async def handle_read_receipt(self, event_id: EventID, data: SingleReceiptEventContent):
-        # Start the redaction timers for all the disappearing messages in the room when the user
-        # reads the room. This is the behavior of the Signal clients.
-        await asyncio.gather(
-            *(
-                Portal._expire_event(dm.room_id, dm.mxid)
-                for dm in await DisappearingMessage.get_unscheduled_for_room(self.mxid)
-            )
-        )
-
     # endregion
     # region Signal event handling
 
@@ -944,13 +856,8 @@ class Portal(DBPortal, BasePortal):
             await self._send_delivery_receipt(event_id)
             self.log.debug(f"Handled Signal message {message.timestamp} -> {event_id}")
 
-            if message.expires_in_seconds and (
-                self.is_direct or self.config["signal.enable_disappearing_messages_in_groups"]
-            ):
-                disappearing_message = DisappearingMessage(
-                    self.mxid, event_id, message.expires_in_seconds
-                )
-                await disappearing_message.insert()
+            if message.expires_in_seconds and self.disappearing_enabled:
+                await DisappearingMessage(self.mxid, event_id, message.expires_in_seconds).insert()
                 self.log.debug(
                     f"{event_id} set to be redacted {message.expires_in_seconds} seconds after "
                     "room is read"

+ 1 - 1
requirements.txt

@@ -4,5 +4,5 @@ commonmark>=0.8,<0.10
 aiohttp>=3,<4
 yarl>=1,<2
 attrs>=19.1
-mautrix>=0.14.6,<0.15
+mautrix>=0.14.8,<0.15
 asyncpg>=0.20,<0.26