Prechádzať zdrojové kódy

Merge pull request #182 from mautrix/sumner/bri-1204-support-disappearing-messages-in-signal

disappearing messages: add support for both inbound and outbound
Sumner Evans 3 rokov pred
rodič
commit
e13206318b

+ 1 - 0
mautrix_signal/__main__.py

@@ -74,6 +74,7 @@ class SignalBridge(Bridge):
         self.add_startup_actions(self.signal.start())
         await super().start()
         self.periodic_sync_task = asyncio.create_task(self._periodic_sync_loop())
+        asyncio.create_task(Portal.start_disappearing_message_expirations())
 
     @staticmethod
     async def _actual_periodic_sync_loop(log: logging.Logger, interval: int) -> None:

+ 1 - 0
mautrix_signal/config.py

@@ -54,6 +54,7 @@ class Config(BaseBridgeConfig):
         copy("signal.delete_unknown_accounts_on_start")
         copy("signal.remove_file_after_handling")
         copy("signal.registration_enabled")
+        copy("signal.enable_disappearing_messages_in_groups")
 
         copy("metrics.enabled")
         copy("metrics.listen_port")

+ 4 - 2
mautrix_signal/db/__init__.py

@@ -3,6 +3,7 @@ import sqlite3
 import uuid
 
 from .upgrade import upgrade_table
+from .disappearing_message import DisappearingMessage
 from .user import User
 from .puppet import Puppet
 from .portal import Portal
@@ -11,7 +12,7 @@ from .reaction import Reaction
 
 
 def init(db: Database) -> None:
-    for table in (User, Puppet, Portal, Message, Reaction):
+    for table in (User, Puppet, Portal, Message, Reaction, DisappearingMessage):
         table.db = db
 
 
@@ -19,4 +20,5 @@ def init(db: Database) -> None:
 sqlite3.register_adapter(uuid.UUID, lambda u: str(u))
 sqlite3.register_converter("UUID", lambda b: uuid.UUID(b.decode("utf-8") if isinstance(b, bytes) else b))
 
-__all__ = ["upgrade_table", "init", "User", "Puppet", "Portal", "Message", "Reaction"]
+__all__ = ["upgrade_table", "init", "User", "Puppet", "Portal", "Message", "Reaction",
+           "DisappearingMessage"]

+ 90 - 0
mautrix_signal/db/disappearing_message.py

@@ -0,0 +1,90 @@
+# mautrix-signal - A Matrix-Signal puppeting bridge
+# Copyright (C) 2021 Sumner Evans
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+from typing import ClassVar, List, Optional, TYPE_CHECKING
+
+from attr import dataclass
+import asyncpg
+
+from mautrix.types import RoomID, EventID
+from mautrix.util.async_db import Database
+
+fake_db = Database.create("") if TYPE_CHECKING else None
+
+
+@dataclass
+class DisappearingMessage:
+    db: ClassVar[Database] = fake_db
+
+    room_id: RoomID
+    mxid: EventID
+    expiration_seconds: int
+    expiration_ts: Optional[int] = None
+
+    async def insert(self) -> None:
+        q = """
+        INSERT INTO disappearing_message (room_id, mxid, expiration_seconds, expiration_ts)
+        VALUES ($1, $2, $3, $4)
+        """
+        await self.db.execute(q, self.room_id, self.mxid, self.expiration_seconds,
+                              self.expiration_ts)
+
+    async def update(self) -> None:
+        q = """
+        UPDATE disappearing_message
+        SET expiration_seconds=$3, expiration_ts=$4
+        WHERE room_id=$1 AND mxid=$2
+        """
+        try:
+            await self.db.execute(q, self.room_id, self.mxid, self.expiration_seconds,
+                              self.expiration_ts)
+        except Exception as e:
+            print(e)
+
+    @classmethod
+    async def delete(cls, room_id: RoomID, event_id: EventID) -> None:
+        q = "DELETE from disappearing_message WHERE room_id=$1 AND mxid=$2"
+        await cls.db.execute(q, room_id, event_id)
+
+    @classmethod
+    def _from_row(cls, row: asyncpg.Record) -> "DisappearingMessage":
+        return cls(**row)
+
+    @classmethod
+    async def get(cls, room_id: RoomID, event_id: EventID) -> Optional["DisappearingMessage"]:
+        q = """
+        SELECT room_id, mxid, expiration_seconds, expiration_ts
+          FROM disappearing_message
+         WHERE room_id = $1
+           AND mxid = $2
+        """
+        try:
+            return cls._from_row(await cls.db.fetchrow(q, room_id, event_id))
+        except Exception:
+            return None
+
+    @classmethod
+    async def get_all(cls) -> List["DisappearingMessage"]:
+        q = "SELECT room_id, mxid, expiration_seconds, expiration_ts FROM disappearing_message"
+        return [cls._from_row(r) for r in await cls.db.fetch(q)]
+
+    @classmethod
+    async def get_all_for_room(cls, room_id: RoomID) -> List["DisappearingMessage"]:
+        q = """
+        SELECT room_id, mxid, expiration_seconds, expiration_ts
+          FROM disappearing_message
+         WHERE room_id = $1
+        """
+        return [cls._from_row(r) for r in await cls.db.fetch(q, room_id)]

