signald.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. # Copyright (c) 2020 Tulir Asokan
  2. #
  3. # This Source Code Form is subject to the terms of the Mozilla Public
  4. # License, v. 2.0. If a copy of the MPL was not distributed with this
  5. # file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, Set, TypeVar, Type
  7. import asyncio
  8. from mautrix.util.logging import TraceLogger
  9. from .rpc import CONNECT_EVENT, DISCONNECT_EVENT, SignaldRPCClient
  10. from .errors import UnexpectedError, UnexpectedResponse
  11. from .types import (
  12. Address,
  13. Quote,
  14. Attachment,
  15. Reaction,
  16. Account,
  17. Message,
  18. DeviceInfo,
  19. Group,
  20. Profile,
  21. GroupID,
  22. GetIdentitiesResponse,
  23. GroupV2,
  24. Mention,
  25. LinkSession,
  26. WebsocketConnectionState,
  27. WebsocketConnectionStateChangeEvent,
  28. )
  29. T = TypeVar("T")
  30. EventHandler = Callable[[T], Awaitable[None]]
  31. class SignaldClient(SignaldRPCClient):
  32. _event_handlers: Dict[Type[T], List[EventHandler]]
  33. _subscriptions: Set[str]
  34. def __init__(
  35. self,
  36. socket_path: str = "/var/run/signald/signald.sock",
  37. log: Optional[TraceLogger] = None,
  38. loop: Optional[asyncio.AbstractEventLoop] = None,
  39. ) -> None:
  40. super().__init__(socket_path, log, loop)
  41. self._event_handlers = {}
  42. self._subscriptions = set()
  43. self.add_rpc_handler("message", self._parse_message)
  44. self.add_rpc_handler(
  45. "websocket_connection_state_change", self._websocket_connection_state_change
  46. )
  47. self.add_rpc_handler("version", self._log_version)
  48. self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
  49. self.add_rpc_handler(DISCONNECT_EVENT, self._on_disconnect)
  50. def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
  51. self._event_handlers.setdefault(event_class, []).append(handler)
  52. def remove_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
  53. self._event_handlers.setdefault(event_class, []).remove(handler)
  54. async def _run_event_handler(self, event: T) -> None:
  55. try:
  56. handlers = self._event_handlers[type(event)]
  57. except KeyError:
  58. self.log.warning(f"No handlers for {type(event)}")
  59. else:
  60. for handler in handlers:
  61. try:
  62. await handler(event)
  63. except Exception:
  64. self.log.exception("Exception in event handler")
  65. async def _parse_message(self, data: Dict[str, Any]) -> None:
  66. event_type = data["type"]
  67. event_data = data["data"]
  68. event_class = {
  69. "message": Message,
  70. }[event_type]
  71. event = event_class.deserialize(event_data)
  72. await self._run_event_handler(event)
  73. async def _log_version(self, data: Dict[str, Any]) -> None:
  74. name = data["data"]["name"]
  75. version = data["data"]["version"]
  76. self.log.info(f"Connected to {name} v{version}")
  77. async def _websocket_connection_state_change(self, change_event: Dict[str, Any]) -> None:
  78. evt = WebsocketConnectionStateChangeEvent.deserialize(change_event["data"])
  79. await self._run_event_handler(evt)
  80. async def subscribe(self, username: str) -> bool:
  81. try:
  82. await self.request("subscribe", "subscribed", username=username)
  83. self._subscriptions.add(username)
  84. return True
  85. except UnexpectedError as e:
  86. self.log.debug("Failed to subscribe to %s: %s", username, e)
  87. evt = WebsocketConnectionStateChangeEvent(
  88. state=(
  89. WebsocketConnectionState.AUTHENTICATION_FAILED
  90. if str(e) == "[401] Authorization failed!"
  91. else WebsocketConnectionState.DISCONNECTED
  92. ),
  93. account=username,
  94. )
  95. await self._run_event_handler(evt)
  96. return False
  97. async def unsubscribe(self, username: str) -> bool:
  98. try:
  99. await self.request("unsubscribe", "unsubscribed", username=username)
  100. self._subscriptions.remove(username)
  101. return True
  102. except UnexpectedError as e:
  103. self.log.debug("Failed to unsubscribe from %s: %s", username, e)
  104. return False
  105. async def _resubscribe(self, unused_data: Dict[str, Any]) -> None:
  106. if self._subscriptions:
  107. self.log.debug("Resubscribing to users")
  108. for username in list(self._subscriptions):
  109. await self.subscribe(username)
  110. async def _on_disconnect(self, *_) -> None:
  111. if self._subscriptions:
  112. self.log.debug("Notifying of disconnection from users")
  113. for username in self._subscriptions:
  114. evt = WebsocketConnectionStateChangeEvent(
  115. state=WebsocketConnectionState.SOCKET_DISCONNECTED,
  116. account=username,
  117. exception="Disconnected from signald",
  118. )
  119. await self._run_event_handler(evt)
  120. async def register(
  121. self, phone: str, voice: bool = False, captcha: Optional[str] = None
  122. ) -> str:
  123. resp = await self.request_v1("register", account=phone, voice=voice, captcha=captcha)
  124. return resp["account_id"]
  125. async def verify(self, username: str, code: str) -> Account:
  126. resp = await self.request_v1("verify", account=username, code=code)
  127. return Account.deserialize(resp)
  128. async def start_link(self) -> LinkSession:
  129. return LinkSession.deserialize(await self.request_v1("generate_linking_uri"))
  130. async def finish_link(
  131. self, session_id: str, device_name: str = "mausignald", overwrite: bool = False
  132. ) -> Account:
  133. resp = await self.request_v1(
  134. "finish_link", device_name=device_name, session_id=session_id, overwrite=overwrite
  135. )
  136. return Account.deserialize(resp)
  137. @staticmethod
  138. def _recipient_to_args(
  139. recipient: Union[Address, GroupID], simple_name: bool = False
  140. ) -> Dict[str, Any]:
  141. if isinstance(recipient, Address):
  142. recipient = recipient.serialize()
  143. field_name = "address" if simple_name else "recipientAddress"
  144. else:
  145. field_name = "group" if simple_name else "recipientGroupId"
  146. return {field_name: recipient}
  147. async def react(
  148. self, username: str, recipient: Union[Address, GroupID], reaction: Reaction
  149. ) -> None:
  150. await self.request_v1(
  151. "react",
  152. username=username,
  153. reaction=reaction.serialize(),
  154. **self._recipient_to_args(recipient),
  155. )
  156. async def remote_delete(
  157. self, username: str, recipient: Union[Address, GroupID], timestamp: int
  158. ) -> None:
  159. await self.request_v1(
  160. "remote_delete",
  161. account=username,
  162. timestamp=timestamp,
  163. **self._recipient_to_args(recipient, simple_name=True),
  164. )
  165. async def send(
  166. self,
  167. username: str,
  168. recipient: Union[Address, GroupID],
  169. body: str,
  170. quote: Optional[Quote] = None,
  171. attachments: Optional[List[Attachment]] = None,
  172. mentions: Optional[List[Mention]] = None,
  173. timestamp: Optional[int] = None,
  174. ) -> None:
  175. serialized_quote = quote.serialize() if quote else None
  176. serialized_attachments = [attachment.serialize() for attachment in (attachments or [])]
  177. serialized_mentions = [mention.serialize() for mention in (mentions or [])]
  178. resp = await self.request_v1(
  179. "send",
  180. username=username,
  181. messageBody=body,
  182. attachments=serialized_attachments,
  183. quote=serialized_quote,
  184. mentions=serialized_mentions,
  185. timestamp=timestamp,
  186. **self._recipient_to_args(recipient),
  187. )
  188. errors = []
  189. # We handle unregisteredFailure a little differently than other errors. If there are no
  190. # successful sends, then we show an error with the unregisteredFailure details, otherwise
  191. # we ignore it.
  192. unregistered_failures = []
  193. successful_send_count = 0
  194. results = resp.get("results", [])
  195. for result in results:
  196. address = result.get("addres", {})
  197. number = address.get("number") or address.get("uuid")
  198. proof_required_failure = result.get("proof_required_failure")
  199. if result.get("networkFailure", False):
  200. errors.append(f"Network failure occurred while sending message to {number}.")
  201. elif result.get("unregisteredFailure", False):
  202. unregistered_failures.append(
  203. f"Unregistered failure occurred while sending message to {number}."
  204. )
  205. elif result.get("identityFailure", ""):
  206. errors.append(
  207. f"Identity failure occurred while sending message to {number}. New identity: "
  208. f"{result['identityFailure']}"
  209. )
  210. elif proof_required_failure:
  211. options = proof_required_failure.get("options")
  212. self.log.warning(
  213. f"Proof Required Failure {options}. "
  214. f"Retry after: {proof_required_failure.get('retry_after')}. "
  215. f"Token: {proof_required_failure.get('token')}. "
  216. f"Message: {proof_required_failure.get('message')}. "
  217. )
  218. errors.append(
  219. f"Proof required failure occurred while sending message to {number}. Message: "
  220. f"{proof_required_failure.get('message')}"
  221. )
  222. if "RECAPTCHA" in options:
  223. errors.append("RECAPTCHA required.")
  224. elif "PUSH_CHALLENGE" in options:
  225. # Just submit the challenge automatically.
  226. await self.request_v1("submit_challenge")
  227. else:
  228. successful_send_count += 1
  229. self.log.info(
  230. f"Successfully sent message to {successful_send_count}/{len(results)} users in {recipient}"
  231. )
  232. if errors or successful_send_count == 0:
  233. raise Exception("\n".join(errors + unregistered_failures))
  234. async def send_receipt(
  235. self,
  236. username: str,
  237. sender: Address,
  238. timestamps: List[int],
  239. when: Optional[int] = None,
  240. read: bool = False,
  241. ) -> None:
  242. if not read:
  243. # TODO implement
  244. return
  245. await self.request_v1(
  246. "mark_read", account=username, timestamps=timestamps, when=when, to=sender.serialize()
  247. )
  248. async def list_accounts(self) -> List[Account]:
  249. resp = await self.request_v1("list_accounts")
  250. return [Account.deserialize(acc) for acc in resp.get("accounts", [])]
  251. async def delete_account(self, username: str, server: bool = False) -> None:
  252. await self.request_v1("delete_account", account=username, server=server)
  253. async def get_linked_devices(self, username: str) -> List[DeviceInfo]:
  254. resp = await self.request_v1("get_linked_devices", account=username)
  255. return [DeviceInfo.deserialize(dev) for dev in resp.get("devices", [])]
  256. async def remove_linked_device(self, username: str, device_id: int) -> None:
  257. await self.request_v1("remove_linked_device", account=username, deviceId=device_id)
  258. async def list_contacts(self, username: str) -> List[Profile]:
  259. resp = await self.request_v1("list_contacts", account=username)
  260. return [Profile.deserialize(contact) for contact in resp["profiles"]]
  261. async def list_groups(self, username: str) -> List[Union[Group, GroupV2]]:
  262. resp = await self.request_v1("list_groups", account=username)
  263. legacy = [Group.deserialize(group) for group in resp.get("legacyGroups", [])]
  264. v2 = [GroupV2.deserialize(group) for group in resp.get("groups", [])]
  265. return legacy + v2
  266. async def update_group(
  267. self,
  268. username: str,
  269. group_id: GroupID,
  270. title: Optional[str] = None,
  271. avatar_path: Optional[str] = None,
  272. add_members: Optional[List[Address]] = None,
  273. remove_members: Optional[List[Address]] = None,
  274. ) -> Union[Group, GroupV2, None]:
  275. update_params = {
  276. key: value
  277. for key, value in {
  278. "groupID": group_id,
  279. "avatar": avatar_path,
  280. "title": title,
  281. "addMembers": [addr.serialize() for addr in add_members] if add_members else None,
  282. "removeMembers": (
  283. [addr.serialize() for addr in remove_members] if remove_members else None
  284. ),
  285. }.items()
  286. if value is not None
  287. }
  288. resp = await self.request_v1("update_group", account=username, **update_params)
  289. if "v1" in resp:
  290. return Group.deserialize(resp["v1"])
  291. elif "v2" in resp:
  292. return GroupV2.deserialize(resp["v2"])
  293. else:
  294. return None
  295. async def accept_invitation(self, username: str, group_id: GroupID) -> GroupV2:
  296. resp = await self.request_v1("accept_invitation", account=username, groupID=group_id)
  297. return GroupV2.deserialize(resp)
  298. async def get_group(
  299. self, username: str, group_id: GroupID, revision: int = -1
  300. ) -> Optional[GroupV2]:
  301. resp = await self.request_v1(
  302. "get_group", account=username, groupID=group_id, revision=revision
  303. )
  304. if "id" not in resp:
  305. return None
  306. return GroupV2.deserialize(resp)
  307. async def get_profile(
  308. self, username: str, address: Address, use_cache: bool = False
  309. ) -> Optional[Profile]:
  310. try:
  311. # async is a reserved keyword, so can't pass it as a normal parameter
  312. kwargs = {"async": use_cache}
  313. resp = await self.request_v1(
  314. "get_profile", account=username, address=address.serialize(), **kwargs
  315. )
  316. except UnexpectedResponse as e:
  317. if e.resp_type == "profile_not_available":
  318. return None
  319. raise
  320. return Profile.deserialize(resp)
  321. async def get_identities(self, username: str, address: Address) -> GetIdentitiesResponse:
  322. resp = await self.request_v1(
  323. "get_identities", account=username, address=address.serialize()
  324. )
  325. return GetIdentitiesResponse.deserialize(resp)
  326. async def set_profile(
  327. self, username: str, name: Optional[str] = None, avatar_path: Optional[str] = None
  328. ) -> None:
  329. args = {}
  330. if name is not None:
  331. args["name"] = name
  332. if avatar_path is not None:
  333. args["avatarFile"] = avatar_path
  334. await self.request_v1("set_profile", account=username, **args)
  335. async def trust(
  336. self,
  337. username: str,
  338. recipient: Address,
  339. trust_level: str,
  340. safety_number: Optional[str] = None,
  341. qr_code_data: Optional[str] = None,
  342. ) -> None:
  343. args = {}
  344. if safety_number:
  345. if qr_code_data:
  346. raise ValueError("only one of safety_number and qr_code_data must be set")
  347. args["safety_number"] = safety_number
  348. elif qr_code_data:
  349. args["qr_code_data"] = qr_code_data
  350. else:
  351. raise ValueError("safety_number or qr_code_data is required")
  352. await self.request_v1(
  353. "trust",
  354. account=username,
  355. **args,
  356. trust_level=trust_level,
  357. address=recipient.serialize(),
  358. )