signald.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. # Copyright (c) 2020 Tulir Asokan
  2. #
  3. # This Source Code Form is subject to the terms of the Mozilla Public
  4. # License, v. 2.0. If a copy of the MPL was not distributed with this
  5. # file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, Set, TypeVar, Type
  7. import asyncio
  8. from mautrix.util.logging import TraceLogger
  9. from .rpc import CONNECT_EVENT, DISCONNECT_EVENT, SignaldRPCClient
  10. from .errors import UnexpectedError, UnexpectedResponse
  11. from .types import (Address, Quote, Attachment, Reaction, Account, Message, DeviceInfo, Group,
  12. Profile, GroupID, GetIdentitiesResponse, GroupV2, Mention, LinkSession,
  13. WebsocketConnectionState, WebsocketConnectionStateChangeEvent)
  14. T = TypeVar('T')
  15. EventHandler = Callable[[T], Awaitable[None]]
  16. class SignaldClient(SignaldRPCClient):
  17. _event_handlers: Dict[Type[T], List[EventHandler]]
  18. _subscriptions: Set[str]
  19. def __init__(self, socket_path: str = "/var/run/signald/signald.sock",
  20. log: Optional[TraceLogger] = None,
  21. loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
  22. super().__init__(socket_path, log, loop)
  23. self._event_handlers = {}
  24. self._subscriptions = set()
  25. self.add_rpc_handler("message", self._parse_message)
  26. self.add_rpc_handler("websocket_connection_state_change",
  27. self._websocket_connection_state_change)
  28. self.add_rpc_handler("version", self._log_version)
  29. self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
  30. self.add_rpc_handler(DISCONNECT_EVENT, self._on_disconnect)
  31. def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
  32. self._event_handlers.setdefault(event_class, []).append(handler)
  33. def remove_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
  34. self._event_handlers.setdefault(event_class, []).remove(handler)
  35. async def _run_event_handler(self, event: T) -> None:
  36. try:
  37. handlers = self._event_handlers[type(event)]
  38. except KeyError:
  39. self.log.warning(f"No handlers for {type(event)}")
  40. else:
  41. for handler in handlers:
  42. try:
  43. await handler(event)
  44. except Exception:
  45. self.log.exception("Exception in event handler")
  46. async def _parse_message(self, data: Dict[str, Any]) -> None:
  47. event_type = data["type"]
  48. event_data = data["data"]
  49. event_class = {
  50. "message": Message,
  51. }[event_type]
  52. event = event_class.deserialize(event_data)
  53. await self._run_event_handler(event)
  54. async def _log_version(self, data: Dict[str, Any]) -> None:
  55. name = data["data"]["name"]
  56. version = data["data"]["version"]
  57. self.log.info(f"Connected to {name} v{version}")
  58. async def _websocket_connection_state_change(self, change_event: Dict[str, Any]) -> None:
  59. evt = WebsocketConnectionStateChangeEvent.deserialize(change_event["data"])
  60. await self._run_event_handler(evt)
  61. async def subscribe(self, username: str) -> bool:
  62. try:
  63. await self.request("subscribe", "subscribed", username=username)
  64. self._subscriptions.add(username)
  65. return True
  66. except UnexpectedError as e:
  67. self.log.debug("Failed to subscribe to %s: %s", username, e)
  68. evt = WebsocketConnectionStateChangeEvent(
  69. state=(
  70. WebsocketConnectionState.AUTHENTICATION_FAILED
  71. if str(e) == "[401] Authorization failed!"
  72. else WebsocketConnectionState.DISCONNECTED
  73. ),
  74. account=username,
  75. )
  76. await self._run_event_handler(evt)
  77. return False
  78. async def unsubscribe(self, username: str) -> bool:
  79. try:
  80. await self.request("unsubscribe", "unsubscribed", username=username)
  81. self._subscriptions.remove(username)
  82. return True
  83. except UnexpectedError as e:
  84. self.log.debug("Failed to unsubscribe from %s: %s", username, e)
  85. return False
  86. async def _resubscribe(self, unused_data: Dict[str, Any]) -> None:
  87. if self._subscriptions:
  88. self.log.debug("Resubscribing to users")
  89. for username in list(self._subscriptions):
  90. await self.subscribe(username)
  91. async def _on_disconnect(self, *_) -> None:
  92. if self._subscriptions:
  93. self.log.debug("Notifying of disconnection from users")
  94. for username in self._subscriptions:
  95. evt = WebsocketConnectionStateChangeEvent(
  96. state=WebsocketConnectionState.SOCKET_DISCONNECTED,
  97. account=username,
  98. exception="Disconnected from signald"
  99. )
  100. await self._run_event_handler(evt)
  101. async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
  102. ) -> str:
  103. resp = await self.request_v1("register", account=phone, voice=voice, captcha=captcha)
  104. return resp["account_id"]
  105. async def verify(self, username: str, code: str) -> Account:
  106. resp = await self.request_v1("verify", account=username, code=code)
  107. return Account.deserialize(resp)
  108. async def start_link(self) -> LinkSession:
  109. return LinkSession.deserialize(await self.request_v1("generate_linking_uri"))
  110. async def finish_link(self, session_id: str, device_name: str = "mausignald",
  111. overwrite: bool = False) -> Account:
  112. resp = await self.request_v1("finish_link", device_name=device_name, session_id=session_id,
  113. overwrite=overwrite)
  114. return Account.deserialize(resp)
  115. @staticmethod
  116. def _recipient_to_args(recipient: Union[Address, GroupID], simple_name: bool = False
  117. ) -> Dict[str, Any]:
  118. if isinstance(recipient, Address):
  119. recipient = recipient.serialize()
  120. field_name = "address" if simple_name else "recipientAddress"
  121. else:
  122. field_name = "group" if simple_name else "recipientGroupId"
  123. return {field_name: recipient}
  124. async def react(self, username: str, recipient: Union[Address, GroupID],
  125. reaction: Reaction) -> None:
  126. await self.request_v1("react", username=username, reaction=reaction.serialize(),
  127. **self._recipient_to_args(recipient))
  128. async def remote_delete(self, username: str, recipient: Union[Address, GroupID], timestamp: int
  129. ) -> None:
  130. await self.request_v1("remote_delete", account=username, timestamp=timestamp,
  131. **self._recipient_to_args(recipient, simple_name=True))
  132. async def send(self, username: str, recipient: Union[Address, GroupID], body: str,
  133. quote: Optional[Quote] = None, attachments: Optional[List[Attachment]] = None,
  134. mentions: Optional[List[Mention]] = None, timestamp: Optional[int] = None
  135. ) -> None:
  136. serialized_quote = quote.serialize() if quote else None
  137. serialized_attachments = [attachment.serialize() for attachment in (attachments or [])]
  138. serialized_mentions = [mention.serialize() for mention in (mentions or [])]
  139. resp = await self.request_v1("send", username=username, messageBody=body,
  140. attachments=serialized_attachments, quote=serialized_quote,
  141. mentions=serialized_mentions, timestamp=timestamp,
  142. **self._recipient_to_args(recipient))
  143. errors = []
  144. # We handle unregisteredFailure a little differently than other errors. If there are no
  145. # successful sends, then we show an error with the unregisteredFailure details, otherwise
  146. # we ignore it.
  147. unregistered_failures = []
  148. successful_send_count = 0
  149. results = resp.get("results", [])
  150. for result in results:
  151. number = (
  152. result.get("address", {}).get("number") or result.get("address", {}).get("uuid")
  153. )
  154. proof_required_failure = result.get("proof_required_failure")
  155. if result.get("networkFailure", False):
  156. errors.append(f"Network failure occurred while sending message to {number}.")
  157. elif result.get("unregisteredFailure", False):
  158. unregistered_failures.append(
  159. f"Unregistered failure occurred while sending message to {number}."
  160. )
  161. elif result.get("identityFailure", ""):
  162. errors.append(
  163. f"Identity failure occurred while sending message to {number}. New identity: "
  164. f"{result['identityFailure']}")
  165. elif proof_required_failure:
  166. options = proof_required_failure.get('options')
  167. self.log.warning(
  168. f"Proof Required Failure {options}. "
  169. f"Retry after: {proof_required_failure.get('retry_after')}. "
  170. f"Token: {proof_required_failure.get('token')}. "
  171. f"Message: {proof_required_failure.get('message')}. "
  172. )
  173. errors.append(
  174. f"Proof required failure occurred while sending message to {number}. Message: "
  175. f"{proof_required_failure.get('message')}"
  176. )
  177. if "RECAPTCHA" in options:
  178. errors.append("RECAPTCHA required.")
  179. elif "PUSH_CHALLENGE" in options:
  180. # Just submit the challenge automatically.
  181. await self.request_v1("submit_challenge")
  182. else:
  183. successful_send_count += 1
  184. self.log.info(f"Successfully sent message to {successful_send_count}/{len(results)} users in {recipient}")
  185. if errors or successful_send_count == 0:
  186. raise Exception("\n".join(errors + unregistered_failures))
  187. async def send_receipt(self, username: str, sender: Address, timestamps: List[int],
  188. when: Optional[int] = None, read: bool = False) -> None:
  189. if not read:
  190. # TODO implement
  191. return
  192. await self.request_v1("mark_read", account=username, timestamps=timestamps, when=when,
  193. to=sender.serialize())
  194. async def list_accounts(self) -> List[Account]:
  195. resp = await self.request_v1("list_accounts")
  196. return [Account.deserialize(acc) for acc in resp.get("accounts", [])]
  197. async def delete_account(self, username: str, server: bool = False) -> None:
  198. await self.request_v1("delete_account", account=username, server=server)
  199. async def get_linked_devices(self, username: str) -> List[DeviceInfo]:
  200. resp = await self.request_v1("get_linked_devices", account=username)
  201. return [DeviceInfo.deserialize(dev) for dev in resp.get("devices", [])]
  202. async def remove_linked_device(self, username: str, device_id: int) -> None:
  203. await self.request_v1("remove_linked_device", account=username, deviceId=device_id)
  204. async def list_contacts(self, username: str) -> List[Profile]:
  205. resp = await self.request_v1("list_contacts", account=username)
  206. return [Profile.deserialize(contact) for contact in resp["profiles"]]
  207. async def list_groups(self, username: str) -> List[Union[Group, GroupV2]]:
  208. resp = await self.request_v1("list_groups", account=username)
  209. legacy = [Group.deserialize(group) for group in resp.get("legacyGroups", [])]
  210. v2 = [GroupV2.deserialize(group) for group in resp.get("groups", [])]
  211. return legacy + v2
  212. async def update_group(self, username: str, group_id: GroupID, title: Optional[str] = None,
  213. avatar_path: Optional[str] = None,
  214. add_members: Optional[List[Address]] = None,
  215. remove_members: Optional[List[Address]] = None
  216. ) -> Union[Group, GroupV2, None]:
  217. update_params = {key: value for key, value in {
  218. "groupID": group_id,
  219. "avatar": avatar_path,
  220. "title": title,
  221. "addMembers": [addr.serialize() for addr in add_members] if add_members else None,
  222. "removeMembers": ([addr.serialize() for addr in remove_members]
  223. if remove_members else None),
  224. }.items() if value is not None}
  225. resp = await self.request_v1("update_group", account=username, **update_params)
  226. if "v1" in resp:
  227. return Group.deserialize(resp["v1"])
  228. elif "v2" in resp:
  229. return GroupV2.deserialize(resp["v2"])
  230. else:
  231. return None
  232. async def accept_invitation(self, username: str, group_id: GroupID) -> GroupV2:
  233. resp = await self.request_v1("accept_invitation", account=username, groupID=group_id)
  234. return GroupV2.deserialize(resp)
  235. async def get_group(self, username: str, group_id: GroupID, revision: int = -1
  236. ) -> Optional[GroupV2]:
  237. resp = await self.request_v1("get_group", account=username, groupID=group_id,
  238. revision=revision)
  239. if "id" not in resp:
  240. return None
  241. return GroupV2.deserialize(resp)
  242. async def get_profile(self, username: str, address: Address, use_cache: bool = False
  243. ) -> Optional[Profile]:
  244. try:
  245. # async is a reserved keyword, so can't pass it as a normal parameter
  246. kwargs = {"async": use_cache}
  247. resp = await self.request_v1("get_profile", account=username,
  248. address=address.serialize(), **kwargs)
  249. except UnexpectedResponse as e:
  250. if e.resp_type == "profile_not_available":
  251. return None
  252. raise
  253. return Profile.deserialize(resp)
  254. async def get_identities(self, username: str, address: Address) -> GetIdentitiesResponse:
  255. resp = await self.request_v1("get_identities", account=username,
  256. address=address.serialize())
  257. return GetIdentitiesResponse.deserialize(resp)
  258. async def set_profile(self, username: str, name: Optional[str] = None,
  259. avatar_path: Optional[str] = None) -> None:
  260. args = {}
  261. if name is not None:
  262. args["name"] = name
  263. if avatar_path is not None:
  264. args["avatarFile"] = avatar_path
  265. await self.request_v1("set_profile", account=username, **args)
  266. async def trust(self, username: str, recipient: Address, trust_level: str,
  267. safety_number: Optional[str] = None, qr_code_data: Optional[str] = None
  268. ) -> None:
  269. args = {}
  270. if safety_number:
  271. if qr_code_data:
  272. raise ValueError("only one of safety_number and qr_code_data must be set")
  273. args["safety_number"] = safety_number
  274. elif qr_code_data:
  275. args["qr_code_data"] = qr_code_data
  276. else:
  277. raise ValueError("safety_number or qr_code_data is required")
  278. await self.request_v1("trust", account=username, **args, trust_level=trust_level,
  279. address=recipient.serialize())