+ 16 - 11
mautrix_signal/db/portal.py

@@ -42,6 +42,7 @@ class Portal:
     revision: int
     encrypted: bool
     relay_user_id: Optional[UserID]
+    expiration_time: Optional[int]
 
     @property
     def chat_id_str(self) -> str:
@@ -49,19 +50,23 @@ class Portal:
 
     async def insert(self) -> None:
         q = ("INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, "
-             "                    name_set, avatar_set, revision, encrypted, relay_user_id) "
-             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)")
+             "                    name_set, avatar_set, revision, encrypted, relay_user_id, "
+             "                    expiration_time) "
+             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)")
         await self.db.execute(q, self.chat_id_str, self.receiver, self.mxid, self.name,
                               self.avatar_hash, self.avatar_url, self.name_set, self.avatar_set,
-                              self.revision, self.encrypted, self.relay_user_id)
+                              self.revision, self.encrypted, self.relay_user_id,
+                              self.expiration_time)
 
     async def update(self) -> None:
         q = ("UPDATE portal SET mxid=$1, name=$2, avatar_hash=$3, avatar_url=$4, name_set=$5, "
-             "                  avatar_set=$6, revision=$7, encrypted=$8, relay_user_id=$9 "
-             "WHERE chat_id=$10 AND receiver=$11")
+             "                  avatar_set=$6, revision=$7, encrypted=$8, relay_user_id=$9, "
+             "                  expiration_time=$10"
+             "WHERE chat_id=$11 AND receiver=$12")
         await self.db.execute(q, self.mxid, self.name, self.avatar_hash, self.avatar_url,
                               self.name_set, self.avatar_set, self.revision, self.encrypted,
-                              self.relay_user_id, self.chat_id_str, self.receiver)
+                              self.relay_user_id, self.expiration_time, self.chat_id_str,
+                              self.receiver)
 
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> 'Portal':
@@ -74,7 +79,7 @@ class Portal:
     @classmethod
     async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE mxid=$1")
         row = await cls.db.fetchrow(q, mxid)
         if not row:
