signald.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. # Copyright (c) 2022 Tulir Asokan
  2. #
  3. # This Source Code Form is subject to the terms of the Mozilla Public
  4. # License, v. 2.0. If a copy of the MPL was not distributed with this
  5. # file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. from __future__ import annotations
  7. from typing import Any, Awaitable, Callable, Type, TypeVar
  8. from uuid import UUID
  9. import asyncio
  10. from mautrix.util.logging import TraceLogger
  11. from .errors import AuthorizationFailedError, NoSuchAccountError, RPCError, UnexpectedResponse
  12. from .rpc import CONNECT_EVENT, DISCONNECT_EVENT, SignaldRPCClient
  13. from .types import (
  14. Account,
  15. Address,
  16. Attachment,
  17. DeviceInfo,
  18. ErrorMessage,
  19. GetIdentitiesResponse,
  20. GroupAccessControl,
  21. GroupID,
  22. GroupMember,
  23. GroupV2,
  24. IncomingMessage,
  25. JoinGroupResponse,
  26. LinkPreview,
  27. LinkSession,
  28. Mention,
  29. MessageResendSuccessEvent,
  30. Profile,
  31. ProofRequiredType,
  32. Quote,
  33. Reaction,
  34. SendMessageResponse,
  35. StorageChange,
  36. TrustLevel,
  37. WebsocketConnectionState,
  38. WebsocketConnectionStateChangeEvent,
  39. )
  40. T = TypeVar("T")
  41. EventHandler = Callable[[T], Awaitable[None]]
  42. class SignaldClient(SignaldRPCClient):
  43. _event_handlers: dict[Type[T], list[EventHandler]]
  44. _subscriptions: set[str]
  45. def __init__(
  46. self,
  47. socket_path: str = "/var/run/signald/signald.sock",
  48. log: TraceLogger | None = None,
  49. loop: asyncio.AbstractEventLoop | None = None,
  50. ) -> None:
  51. super().__init__(socket_path, log, loop)
  52. self._event_handlers = {}
  53. self._subscriptions = set()
  54. self.add_rpc_handler("IncomingMessage", self._parse_message)
  55. self.add_rpc_handler("ProtocolInvalidMessageError", self._parse_error)
  56. self.add_rpc_handler("WebSocketConnectionState", self._websocket_connection_state_change)
  57. self.add_rpc_handler("version", self._log_version)
  58. self.add_rpc_handler("StorageChange", self._parse_storage_change)
  59. self.add_rpc_handler("MessageResendSuccess", self._parse_message_resend_request)
  60. self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
  61. self.add_rpc_handler(DISCONNECT_EVENT, self._on_disconnect)
  62. def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
  63. self._event_handlers.setdefault(event_class, []).append(handler)
  64. def remove_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
  65. self._event_handlers.setdefault(event_class, []).remove(handler)
  66. async def _run_event_handler(self, event: T) -> None:
  67. try:
  68. handlers = self._event_handlers[type(event)]
  69. except KeyError:
  70. self.log.warning(f"No handlers for {type(event)}")
  71. else:
  72. for handler in handlers:
  73. try:
  74. await handler(event)
  75. except Exception:
  76. self.log.exception("Exception in event handler")
  77. async def _parse_error(self, data: dict[str, Any]) -> None:
  78. if not data.get("error"):
  79. return
  80. await self._run_event_handler(ErrorMessage.deserialize(data))
  81. async def _parse_storage_change(self, data: dict[str, Any]) -> None:
  82. if data["type"] != "StorageChange":
  83. return
  84. await self._run_event_handler(StorageChange.deserialize(data))
  85. async def _parse_message_resend_request(self, data: dict[str, Any]) -> None:
  86. if data["type"] != "MesaageResendSuccess":
  87. return
  88. await self._run_event_handler(MessageResendSuccessEvent.deserialize(data))
  89. async def _parse_message(self, data: dict[str, Any]) -> None:
  90. event_type = data["type"]
  91. event_data = data["data"]
  92. event_class = {
  93. "IncomingMessage": IncomingMessage,
  94. }[event_type]
  95. event = event_class.deserialize(event_data)
  96. await self._run_event_handler(event)
  97. async def _log_version(self, data: dict[str, Any]) -> None:
  98. name = data["data"]["name"]
  99. version = data["data"]["version"]
  100. self.log.info(f"Connected to {name} v{version}")
  101. async def _websocket_connection_state_change(self, change_event: dict[str, Any]) -> None:
  102. evt = WebsocketConnectionStateChangeEvent.deserialize(
  103. {
  104. "account": change_event["account"],
  105. **change_event["data"],
  106. }
  107. )
  108. await self._run_event_handler(evt)
  109. async def subscribe(self, username: str) -> bool:
  110. try:
  111. await self.request_v1("subscribe", account=username)
  112. self._subscriptions.add(username)
  113. return True
  114. except RPCError as e:
  115. self.log.debug("Failed to subscribe to %s: %s", username, e)
  116. state = WebsocketConnectionState.DISCONNECTED
  117. if isinstance(e, (AuthorizationFailedError, NoSuchAccountError)):
  118. state = WebsocketConnectionState.AUTHENTICATION_FAILED
  119. evt = WebsocketConnectionStateChangeEvent(state=state, account=username)
  120. await self._run_event_handler(evt)
  121. return False
  122. async def unsubscribe(self, username: str) -> bool:
  123. try:
  124. await self.request_v1("unsubscribe", account=username)
  125. self._subscriptions.discard(username)
  126. return True
  127. except RPCError as e:
  128. self.log.debug("Failed to unsubscribe from %s: %s", username, e)
  129. return False
  130. async def _resubscribe(self, unused_data: dict[str, Any]) -> None:
  131. if self._subscriptions:
  132. self.log.debug("Resubscribing to users")
  133. for username in list(self._subscriptions):
  134. await self.subscribe(username)
  135. async def _on_disconnect(self, *_) -> None:
  136. if self._subscriptions:
  137. self.log.debug("Notifying of disconnection from users")
  138. for username in self._subscriptions:
  139. evt = WebsocketConnectionStateChangeEvent(
  140. state=WebsocketConnectionState.SOCKET_DISCONNECTED,
  141. account=username,
  142. exception="Disconnected from signald",
  143. )
  144. await self._run_event_handler(evt)
  145. async def register(self, phone: str, voice: bool = False, captcha: str | None = None) -> str:
  146. resp = await self.request_v1("register", account=phone, voice=voice, captcha=captcha)
  147. return resp["account_id"]
  148. async def verify(self, username: str, code: str) -> Account:
  149. resp = await self.request_v1("verify", account=username, code=code)
  150. return Account.deserialize(resp)
  151. async def start_link(self) -> LinkSession:
  152. return LinkSession.deserialize(await self.request_v1("generate_linking_uri"))
  153. async def wait_for_scan(self, session_id: str) -> None:
  154. await self.request_v1("wait_for_scan", session_id=session_id)
  155. async def finish_link(
  156. self, session_id: str, device_name: str = "mausignald", overwrite: bool = False
  157. ) -> Account:
  158. resp = await self.request_v1(
  159. "finish_link", device_name=device_name, session_id=session_id, overwrite=overwrite
  160. )
  161. return Account.deserialize(resp)
  162. @staticmethod
  163. def _recipient_to_args(
  164. recipient: UUID | Address | GroupID, simple_name: bool = False
  165. ) -> dict[str, Any]:
  166. if isinstance(recipient, UUID):
  167. recipient = Address(uuid=recipient)
  168. if isinstance(recipient, Address):
  169. recipient = recipient.serialize()
  170. field_name = "address" if simple_name else "recipientAddress"
  171. else:
  172. field_name = "group" if simple_name else "recipientGroupId"
  173. return {field_name: recipient}
  174. async def react(
  175. self,
  176. username: str,
  177. recipient: UUID | Address | GroupID,
  178. reaction: Reaction,
  179. req_id: UUID | None = None,
  180. ) -> None:
  181. await self.request_v1(
  182. "react",
  183. username=username,
  184. reaction=reaction.serialize(),
  185. req_id=req_id,
  186. **self._recipient_to_args(recipient),
  187. )
  188. async def remote_delete(
  189. self, username: str, recipient: UUID | Address | GroupID, timestamp: int
  190. ) -> None:
  191. await self.request_v1(
  192. "remote_delete",
  193. account=username,
  194. timestamp=timestamp,
  195. **self._recipient_to_args(recipient, simple_name=True),
  196. )
  197. async def send_raw(
  198. self,
  199. username: str,
  200. recipient: UUID | Address | GroupID,
  201. body: str,
  202. quote: Quote | None = None,
  203. attachments: list[Attachment] | None = None,
  204. mentions: list[Mention] | None = None,
  205. previews: list[LinkPreview] | None = None,
  206. timestamp: int | None = None,
  207. req_id: UUID | None = None,
  208. ) -> SendMessageResponse:
  209. serialized_quote = quote.serialize() if quote else None
  210. serialized_attachments = [attachment.serialize() for attachment in (attachments or [])]
  211. serialized_mentions = [mention.serialize() for mention in (mentions or [])]
  212. serialized_previews = [preview.serialize() for preview in (previews or [])]
  213. resp = await self.request_v1(
  214. "send",
  215. username=username,
  216. messageBody=body,
  217. attachments=serialized_attachments,
  218. quote=serialized_quote,
  219. mentions=serialized_mentions,
  220. previews=serialized_previews,
  221. timestamp=timestamp,
  222. req_id=req_id,
  223. **self._recipient_to_args(recipient),
  224. )
  225. return SendMessageResponse.deserialize(resp)
  226. async def send(
  227. self,
  228. username: str,
  229. recipient: UUID | Address | GroupID,
  230. body: str,
  231. quote: Quote | None = None,
  232. attachments: list[Attachment] | None = None,
  233. mentions: list[Mention] | None = None,
  234. previews: list[LinkPreview] | None = None,
  235. timestamp: int | None = None,
  236. req_id: UUID | None = None,
  237. ) -> None:
  238. resp = await self.send_raw(
  239. username, recipient, body, quote, attachments, mentions, previews, timestamp, req_id
  240. )
  241. # We handle unregisteredFailure a little differently than other errors. If there are no
  242. # successful sends, then we show an error with the unregisteredFailure details, otherwise
  243. # we ignore it.
  244. errors = []
  245. unregistered_failures = []
  246. successful_send_count = 0
  247. for result in resp.results:
  248. number = result.address.number_or_uuid
  249. if result.network_failure:
  250. errors.append(f"Network failure occurred while sending message to {number}.")
  251. elif result.unregistered_failure:
  252. unregistered_failures.append(
  253. f"Unregistered failure occurred while sending message to {number}."
  254. )
  255. elif result.identity_failure:
  256. errors.append(
  257. f"Identity failure occurred while sending message to {number}. New identity: "
  258. f"{result.identity_failure}"
  259. )
  260. elif result.proof_required_failure:
  261. prf = result.proof_required_failure
  262. self.log.warning(
  263. f"Proof Required Failure {prf.options}. Retry after: {prf.retry_after}. "
  264. f"Token: {prf.token}. Message: {prf.message}."
  265. )
  266. errors.append(
  267. f"Proof required failure occurred while sending message to {number}. Message: "
  268. f"{prf.message}"
  269. )
  270. if ProofRequiredType.RECAPTCHA in prf.options:
  271. errors.append("RECAPTCHA required.")
  272. elif ProofRequiredType.PUSH_CHALLENGE in prf.options:
  273. # Just submit the challenge automatically.
  274. await self.request_v1("submit_challenge")
  275. else:
  276. successful_send_count += 1
  277. self.log.info(
  278. f"Successfully sent message to {successful_send_count}/{len(resp.results)} users in "
  279. f"{recipient} with {len(unregistered_failures)} unregistered failures"
  280. )
  281. if len(unregistered_failures) == len(resp.results):
  282. errors.extend(unregistered_failures)
  283. if errors:
  284. raise Exception("\n".join(errors))
  285. async def send_receipt(
  286. self,
  287. username: str,
  288. sender: Address,
  289. timestamps: list[int],
  290. when: int | None = None,
  291. read: bool = False,
  292. ) -> None:
  293. if not read:
  294. # TODO implement
  295. return
  296. await self.request_v1(
  297. "mark_read", account=username, timestamps=timestamps, when=when, to=sender.serialize()
  298. )
  299. async def list_accounts(self) -> list[Account]:
  300. resp = await self.request_v1("list_accounts")
  301. return [Account.deserialize(acc) for acc in resp.get("accounts", [])]
  302. async def delete_account(self, username: str, server: bool = False) -> None:
  303. await self.request_v1("delete_account", account=username, server=server)
  304. async def get_linked_devices(self, username: str) -> list[DeviceInfo]:
  305. resp = await self.request_v1("get_linked_devices", account=username)
  306. return [DeviceInfo.deserialize(dev) for dev in resp.get("devices", [])]
  307. async def add_linked_device(self, username: str, uri: str) -> None:
  308. await self.request_v1("add_device", account=username, uri=uri)
  309. async def remove_linked_device(self, username: str, device_id: int) -> None:
  310. await self.request_v1("remove_linked_device", account=username, deviceId=device_id)
  311. async def request_sync(
  312. self,
  313. username: str,
  314. blocked: bool = True,
  315. configuration: bool = True,
  316. contacts: bool = True,
  317. groups: bool = True,
  318. ) -> None:
  319. await self.request_v1(
  320. "request_sync",
  321. account=username,
  322. blocked=blocked,
  323. configuration=configuration,
  324. contacts=contacts,
  325. groups=groups,
  326. )
  327. async def list_contacts(self, username: str, use_cache: bool = False) -> list[Profile]:
  328. kwargs = {"async": use_cache}
  329. resp = await self.request_v1("list_contacts", account=username, **kwargs)
  330. return [Profile.deserialize(contact) for contact in resp["profiles"]]
  331. async def list_groups(self, username: str) -> list[GroupV2]:
  332. resp = await self.request_v1("list_groups", account=username)
  333. return [GroupV2.deserialize(group) for group in resp.get("groups", [])]
  334. async def join_group(self, username: str, uri: str) -> JoinGroupResponse:
  335. resp = await self.request_v1("join_group", account=username, uri=uri)
  336. return JoinGroupResponse.deserialize(resp)
  337. async def leave_group(self, username: str, group_id: GroupID) -> None:
  338. await self.request_v1("leave_group", account=username, groupID=group_id)
  339. async def ban_user(self, username: str, group_id: GroupID, users: list[Address]) -> GroupV2:
  340. serialized_users = [user.serialize() for user in (users or [])]
  341. resp = await self.request_v1(
  342. "ban_user", account=username, group_id=group_id, users=serialized_users
  343. )
  344. return GroupV2.deserialize(resp)
  345. async def unban_user(self, username: str, group_id: GroupID, users: list[Address]) -> GroupV2:
  346. serialized_users = [user.serialize() for user in (users or [])]
  347. resp = await self.request_v1(
  348. "unban_user", account=username, group_id=group_id, users=serialized_users
  349. )
  350. return GroupV2.deserialize(resp)
  351. async def approve_membership(
  352. self, username: str, group_id: GroupID, members: list[Address]
  353. ) -> GroupV2:
  354. serialized_members = [member.serialize() for member in (members or [])]
  355. resp = await self.request_v1(
  356. "approve_membership", account=username, groupID=group_id, members=serialized_members
  357. )
  358. return GroupV2.deserialize(resp)
  359. async def refuse_membership(
  360. self, username: str, group_id: GroupID, members: list[Address], also_ban: bool = False
  361. ) -> GroupV2:
  362. serialized_members = [member.serialize() for member in (members or [])]
  363. resp = await self.request_v1(
  364. "refuse_membership",
  365. account=username,
  366. group_id=group_id,
  367. members=serialized_members,
  368. also_ban=also_ban,
  369. )
  370. return GroupV2.deserialize(resp)
  371. async def update_group(
  372. self,
  373. username: str,
  374. group_id: GroupID,
  375. title: str | None = None,
  376. description: str | None = None,
  377. avatar_path: str | None = None,
  378. add_members: list[Address] | None = None,
  379. remove_members: list[Address] | None = None,
  380. update_access_control: GroupAccessControl | None = None,
  381. update_role: GroupMember | None = None,
  382. ) -> GroupV2 | None:
  383. update_params = {
  384. key: value
  385. for key, value in {
  386. "groupID": group_id,
  387. "avatar": avatar_path,
  388. "title": title,
  389. "description": description,
  390. "addMembers": [addr.serialize() for addr in add_members] if add_members else None,
  391. "removeMembers": (
  392. [addr.serialize() for addr in remove_members] if remove_members else None
  393. ),
  394. "updateAccessControl": (
  395. update_access_control.serialize() if update_access_control else None
  396. ),
  397. "updateRole": (update_role.serialize() if update_role else None),
  398. }.items()
  399. if value is not None
  400. }
  401. resp = await self.request_v1("update_group", account=username, **update_params)
  402. if "v2" in resp:
  403. return GroupV2.deserialize(resp["v2"])
  404. elif "v1" in resp:
  405. raise RuntimeError("v1 groups are no longer supported")
  406. else:
  407. return None
  408. async def accept_invitation(self, username: str, group_id: GroupID) -> GroupV2:
  409. resp = await self.request_v1("accept_invitation", account=username, groupID=group_id)
  410. return GroupV2.deserialize(resp)
  411. async def get_group(
  412. self, username: str, group_id: GroupID, revision: int = -1
  413. ) -> GroupV2 | None:
  414. resp = await self.request_v1(
  415. "get_group", account=username, groupID=group_id, revision=revision
  416. )
  417. if "id" not in resp:
  418. return None
  419. return GroupV2.deserialize(resp)
  420. async def create_group(
  421. self,
  422. username: str,
  423. title: str,
  424. avatar_path: str | None = None,
  425. member_role_administrator: bool = False,
  426. members: list[Address] | None = None,
  427. ) -> GroupV2 | None:
  428. create_params = {
  429. "avatar": avatar_path,
  430. "member_role": "ADMINISTRATOR" if member_role_administrator else "DEFAULT",
  431. "title": title,
  432. "members": [addr.serialize() for addr in members],
  433. }
  434. create_params = {k: v for k, v in create_params.items() if v is not None}
  435. resp = await self.request_v1("create_group", account=username, **create_params)
  436. if "id" not in resp:
  437. return None
  438. return GroupV2.deserialize(resp)
  439. async def get_profile(
  440. self, username: str, address: Address, use_cache: bool = False
  441. ) -> Profile | None:
  442. try:
  443. # async is a reserved keyword, so can't pass it as a normal parameter
  444. kwargs = {"async": use_cache}
  445. resp = await self.request_v1(
  446. "get_profile", account=username, address=address.serialize(), **kwargs
  447. )
  448. except UnexpectedResponse as e:
  449. if e.resp_type == "profile_not_available":
  450. return None
  451. raise
  452. return Profile.deserialize(resp)
  453. async def get_identities(self, username: str, address: Address) -> GetIdentitiesResponse:
  454. resp = await self.request_v1(
  455. "get_identities", account=username, address=address.serialize()
  456. )
  457. return GetIdentitiesResponse.deserialize(resp)
  458. async def set_profile(
  459. self, username: str, name: str | None = None, avatar_path: str | None = None
  460. ) -> None:
  461. args = {}
  462. if name is not None:
  463. args["name"] = name
  464. if avatar_path is not None:
  465. args["avatarFile"] = avatar_path
  466. await self.request_v1("set_profile", account=username, **args)
  467. async def trust(
  468. self,
  469. username: str,
  470. recipient: Address,
  471. trust_level: TrustLevel | str,
  472. safety_number: str | None = None,
  473. qr_code_data: str | None = None,
  474. ) -> None:
  475. args = {}
  476. if safety_number:
  477. if qr_code_data:
  478. raise ValueError("only one of safety_number and qr_code_data must be set")
  479. args["safety_number"] = safety_number
  480. elif qr_code_data:
  481. args["qr_code_data"] = qr_code_data
  482. else:
  483. raise ValueError("safety_number or qr_code_data is required")
  484. await self.request_v1(
  485. "trust",
  486. account=username,
  487. **args,
  488. trust_level=trust_level.value if isinstance(trust_level, TrustLevel) else trust_level,
  489. address=recipient.serialize(),
  490. )
  491. async def find_uuid(self, username: str, number: str) -> UUID | None:
  492. resp = await self.request_v1(
  493. "resolve_address", partial=Address(number=number).serialize(), account=username
  494. )
  495. return Address.deserialize(resp).uuid
  496. async def submit_challenge(
  497. self, username: str, captcha_token: str | None, challenge: str | None
  498. ) -> None:
  499. await self.request_v1(
  500. "submit_challenge", account=username, captcha_token=captcha_token, challenge=challenge
  501. )