Browse Source

disappearing messages: handle inbound and outbound disappearing messages

When a message is sent or recieved in a Signal chat that has
disappearing messages enabled, it will automatically be redacted by the
bridge after the time configured on the room. However, note that the
countdown timer will only start once the room has been read.

The main mechanism for this is using async functions that just wait the
configured number of seconds before redacting. However, there we also
store all of the state necessary for determining when to redact a
message in the database in case a restart occurs. In the event of a
restart, we can resume waiting for the expiration (or redact immediately
if we are past the expiration timestamp).

This change also adds m.notice notifications for the user when the
disappearing messages setting changes on the Signal chat.
Sumner Evans 3 years ago
parent
commit
f62f5252a9

+ 1 - 1
mautrix_signal/__main__.py

@@ -68,7 +68,7 @@ class SignalBridge(Bridge):
     async def start(self) -> None:
     async def start(self) -> None:
         User.init_cls(self)
         User.init_cls(self)
         self.add_startup_actions(Puppet.init_cls(self))
         self.add_startup_actions(Puppet.init_cls(self))
-        Portal.init_cls(self)
+        self.add_startup_actions(Portal.init_cls(self))
         if self.config["bridge.resend_bridge_info"]:
         if self.config["bridge.resend_bridge_info"]:
             self.add_startup_actions(self.resend_bridge_info())
             self.add_startup_actions(self.resend_bridge_info())
         self.add_startup_actions(self.signal.start())
         self.add_startup_actions(self.signal.start())

+ 4 - 2
mautrix_signal/db/__init__.py

@@ -3,6 +3,7 @@ import sqlite3
 import uuid
 import uuid
 
 
 from .upgrade import upgrade_table
 from .upgrade import upgrade_table
+from .disappearing_message import DisappearingMessage
 from .user import User
 from .user import User
 from .puppet import Puppet
 from .puppet import Puppet
 from .portal import Portal
 from .portal import Portal
@@ -11,7 +12,7 @@ from .reaction import Reaction
 
 
 
 
 def init(db: Database) -> None:
 def init(db: Database) -> None:
-    for table in (User, Puppet, Portal, Message, Reaction):
+    for table in (User, Puppet, Portal, Message, Reaction, DisappearingMessage):
         table.db = db
         table.db = db
 
 
 
 
@@ -19,4 +20,5 @@ def init(db: Database) -> None:
 sqlite3.register_adapter(uuid.UUID, lambda u: str(u))
 sqlite3.register_adapter(uuid.UUID, lambda u: str(u))
 sqlite3.register_converter("UUID", lambda b: uuid.UUID(b.decode("utf-8") if isinstance(b, bytes) else b))
 sqlite3.register_converter("UUID", lambda b: uuid.UUID(b.decode("utf-8") if isinstance(b, bytes) else b))
 
 
-__all__ = ["upgrade_table", "init", "User", "Puppet", "Portal", "Message", "Reaction"]
+__all__ = ["upgrade_table", "init", "User", "Puppet", "Portal", "Message", "Reaction",
+           "DisappearingMessage"]

+ 94 - 0
mautrix_signal/db/disappearing_message.py

