瀏覽代碼

Add support for Signal->Matrix attachments

Tulir Asokan 4 年之前
父節點
當前提交
90d9266ba9
共有 5 個文件被更改,包括 140 次插入42 次删除
  1. 9 8
      ROADMAP.md
  2. 32 6
      mausignald/types.py
  3. 83 14
      mautrix_signal/portal.py
  4. 7 4
      mautrix_signal/puppet.py
  5. 9 10
      mautrix_signal/signal.py

+ 9 - 8
ROADMAP.md

@@ -9,22 +9,23 @@
       * [ ] Files
       * [ ] Files
       * [ ] Gifs
       * [ ] Gifs
       * [ ] Locations
       * [ ] Locations
-      * [ ] Stickers
+      * [ ] Stickers
   * [x] Message reactions
   * [x] Message reactions
   * [ ] Typing notifications
   * [ ] Typing notifications
   * [ ] Read receipts
   * [ ] Read receipts
 * Signal → Matrix
 * Signal → Matrix
   * [ ] Message content
   * [ ] Message content
     * [x] Text
     * [x] Text
-    * [ ] Media
-      * [ ] Images
-      * [ ] Files
-      * [ ] Gifs
+    * [x] Media
+      * [x] Images
+	  * [x] Voice notes
+      * [x] Files
+      * [x] Gifs
       * [ ] Contacts
       * [ ] Contacts
-      * [ ] Locations
-      * [ ] Stickers
+      * [x] Locations
+      * [x] Stickers
   * [x] Message reactions
   * [x] Message reactions
-  * [ ] User and group avatars
+  * [ ] User and group avatars
   * [ ] Typing notifications
   * [ ] Typing notifications
   * [x] Read receipts
   * [x] Read receipts
   * [ ] Disappearing messages
   * [ ] Disappearing messages

+ 32 - 6
mausignald/types.py

@@ -66,11 +66,22 @@ class FullGroup(Group, SerializableAttrs['FullGroup']):
 
 
 @dataclass
 @dataclass
 class Attachment(SerializableAttrs['Attachment']):
 class Attachment(SerializableAttrs['Attachment']):
-    filename: str
+    width: int = 0
+    height: int = 0
+    voice_note: bool = attr.ib(default=False, metadata={"json": "voiceNote"})
+    content_type: Optional[str] = attr.ib(default=None, metadata={"json": "contentType"})
+
+    # Only for incoming
+    id: Optional[str] = None
+    stored_filename: Optional[str] = attr.ib(default=None, metadata={"json": "storedFilename"})
+
+    blurhash: Optional[str] = None
+    digest: Optional[str] = None
+
+    # Only for outgoing
+    filename: Optional[str] = None
+
     caption: Optional[str] = None
     caption: Optional[str] = None
-    width: Optional[int] = None
-    height: Optional[int] = None
-    voice_note: Optional[bool] = attr.ib(default=None, metadata={"json": "voiceNote"})
     preview: Optional[str] = None
     preview: Optional[str] = None
 
 
 
 
@@ -90,6 +101,14 @@ class Reaction(SerializableAttrs['Reaction']):
     target_sent_timestamp: int = attr.ib(metadata={"json": "targetSentTimestamp"})
     target_sent_timestamp: int = attr.ib(metadata={"json": "targetSentTimestamp"})
 
 
 
 
+@dataclass
+class Sticker(SerializableAttrs['Sticker']):
+    attachment: Attachment
+    pack_id: str = attr.ib(metadata={"json": "packID"})
+    pack_key: str = attr.ib(metadata={"json": "packKey"})
+    sticker_id: int = attr.ib(metadata={"json": "stickerID"})
+
+
 @dataclass
 @dataclass
 class MessageData(SerializableAttrs['MessageData']):
 class MessageData(SerializableAttrs['MessageData']):
     timestamp: int
     timestamp: int
@@ -97,7 +116,9 @@ class MessageData(SerializableAttrs['MessageData']):
     body: Optional[str] = None
     body: Optional[str] = None
     quote: Optional[Quote] = None
     quote: Optional[Quote] = None
     reaction: Optional[Reaction] = None
     reaction: Optional[Reaction] = None
