Procházet zdrojové kódy

Set avatars for private chat portals

Tulir Asokan před 3 roky
rodič
revize
30f529d80f

+ 28 - 15
mautrix_instagram/db/portal.py

@@ -18,7 +18,7 @@ from typing import Optional, ClassVar, List, TYPE_CHECKING
 from attr import dataclass
 import asyncpg
 
-from mautrix.types import RoomID
+from mautrix.types import RoomID, ContentURI
 from mautrix.util.async_db import Database
 
 fake_db = Database("") if TYPE_CHECKING else None
@@ -33,19 +33,26 @@ class Portal:
     other_user_pk: Optional[int]
     mxid: Optional[RoomID]
     name: Optional[str]
+    avatar_url: Optional[ContentURI]
     encrypted: bool
+    name_set: bool
+    avatar_set: bool
 
     async def insert(self) -> None:
-        q = ("INSERT INTO portal (thread_id, receiver, other_user_pk, mxid, name, encrypted) "
-             "VALUES ($1, $2, $3, $4, $5, $6)")
+        q = ("INSERT INTO portal (thread_id, receiver, other_user_pk, mxid, name, avatar_url, "
+             "                    encrypted, name_set, avatar_set) "
+             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)")
         await self.db.execute(q, self.thread_id, self.receiver, self.other_user_pk,
-                              self.mxid, self.name, self.encrypted)
+                              self.mxid, self.name, self.avatar_url, self.encrypted,
+                              self.name_set, self.avatar_set)
 
     async def update(self) -> None:
-        q = ("UPDATE portal SET other_user_pk=$3, mxid=$4, name=$5, encrypted=$6 "
+        q = ("UPDATE portal SET other_user_pk=$3, mxid=$4, name=$5, avatar_url=$6, encrypted=$7,"
+             "                  name_set=$8, avatar_set=$9 "
              "WHERE thread_id=$1 AND receiver=$2")
         await self.db.execute(q, self.thread_id, self.receiver, self.other_user_pk,
-                              self.mxid, self.name, self.encrypted)
+                              self.mxid, self.name, self.avatar_url, self.encrypted,
+                              self.name_set, self.avatar_set)
 
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> 'Portal':
@@ -53,7 +60,8 @@ class Portal:
 
     @classmethod
     async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
-        q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, encrypted "
+        q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
+             "       name_set, avatar_set "
              "FROM portal WHERE mxid=$1")
         row = await cls.db.fetchrow(q, mxid)
         if not row:
@@ -63,10 +71,12 @@ class Portal:
     @classmethod
     async def get_by_thread_id(cls, thread_id: str, receiver: int,
                                rec_must_match: bool = True) -> Optional['Portal']:
-        q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, encrypted "
+        q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
+             "       name_set, avatar_set "
              "FROM portal WHERE thread_id=$1 AND receiver=$2")
         if not rec_must_match:
-            q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, encrypted "
+            q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
+                 "       name_set, avatar_set "
                  "FROM portal WHERE thread_id=$1 AND (receiver=$2 OR receiver=0)")
         row = await cls.db.fetchrow(q, thread_id, receiver)
         if not row:
@@ -75,21 +85,24 @@ class Portal:
 
     @classmethod
     async def find_private_chats_of(cls, receiver: int) -> List['Portal']:
-        q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, encrypted FROM portal "
-             "WHERE receiver=$1 AND other_user_pk IS NOT NULL")
+        q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
+             "       name_set, avatar_set "
+             "FROM portal WHERE receiver=$1 AND other_user_pk IS NOT NULL")
         rows = await cls.db.fetch(q, receiver)
         return [cls._from_row(row) for row in rows]
 
     @classmethod
     async def find_private_chats_with(cls, other_user: int) -> List['Portal']:
-        q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, encrypted FROM portal "
-             "WHERE other_user_pk=$1")
+        q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
+             "       name_set, avatar_set "
+             "FROM portal WHERE other_user_pk=$1")
         rows = await cls.db.fetch(q, other_user)
         return [cls._from_row(row) for row in rows]
 
     @classmethod
     async def all_with_room(cls) -> List['Portal']:
-        q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, encrypted FROM portal "
-             'WHERE mxid IS NOT NULL')
+        q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
+             "       name_set, avatar_set "
+             "FROM portal WHERE mxid IS NOT NULL")
         rows = await cls.db.fetch(q)
         return [cls._from_row(row) for row in rows]

+ 8 - 0
mautrix_instagram/db/upgrade.py

