Pārlūkot izejas kodu

Make disappearing message db queries more efficient

Tulir Asokan 3 gadi atpakaļ
vecāks
revīzija
7be41c1d39

+ 19 - 27
mautrix_signal/db/disappearing_message.py

@@ -15,7 +15,7 @@
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, ClassVar, List, Optional
+from typing import TYPE_CHECKING, ClassVar
 
 from attr import dataclass
 import asyncpg
@@ -33,29 +33,21 @@ class DisappearingMessage:
     room_id: RoomID
     mxid: EventID
     expiration_seconds: int
-    expiration_ts: Optional[int] = None
+    expiration_ts: int | None = None
 
     async def insert(self) -> None:
         q = """
-        INSERT INTO disappearing_message (room_id, mxid, expiration_seconds, expiration_ts)
-        VALUES ($1, $2, $3, $4)
+            INSERT INTO disappearing_message (room_id, mxid, expiration_seconds, expiration_ts)
+            VALUES ($1, $2, $3, $4)
         """
         await self.db.execute(
             q, self.room_id, self.mxid, self.expiration_seconds, self.expiration_ts
         )
 
-    async def update(self) -> None:
-        q = """
-        UPDATE disappearing_message
-        SET expiration_seconds=$3, expiration_ts=$4
-        WHERE room_id=$1 AND mxid=$2
-        """
-        try:
-            await self.db.execute(
-                q, self.room_id, self.mxid, self.expiration_seconds, self.expiration_ts
-            )
-        except Exception as e:
-            print(e)
+    async def set_expiration_ts(self, ts: int) -> None:
+        q = "UPDATE disappearing_message SET expiration_ts=$3 WHERE room_id=$1 AND mxid=$2"
+        self.expiration_ts = ts
+        await self.db.execute(q, self.room_id, self.mxid, self.expiration_ts)
 
     @classmethod
     async def delete(cls, room_id: RoomID, event_id: EventID) -> None:
@@ -67,12 +59,10 @@ class DisappearingMessage:
         return cls(**row)
 
     @classmethod
-    async def get(cls, room_id: RoomID, event_id: EventID) -> Optional[DisappearingMessage]:
+    async def get(cls, room_id: RoomID, event_id: EventID) -> DisappearingMessage | None:
         q = """
-        SELECT room_id, mxid, expiration_seconds, expiration_ts
-          FROM disappearing_message
-         WHERE room_id = $1
-           AND mxid = $2
+            SELECT room_id, mxid, expiration_seconds, expiration_ts FROM disappearing_message
+            WHERE room_id=$1 AND mxid=$2
         """
         try:
             return cls._from_row(await cls.db.fetchrow(q, room_id, event_id))
@@ -80,15 +70,17 @@ class DisappearingMessage:
             return None
 
     @classmethod
-    async def get_all(cls) -> List[DisappearingMessage]:
-        q = "SELECT room_id, mxid, expiration_seconds, expiration_ts FROM disappearing_message"
+    async def get_all_scheduled(cls) -> list[DisappearingMessage]:
+        q = """
+            SELECT room_id, mxid, expiration_seconds, expiration_ts FROM disappearing_message
+            WHERE expiration_ts IS NOT NULL
+        """
         return [cls._from_row(r) for r in await cls.db.fetch(q)]
 
     @classmethod
-    async def get_all_for_room(cls, room_id: RoomID) -> List[DisappearingMessage]:
+    async def get_unscheduled_for_room(cls, room_id: RoomID) -> list[DisappearingMessage]:
         q = """
-        SELECT room_id, mxid, expiration_seconds, expiration_ts
-          FROM disappearing_message
-         WHERE room_id = $1
+            SELECT room_id, mxid, expiration_seconds, expiration_ts FROM disappearing_message
+            WHERE room_id = $1 AND expiration_ts IS NULL
         """
         return [cls._from_row(r) for r in await cls.db.fetch(q, room_id)]

+ 11 - 10
mautrix_signal/matrix.py