-    # TODO attachments, mentions
+    attachments: List[Attachment] = attr.ib(factory=lambda: [])
+    sticker: Optional[Sticker] = None
+    # TODO mentions (although signald doesn't support group v2 yet)
 
 
     group: Optional[Group] = None
     group: Optional[Group] = None
 
 
@@ -106,6 +127,10 @@ class MessageData(SerializableAttrs['MessageData']):
     profile_key_update: bool = attr.ib(default=False, metadata={"json": "profileKeyUpdate"})
     profile_key_update: bool = attr.ib(default=False, metadata={"json": "profileKeyUpdate"})
     view_once: bool = attr.ib(default=False, metadata={"json": "viewOnce"})
     view_once: bool = attr.ib(default=False, metadata={"json": "viewOnce"})
 
 
+    @property
+    def all_attachments(self) -> List[Attachment]:
+        return self.attachments + ([self.sticker] if self.sticker else [])
+
 
 
 @dataclass
 @dataclass
 class SentSyncMessage(SerializableAttrs['SentSyncMessage']):
 class SentSyncMessage(SerializableAttrs['SentSyncMessage']):
@@ -151,7 +176,8 @@ class Receipt(SerializableAttrs['Receipt']):
 class SyncMessage(SerializableAttrs['SyncMessage']):
 class SyncMessage(SerializableAttrs['SyncMessage']):
     sent: Optional[SentSyncMessage] = None
     sent: Optional[SentSyncMessage] = None
     typing: Optional[TypingNotification] = None
     typing: Optional[TypingNotification] = None
-    read_messages: Optional[List[OwnReadReceipt]] = attr.ib(default=None, metadata={"json": "readMessages"})
+    read_messages: Optional[List[OwnReadReceipt]] = attr.ib(default=None,
+                                                            metadata={"json": "readMessages"})
     contacts: Optional[Dict[str, Any]] = None
     contacts: Optional[Dict[str, Any]] = None
     contacts_complete: bool = attr.ib(default=False, metadata={"json": "contactsComplete"})
     contacts_complete: bool = attr.ib(default=False, metadata={"json": "contactsComplete"})
 
 

+ 83 - 14
mautrix_signal/portal.py

@@ -17,15 +17,17 @@ from typing import (Dict, Tuple, Optional, List, Deque, Set, Any, Union, AsyncGe
                     Awaitable, TYPE_CHECKING, cast)
                     Awaitable, TYPE_CHECKING, cast)
 from collections import deque
 from collections import deque
 from uuid import UUID
 from uuid import UUID
+import mimetypes
 import asyncio
 import asyncio
 import time
 import time
 
 
 from mausignald.types import (Address, MessageData, Reaction, Quote, FullGroup, Group, Contact,
 from mausignald.types import (Address, MessageData, Reaction, Quote, FullGroup, Group, Contact,
-                              Profile)
+                              Profile, Attachment)
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.bridge import BasePortal
 from mautrix.bridge import BasePortal
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
-                           TextMessageEventContent, MessageEvent, EncryptedEvent)
+                           TextMessageEventContent, MessageEvent, EncryptedEvent,
+                           MediaMessageEventContent, ImageInfo, VideoInfo, FileInfo, AudioInfo)
 from mautrix.errors import MatrixError, MForbidden
 from mautrix.errors import MatrixError, MForbidden
 
 
 from .db import Portal as DBPortal, Message as DBMessage, Reaction as DBReaction
 from .db import Portal as DBPortal, Message as DBMessage, Reaction as DBReaction
@@ -40,9 +42,14 @@ try:
 except ImportError:
 except ImportError:
     encrypt_attachment = decrypt_attachment = None
     encrypt_attachment = decrypt_attachment = None
 
 
+try:
+    import magic
+except ImportError:
+    magic = None
+
 StateBridge = EventType.find("m.bridge", EventType.Class.STATE)
 StateBridge = EventType.find("m.bridge", EventType.Class.STATE)
 StateHalfShotBridge = EventType.find("uk.half-shot.bridge", EventType.Class.STATE)
 StateHalfShotBridge = EventType.find("uk.half-shot.bridge", EventType.Class.STATE)
-ChatInfo = Union[FullGroup, Group, Contact, Profile]
+ChatInfo = Union[FullGroup, Group, Contact, Profile, Address]
 
 
 
 
 class Portal(DBPortal, BasePortal):
 class Portal(DBPortal, BasePortal):
