Browse Source

Merge pull request #182 from mautrix/sumner/bri-1204-support-disappearing-messages-in-signal

disappearing messages: add support for both inbound and outbound
Sumner Evans 3 years ago
parent
commit
e13206318b

+ 1 - 0
mautrix_signal/__main__.py

@@ -74,6 +74,7 @@ class SignalBridge(Bridge):
         self.add_startup_actions(self.signal.start())
         self.add_startup_actions(self.signal.start())
         await super().start()
         await super().start()
         self.periodic_sync_task = asyncio.create_task(self._periodic_sync_loop())
         self.periodic_sync_task = asyncio.create_task(self._periodic_sync_loop())
+        asyncio.create_task(Portal.start_disappearing_message_expirations())
 
 
     @staticmethod
     @staticmethod
     async def _actual_periodic_sync_loop(log: logging.Logger, interval: int) -> None:
     async def _actual_periodic_sync_loop(log: logging.Logger, interval: int) -> None:

+ 1 - 0
mautrix_signal/config.py

@@ -54,6 +54,7 @@ class Config(BaseBridgeConfig):
         copy("signal.delete_unknown_accounts_on_start")
         copy("signal.delete_unknown_accounts_on_start")
         copy("signal.remove_file_after_handling")
         copy("signal.remove_file_after_handling")
         copy("signal.registration_enabled")
         copy("signal.registration_enabled")
+        copy("signal.enable_disappearing_messages_in_groups")
 
 
         copy("metrics.enabled")
         copy("metrics.enabled")
         copy("metrics.listen_port")
         copy("metrics.listen_port")

+ 4 - 2
mautrix_signal/db/__init__.py

@@ -3,6 +3,7 @@ import sqlite3
 import uuid
 import uuid
 
 
 from .upgrade import upgrade_table
 from .upgrade import upgrade_table
+from .disappearing_message import DisappearingMessage
 from .user import User
 from .user import User
 from .puppet import Puppet
 from .puppet import Puppet
 from .portal import Portal
 from .portal import Portal
@@ -11,7 +12,7 @@ from .reaction import Reaction
 
 
 
 
 def init(db: Database) -> None:
 def init(db: Database) -> None:
-    for table in (User, Puppet, Portal, Message, Reaction):
+    for table in (User, Puppet, Portal, Message, Reaction, DisappearingMessage):
         table.db = db
         table.db = db
 
 
 
 
@@ -19,4 +20,5 @@ def init(db: Database) -> None:
 sqlite3.register_adapter(uuid.UUID, lambda u: str(u))
 sqlite3.register_adapter(uuid.UUID, lambda u: str(u))
 sqlite3.register_converter("UUID", lambda b: uuid.UUID(b.decode("utf-8") if isinstance(b, bytes) else b))
 sqlite3.register_converter("UUID", lambda b: uuid.UUID(b.decode("utf-8") if isinstance(b, bytes) else b))
 
 
-__all__ = ["upgrade_table", "init", "User", "Puppet", "Portal", "Message", "Reaction"]
+__all__ = ["upgrade_table", "init", "User", "Puppet", "Portal", "Message", "Reaction",
+           "DisappearingMessage"]

+ 90 - 0
mautrix_signal/db/disappearing_message.py

