Explorar el Código

Add support for real-time group info updates from Signal

Tulir Asokan hace 4 años
padre
commit
411587645f

+ 2 - 0
ROADMAP.md

@@ -38,6 +38,8 @@
   * [ ] Profile info changes
     * [x] When restarting bridge or syncing
     * [ ] Real time
+      * [x] Groups
+      * [ ] Users
   * [ ] Group permissions
   * [x] Typing notifications
   * [x] Read receipts

+ 4 - 3
mausignald/signald.py

@@ -170,9 +170,10 @@ class SignaldClient(SignaldRPCClient):
         return [Contact.deserialize(contact) for contact in contacts]
 
     async def list_groups(self, username: str) -> List[Union[Group, GroupV2]]:
-        resp = await self.request("list_groups", "group_list", username=username)
-        return ([Group.deserialize(group) for group in resp["groups"]]
-                + [GroupV2.deserialize(group) for group in resp["groupsv2"]])
+        resp = await self.request("list_groups", "list_groups", account=username, version="v1")
+        legacy = [Group.deserialize(group) for group in resp.get("legacyGroups", [])]
+        v2 = [GroupV2.deserialize(group) for group in resp.get("groups", [])]
+        return legacy + v2
 
     async def update_group(self, username: str, group_id: GroupID, title: Optional[str] = None,
                            avatar_path: Optional[str] = None,

+ 19 - 9
mautrix_signal/db/portal.py

@@ -38,6 +38,9 @@ class Portal:
     name: Optional[str]
     avatar_hash: Optional[str]
     avatar_url: Optional[ContentURI]
+    name_set: bool
+    avatar_set: bool
+    revision: int
     encrypted: bool
 
     @property
@@ -46,16 +49,18 @@ class Portal:
 
     async def insert(self) -> None:
         q = ("INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, "
-             "                    encrypted) "
-             "VALUES ($1, $2, $3, $4, $5, $6, $7)")
+             "                    name_set, avatar_set, revision, encrypted) "
+             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)")
         await self.db.execute(q, self.chat_id_str, self.receiver, self.mxid, self.name,
                               self.avatar_hash, self.avatar_url, self.encrypted)
 
     async def update(self) -> None:
-        q = ("UPDATE portal SET mxid=$3, name=$4, avatar_hash=$5, avatar_url=$6, encrypted=$7 "
+        q = ("UPDATE portal SET mxid=$3, name=$4, avatar_hash=$5, avatar_url=$6, "
+             "                  name_set=$7, avatar_set=$8, revision=$9, encrypted=$10 "
              "WHERE chat_id=$1 AND receiver=$2")
         await self.db.execute(q, self.chat_id_str, self.receiver, self.mxid, self.name,
-                              self.avatar_hash, self.avatar_url, self.encrypted)
+                              self.avatar_hash, self.avatar_url, self.name_set, self.avatar_set,
+                              self.revision, self.encrypted)
 
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> 'Portal':
@@ -67,7 +72,8 @@ class Portal:
 
     @classmethod
     async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
+        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
+             "       revision, encrypted "
              "FROM portal WHERE mxid=$1")
         row = await cls.db.fetchrow(q, mxid)
         if not row:
@@ -77,7 +83,8 @@ class Portal:
     @classmethod
     async def get_by_chat_id(cls, chat_id: Union[GroupID, Address], receiver: str = ""
                              ) -> Optional['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
+        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
+             "       revision, encrypted "
              "FROM portal WHERE chat_id=$1 AND receiver=$2")
         row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
         if not row:
@@ -86,21 +93,24 @@ class Portal:
 
     @classmethod
     async def find_private_chats_of(cls, receiver: str) -> List['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
+        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
+             "       revision, encrypted "
              "FROM portal WHERE receiver=$1")
         rows = await cls.db.fetch(q, receiver)
         return [cls._from_row(row) for row in rows]
 
     @classmethod
     async def find_private_chats_with(cls, other_user: Address) -> List['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
+        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
+             "       revision, encrypted "
              "FROM portal WHERE chat_id=$1 AND receiver<>''")
         rows = await cls.db.fetch(q, other_user.best_identifier)
         return [cls._from_row(row) for row in rows]
 
     @classmethod
     async def all_with_room(cls) -> List['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
+        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
+             "       revision, encrypted "
              "FROM portal WHERE mxid IS NOT NULL")
         rows = await cls.db.fetch(q)
         return [cls._from_row(row) for row in rows]

