瀏覽代碼

Add support for mentions in both directions

Tulir Asokan 4 年之前
父節點
當前提交
4c128cc85e
共有 5 個文件被更改,包括 166 次插入15 次删除
  1. 2 2
      ROADMAP.md
  2. 8 4
      mausignald/signald.py
  3. 8 1
      mausignald/types.py
  4. 135 0
      mautrix_signal/formatter.py
  5. 13 8
      mautrix_signal/portal.py

+ 2 - 2
ROADMAP.md

@@ -4,7 +4,7 @@
   * [ ] Message content
   * [ ] Message content
     * [x] Text
     * [x] Text
     * [ ] ‡Formatting
     * [ ] ‡Formatting
-    * [ ] Mentions
+    * [x] Mentions
     * [ ] Media
     * [ ] Media
       * [x] Images
       * [x] Images
       * [x] Audio files
       * [x] Audio files
@@ -23,7 +23,7 @@
 * Signal → Matrix
 * Signal → Matrix
   * [ ] Message content
   * [ ] Message content
     * [x] Text
     * [x] Text
-    * [ ] Mentions
+    * [x] Mentions
     * [ ] Media
     * [ ] Media
       * [x] Images
       * [x] Images
       * [x] Voice notes
       * [x] Voice notes

+ 8 - 4
mausignald/signald.py

@@ -12,7 +12,8 @@ from mautrix.util.logging import TraceLogger
 from .rpc import CONNECT_EVENT, SignaldRPCClient
 from .rpc import CONNECT_EVENT, SignaldRPCClient
 from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
 from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
-                    Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2)
+                    Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2,
+                    Mention)
 
 
 T = TypeVar('T')
 T = TypeVar('T')
 EventHandler = Callable[[T], Awaitable[None]]
 EventHandler = Callable[[T], Awaitable[None]]
@@ -147,12 +148,15 @@ class SignaldClient(SignaldRPCClient):
 
 
     async def send(self, username: str, recipient: Union[Address, GroupID], body: str,
     async def send(self, username: str, recipient: Union[Address, GroupID], body: str,
                    quote: Optional[Quote] = None, attachments: Optional[List[Attachment]] = None,
                    quote: Optional[Quote] = None, attachments: Optional[List[Attachment]] = None,
-                   timestamp: Optional[int] = None) -> None:
+                   mentions: Optional[List[Mention]] = None, timestamp: Optional[int] = None
+                   ) -> None:
         serialized_quote = quote.serialize() if quote else None
         serialized_quote = quote.serialize() if quote else None
         serialized_attachments = [attachment.serialize() for attachment in (attachments or [])]
         serialized_attachments = [attachment.serialize() for attachment in (attachments or [])]