@@ -0,0 +1,90 @@
+# mautrix-signal - A Matrix-Signal puppeting bridge
+# Copyright (C) 2021 Sumner Evans
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+from typing import ClassVar, List, Optional, TYPE_CHECKING
+
+from attr import dataclass
+import asyncpg
+
+from mautrix.types import RoomID, EventID
+from mautrix.util.async_db import Database
+
+fake_db = Database.create("") if TYPE_CHECKING else None
+
+
+@dataclass
+class DisappearingMessage:
+    db: ClassVar[Database] = fake_db
+
+    room_id: RoomID
+    mxid: EventID
+    expiration_seconds: int
+    expiration_ts: Optional[int] = None
+
+    async def insert(self) -> None:
+        q = """
+        INSERT INTO disappearing_message (room_id, mxid, expiration_seconds, expiration_ts)
+        VALUES ($1, $2, $3, $4)
+        """
+        await self.db.execute(q, self.room_id, self.mxid, self.expiration_seconds,
+                              self.expiration_ts)
+
+    async def update(self) -> None:
+        q = """
+        UPDATE disappearing_message
+        SET expiration_seconds=$3, expiration_ts=$4
+        WHERE room_id=$1 AND mxid=$2
+        """
+        try:
+            await self.db.execute(q, self.room_id, self.mxid, self.expiration_seconds,
+                              self.expiration_ts)
+        except Exception as e:
+            print(e)
+
+    @classmethod
+    async def delete(cls, room_id: RoomID, event_id: EventID) -> None:
+        q = "DELETE from disappearing_message WHERE room_id=$1 AND mxid=$2"
+        await cls.db.execute(q, room_id, event_id)
+
+    @classmethod
+    def _from_row(cls, row: asyncpg.Record) -> "DisappearingMessage":
+        return cls(**row)
+
+    @classmethod
+    async def get(cls, room_id: RoomID, event_id: EventID) -> Optional["DisappearingMessage"]:
+        q = """
+        SELECT room_id, mxid, expiration_seconds, expiration_ts
+          FROM disappearing_message
+         WHERE room_id = $1
+           AND mxid = $2
+        """
+        try:
+            return cls._from_row(await cls.db.fetchrow(q, room_id, event_id))
+        except Exception:
+            return None
+
+    @classmethod
+    async def get_all(cls) -> List["DisappearingMessage"]:
+        q = "SELECT room_id, mxid, expiration_seconds, expiration_ts FROM disappearing_message"
+        return [cls._from_row(r) for r in await cls.db.fetch(q)]
+
+    @classmethod
+    async def get_all_for_room(cls, room_id: RoomID) -> List["DisappearingMessage"]:
+        q = """
+        SELECT room_id, mxid, expiration_seconds, expiration_ts
+          FROM disappearing_message
+         WHERE room_id = $1
+        """
+        return [cls._from_row(r) for r in await cls.db.fetch(q, room_id)]

+ 16 - 11
mautrix_signal/db/portal.py

@@ -42,6 +42,7 @@ class Portal:
     revision: int
     revision: int
     encrypted: bool
     encrypted: bool
     relay_user_id: Optional[UserID]
     relay_user_id: Optional[UserID]
+    expiration_time: Optional[int]
 
 
     @property
     @property
     def chat_id_str(self) -> str:
     def chat_id_str(self) -> str:
@@ -49,19 +50,23 @@ class Portal:
 
 
     async def insert(self) -> None:
     async def insert(self) -> None:
         q = ("INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, "
         q = ("INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, "
-             "                    name_set, avatar_set, revision, encrypted, relay_user_id) "
-             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)")
+             "                    name_set, avatar_set, revision, encrypted, relay_user_id, "
+             "                    expiration_time) "
+             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)")
         await self.db.execute(q, self.chat_id_str, self.receiver, self.mxid, self.name,
         await self.db.execute(q, self.chat_id_str, self.receiver, self.mxid, self.name,
                               self.avatar_hash, self.avatar_url, self.name_set, self.avatar_set,
                               self.avatar_hash, self.avatar_url, self.name_set, self.avatar_set,
-                              self.revision, self.encrypted, self.relay_user_id)
+                              self.revision, self.encrypted, self.relay_user_id,
+                              self.expiration_time)
 
 
     async def update(self) -> None:
     async def update(self) -> None:
         q = ("UPDATE portal SET mxid=$1, name=$2, avatar_hash=$3, avatar_url=$4, name_set=$5, "
         q = ("UPDATE portal SET mxid=$1, name=$2, avatar_hash=$3, avatar_url=$4, name_set=$5, "
-             "                  avatar_set=$6, revision=$7, encrypted=$8, relay_user_id=$9 "
-             "WHERE chat_id=$10 AND receiver=$11")
+             "                  avatar_set=$6, revision=$7, encrypted=$8, relay_user_id=$9, "
+             "                  expiration_time=$10"
+             "WHERE chat_id=$11 AND receiver=$12")
         await self.db.execute(q, self.mxid, self.name, self.avatar_hash, self.avatar_url,
         await self.db.execute(q, self.mxid, self.name, self.avatar_hash, self.avatar_url,
                               self.name_set, self.avatar_set, self.revision, self.encrypted,
                               self.name_set, self.avatar_set, self.revision, self.encrypted,
-                              self.relay_user_id, self.chat_id_str, self.receiver)
+                              self.relay_user_id, self.expiration_time, self.chat_id_str,
+                              self.receiver)
 
 
     @classmethod
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> 'Portal':
     def _from_row(cls, row: asyncpg.Record) -> 'Portal':
