Przeglądaj źródła

disappearing messages: handle inbound and outbound disappearing messages

When a message is sent or recieved in a Signal chat that has
disappearing messages enabled, it will automatically be redacted by the
bridge after the time configured on the room. However, note that the
countdown timer will only start once the room has been read.

The main mechanism for this is using async functions that just wait the
configured number of seconds before redacting. However, there we also
store all of the state necessary for determining when to redact a
message in the database in case a restart occurs. In the event of a
restart, we can resume waiting for the expiration (or redact immediately
if we are past the expiration timestamp).

This change also adds m.notice notifications for the user when the
disappearing messages setting changes on the Signal chat.
Sumner Evans 3 lat temu
rodzic
commit
f62f5252a9

+ 1 - 1
mautrix_signal/__main__.py

@@ -68,7 +68,7 @@ class SignalBridge(Bridge):
     async def start(self) -> None:
         User.init_cls(self)
         self.add_startup_actions(Puppet.init_cls(self))
-        Portal.init_cls(self)
+        self.add_startup_actions(Portal.init_cls(self))
         if self.config["bridge.resend_bridge_info"]:
             self.add_startup_actions(self.resend_bridge_info())
         self.add_startup_actions(self.signal.start())

+ 4 - 2
mautrix_signal/db/__init__.py

@@ -3,6 +3,7 @@ import sqlite3
 import uuid
 
 from .upgrade import upgrade_table
+from .disappearing_message import DisappearingMessage
 from .user import User
 from .puppet import Puppet
 from .portal import Portal
@@ -11,7 +12,7 @@ from .reaction import Reaction
 
 
 def init(db: Database) -> None:
-    for table in (User, Puppet, Portal, Message, Reaction):
+    for table in (User, Puppet, Portal, Message, Reaction, DisappearingMessage):
         table.db = db
 
 
@@ -19,4 +20,5 @@ def init(db: Database) -> None:
 sqlite3.register_adapter(uuid.UUID, lambda u: str(u))
 sqlite3.register_converter("UUID", lambda b: uuid.UUID(b.decode("utf-8") if isinstance(b, bytes) else b))
 
-__all__ = ["upgrade_table", "init", "User", "Puppet", "Portal", "Message", "Reaction"]
+__all__ = ["upgrade_table", "init", "User", "Puppet", "Portal", "Message", "Reaction",
+           "DisappearingMessage"]

+ 94 - 0
mautrix_signal/db/disappearing_message.py

@@ -0,0 +1,94 @@
+# mautrix-signal - A Matrix-Signal puppeting bridge
+# Copyright (C) 2021 Sumner Evans
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+from typing import ClassVar, List, Optional, TYPE_CHECKING
+
+from attr import dataclass
+import asyncpg
+
+from mautrix.types import RoomID, EventID
+from mautrix.util.async_db import Database
+
+fake_db = Database.create("") if TYPE_CHECKING else None
+
+
+@dataclass
+class DisappearingMessage:
+    db: ClassVar[Database] = fake_db
+
+    room_id: RoomID
+    mxid: EventID
+    expiration_seconds: int
+    expiration_ts: Optional[int] = None
+
+    async def insert(self) -> None:
+        q = """
+        INSERT INTO disappearing_messages (room_id, mxid, expiration_seconds, expiration_ts)
+        VALUES ($1, $2, $3, $4)
+        """
+        await self.db.execute(q, self.room_id, self.mxid, self.expiration_seconds,
+                              self.expiration_ts)
+
+    async def update(self) -> None:
+        q = """
+        UPDATE disappearing_messages
+        SET expiration_seconds=$3, expiration_ts=$4
+        WHERE room_id=$1 AND mxid=$2
+        """
+        try:
+            await self.db.execute(q, self.room_id, self.mxid, self.expiration_seconds,
+                              self.expiration_ts)
+        except Exception as e:
+            print(e)
+
+    @classmethod
+    async def delete(cls, room_id: RoomID, event_id: EventID) -> None:
+        q = "DELETE from disappearing_messages WHERE room_id=$1 AND mxid=$2"
+        await cls.db.execute(q, room_id, event_id)
+
+    @classmethod
+    async def delete_all(cls, room_id: RoomID) -> None:
+        await cls.db.execute("DELETE FROM message WHERE room_id=$1", room_id)
+
+    @classmethod
+    def _from_row(cls, row: asyncpg.Record) -> "DisappearingMessage":
+        return cls(**row)
+
+    @classmethod
+    async def get(cls, room_id: RoomID, event_id: EventID) -> Optional["DisappearingMessage"]:
+        q = """
+        SELECT room_id, mxid, expiration_seconds, expiration_ts
+          FROM disappearing_messages
+         WHERE room_id = $1
+           AND mxid = $2
+        """
+        try:
+            return cls._from_row(await cls.db.fetchrow(q, room_id, event_id))
+        except Exception:
+            return None
+
+    @classmethod
+    async def get_all(cls) -> List["DisappearingMessage"]:
+        q = "SELECT room_id, mxid, expiration_seconds, expiration_ts FROM disappearing_messages"
+        return [cls._from_row(r) for r in await cls.db.fetch(q)]
+
+    @classmethod
+    async def get_all_for_room(cls, room_id: RoomID) -> List["DisappearingMessage"]:
+        q = """
+        SELECT room_id, mxid, expiration_seconds, expiration_ts
+          FROM disappearing_messages
+         WHERE room_id = $1
+        """
+        return [cls._from_row(r) for r in await cls.db.fetch(q, room_id)]

