Ver Fonte

Add bridging for group avatars from Signal

Tulir Asokan há 4 anos atrás
pai
commit
b25c2e8a88

+ 4 - 4
ROADMAP.md

@@ -31,14 +31,14 @@
   * [x] Message reactions
   * [x] Message reactions
   * [ ] Initial profile info
   * [ ] Initial profile info
     * [x] User displayname
     * [x] User displayname
-    * [ ] User avatar
+    * [ ] User avatar
     * [x] Group name
     * [x] Group name
-    * [ ] Group avatar
+    * [x] Group avatar
   * [ ] Profile info changes
   * [ ] Profile info changes
     * [ ] User displayname
     * [ ] User displayname
-    * [ ] User avatar
+    * [ ] User avatar
     * [x] Group name
     * [x] Group name
-    * [ ] Group avatar
+    * [x] Group avatar
   * [ ] Typing notifications
   * [ ] Typing notifications
   * [x] Read receipts
   * [x] Read receipts
   * [ ] Disappearing messages
   * [ ] Disappearing messages

+ 1 - 0
mautrix_signal/config.py

@@ -55,6 +55,7 @@ class Config(BaseBridgeConfig):
 
 
         copy("signal.socket_path")
         copy("signal.socket_path")
         copy("signal.outgoing_attachment_dir")
         copy("signal.outgoing_attachment_dir")
+        copy("signal.avatar_dir")
         copy("signal.remove_file_after_handling")
         copy("signal.remove_file_after_handling")
 
 
         copy("metrics.enabled")
         copy("metrics.enabled")

+ 17 - 11
mautrix_signal/db/portal.py

@@ -19,7 +19,7 @@ from uuid import UUID
 from attr import dataclass
 from attr import dataclass
 import asyncpg
 import asyncpg
 
 
-from mautrix.types import RoomID
+from mautrix.types import RoomID, ContentURI
 from mautrix.util.async_db import Database
 from mautrix.util.async_db import Database
 
 
 fake_db = Database("") if TYPE_CHECKING else None
 fake_db = Database("") if TYPE_CHECKING else None
@@ -33,19 +33,22 @@ class Portal:
     receiver: str
     receiver: str
     mxid: Optional[RoomID]
     mxid: Optional[RoomID]
     name: Optional[str]
     name: Optional[str]
+    avatar_hash: Optional[str]
+    avatar_url: Optional[ContentURI]
     encrypted: bool
     encrypted: bool
 
 
     async def insert(self) -> None:
     async def insert(self) -> None:
-        q = ("INSERT INTO portal (chat_id, receiver, mxid, name, encrypted) "
-             "VALUES ($1, $2, $3, $4, $5)")
+        q = ("INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, "
+             "                    encrypted) "
+             "VALUES ($1, $2, $3, $4, $5, $6, $7)")
         await self.db.execute(q, str(self.chat_id), self.receiver, self.mxid, self.name,
         await self.db.execute(q, str(self.chat_id), self.receiver, self.mxid, self.name,
                               self.encrypted)
                               self.encrypted)
 
 
     async def update(self) -> None:
     async def update(self) -> None:
-        q = ("UPDATE portal SET mxid=$3, name=$4, encrypted=$5 "
+        q = ("UPDATE portal SET mxid=$3, name=$4, avatar_hash=$5, avatar_url=$6, encrypted=$7 "
              "WHERE chat_id=$1 AND receiver=$2")
              "WHERE chat_id=$1 AND receiver=$2")
         await self.db.execute(q, str(self.chat_id), self.receiver, self.mxid, self.name,
         await self.db.execute(q, str(self.chat_id), self.receiver, self.mxid, self.name,
-                              self.encrypted)
+                              self.avatar_hash, self.avatar_url, self.encrypted)
 
 
     @classmethod
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> 'Portal':
     def _from_row(cls, row: asyncpg.Record) -> 'Portal':
@@ -58,7 +61,8 @@ class Portal:
 
 
     @classmethod
     @classmethod
     async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
     async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