+ 11 - 2
mautrix_signal/db/upgrade.py

@@ -120,6 +120,15 @@ async def upgrade_v4(conn: Connection) -> None:
 async def upgrade_v5(conn: Connection) -> None:
     await conn.execute("ALTER TABLE puppet ADD COLUMN avatar_hash TEXT")
     await conn.execute("ALTER TABLE puppet ADD COLUMN avatar_url TEXT")
-    await conn.execute("ALTER TABLE puppet ADD COLUMN name_set BOOL NOT NULL DEFAULT false")
-    await conn.execute("ALTER TABLE puppet ADD COLUMN avatar_set BOOL NOT NULL DEFAULT false")
+    await conn.execute("ALTER TABLE puppet ADD COLUMN name_set BOOLEAN NOT NULL DEFAULT false")
+    await conn.execute("ALTER TABLE puppet ADD COLUMN avatar_set BOOLEAN NOT NULL DEFAULT false")
     await conn.execute("UPDATE puppet SET name_set=true WHERE name<>''")
+
+
+@upgrade_table.register(description="Add revision to portal table")
+async def upgrade_v6(conn: Connection) -> None:
+    await conn.execute("ALTER TABLE portal ADD COLUMN name_set BOOLEAN NOT NULL DEFAULT false")
+    await conn.execute("ALTER TABLE portal ADD COLUMN avatar_set BOOLEAN NOT NULL DEFAULT false")
+    await conn.execute("ALTER TABLE portal ADD COLUMN revision INTEGER NOT NULL DEFAULT 0")
+    await conn.execute("UPDATE portal SET name_set=true WHERE name<>''")
+    await conn.execute("UPDATE portal SET avatar_set=true WHERE avatar_hash<>''")

+ 68 - 36
mautrix_signal/portal.py

@@ -14,7 +14,7 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 from typing import (Dict, Tuple, Optional, List, Deque, Any, Union, AsyncGenerator, Awaitable,
-                    TYPE_CHECKING, cast)
+                    Callable, TYPE_CHECKING, cast)
 from collections import deque
 from uuid import UUID, uuid4
 import mimetypes
@@ -81,8 +81,10 @@ class Portal(DBPortal, BasePortal):
     def __init__(self, chat_id: Union[GroupID, Address], receiver: str,
                  mxid: Optional[RoomID] = None, name: Optional[str] = None,
                  avatar_hash: Optional[str] = None, avatar_url: Optional[ContentURI] = None,
+                 name_set: bool = False, avatar_set: bool = False, revision: int = 0,
                  encrypted: bool = False) -> None:
-        super().__init__(chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted)
+        super().__init__(chat_id, receiver, mxid, name, avatar_hash, avatar_url,
+                         name_set, avatar_set, revision, encrypted)
         self._create_room_lock = asyncio.Lock()
         self.log = self.log.getChild(self.chat_id_str)
         self._main_intent = None
@@ -278,7 +280,7 @@ class Portal(DBPortal, BasePortal):
             # TODO cleanup if empty
 
     async def handle_matrix_name(self, user: 'u.User', name: str) -> None:
-        if self.name == name or self.is_direct:
+        if self.name == name or self.is_direct or not name:
             return
         self.name = name
         self.log.debug(f"{user.mxid} changed the group name, sending to Signal")
@@ -289,12 +291,12 @@ class Portal(DBPortal, BasePortal):
             self.name = None
 
     async def handle_matrix_avatar(self, user: 'u.User', url: ContentURI) -> None:
-        if self.is_direct:
+        if self.is_direct or not url:
             return
 
         data = await self.main_intent.download_media(url)
         new_hash = hashlib.sha256(data).hexdigest()
-        if new_hash == self.avatar_hash:
+        if new_hash == self.avatar_hash and self.avatar_set:
             self.log.debug(f"New avatar from Matrix set by {user.mxid} is same as current one")
             return
         self.avatar_url = url
@@ -306,9 +308,10 @@ class Portal(DBPortal, BasePortal):
             with open(path, "wb") as file:
                 file.write(data)
             await self.signal.update_group(user.username, self.chat_id, avatar_path=path)
+            self.avatar_set = True
         except Exception:
             self.log.exception("Failed to update Signal group avatar")
-            self.avatar_hash = None
+            self.avatar_set = False
         if self.config["signal.remove_file_after_handling"]:
             try:
                 os.remove(path)
