Эх сурвалжийг харах

Handle thread image changes

Tulir Asokan 2 жил өмнө
parent
commit
ee3642f604

+ 2 - 0
mauigpapi/mqtt/conn.py

@@ -318,6 +318,8 @@ class AndroidMQTT:
             subitem_key = rest[0]
             if subitem_key == "approval_required_for_new_members":
                 additional["approval_required_for_new_members"] = True
+            elif subitem_key == "thread_image":
+                additional["is_thread_image"] = True
             elif subitem_key == "participants" and len(rest) > 2 and rest[2] == "has_seen":
                 additional["has_seen"] = int(rest[1])
             elif subitem_key == "items":

+ 2 - 8
mauigpapi/types/__init__.py

@@ -63,14 +63,7 @@ from .mqtt import (
     ZeroProductProvisioningEvent,
 )
 from .qe import AndroidExperiment, QeSyncExperiment, QeSyncExperimentParam, QeSyncResponse
-from .thread import (
-    Thread,
-    ThreadImageCandidate,
-    ThreadItem,
-    ThreadTheme,
-    ThreadUser,
-    ThreadUserLastSeenAt,
-)
+from .thread import Thread, ThreadTheme, ThreadUser, ThreadUserLastSeenAt
 from .thread_item import (
     AnimatedMediaImage,
     AnimatedMediaImages,
@@ -98,6 +91,7 @@ from .thread_item import (
     RegularMediaItem,
     ReplayableMediaItem,
     SharingFrictionInfo,
+    ThreadImage,
     ThreadItem,
     ThreadItemActionLog,
     ThreadItemType,

+ 1 - 0
mauigpapi/types/mqtt.py

@@ -110,6 +110,7 @@ class MessageSyncMessage(ThreadItem, SerializableAttrs):
     # These come from parsing the path
     admin_user_id: Optional[int] = None
     approval_required_for_new_members: Optional[bool] = None
+    is_thread_image: Optional[bool] = None
     has_seen: Optional[int] = None
     thread_id: Optional[str] = None
 

+ 1 - 23
mauigpapi/types/thread.py

@@ -21,7 +21,7 @@ import attr
 from mautrix.types import SerializableAttrs
 
 from .account import BaseResponseUser
-from .thread_item import ThreadItem
+from .thread_item import ThreadImage, ThreadItem
 
 
 @dataclass
@@ -47,28 +47,6 @@ class ThreadUserLastSeenAt(SerializableAttrs):
         return int(self.timestamp) // 1000
 
 
-@dataclass
-class ThreadImageCandidate(SerializableAttrs):
-    width: int
-    height: int
-    url: str
-    url_expiration_timestamp_us: int
-
-
-@dataclass
-class ThreadImageCandidates(SerializableAttrs):
-    candidates: List[ThreadImageCandidate]
-
-
-@dataclass
-class ThreadImage(SerializableAttrs):
-    id: int
-    media_type: int
-    image_versions2: ThreadImageCandidates
-    original_width: int
-    original_height: int
-
-
 @dataclass(kw_only=True)
 class Thread(SerializableAttrs):
     thread_id: str

+ 27 - 16
mauigpapi/types/thread_item.py

@@ -124,6 +124,25 @@ class ImageVersions(SerializableAttrs):
     candidates: List[ImageVersion]
 
 
+@dataclass
+class ImageVersionsContainer(SerializableAttrs):
+    image_versions2: Optional[ImageVersions] = None
+    original_width: Optional[int] = None
+    original_height: Optional[int] = None
+
+    @property
+    def best_image(self) -> Optional[ImageVersion]:
+        if not self.image_versions2:
+            return None
+        best: Optional[ImageVersion] = None
+        for version in self.image_versions2.candidates:
+            if version.width == self.original_width and version.height == self.original_height:
+                return version
+            elif not best or (version.width * version.height > best.width * best.height):
+                best = version
+        return best
+
+
 @dataclass(kw_only=True)
 class VideoVersion(SerializableAttrs):
     type: int
@@ -156,12 +175,9 @@ class ExpiredMediaItem(SerializableAttrs):
 
 
 @dataclass(kw_only=True)
-class RegularMediaItem(SerializableAttrs):
+class RegularMediaItem(ImageVersionsContainer, SerializableAttrs):
     id: str
-    image_versions2: Optional[ImageVersions] = None
     video_versions: Optional[List[VideoVersion]] = None
-    original_width: Optional[int] = None
-    original_height: Optional[int] = None
     media_type: MediaType
     media_id: Optional[int] = None
     organic_tracking_token: Optional[str] = None
@@ -170,18 +186,6 @@ class RegularMediaItem(SerializableAttrs):
     is_commercial: Optional[bool] = None
     commerciality_status: Optional[str] = None  # TODO enum? commercial
 
-    @property
-    def best_image(self) -> Optional[ImageVersion]:
-        if not self.image_versions2:
-            return None
-        best: Optional[ImageVersion] = None
-        for version in self.image_versions2.candidates:
-            if version.width == self.original_width and version.height == self.original_height:
-                return version
-            elif not best or (version.width * version.height > best.width * best.height):
-                best = version
-        return best
-
     @property
     def best_video(self) -> Optional[VideoVersion]:
         if not self.video_versions:
@@ -588,6 +592,12 @@ class MentionedEntity(SerializableAttrs):
     interop_user_type: int
 
 
+@dataclass(kw_only=True)
+class ThreadImage(ImageVersionsContainer, SerializableAttrs):
+    id: int
+    media_type: int
+
+
 @dataclass(kw_only=True)
 class ThreadItem(SerializableAttrs):
     item_id: Optional[str] = None
@@ -603,6 +613,7 @@ class ThreadItem(SerializableAttrs):
     client_context: Optional[str] = None
     show_forward_attribution: Optional[bool] = None
     action_log: Optional[ThreadItemActionLog] = None
+    thread_image: Optional[ThreadImage] = None
     auxiliary_text: Optional[str] = None
     auxiliary_text_source_type: Optional[int] = None
     message_item_type: Optional[str] = None

+ 21 - 18
mautrix_instagram/portal.py

@@ -58,7 +58,7 @@ from mauigpapi.types import (
     ReelShareType,
     RegularMediaItem,
     Thread,
-    ThreadImageCandidate,
+    ThreadImage,
     ThreadItem,
     ThreadItemType,
     ThreadUser,
@@ -1892,33 +1892,35 @@ class Portal(DBPortal, BasePortal):
 
         return ""
 
-    async def _get_thread_avatar(self, source: u.User, thread: Thread) -> Optional[ContentURI]:
-        if self.is_direct or not thread.thread_image:
-            return None
-        if self.thread_image_id == thread.thread_image.id:
-            return self.avatar_url
-        best: Optional[ThreadImageCandidate] = None
-        for candidate in thread.thread_image.image_versions2.candidates:
-            if best is None or candidate.width > best.width:
-                best = candidate
+    async def update_thread_image(
+        self, source: u.User, thread_image: ThreadImage, sender: p.Puppet | None = None
+    ) -> bool:
+        if (
+            self.is_direct
+            or not thread_image
+            or (self.thread_image_id == thread_image.id and self.avatar_set)
+        ):
+            return False
+
+        best = thread_image.best_image
         if not best:
-            return None
+            return False
         data, mimetype = await self._download_instagram_file(source, best.url)
         if not data:
-            return None
+            return False
+        self.thread_image_id = thread_image.id
+        self.avatar_set = False
         mxc = await self.main_intent.upload_media(
             data=data,
             mime_type=mimetype,
-            filename=str(thread.thread_image.id),
+            filename=str(thread_image.id),
             async_upload=self.config["homeserver.async_media"],
         )
-        self.thread_image_id = thread.thread_image.id
-        return mxc
+        return await self._update_photo(mxc, sender=sender)
 
     async def update_info(self, thread: Thread, source: u.User) -> None:
         changed = await self._update_name(self._get_thread_name(thread))
-        if thread_avatar := await self._get_thread_avatar(source, thread):
-            changed = await self._update_photo(thread_avatar)
+        changed = await self.update_thread_image(source, thread.thread_image) or changed
         changed = await self._update_participants(thread.users, source) or changed
         if changed:
             await self.update_bridge_info()
@@ -1947,13 +1949,14 @@ class Portal(DBPortal, BasePortal):
             return True
         return False
 
-    async def _update_photo(self, photo_mxc: ContentURI) -> bool:
+    async def _update_photo(self, photo_mxc: ContentURI, sender: p.Puppet | None = None) -> bool:
         if self.avatar_url == photo_mxc and (self.avatar_set or not self.set_dm_room_metadata):
             return False
         self.avatar_url = photo_mxc
         self.avatar_set = False
         if self.mxid and self.set_dm_room_metadata:
             try:
+                # TODO use sender intent
                 await self.main_intent.set_room_avatar(self.mxid, photo_mxc)
                 self.avatar_set = True
             except Exception:

+ 3 - 1
mautrix_instagram/user.py

@@ -1095,7 +1095,9 @@ class User(DBUser, BaseUser):
             )
             return
         sender = await pu.Puppet.get_by_pk(evt.message.user_id) if evt.message.user_id else None
-        if evt.message.op == Operation.ADD:
+        if evt.message.is_thread_image:
+            await portal.update_thread_image(self, evt.message.thread_image, sender=sender)
+        elif evt.message.op == Operation.ADD:
             if not sender:
                 # I don't think we care about adds with no sender
                 return