signal.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import Optional, List, TYPE_CHECKING
  17. import asyncio
  18. import logging
  19. from mausignald import SignaldClient
  20. from mausignald.types import (Message, MessageData, Address, TypingNotification, TypingAction,
  21. OwnReadReceipt, Receipt, ReceiptType, ListenEvent)
  22. from mautrix.util.logging import TraceLogger
  23. from .db import Message as DBMessage
  24. from . import user as u, portal as po, puppet as pu
  25. if TYPE_CHECKING:
  26. from .__main__ import SignalBridge
  27. # Typing notifications seem to get resent every 10 seconds and the timeout is around 15 seconds
  28. SIGNAL_TYPING_TIMEOUT = 15000
  29. class SignalHandler(SignaldClient):
  30. log: TraceLogger = logging.getLogger("mau.signal")
  31. loop: asyncio.AbstractEventLoop
  32. def __init__(self, bridge: 'SignalBridge') -> None:
  33. super().__init__(bridge.config["signal.socket_path"], loop=bridge.loop)
  34. self.add_event_handler(Message, self.on_message)
  35. self.add_event_handler(ListenEvent, self.on_listen)
  36. async def on_message(self, evt: Message) -> None:
  37. sender = await pu.Puppet.get_by_address(evt.source)
  38. user = await u.User.get_by_username(evt.username)
  39. # TODO add lots of logging
  40. if evt.data_message:
  41. await self.handle_message(user, sender, evt.data_message)
  42. if evt.typing:
  43. await self.handle_typing(user, sender, evt.typing)
  44. if evt.receipt:
  45. await self.handle_receipt(sender, evt.receipt)
  46. if evt.sync_message:
  47. if evt.sync_message.read_messages:
  48. await self.handle_own_receipts(sender, evt.sync_message.read_messages)
  49. if evt.sync_message.contacts:
  50. # Contact list update?
  51. pass
  52. if evt.sync_message.sent:
  53. await self.handle_message(user, sender, evt.sync_message.sent.message,
  54. addr_override=evt.sync_message.sent.destination)
  55. if evt.sync_message.typing:
  56. # Typing notification from own device
  57. pass
  58. @staticmethod
  59. async def on_listen(evt: ListenEvent) -> None:
  60. user = await u.User.get_by_username(evt.username)
  61. user.on_listen(evt)
  62. async def handle_message(self, user: 'u.User', sender: 'pu.Puppet', msg: MessageData,
  63. addr_override: Optional[Address] = None) -> None:
  64. group_v2_info = None
  65. if msg.group_v2:
  66. portal = await po.Portal.get_by_chat_id(msg.group_v2.id, create=True)
  67. if not portal.mxid:
  68. group_v2_info = await self.get_group(user.username, msg.group_v2.id)
  69. if not group_v2_info:
  70. user.log.debug(f"Dropping message in unknown v2 group {msg.group_v2.id}")
  71. return
  72. elif msg.group:
  73. portal = await po.Portal.get_by_chat_id(msg.group.group_id, create=True)
  74. else:
  75. portal = await po.Portal.get_by_chat_id(addr_override or sender.address,
  76. receiver=user.username, create=True)
  77. if addr_override and not sender.is_real_user:
  78. portal.log.debug(f"Ignoring own message {msg.timestamp} as user doesn't have"
  79. " double puppeting enabled")
  80. return
  81. if not portal.mxid:
  82. await portal.create_matrix_room(user, (group_v2_info or msg.group
  83. or addr_override or sender.address))
  84. if msg.reaction:
  85. await portal.handle_signal_reaction(sender, msg.reaction)
  86. if msg.body or msg.attachments or msg.sticker:
  87. await portal.handle_signal_message(user, sender, msg)
  88. if msg.group and msg.group.type == "UPDATE":
  89. await portal.update_info(msg.group)
  90. if msg.remote_delete:
  91. await portal.handle_signal_delete(sender, msg.remote_delete.target_sent_timestamp)
  92. @staticmethod
  93. async def handle_own_receipts(sender: 'pu.Puppet', receipts: List[OwnReadReceipt]) -> None:
  94. for receipt in receipts:
  95. puppet = await pu.Puppet.get_by_address(receipt.sender, create=False)
  96. if not puppet:
  97. continue
  98. message = await DBMessage.find_by_sender_timestamp(puppet.address, receipt.timestamp)
  99. if not message:
  100. continue
  101. portal = await po.Portal.get_by_mxid(message.mx_room)
  102. if not portal or (portal.is_direct and not sender.is_real_user):
  103. continue
  104. await sender.intent_for(portal).mark_read(portal.mxid, message.mxid)
  105. @staticmethod
  106. async def handle_typing(user: 'u.User', sender: 'pu.Puppet',
  107. typing: TypingNotification) -> None:
  108. if typing.group_id:
  109. portal = await po.Portal.get_by_chat_id(typing.group_id)
  110. else:
  111. portal = await po.Portal.get_by_chat_id(sender.address, receiver=user.username)
  112. if not portal or not portal.mxid:
  113. return
  114. is_typing = typing.action == TypingAction.STARTED
  115. await sender.intent_for(portal).set_typing(portal.mxid, is_typing, ignore_cache=True,
  116. timeout=SIGNAL_TYPING_TIMEOUT)
  117. @staticmethod
  118. async def handle_receipt(sender: 'pu.Puppet', receipt: Receipt) -> None:
  119. if receipt.type != ReceiptType.READ:
  120. pass
  121. messages = await DBMessage.find_by_timestamps(receipt.timestamps)
  122. for message in messages:
  123. portal = await po.Portal.get_by_mxid(message.mx_room)
  124. await sender.intent_for(portal).mark_read(portal.mxid, message.mxid)
  125. async def start(self) -> None:
  126. await self.connect()
  127. async for user in u.User.all_logged_in():
  128. # TODO report errors to user?
  129. if await self.subscribe(user.username):
  130. self.loop.create_task(user.sync())
  131. async def stop(self) -> None:
  132. await self.disconnect()