-        q = "SELECT chat_id, receiver, mxid, name, encrypted FROM portal WHERE mxid=$1"
+        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
+             "FROM portal WHERE mxid=$1")
         row = await cls.db.fetchrow(q, mxid)
         row = await cls.db.fetchrow(q, mxid)
         if not row:
         if not row:
             return None
             return None
@@ -67,7 +71,7 @@ class Portal:
     @classmethod
     @classmethod
     async def get_by_chat_id(cls, chat_id: Union[UUID, str], receiver: str = ""
     async def get_by_chat_id(cls, chat_id: Union[UUID, str], receiver: str = ""
                              ) -> Optional['Portal']:
                              ) -> Optional['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, encrypted "
+        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
              "FROM portal WHERE chat_id=$1 AND receiver=$2")
              "FROM portal WHERE chat_id=$1 AND receiver=$2")
         row = await cls.db.fetchrow(q, str(chat_id), receiver)
         row = await cls.db.fetchrow(q, str(chat_id), receiver)
         if not row:
         if not row:
@@ -76,19 +80,21 @@ class Portal:
 
 
     @classmethod
     @classmethod
     async def find_private_chats_of(cls, receiver: str) -> List['Portal']:
     async def find_private_chats_of(cls, receiver: str) -> List['Portal']:
-        q = "SELECT chat_id, receiver, mxid, name, encrypted FROM portal WHERE receiver=$1"
+        q =( "SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
+             "FROM portal WHERE receiver=$1")
         rows = await cls.db.fetch(q, receiver)
         rows = await cls.db.fetch(q, receiver)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]
 
 
     @classmethod
     @classmethod
     async def find_private_chats_with(cls, other_user: UUID) -> List['Portal']:
     async def find_private_chats_with(cls, other_user: UUID) -> List['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, encrypted FROM portal "
-             "WHERE chat_id=$1::text AND receiver<>''")
+        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
+             "FROM portal WHERE chat_id=$1::text AND receiver<>''")
         rows = await cls.db.fetch(q, other_user)
         rows = await cls.db.fetch(q, other_user)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]
 
 
     @classmethod
     @classmethod
     async def all_with_room(cls) -> List['Portal']:
     async def all_with_room(cls) -> List['Portal']:
-        q = "SELECT chat_id, receiver, mxid, name, encrypted FROM portal WHERE mxid IS NOT NULL"
+        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
+             "FROM portal WHERE mxid IS NOT NULL")
         rows = await cls.db.fetch(q)
         rows = await cls.db.fetch(q)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]

+ 6 - 0
mautrix_signal/db/upgrade.py

@@ -89,3 +89,9 @@ async def upgrade_v1(conn: Connection) -> None:
             ON DELETE CASCADE ON UPDATE CASCADE,
             ON DELETE CASCADE ON UPDATE CASCADE,
         UNIQUE (mxid, mx_room)
         UNIQUE (mxid, mx_room)
     )""")
     )""")
+
+
+@upgrade_table.register(description="Add avatar info to portal table")
+async def upgrade_v2(conn: Connection) -> None:
+    await conn.execute("ALTER TABLE portal ADD COLUMN avatar_hash TEXT")
+    await conn.execute("ALTER TABLE portal ADD COLUMN avatar_url TEXT")

+ 2 - 0
mautrix_signal/example-config.yaml

@@ -70,6 +70,8 @@ signal:
     # absolute path that signald can read. For attachments in the other direction,
     # absolute path that signald can read. For attachments in the other direction,
     # make sure signald is configured to use an absolute path as the data directory.
     # make sure signald is configured to use an absolute path as the data directory.
     outgoing_attachment_dir: /tmp
     outgoing_attachment_dir: /tmp
+    # Directory where signald stores avatars for groups.
+    avatar_dir: ~/.config/signald/avatars
     # Whether or not message attachments should be removed from disk after they're bridged.
     # Whether or not message attachments should be removed from disk after they're bridged.
     remove_file_after_handling: true
     remove_file_after_handling: true
 
 