+ 16 - 11
mautrix_signal/db/portal.py

@@ -42,6 +42,7 @@ class Portal:
     revision: int
     encrypted: bool
     relay_user_id: Optional[UserID]
+    expiration_time: Optional[int]
 
     @property
     def chat_id_str(self) -> str:
@@ -49,19 +50,23 @@ class Portal:
 
     async def insert(self) -> None:
         q = ("INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, "
-             "                    name_set, avatar_set, revision, encrypted, relay_user_id) "
-             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)")
+             "                    name_set, avatar_set, revision, encrypted, relay_user_id, "
+             "                    expiration_time) "
+             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)")
         await self.db.execute(q, self.chat_id_str, self.receiver, self.mxid, self.name,
                               self.avatar_hash, self.avatar_url, self.name_set, self.avatar_set,
-                              self.revision, self.encrypted, self.relay_user_id)
+                              self.revision, self.encrypted, self.relay_user_id,
+                              self.expiration_time)
 
     async def update(self) -> None:
         q = ("UPDATE portal SET mxid=$1, name=$2, avatar_hash=$3, avatar_url=$4, name_set=$5, "
-             "                  avatar_set=$6, revision=$7, encrypted=$8, relay_user_id=$9 "
-             "WHERE chat_id=$10 AND receiver=$11")
+             "                  avatar_set=$6, revision=$7, encrypted=$8, relay_user_id=$9, "
+             "                  expiration_time=$10"
+             "WHERE chat_id=$11 AND receiver=$12")
         await self.db.execute(q, self.mxid, self.name, self.avatar_hash, self.avatar_url,
                               self.name_set, self.avatar_set, self.revision, self.encrypted,
-                              self.relay_user_id, self.chat_id_str, self.receiver)
+                              self.relay_user_id, self.expiration_time, self.chat_id_str,
+                              self.receiver)
 
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> 'Portal':
@@ -74,7 +79,7 @@ class Portal:
     @classmethod
     async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE mxid=$1")
         row = await cls.db.fetchrow(q, mxid)
         if not row:
@@ -85,7 +90,7 @@ class Portal:
     async def get_by_chat_id(cls, chat_id: Union[GroupID, Address], receiver: str = ""
                              ) -> Optional['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE chat_id=$1 AND receiver=$2")
         row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
         if not row:
@@ -95,7 +100,7 @@ class Portal:
     @classmethod
     async def find_private_chats_of(cls, receiver: str) -> List['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE receiver=$1")
         rows = await cls.db.fetch(q, receiver)
         return [cls._from_row(row) for row in rows]
@@ -103,7 +108,7 @@ class Portal:
     @classmethod
     async def find_private_chats_with(cls, other_user: Address) -> List['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE chat_id=$1 AND receiver<>''")
         rows = await cls.db.fetch(q, other_user.best_identifier)
         return [cls._from_row(row) for row in rows]
@@ -111,7 +116,7 @@ class Portal:
     @classmethod
     async def all_with_room(cls) -> List['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE mxid IS NOT NULL")
         rows = await cls.db.fetch(q)
         return [cls._from_row(row) for row in rows]

+ 17 - 0
mautrix_signal/db/upgrade.py

@@ -175,3 +175,20 @@ async def upgrade_v6(conn: Connection) -> None:
 @upgrade_table.register(description="Add relay user field to portal table")
 async def upgrade_v7(conn: Connection) -> None:
     await conn.execute("ALTER TABLE portal ADD COLUMN relay_user_id TEXT")
+
+
+@upgrade_table.register(description="Add table for tracking when to redact disappearing messages")
+async def upgrade_v8(conn: Connection) -> None:
+    await conn.execute("""CREATE TABLE disappearing_messages (
+        room_id             TEXT,
+        mxid                TEXT,
+        expiration_seconds  BIGINT,
+        expiration_ts       BIGINT,
+
+        PRIMARY KEY (room_id, mxid)
+    )""")
+
+
+@upgrade_table.register(description="Add expiration_time column to portal table")
+async def upgrade_v9(conn: Connection) -> None:
+    await conn.execute("ALTER TABLE portal ADD COLUMN expiration_time BIGINT")

+ 5 - 0
mautrix_signal/matrix.py

@@ -14,12 +14,15 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 from typing import List, Union, TYPE_CHECKING
+from datetime import datetime
 
 from mautrix.bridge import BaseMatrixHandler
 from mautrix.types import (Event, ReactionEvent, StateEvent, RoomID, EventID, UserID, TypingEvent,
                            ReactionEventContent, RelationType, EventType, ReceiptEvent,
                            PresenceEvent, RedactionEvent, SingleReceiptEventContent)
 
+from mautrix_signal.db.disappearing_message import DisappearingMessage
+
 from .db import Message as DBMessage
 from . import portal as po, user as u, signal as s
 
@@ -102,6 +105,8 @@ class MatrixHandler(BaseMatrixHandler):
 
     async def handle_read_receipt(self, user: 'u.User', portal: 'po.Portal', event_id: EventID,
                                   data: SingleReceiptEventContent) -> None:
+        await portal.handle_read_receipt(event_id, data)
+
         message = await DBMessage.get_by_mxid(event_id, portal.mxid)
         if not message:
             return

+ 122 - 5
mautrix_signal/portal.py

@@ -19,6 +19,7 @@ from html import escape as escape_html
 from collections import deque
 from uuid import UUID, uuid4
 from string import Template
+from datetime import datetime
 import mimetypes
 import pathlib
 import hashlib
@@ -36,12 +37,13 @@ from mautrix.bridge import BasePortal, async_getter_lock
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType, Format,
                            MessageEvent, EncryptedEvent, ContentURI, MediaMessageEventContent,
                            TextMessageEventContent, ImageInfo, VideoInfo, FileInfo, AudioInfo,
-                           PowerLevelStateEventContent, UserID)
+                           PowerLevelStateEventContent, UserID, SingleReceiptEventContent)
 from mautrix.util.bridge_state import BridgeStateEvent
 from mautrix.util.message_send_checkpoint import MessageSendCheckpointStatus
 from mautrix.errors import MatrixError, MForbidden, IntentError
 
-from .db import Portal as DBPortal, Message as DBMessage, Reaction as DBReaction
+from .db import (Portal as DBPortal, Message as DBMessage, Reaction as DBReaction,
+                 DisappearingMessage)
 from .config import Config
 from .formatter import matrix_to_signal, signal_to_matrix
 from .util import id_to_str
@@ -80,6 +82,7 @@ class Portal(DBPortal, BasePortal):
     signal: 's.SignalHandler'
     az: AppService
     private_chat_portal_meta: bool
+    expiration_time: Optional[int]
 
     _main_intent: Optional[IntentAPI]
     _create_room_lock: asyncio.Lock
@@ -93,9 +96,10 @@ class Portal(DBPortal, BasePortal):
                  mxid: Optional[RoomID] = None, name: Optional[str] = None,
                  avatar_hash: Optional[str] = None, avatar_url: Optional[ContentURI] = None,
                  name_set: bool = False, avatar_set: bool = False, revision: int = 0,
-                 encrypted: bool = False, relay_user_id: Optional[UserID] = None) -> None:
+                 encrypted: bool = False, relay_user_id: Optional[UserID] = None,
+                 expiration_time: Optional[int] = None) -> None:
         super().__init__(chat_id, receiver, mxid, name, avatar_hash, avatar_url,
-                         name_set, avatar_set, revision, encrypted, relay_user_id)
+                         name_set, avatar_set, revision, encrypted, relay_user_id, expiration_time)
         self._create_room_lock = asyncio.Lock()
         self.log = self.log.getChild(self.chat_id_str)
         self._main_intent = None