@@ -592,7 +595,8 @@ class Portal(DBPortal, BasePortal):
     # endregion
     # region Updating portal info
 
-    async def update_info(self, source: 'u.User', info: ChatInfo) -> None:
+    async def update_info(self, source: 'u.User', info: ChatInfo,
+                          sender: Optional['p.Puppet'] = None) -> None:
         if self.is_direct:
             if not isinstance(info, (Contact, Profile, Address)):
                 raise ValueError(f"Unexpected type for direct chat update_info: {type(info)}")
@@ -603,15 +607,30 @@ class Portal(DBPortal, BasePortal):
                 self.name = puppet.name
             return
 
+        if isinstance(info, GroupV2ID):
+            info = await self.signal.get_group(source.username, info.id, info.revision or -1)
+            if not info:
+                self.log.debug(f"Failed to get full group v2 info through {source.username}, "
+                               "cancelling update")
+                return
+
+        changed = False
         if isinstance(info, Group):
-            changed = await self._update_name(info.name)
+            changed = await self._update_name(info.name, sender) or changed
         elif isinstance(info, GroupV2):
-            changed = await self._update_name(info.title)
+            if self.revision < info.revision:
+                self.revision = info.revision
+                changed = True
+            elif self.revision > info.revision:
+                self.log.warning(f"Got outdated info when syncing through {source.username} "
+                                 f"({info.revision} < {self.revision}), ignoring...")
+                return
+            changed = await self._update_name(info.title, sender) or changed
         elif isinstance(info, GroupV2ID):
             return
         else:
             raise ValueError(f"Unexpected type for group update_info: {type(info)}")
-        changed = await self._update_avatar(info) or changed
+        changed = await self._update_avatar(info, sender) or changed
         await self._update_participants(source, info.members)
         if changed:
             await self.update_bridge_info()
@@ -621,11 +640,16 @@ class Portal(DBPortal, BasePortal):
         if not self.encrypted and not self.private_chat_portal_meta:
             return
 
-        if self.avatar_hash != new_hash:
+        if self.avatar_hash != new_hash or not self.avatar_set:
             self.avatar_hash = new_hash
             self.avatar_url = avatar_url
             if self.mxid:
-                await self.main_intent.set_room_avatar(self.mxid, avatar_url)
+                try:
+                    await self.main_intent.set_room_avatar(self.mxid, avatar_url)
+                    self.avatar_set = True
+                except Exception:
+                    self.log.exception("Error setting avatar")
+                    self.avatar_set = False
                 await self.update_bridge_info()
                 await self.update()
 
@@ -639,34 +663,50 @@ class Portal(DBPortal, BasePortal):
             await self.update_bridge_info()
             await self.update()
 
-    async def _update_name(self, name: str) -> bool:
-        if self.name != name:
+    async def _update_name(self, name: str, sender: Optional['p.Puppet'] = None) -> bool:
+        if self.name != name or not self.name_set:
             self.name = name
             if self.mxid:
-                await self.main_intent.set_room_name(self.mxid, name)
+                try:
+                    await self._try_with_puppet(lambda i: i.set_room_name(self.mxid, self.name),
+                                                puppet=sender)
+                    self.name_set = True
+                except Exception:
+                    self.log.exception("Error setting name")
+                    self.name_set = False
             return True
         return False
 
-    @property
-    def avatar_set(self) -> bool:
-        return bool(self.avatar_hash)
+    async def _try_with_puppet(self, action: Callable[[IntentAPI], Awaitable[Any]],
+                               puppet: Optional['p.Puppet'] = None) -> None:
+        if puppet:
+            try:
+                await action(puppet.intent_for(self))
+            except MForbidden:
+                await action(self.main_intent)
+        else:
+            await action(self.main_intent)
 
-    async def _update_avatar(self, info: ChatInfo) -> bool:
+    async def _update_avatar(self, info: ChatInfo, sender: Optional['p.Puppet'] = None) -> bool:
         path = None
         if isinstance(info, GroupV2):
             path = info.avatar
         elif isinstance(info, Group):
             path = f"group-{self.chat_id}"
-        res = await p.Puppet.upload_avatar(self, path)
+        res = await p.Puppet.upload_avatar(self, path, self.main_intent)
         if res is False:
             return False
         self.avatar_hash, self.avatar_url = res