@@ -255,16 +262,28 @@ class Portal(DBPortal, BasePortal):
             self.log.debug(f"Ignoring message {message.timestamp} by {sender.uuid}"
             self.log.debug(f"Ignoring message {message.timestamp} by {sender.uuid}"
                            " as it was already handled (message.id found in database)")
                            " as it was already handled (message.id found in database)")
             return
             return
+        self.log.debug(f"Started handling message {message.timestamp} by {sender.uuid}")
+        self.log.trace(f"Message content: {message}")
         self._msgts_dedup.appendleft((sender.uuid, message.timestamp))
         self._msgts_dedup.appendleft((sender.uuid, message.timestamp))
         intent = sender.intent_for(self)
         intent = sender.intent_for(self)
         event_id = None
         event_id = None
         reply_to = await self._find_quote_event_id(message.quote)
         reply_to = await self._find_quote_event_id(message.quote)
-        # TODO attachments
+
+        for attachment in message.all_attachments:
+            content = await self._handle_signal_attachment(intent, attachment)
+            if content:
+                if reply_to and not message.body:
+                    # If there's no text, set the first image as the reply
+                    content.set_reply(reply_to)
+                    reply_to = None
+                event_id = await self._send_message(intent, content, timestamp=message.timestamp)
+
         if message.body:
         if message.body:
             content = TextMessageEventContent(msgtype=MessageType.TEXT, body=message.body)
             content = TextMessageEventContent(msgtype=MessageType.TEXT, body=message.body)
             if reply_to:
             if reply_to:
                 content.set_reply(reply_to)
                 content.set_reply(reply_to)
             event_id = await self._send_message(intent, content, timestamp=message.timestamp)
             event_id = await self._send_message(intent, content, timestamp=message.timestamp)
+
         if event_id:
         if event_id:
             msg = DBMessage(mxid=event_id, mx_room=self.mxid,
             msg = DBMessage(mxid=event_id, mx_room=self.mxid,
                             sender=sender.uuid, timestamp=message.timestamp,
                             sender=sender.uuid, timestamp=message.timestamp,
@@ -272,6 +291,52 @@ class Portal(DBPortal, BasePortal):
             await msg.insert()
             await msg.insert()
             await self._send_delivery_receipt(event_id)
             await self._send_delivery_receipt(event_id)
             self.log.debug(f"Handled Signal message {message.timestamp} -> {event_id}")
             self.log.debug(f"Handled Signal message {message.timestamp} -> {event_id}")
+        else:
+            self.log.debug(f"Didn't get event ID for {message.timestamp}")
+
+    @staticmethod
+    def _make_media_content(attachment: Attachment) -> MediaMessageEventContent:
+        if attachment.content_type.startswith("image/"):
+            msgtype = MessageType.IMAGE
+            info = ImageInfo(mimetype=attachment.content_type,
+                             width=attachment.width, height=attachment.height)
+        elif attachment.content_type.startswith("video/"):
+            msgtype = MessageType.VIDEO
+            info = VideoInfo(mimetype=attachment.content_type,
+                             width=attachment.width, height=attachment.height)
+        elif attachment.voice_note or attachment.content_type.startswith("audio/"):
+            msgtype = MessageType.AUDIO
+            info = AudioInfo(mimetype=attachment.content_type)
+        else:
+            msgtype = MessageType.FILE
+            info = FileInfo(mimetype=attachment.content_type)
+        # TODO add something to signald so we can get the actual file name if one is set
+        ext = mimetypes.guess_extension(attachment.content_type) or ""
+        return MediaMessageEventContent(msgtype=msgtype, body=attachment.id + ext, info=info)
+
+    async def _handle_signal_attachment(self, intent: IntentAPI, attachment: Attachment
+                                        ) -> Optional[MediaMessageEventContent]:
+        self.log.trace(f"Reuploading attachment {attachment}")
+        if not attachment.content_type:
+            attachment.content_type = (magic.from_file(attachment.stored_filename, mime=True)
+                                       if magic is not None else "application/octet-stream")
+
+        content = self._make_media_content(attachment)
+
+        with open(attachment.stored_filename, "rb") as file:
+            data = file.read()
+
+        upload_mime_type = attachment.content_type
+        if self.encrypted and encrypt_attachment:
+            data, content.file = encrypt_attachment(data)
+            upload_mime_type = "application/octet-stream"
+
+        content.url = await intent.upload_media(data, mime_type=upload_mime_type,
+                                                filename=content.body)
+        if content.file:
+            content.file.url = content.url
+            content.url = None
+        return content
 
 
     async def handle_signal_reaction(self, sender: 'p.Puppet', reaction: Reaction) -> None:
     async def handle_signal_reaction(self, sender: 'p.Puppet', reaction: Reaction) -> None:
         author_uuid = await self._find_address_uuid(reaction.target_author)
         author_uuid = await self._find_address_uuid(reaction.target_author)
@@ -318,12 +383,13 @@ class Portal(DBPortal, BasePortal):
 
 
     async def update_info(self, info: ChatInfo) -> None:
     async def update_info(self, info: ChatInfo) -> None:
         if self.is_direct:
         if self.is_direct:
-            # TODO do we need to do something here?
-            #      I think all profile updates should just call puppet.update_info() directly
-            # if not isinstance(info, (Contact, Profile)):
-            #     raise ValueError(f"Unexpected type for direct chat update_info: {type(info)}")
-            # puppet = await p.Puppet.get_by_address(Address(uuid=self.chat_id))
-            # await puppet.update_info(info)
+            if not isinstance(info, (Contact, Profile, Address)):
+                raise ValueError(f"Unexpected type for direct chat update_info: {type(info)}")
+            if not self.name:
+                puppet = await p.Puppet.get_by_address(Address(uuid=self.chat_id))
+                if not puppet.name:
+                    await puppet.update_info(info)
+                self.name = puppet.name
             return
             return
 
 
         if not isinstance(info, Group):
         if not isinstance(info, Group):
@@ -381,7 +447,7 @@ class Portal(DBPortal, BasePortal):
                 "avatar_url": self.config["appservice.bot_avatar"],
                 "avatar_url": self.config["appservice.bot_avatar"],
             },
             },
             "channel": {
             "channel": {
-                "id": self.chat_id,
+                "id": str(self.chat_id),
                 "displayname": self.name,
                 "displayname": self.name,
             }
             }
         }
         }