@@ -74,7 +79,7 @@ class Portal:
     @classmethod
     @classmethod
     async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
     async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE mxid=$1")
              "FROM portal WHERE mxid=$1")
         row = await cls.db.fetchrow(q, mxid)
         row = await cls.db.fetchrow(q, mxid)
         if not row:
         if not row:
@@ -85,7 +90,7 @@ class Portal:
     async def get_by_chat_id(cls, chat_id: Union[GroupID, Address], receiver: str = ""
     async def get_by_chat_id(cls, chat_id: Union[GroupID, Address], receiver: str = ""
                              ) -> Optional['Portal']:
                              ) -> Optional['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE chat_id=$1 AND receiver=$2")
              "FROM portal WHERE chat_id=$1 AND receiver=$2")
         row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
         row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
         if not row:
         if not row:
@@ -95,7 +100,7 @@ class Portal:
     @classmethod
     @classmethod
     async def find_private_chats_of(cls, receiver: str) -> List['Portal']:
     async def find_private_chats_of(cls, receiver: str) -> List['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE receiver=$1")
              "FROM portal WHERE receiver=$1")
         rows = await cls.db.fetch(q, receiver)
         rows = await cls.db.fetch(q, receiver)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]
@@ -103,7 +108,7 @@ class Portal:
     @classmethod
     @classmethod
     async def find_private_chats_with(cls, other_user: Address) -> List['Portal']:
     async def find_private_chats_with(cls, other_user: Address) -> List['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE chat_id=$1 AND receiver<>''")
              "FROM portal WHERE chat_id=$1 AND receiver<>''")
         rows = await cls.db.fetch(q, other_user.best_identifier)
         rows = await cls.db.fetch(q, other_user.best_identifier)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]
@@ -111,7 +116,7 @@ class Portal:
     @classmethod
     @classmethod
     async def all_with_room(cls) -> List['Portal']:
     async def all_with_room(cls) -> List['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE mxid IS NOT NULL")
              "FROM portal WHERE mxid IS NOT NULL")
         rows = await cls.db.fetch(q)
         rows = await cls.db.fetch(q)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]

+ 13 - 0
mautrix_signal/db/upgrade.py

@@ -175,3 +175,16 @@ async def upgrade_v6(conn: Connection) -> None:
 @upgrade_table.register(description="Add relay user field to portal table")
 @upgrade_table.register(description="Add relay user field to portal table")
 async def upgrade_v7(conn: Connection) -> None:
 async def upgrade_v7(conn: Connection) -> None:
     await conn.execute("ALTER TABLE portal ADD COLUMN relay_user_id TEXT")
     await conn.execute("ALTER TABLE portal ADD COLUMN relay_user_id TEXT")
+
+
+@upgrade_table.register(description="Add support for disappearing messages")
+async def upgrade_v8(conn: Connection) -> None:
+    await conn.execute("""CREATE TABLE disappearing_message (
+        room_id             TEXT,
+        mxid                TEXT,
+        expiration_seconds  BIGINT,
+        expiration_ts       BIGINT,
+
+        PRIMARY KEY (room_id, mxid)
+    )""")
+    await conn.execute("ALTER TABLE portal ADD COLUMN expiration_time BIGINT")