@@ -85,7 +90,7 @@ class Portal:
     async def get_by_chat_id(cls, chat_id: Union[GroupID, Address], receiver: str = ""
                              ) -> Optional['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE chat_id=$1 AND receiver=$2")
         row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
         if not row:
@@ -95,7 +100,7 @@ class Portal:
     @classmethod
     async def find_private_chats_of(cls, receiver: str) -> List['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE receiver=$1")
         rows = await cls.db.fetch(q, receiver)
         return [cls._from_row(row) for row in rows]
@@ -103,7 +108,7 @@ class Portal:
     @classmethod
     async def find_private_chats_with(cls, other_user: Address) -> List['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE chat_id=$1 AND receiver<>''")
         rows = await cls.db.fetch(q, other_user.best_identifier)
         return [cls._from_row(row) for row in rows]
@@ -111,7 +116,7 @@ class Portal:
     @classmethod
     async def all_with_room(cls) -> List['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE mxid IS NOT NULL")
         rows = await cls.db.fetch(q)
         return [cls._from_row(row) for row in rows]

+ 13 - 0
mautrix_signal/db/upgrade.py

@@ -175,3 +175,16 @@ async def upgrade_v6(conn: Connection) -> None:
 @upgrade_table.register(description="Add relay user field to portal table")
 async def upgrade_v7(conn: Connection) -> None:
     await conn.execute("ALTER TABLE portal ADD COLUMN relay_user_id TEXT")
+
+
+@upgrade_table.register(description="Add support for disappearing messages")
+async def upgrade_v8(conn: Connection) -> None:
+    await conn.execute("""CREATE TABLE disappearing_message (
+        room_id             TEXT,
+        mxid                TEXT,
+        expiration_seconds  BIGINT,
+        expiration_ts       BIGINT,
+
+        PRIMARY KEY (room_id, mxid)
+    )""")
+    await conn.execute("ALTER TABLE portal ADD COLUMN expiration_time BIGINT")

+ 4 - 0
mautrix_signal/example-config.yaml

@@ -105,6 +105,10 @@ signal:
     remove_file_after_handling: true
     # Whether or not users can register a primary device
     registration_enabled: true
+    # Whether or not to enable disappearing messages in groups. If enabled, then the expiration
+    # time of the messages will be determined by the first users to read the message, rather
+    # than individually. If the bridge has a single user, this can be turned on safely.
+    enable_disappearing_messages_in_groups: false
 
 # Bridge config
 bridge:

+ 4 - 0
mautrix_signal/matrix.py

@@ -20,6 +20,8 @@ from mautrix.types import (Event, ReactionEvent, StateEvent, RoomID, EventID, Us
                            ReactionEventContent, RelationType, EventType, ReceiptEvent,
                            PresenceEvent, RedactionEvent, SingleReceiptEventContent)
 
+from mautrix_signal.db.disappearing_message import DisappearingMessage
+
 from .db import Message as DBMessage
 from . import portal as po, user as u, signal as s
 
@@ -102,6 +104,8 @@ class MatrixHandler(BaseMatrixHandler):
 
     async def handle_read_receipt(self, user: 'u.User', portal: 'po.Portal', event_id: EventID,
                                   data: SingleReceiptEventContent) -> None:
+        await portal.handle_read_receipt(event_id, data)
+
         message = await DBMessage.get_by_mxid(event_id, portal.mxid)
         if not message:
             return

+ 146 - 4
mautrix_signal/portal.py

@@ -36,12 +36,14 @@ from mautrix.bridge import BasePortal, async_getter_lock
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType, Format,
                            MessageEvent, EncryptedEvent, ContentURI, MediaMessageEventContent,
                            TextMessageEventContent, ImageInfo, VideoInfo, FileInfo, AudioInfo,
-                           PowerLevelStateEventContent, UserID)
+                           PowerLevelStateEventContent, UserID, SingleReceiptEventContent)
+from mautrix.util.format_duration import format_duration
 from mautrix.util.bridge_state import BridgeStateEvent
 from mautrix.util.message_send_checkpoint import MessageSendCheckpointStatus
 from mautrix.errors import MatrixError, MForbidden, IntentError
 
-from .db import Portal as DBPortal, Message as DBMessage, Reaction as DBReaction
+from .db import (Portal as DBPortal, Message as DBMessage, Reaction as DBReaction,
+                 DisappearingMessage)
 from .config import Config
 from .formatter import matrix_to_signal, signal_to_matrix
 from .util import id_to_str
@@ -80,6 +82,7 @@ class Portal(DBPortal, BasePortal):
     signal: 's.SignalHandler'
     az: AppService
     private_chat_portal_meta: bool
+    expiration_time: Optional[int]
 
     _main_intent: Optional[IntentAPI]
     _create_room_lock: asyncio.Lock
@@ -88,14 +91,16 @@ class Portal(DBPortal, BasePortal):
     _reaction_lock: asyncio.Lock
     _pending_members: Optional[Set[UUID]]
     _relay_user: Optional['u.User']
+    _expiration_lock: asyncio.Lock
 
     def __init__(self, chat_id: Union[GroupID, Address], receiver: str,
                  mxid: Optional[RoomID] = None, name: Optional[str] = None,
                  avatar_hash: Optional[str] = None, avatar_url: Optional[ContentURI] = None,
                  name_set: bool = False, avatar_set: bool = False, revision: int = 0,
-                 encrypted: bool = False, relay_user_id: Optional[UserID] = None) -> None:
+                 encrypted: bool = False, relay_user_id: Optional[UserID] = None,
+                 expiration_time: Optional[int] = None) -> None:
         super().__init__(chat_id, receiver, mxid, name, avatar_hash, avatar_url,
-                         name_set, avatar_set, revision, encrypted, relay_user_id)
+                         name_set, avatar_set, revision, encrypted, relay_user_id, expiration_time)
         self._create_room_lock = asyncio.Lock()
         self.log = self.log.getChild(self.chat_id_str)
         self._main_intent = None
