Răsfoiți Sursa

Reduce code duplication in avatar uploads

Tulir Asokan 4 ani în urmă
părinte
comite
520405aebd
2 a modificat fișierele cu 24 adăugiri și 18 ștergeri
  1. 13 14
      mautrix_signal/portal.py
  2. 11 4
      mautrix_signal/puppet.py

+ 13 - 14
mautrix_signal/portal.py

@@ -609,27 +609,26 @@ class Portal(DBPortal, BasePortal):
             return True
         return False
 
+    @property
+    def avatar_set(self) -> bool:
+        return bool(self.avatar_hash)
+
     async def _update_avatar(self, info: ChatInfo) -> bool:
         path = None
         if isinstance(info, GroupV2):
             path = info.avatar
         elif isinstance(info, Group):
-            path = os.path.join(self.config["signal.avatar_dir"], f"group-{self.chat_id}")
-        if not path:
-            return False
-        try:
-            with open(path, "rb") as file:
-                data = file.read()
-        except FileNotFoundError:
+            path = f"group-{self.chat_id}"
+        res = await p.Puppet.upload_avatar(self, path)
+        if res is False:
             return False
-        new_hash = hashlib.sha256(data).hexdigest()
-        if self.avatar_hash and new_hash == self.avatar_hash:
-            return False
-        mxc = await self.main_intent.upload_media(data)
+        self.avatar_hash, self.avatar_url = res
         if self.mxid:
-            await self.main_intent.set_room_avatar(self.mxid, mxc)
-        self.avatar_url = mxc
-        self.avatar_hash = new_hash
+            try:
+                await self.main_intent.set_room_avatar(self.mxid, self.avatar_url)
+            except Exception:
+                self.log.exception("Error setting avatar")
+                self.avatar_hash = None
         return True
 
     async def _update_participants(self, source: 'u.User', participants: List[Address]) -> None:

+ 11 - 4
mautrix_signal/puppet.py

@@ -13,7 +13,7 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import (Optional, Dict, AsyncIterable, Awaitable, AsyncGenerator, Union,
+from typing import (Optional, Dict, AsyncIterable, Awaitable, AsyncGenerator, Union, Tuple,
                     TYPE_CHECKING, cast)
 from uuid import UUID
 import hashlib
@@ -219,7 +219,9 @@ class Puppet(DBPuppet, BasePuppet):
             return True
         return False
 
-    async def _update_avatar(self, path: str) -> bool:
+    @staticmethod
+    async def upload_avatar(self: Union['Puppet', 'p.Portal'], path: str
+                            ) -> Union[bool, Tuple[str, ContentURI]]:
         if not path:
             return False
         if not path.startswith("/"):
@@ -233,8 +235,13 @@ class Puppet(DBPuppet, BasePuppet):
         if self.avatar_set and new_hash == self.avatar_hash:
             return False
         mxc = await self.default_mxid_intent.upload_media(data)
-        self.avatar_hash = new_hash
-        self.avatar_url = mxc
+        return new_hash, mxc
+
+    async def _update_avatar(self, path: str) -> bool:
+        res = await Puppet.upload_avatar(self, path)
+        if res is False:
+            return False
+        self.avatar_hash, self.avatar_url = res
         try:
             await self.default_mxid_intent.set_avatar_url(self.avatar_url)
             self.avatar_set = True