signal.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2022 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING, Awaitable
  18. from uuid import UUID
  19. import asyncio
  20. import logging
  21. from mausignald import SignaldClient
  22. from mausignald.types import (
  23. Address,
  24. ErrorMessage,
  25. IncomingMessage,
  26. MessageData,
  27. MessageResendSuccessEvent,
  28. OfferMessageType,
  29. OwnReadReceipt,
  30. ReceiptMessage,
  31. ReceiptType,
  32. StorageChange,
  33. TypingAction,
  34. TypingMessage,
  35. WebsocketConnectionStateChangeEvent,
  36. )
  37. from mautrix.types import EventID, EventType, Format, MessageType, TextMessageEventContent
  38. from mautrix.util import background_task
  39. from mautrix.util.logging import TraceLogger
  40. from mautrix.util.message_send_checkpoint import MessageSendCheckpointStatus
  41. from . import portal as po, puppet as pu, user as u
  42. from .db import Message as DBMessage
  43. from .web.segment_analytics import track
  44. if TYPE_CHECKING:
  45. from .__main__ import SignalBridge
  46. # Typing notifications seem to get resent every 10 seconds and the timeout is around 15 seconds
  47. SIGNAL_TYPING_TIMEOUT = 15000
  48. class SignalHandler(SignaldClient):
  49. log: TraceLogger = logging.getLogger("mau.signal")
  50. loop: asyncio.AbstractEventLoop
  51. data_dir: str
  52. delete_unknown_accounts: bool
  53. error_message_events: dict[tuple[UUID, str, int], Awaitable[EventID] | None]
  54. def __init__(self, bridge: "SignalBridge") -> None:
  55. super().__init__(bridge.config["signal.socket_path"], loop=bridge.loop)
  56. self.data_dir = bridge.config["signal.data_dir"]
  57. self.delete_unknown_accounts = bridge.config["signal.delete_unknown_accounts_on_start"]
  58. self.error_message_events = {}
  59. self.add_event_handler(IncomingMessage, self.on_message)
  60. self.add_event_handler(ErrorMessage, self.on_error_message)
  61. self.add_event_handler(StorageChange, self.on_storage_change)
  62. self.add_event_handler(
  63. WebsocketConnectionStateChangeEvent, self.on_websocket_connection_state_change
  64. )
  65. self.add_event_handler(MessageResendSuccessEvent, self.on_message_resend_success)
  66. async def on_message(self, evt: IncomingMessage) -> None:
  67. sender = await pu.Puppet.get_by_address(evt.source, resolve_via=evt.account)
  68. if not sender:
  69. self.log.warning(f"Didn't find puppet for incoming message {evt.source}")
  70. return
  71. user = await u.User.get_by_username(evt.account)
  72. # TODO add lots of logging
  73. if evt.data_message:
  74. await self.handle_message(user, sender, evt.data_message)
  75. if evt.typing_message:
  76. await self.handle_typing(user, sender, evt.typing_message)
  77. if evt.receipt_message:
  78. await self.handle_receipt(sender, evt.receipt_message)
  79. if evt.call_message:
  80. await self.handle_call_message(user, sender, evt)
  81. if evt.decryption_error_message:
  82. await self.handle_decryption_error(user, sender, evt)
  83. if evt.sync_message:
  84. if evt.sync_message.read_messages:
  85. await self.handle_own_receipts(sender, evt.sync_message.read_messages)
  86. if evt.sync_message.sent:
  87. if (
  88. evt.sync_message.sent.destination
  89. and not evt.sync_message.sent.destination.uuid
  90. ):
  91. self.log.warning(
  92. "Got sent message without destination UUID "
  93. f"{evt.sync_message.sent.destination}"
  94. )
  95. await self.handle_message(
  96. user,
  97. sender,
  98. evt.sync_message.sent.message,
  99. addr_override=evt.sync_message.sent.destination,
  100. )
  101. if evt.sync_message.contacts or evt.sync_message.contacts_complete:
  102. self.log.debug("Sync message includes contacts meta, syncing contacts...")
  103. await user.sync_contacts()
  104. if evt.sync_message.groups:
  105. self.log.debug("Sync message includes groups meta, syncing groups...")
  106. await user.sync_groups()
  107. try:
  108. event_id_future = self.error_message_events.pop(
  109. (sender.uuid, user.username, evt.timestamp)
  110. )
  111. except KeyError:
  112. pass
  113. else:
  114. self.log.debug(f"Got previously errored message {evt.timestamp} from {sender.address}")
  115. event_id = await event_id_future if event_id_future is not None else None
  116. if event_id is not None:
  117. portal = await po.Portal.get_by_chat_id(sender.uuid, receiver=user.username)
  118. if portal and portal.mxid:
  119. await sender.intent_for(portal).redact(portal.mxid, event_id)
  120. error = {"sender": str(sender.uuid), "timestamp": str(evt.timestamp)}
  121. track(user, "$signal_inbound_error_redacted", error)
  122. async def on_error_message(self, err: ErrorMessage) -> None:
  123. self.log.warning(
  124. f"Error reading message from {err.data.sender}/{err.data.sender_device} "
  125. f"(timestamp: {err.data.timestamp}, content hint: {err.data.content_hint}): "
  126. f"{err.data.message}"
  127. )
  128. if err.data.content_hint == 2:
  129. return
  130. sender = await pu.Puppet.get_by_address(
  131. Address.parse(err.data.sender), resolve_via=err.account
  132. )
  133. if not sender:
  134. return
  135. user = await u.User.get_by_username(err.account)
  136. portal = await po.Portal.get_by_chat_id(sender.uuid, receiver=user.username)
  137. if not portal or not portal.mxid:
  138. return
  139. # Add the error to the error_message_events dictionary, then wait for 10 seconds until
  140. # sending an error. If a success for the timestamp comes in before the 10 seconds is up,
  141. # don't send the error message.
  142. error_message_event_key = (sender.uuid, user.username, err.data.timestamp)
  143. self.error_message_events[error_message_event_key] = None
  144. await asyncio.sleep(10)
  145. err_text = (
  146. "There was an error receiving a message. Check your Signal app for missing messages."
  147. )
  148. if error_message_event_key in self.error_message_events:
  149. fut = self.error_message_events[error_message_event_key] = self.loop.create_future()
  150. event_id = None
  151. try:
  152. event_id = await portal._send_message(
  153. intent=sender.intent_for(portal),
  154. content=TextMessageEventContent(body=err_text, msgtype=MessageType.NOTICE),
  155. )
  156. error = {
  157. "message": err_text,
  158. "sender": str(sender.uuid),
  159. "timestamp": str(err.data.timestamp),
  160. }
  161. track(user, "$signal_inbound_error_displayed", error)
  162. finally:
  163. fut.set_result(event_id)
  164. async def on_storage_change(self, storage_change: StorageChange) -> None:
  165. self.log.info("Handling StorageChange %s", str(storage_change))
  166. if user := await u.User.get_by_username(storage_change.account):
  167. await user.sync()
  168. @staticmethod
  169. async def on_websocket_connection_state_change(
  170. evt: WebsocketConnectionStateChangeEvent,
  171. ) -> None:
  172. user = await u.User.get_by_username(evt.account)
  173. user.on_websocket_connection_state_change(evt)
  174. @staticmethod
  175. async def on_message_resend_success(evt: MessageResendSuccessEvent):
  176. user = await u.User.get_by_username(evt.account)
  177. await user.on_message_resend_success(evt)
  178. async def handle_message(
  179. self,
  180. user: u.User,
  181. sender: pu.Puppet,
  182. msg: MessageData,
  183. addr_override: Address | None = None,
  184. ) -> None:
  185. try:
  186. await self._handle_message(user, sender, msg, addr_override)
  187. except Exception as e:
  188. await user.handle_auth_failure(e)
  189. raise
  190. async def _handle_message(
  191. self,
  192. user: u.User,
  193. sender: pu.Puppet,
  194. msg: MessageData,
  195. addr_override: Address | None = None,
  196. ) -> None:
  197. if msg.profile_key_update:
  198. background_task.create(user.sync_contact(sender.address, use_cache=False))
  199. return
  200. if msg.group_v2:
  201. portal = await po.Portal.get_by_chat_id(msg.group_v2.id, create=True)
  202. else:
  203. if addr_override and not addr_override.uuid:
  204. target = await pu.Puppet.get_by_address(addr_override, resolve_via=user.username)
  205. if not target:
  206. self.log.warning(
  207. f"Didn't find puppet for recipient of incoming message {addr_override}"
  208. )
  209. return
  210. portal = await po.Portal.get_by_chat_id(
  211. addr_override.uuid if addr_override else sender.uuid,
  212. receiver=user.username,
  213. create=True,
  214. )
  215. if addr_override and not sender.is_real_user:
  216. portal.log.debug(
  217. f"Ignoring own message {msg.timestamp} as user doesn't have double puppeting "
  218. "enabled"
  219. )
  220. return
  221. assert portal
  222. # Handle the user being removed from the group.
  223. if msg.group_v2 and msg.group_v2.removed:
  224. if portal.mxid:
  225. await portal.handle_signal_kicked(user, sender)
  226. return
  227. if not portal.mxid:
  228. if not msg.is_message and not msg.group_v2:
  229. user.log.debug(
  230. f"Ignoring message {msg.timestamp},"
  231. " probably not bridgeable as there's no portal yet"
  232. )
  233. return
  234. await portal.create_matrix_room(user, msg.group_v2 or addr_override or sender.address)
  235. if not portal.mxid:
  236. user.log.warning(
  237. f"Failed to create room for incoming message {msg.timestamp}, dropping message"
  238. )
  239. return
  240. elif (
  241. msg.group_v2
  242. and msg.group_v2.group_change
  243. and msg.group_v2.revision == portal.revision + 1
  244. ):
  245. self.log.debug(
  246. f"Got update for {msg.group_v2.id} ({portal.revision} -> "
  247. f"{msg.group_v2.revision}), applying diff"
  248. )
  249. await portal.handle_signal_group_change(msg.group_v2.group_change, user)
  250. elif msg.group_v2 and msg.group_v2.revision > portal.revision:
  251. self.log.debug(
  252. f"Got update with multiple revisions for {msg.group_v2.id} ({portal.revision} -> "
  253. f"{msg.group_v2.revision}), resyncing info"
  254. )
  255. await portal.update_info(user, msg.group_v2)
  256. if msg.expires_in_seconds is not None and (msg.is_message or msg.is_expiration_update):
  257. await portal.update_expires_in_seconds(sender, msg.expires_in_seconds)
  258. if msg.reaction:
  259. await portal.handle_signal_reaction(sender, msg.reaction, msg.timestamp)
  260. if msg.is_message:
  261. await portal.handle_signal_message(user, sender, msg)
  262. if msg.remote_delete:
  263. await portal.handle_signal_delete(sender, msg.remote_delete.target_sent_timestamp)
  264. @staticmethod
  265. async def handle_call_message(user: u.User, sender: pu.Puppet, msg: IncomingMessage) -> None:
  266. assert msg.call_message
  267. portal = await po.Portal.get_by_chat_id(sender.uuid, receiver=user.username, create=True)
  268. if not portal.mxid:
  269. # FIXME
  270. # await portal.create_matrix_room(
  271. # user, (msg.group_v2 or msg.group or addr_override or sender.address)
  272. # )
  273. # if not portal.mxid:
  274. # user.log.debug(
  275. # f"Failed to create room for incoming message {msg.timestamp},"
  276. # " dropping message"
  277. # )
  278. return
  279. msg_prefix_html = f'<a href="https://matrix.to/#/{sender.mxid}">{sender.name}</a>'
  280. msg_prefix_text = f"{sender.name}"
  281. msg_suffix = ""
  282. if msg.call_message.offer_message:
  283. call_type = {
  284. OfferMessageType.AUDIO_CALL: "voice call",
  285. OfferMessageType.VIDEO_CALL: "video call",
  286. }.get(msg.call_message.offer_message.type, "call")
  287. msg_suffix = (
  288. f" started a {call_type} on Signal. Use the native app to answer the call."
  289. )
  290. msg_type = MessageType.TEXT
  291. elif msg.call_message.hangup_message:
  292. msg_suffix = " ended a call on Signal."
  293. msg_type = MessageType.NOTICE
  294. else:
  295. portal.log.debug(f"Unhandled call message. Likely an ICE message. {msg.call_message}")
  296. return
  297. await portal._send_message(
  298. intent=sender.intent_for(portal),
  299. content=TextMessageEventContent(
  300. format=Format.HTML,
  301. formatted_body=msg_prefix_html + msg_suffix,
  302. body=msg_prefix_text + msg_suffix,
  303. msgtype=msg_type,
  304. ),
  305. )
  306. @staticmethod
  307. async def handle_own_receipts(sender: pu.Puppet, receipts: list[OwnReadReceipt]) -> None:
  308. for receipt in receipts:
  309. puppet = await pu.Puppet.get_by_address(receipt.sender, create=False)
  310. if not puppet:
  311. continue
  312. message = await DBMessage.find_by_sender_timestamp(puppet.uuid, receipt.timestamp)
  313. if not message:
  314. continue
  315. portal = await po.Portal.get_by_mxid(message.mx_room)
  316. if not portal or (portal.is_direct and not sender.is_real_user):
  317. continue
  318. await sender.intent_for(portal).mark_read(portal.mxid, message.mxid)
  319. @staticmethod
  320. async def handle_typing(user: u.User, sender: pu.Puppet, typing: TypingMessage) -> None:
  321. if typing.group_id:
  322. portal = await po.Portal.get_by_chat_id(typing.group_id)
  323. else:
  324. portal = await po.Portal.get_by_chat_id(sender.uuid, receiver=user.username)
  325. if not portal or not portal.mxid:
  326. return
  327. is_typing = typing.action == TypingAction.STARTED
  328. await sender.intent_for(portal).set_typing(
  329. portal.mxid, timeout=SIGNAL_TYPING_TIMEOUT if is_typing else 0
  330. )
  331. @staticmethod
  332. async def handle_receipt(sender: pu.Puppet, receipt: ReceiptMessage) -> None:
  333. if receipt.type != ReceiptType.READ:
  334. return
  335. messages = await DBMessage.find_by_timestamps(receipt.timestamps)
  336. for message in messages:
  337. portal = await po.Portal.get_by_mxid(message.mx_room)
  338. await sender.intent_for(portal).mark_read(portal.mxid, message.mxid)
  339. async def handle_decryption_error(
  340. self, user: u.User, sender: pu.Puppet, msg: IncomingMessage
  341. ) -> None:
  342. # These messages mean that a message resend was requested. Signald will handle it, but we
  343. # need to update the checkpoints.
  344. assert msg.decryption_error_message
  345. my_uuid = user.address.uuid
  346. timestamp = msg.decryption_error_message.timestamp
  347. self.log.debug(f"Got decryption error message for {my_uuid}/{timestamp}")
  348. message = await DBMessage.find_by_sender_timestamp(my_uuid, timestamp)
  349. if not message:
  350. self.log.warning("Couldn't find message to referenced in decryption error")
  351. return
  352. self.log.debug(
  353. f"Got decryption error message for {message.mxid} from {sender.uuid} "
  354. f"in {message.mx_room}"
  355. )
  356. portal = await po.Portal.get_by_mxid(message.mx_room)
  357. if not portal or not portal.mxid:
  358. self.log.warning("Couldn't find portal for message referenced in decryption error")
  359. return
  360. evt = await portal.main_intent.get_event(message.mx_room, message.mxid)
  361. if evt.content.get("fi.mau.double_puppet_source"):
  362. self.log.debug(
  363. "Message requested in decryption error is double-puppeted, not sending checkpoint"
  364. )
  365. return
  366. user.send_remote_checkpoint(
  367. status=MessageSendCheckpointStatus.DELIVERY_FAILED,
  368. event_id=message.mxid,
  369. room_id=message.mx_room,
  370. event_type=EventType.ROOM_MESSAGE,
  371. error=f"{sender.uuid} sent a decryption error message for this message",
  372. )
  373. async def start(self) -> None:
  374. await self.connect()
  375. known_usernames = set()
  376. async for user in u.User.all_logged_in():
  377. # TODO report errors to user?
  378. known_usernames.add(user.username)
  379. if await self.subscribe(user.username):
  380. self.log.info(
  381. f"Successfully subscribed {user.username}, running sync in background"
  382. )
  383. background_task.create(user.sync())
  384. if self.delete_unknown_accounts:
  385. self.log.debug("Checking for unknown accounts to delete")
  386. for account in await self.list_accounts():
  387. if account.account_id not in known_usernames:
  388. self.log.warning(f"Unknown account ID {account.account_id}, deleting...")
  389. await self.delete_account(account.account_id)
  390. else:
  391. self.log.debug("No unknown accounts found")
  392. async def stop(self) -> None:
  393. await self.disconnect()