@@ -105,6 +110,7 @@ class Portal(DBPortal, BasePortal):
         self._reaction_lock = asyncio.Lock()
         self._pending_members = None
         self._relay_user = None
+        self._expiration_lock = asyncio.Lock()
 
     @property
     def has_relay(self) -> bool:
@@ -150,6 +156,14 @@ class Portal(DBPortal, BasePortal):
         BasePortal.bridge = bridge
         cls.private_chat_portal_meta = cls.config["bridge.private_chat_portal_meta"]
 
+    @classmethod
+    async def start_disappearing_message_expirations(cls):
+        await asyncio.gather(*(
+            cls._expire_event(dm.room_id, dm.mxid, restart=True)
+            for dm in await DisappearingMessage.get_all()
+            if dm.expiration_ts
+        ))
+
     # region Misc
 
     async def _send_delivery_receipt(self, event_id: EventID) -> None:
@@ -328,6 +342,18 @@ class Portal(DBPortal, BasePortal):
                 except FileNotFoundError:
                     pass
 
+            # Handle disappearing messages
+            if (
+                self.expiration_time
+                and (
+                    self.is_direct
+                    or self.config["signal.enable_disappearing_messages_in_groups"]
+                )
+            ):
+                dm = DisappearingMessage(self.mxid, event_id, self.expiration_time)
+                await dm.insert()
+                await Portal._expire_event(dm.room_id, dm.mxid)
+
     async def handle_matrix_reaction(self, sender: 'u.User', event_id: EventID,
                                      reacting_to: EventID, emoji: str) -> None:
         if not await sender.is_logged_in():
@@ -522,6 +548,89 @@ class Portal(DBPortal, BasePortal):
             except FileNotFoundError:
                 pass
 
+    @classmethod
+    async def _expire_event(cls, room_id: RoomID, event_id: EventID, restart: bool = False):
+        """
+        Schedule a task to expire a an event. This should only be called once the message has been
+        read, as the timer for redaction will start immediately, and there is no (supported)
+        mechanism to stop the countdown, even after bridge restart.
+
+        If there is already an expiration event for the given ``room_id`` and ``event_id``, it will
+        not schedule a new task.
+        """
+        portal = await cls.get_by_mxid(room_id)
+        if not portal:
+            raise AttributeError(f"No portal found for {room_id}")
+
+        # Need a lock around this critical section to make sure that we know if a task has been
+        # created for this particular (room_id, event_id) combination.
+        async with portal._expiration_lock:
+            if (
+                not portal.is_direct and not
+                cls.config["signal.enable_disappearing_messages_in_groups"]
+            ):
+                portal.log.debug(
+                    "Not expiring event in group message since "
+                    "signal.enable_disappearing_messages_in_groups is not enabled."
+                )
+                await DisappearingMessage.delete(room_id, event_id)
+                return
+
+            disappearing_message = await DisappearingMessage.get(room_id, event_id)
+            if disappearing_message is None:
+                return
+            wait = disappearing_message.expiration_seconds
+            now = time.time()
+            # If there is an expiration_ts, then there's already a task going, or it's a restart.
+            # If it's a restart, then restart the countdown. This is fairly likely to occur if the
+            # disappearance timeout is weeks.
+            if disappearing_message.expiration_ts:
+                if not restart:
+                    portal.log.debug(f"Expiration task already exists for {event_id} in {room_id}")
+                    return
+                portal.log.debug(f"Resuming expiration for {event_id} in {room_id}")
+                wait = (disappearing_message.expiration_ts / 1000) - now
+            if wait < 0:
+                wait = 0
+
+            # Spawn the actual expiration task.
+            asyncio.create_task(cls._expire_event_task(portal, event_id, wait))
+
+            # Set the expiration_ts only after we have actually created the expiration task.
+            if not disappearing_message.expiration_ts:
+                disappearing_message.expiration_ts = int((now + wait) * 1000)
+                await disappearing_message.update()
+
+
+    @classmethod
+    async def _expire_event_task(cls, portal: 'Portal', event_id: EventID, wait: float):
+        portal.log.debug(f"Redacting {event_id} in {wait} seconds")
+        await asyncio.sleep(wait)
+
+        async with portal._expiration_lock:
+            if not await DisappearingMessage.get(portal.mxid, event_id):
+                portal.log.debug(
+                    f"{event_id} no longer in disappearing messages list, not redacting"
+                )
+                return
+
+            portal.log.debug(f"Redacting {event_id} because it was expired")
+            try:
+                await portal.main_intent.redact(portal.mxid, event_id)
+                portal.log.debug(f"Redacted {event_id} successfully")
+            except Exception as e:
+                portal.log.warning(f"Redacting expired event {event_id} failed", e)
+            finally:
+                await DisappearingMessage.delete(portal.mxid, event_id)
+
+    async def handle_read_receipt(self, event_id: EventID, data: SingleReceiptEventContent):
+        # Start the redaction timers for all of the disappearing messages in the room when the user
+        # reads the room. This is the behavior of the Signal clients.
+        await asyncio.gather(*(
+            Portal._expire_event(dm.room_id, dm.mxid)
+            for dm in await DisappearingMessage.get_all_for_room(self.mxid)
+        ))
+
     # endregion
     # region Signal event handling
 
