Преглед на файлове

Add support for Signal profile avatars

Tulir Asokan преди 4 години
родител
ревизия
d87d45ca82
променени са 5 файла, в които са добавени 115 реда и са изтрити 46 реда
  1. 2 2
      ROADMAP.md
  2. 38 31
      mautrix_signal/db/puppet.py
  3. 9 0
      mautrix_signal/db/upgrade.py
  4. 13 1
      mautrix_signal/portal.py
  5. 53 12
      mautrix_signal/puppet.py

+ 2 - 2
ROADMAP.md

@@ -36,9 +36,9 @@
   * [x] Remote deletions
   * [x] Remote deletions
   * [ ] Initial profile info
   * [ ] Initial profile info
     * [x] User displayname
     * [x] User displayname
-    * [ ] †User avatar
+    * [x] User avatar
     * [x] Group name
     * [x] Group name
-    * [x] Group avatar
+    * [ ] †Group avatar
   * [ ] Profile info changes
   * [ ] Profile info changes
     * [ ] User displayname
     * [ ] User displayname
     * [ ] †User avatar
     * [ ] †User avatar

+ 38 - 31
mautrix_signal/db/puppet.py

@@ -21,7 +21,7 @@ from yarl import URL
 import asyncpg
 import asyncpg
 
 
 from mausignald.types import Address
 from mausignald.types import Address
-from mautrix.types import UserID, SyncToken
+from mautrix.types import UserID, SyncToken, ContentURI
 from mautrix.util.async_db import Database
 from mautrix.util.async_db import Database
 
 
 fake_db = Database("") if TYPE_CHECKING else None
 fake_db = Database("") if TYPE_CHECKING else None
@@ -34,6 +34,10 @@ class Puppet:
     uuid: Optional[UUID]
     uuid: Optional[UUID]
     number: Optional[str]
     number: Optional[str]
     name: Optional[str]
     name: Optional[str]
+    avatar_hash: Optional[str]
+    avatar_url: Optional[ContentURI]
+    name_set: bool
+    avatar_set: bool
 
 
     uuid_registered: bool
     uuid_registered: bool
     number_registered: bool
     number_registered: bool
@@ -43,13 +47,19 @@ class Puppet:
     next_batch: Optional[SyncToken]
     next_batch: Optional[SyncToken]
     base_url: Optional[URL]
     base_url: Optional[URL]
 
 
+    @property
+    def _base_url_str(self) -> Optional[str]:
+        return str(self.base_url) if self.base_url else None
+
     async def insert(self) -> None:
     async def insert(self) -> None:
-        q = ("INSERT INTO puppet (uuid, number, name, uuid_registered, number_registered, "
+        q = ("INSERT INTO puppet (uuid, number, name, avatar_hash, avatar_url, name_set, "
+             "                    avatar_set, uuid_registered, number_registered, "
              "                    custom_mxid, access_token, next_batch, base_url) "
              "                    custom_mxid, access_token, next_batch, base_url) "
-             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)")
-        await self.db.execute(q, self.uuid, self.number, self.name, self.uuid_registered,
-                              self.number_registered, self.custom_mxid, self.access_token,
-                              self.next_batch, str(self.base_url) if self.base_url else None)
+             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)")
+        await self.db.execute(q, self.uuid, self.number, self.name, self.avatar_hash,
+                              self.avatar_url, self.name_set, self.avatar_set,
+                              self.uuid_registered, self.number_registered, self.custom_mxid,
+                              self.access_token, self.next_batch, self._base_url_str)
 
 
     async def _set_uuid(self, uuid: UUID) -> None:
     async def _set_uuid(self, uuid: UUID) -> None:
         if self.uuid:
         if self.uuid:
@@ -63,17 +73,18 @@ class Puppet:
             await conn.execute("UPDATE reaction SET author=$1 WHERE author=$2", uuid, self.number)
             await conn.execute("UPDATE reaction SET author=$1 WHERE author=$2", uuid, self.number)
 
 
     async def update(self) -> None:
     async def update(self) -> None:
-        if self.uuid is None:
-            q = ("UPDATE puppet SET uuid=$1, name=$3, uuid_registered=$4, number_registered=$5, "
-                 "                  custom_mxid=$6, access_token=$7, next_batch=$8, base_url=$9 "
-                 "WHERE number=$2")
-        else:
-            q = ("UPDATE puppet SET number=$2, name=$3, uuid_registered=$4, number_registered=$5, "
-                 "                  custom_mxid=$6, access_token=$7, next_batch=$8, base_url=$9 "
-                 "WHERE uuid=$1")
-        await self.db.execute(q, self.uuid, self.number, self.name, self.uuid_registered,
-                              self.number_registered, self.custom_mxid, self.access_token,
-                              self.next_batch, str(self.base_url) if self.base_url else None)
+        set_columns = (
+            "name=$3, avatar_hash=$4, avatar_url=$5, name_set=$6, avatar_set=$7, "
+            "uuid_registered=$8, number_registered=$9, "
+            "custom_mxid=$10, access_token=$11, next_batch=$12, base_url=$13"
+        )
+        q = (f"UPDATE puppet SET uuid=$1, {set_columns} WHERE number=$2"
+             if self.uuid is None
+             else f"UPDATE puppet SET number=$2, {set_columns} WHERE uuid=$1")
+        await self.db.execute(q,self.uuid, self.number, self.name, self.avatar_hash,
+                              self.avatar_url, self.name_set, self.avatar_set,
+                              self.uuid_registered, self.number_registered, self.custom_mxid,
+                              self.access_token, self.next_batch, self._base_url_str)
 
 
     @classmethod
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> 'Puppet':
     def _from_row(cls, row: asyncpg.Record) -> 'Puppet':
@@ -82,19 +93,21 @@ class Puppet:
         base_url = URL(base_url_str) if base_url_str is not None else None
         base_url = URL(base_url_str) if base_url_str is not None else None
         return cls(base_url=base_url, **data)
         return cls(base_url=base_url, **data)
 
 
+    _select_base = ("SELECT uuid, number, name, avatar_hash, avatar_url, name_set, avatar_set, "
+                    "       uuid_registered, number_registered, custom_mxid, access_token, "
+                    "       next_batch, base_url "
+                    "FROM puppet")
+
     @classmethod
     @classmethod
     async def get_by_address(cls, address: Address) -> Optional['Puppet']:
     async def get_by_address(cls, address: Address) -> Optional['Puppet']:
-        select = ("SELECT uuid, number, name, uuid_registered, "
-                  "       number_registered, custom_mxid, access_token, next_batch, base_url "
-                  "FROM puppet")
         if address.uuid:
         if address.uuid:
             if address.number:
             if address.number:
-                row = await cls.db.fetchrow(f"{select} WHERE uuid=$1 OR number=$2",
+                row = await cls.db.fetchrow(f"{cls._select_base} WHERE uuid=$1 OR number=$2",
                                             address.uuid, address.number)
                                             address.uuid, address.number)
             else:
             else:
-                row = await cls.db.fetchrow(f"{select} WHERE uuid=$1", address.uuid)
+                row = await cls.db.fetchrow(f"{cls._select_base} WHERE uuid=$1", address.uuid)
         elif address.number:
         elif address.number:
-            row = await cls.db.fetchrow(f"{select} WHERE number=$1", address.number)
+            row = await cls.db.fetchrow(f"{cls._select_base} WHERE number=$1", address.number)
         else:
         else:
             raise ValueError("Invalid address")
             raise ValueError("Invalid address")
         if not row:
         if not row:
@@ -103,18 +116,12 @@ class Puppet:
 
 
     @classmethod
     @classmethod
     async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
     async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
-        q = ("SELECT uuid, number, name, uuid_registered, number_registered,"
-             "       custom_mxid, access_token, next_batch, base_url "
-             "FROM puppet WHERE custom_mxid=$1")
-        row = await cls.db.fetchrow(q, mxid)
+        row = await cls.db.fetchrow(f"{cls._select_base} WHERE custom_mxid=$1", mxid)
         if not row:
         if not row:
             return None
             return None
         return cls._from_row(row)
         return cls._from_row(row)
 
 
     @classmethod
     @classmethod
     async def all_with_custom_mxid(cls) -> List['Puppet']:
     async def all_with_custom_mxid(cls) -> List['Puppet']:
-        q = ("SELECT uuid, number, name, uuid_registered, number_registered,"
-             "       custom_mxid, access_token, next_batch, base_url "
-             "FROM puppet WHERE custom_mxid IS NOT NULL")
-        rows = await cls.db.fetch(q)
+        rows = await cls.db.fetch(f"{cls._select_base} WHERE custom_mxid IS NOT NULL")
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]

+ 9 - 0
mautrix_signal/db/upgrade.py

@@ -114,3 +114,12 @@ async def upgrade_v4(conn: Connection) -> None:
                        "FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver) "
                        "FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver) "
                        "  REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver) "
                        "  REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver) "
                        "  ON DELETE CASCADE ON UPDATE CASCADE")
                        "  ON DELETE CASCADE ON UPDATE CASCADE")
+
+
+@upgrade_table.register(description="Add avatar info to puppet table")
+async def upgrade_v5(conn: Connection) -> None:
+    await conn.execute("ALTER TABLE puppet ADD COLUMN avatar_hash TEXT")
+    await conn.execute("ALTER TABLE puppet ADD COLUMN avatar_url TEXT")
+    await conn.execute("ALTER TABLE puppet ADD COLUMN name_set BOOL NOT NULL DEFAULT false")
+    await conn.execute("ALTER TABLE puppet ADD COLUMN avatar_set BOOL NOT NULL DEFAULT false")
+    await conn.execute("UPDATE puppet SET name_set=true WHERE name<>''")

+ 13 - 1
mautrix_signal/portal.py

@@ -474,12 +474,24 @@ class Portal(DBPortal, BasePortal):
             return
             return
         else:
         else:
             raise ValueError(f"Unexpected type for group update_info: {type(info)}")
             raise ValueError(f"Unexpected type for group update_info: {type(info)}")
-        changed = await self._update_avatar()
+        changed = await self._update_avatar() or changed
         await self._update_participants(source, info.members)
         await self._update_participants(source, info.members)
         if changed:
         if changed:
             await self.update_bridge_info()
             await self.update_bridge_info()
             await self.update()
             await self.update()
 
 
+    async def update_puppet_avatar(self, new_hash: str, avatar_url: ContentURI) -> None:
+        if not self.encrypted and not self.private_chat_portal_meta:
+            return
+
+        if self.avatar_hash != new_hash:
+            self.avatar_hash = new_hash
+            self.avatar_url = avatar_url
+            if self.mxid:
+                await self.main_intent.set_room_avatar(self.mxid, avatar_url)
+                await self.update_bridge_info()
+                await self.update()
+
     async def update_puppet_name(self, name: str) -> None:
     async def update_puppet_name(self, name: str) -> None:
         if not self.encrypted and not self.private_chat_portal_meta:
         if not self.encrypted and not self.private_chat_portal_meta:
             return
             return

+ 53 - 12
mautrix_signal/puppet.py

@@ -16,14 +16,16 @@
 from typing import (Optional, Dict, AsyncIterable, Awaitable, AsyncGenerator, Union,
 from typing import (Optional, Dict, AsyncIterable, Awaitable, AsyncGenerator, Union,
                     TYPE_CHECKING, cast)
                     TYPE_CHECKING, cast)
 from uuid import UUID
 from uuid import UUID
+import hashlib
 import asyncio
 import asyncio
+import os.path
 
 
 from yarl import URL
 from yarl import URL
 
 
 from mausignald.types import Address, Contact, Profile
 from mausignald.types import Address, Contact, Profile
 from mautrix.bridge import BasePuppet
 from mautrix.bridge import BasePuppet
 from mautrix.appservice import IntentAPI
 from mautrix.appservice import IntentAPI
-from mautrix.types import UserID, SyncToken, RoomID
+from mautrix.types import UserID, SyncToken, RoomID, ContentURI
 from mautrix.errors import MForbidden
 from mautrix.errors import MForbidden
 from mautrix.util.simple_template import SimpleTemplate
 from mautrix.util.simple_template import SimpleTemplate
 
 
@@ -56,12 +58,16 @@ class Puppet(DBPuppet, BasePuppet):
     _update_info_lock: asyncio.Lock
     _update_info_lock: asyncio.Lock
 
 
     def __init__(self, uuid: Optional[UUID], number: Optional[str], name: Optional[str] = None,
     def __init__(self, uuid: Optional[UUID], number: Optional[str], name: Optional[str] = None,
-                 uuid_registered: bool = False, number_registered: bool = False,
-                 custom_mxid: Optional[UserID] = None, access_token: Optional[str] = None,
-                 next_batch: Optional[SyncToken] = None, base_url: Optional[URL] = None) -> None:
-        super().__init__(uuid=uuid, number=number, name=name, uuid_registered=uuid_registered,
-                         number_registered=number_registered, custom_mxid=custom_mxid,
-                         access_token=access_token, next_batch=next_batch, base_url=base_url)
+                 avatar_url: Optional[ContentURI] = None, avatar_hash: Optional[str] = None,
+                 name_set: bool = False, avatar_set: bool = False, uuid_registered: bool = False,
+                 number_registered: bool = False, custom_mxid: Optional[UserID] = None,
+                 access_token: Optional[str] = None, next_batch: Optional[SyncToken] = None,
+                 base_url: Optional[URL] = None) -> None:
+        super().__init__(uuid=uuid, number=number, name=name, avatar_url=avatar_url,
+                         avatar_hash=avatar_hash, name_set=name_set, avatar_set=avatar_set,
+                         uuid_registered=uuid_registered, number_registered=number_registered,
+                         custom_mxid=custom_mxid, access_token=access_token, next_batch=next_batch,
+                         base_url=base_url)
         self.log = self.log.getChild(str(uuid) if uuid else number)
         self.log = self.log.getChild(str(uuid) if uuid else number)
 
 
         self.default_mxid = self.get_mxid_from_id(self.address)
         self.default_mxid = self.get_mxid_from_id(self.address)
@@ -167,8 +173,11 @@ class Puppet(DBPuppet, BasePuppet):
             update = False
             update = False
             if name is not None or self.name is None:
             if name is not None or self.name is None:
                 update = await self._update_name(name) or update
                 update = await self._update_name(name) or update
+            if isinstance(info, Profile):
+                update = await self._update_avatar(info.avatar) or update
             if update:
             if update:
                 await self.update()
                 await self.update()
+                self.loop.create_task(self._update_portal_meta())
 
 
     @staticmethod
     @staticmethod
     def fmt_phone(number: str) -> str:
     def fmt_phone(number: str) -> str:
@@ -199,19 +208,51 @@ class Puppet(DBPuppet, BasePuppet):
 
 
     async def _update_name(self, name: Optional[str]) -> bool:
     async def _update_name(self, name: Optional[str]) -> bool:
         name = self._get_displayname(self.address, name)
         name = self._get_displayname(self.address, name)
-        if name != self.name:
+        if name != self.name or not self.name_set:
             self.name = name
             self.name = name
-            await self.default_mxid_intent.set_displayname(self.name)
-            self.loop.create_task(self._update_portal_names())
+            try:
+                await self.default_mxid_intent.set_displayname(self.name)
+                self.name_set = True
+            except Exception:
+                self.log.exception("Error setting displayname")
+                self.name_set = False
             return True
             return True
         return False
         return False
 
 
-    async def _update_portal_names(self) -> None:
+    async def _update_avatar(self, path: str) -> bool:
+        if not path:
+            return False
+        if not path.startswith("/"):
+            path = os.path.join(self.config["signal.avatar_dir"], path)
+        try:
+            with open(path, "rb") as file:
+                data = file.read()
+        except FileNotFoundError:
+            return False
+        new_hash = hashlib.sha256(data).hexdigest()
+        if self.avatar_set and new_hash == self.avatar_hash:
+            return False
+        mxc = await self.default_mxid_intent.upload_media(data)
+        self.avatar_hash = new_hash
+        self.avatar_url = mxc
+        try:
+            await self.default_mxid_intent.set_avatar_url(self.avatar_url)
+            self.avatar_set = True
+        except Exception:
+            self.log.exception("Error setting avatar")
+            self.avatar_set = False
+        return True
+
+    async def _update_portal_meta(self) -> None:
         async for portal in p.Portal.find_private_chats_with(self.address):
         async for portal in p.Portal.find_private_chats_with(self.address):
             if portal.receiver == self.number:
             if portal.receiver == self.number:
                 # This is a note to self chat, don't change the name
                 # This is a note to self chat, don't change the name
                 continue
                 continue
-            await portal.update_puppet_name(self.name)
+            try:
+                await portal.update_puppet_name(self.name)
+                await portal.update_puppet_avatar(self.avatar_hash, self.avatar_url)
+            except Exception:
+                self.log.exception(f"Error updating portal meta for {portal.receiver}")
 
 
     async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
     async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
         portal = await p.Portal.get_by_mxid(room_id)
         portal = await p.Portal.get_by_mxid(room_id)