@@ -13,7 +13,9 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import TYPE_CHECKING, List, Union
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
 
 from mautrix.bridge import BaseMatrixHandler
 from mautrix.types import (
@@ -32,7 +34,6 @@ from mautrix.types import (
     TypingEvent,
     UserID,
 )
-from mautrix_signal.db.disappearing_message import DisappearingMessage
 
 from . import portal as po, signal as s, user as u
 from .db import Message as DBMessage
@@ -42,7 +43,7 @@ if TYPE_CHECKING:
 
 
 class MatrixHandler(BaseMatrixHandler):
-    signal: "s.SignalHandler"
+    signal: s.SignalHandler
 
     def __init__(self, bridge: "SignalBridge") -> None:
         prefix, suffix = bridge.config["bridge.username_template"].format(userid=":").split(":")
@@ -53,7 +54,7 @@ class MatrixHandler(BaseMatrixHandler):
 
         super().__init__(bridge=bridge)
 
-    async def send_welcome_message(self, room_id: RoomID, inviter: "u.User") -> None:
+    async def send_welcome_message(self, room_id: RoomID, inviter: u.User) -> None:
         await super().send_welcome_message(room_id, inviter)
         if not inviter.notice_room:
             inviter.notice_room = room_id
@@ -122,8 +123,8 @@ class MatrixHandler(BaseMatrixHandler):
 
     async def handle_read_receipt(
         self,
-        user: "u.User",
-        portal: "po.Portal",
+        user: u.User,
+        portal: po.Portal,
         event_id: EventID,
         data: SingleReceiptEventContent,
     ) -> None:
@@ -138,7 +139,7 @@ class MatrixHandler(BaseMatrixHandler):
             user.username, message.sender, timestamps=[message.timestamp], when=data.ts, read=True
         )
 
-    async def handle_typing(self, room_id: RoomID, typing: List[UserID]) -> None:
+    async def handle_typing(self, room_id: RoomID, typing: list[UserID]) -> None:
         pass
         # portal = await po.Portal.get_by_mxid(room_id)
         # if not portal:
@@ -159,7 +160,7 @@ class MatrixHandler(BaseMatrixHandler):
             await self.handle_redaction(evt.room_id, evt.sender, evt.redacts, evt.event_id)
 
     async def handle_ephemeral_event(
-        self, evt: Union[ReceiptEvent, PresenceEvent, TypingEvent]
+        self, evt: ReceiptEvent | PresenceEvent | TypingEvent
     ) -> None:
         if evt.type == EventType.TYPING:
             await self.handle_typing(evt.room_id, evt.content.user_ids)
@@ -182,8 +183,8 @@ class MatrixHandler(BaseMatrixHandler):
         elif evt.type == EventType.ROOM_AVATAR:
             await portal.handle_matrix_avatar(user, evt.content.url)
 
-    async def allow_message(self, user: "u.User") -> bool:
+    async def allow_message(self, user: u.User) -> bool:
         return user.relay_whitelisted
 
-    async def allow_bridging_message(self, user: "u.User", portal: "po.Portal") -> bool:
+    async def allow_bridging_message(self, user: u.User, portal: po.Portal) -> bool:
         return portal.has_relay or await user.is_logged_in()

+ 4 - 5
mautrix_signal/portal.py

@@ -204,7 +204,7 @@ class Portal(DBPortal, BasePortal):
         await asyncio.gather(
             *(
                 cls._expire_event(dm.room_id, dm.mxid, restart=True)
-                for dm in await DisappearingMessage.get_all()
+                for dm in await DisappearingMessage.get_all_scheduled()
                 if dm.expiration_ts
             )
         )
@@ -731,8 +731,7 @@ class Portal(DBPortal, BasePortal):
 
             # Set the expiration_ts only after we have actually created the expiration task.
             if not disappearing_message.expiration_ts:
-                disappearing_message.expiration_ts = int((now + wait) * 1000)
-                await disappearing_message.update()
+                await disappearing_message.set_expiration_ts(int((now + wait) * 1000))
 
     @classmethod
     async def _expire_event_task(cls, portal: Portal, event_id: EventID, wait: float):
@@ -756,12 +755,12 @@ class Portal(DBPortal, BasePortal):
                 await DisappearingMessage.delete(portal.mxid, event_id)
 
     async def handle_read_receipt(self, event_id: EventID, data: SingleReceiptEventContent):
-        # Start the redaction timers for all of the disappearing messages in the room when the user
+        # Start the redaction timers for all the disappearing messages in the room when the user
         # reads the room. This is the behavior of the Signal clients.
         await asyncio.gather(
             *(
                 Portal._expire_event(dm.room_id, dm.mxid)
-                for dm in await DisappearingMessage.get_all_for_room(self.mxid)
+                for dm in await DisappearingMessage.get_unscheduled_for_room(self.mxid)
             )
         )