signald.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright (c) 2020 Tulir Asokan
  2. #
  3. # This Source Code Form is subject to the terms of the Mozilla Public
  4. # License, v. 2.0. If a copy of the MPL was not distributed with this
  5. # file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, TypeVar, Type
  7. from uuid import uuid4
  8. import asyncio
  9. from mautrix.util.logging import TraceLogger
  10. from .rpc import SignaldRPCClient
  11. from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
  12. from .types import Address, Quote, Attachment, Reaction, Account, Message, Contact, Group, Profile
  13. T = TypeVar('T')
  14. EventHandler = Callable[[T], Awaitable[None]]
  15. class SignaldClient(SignaldRPCClient):
  16. _event_handlers: Dict[Type[T], List[EventHandler]]
  17. def __init__(self, socket_path: str = "/var/run/signald/signald.sock",
  18. log: Optional[TraceLogger] = None,
  19. loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
  20. super().__init__(socket_path, log, loop)
  21. self._event_handlers = {}
  22. self.add_rpc_handler("message", self._parse_message)
  23. def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
  24. self._event_handlers.setdefault(event_class, []).append(handler)
  25. def remove_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
  26. self._event_handlers.setdefault(event_class, []).remove(handler)
  27. async def _run_event_handler(self, event: T) -> None:
  28. try:
  29. handlers = self._event_handlers[type(event)]
  30. except KeyError:
  31. self.log.warning(f"No handlers for {type(event)}")
  32. else:
  33. for handler in handlers:
  34. try:
  35. await handler(event)
  36. except Exception:
  37. self.log.exception("Exception in event handler")
  38. async def _parse_message(self, data: Dict[str, Any]) -> None:
  39. event_type = data["type"]
  40. event_data = data["data"]
  41. event_class = {
  42. "message": Message,
  43. }[event_type]
  44. event = event_class.deserialize(event_data)
  45. await self._run_event_handler(event)
  46. async def subscribe(self, username: str) -> bool:
  47. try:
  48. await self.request("subscribe", "subscribed", username=username)
  49. return True
  50. except UnexpectedError as e:
  51. self.log.debug("Failed to subscribe to %s: %s", username, e)
  52. return False
  53. async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
  54. ) -> str:
  55. resp = await self.request("register", "verification_required", username=phone,
  56. voice=voice, captcha=captcha)
  57. return resp["username"]
  58. async def verify(self, username: str, code: str) -> Account:
  59. resp = await self.request("verify", "verification_succeeded", username=username, code=code)
  60. return Account.deserialize(resp)
  61. async def link(self, url_callback: Callable[[str], Awaitable[None]],
  62. device_name: str = "mausignald") -> Account:
  63. req_id = uuid4()
  64. resp_type, resp = await self._raw_request("link", req_id, deviceName=device_name)
  65. if resp_type == "linking_error":
  66. raise make_linking_error(resp)
  67. elif resp_type != "linking_uri":
  68. raise UnexpectedResponse(resp_type, resp)
  69. self.loop.create_task(url_callback(resp["uri"]))
  70. resp_type, resp = await self._wait_response(req_id)
  71. if resp_type == "linking_error":
  72. raise make_linking_error(resp)
  73. elif resp_type != "linking_successful":
  74. raise UnexpectedResponse(resp_type, resp)
  75. return Account.deserialize(resp)
  76. async def list_accounts(self) -> List[Account]:
  77. data = await self.request("list_accounts", "account_list")
  78. return [Account.deserialize(acc) for acc in data["accounts"]]
  79. @staticmethod
  80. def _recipient_to_args(recipient: Union[Address, str]) -> Dict[str, Any]:
  81. if isinstance(recipient, Address):
  82. return {"recipientAddress": recipient.serialize()}
  83. else:
  84. return {"recipientGroupId": recipient}
  85. async def react(self, username: str, recipient: Union[Address, str],
  86. reaction: Reaction) -> None:
  87. await self.request("react", "send_results", username=username,
  88. reaction=reaction.serialize(),
  89. **self._recipient_to_args(recipient))
  90. async def send(self, username: str, recipient: Union[Address, str], body: str,
  91. quote: Optional[Quote] = None, attachments: Optional[List[Attachment]] = None,
  92. timestamp: Optional[int] = None) -> None:
  93. serialized_quote = quote.serialize() if quote else None
  94. serialized_attachments = [attachment.serialize() for attachment in (attachments or [])]
  95. await self.request("send", "send_results", username=username, messageBody=body,
  96. attachments=serialized_attachments, quote=serialized_quote,
  97. timestamp=timestamp, **self._recipient_to_args(recipient))
  98. # TODO return something?
  99. async def send_receipt(self, username: str, sender: Address, timestamps: List[int],
  100. when: Optional[int] = None, read: bool = False) -> None:
  101. await self.request_nowait("mark_read" if read else "mark_delivered", username=username,
  102. timestamps=timestamps, when=when,
  103. recipientAddress=sender.serialize())
  104. async def list_contacts(self, username: str) -> List[Contact]:
  105. contacts = await self.request("list_contacts", "contact_list", username=username)
  106. return [Contact.deserialize(contact) for contact in contacts]
  107. async def list_groups(self, username: str) -> List[Group]:
  108. resp = await self.request("list_groups", "group_list", username=username)
  109. return [Group.deserialize(group) for group in resp["groups"]]
  110. async def get_profile(self, username: str, address: Address) -> Optional[Profile]:
  111. try:
  112. resp = await self.request("get_profile", "profile", username=username,
  113. recipientAddress=address.serialize())
  114. except UnexpectedResponse as e:
  115. if e.resp_type == "profile_not_available":
  116. return None
  117. raise
  118. return Profile.deserialize(resp)
  119. async def set_profile(self, username: str, new_name: str) -> None:
  120. await self.request("set_profile", "profile_set", username=username, name=new_name)