-        if self.mxid:
-            try:
-                await self.main_intent.set_room_avatar(self.mxid, self.avatar_url)
-            except Exception:
-                self.log.exception("Error setting avatar")
-                self.avatar_hash = None
+        if not self.mxid:
+            return True
+
+        try:
+            await self._try_with_puppet(lambda i: i.set_room_avatar(self.mxid, self.avatar_url),
+                                        puppet=sender)
+            self.avatar_set = True
+        except Exception:
+            self.log.exception("Error setting avatar")
+            self.avatar_set = False
         return True
 
     async def _update_participants(self, source: 'u.User', participants: List[Address]) -> None:
@@ -762,16 +802,6 @@ class Portal(DBPortal, BasePortal):
 
         await self.update_info(source, info)
 
-        # TODO
-        # up = DBUserPortal.get(source.fbid, self.fbid, self.fb_receiver)
-        # if not up:
-        #     in_community = await source._community_helper.add_room(source._community_id, self.mxid)
-        #     DBUserPortal(user=source.fbid, portal=self.fbid, portal_receiver=self.fb_receiver,
-        #                  in_community=in_community).insert()
-        # elif not up.in_community:
-        #     in_community = await source._community_helper.add_room(source._community_id, self.mxid)
-        #     up.edit(in_community=in_community)
-
     async def _create_matrix_room(self, source: 'u.User', info: ChatInfo) -> Optional[RoomID]:
         if self.mxid:
             await self._update_matrix_room(source, info)
@@ -824,6 +854,8 @@ class Portal(DBPortal, BasePortal):
                                                        invitees=invites)
         if not self.mxid:
             raise Exception("Failed to create room: no mxid returned")
+        self.name_set = bool(name)
+        self.avatar_set = bool(self.avatar_url)
 
         if self.encrypted and self.matrix.e2ee and self.is_direct:
             try:

+ 3 - 3
mautrix_signal/puppet.py

@@ -220,7 +220,7 @@ class Puppet(DBPuppet, BasePuppet):
         return False
 
     @staticmethod
-    async def upload_avatar(self: Union['Puppet', 'p.Portal'], path: str
+    async def upload_avatar(self: Union['Puppet', 'p.Portal'], path: str, intent: IntentAPI,
                             ) -> Union[bool, Tuple[str, ContentURI]]:
         if not path:
             return False
@@ -234,11 +234,11 @@ class Puppet(DBPuppet, BasePuppet):
         new_hash = hashlib.sha256(data).hexdigest()
         if self.avatar_set and new_hash == self.avatar_hash:
             return False
-        mxc = await self.default_mxid_intent.upload_media(data)
+        mxc = await intent.upload_media(data)
         return new_hash, mxc
 
     async def _update_avatar(self, path: str) -> bool:
-        res = await Puppet.upload_avatar(self, path)
+        res = await Puppet.upload_avatar(self, path, self.default_mxid_intent)
         if res is False:
             return False
         self.avatar_hash, self.avatar_url = res

+ 3 - 0
mautrix_signal/signal.py

@@ -93,6 +93,9 @@ class SignalHandler(SignaldClient):
         if not portal.mxid:
             await portal.create_matrix_room(user, (group_v2_info or msg.group
                                                    or addr_override or sender.address))
+        elif msg.group_v2 and msg.group_v2.revision > portal.revision:
+            self.log.debug(f"Got new revision of {msg.group_v2.id}, updating info")
+            await portal.update_info(user, group_v2_info or msg.group_v2, sender)
         if msg.reaction:
             await portal.handle_signal_reaction(sender, msg.reaction)
         if msg.body or msg.attachments or msg.sticker:

+ 2 - 1
mautrix_signal/user.py

@@ -178,6 +178,7 @@ class User(DBUser, BaseUser):
     async def _sync_groups(self) -> None:
         create_group_portal = self.config["bridge.autocreate_group_portal"]
         for group in await self.bridge.signal.list_groups(self.username):
+            group_id = group.group_id if isinstance(group, Group) else group.id
             try:
                 if isinstance(group, Group):
                     await self._sync_group(group, create_group_portal)
@@ -186,7 +187,7 @@ class User(DBUser, BaseUser):
                 else:
                     self.log.warning("Unknown return type in list_groups: %s", type(group))
             except Exception:
-                self.log.exception(f"Failed to sync group {group.group_id}")
+                self.log.exception(f"Failed to sync group {group_id}")
 
     # region Database getters