@@ -406,7 +472,7 @@ class Portal(DBPortal, BasePortal):
     async def update_matrix_room(self, source: 'u.User', info: ChatInfo) -> None:
     async def update_matrix_room(self, source: 'u.User', info: ChatInfo) -> None:
         if not self.is_direct and not isinstance(info, Group):
         if not self.is_direct and not isinstance(info, Group):
             raise ValueError(f"Unexpected type for updating group portal: {type(info)}")
             raise ValueError(f"Unexpected type for updating group portal: {type(info)}")
-        elif self.is_direct and not isinstance(info, (Contact, Profile)):
+        elif self.is_direct and not isinstance(info, (Contact, Profile, Address)):
             raise ValueError(f"Unexpected type for updating direct chat portal: {type(info)}")
             raise ValueError(f"Unexpected type for updating direct chat portal: {type(info)}")
         try:
         try:
             await self._update_matrix_room(source, info)
             await self._update_matrix_room(source, info)
@@ -416,8 +482,11 @@ class Portal(DBPortal, BasePortal):
     async def create_matrix_room(self, source: 'u.User', info: ChatInfo) -> Optional[RoomID]:
     async def create_matrix_room(self, source: 'u.User', info: ChatInfo) -> Optional[RoomID]:
         if not self.is_direct and not isinstance(info, Group):
         if not self.is_direct and not isinstance(info, Group):
             raise ValueError(f"Unexpected type for creating group portal: {type(info)}")
             raise ValueError(f"Unexpected type for creating group portal: {type(info)}")
-        elif self.is_direct and not isinstance(info, (Contact, Profile)):
+        elif self.is_direct and not isinstance(info, (Contact, Profile, Address)):
             raise ValueError(f"Unexpected type for creating direct chat portal: {type(info)}")
             raise ValueError(f"Unexpected type for creating direct chat portal: {type(info)}")
+        if isinstance(info, Group):
+            groups = await self.signal.list_groups(source.username)
+            info = next((g for g in groups if g.group_id == info.group_id), info)
         if self.mxid:
         if self.mxid:
             await self.update_matrix_room(source, info)
             await self.update_matrix_room(source, info)
             return self.mxid
             return self.mxid
@@ -494,7 +563,7 @@ class Portal(DBPortal, BasePortal):
         await self.update()
         await self.update()
         self.log.debug(f"Matrix room created: {self.mxid}")
         self.log.debug(f"Matrix room created: {self.mxid}")
         self.by_mxid[self.mxid] = self
         self.by_mxid[self.mxid] = self
-        if not self.is_direct:
+        if not self.is_direct and isinstance(info, FullGroup):
             await self._update_participants(info.members)
             await self._update_participants(info.members)
         else:
         else:
             puppet = await p.Puppet.get_by_custom_mxid(source.mxid)
             puppet = await p.Puppet.get_by_custom_mxid(source.mxid)

+ 7 - 4
mautrix_signal/puppet.py

@@ -129,15 +129,18 @@ class Puppet(DBPuppet, BasePuppet):
             await prev_intent.leave_room(room_id)
             await prev_intent.leave_room(room_id)
 
 
     async def update_info(self, info: Union[Profile, Contact]) -> None:
     async def update_info(self, info: Union[Profile, Contact]) -> None:
-        if isinstance(info, Contact):
-            if info.address.uuid and not self.uuid:
-                await self.handle_uuid_receive(info.address.uuid)
+        if isinstance(info, (Contact, Address)):
+            address = info.address if isinstance(info, Contact) else info
+            if address.uuid and not self.uuid:
+                await self.handle_uuid_receive(address.uuid)
             if not self.config["bridge.allow_contact_list_name_updates"] and self.name is not None:
             if not self.config["bridge.allow_contact_list_name_updates"] and self.name is not None:
                 return
                 return
 
 
+        name = info.name if isinstance(info, (Contact, Profile)) else None
+
         async with self._update_info_lock:
         async with self._update_info_lock:
             update = False
             update = False
-            update = await self._update_name(info.name) or update
+            update = await self._update_name(name) or update
             if update:
             if update:
                 await self.update()
                 await self.update()
 
 

+ 9 - 10
mautrix_signal/signal.py

@@ -61,27 +61,26 @@ class SignalHandler(SignaldClient):
                 pass
                 pass
             if evt.sync_message.sent:
             if evt.sync_message.sent:
                 await self.handle_message(user, sender, evt.sync_message.sent.message,
                 await self.handle_message(user, sender, evt.sync_message.sent.message,
-                                          recipient_override=evt.sync_message.sent.destination)
+                                          addr_override=evt.sync_message.sent.destination)
             if evt.sync_message.typing:
             if evt.sync_message.typing:
                 # Typing notification from own device
                 # Typing notification from own device
                 pass
                 pass
 
 
     @staticmethod
     @staticmethod
     async def handle_message(user: 'u.User', sender: 'pu.Puppet', msg: MessageData,
     async def handle_message(user: 'u.User', sender: 'pu.Puppet', msg: MessageData,
-                             recipient_override: Optional[Address] = None) -> None:
+                             addr_override: Optional[Address] = None) -> None:
         if msg.group:
         if msg.group:
-            portal = await po.Portal.get_by_chat_id(msg.group.group_id, receiver=user.username)
+            portal = await po.Portal.get_by_chat_id(msg.group.group_id, receiver=user.username,
+                                                    create=True)
         else:
         else:
-            portal = await po.Portal.get_by_chat_id(recipient_override.uuid
-                                                    if recipient_override else sender.uuid,
-                                                    receiver=user.username)
+            portal = await po.Portal.get_by_chat_id(addr_override.uuid
+                                                    if addr_override else sender.uuid,
+                                                    receiver=user.username, create=True)
         if not portal.mxid:
         if not portal.mxid:
-            # TODO create room?
-            # TODO definitely at least log
-            return
+            await portal.create_matrix_room(user, msg.group or addr_override or sender.address)
         if msg.reaction:
         if msg.reaction:
             await portal.handle_signal_reaction(sender, msg.reaction)
             await portal.handle_signal_reaction(sender, msg.reaction)
-        if msg.body:
+        if msg.body or msg.attachments or msg.sticker:
             await portal.handle_signal_message(sender, msg)
             await portal.handle_signal_message(sender, msg)
 
 
     @staticmethod
     @staticmethod