Browse Source

Make signalstickers-client optional and improve sticker handling

Tulir Asokan 4 years ago
parent
commit
bcd1f976a1
3 changed files with 100 additions and 26 deletions
  1. 97 25
      mautrix_signal/portal.py
  2. 3 0
      optional-requirements.txt
  3. 0 1
      requirements.txt

+ 97 - 25
mautrix_signal/portal.py

@@ -24,10 +24,8 @@ import os.path
 import time
 import time
 import os
 import os
 
 
-from signalstickers_client import StickersClient
-
 from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
 from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
-                              Attachment, GroupID, GroupV2ID, GroupV2, Mention)
+                              Attachment, GroupID, GroupV2ID, GroupV2, Mention, Sticker)
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.bridge import BasePortal
 from mautrix.bridge import BasePortal
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
@@ -48,6 +46,12 @@ try:
 except ImportError:
 except ImportError:
     encrypt_attachment = decrypt_attachment = None
     encrypt_attachment = decrypt_attachment = None
 
 
+try:
+    from signalstickers_client import StickersClient
+    from signalstickers_client.models import StickerPack
+except ImportError:
+    StickersClient = StickerPack = None
+
 try:
 try:
     import magic
     import magic
 except ImportError:
 except ImportError:
@@ -61,6 +65,7 @@ ChatInfo = Union[Group, GroupV2, GroupV2ID, Contact, Profile, Address]
 class Portal(DBPortal, BasePortal):
 class Portal(DBPortal, BasePortal):
     by_mxid: Dict[RoomID, 'Portal'] = {}
     by_mxid: Dict[RoomID, 'Portal'] = {}
     by_chat_id: Dict[Tuple[str, str], 'Portal'] = {}
     by_chat_id: Dict[Tuple[str, str], 'Portal'] = {}
+    _sticker_meta_cache: Dict[str, StickerPack] = {}
     config: Config
     config: Config
     matrix: 'm.MatrixHandler'
     matrix: 'm.MatrixHandler'
     signal: 's.SignalHandler'
     signal: 's.SignalHandler'
@@ -323,26 +328,21 @@ class Portal(DBPortal, BasePortal):
         reply_to = await self._find_quote_event_id(message.quote)
         reply_to = await self._find_quote_event_id(message.quote)
 
 
         if message.sticker:
         if message.sticker:
-            if not message.sticker.attachment.incoming_filename:
-                self.log.debug("Downloading sticker from signal, as no incoming filename was defined: %s",
-                               message.sticker.attachment)
-                try:
-                    async with StickersClient() as client:
-                        sticker_data = await client.download_sticker(message.sticker.sticker_id,
-                                                                     message.sticker.pack_id,
-                                                                     message.sticker.pack_key)
-
-                    path = os.path.join(self.config["signal.outgoing_attachment_dir"],
-                                        f"{message.sticker.pack_id}_{message.sticker.sticker_id}")
-                    with open(path, "wb") as file:
-                        file.write(sticker_data)
-                    message.sticker.attachment.incoming_filename = path
-                except Exception as ex:
-                    self.log.warning("Failed to download sticker: %s", ex)
-
             if message.sticker.attachment.incoming_filename:
             if message.sticker.attachment.incoming_filename:
                 content = await self._handle_signal_attachment(intent, message.sticker.attachment)
                 content = await self._handle_signal_attachment(intent, message.sticker.attachment)