@@ -621,6 +730,22 @@ class Portal(DBPortal, BasePortal):
                                            timestamps=[message.timestamp])
             await self._send_delivery_receipt(event_id)
             self.log.debug(f"Handled Signal message {message.timestamp} -> {event_id}")
+
+            if (
+                message.expires_in_seconds
+                and (
+                    self.is_direct
+                    or self.config["signal.enable_disappearing_messages_in_groups"]
+                )
+            ):
+                disappearing_message = DisappearingMessage(
+                    self.mxid, event_id, message.expires_in_seconds
+                )
+                await disappearing_message.insert()
+                self.log.debug(
+                    f"{event_id} set to be redacted {message.expires_in_seconds} seconds after "
+                    "room is read"
+                )
         else:
             self.log.debug(f"Didn't get event ID for {message.timestamp}")
 
@@ -853,6 +978,23 @@ class Portal(DBPortal, BasePortal):
             await self.update_bridge_info()
             await self.update()
 
+    async def update_expires_in_seconds(self, sender: 'p.Puppet', expires_in_seconds: int) -> None:
+        if expires_in_seconds == 0:
+            expires_in_seconds = None
+        if self.expiration_time == expires_in_seconds:
+            return
+
+        assert self.mxid
+        self.expiration_time = expires_in_seconds
+        await self.update()
+
+        time_str = "Off" if expires_in_seconds is None else format_duration(expires_in_seconds)
+        await self.main_intent.send_notice(
+            self.mxid,
+            html=f'<a href="https://matrix.to/#/{sender.mxid}">{sender.name}</a> set the '
+            f'disappearing message timer to {time_str}.'
+        )
+
     async def update_puppet_avatar(self, new_hash: str, avatar_url: ContentURI) -> None:
         if not self.encrypted and not self.private_chat_portal_meta:
             return

+ 2 - 0
mautrix_signal/signal.py

@@ -113,6 +113,8 @@ class SignalHandler(SignaldClient):
             await portal.update_info(user, msg.group)
         if msg.remote_delete:
             await portal.handle_signal_delete(sender, msg.remote_delete.target_sent_timestamp)
+        if msg.expires_in_seconds is not None:
+            await portal.update_expires_in_seconds(sender, msg.expires_in_seconds)
 
     @staticmethod
     async def handle_own_receipts(sender: 'pu.Puppet', receipts: List[OwnReadReceipt]) -> None:

+ 1 - 1
requirements.txt

@@ -4,5 +4,5 @@ commonmark>=0.8,<0.10
 aiohttp>=3,<4
 yarl>=1,<2
 attrs>=19.1
-mautrix>=0.13rc1,<0.14
+mautrix>=0.13.0,<0.14
 asyncpg>=0.20,<0.26