@@ -141,7 +145,7 @@ class Portal(DBPortal, BasePortal):
         self.by_chat_id[(self.chat_id_str, self.receiver)] = self
 
     @classmethod
-    def init_cls(cls, bridge: 'SignalBridge') -> None:
+    async def init_cls(cls, bridge: 'SignalBridge') -> None:
         cls.config = bridge.config
         cls.matrix = bridge.matrix
         cls.signal = bridge.signal
@@ -150,6 +154,10 @@ class Portal(DBPortal, BasePortal):
         BasePortal.bridge = bridge
         cls.private_chat_portal_meta = cls.config["bridge.private_chat_portal_meta"]
 
+        for dm in await DisappearingMessage.get_all():
+            if dm.expiration_ts:
+                asyncio.create_task(cls._expire_event(dm))
+
     # region Misc
 
     async def _send_delivery_receipt(self, event_id: EventID) -> None:
@@ -328,6 +336,14 @@ class Portal(DBPortal, BasePortal):
                 except FileNotFoundError:
                     pass
 
+            # Handle disappearing messages
+            if self.expiration_time:
+                disappearing_message = DisappearingMessage(
+                    self.mxid, event_id, self.expiration_time
+                )
+                await disappearing_message.insert()
+                asyncio.create_task(Portal._expire_event(disappearing_message))
+
     async def handle_matrix_reaction(self, sender: 'u.User', event_id: EventID,
                                      reacting_to: EventID, emoji: str) -> None:
         if not await sender.is_logged_in():
@@ -522,6 +538,56 @@ class Portal(DBPortal, BasePortal):
             except FileNotFoundError:
                 pass
 
+    @classmethod
+    async def _expire_event(cls, disappearing_message: DisappearingMessage):
+        """
+        Expire a :class:`DisappearingMessage`. This should only be called once the message has been
+        read, as the timer for redaction will start immediately, and there is no (supported)
+        mechanism to stop the countdown, even after bridge restart.
+        """
+        room_id = disappearing_message.room_id
+        event_id = disappearing_message.mxid
+        wait = disappearing_message.expiration_seconds
+        # If there is an expiration_ts, then there was probably a bridge restart, so we have to
+        # resume the countdown. This is fairly likely to occur if the disappearance timeout is
+        # weeks.
+        # If there is not an expiration_ts, then set one.
+        now = datetime.now().timestamp()
+        if disappearing_message.expiration_ts is not None:
+            wait = (disappearing_message.expiration_ts / 1000) - now
+        else:
+            disappearing_message.expiration_ts = int((now + wait) * 1000)
+            await disappearing_message.update()
+
+        if wait < 0:
+            wait = 0
+
+        portal = await cls.get_by_mxid(room_id)
+        if not portal:
+            portal.log.warning(f"No portal for {room_id}")
+
+        portal.log.debug(f"Redacting {event_id} in {wait} seconds")
+        await asyncio.sleep(wait)
+
+        if not await DisappearingMessage.get(room_id, event_id):
+            portal.log.debug(f"{event_id} no longer in disappearing messages list, not redacting")
+            return
+
+        portal.log.debug(f"Redacting {event_id} because it was expired")
+        try:
+            await portal.main_intent.redact(room_id, event_id)
+            portal.log.debug(f"Redacted {event_id} successfully")
+        except Exception as e:
+            portal.log.warning("Redacting expired event didn't work", e)
+        finally:
+            await DisappearingMessage.delete(room_id, event_id)
+
+    async def handle_read_receipt(self, event_id: EventID, data: SingleReceiptEventContent):
+        # Start the redaction timers for all of the disappearing messages in the room when the user
+        # reads the room. This is the behavior of the Signal clients.
+        for disappearing_message in await DisappearingMessage.get_all_for_room(self.mxid):
+            asyncio.create_task(Portal._expire_event(disappearing_message))
+
     # endregion
     # region Signal event handling
 