+ 25 - 3
mautrix_signal/portal.py

@@ -18,6 +18,7 @@ from typing import (Dict, Tuple, Optional, List, Deque, Set, Any, Union, AsyncGe
 from collections import deque
 from collections import deque
 from uuid import UUID, uuid4
 from uuid import UUID, uuid4
 import mimetypes
 import mimetypes
+import hashlib
 import asyncio
 import asyncio
 import os.path
 import os.path
 import time
 import time
@@ -28,7 +29,7 @@ from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Cont
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.bridge import BasePortal
 from mautrix.bridge import BasePortal
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
-                           TextMessageEventContent, MessageEvent, EncryptedEvent,
+                           TextMessageEventContent, MessageEvent, EncryptedEvent, ContentURI,
                            MediaMessageEventContent, ImageInfo, VideoInfo, FileInfo, AudioInfo)
                            MediaMessageEventContent, ImageInfo, VideoInfo, FileInfo, AudioInfo)
 from mautrix.errors import MatrixError, MForbidden
 from mautrix.errors import MatrixError, MForbidden
 
 
@@ -70,8 +71,9 @@ class Portal(DBPortal, BasePortal):
     _reaction_lock: asyncio.Lock
     _reaction_lock: asyncio.Lock
 
 
     def __init__(self, chat_id: Union[str, UUID], receiver: str, mxid: Optional[RoomID] = None,
     def __init__(self, chat_id: Union[str, UUID], receiver: str, mxid: Optional[RoomID] = None,
-                 name: Optional[str] = None, encrypted: bool = False) -> None:
-        super().__init__(chat_id, receiver, mxid, name, encrypted)
+                 name: Optional[str] = None, avatar_hash: Optional[str] = None,
+                 avatar_url: Optional[ContentURI] = None, encrypted: bool = False) -> None:
+        super().__init__(chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted)
         self._create_room_lock = asyncio.Lock()
         self._create_room_lock = asyncio.Lock()
         self.log = self.log.getChild(str(chat_id))
         self.log = self.log.getChild(str(chat_id))
         self._main_intent = None
         self._main_intent = None
@@ -434,6 +436,7 @@ class Portal(DBPortal, BasePortal):
         if not isinstance(info, Group):
         if not isinstance(info, Group):
             raise ValueError(f"Unexpected type for group update_info: {type(info)}")
             raise ValueError(f"Unexpected type for group update_info: {type(info)}")
         changed = await self._update_name(info.name)
         changed = await self._update_name(info.name)
+        changed = await self._update_avatar()
         await self._update_participants(info.members)
         await self._update_participants(info.members)
         if changed:
         if changed:
             await self.update_bridge_info()
             await self.update_bridge_info()
@@ -457,6 +460,24 @@ class Portal(DBPortal, BasePortal):
             return True
             return True
         return False
         return False
 
 
+    async def _update_avatar(self) -> bool:
+        if self.is_direct:
+            return False
+        path = os.path.join(self.config["signal.avatar_dir"], f"group-{self.chat_id}")
+        try:
+            with open(path, "rb") as file:
+                data = file.read()
+        except FileNotFoundError:
+            return False
+        new_hash = hashlib.sha256(data).hexdigest()
+        if self.avatar_hash and new_hash == self.avatar_hash:
+            return False
+        mxc = await self.main_intent.upload_media(data)
+        await self.main_intent.set_room_avatar(self.mxid, mxc)
+        self.avatar_url = mxc
+        self.avatar_hash = new_hash
+        return True
+
     async def _update_participants(self, participants: List[Address]) -> None:
     async def _update_participants(self, participants: List[Address]) -> None:
         if not self.mxid or not participants:
         if not self.mxid or not participants:
             return
             return
@@ -487,6 +508,7 @@ class Portal(DBPortal, BasePortal):
             "channel": {
             "channel": {
                 "id": str(self.chat_id),
                 "id": str(self.chat_id),
                 "displayname": self.name,
                 "displayname": self.name,
+                "avatar_url": self.avatar_url,
             }
             }
         }
         }