signal.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2022 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING, Awaitable
  18. from uuid import UUID
  19. import asyncio
  20. import logging
  21. from mausignald import SignaldClient
  22. from mausignald.types import (
  23. Address,
  24. ErrorMessage,
  25. IncomingMessage,
  26. MessageData,
  27. OfferMessageType,
  28. OwnReadReceipt,
  29. ReceiptMessage,
  30. ReceiptType,
  31. StorageChange,
  32. TypingAction,
  33. TypingMessage,
  34. WebsocketConnectionStateChangeEvent,
  35. )
  36. from mautrix.types import EventID, Format, MessageType, TextMessageEventContent
  37. from mautrix.util.logging import TraceLogger
  38. from . import portal as po, puppet as pu, user as u
  39. from .db import Message as DBMessage
  40. if TYPE_CHECKING:
  41. from .__main__ import SignalBridge
  42. # Typing notifications seem to get resent every 10 seconds and the timeout is around 15 seconds
  43. SIGNAL_TYPING_TIMEOUT = 15000
  44. class SignalHandler(SignaldClient):
  45. log: TraceLogger = logging.getLogger("mau.signal")
  46. loop: asyncio.AbstractEventLoop
  47. data_dir: str
  48. delete_unknown_accounts: bool
  49. error_message_events: dict[tuple[Address, str, int], Awaitable[EventID] | None]
  50. def __init__(self, bridge: "SignalBridge") -> None:
  51. super().__init__(bridge.config["signal.socket_path"], loop=bridge.loop)
  52. self.data_dir = bridge.config["signal.data_dir"]
  53. self.delete_unknown_accounts = bridge.config["signal.delete_unknown_accounts_on_start"]
  54. self.error_message_events = {}
  55. self.add_event_handler(IncomingMessage, self.on_message)
  56. self.add_event_handler(ErrorMessage, self.on_error_message)
  57. self.add_event_handler(StorageChange, self.on_storage_change)
  58. self.add_event_handler(
  59. WebsocketConnectionStateChangeEvent, self.on_websocket_connection_state_change
  60. )
  61. async def on_message(self, evt: IncomingMessage) -> None:
  62. sender = await pu.Puppet.get_by_address(evt.source, resolve_via=evt.account)
  63. if not sender:
  64. self.log.warning(f"Didn't find puppet for incoming message {evt.source}")
  65. return
  66. user = await u.User.get_by_username(evt.account)
  67. # TODO add lots of logging
  68. if evt.data_message:
  69. await self.handle_message(user, sender, evt.data_message)
  70. if evt.typing_message:
  71. await self.handle_typing(user, sender, evt.typing_message)
  72. if evt.receipt_message:
  73. await self.handle_receipt(sender, evt.receipt_message)
  74. if evt.call_message:
  75. await self.handle_call_message(user, sender, evt)
  76. if evt.sync_message:
  77. if evt.sync_message.read_messages:
  78. await self.handle_own_receipts(sender, evt.sync_message.read_messages)
  79. if evt.sync_message.sent:
  80. await self.handle_message(
  81. user,
  82. sender,
  83. evt.sync_message.sent.message,
  84. addr_override=evt.sync_message.sent.destination,
  85. )
  86. if evt.sync_message.contacts or evt.sync_message.contacts_complete:
  87. self.log.debug("Sync message includes contacts meta, syncing contacts...")
  88. await user.sync_contacts()
  89. if evt.sync_message.groups:
  90. self.log.debug("Sync message includes groups meta, syncing groups...")
  91. await user.sync_groups()
  92. try:
  93. event_id_future = self.error_message_events.pop(
  94. (sender.address, user.username, evt.timestamp)
  95. )
  96. except KeyError:
  97. pass
  98. else:
  99. self.log.debug(f"Got previously errored message {evt.timestamp} from {sender.address}")
  100. event_id = await event_id_future if event_id_future is not None else None
  101. if event_id is not None:
  102. portal = await po.Portal.get_by_chat_id(sender.address, receiver=user.username)
  103. if portal and portal.mxid:
  104. await sender.intent_for(portal).redact(portal.mxid, event_id)
  105. async def on_error_message(self, err: ErrorMessage) -> None:
  106. self.log.warning(
  107. f"Error reading message from {err.data.sender}/{err.data.sender_device} "
  108. f"(timestamp: {err.data.timestamp}, content hint: {err.data.content_hint}): "
  109. f"{err.data.message}"
  110. )
  111. sender = await pu.Puppet.get_by_address(
  112. Address.parse(err.data.sender), resolve_via=err.account
  113. )
  114. if not sender:
  115. return
  116. user = await u.User.get_by_username(err.account)
  117. portal = await po.Portal.get_by_chat_id(sender.address, receiver=user.username)
  118. if not portal or not portal.mxid:
  119. return
  120. # Add the error to the error_message_events dictionary, then wait for 10 seconds until
  121. # sending an error. If a success for the timestamp comes in before the 10 seconds is up,
  122. # don't send the error message.
  123. error_message_event_key = (sender.address, user.username, err.data.timestamp)
  124. self.error_message_events[error_message_event_key] = None
  125. await asyncio.sleep(10)
  126. err_text = (
  127. "There was an error receiving a message. Check your Signal app for missing messages."
  128. )
  129. if error_message_event_key in self.error_message_events:
  130. fut = self.error_message_events[error_message_event_key] = self.loop.create_future()
  131. event_id = None
  132. try:
  133. event_id = await portal._send_message(
  134. intent=sender.intent_for(portal),
  135. content=TextMessageEventContent(body=err_text, msgtype=MessageType.NOTICE),
  136. )
  137. finally:
  138. fut.set_result(event_id)
  139. async def on_storage_change(self, storage_change: StorageChange) -> None:
  140. self.log.info("Handling StorageChange %s", str(storage_change))
  141. if user := await u.User.get_by_username(storage_change.account):
  142. await user.sync()
  143. @staticmethod
  144. async def on_websocket_connection_state_change(
  145. evt: WebsocketConnectionStateChangeEvent,
  146. ) -> None:
  147. user = await u.User.get_by_username(evt.account)
  148. user.on_websocket_connection_state_change(evt)
  149. async def handle_message(
  150. self,
  151. user: u.User,
  152. sender: pu.Puppet,
  153. msg: MessageData,
  154. addr_override: Address | None = None,
  155. ) -> None:
  156. try:
  157. await self._handle_message(user, sender, msg, addr_override)
  158. except Exception as e:
  159. await user.handle_auth_failure(e)
  160. raise
  161. async def _handle_message(
  162. self,
  163. user: u.User,
  164. sender: pu.Puppet,
  165. msg: MessageData,
  166. addr_override: Address | None = None,
  167. ) -> None:
  168. if msg.profile_key_update:
  169. asyncio.create_task(user.sync_contact(sender.address, use_cache=False))
  170. return
  171. if msg.group_v2:
  172. portal = await po.Portal.get_by_chat_id(msg.group_v2.id, create=True)
  173. elif msg.group:
  174. portal = await po.Portal.get_by_chat_id(msg.group.group_id, create=True)
  175. else:
  176. portal = await po.Portal.get_by_chat_id(
  177. addr_override or sender.address, receiver=user.username, create=True
  178. )
  179. if addr_override and not sender.is_real_user:
  180. portal.log.debug(
  181. f"Ignoring own message {msg.timestamp} as user doesn't have double puppeting "
  182. "enabled"
  183. )
  184. return
  185. assert portal
  186. # Handle the user being removed from the group.
  187. if msg.group_v2 and msg.group_v2.removed:
  188. if portal.mxid:
  189. await portal.handle_signal_kicked(user, sender)
  190. return
  191. if not portal.mxid:
  192. if not msg.is_message and not msg.group_v2:
  193. user.log.debug(
  194. f"Ignoring message {msg.timestamp},"
  195. " probably not bridgeable as there's no portal yet"
  196. )
  197. return
  198. await portal.create_matrix_room(
  199. user, msg.group_v2 or msg.group or addr_override or sender.address
  200. )
  201. if not portal.mxid:
  202. user.log.warning(
  203. f"Failed to create room for incoming message {msg.timestamp}, dropping message"
  204. )
  205. return
  206. elif (
  207. msg.group_v2
  208. and msg.group_v2.group_change
  209. and msg.group_v2.revision == portal.revision + 1
  210. ):
  211. self.log.debug(
  212. f"Got update for {msg.group_v2.id} ({portal.revision} -> "
  213. f"{msg.group_v2.revision}), applying diff"
  214. )
  215. await portal.handle_signal_group_change(msg.group_v2.group_change, user)
  216. elif msg.group_v2 and msg.group_v2.revision > portal.revision:
  217. self.log.debug(
  218. f"Got update with multiple revisions for {msg.group_v2.id} ({portal.revision} -> "
  219. f"{msg.group_v2.revision}), resyncing info"
  220. )
  221. await portal.update_info(user, msg.group_v2)
  222. if msg.expires_in_seconds is not None and (msg.is_message or msg.is_expiration_update):
  223. await portal.update_expires_in_seconds(sender, msg.expires_in_seconds)
  224. if msg.reaction:
  225. await portal.handle_signal_reaction(sender, msg.reaction, msg.timestamp)
  226. if msg.is_message:
  227. await portal.handle_signal_message(user, sender, msg)
  228. if msg.group and msg.group.type == "UPDATE":
  229. await portal.update_info(user, msg.group)
  230. if msg.remote_delete:
  231. await portal.handle_signal_delete(sender, msg.remote_delete.target_sent_timestamp)
  232. @staticmethod
  233. async def handle_call_message(user: u.User, sender: pu.Puppet, msg: IncomingMessage) -> None:
  234. assert msg.call_message
  235. portal = await po.Portal.get_by_chat_id(
  236. sender.address, receiver=user.username, create=True
  237. )
  238. if not portal.mxid:
  239. # FIXME
  240. # await portal.create_matrix_room(
  241. # user, (msg.group_v2 or msg.group or addr_override or sender.address)
  242. # )
  243. # if not portal.mxid:
  244. # user.log.debug(
  245. # f"Failed to create room for incoming message {msg.timestamp},"
  246. # " dropping message"
  247. # )
  248. return
  249. msg_html = f'<a href="https://matrix.to/#/{sender.mxid}">{sender.name}</a>'
  250. if msg.call_message.offer_message:
  251. call_type = {
  252. OfferMessageType.AUDIO_CALL: "voice call",
  253. OfferMessageType.VIDEO_CALL: "video call",
  254. }.get(msg.call_message.offer_message.type, "call")
  255. msg_html += f" started a {call_type} on Signal. Use the native app to answer the call."
  256. msg_type = MessageType.TEXT
  257. elif msg.call_message.hangup_message:
  258. msg_html += " ended a call on Signal."
  259. msg_type = MessageType.NOTICE
  260. else:
  261. portal.log.debug(f"Unhandled call message. Likely an ICE message. {msg.call_message}")
  262. return
  263. await portal._send_message(
  264. intent=sender.intent_for(portal),
  265. content=TextMessageEventContent(
  266. format=Format.HTML, formatted_body=msg_html, msgtype=msg_type
  267. ),
  268. )
  269. @staticmethod
  270. async def handle_own_receipts(sender: pu.Puppet, receipts: list[OwnReadReceipt]) -> None:
  271. for receipt in receipts:
  272. puppet = await pu.Puppet.get_by_address(receipt.sender, create=False)
  273. if not puppet:
  274. continue
  275. message = await DBMessage.find_by_sender_timestamp(puppet.address, receipt.timestamp)
  276. if not message:
  277. continue
  278. portal = await po.Portal.get_by_mxid(message.mx_room)
  279. if not portal or (portal.is_direct and not sender.is_real_user):
  280. continue
  281. await sender.intent_for(portal).mark_read(portal.mxid, message.mxid)
  282. @staticmethod
  283. async def handle_typing(user: u.User, sender: pu.Puppet, typing: TypingMessage) -> None:
  284. if typing.group_id:
  285. portal = await po.Portal.get_by_chat_id(typing.group_id)
  286. else:
  287. portal = await po.Portal.get_by_chat_id(sender.address, receiver=user.username)
  288. if not portal or not portal.mxid:
  289. return
  290. is_typing = typing.action == TypingAction.STARTED
  291. await sender.intent_for(portal).set_typing(
  292. portal.mxid, is_typing, ignore_cache=True, timeout=SIGNAL_TYPING_TIMEOUT
  293. )
  294. @staticmethod
  295. async def handle_receipt(sender: pu.Puppet, receipt: ReceiptMessage) -> None:
  296. if receipt.type != ReceiptType.READ:
  297. return
  298. messages = await DBMessage.find_by_timestamps(receipt.timestamps)
  299. for message in messages:
  300. portal = await po.Portal.get_by_mxid(message.mx_room)
  301. await sender.intent_for(portal).mark_read(portal.mxid, message.mxid)
  302. async def start(self) -> None:
  303. await self.connect()
  304. known_usernames = set()
  305. async for user in u.User.all_logged_in():
  306. # TODO report errors to user?
  307. known_usernames.add(user.username)
  308. if await self.subscribe(user.username):
  309. asyncio.create_task(user.sync())
  310. if self.delete_unknown_accounts:
  311. self.log.debug("Checking for unknown accounts to delete")
  312. for account in await self.list_accounts():
  313. if account.account_id not in known_usernames:
  314. self.log.warning(f"Unknown account ID {account.account_id}, deleting...")
  315. await self.delete_account(account.account_id)
  316. async def stop(self) -> None:
  317. await self.disconnect()