-                if reply_to:
+            elif StickersClient:
+                content = await self._handle_signal_sticker(intent, message.sticker)
+            else:
+                self.log.debug(f"Not handling sticker in {message.timestamp}: no incoming_filename"
+                               " and signalstickers-client not installed.")
+                return
+
+            if content:
+                if message.sticker.attachment.blurhash:
+                    content.info["blurhash"] = message.sticker.attachment.blurhash
+                    content.info["xyz.amorgan.blurhash"] = message.sticker.attachment.blurhash
+                await self._add_sticker_meta(message.sticker, content)
+                if reply_to and not message.body:
                     content.set_reply(reply_to)
                     content.set_reply(reply_to)
                     reply_to = None
                     reply_to = None
                 event_id = await self._send_message(intent, content, timestamp=message.timestamp,
                 event_id = await self._send_message(intent, content, timestamp=message.timestamp,
@@ -397,6 +397,9 @@ class Portal(DBPortal, BasePortal):
         if not attachment.custom_filename:
         if not attachment.custom_filename:
             ext = mimetypes.guess_extension(attachment.content_type) or ""
             ext = mimetypes.guess_extension(attachment.content_type) or ""
             attachment.custom_filename = attachment.id + ext
             attachment.custom_filename = attachment.id + ext
+        if attachment.blurhash:
+            info["blurhash"] = attachment.blurhash
+            info["xyz.amorgan.blurhash"] = attachment.blurhash
         return MediaMessageEventContent(msgtype=msgtype, info=info,
         return MediaMessageEventContent(msgtype=msgtype, info=info,
                                         body=attachment.custom_filename)
                                         body=attachment.custom_filename)
 
 
@@ -414,17 +417,86 @@ class Portal(DBPortal, BasePortal):
         if self.config["signal.remove_file_after_handling"]:
         if self.config["signal.remove_file_after_handling"]:
             os.remove(attachment.incoming_filename)
             os.remove(attachment.incoming_filename)
 
 
-        upload_mime_type = attachment.content_type
+        await self._upload_attachment(intent, content, data, attachment.id)
+        return content
+
+    async def _add_sticker_meta(self, sticker: Sticker, content: MediaMessageEventContent) -> None:
+        try:
+            pack = self._sticker_meta_cache[sticker.pack_id]
+        except KeyError:
+            self.log.debug(f"Fetching sticker pack metadata for {sticker.pack_id}")
+            try:
+                async with StickersClient() as client:
+                    pack = await client.get_pack_metadata(sticker.pack_id, sticker.pack_key)
+                self._sticker_meta_cache[sticker.pack_id] = pack
+            except Exception:
+                self.log.warning(f"Failed to fetch pack metadata for {sticker.pack_id}",
+                                 exc_info=True)
+                pack = None
+        if not pack:
+            content.info["fi.mau.signal.sticker"] = {
+                "id": sticker.sticker_id,
+                "pack": {
+                    "id": sticker.pack_id,
+                    "key": sticker.pack_key,
+                },
+            }
+            return
+        sticker_meta = pack.stickers[sticker.sticker_id]
+        content.body = sticker_meta.emoji
+        content.info["fi.mau.signal.sticker"] = {
+            "id": sticker.sticker_id,
+            "emoji": sticker_meta.emoji,
+            "pack": {
+                "id": pack.id,
+                "key": pack.key,
+                "title": pack.title,
+                "author": pack.author,
+            },
+        }
+
+    async def _handle_signal_sticker(self, intent: IntentAPI, sticker: Sticker
+                                     ) -> Optional[MediaMessageEventContent]:
+        try:
+            self.log.debug(f"Fetching sticker {sticker.pack_id}#{sticker.sticker_id}")
+            async with StickersClient() as client:
+                data = await client.download_sticker(sticker.sticker_id,
+                                                     sticker.pack_id, sticker.pack_key)
+        except Exception:
+            self.log.warning(f"Failed to download sticker {sticker.sticker_id}", exc_info=True)
+            return None
+        info = ImageInfo(mimetype=sticker.attachment.content_type, size=len(data),
+                         width=sticker.attachment.width, height=sticker.attachment.height)
+        if info.width > 256 or info.height > 256:
+            if info.width == info.height:
+                info.width = info.height = 256
+            elif info.width > info.height:
+                info.height = int(info.height / (info.width / 256))
+                info.width = 256
+            else:
+                info.width = int(info.width / (info.height / 256))
+                info.height = 256
+        if magic:
+            info.mimetype = magic.from_buffer(data, mime=True)
+        ext = mimetypes.guess_extension(info.mimetype)
+        if not ext and info.mimetype == "image/webp":
+            ext = ".webp"
+        content = MediaMessageEventContent(msgtype=MessageType.IMAGE, info=info,
+                                           body=f"sticker{ext}")
+        await self._upload_attachment(intent, content, data, sticker.attachment.id)
+        return content
+
+    async def _upload_attachment(self, intent: IntentAPI, content: MediaMessageEventContent,
+                                 data: bytes, id: str) -> None:
+        upload_mime_type = content.info.mimetype
         if self.encrypted and encrypt_attachment:
         if self.encrypted and encrypt_attachment:
             data, content.file = encrypt_attachment(data)
             data, content.file = encrypt_attachment(data)
             upload_mime_type = "application/octet-stream"
             upload_mime_type = "application/octet-stream"
 
 
-        content.url = await intent.upload_media(data, mime_type=upload_mime_type,
-                                                filename=attachment.id)
+        content.url = await intent.upload_media(data, mime_type=upload_mime_type, filename=id)
         if content.file:
         if content.file:
             content.file.url = content.url
             content.file.url = content.url
             content.url = None
             content.url = None
-        return content
 
 
     async def handle_signal_reaction(self, sender: 'p.Puppet', reaction: Reaction) -> None:
     async def handle_signal_reaction(self, sender: 'p.Puppet', reaction: Reaction) -> None:
         author_address = await self._resolve_address(reaction.target_author)
         author_address = await self._resolve_address(reaction.target_author)

+ 3 - 0
optional-requirements.txt

@@ -15,3 +15,6 @@ phonenumbers>=8,<9
 #/qrlink
 #/qrlink
 qrcode>=6,<7
 qrcode>=6,<7
 Pillow>=4,<9
 Pillow>=4,<9
+
+#/stickers
+signalstickers-client>=3.0

+ 0 - 1
requirements.txt

@@ -6,4 +6,3 @@ yarl>=1,<2
 attrs>=19.1
 attrs>=19.1
 mautrix>=0.8.11,<0.9
 mautrix>=0.8.11,<0.9
 asyncpg>=0.20,<0.22
 asyncpg>=0.20,<0.22
-signalstickers-client>=3.0