@@ -0,0 +1,94 @@
+# mautrix-signal - A Matrix-Signal puppeting bridge
+# Copyright (C) 2021 Sumner Evans
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+from typing import ClassVar, List, Optional, TYPE_CHECKING
+
+from attr import dataclass
+import asyncpg
+
+from mautrix.types import RoomID, EventID
+from mautrix.util.async_db import Database
+
+fake_db = Database.create("") if TYPE_CHECKING else None
+
+
+@dataclass
+class DisappearingMessage:
+    db: ClassVar[Database] = fake_db
+
+    room_id: RoomID
+    mxid: EventID
+    expiration_seconds: int
+    expiration_ts: Optional[int] = None
+
+    async def insert(self) -> None:
+        q = """
+        INSERT INTO disappearing_messages (room_id, mxid, expiration_seconds, expiration_ts)
+        VALUES ($1, $2, $3, $4)
+        """
+        await self.db.execute(q, self.room_id, self.mxid, self.expiration_seconds,
+                              self.expiration_ts)
+
+    async def update(self) -> None:
+        q = """
+        UPDATE disappearing_messages
+        SET expiration_seconds=$3, expiration_ts=$4
+        WHERE room_id=$1 AND mxid=$2
+        """
+        try:
+            await self.db.execute(q, self.room_id, self.mxid, self.expiration_seconds,
+                              self.expiration_ts)
+        except Exception as e:
+            print(e)
+
+    @classmethod
+    async def delete(cls, room_id: RoomID, event_id: EventID) -> None:
+        q = "DELETE from disappearing_messages WHERE room_id=$1 AND mxid=$2"
+        await cls.db.execute(q, room_id, event_id)
+
+    @classmethod
+    async def delete_all(cls, room_id: RoomID) -> None:
+        await cls.db.execute("DELETE FROM message WHERE room_id=$1", room_id)
+
+    @classmethod
+    def _from_row(cls, row: asyncpg.Record) -> "DisappearingMessage":
+        return cls(**row)
+
+    @classmethod
+    async def get(cls, room_id: RoomID, event_id: EventID) -> Optional["DisappearingMessage"]:
+        q = """
+        SELECT room_id, mxid, expiration_seconds, expiration_ts
+          FROM disappearing_messages
+         WHERE room_id = $1
+           AND mxid = $2
+        """
+        try:
+            return cls._from_row(await cls.db.fetchrow(q, room_id, event_id))
+        except Exception:
+            return None
+
+    @classmethod
+    async def get_all(cls) -> List["DisappearingMessage"]:
+        q = "SELECT room_id, mxid, expiration_seconds, expiration_ts FROM disappearing_messages"
+        return [cls._from_row(r) for r in await cls.db.fetch(q)]
+
+    @classmethod
+    async def get_all_for_room(cls, room_id: RoomID) -> List["DisappearingMessage"]:
+        q = """
+        SELECT room_id, mxid, expiration_seconds, expiration_ts
+          FROM disappearing_messages
+         WHERE room_id = $1
+        """
+        return [cls._from_row(r) for r in await cls.db.fetch(q, room_id)]

+ 16 - 11
mautrix_signal/db/portal.py

@@ -42,6 +42,7 @@ class Portal:
     revision: int
     revision: int
     encrypted: bool
     encrypted: bool
     relay_user_id: Optional[UserID]
     relay_user_id: Optional[UserID]
+    expiration_time: Optional[int]
 
 
     @property
     @property
     def chat_id_str(self) -> str:
     def chat_id_str(self) -> str:
@@ -49,19 +50,23 @@ class Portal:
 
 
     async def insert(self) -> None:
     async def insert(self) -> None:
         q = ("INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, "
         q = ("INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, "
-             "                    name_set, avatar_set, revision, encrypted, relay_user_id) "
-             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)")
+             "                    name_set, avatar_set, revision, encrypted, relay_user_id, "
+             "                    expiration_time) "
+             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)")
         await self.db.execute(q, self.chat_id_str, self.receiver, self.mxid, self.name,
         await self.db.execute(q, self.chat_id_str, self.receiver, self.mxid, self.name,
                               self.avatar_hash, self.avatar_url, self.name_set, self.avatar_set,
                               self.avatar_hash, self.avatar_url, self.name_set, self.avatar_set,
-                              self.revision, self.encrypted, self.relay_user_id)
+                              self.revision, self.encrypted, self.relay_user_id,
+                              self.expiration_time)
 
 
     async def update(self) -> None:
     async def update(self) -> None:
         q = ("UPDATE portal SET mxid=$1, name=$2, avatar_hash=$3, avatar_url=$4, name_set=$5, "
         q = ("UPDATE portal SET mxid=$1, name=$2, avatar_hash=$3, avatar_url=$4, name_set=$5, "
-             "                  avatar_set=$6, revision=$7, encrypted=$8, relay_user_id=$9 "
-             "WHERE chat_id=$10 AND receiver=$11")
+             "                  avatar_set=$6, revision=$7, encrypted=$8, relay_user_id=$9, "
+             "                  expiration_time=$10"
+             "WHERE chat_id=$11 AND receiver=$12")
         await self.db.execute(q, self.mxid, self.name, self.avatar_hash, self.avatar_url,
         await self.db.execute(q, self.mxid, self.name, self.avatar_hash, self.avatar_url,
                               self.name_set, self.avatar_set, self.revision, self.encrypted,
                               self.name_set, self.avatar_set, self.revision, self.encrypted,
-                              self.relay_user_id, self.chat_id_str, self.receiver)
+                              self.relay_user_id, self.expiration_time, self.chat_id_str,
+                              self.receiver)
 
 
     @classmethod
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> 'Portal':
     def _from_row(cls, row: asyncpg.Record) -> 'Portal':