@@ -80,3 +80,11 @@ async def upgrade_v1(conn: Connection) -> None:
             ON DELETE CASCADE ON UPDATE CASCADE,
         UNIQUE (mxid, mx_room)
     )""")
+
+
+@upgrade_table.register(description="Add name_set and avatar_set to portal table")
+async def upgrade_v2(conn: Connection) -> None:
+    await conn.execute("ALTER TABLE portal ADD COLUMN avatar_url TEXT")
+    await conn.execute("ALTER TABLE portal ADD COLUMN name_set BOOLEAN NOT NULL DEFAULT false")
+    await conn.execute("ALTER TABLE portal ADD COLUMN avatar_set BOOLEAN NOT NULL DEFAULT false")
+    await conn.execute("UPDATE portal SET name_set=true WHERE name<>''")

+ 45 - 17
mautrix_instagram/portal.py

@@ -82,9 +82,11 @@ class Portal(DBPortal, BasePortal):
     _typing: Set[UserID]
 
     def __init__(self, thread_id: str, receiver: int, other_user_pk: Optional[int],
-                 mxid: Optional[RoomID] = None, name: Optional[str] = None, encrypted: bool = False
-                 ) -> None:
-        super().__init__(thread_id, receiver, other_user_pk, mxid, name, encrypted)
+                 mxid: Optional[RoomID] = None, name: Optional[str] = None,
+                 avatar_url: Optional[ContentURI] = None, encrypted: bool = False,
+                 name_set: bool = False, avatar_set: bool = False) -> None:
+        super().__init__(thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted,
+                         name_set, avatar_set)
         self._create_room_lock = asyncio.Lock()
         self.log = self.log.getChild(thread_id)
         self._msgid_dedup = deque(maxlen=100)
@@ -686,37 +688,62 @@ class Portal(DBPortal, BasePortal):
 
     async def update_info(self, thread: Thread, source: 'u.User') -> None:
         changed = await self._update_name(self._get_thread_name(thread))
+        changed = await self._update_participants(thread.users, source) or changed
         if changed:
             await self.update_bridge_info()
             await self.update()
-        await self._update_participants(thread.users, source)
         # TODO update power levels with thread.admin_user_ids
 
     async def _update_name(self, name: str) -> bool:
-        if self.name != name and name:
+        if name and (self.name != name or not self.name_set):
             self.name = name
             if self.mxid:
-                await self.main_intent.set_room_name(self.mxid, name)
+                try:
+                    await self.main_intent.set_room_name(self.mxid, name)
+                    self.name_set = True
+                except Exception:
+                    self.log.exception("Failed to update name")
+                    self.name_set = False
             return True
         return False
 
-    async def _update_participants(self, users: List[ThreadUser], source: 'u.User') -> None:
-        if not self.mxid:
-            return
+    async def _update_photo_from_puppet(self, puppet: 'p.Puppet') -> bool:
+        if not self.private_chat_portal_meta and not self.encrypted:
+            return False
+        if self.avatar_set and self.avatar_url == puppet.photo_mxc:
+            return False
+        self.avatar_url = puppet.photo_mxc
+        if self.mxid:
+            try:
+                await self.main_intent.set_room_avatar(self.mxid, puppet.photo_mxc)
+                self.avatar_set = True
+            except Exception:
+                self.log.exception("Failed to set room avatar")
+                self.avatar_set = False
+        return True
+
+    async def _update_participants(self, users: List[ThreadUser], source: 'u.User') -> bool:
+        meta_changed = False
 
         # Make sure puppets who should be here are here
         for user in users:
             puppet = await p.Puppet.get_by_pk(user.pk)
             await puppet.update_info(user, source)
-            await puppet.intent_for(self).ensure_joined(self.mxid)
+            if self.mxid:
+                await puppet.intent_for(self).ensure_joined(self.mxid)
+            if puppet.pk == self.other_user_pk:
+                meta_changed = await self._update_photo_from_puppet(puppet)
+
+        if self.mxid:
+            # Kick puppets who shouldn't be here
+            current_members = {int(user.pk) for user in users}
+            for user_id in await self.main_intent.get_room_members(self.mxid):
+                pk = p.Puppet.get_id_from_mxid(user_id)
+                if pk and pk not in current_members and pk != self.other_user_pk:
+                    await self.main_intent.kick_user(self.mxid, p.Puppet.get_mxid_from_id(pk),
+                                                     reason="User had left this Instagram DM")
 
-        # Kick puppets who shouldn't be here
-        current_members = {int(user.pk) for user in users}
-        for user_id in await self.main_intent.get_room_members(self.mxid):
-            pk = p.Puppet.get_id_from_mxid(user_id)
-            if pk and pk not in current_members and pk != self.other_user_pk:
-                await self.main_intent.kick_user(self.mxid, p.Puppet.get_mxid_from_id(pk),
-                                                 reason="User had left this Instagram DM")
+        return meta_changed
 
     async def _update_read_receipts(self, receipts: Dict[Union[int, str], ThreadUserLastSeenAt]
                                     ) -> None:
@@ -804,6 +831,7 @@ class Portal(DBPortal, BasePortal):
             "channel": {
                 "id": self.thread_id,
                 "displayname": self.name,
+                "avatar_url": self.avatar_url,
             }
         }