+ 4 - 0
mautrix_signal/example-config.yaml

@@ -105,6 +105,10 @@ signal:
     remove_file_after_handling: true
     remove_file_after_handling: true
     # Whether or not users can register a primary device
     # Whether or not users can register a primary device
     registration_enabled: true
     registration_enabled: true
+    # Whether or not to enable disappearing messages in groups. If enabled, then the expiration
+    # time of the messages will be determined by the first users to read the message, rather
+    # than individually. If the bridge has a single user, this can be turned on safely.
+    enable_disappearing_messages_in_groups: false
 
 
 # Bridge config
 # Bridge config
 bridge:
 bridge:

+ 4 - 0
mautrix_signal/matrix.py

@@ -20,6 +20,8 @@ from mautrix.types import (Event, ReactionEvent, StateEvent, RoomID, EventID, Us
                            ReactionEventContent, RelationType, EventType, ReceiptEvent,
                            ReactionEventContent, RelationType, EventType, ReceiptEvent,
                            PresenceEvent, RedactionEvent, SingleReceiptEventContent)
                            PresenceEvent, RedactionEvent, SingleReceiptEventContent)
 
 
+from mautrix_signal.db.disappearing_message import DisappearingMessage
+
 from .db import Message as DBMessage
 from .db import Message as DBMessage
 from . import portal as po, user as u, signal as s
 from . import portal as po, user as u, signal as s
 
 
@@ -102,6 +104,8 @@ class MatrixHandler(BaseMatrixHandler):
 
 
     async def handle_read_receipt(self, user: 'u.User', portal: 'po.Portal', event_id: EventID,
     async def handle_read_receipt(self, user: 'u.User', portal: 'po.Portal', event_id: EventID,
                                   data: SingleReceiptEventContent) -> None:
                                   data: SingleReceiptEventContent) -> None:
+        await portal.handle_read_receipt(event_id, data)
+
         message = await DBMessage.get_by_mxid(event_id, portal.mxid)
         message = await DBMessage.get_by_mxid(event_id, portal.mxid)
         if not message:
         if not message:
             return
             return

+ 146 - 4
mautrix_signal/portal.py

@@ -36,12 +36,14 @@ from mautrix.bridge import BasePortal, async_getter_lock
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType, Format,
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType, Format,
                            MessageEvent, EncryptedEvent, ContentURI, MediaMessageEventContent,
                            MessageEvent, EncryptedEvent, ContentURI, MediaMessageEventContent,
                            TextMessageEventContent, ImageInfo, VideoInfo, FileInfo, AudioInfo,
                            TextMessageEventContent, ImageInfo, VideoInfo, FileInfo, AudioInfo,
-                           PowerLevelStateEventContent, UserID)
+                           PowerLevelStateEventContent, UserID, SingleReceiptEventContent)
+from mautrix.util.format_duration import format_duration
 from mautrix.util.bridge_state import BridgeStateEvent
 from mautrix.util.bridge_state import BridgeStateEvent
 from mautrix.util.message_send_checkpoint import MessageSendCheckpointStatus
 from mautrix.util.message_send_checkpoint import MessageSendCheckpointStatus
 from mautrix.errors import MatrixError, MForbidden, IntentError
 from mautrix.errors import MatrixError, MForbidden, IntentError
 
 
-from .db import Portal as DBPortal, Message as DBMessage, Reaction as DBReaction
+from .db import (Portal as DBPortal, Message as DBMessage, Reaction as DBReaction,
+                 DisappearingMessage)
 from .config import Config
 from .config import Config
 from .formatter import matrix_to_signal, signal_to_matrix
 from .formatter import matrix_to_signal, signal_to_matrix
 from .util import id_to_str
 from .util import id_to_str