@@ -74,7 +79,7 @@ class Portal:
     @classmethod
     @classmethod
     async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
     async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE mxid=$1")
              "FROM portal WHERE mxid=$1")
         row = await cls.db.fetchrow(q, mxid)
         row = await cls.db.fetchrow(q, mxid)
         if not row:
         if not row:
@@ -85,7 +90,7 @@ class Portal:
     async def get_by_chat_id(cls, chat_id: Union[GroupID, Address], receiver: str = ""
     async def get_by_chat_id(cls, chat_id: Union[GroupID, Address], receiver: str = ""
                              ) -> Optional['Portal']:
                              ) -> Optional['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE chat_id=$1 AND receiver=$2")
              "FROM portal WHERE chat_id=$1 AND receiver=$2")
         row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
         row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
         if not row:
         if not row:
@@ -95,7 +100,7 @@ class Portal:
     @classmethod
     @classmethod
     async def find_private_chats_of(cls, receiver: str) -> List['Portal']:
     async def find_private_chats_of(cls, receiver: str) -> List['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE receiver=$1")
              "FROM portal WHERE receiver=$1")
         rows = await cls.db.fetch(q, receiver)
         rows = await cls.db.fetch(q, receiver)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]
@@ -103,7 +108,7 @@ class Portal:
     @classmethod
     @classmethod
     async def find_private_chats_with(cls, other_user: Address) -> List['Portal']:
     async def find_private_chats_with(cls, other_user: Address) -> List['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE chat_id=$1 AND receiver<>''")
              "FROM portal WHERE chat_id=$1 AND receiver<>''")
         rows = await cls.db.fetch(q, other_user.best_identifier)
         rows = await cls.db.fetch(q, other_user.best_identifier)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]
@@ -111,7 +116,7 @@ class Portal:
     @classmethod
     @classmethod
     async def all_with_room(cls) -> List['Portal']:
     async def all_with_room(cls) -> List['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id "
+             "       revision, encrypted, relay_user_id, expiration_time "
              "FROM portal WHERE mxid IS NOT NULL")
              "FROM portal WHERE mxid IS NOT NULL")
         rows = await cls.db.fetch(q)
         rows = await cls.db.fetch(q)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]

+ 17 - 0
mautrix_signal/db/upgrade.py

@@ -175,3 +175,20 @@ async def upgrade_v6(conn: Connection) -> None:
 @upgrade_table.register(description="Add relay user field to portal table")
 @upgrade_table.register(description="Add relay user field to portal table")
 async def upgrade_v7(conn: Connection) -> None:
 async def upgrade_v7(conn: Connection) -> None:
     await conn.execute("ALTER TABLE portal ADD COLUMN relay_user_id TEXT")
     await conn.execute("ALTER TABLE portal ADD COLUMN relay_user_id TEXT")
+
+
+@upgrade_table.register(description="Add table for tracking when to redact disappearing messages")
+async def upgrade_v8(conn: Connection) -> None:
+    await conn.execute("""CREATE TABLE disappearing_messages (
+        room_id             TEXT,
+        mxid                TEXT,
+        expiration_seconds  BIGINT,
+        expiration_ts       BIGINT,
+
+        PRIMARY KEY (room_id, mxid)
+    )""")
+
+
+@upgrade_table.register(description="Add expiration_time column to portal table")
+async def upgrade_v9(conn: Connection) -> None:
+    await conn.execute("ALTER TABLE portal ADD COLUMN expiration_time BIGINT")

+ 5 - 0
mautrix_signal/matrix.py

@@ -14,12 +14,15 @@
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 from typing import List, Union, TYPE_CHECKING
 from typing import List, Union, TYPE_CHECKING
+from datetime import datetime
 
 
 from mautrix.bridge import BaseMatrixHandler
 from mautrix.bridge import BaseMatrixHandler
 from mautrix.types import (Event, ReactionEvent, StateEvent, RoomID, EventID, UserID, TypingEvent,
 from mautrix.types import (Event, ReactionEvent, StateEvent, RoomID, EventID, UserID, TypingEvent,
                            ReactionEventContent, RelationType, EventType, ReceiptEvent,
                            ReactionEventContent, RelationType, EventType, ReceiptEvent,
                            PresenceEvent, RedactionEvent, SingleReceiptEventContent)
                            PresenceEvent, RedactionEvent, SingleReceiptEventContent)
 
 
+from mautrix_signal.db.disappearing_message import DisappearingMessage
+
 from .db import Message as DBMessage
 from .db import Message as DBMessage
 from . import portal as po, user as u, signal as s
 from . import portal as po, user as u, signal as s
 
 
