Przeglądaj źródła

portal: support group chat avatars

Signed-off-by: Sumner Evans <sumner@beeper.com>
Sumner Evans 2 lat temu
rodzic
commit
6f40e012a2

+ 8 - 1
mauigpapi/types/__init__.py

@@ -62,7 +62,14 @@ from .mqtt import (
     ZeroProductProvisioningEvent,
 )
 from .qe import AndroidExperiment, QeSyncExperiment, QeSyncExperimentParam, QeSyncResponse
-from .thread import Thread, ThreadItem, ThreadTheme, ThreadUser, ThreadUserLastSeenAt
+from .thread import (
+    Thread,
+    ThreadImageCandidate,
+    ThreadItem,
+    ThreadTheme,
+    ThreadUser,
+    ThreadUserLastSeenAt,
+)
 from .thread_item import (
     AnimatedMediaImage,
     AnimatedMediaImages,

+ 24 - 0
mauigpapi/types/thread.py

@@ -47,6 +47,28 @@ class ThreadUserLastSeenAt(SerializableAttrs):
         return int(self.timestamp) // 1000
 
 
+@dataclass
+class ThreadImageCandidate(SerializableAttrs):
+    width: int
+    height: int
+    url: str
+    url_expiration_timestamp_us: int
+
+
+@dataclass
+class ThreadImageCandidates(SerializableAttrs):
+    candidates: List[ThreadImageCandidate]
+
+
+@dataclass
+class ThreadImage(SerializableAttrs):
+    id: int
+    media_type: int
+    image_versions2: ThreadImageCandidates
+    original_width: int
+    original_height: int
+
+
 @dataclass(kw_only=True)
 class Thread(SerializableAttrs):
     thread_id: str
@@ -85,6 +107,8 @@ class Thread(SerializableAttrs):
     has_older: bool
     has_newer: bool
 
+    thread_image: Optional[ThreadImage] = None
+
     theme: ThreadTheme
     last_seen_at: Dict[str, ThreadUserLastSeenAt] = attr.ib(factory=lambda: {})
 

+ 6 - 3
mautrix_instagram/db/portal.py

@@ -43,7 +43,8 @@ class Portal:
     first_event_id: EventID | None
     next_batch_id: BatchID | None
     historical_base_insertion_event_id: EventID | None
-    cursor: str | none
+    cursor: str | None
+    thread_image_id: int | None
 
     @property
     def _values(self):
@@ -62,6 +63,7 @@ class Portal:
             self.next_batch_id,
             self.historical_base_insertion_event_id,
             self.cursor,
+            self.thread_image_id,
         )
 
     column_names = ",".join(
@@ -80,13 +82,14 @@ class Portal:
             "next_batch_id",
             "historical_base_insertion_event_id",
             "cursor",
+            "thread_image_id",
         )
     )
 
     async def insert(self) -> None:
         q = (
             f"INSERT INTO portal ({self.column_names}) "
-            "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)"
+            "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)"
         )
         await self.db.execute(q, *self._values)
 
@@ -95,7 +98,7 @@ class Portal:
             "UPDATE portal SET other_user_pk=$3, mxid=$4, name=$5, avatar_url=$6, encrypted=$7,"
             "                  name_set=$8, avatar_set=$9, relay_user_id=$10, first_event_id=$11,"
             "                  next_batch_id=$12, historical_base_insertion_event_id=$13,"
-            "                  cursor=$14 "
+            "                  cursor=$14, thread_image_id=$15 "
             "WHERE thread_id=$1 AND receiver=$2"
         )
         await self.db.execute(q, *self._values)

+ 1 - 0
mautrix_instagram/db/upgrade/__init__.py

@@ -29,4 +29,5 @@ from . import (
     v09_backfill_queue,
     v10_portal_infinite_backfill,
     v11_per_user_thread_sync_status,
+    v12_portal_thread_image_id,
 )

+ 23 - 0
mautrix_instagram/db/upgrade/v12_portal_thread_image_id.py

@@ -0,0 +1,23 @@
+# mautrix-instagram - A Matrix-Instagram puppeting bridge.
+# Copyright (C) 2022 Tulir Asokan, Sumner Evans
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+from mautrix.util.async_db import Connection
+
+from . import upgrade_table
+
+
+@upgrade_table.register(description="Add column to portal to track group thread image ID")
+async def upgrade_v12(conn: Connection) -> None:
+    await conn.execute("ALTER TABLE portal ADD COLUMN thread_image_id INTEGER")

+ 57 - 19
mautrix_instagram/portal.py

@@ -15,7 +15,7 @@
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Union, cast
+from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Optional, Union, cast
 from collections import deque
 from io import BytesIO
 import asyncio