-        await self.request("send", "send_results", username=username, messageBody=body,
+        serialized_mentions = [mention.serialize() for mention in (mentions or [])]
+        await self.request("send", "send", username=username, messageBody=body,
                            attachments=serialized_attachments, quote=serialized_quote,
                            attachments=serialized_attachments, quote=serialized_quote,
-                           timestamp=timestamp, **self._recipient_to_args(recipient))
+                           mentions=serialized_mentions, timestamp=timestamp,
+                           **self._recipient_to_args(recipient), version="v1")
         # TODO return something?
         # TODO return something?
 
 
     async def send_receipt(self, username: str, sender: Address, timestamps: List[int],
     async def send_receipt(self, username: str, sender: Address, timestamps: List[int],

+ 8 - 1
mausignald/types.py

@@ -188,6 +188,13 @@ class RemoteDelete(SerializableAttrs['RemoteDelete']):
     target_sent_timestamp: int = attr.ib(metadata={"json": "targetSentTimestamp"})
     target_sent_timestamp: int = attr.ib(metadata={"json": "targetSentTimestamp"})
 
 
 
 
+@dataclass
+class Mention(SerializableAttrs['Mention']):
+    uuid: UUID
+    length: int
+    start: int = 0
+
+
 @dataclass
 @dataclass
 class MessageData(SerializableAttrs['MessageData']):
 class MessageData(SerializableAttrs['MessageData']):
     timestamp: int
     timestamp: int
@@ -197,7 +204,7 @@ class MessageData(SerializableAttrs['MessageData']):
     reaction: Optional[Reaction] = None
     reaction: Optional[Reaction] = None
     attachments: List[Attachment] = attr.ib(factory=lambda: [])
     attachments: List[Attachment] = attr.ib(factory=lambda: [])
     sticker: Optional[Sticker] = None
     sticker: Optional[Sticker] = None
-    # TODO mentions
+    mentions: List[Mention] = attr.ib(factory=lambda: [])
 
 
     group: Optional[Group] = None
     group: Optional[Group] = None
     group_v2: Optional[GroupV2ID] = attr.ib(default=None, metadata={"json": "groupV2"})
     group_v2: Optional[GroupV2ID] = attr.ib(default=None, metadata={"json": "groupV2"})

+ 135 - 0
mautrix_signal/formatter.py

@@ -0,0 +1,135 @@
+# mautrix-signal - A Matrix-Signal puppeting bridge
+# Copyright (C) 2020 Tulir Asokan
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+from typing import Tuple, List, cast
+from html import escape
+import struct
+
+from mausignald.types import MessageData, Address, Mention
+from mautrix.types import TextMessageEventContent, MessageType, Format
+from mautrix.util.formatter import (MatrixParser as BaseMatrixParser, EntityString, SimpleEntity,
+                                    EntityType, MarkdownString)
+
+from . import puppet as pu, user as u
+
+
+# Helper methods from rom https://github.com/LonamiWebs/Telethon/blob/master/telethon/helpers.py
+# I don't know if this is how Signal actually calculates lengths, but it seems
+# to work better than plain len()
+def add_surrogate(text: str) -> str:
+    return ''.join(
+        ''.join(chr(y) for y in struct.unpack('<HH', x.encode('utf-16le')))
+        if (0x10000 <= ord(x) <= 0x10FFFF) else x for x in text
+    )
+
+
+def del_surrogate(text: str) -> str:
+    return text.encode('utf-16', 'surrogatepass').decode('utf-16')
+
+
+async def signal_to_matrix(message: MessageData) -> TextMessageEventContent:
+    content = TextMessageEventContent(msgtype=MessageType.TEXT, body=message.body)
+    surrogated_text = add_surrogate(message.body)
+    if message.mentions:
+        text_chunks = []
+        html_chunks = []
+        last_offset = 0
+        for mention in message.mentions:
+            before = surrogated_text[last_offset:mention.start]
+            last_offset = mention.start + mention.length
+
+            text_chunks.append(before)
+            html_chunks.append(escape(before))
+            puppet = await pu.Puppet.get_by_address(Address(uuid=mention.uuid))
+            name = add_surrogate(puppet.name or puppet.mxid)
+            text_chunks.append(name)
+            html_chunks.append(f'<a href="https://matrix.to/#/{puppet.mxid}">{name}</a>')
+        end = surrogated_text[last_offset:]
+        text_chunks.append(end)
+        html_chunks.append(escape(end))
+        content.body = del_surrogate("".join(text_chunks))
+        content.format = Format.HTML
+        content.formatted_body = del_surrogate("".join(html_chunks))
+    return content
+
+
+# TODO this has a lot of duplication with mautrix-facebook, maybe move to mautrix-python
+class SignalFormatString(EntityString[SimpleEntity, EntityType], MarkdownString):
+    def format(self, entity_type: EntityType, **kwargs) -> 'SignalFormatString':
+        prefix = suffix = ""
+        if entity_type == EntityType.USER_MENTION:
+            self.entities.append(SimpleEntity(type=entity_type, offset=0, length=len(self.text),
+                                              extra_info={"user_id": kwargs["user_id"]}))
+            return self
+        elif entity_type == EntityType.BOLD:
+            prefix = suffix = "**"
+        elif entity_type == EntityType.ITALIC:
+            prefix = suffix = "_"
+        elif entity_type == EntityType.STRIKETHROUGH:
+            prefix = suffix = "~~"
+        elif entity_type == EntityType.URL:
+            if kwargs['url'] != self.text:
+                suffix = f" ({kwargs['url']})"
+        elif entity_type == EntityType.PREFORMATTED:
+            prefix = f"```{kwargs['language']}\n"
+            suffix = "\n```"
+        elif entity_type == EntityType.INLINE_CODE:
+            prefix = suffix = "`"
+        elif entity_type == EntityType.BLOCKQUOTE:
+            children = self.trim().split("\n")
+            children = [child.prepend("> ") for child in children]
+            return self.join(children, "\n")
+        elif entity_type == EntityType.HEADER:
+            prefix = "#" * kwargs["size"] + " "
+        else:
+            return self
+
+        self._offset_entities(len(prefix))
+        self.text = f"{prefix}{self.text}{suffix}"
+        return self
+
+
+class MatrixParser(BaseMatrixParser[SignalFormatString]):
+    fs = SignalFormatString
+
+    @classmethod
+    def parse(cls, data: str) -> SignalFormatString:
+        return cast(SignalFormatString, super().parse(data))
+
+
+async def matrix_to_signal(content: TextMessageEventContent) -> Tuple[str, List[Mention]]:
+    if content.msgtype == MessageType.EMOTE:
+        content.body = f"/me {content.body}"
+        if content.formatted_body:
+            content.formatted_body = f"/me {content.formatted_body}"
+    mentions = []
+    if content.format == Format.HTML and content.formatted_body:
+        parsed = MatrixParser.parse(add_surrogate(content.formatted_body))
+        text = del_surrogate(parsed.text)
+        for mention in parsed.entities:
+            mxid = mention.extra_info["user_id"]
+            user = await u.User.get_by_mxid(mxid, create=False)
+            if user and user.uuid:
+                uuid = user.uuid
+            else:
+                puppet = await pu.Puppet.get_by_mxid(mxid, create=False)
+                if puppet:
+                    uuid = puppet.uuid
+                else:
+                    continue
+            mentions.append(Mention(uuid=uuid, start=mention.offset, length=mention.length))
+    else:
+        text = content.body
+    return text, mentions

+ 13 - 8
mautrix_signal/portal.py

@@ -25,16 +25,17 @@ import time
 import os
 import os
 
 
 from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
 from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
-                              Attachment, GroupID, GroupV2ID, GroupV2)
+                              Attachment, GroupID, GroupV2ID, GroupV2, Mention)
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.bridge import BasePortal
 from mautrix.bridge import BasePortal
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
-                           TextMessageEventContent, MessageEvent, EncryptedEvent, ContentURI,
-                           MediaMessageEventContent, ImageInfo, VideoInfo, FileInfo, AudioInfo)
+                           MessageEvent, EncryptedEvent, ContentURI, MediaMessageEventContent,
+                           ImageInfo, VideoInfo, FileInfo, AudioInfo)
 from mautrix.errors import MatrixError, MForbidden
 from mautrix.errors import MatrixError, MForbidden
 
 
 from .db import Portal as DBPortal, Message as DBMessage, Reaction as DBReaction
 from .db import Portal as DBPortal, Message as DBMessage, Reaction as DBReaction
 from .config import Config
 from .config import Config
+from .formatter import matrix_to_signal, signal_to_matrix
 from . import user as u, puppet as p, matrix as m, signal as s
 from . import user as u, puppet as p, matrix as m, signal as s
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
@@ -180,19 +181,23 @@ class Portal(DBPortal, BasePortal):
             if reply is not None:
             if reply is not None:
                 quote = Quote(id=reply.timestamp, author=reply.sender, text="")
                 quote = Quote(id=reply.timestamp, author=reply.sender, text="")
 
 
-        text = message.body
         attachments: Optional[List[Attachment]] = None
         attachments: Optional[List[Attachment]] = None
         attachment_path: Optional[str] = None
         attachment_path: Optional[str] = None
-        if message.msgtype == MessageType.EMOTE:
-            text = f"/me {text}"
+        mentions: Optional[List[Mention]] = None
+        if message.msgtype.is_text:
+            text, mentions = await matrix_to_signal(message)
         elif message.msgtype.is_media:
         elif message.msgtype.is_media:
             attachment_path = await self._download_matrix_media(message)
             attachment_path = await self._download_matrix_media(message)
             attachment = self._make_attachment(message, attachment_path)
             attachment = self._make_attachment(message, attachment_path)
             attachments = [attachment]
             attachments = [attachment]
             text = None
             text = None
             self.log.trace("Formed outgoing attachment %s", attachment)
             self.log.trace("Formed outgoing attachment %s", attachment)
+        else:
+            self.log.debug(f"Unknown msgtype {message.msgtype} in Matrix message {event_id}")
+            return
         await self.signal.send(username=sender.username, recipient=self.chat_id, body=text,
         await self.signal.send(username=sender.username, recipient=self.chat_id, body=text,
-                               quote=quote, attachments=attachments, timestamp=request_id)
+                               mentions=mentions, quote=quote, attachments=attachments,
+                               timestamp=request_id)
         msg = DBMessage(mxid=event_id, mx_room=self.mxid, sender=sender.address,
         msg = DBMessage(mxid=event_id, mx_room=self.mxid, sender=sender.address,
                         timestamp=request_id,
                         timestamp=request_id,
                         signal_chat_id=self.chat_id, signal_receiver=self.receiver)
                         signal_chat_id=self.chat_id, signal_receiver=self.receiver)
@@ -340,7 +345,7 @@ class Portal(DBPortal, BasePortal):
             event_id = await self._send_message(intent, content, timestamp=message.timestamp)
             event_id = await self._send_message(intent, content, timestamp=message.timestamp)
 
 
         if message.body:
         if message.body:
-            content = TextMessageEventContent(msgtype=MessageType.TEXT, body=message.body)
+            content = await signal_to_matrix(message)
             if reply_to:
             if reply_to:
                 content.set_reply(reply_to)
                 content.set_reply(reply_to)
             event_id = await self._send_message(intent, content, timestamp=message.timestamp)
             event_id = await self._send_message(intent, content, timestamp=message.timestamp)