signal.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2022 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING
  18. import asyncio
  19. import logging
  20. from mausignald import SignaldClient
  21. from mausignald.types import (
  22. Address,
  23. ErrorMessage,
  24. IncomingMessage,
  25. MessageData,
  26. OfferMessageType,
  27. OwnReadReceipt,
  28. ReceiptMessage,
  29. ReceiptType,
  30. TypingAction,
  31. TypingMessage,
  32. WebsocketConnectionStateChangeEvent,
  33. )
  34. from mautrix.types import EventID, MessageType, RoomID
  35. from mautrix.util.logging import TraceLogger
  36. from . import portal as po, puppet as pu, user as u
  37. from .db import Message as DBMessage
  38. if TYPE_CHECKING:
  39. from .__main__ import SignalBridge
  40. # Typing notifications seem to get resent every 10 seconds and the timeout is around 15 seconds
  41. SIGNAL_TYPING_TIMEOUT = 15000
  42. class SignalHandler(SignaldClient):
  43. log: TraceLogger = logging.getLogger("mau.signal")
  44. loop: asyncio.AbstractEventLoop
  45. data_dir: str
  46. delete_unknown_accounts: bool
  47. error_message_lock: asyncio.Lock
  48. error_message_events: dict[tuple[RoomID, Address, int], EventID | None]
  49. def __init__(self, bridge: "SignalBridge") -> None:
  50. super().__init__(bridge.config["signal.socket_path"], loop=bridge.loop)
  51. self.data_dir = bridge.config["signal.data_dir"]
  52. self.delete_unknown_accounts = bridge.config["signal.delete_unknown_accounts_on_start"]
  53. self.error_message_lock = asyncio.Lock()
  54. self.error_message_events = {}
  55. self.add_event_handler(IncomingMessage, self.on_message)
  56. self.add_event_handler(ErrorMessage, self.on_error_message)
  57. self.add_event_handler(
  58. WebsocketConnectionStateChangeEvent, self.on_websocket_connection_state_change
  59. )
  60. async def on_message(self, evt: IncomingMessage) -> None:
  61. sender = await pu.Puppet.get_by_address(evt.source)
  62. user = await u.User.get_by_username(evt.account)
  63. # TODO add lots of logging
  64. if evt.data_message:
  65. await self.handle_message(user, sender, evt.data_message)
  66. if evt.typing_message:
  67. await self.handle_typing(user, sender, evt.typing_message)
  68. if evt.receipt_message:
  69. await self.handle_receipt(sender, evt.receipt_message)
  70. if evt.call_message:
  71. await self.handle_call_message(user, sender, evt)
  72. if evt.sync_message:
  73. if evt.sync_message.read_messages:
  74. await self.handle_own_receipts(sender, evt.sync_message.read_messages)
  75. if evt.sync_message.sent:
  76. await self.handle_message(
  77. user,
  78. sender,
  79. evt.sync_message.sent.message,
  80. addr_override=evt.sync_message.sent.destination,
  81. )
  82. if evt.sync_message.contacts or evt.sync_message.contacts_complete:
  83. self.log.debug("Sync message includes contacts meta, syncing contacts...")
  84. await user.sync_contacts()
  85. if evt.sync_message.groups:
  86. self.log.debug("Sync message includes groups meta, syncing groups...")
  87. await user.sync_groups()
  88. async with self.error_message_lock:
  89. portal = await po.Portal.get_by_chat_id(sender.address, receiver=user.username)
  90. if not portal or not portal.mxid:
  91. return
  92. error_message_event_key = (portal.mxid, sender.address, evt.timestamp)
  93. if error_message_event_key in self.error_message_events:
  94. event_id = self.error_message_events[error_message_event_key]
  95. if event_id is not None:
  96. await sender.intent_for(portal).redact(portal.mxid, event_id)
  97. del self.error_message_events[error_message_event_key]
  98. async def on_error_message(self, err: ErrorMessage) -> None:
  99. sender = await pu.Puppet.get_by_address(Address.parse(err.data.sender))
  100. user = await u.User.get_by_username(err.account)
  101. portal = await po.Portal.get_by_chat_id(sender.address, receiver=user.username)
  102. if not portal or not portal.mxid:
  103. return
  104. # Add the error to the error_message_events dictionary, then wait for 10 seconds until
  105. # sending an error. If a success for the timestamp comes in before the 10 seconds is up,
  106. # don't send the error message.
  107. error_message_event_key = (portal.mxid, sender.address, err.data.timestamp)
  108. async with self.error_message_lock:
  109. self.error_message_events[error_message_event_key] = None
  110. await asyncio.sleep(10)
  111. err_text = (
  112. "There was an error receiving a message. Check your Signal app for missing messages. "
  113. f"{err.type}: {err.data.message}"
  114. )
  115. async with self.error_message_lock:
  116. if error_message_event_key in self.error_message_events:
  117. event_id = await sender.intent_for(portal).send_text(
  118. portal.mxid, html=err_text, msgtype=MessageType.NOTICE
  119. )
  120. self.error_message_events[error_message_event_key] = event_id
  121. @staticmethod
  122. async def on_websocket_connection_state_change(
  123. evt: WebsocketConnectionStateChangeEvent,
  124. ) -> None:
  125. user = await u.User.get_by_username(evt.account)
  126. user.on_websocket_connection_state_change(evt)
  127. async def handle_message(
  128. self,
  129. user: u.User,
  130. sender: pu.Puppet,
  131. msg: MessageData,
  132. addr_override: Address | None = None,
  133. ) -> None:
  134. if msg.profile_key_update:
  135. self.log.debug("Ignoring profile key update")
  136. return
  137. if msg.group_v2:
  138. portal = await po.Portal.get_by_chat_id(msg.group_v2.id, create=True)
  139. elif msg.group:
  140. portal = await po.Portal.get_by_chat_id(msg.group.group_id, create=True)
  141. else:
  142. portal = await po.Portal.get_by_chat_id(
  143. addr_override or sender.address, receiver=user.username, create=True
  144. )
  145. if addr_override and not sender.is_real_user:
  146. portal.log.debug(
  147. f"Ignoring own message {msg.timestamp} as user doesn't have double puppeting "
  148. "enabled"
  149. )
  150. return
  151. assert portal
  152. if not portal.mxid:
  153. if not msg.is_message and not msg.group_v2:
  154. user.log.debug(
  155. f"Ignoring message {msg.timestamp},"
  156. " probably not bridgeable as there's no portal yet"
  157. )
  158. return
  159. await portal.create_matrix_room(
  160. user, msg.group_v2 or msg.group or addr_override or sender.address
  161. )
  162. if not portal.mxid:
  163. user.log.warning(
  164. f"Failed to create room for incoming message {msg.timestamp}, dropping message"
  165. )
  166. return
  167. elif msg.group_v2 and msg.group_v2.revision > portal.revision:
  168. self.log.debug(f"Got new revision of {msg.group_v2.id}, updating info")
  169. await portal.update_info(user, msg.group_v2, sender)
  170. if msg.reaction:
  171. await portal.handle_signal_reaction(sender, msg.reaction, msg.timestamp)
  172. if msg.is_message:
  173. await portal.handle_signal_message(user, sender, msg)
  174. if msg.expires_in_seconds is not None:
  175. await portal.update_expires_in_seconds(sender, msg.expires_in_seconds)
  176. if msg.group and msg.group.type == "UPDATE":
  177. await portal.update_info(user, msg.group)
  178. if msg.remote_delete:
  179. await portal.handle_signal_delete(sender, msg.remote_delete.target_sent_timestamp)
  180. @staticmethod
  181. async def handle_call_message(user: u.User, sender: pu.Puppet, msg: IncomingMessage) -> None:
  182. assert msg.call_message
  183. portal = await po.Portal.get_by_chat_id(
  184. sender.address, receiver=user.username, create=True
  185. )
  186. if not portal.mxid:
  187. # FIXME
  188. # await portal.create_matrix_room(
  189. # user, (msg.group_v2 or msg.group or addr_override or sender.address)
  190. # )
  191. # if not portal.mxid:
  192. # user.log.debug(
  193. # f"Failed to create room for incoming message {msg.timestamp},"
  194. # " dropping message"
  195. # )
  196. return
  197. msg_html = f'<a href="https://matrix.to/#/{sender.mxid}">{sender.name}</a>'
  198. if msg.call_message.offer_message:
  199. call_type = {
  200. OfferMessageType.AUDIO_CALL: "voice call",
  201. OfferMessageType.VIDEO_CALL: "video call",
  202. }.get(msg.call_message.offer_message.type, "call")
  203. msg_html += f" started a {call_type} on Signal. Use the native app to answer the call."
  204. msg_type = MessageType.TEXT
  205. elif msg.call_message.hangup_message:
  206. msg_html += " ended a call on Signal."
  207. msg_type = MessageType.NOTICE
  208. else:
  209. portal.log.debug(f"Unhandled call message. Likely an ICE message. {msg.call_message}")
  210. return
  211. await sender.intent_for(portal).send_text(portal.mxid, html=msg_html, msgtype=msg_type)
  212. @staticmethod
  213. async def handle_own_receipts(sender: pu.Puppet, receipts: list[OwnReadReceipt]) -> None:
  214. for receipt in receipts:
  215. puppet = await pu.Puppet.get_by_address(receipt.sender, create=False)
  216. if not puppet:
  217. continue
  218. message = await DBMessage.find_by_sender_timestamp(puppet.address, receipt.timestamp)
  219. if not message:
  220. continue
  221. portal = await po.Portal.get_by_mxid(message.mx_room)
  222. if not portal or (portal.is_direct and not sender.is_real_user):
  223. continue
  224. await sender.intent_for(portal).mark_read(portal.mxid, message.mxid)
  225. @staticmethod
  226. async def handle_typing(user: u.User, sender: pu.Puppet, typing: TypingMessage) -> None:
  227. if typing.group_id:
  228. portal = await po.Portal.get_by_chat_id(typing.group_id)
  229. else:
  230. portal = await po.Portal.get_by_chat_id(sender.address, receiver=user.username)
  231. if not portal or not portal.mxid:
  232. return
  233. is_typing = typing.action == TypingAction.STARTED
  234. await sender.intent_for(portal).set_typing(
  235. portal.mxid, is_typing, ignore_cache=True, timeout=SIGNAL_TYPING_TIMEOUT
  236. )
  237. @staticmethod
  238. async def handle_receipt(sender: pu.Puppet, receipt: ReceiptMessage) -> None:
  239. if receipt.type != ReceiptType.READ:
  240. return
  241. messages = await DBMessage.find_by_timestamps(receipt.timestamps)
  242. for message in messages:
  243. portal = await po.Portal.get_by_mxid(message.mx_room)
  244. await sender.intent_for(portal).mark_read(portal.mxid, message.mxid)
  245. async def start(self) -> None:
  246. await self.connect()
  247. known_usernames = set()
  248. async for user in u.User.all_logged_in():
  249. # TODO report errors to user?
  250. known_usernames.add(user.username)
  251. if await self.subscribe(user.username):
  252. asyncio.create_task(user.sync())
  253. if self.delete_unknown_accounts:
  254. self.log.debug("Checking for unknown accounts to delete")
  255. for account in await self.list_accounts():
  256. if account.account_id not in known_usernames:
  257. self.log.warning(f"Unknown account ID {account.account_id}, deleting...")
  258. await self.delete_account(account.account_id)
  259. async def stop(self) -> None:
  260. await self.disconnect()