@@ -102,6 +105,8 @@ class MatrixHandler(BaseMatrixHandler):
 
 
     async def handle_read_receipt(self, user: 'u.User', portal: 'po.Portal', event_id: EventID,
     async def handle_read_receipt(self, user: 'u.User', portal: 'po.Portal', event_id: EventID,
                                   data: SingleReceiptEventContent) -> None:
                                   data: SingleReceiptEventContent) -> None:
+        await portal.handle_read_receipt(event_id, data)
+
         message = await DBMessage.get_by_mxid(event_id, portal.mxid)
         message = await DBMessage.get_by_mxid(event_id, portal.mxid)
         if not message:
         if not message:
             return
             return

+ 122 - 5
mautrix_signal/portal.py

@@ -19,6 +19,7 @@ from html import escape as escape_html
 from collections import deque
 from collections import deque
 from uuid import UUID, uuid4
 from uuid import UUID, uuid4
 from string import Template
 from string import Template
+from datetime import datetime
 import mimetypes
 import mimetypes
 import pathlib
 import pathlib
 import hashlib
 import hashlib
@@ -36,12 +37,13 @@ from mautrix.bridge import BasePortal, async_getter_lock
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType, Format,
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType, Format,
                            MessageEvent, EncryptedEvent, ContentURI, MediaMessageEventContent,
                            MessageEvent, EncryptedEvent, ContentURI, MediaMessageEventContent,
                            TextMessageEventContent, ImageInfo, VideoInfo, FileInfo, AudioInfo,
                            TextMessageEventContent, ImageInfo, VideoInfo, FileInfo, AudioInfo,
-                           PowerLevelStateEventContent, UserID)
+                           PowerLevelStateEventContent, UserID, SingleReceiptEventContent)
 from mautrix.util.bridge_state import BridgeStateEvent
 from mautrix.util.bridge_state import BridgeStateEvent
 from mautrix.util.message_send_checkpoint import MessageSendCheckpointStatus
 from mautrix.util.message_send_checkpoint import MessageSendCheckpointStatus
 from mautrix.errors import MatrixError, MForbidden, IntentError
 from mautrix.errors import MatrixError, MForbidden, IntentError
 
 
-from .db import Portal as DBPortal, Message as DBMessage, Reaction as DBReaction
+from .db import (Portal as DBPortal, Message as DBMessage, Reaction as DBReaction,
+                 DisappearingMessage)
 from .config import Config
 from .config import Config
 from .formatter import matrix_to_signal, signal_to_matrix
 from .formatter import matrix_to_signal, signal_to_matrix
 from .util import id_to_str
 from .util import id_to_str
@@ -80,6 +82,7 @@ class Portal(DBPortal, BasePortal):
     signal: 's.SignalHandler'
     signal: 's.SignalHandler'
     az: AppService
     az: AppService
     private_chat_portal_meta: bool
     private_chat_portal_meta: bool
+    expiration_time: Optional[int]
 
 
     _main_intent: Optional[IntentAPI]
     _main_intent: Optional[IntentAPI]
     _create_room_lock: asyncio.Lock
     _create_room_lock: asyncio.Lock
@@ -93,9 +96,10 @@ class Portal(DBPortal, BasePortal):
                  mxid: Optional[RoomID] = None, name: Optional[str] = None,
                  mxid: Optional[RoomID] = None, name: Optional[str] = None,
                  avatar_hash: Optional[str] = None, avatar_url: Optional[ContentURI] = None,
                  avatar_hash: Optional[str] = None, avatar_url: Optional[ContentURI] = None,
                  name_set: bool = False, avatar_set: bool = False, revision: int = 0,
                  name_set: bool = False, avatar_set: bool = False, revision: int = 0,
-                 encrypted: bool = False, relay_user_id: Optional[UserID] = None) -> None:
+                 encrypted: bool = False, relay_user_id: Optional[UserID] = None,
+                 expiration_time: Optional[int] = None) -> None:
         super().__init__(chat_id, receiver, mxid, name, avatar_hash, avatar_url,
         super().__init__(chat_id, receiver, mxid, name, avatar_hash, avatar_url,
-                         name_set, avatar_set, revision, encrypted, relay_user_id)
+                         name_set, avatar_set, revision, encrypted, relay_user_id, expiration_time)
         self._create_room_lock = asyncio.Lock()
         self._create_room_lock = asyncio.Lock()
         self.log = self.log.getChild(self.chat_id_str)
         self.log = self.log.getChild(self.chat_id_str)
         self._main_intent = None
         self._main_intent = None