@@ -80,6 +82,7 @@ class Portal(DBPortal, BasePortal):
     signal: 's.SignalHandler'
     signal: 's.SignalHandler'
     az: AppService
     az: AppService
     private_chat_portal_meta: bool
     private_chat_portal_meta: bool
+    expiration_time: Optional[int]
 
 
     _main_intent: Optional[IntentAPI]
     _main_intent: Optional[IntentAPI]
     _create_room_lock: asyncio.Lock
     _create_room_lock: asyncio.Lock
@@ -88,14 +91,16 @@ class Portal(DBPortal, BasePortal):
     _reaction_lock: asyncio.Lock
     _reaction_lock: asyncio.Lock
     _pending_members: Optional[Set[UUID]]
     _pending_members: Optional[Set[UUID]]
     _relay_user: Optional['u.User']
     _relay_user: Optional['u.User']
+    _expiration_lock: asyncio.Lock
 
 
     def __init__(self, chat_id: Union[GroupID, Address], receiver: str,
     def __init__(self, chat_id: Union[GroupID, Address], receiver: str,
                  mxid: Optional[RoomID] = None, name: Optional[str] = None,
                  mxid: Optional[RoomID] = None, name: Optional[str] = None,
                  avatar_hash: Optional[str] = None, avatar_url: Optional[ContentURI] = None,
                  avatar_hash: Optional[str] = None, avatar_url: Optional[ContentURI] = None,
                  name_set: bool = False, avatar_set: bool = False, revision: int = 0,
                  name_set: bool = False, avatar_set: bool = False, revision: int = 0,
-                 encrypted: bool = False, relay_user_id: Optional[UserID] = None) -> None:
+                 encrypted: bool = False, relay_user_id: Optional[UserID] = None,
+                 expiration_time: Optional[int] = None) -> None:
         super().__init__(chat_id, receiver, mxid, name, avatar_hash, avatar_url,
         super().__init__(chat_id, receiver, mxid, name, avatar_hash, avatar_url,
-                         name_set, avatar_set, revision, encrypted, relay_user_id)
+                         name_set, avatar_set, revision, encrypted, relay_user_id, expiration_time)
         self._create_room_lock = asyncio.Lock()
         self._create_room_lock = asyncio.Lock()
         self.log = self.log.getChild(self.chat_id_str)
         self.log = self.log.getChild(self.chat_id_str)
         self._main_intent = None
         self._main_intent = None
@@ -105,6 +110,7 @@ class Portal(DBPortal, BasePortal):
         self._reaction_lock = asyncio.Lock()
         self._reaction_lock = asyncio.Lock()
         self._pending_members = None
         self._pending_members = None
         self._relay_user = None
         self._relay_user = None
+        self._expiration_lock = asyncio.Lock()
 
 
     @property
     @property
     def has_relay(self) -> bool:
     def has_relay(self) -> bool:
@@ -150,6 +156,14 @@ class Portal(DBPortal, BasePortal):
         BasePortal.bridge = bridge
         BasePortal.bridge = bridge
         cls.private_chat_portal_meta = cls.config["bridge.private_chat_portal_meta"]
         cls.private_chat_portal_meta = cls.config["bridge.private_chat_portal_meta"]
 
 
+    @classmethod
+    async def start_disappearing_message_expirations(cls):
+        await asyncio.gather(*(
+            cls._expire_event(dm.room_id, dm.mxid, restart=True)
+            for dm in await DisappearingMessage.get_all()
+            if dm.expiration_ts
+        ))
+
     # region Misc
     # region Misc
 
 
     async def _send_delivery_receipt(self, event_id: EventID) -> None:
     async def _send_delivery_receipt(self, event_id: EventID) -> None:
@@ -328,6 +342,18 @@ class Portal(DBPortal, BasePortal):
                 except FileNotFoundError:
                 except FileNotFoundError:
                     pass
                     pass
 
 
