signald.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # Copyright (c) 2020 Tulir Asokan
  2. #
  3. # This Source Code Form is subject to the terms of the Mozilla Public
  4. # License, v. 2.0. If a copy of the MPL was not distributed with this
  5. # file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, TypeVar, Type
  7. from uuid import uuid4
  8. import asyncio
  9. from mautrix.util.logging import TraceLogger
  10. from .rpc import SignaldRPCClient
  11. from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
  12. from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
  13. Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2)
  14. T = TypeVar('T')
  15. EventHandler = Callable[[T], Awaitable[None]]
  16. class SignaldClient(SignaldRPCClient):
  17. _event_handlers: Dict[Type[T], List[EventHandler]]
  18. def __init__(self, socket_path: str = "/var/run/signald/signald.sock",
  19. log: Optional[TraceLogger] = None,
  20. loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
  21. super().__init__(socket_path, log, loop)
  22. self._event_handlers = {}
  23. self.add_rpc_handler("message", self._parse_message)
  24. self.add_rpc_handler("listen_started", self._parse_listen_start)
  25. self.add_rpc_handler("listen_stopped", self._parse_listen_stop)
  26. self.add_rpc_handler("version", self._log_version)
  27. def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
  28. self._event_handlers.setdefault(event_class, []).append(handler)
  29. def remove_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
  30. self._event_handlers.setdefault(event_class, []).remove(handler)
  31. async def _run_event_handler(self, event: T) -> None:
  32. try:
  33. handlers = self._event_handlers[type(event)]
  34. except KeyError:
  35. self.log.warning(f"No handlers for {type(event)}")
  36. else:
  37. for handler in handlers:
  38. try:
  39. await handler(event)
  40. except Exception:
  41. self.log.exception("Exception in event handler")
  42. async def _parse_message(self, data: Dict[str, Any]) -> None:
  43. event_type = data["type"]
  44. event_data = data["data"]
  45. event_class = {
  46. "message": Message,
  47. }[event_type]
  48. event = event_class.deserialize(event_data)
  49. await self._run_event_handler(event)
  50. async def _log_version(self, data: Dict[str, Any]) -> None:
  51. name = data["data"]["name"]
  52. version = data["data"]["version"]
  53. self.log.info(f"Connected to {name} v{version}")
  54. async def _parse_listen_start(self, data: Dict[str, Any]) -> None:
  55. evt = ListenEvent(action=ListenAction.STARTED, username=data["data"])
  56. await self._run_event_handler(evt)
  57. async def _parse_listen_stop(self, data: Dict[str, Any]) -> None:
  58. evt = ListenEvent(action=ListenAction.STOPPED, username=data["data"],
  59. exception=data.get("exception", None))
  60. await self._run_event_handler(evt)
  61. async def subscribe(self, username: str) -> bool:
  62. try:
  63. await self.request("subscribe", "subscribed", username=username)
  64. return True
  65. except UnexpectedError as e:
  66. self.log.debug("Failed to subscribe to %s: %s", username, e)
  67. return False
  68. async def unsubscribe(self, username: str) -> bool:
  69. try:
  70. await self.request("unsubscribe", "unsubscribed", username=username)
  71. return True
  72. except UnexpectedError as e:
  73. self.log.debug("Failed to unsubscribe from %s: %s", username, e)
  74. return False
  75. async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
  76. ) -> str:
  77. resp = await self.request("register", "verification_required", username=phone,
  78. voice=voice, captcha=captcha)
  79. return resp["username"]
  80. async def verify(self, username: str, code: str) -> Account:
  81. resp = await self.request("verify", "verification_succeeded", username=username, code=code)
  82. return Account.deserialize(resp)
  83. async def link(self, url_callback: Callable[[str], Awaitable[None]],
  84. device_name: str = "mausignald") -> Account:
  85. req_id = uuid4()
  86. resp_type, resp = await self._raw_request("link", req_id, deviceName=device_name)
  87. if resp_type == "linking_error":
  88. raise make_linking_error(resp)
  89. elif resp_type != "linking_uri":
  90. raise UnexpectedResponse(resp_type, resp)
  91. self.loop.create_task(url_callback(resp["uri"]))
  92. resp_type, resp = await self._wait_response(req_id)
  93. if resp_type == "linking_error":
  94. raise make_linking_error(resp)
  95. elif resp_type != "linking_successful":
  96. raise UnexpectedResponse(resp_type, resp)
  97. return Account.deserialize(resp)
  98. async def list_accounts(self) -> List[Account]:
  99. data = await self.request("list_accounts", "account_list")
  100. return [Account.deserialize(acc) for acc in data["accounts"]]
  101. @staticmethod
  102. def _recipient_to_args(recipient: Union[Address, GroupID]) -> Dict[str, Any]:
  103. if isinstance(recipient, Address):
  104. return {"recipientAddress": recipient.serialize()}
  105. else:
  106. return {"recipientGroupId": recipient}
  107. async def react(self, username: str, recipient: Union[Address, GroupID],
  108. reaction: Reaction) -> None:
  109. await self.request("react", "send_results", username=username,
  110. reaction=reaction.serialize(),
  111. **self._recipient_to_args(recipient))
  112. async def send(self, username: str, recipient: Union[Address, GroupID], body: str,
  113. quote: Optional[Quote] = None, attachments: Optional[List[Attachment]] = None,
  114. timestamp: Optional[int] = None) -> None:
  115. serialized_quote = quote.serialize() if quote else None
  116. serialized_attachments = [attachment.serialize() for attachment in (attachments or [])]
  117. await self.request("send", "send_results", username=username, messageBody=body,
  118. attachments=serialized_attachments, quote=serialized_quote,
  119. timestamp=timestamp, **self._recipient_to_args(recipient))
  120. # TODO return something?
  121. async def send_receipt(self, username: str, sender: Address, timestamps: List[int],
  122. when: Optional[int] = None, read: bool = False) -> None:
  123. await self.request_nowait("mark_read" if read else "mark_delivered", username=username,
  124. timestamps=timestamps, when=when,
  125. recipientAddress=sender.serialize())
  126. async def list_contacts(self, username: str) -> List[Contact]:
  127. contacts = await self.request("list_contacts", "contact_list", username=username)
  128. return [Contact.deserialize(contact) for contact in contacts]
  129. async def list_groups(self, username: str) -> List[Union[Group, GroupV2]]:
  130. resp = await self.request("list_groups", "group_list", username=username)
  131. return ([Group.deserialize(group) for group in resp["groups"]]
  132. + [GroupV2.deserialize(group) for group in resp["groupsv2"]])
  133. async def get_group(self, username: str, group_id: GroupID) -> Optional[GroupV2]:
  134. resp = await self.request("get_group", "get_group", account=username, groupID=group_id)
  135. if "id" not in resp:
  136. return None
  137. return GroupV2.deserialize(resp)
  138. async def get_profile(self, username: str, address: Address) -> Optional[Profile]:
  139. try:
  140. resp = await self.request("get_profile", "profile", username=username,
  141. recipientAddress=address.serialize())
  142. except UnexpectedResponse as e:
  143. if e.resp_type == "profile_not_available":
  144. return None
  145. raise
  146. return Profile.deserialize(resp)
  147. async def get_identities(self, username: str, address: Address) -> GetIdentitiesResponse:
  148. resp = await self.request("get_identities", "identities", username=username,
  149. recipientAddress=address.serialize())
  150. return GetIdentitiesResponse.deserialize(resp)
  151. async def set_profile(self, username: str, new_name: str) -> None:
  152. await self.request("set_profile", "profile_set", username=username, name=new_name)