@@ -621,6 +687,16 @@ class Portal(DBPortal, BasePortal):
                                            timestamps=[message.timestamp])
             await self._send_delivery_receipt(event_id)
             self.log.debug(f"Handled Signal message {message.timestamp} -> {event_id}")
+
+            if message.expires_in_seconds:
+                disappearing_message = DisappearingMessage(
+                    self.mxid, event_id, message.expires_in_seconds
+                )
+                await disappearing_message.insert()
+                self.log.debug(
+                    f"{event_id} set to be redacted {message.expires_in_seconds} seconds after "
+                    "room is read"
+                )
         else:
             self.log.debug(f"Didn't get event ID for {message.timestamp}")
 
@@ -853,6 +929,46 @@ class Portal(DBPortal, BasePortal):
             await self.update_bridge_info()
             await self.update()
 
+    async def update_expires_in_seconds(self, sender: 'p.Puppet', expires_in_seconds: int) -> None:
+        if expires_in_seconds == 0:
+            expires_in_seconds = None
+        if self.expiration_time == expires_in_seconds:
+            return
+
+        def format_time(seconds) -> str:
+            if seconds is None:
+                return "Off"
+
+            # Technique from https://stackoverflow.com/a/24542445
+            intervals = (
+                ("weeks", 604800),  # 60 * 60 * 24 * 7
+                ("days", 86400),    # 60 * 60 * 24
+                ("hours", 3600),    # 60 * 60
+                ("minutes", 60),
+                ("seconds", 1),
+            )
+
+            result = []
+            for name, count in intervals:
+                value = seconds // count
+                if value:
+                    seconds -= value * count
+                    if value == 1:
+                        name = name[:-1]
+                    result.append(f"{value} {name}")
+            return ", ".join(result)
+
+        assert self.mxid
+        self.expiration_time = expires_in_seconds
+        await self.update()
+
+        await self.main_intent.send_notice(
+            self.mxid,
+            text=None,
+            html=f'<a href="https://matrix.to/#/{sender.mxid}">{sender.name}</a> set the '
+            f'disappearing message timer to {format_time(expires_in_seconds)}.'
+        )
+
     async def update_puppet_avatar(self, new_hash: str, avatar_url: ContentURI) -> None:
         if not self.encrypted and not self.private_chat_portal_meta:
             return
@@ -1195,6 +1311,7 @@ class Portal(DBPortal, BasePortal):
         elif not self.is_direct:
             self._main_intent = self.az.intent
 
+
     async def delete(self) -> None:
         await DBMessage.delete_all(self.mxid)
         self.by_mxid.pop(self.mxid, None)

+ 2 - 0
mautrix_signal/signal.py

@@ -113,6 +113,8 @@ class SignalHandler(SignaldClient):
             await portal.update_info(user, msg.group)
         if msg.remote_delete:
             await portal.handle_signal_delete(sender, msg.remote_delete.target_sent_timestamp)
+        if msg.expires_in_seconds is not None:
+            await portal.update_expires_in_seconds(sender, msg.expires_in_seconds)
 
     @staticmethod
     async def handle_own_receipts(sender: 'pu.Puppet', receipts: List[OwnReadReceipt]) -> None: