Procházet zdrojové kódy

Add bridging for group avatars from Signal

Tulir Asokan před 4 roky
rodič
revize
b25c2e8a88

+ 4 - 4
ROADMAP.md

@@ -31,14 +31,14 @@
   * [x] Message reactions
   * [ ] Initial profile info
     * [x] User displayname
-    * [ ] User avatar
+    * [ ] User avatar
     * [x] Group name
-    * [ ] Group avatar
+    * [x] Group avatar
   * [ ] Profile info changes
     * [ ] User displayname
-    * [ ] User avatar
+    * [ ] User avatar
     * [x] Group name
-    * [ ] Group avatar
+    * [x] Group avatar
   * [ ] Typing notifications
   * [x] Read receipts
   * [ ] Disappearing messages

+ 1 - 0
mautrix_signal/config.py

@@ -55,6 +55,7 @@ class Config(BaseBridgeConfig):
 
         copy("signal.socket_path")
         copy("signal.outgoing_attachment_dir")
+        copy("signal.avatar_dir")
         copy("signal.remove_file_after_handling")
 
         copy("metrics.enabled")

+ 17 - 11
mautrix_signal/db/portal.py

@@ -19,7 +19,7 @@ from uuid import UUID
 from attr import dataclass
 import asyncpg
 
-from mautrix.types import RoomID
+from mautrix.types import RoomID, ContentURI
 from mautrix.util.async_db import Database
 
 fake_db = Database("") if TYPE_CHECKING else None
@@ -33,19 +33,22 @@ class Portal:
     receiver: str
     mxid: Optional[RoomID]
     name: Optional[str]
+    avatar_hash: Optional[str]
+    avatar_url: Optional[ContentURI]
     encrypted: bool
 
     async def insert(self) -> None:
-        q = ("INSERT INTO portal (chat_id, receiver, mxid, name, encrypted) "
-             "VALUES ($1, $2, $3, $4, $5)")
+        q = ("INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, "
+             "                    encrypted) "
+             "VALUES ($1, $2, $3, $4, $5, $6, $7)")
         await self.db.execute(q, str(self.chat_id), self.receiver, self.mxid, self.name,
                               self.encrypted)
 
     async def update(self) -> None:
-        q = ("UPDATE portal SET mxid=$3, name=$4, encrypted=$5 "
+        q = ("UPDATE portal SET mxid=$3, name=$4, avatar_hash=$5, avatar_url=$6, encrypted=$7 "
              "WHERE chat_id=$1 AND receiver=$2")
         await self.db.execute(q, str(self.chat_id), self.receiver, self.mxid, self.name,
-                              self.encrypted)
+                              self.avatar_hash, self.avatar_url, self.encrypted)
 
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> 'Portal':
@@ -58,7 +61,8 @@ class Portal:
 
     @classmethod
     async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
-        q = "SELECT chat_id, receiver, mxid, name, encrypted FROM portal WHERE mxid=$1"
+        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
+             "FROM portal WHERE mxid=$1")
         row = await cls.db.fetchrow(q, mxid)
         if not row:
             return None
@@ -67,7 +71,7 @@ class Portal:
     @classmethod
     async def get_by_chat_id(cls, chat_id: Union[UUID, str], receiver: str = ""
                              ) -> Optional['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, encrypted "
+        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
              "FROM portal WHERE chat_id=$1 AND receiver=$2")
         row = await cls.db.fetchrow(q, str(chat_id), receiver)
         if not row:
@@ -76,19 +80,21 @@ class Portal:
 
     @classmethod
     async def find_private_chats_of(cls, receiver: str) -> List['Portal']:
-        q = "SELECT chat_id, receiver, mxid, name, encrypted FROM portal WHERE receiver=$1"
+        q =( "SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
+             "FROM portal WHERE receiver=$1")
         rows = await cls.db.fetch(q, receiver)
         return [cls._from_row(row) for row in rows]
 
     @classmethod
     async def find_private_chats_with(cls, other_user: UUID) -> List['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, encrypted FROM portal "
-             "WHERE chat_id=$1::text AND receiver<>''")
+        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
+             "FROM portal WHERE chat_id=$1::text AND receiver<>''")
         rows = await cls.db.fetch(q, other_user)
         return [cls._from_row(row) for row in rows]
 
     @classmethod
     async def all_with_room(cls) -> List['Portal']:
-        q = "SELECT chat_id, receiver, mxid, name, encrypted FROM portal WHERE mxid IS NOT NULL"
+        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
+             "FROM portal WHERE mxid IS NOT NULL")
         rows = await cls.db.fetch(q)
         return [cls._from_row(row) for row in rows]

+ 6 - 0
mautrix_signal/db/upgrade.py

@@ -89,3 +89,9 @@ async def upgrade_v1(conn: Connection) -> None:
             ON DELETE CASCADE ON UPDATE CASCADE,
         UNIQUE (mxid, mx_room)
     )""")
+
+
+@upgrade_table.register(description="Add avatar info to portal table")
+async def upgrade_v2(conn: Connection) -> None:
+    await conn.execute("ALTER TABLE portal ADD COLUMN avatar_hash TEXT")
+    await conn.execute("ALTER TABLE portal ADD COLUMN avatar_url TEXT")

+ 2 - 0
mautrix_signal/example-config.yaml

@@ -70,6 +70,8 @@ signal:
     # absolute path that signald can read. For attachments in the other direction,
     # make sure signald is configured to use an absolute path as the data directory.
     outgoing_attachment_dir: /tmp
+    # Directory where signald stores avatars for groups.
+    avatar_dir: ~/.config/signald/avatars
     # Whether or not message attachments should be removed from disk after they're bridged.
     remove_file_after_handling: true
 

+ 25 - 3
mautrix_signal/portal.py

@@ -18,6 +18,7 @@ from typing import (Dict, Tuple, Optional, List, Deque, Set, Any, Union, AsyncGe
 from collections import deque
 from uuid import UUID, uuid4
 import mimetypes
+import hashlib
 import asyncio
 import os.path
 import time
@@ -28,7 +29,7 @@ from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Cont
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.bridge import BasePortal
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
-                           TextMessageEventContent, MessageEvent, EncryptedEvent,
+                           TextMessageEventContent, MessageEvent, EncryptedEvent, ContentURI,
                            MediaMessageEventContent, ImageInfo, VideoInfo, FileInfo, AudioInfo)
 from mautrix.errors import MatrixError, MForbidden
 
@@ -70,8 +71,9 @@ class Portal(DBPortal, BasePortal):
     _reaction_lock: asyncio.Lock
 
     def __init__(self, chat_id: Union[str, UUID], receiver: str, mxid: Optional[RoomID] = None,
-                 name: Optional[str] = None, encrypted: bool = False) -> None:
-        super().__init__(chat_id, receiver, mxid, name, encrypted)
+                 name: Optional[str] = None, avatar_hash: Optional[str] = None,
+                 avatar_url: Optional[ContentURI] = None, encrypted: bool = False) -> None:
+        super().__init__(chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted)
         self._create_room_lock = asyncio.Lock()
         self.log = self.log.getChild(str(chat_id))
         self._main_intent = None
@@ -434,6 +436,7 @@ class Portal(DBPortal, BasePortal):
         if not isinstance(info, Group):
             raise ValueError(f"Unexpected type for group update_info: {type(info)}")
         changed = await self._update_name(info.name)
+        changed = await self._update_avatar()
         await self._update_participants(info.members)
         if changed:
             await self.update_bridge_info()
@@ -457,6 +460,24 @@ class Portal(DBPortal, BasePortal):
             return True
         return False
 
+    async def _update_avatar(self) -> bool:
+        if self.is_direct:
+            return False
+        path = os.path.join(self.config["signal.avatar_dir"], f"group-{self.chat_id}")
+        try:
+            with open(path, "rb") as file:
+                data = file.read()
+        except FileNotFoundError:
+            return False
+        new_hash = hashlib.sha256(data).hexdigest()
+        if self.avatar_hash and new_hash == self.avatar_hash:
+            return False
+        mxc = await self.main_intent.upload_media(data)
+        await self.main_intent.set_room_avatar(self.mxid, mxc)
+        self.avatar_url = mxc
+        self.avatar_hash = new_hash
+        return True
+
     async def _update_participants(self, participants: List[Address]) -> None:
         if not self.mxid or not participants:
             return
@@ -487,6 +508,7 @@ class Portal(DBPortal, BasePortal):
             "channel": {
                 "id": str(self.chat_id),
                 "displayname": self.name,
+                "avatar_url": self.avatar_url,
             }
         }