@@ -141,7 +145,7 @@ class Portal(DBPortal, BasePortal):
         self.by_chat_id[(self.chat_id_str, self.receiver)] = self
         self.by_chat_id[(self.chat_id_str, self.receiver)] = self
 
 
     @classmethod
     @classmethod
-    def init_cls(cls, bridge: 'SignalBridge') -> None:
+    async def init_cls(cls, bridge: 'SignalBridge') -> None:
         cls.config = bridge.config
         cls.config = bridge.config
         cls.matrix = bridge.matrix
         cls.matrix = bridge.matrix
         cls.signal = bridge.signal
         cls.signal = bridge.signal
@@ -150,6 +154,10 @@ class Portal(DBPortal, BasePortal):
         BasePortal.bridge = bridge
         BasePortal.bridge = bridge
         cls.private_chat_portal_meta = cls.config["bridge.private_chat_portal_meta"]
         cls.private_chat_portal_meta = cls.config["bridge.private_chat_portal_meta"]
 
 
+        for dm in await DisappearingMessage.get_all():
+            if dm.expiration_ts:
+                asyncio.create_task(cls._expire_event(dm))
+
     # region Misc
     # region Misc
 
 
     async def _send_delivery_receipt(self, event_id: EventID) -> None:
     async def _send_delivery_receipt(self, event_id: EventID) -> None:
@@ -328,6 +336,14 @@ class Portal(DBPortal, BasePortal):
                 except FileNotFoundError:
                 except FileNotFoundError:
                     pass
                     pass
 
 
+            # Handle disappearing messages
+            if self.expiration_time:
+                disappearing_message = DisappearingMessage(
+                    self.mxid, event_id, self.expiration_time
+                )
+                await disappearing_message.insert()
+                asyncio.create_task(Portal._expire_event(disappearing_message))
+
     async def handle_matrix_reaction(self, sender: 'u.User', event_id: EventID,
     async def handle_matrix_reaction(self, sender: 'u.User', event_id: EventID,
                                      reacting_to: EventID, emoji: str) -> None:
                                      reacting_to: EventID, emoji: str) -> None:
         if not await sender.is_logged_in():
         if not await sender.is_logged_in():
@@ -522,6 +538,56 @@ class Portal(DBPortal, BasePortal):
             except FileNotFoundError:
             except FileNotFoundError:
                 pass
                 pass
 
 
+    @classmethod
+    async def _expire_event(cls, disappearing_message: DisappearingMessage):
+        """
+        Expire a :class:`DisappearingMessage`. This should only be called once the message has been
+        read, as the timer for redaction will start immediately, and there is no (supported)
+        mechanism to stop the countdown, even after bridge restart.
+        """
+        room_id = disappearing_message.room_id
+        event_id = disappearing_message.mxid
+        wait = disappearing_message.expiration_seconds
+        # If there is an expiration_ts, then there was probably a bridge restart, so we have to
+        # resume the countdown. This is fairly likely to occur if the disappearance timeout is
+        # weeks.
+        # If there is not an expiration_ts, then set one.
+        now = datetime.now().timestamp()
+        if disappearing_message.expiration_ts is not None:
+            wait = (disappearing_message.expiration_ts / 1000) - now
+        else:
+            disappearing_message.expiration_ts = int((now + wait) * 1000)
+            await disappearing_message.update()
+
+        if wait < 0:
+            wait = 0
+
+        portal = await cls.get_by_mxid(room_id)
+        if not portal:
+            portal.log.warning(f"No portal for {room_id}")
+
+        portal.log.debug(f"Redacting {event_id} in {wait} seconds")
+        await asyncio.sleep(wait)
+
+        if not await DisappearingMessage.get(room_id, event_id):
+            portal.log.debug(f"{event_id} no longer in disappearing messages list, not redacting")
+            return
+
+        portal.log.debug(f"Redacting {event_id} because it was expired")
+        try:
+            await portal.main_intent.redact(room_id, event_id)
+            portal.log.debug(f"Redacted {event_id} successfully")
+        except Exception as e:
+            portal.log.warning("Redacting expired event didn't work", e)
+        finally:
+            await DisappearingMessage.delete(room_id, event_id)
+
+    async def handle_read_receipt(self, event_id: EventID, data: SingleReceiptEventContent):
+        # Start the redaction timers for all of the disappearing messages in the room when the user
+        # reads the room. This is the behavior of the Signal clients.
+        for disappearing_message in await DisappearingMessage.get_all_for_room(self.mxid):
+            asyncio.create_task(Portal._expire_event(disappearing_message))
+
     # endregion
     # endregion
     # region Signal event handling
     # region Signal event handling
 
 
