Forráskód Böngészése

disappearing messages: don't spawn expiration events for every message every read receipt

Sumner Evans 3 éve
szülő
commit
b1e27324a6
1 módosított fájl, 60 hozzáadás és 42 törlés
  1. 60 42
      mautrix_signal/portal.py

+ 60 - 42
mautrix_signal/portal.py

@@ -158,9 +158,11 @@ class Portal(DBPortal, BasePortal):
 
     @classmethod
     async def start_disappearing_message_expirations(cls):
-        for dm in await DisappearingMessage.get_all():
-            if dm.expiration_ts:
-                asyncio.create_task(cls._expire_event(dm))
+        await asyncio.gather(*(
+            cls._expire_event(dm.room_id, dm.mxid, restart=True)
+            for dm in await DisappearingMessage.get_all()
+            if dm.expiration_ts
+        ))
 
     # region Misc
 
@@ -348,11 +350,9 @@ class Portal(DBPortal, BasePortal):
                     or self.config["signal.enable_disappearing_messages_in_groups"]
                 )
             ):
-                disappearing_message = DisappearingMessage(
-                    self.mxid, event_id, self.expiration_time
-                )
-                await disappearing_message.insert()
-                asyncio.create_task(Portal._expire_event(disappearing_message))
+                dm = DisappearingMessage(self.mxid, event_id, self.expiration_time)
+                await dm.insert()
+                await Portal._expire_event(dm.room_id, dm.mxid)
 
     async def handle_matrix_reaction(self, sender: 'u.User', event_id: EventID,
                                      reacting_to: EventID, emoji: str) -> None:
@@ -549,50 +549,66 @@ class Portal(DBPortal, BasePortal):
                 pass
 
     @classmethod
-    async def _expire_event(cls, disappearing_message: DisappearingMessage):
+    async def _expire_event(cls, room_id: RoomID, event_id: EventID, restart: bool = False):
         """
-        Expire a :class:`DisappearingMessage`. This should only be called once the message has been
+        Schedule a task to expire a an event. This should only be called once the message has been
         read, as the timer for redaction will start immediately, and there is no (supported)
         mechanism to stop the countdown, even after bridge restart.
-        """
-        room_id = disappearing_message.room_id
-        event_id = disappearing_message.mxid
 
+        If there is already an expiration event for the given ``room_id`` and ``event_id``, it will
+        not schedule a new task.
+        """
         portal = await cls.get_by_mxid(room_id)
         if not portal:
             raise AttributeError(f"No portal found for {room_id}")
 
-        if (
-            not portal.is_direct and not
-            cls.config["signal.enable_disappearing_messages_in_groups"]
-        ):
-            portal.log.debug(
-                "Not expiring event in group message since "
-                "signal.enable_disappearing_messages_in_groups is not enabled."
-            )
-            await disappearing_message.delete(room_id, event_id)
-            return
+        # Need a lock around this critical section to make sure that we know if a task has been
+        # created for this particular (room_id, event_id) combination.
+        async with portal._expiration_lock:
+            if (
+                not portal.is_direct and not
+                cls.config["signal.enable_disappearing_messages_in_groups"]
+            ):
+                portal.log.debug(
+                    "Not expiring event in group message since "
+                    "signal.enable_disappearing_messages_in_groups is not enabled."
+                )
+                await DisappearingMessage.delete(room_id, event_id)
+                return
 
-        wait = disappearing_message.expiration_seconds
-        # If there is an expiration_ts, then there was probably a bridge restart, so we have to
-        # resume the countdown. This is fairly likely to occur if the disappearance timeout is
-        # weeks.
-        # If there is not an expiration_ts, then set one.
-        now = time.time()
-        if disappearing_message.expiration_ts is not None:
-            wait = (disappearing_message.expiration_ts / 1000) - now
-        else:
-            disappearing_message.expiration_ts = int((now + wait) * 1000)
-            await disappearing_message.update()
+            disappearing_message = await DisappearingMessage.get(room_id, event_id)
+            if disappearing_message is None:
+                return
+            wait = disappearing_message.expiration_seconds
+            now = time.time()
+            # If there is an expiration_ts, then there's already a task going, or it's a restart.
+            # If it's a restart, then restart the countdown. This is fairly likely to occur if the
+            # disappearance timeout is weeks.
+            if disappearing_message.expiration_ts:
+                if not restart:
+                    portal.log.debug(f"Expiration task already exists for {event_id} in {room_id}")
+                    return
+                portal.log.debug(f"Resuming expiration for {event_id} in {room_id}")
+                wait = (disappearing_message.expiration_ts / 1000) - now
+            if wait < 0:
+                wait = 0
+
+            # Spawn the actual expiration task.
+            asyncio.create_task(cls._expire_event_task(portal, event_id, wait))
+
+            # Set the expiration_ts only after we have actually created the expiration task.
+            if not disappearing_message.expiration_ts:
+                disappearing_message.expiration_ts = int((now + wait) * 1000)
+                await disappearing_message.update()
 
-        if wait < 0:
-            wait = 0
 
+    @classmethod
+    async def _expire_event_task(cls, portal: 'Portal', event_id: EventID, wait: float):
         portal.log.debug(f"Redacting {event_id} in {wait} seconds")
         await asyncio.sleep(wait)
 
         async with portal._expiration_lock:
-            if not await DisappearingMessage.get(room_id, event_id):
+            if not await DisappearingMessage.get(portal.mxid, event_id):
                 portal.log.debug(
                     f"{event_id} no longer in disappearing messages list, not redacting"
                 )
@@ -600,18 +616,20 @@ class Portal(DBPortal, BasePortal):
 
             portal.log.debug(f"Redacting {event_id} because it was expired")
             try:
-                await portal.main_intent.redact(room_id, event_id)
+                await portal.main_intent.redact(portal.mxid, event_id)
                 portal.log.debug(f"Redacted {event_id} successfully")
             except Exception as e:
-                portal.log.warning("Redacting expired event didn't work", e)
+                portal.log.warning(f"Redacting expired event {event_id} failed", e)
             finally:
-                await DisappearingMessage.delete(room_id, event_id)
+                await DisappearingMessage.delete(portal.mxid, event_id)
 
     async def handle_read_receipt(self, event_id: EventID, data: SingleReceiptEventContent):
         # Start the redaction timers for all of the disappearing messages in the room when the user
         # reads the room. This is the behavior of the Signal clients.
-        for disappearing_message in await DisappearingMessage.get_all_for_room(self.mxid):
-            asyncio.create_task(Portal._expire_event(disappearing_message))
+        await asyncio.gather(*(
+            Portal._expire_event(dm.room_id, dm.mxid)
+            for dm in await DisappearingMessage.get_all_for_room(self.mxid)
+        ))
 
     # endregion
     # region Signal event handling