@@ -46,6 +46,7 @@ from mauigpapi.types import (
     ReelShareType,
     RegularMediaItem,
     Thread,
+    ThreadImageCandidate,
     ThreadItem,
     ThreadItemType,
     ThreadUser,
@@ -166,6 +167,7 @@ class Portal(DBPortal, BasePortal):
         next_batch_id: BatchID | None = None,
         historical_base_insertion_event_id: EventID | None = None,
         cursor: str | None = None,
+        thread_image_id: int | None = None,
     ) -> None:
         super().__init__(
             thread_id,
@@ -182,6 +184,7 @@ class Portal(DBPortal, BasePortal):
             next_batch_id,
             historical_base_insertion_event_id,
             cursor,
+            thread_image_id,
         )
         self._create_room_lock = asyncio.Lock()
         self.log = self.log.getChild(thread_id)
@@ -810,17 +813,9 @@ class Portal(DBPortal, BasePortal):
         content["org.matrix.msc3245.voice"] = {}
         return content
 
-    async def _reupload_instagram_file(
-        self,
-        source: u.User,
-        url: str,
-        msgtype: MessageType | None,
-        info: ImageInfo | VideoInfo | AudioInfo,
-        intent: IntentAPI,
-        convert_fn: Callable[[bytes, str], Awaitable[tuple[bytes, str]]] | None = None,
-        allow_encrypt: bool = True,
-    ) -> MediaMessageEventContent:
-        data = None
+    async def _download_instagram_file(
+        self, source: u.User, url: str
+    ) -> tuple[Optional[bytes], str]:
         async with source.client.raw_http_get(url) as resp:
             try:
                 length = int(resp.headers["Content-Length"])
@@ -837,8 +832,24 @@ class Portal(DBPortal, BasePortal):
                 )
                 raise ValueError("Attachment not available: too large")
             data = await resp.read()
-            info.mimetype = resp.headers["Content-Type"] or magic.from_buffer(data, mime=True)
+            if not data:
+                return None, ""
+            mimetype = resp.headers["Content-Type"] or magic.from_buffer(data, mime=True)
+            return data, mimetype
+
+    async def _reupload_instagram_file(
+        self,
+        source: u.User,
+        url: str,
+        msgtype: MessageType | None,
+        info: ImageInfo | VideoInfo | AudioInfo,
+        intent: IntentAPI,
+        convert_fn: Callable[[bytes, str], Awaitable[tuple[bytes, str]]] | None = None,
+        allow_encrypt: bool = True,
+    ) -> MediaMessageEventContent:
+        data, mimetype = await self._download_instagram_file(source, url)
         assert data is not None
+        info.mimetype = mimetype
 
         # Run the conversion function on the data.
         if convert_fn is not None:
@@ -1691,14 +1702,38 @@ class Portal(DBPortal, BasePortal):
                 return tpl.format(
                     displayname=ui.full_name or ui.username, id=ui.pk, username=ui.username
                 )
-            pass
         elif thread.thread_title:
             return self.config["bridge.group_chat_name_template"].format(name=thread.thread_title)
-        else:
-            return ""
+
+        return ""
+
+    async def _get_thread_avatar(self, source: u.User, thread: Thread) -> Optional[ContentURI]:
+        if self.is_direct or not thread.thread_image:
+            return None
+        if self.thread_image_id == thread.thread_image.id:
+            return self.avatar_url
+        best: Optional[ThreadImageCandidate] = None
+        for candidate in thread.thread_image.image_versions2.candidates:
+            if best is None or candidate.width > best.width:
+                best = candidate
+        if not best:
+            return None
+        data, mimetype = await self._download_instagram_file(source, best.url)
+        if not data:
+            return None
+        mxc = await self.main_intent.upload_media(
+            data=data,
+            mime_type=mimetype,
+            filename=thread.thread_image.id,
+            async_upload=self.config["homeserver.async_media"],
+        )
+        self.thread_image_id = thread.thread_image.id
+        return mxc
 
     async def update_info(self, thread: Thread, source: u.User) -> None:
         changed = await self._update_name(self._get_thread_name(thread))
+        if thread_avatar := await self._get_thread_avatar(source, thread):
+            changed = await self._update_photo(thread_avatar)
         changed = await self._update_participants(thread.users, source) or changed
         if changed:
             await self.update_bridge_info()
@@ -1730,12 +1765,15 @@ class Portal(DBPortal, BasePortal):
     async def _update_photo_from_puppet(self, puppet: p.Puppet) -> bool:
         if not self.private_chat_portal_meta and not self.encrypted:
             return False
-        if self.avatar_set and self.avatar_url == puppet.photo_mxc:
+        return await self._update_photo(puppet.photo_mxc)
+
+    async def _update_photo(self, photo_mxc: ContentURI) -> bool:
+        if self.avatar_set and self.avatar_url == photo_mxc:
             return False
-        self.avatar_url = puppet.photo_mxc
+        self.avatar_url = photo_mxc
         if self.mxid:
             try:
-                await self.main_intent.set_room_avatar(self.mxid, puppet.photo_mxc)
+                await self.main_intent.set_room_avatar(self.mxid, photo_mxc)
                 self.avatar_set = True
             except Exception:
                 self.log.exception("Failed to set room avatar")