+            # Handle disappearing messages
+            if (
+                self.expiration_time
+                and (
+                    self.is_direct
+                    or self.config["signal.enable_disappearing_messages_in_groups"]
+                )
+            ):
+                dm = DisappearingMessage(self.mxid, event_id, self.expiration_time)
+                await dm.insert()
+                await Portal._expire_event(dm.room_id, dm.mxid)
+
     async def handle_matrix_reaction(self, sender: 'u.User', event_id: EventID,
     async def handle_matrix_reaction(self, sender: 'u.User', event_id: EventID,
                                      reacting_to: EventID, emoji: str) -> None:
                                      reacting_to: EventID, emoji: str) -> None:
         if not await sender.is_logged_in():
         if not await sender.is_logged_in():
@@ -522,6 +548,89 @@ class Portal(DBPortal, BasePortal):
             except FileNotFoundError:
             except FileNotFoundError:
                 pass
                 pass
 
 
+    @classmethod
+    async def _expire_event(cls, room_id: RoomID, event_id: EventID, restart: bool = False):
+        """
+        Schedule a task to expire a an event. This should only be called once the message has been
+        read, as the timer for redaction will start immediately, and there is no (supported)
+        mechanism to stop the countdown, even after bridge restart.
+
+        If there is already an expiration event for the given ``room_id`` and ``event_id``, it will
+        not schedule a new task.
+        """
+        portal = await cls.get_by_mxid(room_id)
+        if not portal:
+            raise AttributeError(f"No portal found for {room_id}")
+
+        # Need a lock around this critical section to make sure that we know if a task has been
+        # created for this particular (room_id, event_id) combination.
+        async with portal._expiration_lock:
+            if (
+                not portal.is_direct and not
+                cls.config["signal.enable_disappearing_messages_in_groups"]
+            ):
+                portal.log.debug(
+                    "Not expiring event in group message since "
+                    "signal.enable_disappearing_messages_in_groups is not enabled."
+                )
+                await DisappearingMessage.delete(room_id, event_id)
+                return
+
+            disappearing_message = await DisappearingMessage.get(room_id, event_id)
+            if disappearing_message is None:
+                return
+            wait = disappearing_message.expiration_seconds
+            now = time.time()
+            # If there is an expiration_ts, then there's already a task going, or it's a restart.
+            # If it's a restart, then restart the countdown. This is fairly likely to occur if the
+            # disappearance timeout is weeks.
+            if disappearing_message.expiration_ts:
+                if not restart:
+                    portal.log.debug(f"Expiration task already exists for {event_id} in {room_id}")
+                    return
+                portal.log.debug(f"Resuming expiration for {event_id} in {room_id}")
+                wait = (disappearing_message.expiration_ts / 1000) - now
+            if wait < 0:
+                wait = 0
+
+            # Spawn the actual expiration task.
+            asyncio.create_task(cls._expire_event_task(portal, event_id, wait))
+
+            # Set the expiration_ts only after we have actually created the expiration task.
+            if not disappearing_message.expiration_ts:
+                disappearing_message.expiration_ts = int((now + wait) * 1000)
+                await disappearing_message.update()
+
+
+    @classmethod
+    async def _expire_event_task(cls, portal: 'Portal', event_id: EventID, wait: float):
+        portal.log.debug(f"Redacting {event_id} in {wait} seconds")
+        await asyncio.sleep(wait)
+
+        async with portal._expiration_lock:
+            if not await DisappearingMessage.get(portal.mxid, event_id):
+                portal.log.debug(
+                    f"{event_id} no longer in disappearing messages list, not redacting"
+                )
+                return
+
+            portal.log.debug(f"Redacting {event_id} because it was expired")
+            try:
+                await portal.main_intent.redact(portal.mxid, event_id)
+                portal.log.debug(f"Redacted {event_id} successfully")
+            except Exception as e:
+                portal.log.warning(f"Redacting expired event {event_id} failed", e)
+            finally:
+                await DisappearingMessage.delete(portal.mxid, event_id)
+
+    async def handle_read_receipt(self, event_id: EventID, data: SingleReceiptEventContent):
+        # Start the redaction timers for all of the disappearing messages in the room when the user
+        # reads the room. This is the behavior of the Signal clients.
+        await asyncio.gather(*(
+            Portal._expire_event(dm.room_id, dm.mxid)
+            for dm in await DisappearingMessage.get_all_for_room(self.mxid)
+        ))
+
     # endregion
     # endregion
     # region Signal event handling
     # region Signal event handling
 
 
