Эх сурвалжийг харах

Add support for group descriptions

Tulir Asokan 3 жил өмнө
parent
commit
942478969a

+ 1 - 0
mausignald/types.py

@@ -205,6 +205,7 @@ class GroupMember(SerializableAttrs):
 @dataclass(kw_only=True)
 class GroupV2(GroupV2ID, SerializableAttrs):
     title: str
+    description: Optional[str] = None
     avatar: Optional[str] = None
     timer: Optional[int] = None
     master_key: Optional[str] = field(default=None, json="masterKey")

+ 29 - 52
mautrix_signal/db/portal.py

@@ -37,6 +37,7 @@ class Portal:
     receiver: str
     mxid: RoomID | None
     name: str | None
+    topic: str | None
     avatar_hash: str | None
     avatar_url: ContentURI | None
     name_set: bool
@@ -50,18 +51,14 @@ class Portal:
     def chat_id_str(self) -> str:
         return id_to_str(self.chat_id)
 
-    async def insert(self) -> None:
-        q = """
-        INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set,
-                            avatar_set, revision, encrypted, relay_user_id, expiration_time)
-        VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
-        """
-        await self.db.execute(
-            q,
+    @property
+    def _values(self):
+        return (
             self.chat_id_str,
             self.receiver,
             self.mxid,
             self.name,
+            self.topic,
             self.avatar_hash,
             self.avatar_url,
             self.name_set,
@@ -72,28 +69,23 @@ class Portal:
             self.expiration_time,
         )
 
+    async def insert(self) -> None:
+        q = """
+        INSERT INTO portal (
+            chat_id, receiver, mxid, name, topic, avatar_hash, avatar_url, name_set, avatar_set,
+            revision, encrypted, relay_user_id, expiration_time
+        ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
+        """
+        await self.db.execute(q, *self._values)
+
     async def update(self) -> None:
         q = """
-        UPDATE portal SET mxid=$1, name=$2, avatar_hash=$3, avatar_url=$4, name_set=$5,
-                          avatar_set=$6, revision=$7, encrypted=$8, relay_user_id=$9,
-                          expiration_time=$10
-        WHERE chat_id=$11 AND receiver=$12
+        UPDATE portal SET mxid=$3, name=$4, topic=$5, avatar_hash=$6, avatar_url=$7, name_set=$8,
+                          avatar_set=$9, revision=$10, encrypted=$11, relay_user_id=$12,
+                          expiration_time=$13
+        WHERE chat_id=$1 AND receiver=$2
         """
-        await self.db.execute(
-            q,
-            self.mxid,
-            self.name,
-            self.avatar_hash,
-            self.avatar_url,
-            self.name_set,
-            self.avatar_set,
-            self.revision,
-            self.encrypted,
-            self.relay_user_id,
-            self.expiration_time,
-            self.chat_id_str,
-            self.receiver,
-        )
+        await self.db.execute(q, *self._values)
 
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> Portal:
@@ -103,13 +95,14 @@ class Portal:
             chat_id = Address.parse(chat_id)
         return cls(chat_id=chat_id, **data)
 
+    _columns = (
+        "chat_id, receiver, mxid, name, topic, avatar_hash, avatar_url, name_set, avatar_set, "
+        "revision, encrypted, relay_user_id, expiration_time"
+    )
+
     @classmethod
     async def get_by_mxid(cls, mxid: RoomID) -> Portal | None:
-        q = """
-        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
-               revision, encrypted, relay_user_id, expiration_time FROM portal
-        WHERE mxid=$1
-        """
+        q = f"SELECT {cls._columns} FROM portal WHERE mxid=$1"
         row = await cls.db.fetchrow(q, mxid)
         if not row:
             return None
@@ -117,11 +110,7 @@ class Portal:
 
     @classmethod
     async def get_by_chat_id(cls, chat_id: GroupID | Address, receiver: str = "") -> Portal | None:
-        q = """
-        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
-               revision, encrypted, relay_user_id, expiration_time FROM portal
-        WHERE chat_id=$1 AND receiver=$2
-        """
+        q = f"SELECT {cls._columns} FROM portal WHERE chat_id=$1 AND receiver=$2"
         row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
         if not row:
             return None
@@ -129,30 +118,18 @@ class Portal:
 
     @classmethod
     async def find_private_chats_of(cls, receiver: str) -> list[Portal]:
-        q = """
-        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
-               revision, encrypted, relay_user_id, expiration_time FROM portal
-        WHERE receiver=$1
-        """
+        q = f"SELECT {cls._columns} FROM portal WHERE receiver=$1"
         rows = await cls.db.fetch(q, receiver)
         return [cls._from_row(row) for row in rows]
 
     @classmethod
     async def find_private_chats_with(cls, other_user: Address) -> list[Portal]:
-        q = """
-        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
-               revision, encrypted, relay_user_id, expiration_time FROM portal
-        WHERE chat_id=$1 AND receiver<>''
-        """
+        q = f"SELECT {cls._columns} FROM portal WHERE chat_id=$1 AND receiver<>''"
         rows = await cls.db.fetch(q, other_user.best_identifier)
         return [cls._from_row(row) for row in rows]
 
     @classmethod
     async def all_with_room(cls) -> list[Portal]:
-        q = """
-        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
-               revision, encrypted, relay_user_id, expiration_time FROM portal
-        WHERE mxid IS NOT NULL
-        """
+        q = f"SELECT {cls._columns} FROM portal WHERE mxid IS NOT NULL"
         rows = await cls.db.fetch(q)
         return [cls._from_row(row) for row in rows]

+ 1 - 0
mautrix_signal/db/upgrade/__init__.py

@@ -11,4 +11,5 @@ from . import (
     v06_portal_revision,
     v07_portal_relay_user,
     v08_disappearing_messages,
+    v09_group_topic,
 )

+ 2 - 1
mautrix_signal/db/upgrade/v00_latest_revision.py

@@ -18,7 +18,7 @@ from mautrix.util.async_db import Connection
 from . import upgrade_table
 
 
-@upgrade_table.register(description="Initial revision", upgrades_to=8)
+@upgrade_table.register(description="Initial revision", upgrades_to=9)
 async def upgrade_latest(conn: Connection) -> None:
     await conn.execute(
         """CREATE TABLE portal (
@@ -26,6 +26,7 @@ async def upgrade_latest(conn: Connection) -> None:
             receiver    TEXT,
             mxid        TEXT,
             name        TEXT,
+            topic       TEXT,
             encrypted   BOOLEAN NOT NULL DEFAULT false,
             avatar_hash TEXT,
             avatar_url  TEXT,

+ 23 - 0
mautrix_signal/db/upgrade/v09_group_topic.py

@@ -0,0 +1,23 @@
+# mautrix-signal - A Matrix-Signal puppeting bridge
+# Copyright (C) 2022 Tulir Asokan
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+from mautrix.util.async_db import Connection
+
+from . import upgrade_table
+
+
+@upgrade_table.register(description="Add support for group descriptions")
+async def upgrade_v9(conn: Connection) -> None:
+    await conn.execute("ALTER TABLE portal ADD COLUMN topic TEXT")

+ 19 - 1
mautrix_signal/portal.py

@@ -143,6 +143,7 @@ class Portal(DBPortal, BasePortal):
         receiver: str,
         mxid: RoomID | None = None,
         name: str | None = None,
+        topic: str | None = None,
         avatar_hash: str | None = None,
         avatar_url: ContentURI | None = None,
         name_set: bool = False,
@@ -157,6 +158,7 @@ class Portal(DBPortal, BasePortal):
             receiver=receiver,
             mxid=mxid,
             name=name,
+            topic=topic,
             avatar_hash=avatar_hash,
             avatar_url=avatar_url,
             name_set=name_set,
@@ -955,7 +957,7 @@ class Portal(DBPortal, BasePortal):
         else:
             self.log.debug(f"Didn't get event ID for {message.timestamp}")
 
-    async def handle_signal_kicked(self, user: u.User, sender: pu.Puppet) -> None:
+    async def handle_signal_kicked(self, user: u.User, sender: p.Puppet) -> None:
         self.log.debug(f"{user.mxid} was kicked by {sender.number} from {self.mxid}")
         await self.main_intent.kick_user(self.mxid, user.mxid, f"{sender.name} kicked you")
 
@@ -1250,6 +1252,7 @@ class Portal(DBPortal, BasePortal):
                 )
                 return
             changed = await self._update_name(info.title, sender) or changed
+            changed = await self._update_topic(info.description, sender) or changed
         elif isinstance(info, GroupV2ID):
             return
         else:
@@ -1321,6 +1324,20 @@ class Portal(DBPortal, BasePortal):
             return True
         return False
 
+    async def _update_topic(self, topic: str, sender: p.Puppet | None = None) -> bool:
+        if self.topic != topic:
+            self.topic = topic
+            if self.mxid:
+                try:
+                    await self._try_with_puppet(
+                        lambda i: i.set_room_topic(self.mxid, self.topic), puppet=sender
+                    )
+                except Exception:
+                    self.log.exception("Error setting topic")
+                    self.topic = None
+            return True
+        return False
+
     async def _try_with_puppet(
         self, action: Callable[[IntentAPI], Awaitable[Any]], puppet: p.Puppet | None = None
     ) -> None:
@@ -1596,6 +1613,7 @@ class Portal(DBPortal, BasePortal):
             creation_content["m.federate"] = False
         self.mxid = await self.main_intent.create_room(
             name=name,
+            topic=self.topic,
             is_direct=self.is_direct,
             initial_state=initial_state,
             invitees=invites,