Parcourir la source

Add support for Signal profile avatars

Tulir Asokan il y a 4 ans
Parent
commit
d87d45ca82
5 fichiers modifiés avec 115 ajouts et 46 suppressions
  1. 2 2
      ROADMAP.md
  2. 38 31
      mautrix_signal/db/puppet.py
  3. 9 0
      mautrix_signal/db/upgrade.py
  4. 13 1
      mautrix_signal/portal.py
  5. 53 12
      mautrix_signal/puppet.py

+ 2 - 2
ROADMAP.md

@@ -36,9 +36,9 @@
   * [x] Remote deletions
   * [ ] Initial profile info
     * [x] User displayname
-    * [ ] †User avatar
+    * [x] User avatar
     * [x] Group name
-    * [x] Group avatar
+    * [ ] †Group avatar
   * [ ] Profile info changes
     * [ ] User displayname
     * [ ] †User avatar

+ 38 - 31
mautrix_signal/db/puppet.py

@@ -21,7 +21,7 @@ from yarl import URL
 import asyncpg
 
 from mausignald.types import Address
-from mautrix.types import UserID, SyncToken
+from mautrix.types import UserID, SyncToken, ContentURI
 from mautrix.util.async_db import Database
 
 fake_db = Database("") if TYPE_CHECKING else None
@@ -34,6 +34,10 @@ class Puppet:
     uuid: Optional[UUID]
     number: Optional[str]
     name: Optional[str]
+    avatar_hash: Optional[str]
+    avatar_url: Optional[ContentURI]
+    name_set: bool
+    avatar_set: bool
 
     uuid_registered: bool
     number_registered: bool
@@ -43,13 +47,19 @@ class Puppet:
     next_batch: Optional[SyncToken]
     base_url: Optional[URL]
 
+    @property
+    def _base_url_str(self) -> Optional[str]:
+        return str(self.base_url) if self.base_url else None
+
     async def insert(self) -> None:
-        q = ("INSERT INTO puppet (uuid, number, name, uuid_registered, number_registered, "
+        q = ("INSERT INTO puppet (uuid, number, name, avatar_hash, avatar_url, name_set, "
+             "                    avatar_set, uuid_registered, number_registered, "
              "                    custom_mxid, access_token, next_batch, base_url) "
-             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)")
-        await self.db.execute(q, self.uuid, self.number, self.name, self.uuid_registered,
-                              self.number_registered, self.custom_mxid, self.access_token,
-                              self.next_batch, str(self.base_url) if self.base_url else None)
+             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)")
+        await self.db.execute(q, self.uuid, self.number, self.name, self.avatar_hash,
+                              self.avatar_url, self.name_set, self.avatar_set,
+                              self.uuid_registered, self.number_registered, self.custom_mxid,
+                              self.access_token, self.next_batch, self._base_url_str)
 
     async def _set_uuid(self, uuid: UUID) -> None:
         if self.uuid:
@@ -63,17 +73,18 @@ class Puppet:
             await conn.execute("UPDATE reaction SET author=$1 WHERE author=$2", uuid, self.number)
 
     async def update(self) -> None:
-        if self.uuid is None:
-            q = ("UPDATE puppet SET uuid=$1, name=$3, uuid_registered=$4, number_registered=$5, "
-                 "                  custom_mxid=$6, access_token=$7, next_batch=$8, base_url=$9 "
-                 "WHERE number=$2")
-        else:
-            q = ("UPDATE puppet SET number=$2, name=$3, uuid_registered=$4, number_registered=$5, "
-                 "                  custom_mxid=$6, access_token=$7, next_batch=$8, base_url=$9 "
-                 "WHERE uuid=$1")
-        await self.db.execute(q, self.uuid, self.number, self.name, self.uuid_registered,
-                              self.number_registered, self.custom_mxid, self.access_token,
-                              self.next_batch, str(self.base_url) if self.base_url else None)
+        set_columns = (
+            "name=$3, avatar_hash=$4, avatar_url=$5, name_set=$6, avatar_set=$7, "
+            "uuid_registered=$8, number_registered=$9, "
+            "custom_mxid=$10, access_token=$11, next_batch=$12, base_url=$13"
+        )
+        q = (f"UPDATE puppet SET uuid=$1, {set_columns} WHERE number=$2"
+             if self.uuid is None
+             else f"UPDATE puppet SET number=$2, {set_columns} WHERE uuid=$1")
+        await self.db.execute(q,self.uuid, self.number, self.name, self.avatar_hash,
+                              self.avatar_url, self.name_set, self.avatar_set,
+                              self.uuid_registered, self.number_registered, self.custom_mxid,
+                              self.access_token, self.next_batch, self._base_url_str)
 
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> 'Puppet':
@@ -82,19 +93,21 @@ class Puppet:
         base_url = URL(base_url_str) if base_url_str is not None else None
         return cls(base_url=base_url, **data)
 
+    _select_base = ("SELECT uuid, number, name, avatar_hash, avatar_url, name_set, avatar_set, "
+                    "       uuid_registered, number_registered, custom_mxid, access_token, "
+                    "       next_batch, base_url "
+                    "FROM puppet")
+
     @classmethod
     async def get_by_address(cls, address: Address) -> Optional['Puppet']:
-        select = ("SELECT uuid, number, name, uuid_registered, "
-                  "       number_registered, custom_mxid, access_token, next_batch, base_url "
-                  "FROM puppet")
         if address.uuid:
             if address.number:
-                row = await cls.db.fetchrow(f"{select} WHERE uuid=$1 OR number=$2",
+                row = await cls.db.fetchrow(f"{cls._select_base} WHERE uuid=$1 OR number=$2",
                                             address.uuid, address.number)
             else:
-                row = await cls.db.fetchrow(f"{select} WHERE uuid=$1", address.uuid)
+                row = await cls.db.fetchrow(f"{cls._select_base} WHERE uuid=$1", address.uuid)
         elif address.number:
-            row = await cls.db.fetchrow(f"{select} WHERE number=$1", address.number)
+            row = await cls.db.fetchrow(f"{cls._select_base} WHERE number=$1", address.number)
         else:
             raise ValueError("Invalid address")
         if not row:
@@ -103,18 +116,12 @@ class Puppet:
 
     @classmethod
     async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
-        q = ("SELECT uuid, number, name, uuid_registered, number_registered,"
-             "       custom_mxid, access_token, next_batch, base_url "
-             "FROM puppet WHERE custom_mxid=$1")
-        row = await cls.db.fetchrow(q, mxid)
+        row = await cls.db.fetchrow(f"{cls._select_base} WHERE custom_mxid=$1", mxid)
         if not row:
             return None
         return cls._from_row(row)
 
     @classmethod
     async def all_with_custom_mxid(cls) -> List['Puppet']:
-        q = ("SELECT uuid, number, name, uuid_registered, number_registered,"
-             "       custom_mxid, access_token, next_batch, base_url "
-             "FROM puppet WHERE custom_mxid IS NOT NULL")
-        rows = await cls.db.fetch(q)
+        rows = await cls.db.fetch(f"{cls._select_base} WHERE custom_mxid IS NOT NULL")
         return [cls._from_row(row) for row in rows]

+ 9 - 0
mautrix_signal/db/upgrade.py

@@ -114,3 +114,12 @@ async def upgrade_v4(conn: Connection) -> None:
                        "FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver) "
                        "  REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver) "
                        "  ON DELETE CASCADE ON UPDATE CASCADE")
+
+
+@upgrade_table.register(description="Add avatar info to puppet table")
+async def upgrade_v5(conn: Connection) -> None:
+    await conn.execute("ALTER TABLE puppet ADD COLUMN avatar_hash TEXT")
+    await conn.execute("ALTER TABLE puppet ADD COLUMN avatar_url TEXT")
+    await conn.execute("ALTER TABLE puppet ADD COLUMN name_set BOOL NOT NULL DEFAULT false")
+    await conn.execute("ALTER TABLE puppet ADD COLUMN avatar_set BOOL NOT NULL DEFAULT false")
+    await conn.execute("UPDATE puppet SET name_set=true WHERE name<>''")

+ 13 - 1
mautrix_signal/portal.py

@@ -474,12 +474,24 @@ class Portal(DBPortal, BasePortal):
             return
         else:
             raise ValueError(f"Unexpected type for group update_info: {type(info)}")
-        changed = await self._update_avatar()
+        changed = await self._update_avatar() or changed
         await self._update_participants(source, info.members)
         if changed:
             await self.update_bridge_info()
             await self.update()
 
+    async def update_puppet_avatar(self, new_hash: str, avatar_url: ContentURI) -> None:
+        if not self.encrypted and not self.private_chat_portal_meta:
+            return
+
+        if self.avatar_hash != new_hash:
+            self.avatar_hash = new_hash
+            self.avatar_url = avatar_url
+            if self.mxid:
+                await self.main_intent.set_room_avatar(self.mxid, avatar_url)
+                await self.update_bridge_info()
+                await self.update()
+
     async def update_puppet_name(self, name: str) -> None:
         if not self.encrypted and not self.private_chat_portal_meta:
             return

+ 53 - 12
mautrix_signal/puppet.py

@@ -16,14 +16,16 @@
 from typing import (Optional, Dict, AsyncIterable, Awaitable, AsyncGenerator, Union,
                     TYPE_CHECKING, cast)
 from uuid import UUID
+import hashlib
 import asyncio
+import os.path
 
 from yarl import URL
 
 from mausignald.types import Address, Contact, Profile
 from mautrix.bridge import BasePuppet
 from mautrix.appservice import IntentAPI
-from mautrix.types import UserID, SyncToken, RoomID
+from mautrix.types import UserID, SyncToken, RoomID, ContentURI
 from mautrix.errors import MForbidden
 from mautrix.util.simple_template import SimpleTemplate
 
@@ -56,12 +58,16 @@ class Puppet(DBPuppet, BasePuppet):
     _update_info_lock: asyncio.Lock
 
     def __init__(self, uuid: Optional[UUID], number: Optional[str], name: Optional[str] = None,
-                 uuid_registered: bool = False, number_registered: bool = False,
-                 custom_mxid: Optional[UserID] = None, access_token: Optional[str] = None,
-                 next_batch: Optional[SyncToken] = None, base_url: Optional[URL] = None) -> None:
-        super().__init__(uuid=uuid, number=number, name=name, uuid_registered=uuid_registered,
-                         number_registered=number_registered, custom_mxid=custom_mxid,
-                         access_token=access_token, next_batch=next_batch, base_url=base_url)
+                 avatar_url: Optional[ContentURI] = None, avatar_hash: Optional[str] = None,
+                 name_set: bool = False, avatar_set: bool = False, uuid_registered: bool = False,
+                 number_registered: bool = False, custom_mxid: Optional[UserID] = None,
+                 access_token: Optional[str] = None, next_batch: Optional[SyncToken] = None,
+                 base_url: Optional[URL] = None) -> None:
+        super().__init__(uuid=uuid, number=number, name=name, avatar_url=avatar_url,
+                         avatar_hash=avatar_hash, name_set=name_set, avatar_set=avatar_set,
+                         uuid_registered=uuid_registered, number_registered=number_registered,
+                         custom_mxid=custom_mxid, access_token=access_token, next_batch=next_batch,
+                         base_url=base_url)
         self.log = self.log.getChild(str(uuid) if uuid else number)
 
         self.default_mxid = self.get_mxid_from_id(self.address)
@@ -167,8 +173,11 @@ class Puppet(DBPuppet, BasePuppet):
             update = False
             if name is not None or self.name is None:
                 update = await self._update_name(name) or update
+            if isinstance(info, Profile):
+                update = await self._update_avatar(info.avatar) or update
             if update:
                 await self.update()
+                self.loop.create_task(self._update_portal_meta())
 
     @staticmethod
     def fmt_phone(number: str) -> str:
@@ -199,19 +208,51 @@ class Puppet(DBPuppet, BasePuppet):
 
     async def _update_name(self, name: Optional[str]) -> bool:
         name = self._get_displayname(self.address, name)
-        if name != self.name:
+        if name != self.name or not self.name_set:
             self.name = name
-            await self.default_mxid_intent.set_displayname(self.name)
-            self.loop.create_task(self._update_portal_names())
+            try:
+                await self.default_mxid_intent.set_displayname(self.name)
+                self.name_set = True
+            except Exception:
+                self.log.exception("Error setting displayname")
+                self.name_set = False
             return True
         return False
 
-    async def _update_portal_names(self) -> None:
+    async def _update_avatar(self, path: str) -> bool:
+        if not path:
+            return False
+        if not path.startswith("/"):
+            path = os.path.join(self.config["signal.avatar_dir"], path)
+        try:
+            with open(path, "rb") as file:
+                data = file.read()
+        except FileNotFoundError:
+            return False
+        new_hash = hashlib.sha256(data).hexdigest()
+        if self.avatar_set and new_hash == self.avatar_hash:
+            return False
+        mxc = await self.default_mxid_intent.upload_media(data)
+        self.avatar_hash = new_hash
+        self.avatar_url = mxc
+        try:
+            await self.default_mxid_intent.set_avatar_url(self.avatar_url)
+            self.avatar_set = True
+        except Exception:
+            self.log.exception("Error setting avatar")
+            self.avatar_set = False
+        return True
+
+    async def _update_portal_meta(self) -> None:
         async for portal in p.Portal.find_private_chats_with(self.address):
             if portal.receiver == self.number:
                 # This is a note to self chat, don't change the name
                 continue
-            await portal.update_puppet_name(self.name)
+            try:
+                await portal.update_puppet_name(self.name)
+                await portal.update_puppet_avatar(self.avatar_hash, self.avatar_url)
+            except Exception:
+                self.log.exception(f"Error updating portal meta for {portal.receiver}")
 
     async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
         portal = await p.Portal.get_by_mxid(room_id)