@@ -621,6 +730,22 @@ class Portal(DBPortal, BasePortal):
                                            timestamps=[message.timestamp])
                                            timestamps=[message.timestamp])
             await self._send_delivery_receipt(event_id)
             await self._send_delivery_receipt(event_id)
             self.log.debug(f"Handled Signal message {message.timestamp} -> {event_id}")
             self.log.debug(f"Handled Signal message {message.timestamp} -> {event_id}")
+
+            if (
+                message.expires_in_seconds
+                and (
+                    self.is_direct
+                    or self.config["signal.enable_disappearing_messages_in_groups"]
+                )
+            ):
+                disappearing_message = DisappearingMessage(
+                    self.mxid, event_id, message.expires_in_seconds
+                )
+                await disappearing_message.insert()
+                self.log.debug(
+                    f"{event_id} set to be redacted {message.expires_in_seconds} seconds after "
+                    "room is read"
+                )
         else:
         else:
             self.log.debug(f"Didn't get event ID for {message.timestamp}")
             self.log.debug(f"Didn't get event ID for {message.timestamp}")
 
 
@@ -853,6 +978,23 @@ class Portal(DBPortal, BasePortal):
             await self.update_bridge_info()
             await self.update_bridge_info()
             await self.update()
             await self.update()
 
 
+    async def update_expires_in_seconds(self, sender: 'p.Puppet', expires_in_seconds: int) -> None:
+        if expires_in_seconds == 0:
+            expires_in_seconds = None
+        if self.expiration_time == expires_in_seconds:
+            return
+
+        assert self.mxid
+        self.expiration_time = expires_in_seconds
+        await self.update()
+
+        time_str = "Off" if expires_in_seconds is None else format_duration(expires_in_seconds)
+        await self.main_intent.send_notice(
+            self.mxid,
+            html=f'<a href="https://matrix.to/#/{sender.mxid}">{sender.name}</a> set the '
+            f'disappearing message timer to {time_str}.'
+        )
+
     async def update_puppet_avatar(self, new_hash: str, avatar_url: ContentURI) -> None:
     async def update_puppet_avatar(self, new_hash: str, avatar_url: ContentURI) -> None:
         if not self.encrypted and not self.private_chat_portal_meta:
         if not self.encrypted and not self.private_chat_portal_meta:
             return
             return

+ 2 - 0
mautrix_signal/signal.py

@@ -113,6 +113,8 @@ class SignalHandler(SignaldClient):
             await portal.update_info(user, msg.group)
             await portal.update_info(user, msg.group)
         if msg.remote_delete:
         if msg.remote_delete:
             await portal.handle_signal_delete(sender, msg.remote_delete.target_sent_timestamp)
             await portal.handle_signal_delete(sender, msg.remote_delete.target_sent_timestamp)
+        if msg.expires_in_seconds is not None:
+            await portal.update_expires_in_seconds(sender, msg.expires_in_seconds)
 
 
     @staticmethod
     @staticmethod
     async def handle_own_receipts(sender: 'pu.Puppet', receipts: List[OwnReadReceipt]) -> None:
     async def handle_own_receipts(sender: 'pu.Puppet', receipts: List[OwnReadReceipt]) -> None:

+ 1 - 1
requirements.txt

@@ -4,5 +4,5 @@ commonmark>=0.8,<0.10
 aiohttp>=3,<4
 aiohttp>=3,<4
 yarl>=1,<2
 yarl>=1,<2
 attrs>=19.1
 attrs>=19.1
-mautrix>=0.13rc1,<0.14
+mautrix>=0.13.0,<0.14
 asyncpg>=0.20,<0.26
 asyncpg>=0.20,<0.26