@@ -621,6 +687,16 @@ class Portal(DBPortal, BasePortal):
                                            timestamps=[message.timestamp])
                                            timestamps=[message.timestamp])
             await self._send_delivery_receipt(event_id)
             await self._send_delivery_receipt(event_id)
             self.log.debug(f"Handled Signal message {message.timestamp} -> {event_id}")
             self.log.debug(f"Handled Signal message {message.timestamp} -> {event_id}")
+
+            if message.expires_in_seconds:
+                disappearing_message = DisappearingMessage(
+                    self.mxid, event_id, message.expires_in_seconds
+                )
+                await disappearing_message.insert()
+                self.log.debug(
+                    f"{event_id} set to be redacted {message.expires_in_seconds} seconds after "
+                    "room is read"
+                )
         else:
         else:
             self.log.debug(f"Didn't get event ID for {message.timestamp}")
             self.log.debug(f"Didn't get event ID for {message.timestamp}")
 
 
@@ -853,6 +929,46 @@ class Portal(DBPortal, BasePortal):
             await self.update_bridge_info()
             await self.update_bridge_info()
             await self.update()
             await self.update()
 
 
+    async def update_expires_in_seconds(self, sender: 'p.Puppet', expires_in_seconds: int) -> None:
+        if expires_in_seconds == 0:
+            expires_in_seconds = None
+        if self.expiration_time == expires_in_seconds:
+            return
+
+        def format_time(seconds) -> str:
+            if seconds is None:
+                return "Off"
+
+            # Technique from https://stackoverflow.com/a/24542445
+            intervals = (
+                ("weeks", 604800),  # 60 * 60 * 24 * 7
+                ("days", 86400),    # 60 * 60 * 24
+                ("hours", 3600),    # 60 * 60
+                ("minutes", 60),
+                ("seconds", 1),
+            )
+
+            result = []
+            for name, count in intervals:
+                value = seconds // count
+                if value:
+                    seconds -= value * count
+                    if value == 1:
+                        name = name[:-1]
+                    result.append(f"{value} {name}")
+            return ", ".join(result)
+
+        assert self.mxid
+        self.expiration_time = expires_in_seconds
+        await self.update()
+
+        await self.main_intent.send_notice(
+            self.mxid,
+            text=None,
+            html=f'<a href="https://matrix.to/#/{sender.mxid}">{sender.name}</a> set the '
+            f'disappearing message timer to {format_time(expires_in_seconds)}.'
+        )
+
     async def update_puppet_avatar(self, new_hash: str, avatar_url: ContentURI) -> None:
     async def update_puppet_avatar(self, new_hash: str, avatar_url: ContentURI) -> None:
         if not self.encrypted and not self.private_chat_portal_meta:
         if not self.encrypted and not self.private_chat_portal_meta:
             return
             return
@@ -1195,6 +1311,7 @@ class Portal(DBPortal, BasePortal):
         elif not self.is_direct:
         elif not self.is_direct:
             self._main_intent = self.az.intent
             self._main_intent = self.az.intent
 
 
+
     async def delete(self) -> None:
     async def delete(self) -> None:
         await DBMessage.delete_all(self.mxid)
         await DBMessage.delete_all(self.mxid)
         self.by_mxid.pop(self.mxid, None)
         self.by_mxid.pop(self.mxid, None)

+ 2 - 0
mautrix_signal/signal.py

@@ -113,6 +113,8 @@ class SignalHandler(SignaldClient):
             await portal.update_info(user, msg.group)
             await portal.update_info(user, msg.group)
         if msg.remote_delete:
         if msg.remote_delete:
             await portal.handle_signal_delete(sender, msg.remote_delete.target_sent_timestamp)
             await portal.handle_signal_delete(sender, msg.remote_delete.target_sent_timestamp)
+        if msg.expires_in_seconds is not None:
+            await portal.update_expires_in_seconds(sender, msg.expires_in_seconds)
 
 
     @staticmethod
     @staticmethod
     async def handle_own_receipts(sender: 'pu.Puppet', receipts: List[OwnReadReceipt]) -> None:
     async def handle_own_receipts(sender: 'pu.Puppet', receipts: List[OwnReadReceipt]) -> None: