Эх сурвалжийг харах

Update to mautrix-python v0.11 and add experimental support for SQLite

Closes #166

Co-authored-by: Will Hunt <willh@matrix.org>
Tulir Asokan 3 жил өмнө
parent
commit
cc7883ff8c

+ 8 - 4
mautrix_signal/__main__.py

@@ -58,8 +58,8 @@ class SignalBridge(Bridge):
         self.state_store = PgBridgeStateStore(self.db, self.get_puppet, self.get_double_puppet)
 
     def prepare_db(self) -> None:
-        self.db = Database(self.config["appservice.database"], upgrade_table=upgrade_table,
-                           loop=self.loop, db_args=self.config["appservice.database_opts"])
+        self.db = Database.create(self.config["appservice.database"], upgrade_table=upgrade_table,
+                                  db_args=self.config["appservice.database_opts"])
         init_db(self.db)
 
     def prepare_bridge(self) -> None:
@@ -71,9 +71,9 @@ class SignalBridge(Bridge):
 
     async def start(self) -> None:
         await self.db.start()
-        await self.state_store.upgrade_table.upgrade(self.db.pool)
+        await self.state_store.upgrade_table.upgrade(self.db)
         if self.matrix.e2ee:
-            self.matrix.e2ee.crypto_db.override_pool(self.db.pool)
+            self.matrix.e2ee.crypto_db.override_pool(self.db)
         User.init_cls(self)
         self.add_startup_actions(Puppet.init_cls(self))
         Portal.init_cls(self)
@@ -83,6 +83,10 @@ class SignalBridge(Bridge):
         await super().start()
         self.periodic_sync_task = asyncio.create_task(self._periodic_sync_loop())
 
+    async def stop(self) -> None:
+        await super().stop()
+        await self.db.stop()
+
     @staticmethod
     async def _actual_periodic_sync_loop(log: logging.Logger, interval: int) -> None:
         while True:

+ 6 - 0
mautrix_signal/db/__init__.py

@@ -1,4 +1,6 @@
 from mautrix.util.async_db import Database
+import sqlite3
+import uuid
 
 from .upgrade import upgrade_table
 from .user import User
@@ -13,4 +15,8 @@ def init(db: Database) -> None:
         table.db = db
 
 
+# TODO should this be in mautrix-python?
+sqlite3.register_adapter(uuid.UUID, lambda u: str(u))
+sqlite3.register_converter("UUID", lambda b: uuid.UUID(b))
+
 __all__ = ["upgrade_table", "init", "User", "Puppet", "Portal", "Message", "Reaction"]

+ 10 - 4
mautrix_signal/db/message.py

@@ -25,7 +25,7 @@ from mautrix.util.async_db import Database
 
 from ..util import id_to_str
 
-fake_db = Database("") if TYPE_CHECKING else None
+fake_db = Database.create("") if TYPE_CHECKING else None
 
 
 @dataclass
@@ -88,9 +88,15 @@ class Message:
 
     @classmethod
     async def find_by_timestamps(cls, timestamps: List[int]) -> List['Message']:
-        q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
-             "FROM message WHERE timestamp=ANY($1)")
-        rows = await cls.db.fetch(q, timestamps)
+        if cls.db.scheme == "postgres":
+            q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
+                 "FROM message WHERE timestamp=ANY($1)")
+            rows = await cls.db.fetch(q, timestamps)
+        else:
+            placeholders = ", ".join(f"?" for _ in range(len(timestamps)))
+            q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
+                 f"FROM message WHERE timestamp IN ({placeholders})")
+            rows = await cls.db.fetch(q, *timestamps)
         return [cls._from_row(row) for row in rows]
 
     @classmethod

+ 7 - 7
mautrix_signal/db/portal.py

@@ -24,7 +24,7 @@ from mautrix.util.async_db import Database
 
 from ..util import id_to_str
 
-fake_db = Database("") if TYPE_CHECKING else None
+fake_db = Database.create("") if TYPE_CHECKING else None
 
 
 @dataclass
@@ -56,12 +56,12 @@ class Portal:
                               self.revision, self.encrypted, self.relay_user_id)
 
     async def update(self) -> None:
-        q = ("UPDATE portal SET mxid=$3, name=$4, avatar_hash=$5, avatar_url=$6, name_set=$7, "
-             "                  avatar_set=$8, revision=$9, encrypted=$10, relay_user_id=$11 "
-             "WHERE chat_id=$1 AND receiver=$2")
-        await self.db.execute(q, self.chat_id_str, self.receiver, self.mxid, self.name,
-                              self.avatar_hash, self.avatar_url, self.name_set, self.avatar_set,
-                              self.revision, self.encrypted, self.relay_user_id)
+        q = ("UPDATE portal SET mxid=$1, name=$2, avatar_hash=$3, avatar_url=$4, name_set=$5, "
+             "                  avatar_set=$6, revision=$7, encrypted=$8, relay_user_id=$9 "
+             "WHERE chat_id=$10 AND receiver=$11")
+        await self.db.execute(q, self.mxid, self.name, self.avatar_hash, self.avatar_url,
+                              self.name_set, self.avatar_set, self.revision, self.encrypted,
+                              self.relay_user_id, self.chat_id_str, self.receiver)
 
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> 'Portal':

+ 1 - 1
mautrix_signal/db/puppet.py

@@ -24,7 +24,7 @@ from mausignald.types import Address
 from mautrix.types import UserID, SyncToken, ContentURI
 from mautrix.util.async_db import Database
 
-fake_db = Database("") if TYPE_CHECKING else None
+fake_db = Database.create("") if TYPE_CHECKING else None
 
 
 @dataclass

+ 1 - 1
mautrix_signal/db/reaction.py

@@ -25,7 +25,7 @@ from mautrix.util.async_db import Database
 
 from ..util import id_to_str
 
-fake_db = Database("") if TYPE_CHECKING else None
+fake_db = Database.create("") if TYPE_CHECKING else None
 
 
 @dataclass

+ 39 - 1
mautrix_signal/db/upgrade.py

@@ -103,7 +103,45 @@ async def upgrade_v3(conn: Connection) -> None:
 
 
 @upgrade_table.register(description="Allow phone numbers as message sender identifiers")
-async def upgrade_v4(conn: Connection) -> None:
+async def upgrade_v4(conn: Connection, scheme: str) -> None:
+    if scheme == "sqlite":
+        # SQLite doesn't have anything in the tables yet,
+        # so just recreate them without migrating data
+        await conn.execute("DROP TABLE message")
+        await conn.execute("DROP TABLE reaction")
+        await conn.execute("""CREATE TABLE message (
+            mxid    TEXT NOT NULL,
+            mx_room TEXT NOT NULL,
+            sender          TEXT,
+            timestamp       BIGINT,
+            signal_chat_id  TEXT,
+            signal_receiver TEXT,
+
+            PRIMARY KEY (sender, timestamp, signal_chat_id, signal_receiver),
+            FOREIGN KEY (signal_chat_id, signal_receiver) REFERENCES portal(chat_id, receiver)
+                ON UPDATE CASCADE ON DELETE CASCADE,
+            UNIQUE (mxid, mx_room)
+        )""")
+        await conn.execute("""CREATE TABLE reaction (
+            mxid    TEXT NOT NULL,
+            mx_room TEXT NOT NULL,
+
+            signal_chat_id  TEXT   NOT NULL,
+            signal_receiver TEXT   NOT NULL,
+            msg_author      TEXT   NOT NULL,
+            msg_timestamp   BIGINT NOT NULL,
+            author          TEXT   NOT NULL,
+
+            emoji TEXT NOT NULL,
+
+            PRIMARY KEY (signal_chat_id, signal_receiver, msg_author, msg_timestamp, author),
+            FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver)
+                REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver)
+                ON DELETE CASCADE ON UPDATE CASCADE,
+            UNIQUE (mxid, mx_room)
+        )""")
+        return
+
     cname = await conn.fetchval("SELECT constraint_name FROM information_schema.table_constraints "
                                 "WHERE table_name='reaction' AND constraint_name LIKE '%_fkey'")
     await conn.execute(f"ALTER TABLE reaction DROP CONSTRAINT {cname}")

+ 4 - 4
mautrix_signal/db/user.py

@@ -1,5 +1,5 @@
 # mautrix-signal - A Matrix-Signal puppeting bridge
-# Copyright (C) 2020 Tulir Asokan
+# Copyright (C) 2021 Tulir Asokan
 #
 # This program is free software: you can redistribute it and/or modify
 # it under the terms of the GNU Affero General Public License as published by
@@ -21,7 +21,7 @@ from attr import dataclass
 from mautrix.types import UserID, RoomID
 from mautrix.util.async_db import Database
 
-fake_db = Database("") if TYPE_CHECKING else None
+fake_db = Database.create("") if TYPE_CHECKING else None
 
 
 @dataclass
@@ -39,8 +39,8 @@ class User:
         await self.db.execute(q, self.mxid, self.username, self.uuid, self.notice_room)
 
     async def update(self) -> None:
-        await self.db.execute('UPDATE "user" SET username=$2, uuid=$3, notice_room=$4 '
-                              'WHERE mxid=$1', self.mxid, self.username, self.uuid, self.notice_room)
+        q = 'UPDATE "user" SET username=$1, uuid=$2, notice_room=$3 WHERE mxid=$4'
+        await self.db.execute(q, self.username, self.uuid, self.notice_room, self.mxid)
 
     @classmethod
     async def get_by_mxid(cls, mxid: UserID) -> Optional['User']:

+ 7 - 2
mautrix_signal/example-config.yaml

@@ -31,10 +31,15 @@ appservice:
     # Usually 1 is enough, but on high-traffic bridges you might need to increase this to avoid 413s
     max_body_size: 1
 
-    # The full URI to the database. Only Postgres is currently supported.
+    # The full URI to the database. SQLite and Postgres are supported.
+    # Format examples:
+    #   SQLite:   sqlite:///filename.db
+    #   Postgres: postgres://username:password@hostname/dbname
     database: postgres://username:password@hostname/db
-    # Additional arguments for asyncpg.create_pool()
+    # Additional arguments for asyncpg.create_pool() or sqlite3.connect()
     # https://magicstack.github.io/asyncpg/current/api/index.html#asyncpg.pool.create_pool
+    # https://docs.python.org/3/library/sqlite3.html#sqlite3.connect
+    # For sqlite, min_size is used as the connection thread pool size and max_size is ignored.
     database_opts:
         min_size: 5
         max_size: 10

+ 4 - 1
optional-requirements.txt

@@ -7,7 +7,7 @@ pycryptodome>=3,<4
 unpaddedbase64>=1,<2
 
 #/metrics
-prometheus_client>=0.6,<0.12
+prometheus_client>=0.6,<0.13
 
 #/formattednumbers
 phonenumbers>=8,<9
@@ -18,3 +18,6 @@ Pillow>=4,<9
 
 #/stickers
 signalstickers-client>=3,<4
+
+#/sqlite
+aiosqlite>=0.16,<0.18

+ 1 - 1
requirements.txt

@@ -4,5 +4,5 @@ commonmark>=0.8,<0.10
 aiohttp>=3,<4
 yarl>=1,<2
 attrs>=19.1
-mautrix>=0.10.5,<0.11
+mautrix>=0.11.1,<0.12
 asyncpg>=0.20,<0.25