signald.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. # Copyright (c) 2020 Tulir Asokan
  2. #
  3. # This Source Code Form is subject to the terms of the Mozilla Public
  4. # License, v. 2.0. If a copy of the MPL was not distributed with this
  5. # file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Type, TypeVar, Union
  7. import asyncio
  8. from mautrix.util.logging import TraceLogger
  9. from .errors import UnexpectedError, UnexpectedResponse
  10. from .rpc import CONNECT_EVENT, DISCONNECT_EVENT, SignaldRPCClient
  11. from .types import (
  12. Account,
  13. Address,
  14. Attachment,
  15. DeviceInfo,
  16. GetIdentitiesResponse,
  17. Group,
  18. GroupID,
  19. GroupV2,
  20. LinkSession,
  21. Mention,
  22. Message,
  23. Profile,
  24. Quote,
  25. Reaction,
  26. WebsocketConnectionState,
  27. WebsocketConnectionStateChangeEvent,
  28. )
  29. T = TypeVar("T")
  30. EventHandler = Callable[[T], Awaitable[None]]
  31. class SignaldClient(SignaldRPCClient):
  32. _event_handlers: Dict[Type[T], List[EventHandler]]
  33. _subscriptions: Set[str]
  34. def __init__(
  35. self,
  36. socket_path: str = "/var/run/signald/signald.sock",
  37. log: Optional[TraceLogger] = None,
  38. loop: Optional[asyncio.AbstractEventLoop] = None,
  39. ) -> None:
  40. super().__init__(socket_path, log, loop)
  41. self._event_handlers = {}
  42. self._subscriptions = set()
  43. self.add_rpc_handler("message", self._parse_message)
  44. self.add_rpc_handler(
  45. "websocket_connection_state_change", self._websocket_connection_state_change
  46. )
  47. self.add_rpc_handler("version", self._log_version)
  48. self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
  49. self.add_rpc_handler(DISCONNECT_EVENT, self._on_disconnect)
  50. def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
  51. self._event_handlers.setdefault(event_class, []).append(handler)
  52. def remove_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
  53. self._event_handlers.setdefault(event_class, []).remove(handler)
  54. async def _run_event_handler(self, event: T) -> None:
  55. try:
  56. handlers = self._event_handlers[type(event)]
  57. except KeyError:
  58. self.log.warning(f"No handlers for {type(event)}")
  59. else:
  60. for handler in handlers:
  61. try:
  62. await handler(event)
  63. except Exception:
  64. self.log.exception("Exception in event handler")
  65. async def _parse_message(self, data: Dict[str, Any]) -> None:
  66. event_type = data["type"]
  67. event_data = data["data"]
  68. event_class = {
  69. "message": Message,
  70. }[event_type]
  71. event = event_class.deserialize(event_data)
  72. await self._run_event_handler(event)
  73. async def _log_version(self, data: Dict[str, Any]) -> None:
  74. name = data["data"]["name"]
  75. version = data["data"]["version"]
  76. self.log.info(f"Connected to {name} v{version}")
  77. async def _websocket_connection_state_change(self, change_event: Dict[str, Any]) -> None:
  78. evt = WebsocketConnectionStateChangeEvent.deserialize(change_event["data"])
  79. await self._run_event_handler(evt)
  80. async def subscribe(self, username: str) -> bool:
  81. try:
  82. await self.request("subscribe", "subscribed", username=username)
  83. self._subscriptions.add(username)
  84. return True
  85. except UnexpectedError as e:
  86. self.log.debug("Failed to subscribe to %s: %s", username, e)
  87. evt = WebsocketConnectionStateChangeEvent(
  88. state=(
  89. WebsocketConnectionState.AUTHENTICATION_FAILED
  90. if str(e) == "[401] Authorization failed!"
  91. else WebsocketConnectionState.DISCONNECTED
  92. ),
  93. account=username,
  94. )
  95. await self._run_event_handler(evt)
  96. return False
  97. async def unsubscribe(self, username: str) -> bool:
  98. try:
  99. await self.request("unsubscribe", "unsubscribed", username=username)
  100. self._subscriptions.remove(username)
  101. return True
  102. except UnexpectedError as e:
  103. self.log.debug("Failed to unsubscribe from %s: %s", username, e)
  104. return False
  105. async def _resubscribe(self, unused_data: Dict[str, Any]) -> None:
  106. if self._subscriptions:
  107. self.log.debug("Resubscribing to users")
  108. for username in list(self._subscriptions):
  109. await self.subscribe(username)
  110. async def _on_disconnect(self, *_) -> None:
  111. if self._subscriptions:
  112. self.log.debug("Notifying of disconnection from users")
  113. for username in self._subscriptions:
  114. evt = WebsocketConnectionStateChangeEvent(
  115. state=WebsocketConnectionState.SOCKET_DISCONNECTED,
  116. account=username,
  117. exception="Disconnected from signald",
  118. )
  119. await self._run_event_handler(evt)
  120. async def register(
  121. self, phone: str, voice: bool = False, captcha: Optional[str] = None
  122. ) -> str:
  123. resp = await self.request_v1("register", account=phone, voice=voice, captcha=captcha)
  124. return resp["account_id"]
  125. async def verify(self, username: str, code: str) -> Account:
  126. resp = await self.request_v1("verify", account=username, code=code)
  127. return Account.deserialize(resp)
  128. async def start_link(self) -> LinkSession:
  129. return LinkSession.deserialize(await self.request_v1("generate_linking_uri"))
  130. async def finish_link(
  131. self, session_id: str, device_name: str = "mausignald", overwrite: bool = False
  132. ) -> Account:
  133. resp = await self.request_v1(
  134. "finish_link", device_name=device_name, session_id=session_id, overwrite=overwrite
  135. )
  136. return Account.deserialize(resp)
  137. @staticmethod
  138. def _recipient_to_args(
  139. recipient: Union[Address, GroupID], simple_name: bool = False
  140. ) -> Dict[str, Any]:
  141. if isinstance(recipient, Address):
  142. recipient = recipient.serialize()
  143. field_name = "address" if simple_name else "recipientAddress"
  144. else:
  145. field_name = "group" if simple_name else "recipientGroupId"
  146. return {field_name: recipient}
  147. async def react(
  148. self, username: str, recipient: Union[Address, GroupID], reaction: Reaction
  149. ) -> None:
  150. await self.request_v1(
  151. "react",
  152. username=username,
  153. reaction=reaction.serialize(),
  154. **self._recipient_to_args(recipient),
  155. )
  156. async def remote_delete(
  157. self, username: str, recipient: Union[Address, GroupID], timestamp: int
  158. ) -> None:
  159. await self.request_v1(
  160. "remote_delete",
  161. account=username,
  162. timestamp=timestamp,
  163. **self._recipient_to_args(recipient, simple_name=True),
  164. )
  165. async def send(
  166. self,
  167. username: str,
  168. recipient: Union[Address, GroupID],
  169. body: str,
  170. quote: Optional[Quote] = None,
  171. attachments: Optional[List[Attachment]] = None,
  172. mentions: Optional[List[Mention]] = None,
  173. timestamp: Optional[int] = None,
  174. ) -> None:
  175. serialized_quote = quote.serialize() if quote else None
  176. serialized_attachments = [attachment.serialize() for attachment in (attachments or [])]
  177. serialized_mentions = [mention.serialize() for mention in (mentions or [])]
  178. resp = await self.request_v1(
  179. "send",
  180. username=username,
  181. messageBody=body,
  182. attachments=serialized_attachments,
  183. quote=serialized_quote,
  184. mentions=serialized_mentions,
  185. timestamp=timestamp,
  186. **self._recipient_to_args(recipient),
  187. )
  188. errors = []
  189. # We handle unregisteredFailure a little differently than other errors. If there are no
  190. # successful sends, then we show an error with the unregisteredFailure details, otherwise
  191. # we ignore it.
  192. unregistered_failures = []
  193. successful_send_count = 0
  194. results = resp.get("results", [])
  195. for result in results:
  196. address = result.get("addres", {})
  197. number = address.get("number") or address.get("uuid")
  198. proof_required_failure = result.get("proof_required_failure")
  199. if result.get("networkFailure", False):
  200. errors.append(f"Network failure occurred while sending message to {number}.")
  201. elif result.get("unregisteredFailure", False):
  202. unregistered_failures.append(
  203. f"Unregistered failure occurred while sending message to {number}."
  204. )
  205. elif result.get("identityFailure", ""):
  206. errors.append(
  207. f"Identity failure occurred while sending message to {number}. New identity: "
  208. f"{result['identityFailure']}"
  209. )
  210. elif proof_required_failure:
  211. options = proof_required_failure.get("options")
  212. self.log.warning(
  213. f"Proof Required Failure {options}. "
  214. f"Retry after: {proof_required_failure.get('retry_after')}. "
  215. f"Token: {proof_required_failure.get('token')}. "
  216. f"Message: {proof_required_failure.get('message')}. "
  217. )
  218. errors.append(
  219. f"Proof required failure occurred while sending message to {number}. Message: "
  220. f"{proof_required_failure.get('message')}"
  221. )
  222. if "RECAPTCHA" in options:
  223. errors.append("RECAPTCHA required.")
  224. elif "PUSH_CHALLENGE" in options:
  225. # Just submit the challenge automatically.
  226. await self.request_v1("submit_challenge")
  227. else:
  228. successful_send_count += 1
  229. self.log.info(
  230. f"Successfully sent message to {successful_send_count}/{len(results)} users in {recipient}"
  231. )
  232. if errors or successful_send_count == 0:
  233. raise Exception("\n".join(errors + unregistered_failures))
  234. async def send_receipt(
  235. self,
  236. username: str,
  237. sender: Address,
  238. timestamps: List[int],
  239. when: Optional[int] = None,
  240. read: bool = False,
  241. ) -> None:
  242. if not read:
  243. # TODO implement
  244. return
  245. await self.request_v1(
  246. "mark_read", account=username, timestamps=timestamps, when=when, to=sender.serialize()
  247. )
  248. async def list_accounts(self) -> List[Account]:
  249. resp = await self.request_v1("list_accounts")
  250. return [Account.deserialize(acc) for acc in resp.get("accounts", [])]
  251. async def delete_account(self, username: str, server: bool = False) -> None:
  252. await self.request_v1("delete_account", account=username, server=server)
  253. async def get_linked_devices(self, username: str) -> List[DeviceInfo]:
  254. resp = await self.request_v1("get_linked_devices", account=username)
  255. return [DeviceInfo.deserialize(dev) for dev in resp.get("devices", [])]
  256. async def remove_linked_device(self, username: str, device_id: int) -> None:
  257. await self.request_v1("remove_linked_device", account=username, deviceId=device_id)
  258. async def list_contacts(self, username: str) -> List[Profile]:
  259. resp = await self.request_v1("list_contacts", account=username)
  260. return [Profile.deserialize(contact) for contact in resp["profiles"]]
  261. async def list_groups(self, username: str) -> List[Union[Group, GroupV2]]:
  262. resp = await self.request_v1("list_groups", account=username)
  263. legacy = [Group.deserialize(group) for group in resp.get("legacyGroups", [])]
  264. v2 = [GroupV2.deserialize(group) for group in resp.get("groups", [])]
  265. return legacy + v2
  266. async def update_group(
  267. self,
  268. username: str,
  269. group_id: GroupID,
  270. title: Optional[str] = None,
  271. avatar_path: Optional[str] = None,
  272. add_members: Optional[List[Address]] = None,
  273. remove_members: Optional[List[Address]] = None,
  274. ) -> Union[Group, GroupV2, None]:
  275. update_params = {
  276. key: value
  277. for key, value in {
  278. "groupID": group_id,
  279. "avatar": avatar_path,
  280. "title": title,
  281. "addMembers": [addr.serialize() for addr in add_members] if add_members else None,
  282. "removeMembers": (
  283. [addr.serialize() for addr in remove_members] if remove_members else None
  284. ),
  285. }.items()
  286. if value is not None
  287. }
  288. resp = await self.request_v1("update_group", account=username, **update_params)
  289. if "v1" in resp:
  290. return Group.deserialize(resp["v1"])
  291. elif "v2" in resp:
  292. return GroupV2.deserialize(resp["v2"])
  293. else:
  294. return None
  295. async def accept_invitation(self, username: str, group_id: GroupID) -> GroupV2:
  296. resp = await self.request_v1("accept_invitation", account=username, groupID=group_id)
  297. return GroupV2.deserialize(resp)
  298. async def get_group(
  299. self, username: str, group_id: GroupID, revision: int = -1
  300. ) -> Optional[GroupV2]:
  301. resp = await self.request_v1(
  302. "get_group", account=username, groupID=group_id, revision=revision
  303. )
  304. if "id" not in resp:
  305. return None
  306. return GroupV2.deserialize(resp)
  307. async def get_profile(
  308. self, username: str, address: Address, use_cache: bool = False
  309. ) -> Optional[Profile]:
  310. try:
  311. # async is a reserved keyword, so can't pass it as a normal parameter
  312. kwargs = {"async": use_cache}
  313. resp = await self.request_v1(
  314. "get_profile", account=username, address=address.serialize(), **kwargs
  315. )
  316. except UnexpectedResponse as e:
  317. if e.resp_type == "profile_not_available":
  318. return None
  319. raise
  320. return Profile.deserialize(resp)
  321. async def get_identities(self, username: str, address: Address) -> GetIdentitiesResponse:
  322. resp = await self.request_v1(
  323. "get_identities", account=username, address=address.serialize()
  324. )
  325. return GetIdentitiesResponse.deserialize(resp)
  326. async def set_profile(
  327. self, username: str, name: Optional[str] = None, avatar_path: Optional[str] = None
  328. ) -> None:
  329. args = {}
  330. if name is not None:
  331. args["name"] = name
  332. if avatar_path is not None:
  333. args["avatarFile"] = avatar_path
  334. await self.request_v1("set_profile", account=username, **args)
  335. async def trust(
  336. self,
  337. username: str,
  338. recipient: Address,
  339. trust_level: str,
  340. safety_number: Optional[str] = None,
  341. qr_code_data: Optional[str] = None,
  342. ) -> None:
  343. args = {}
  344. if safety_number:
  345. if qr_code_data:
  346. raise ValueError("only one of safety_number and qr_code_data must be set")
  347. args["safety_number"] = safety_number
  348. elif qr_code_data:
  349. args["qr_code_data"] = qr_code_data
  350. else:
  351. raise ValueError("safety_number or qr_code_data is required")
  352. await self.request_v1(
  353. "trust",
  354. account=username,
  355. **args,
  356. trust_level=trust_level,
  357. address=recipient.serialize(),
  358. )