signald.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright (c) 2020 Tulir Asokan
  2. #
  3. # This Source Code Form is subject to the terms of the Mozilla Public
  4. # License, v. 2.0. If a copy of the MPL was not distributed with this
  5. # file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, Set, TypeVar, Type
  7. from uuid import uuid4
  8. import asyncio
  9. from mautrix.util.logging import TraceLogger
  10. from .rpc import CONNECT_EVENT, SignaldRPCClient
  11. from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
  12. from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
  13. Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2,
  14. Mention)
  15. T = TypeVar('T')
  16. EventHandler = Callable[[T], Awaitable[None]]
  17. class SignaldClient(SignaldRPCClient):
  18. _event_handlers: Dict[Type[T], List[EventHandler]]
  19. _subscriptions: Set[str]
  20. def __init__(self, socket_path: str = "/var/run/signald/signald.sock",
  21. log: Optional[TraceLogger] = None,
  22. loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
  23. super().__init__(socket_path, log, loop)
  24. self._event_handlers = {}
  25. self._subscriptions = set()
  26. self.add_rpc_handler("message", self._parse_message)
  27. self.add_rpc_handler("listen_started", self._parse_listen_start)
  28. self.add_rpc_handler("listen_stopped", self._parse_listen_stop)
  29. self.add_rpc_handler("version", self._log_version)
  30. self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
  31. def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
  32. self._event_handlers.setdefault(event_class, []).append(handler)
  33. def remove_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
  34. self._event_handlers.setdefault(event_class, []).remove(handler)
  35. async def _run_event_handler(self, event: T) -> None:
  36. try:
  37. handlers = self._event_handlers[type(event)]
  38. except KeyError:
  39. self.log.warning(f"No handlers for {type(event)}")
  40. else:
  41. for handler in handlers:
  42. try:
  43. await handler(event)
  44. except Exception:
  45. self.log.exception("Exception in event handler")
  46. async def _parse_message(self, data: Dict[str, Any]) -> None:
  47. event_type = data["type"]
  48. event_data = data["data"]
  49. event_class = {
  50. "message": Message,
  51. }[event_type]
  52. event = event_class.deserialize(event_data)
  53. await self._run_event_handler(event)
  54. async def _log_version(self, data: Dict[str, Any]) -> None:
  55. name = data["data"]["name"]
  56. version = data["data"]["version"]
  57. self.log.info(f"Connected to {name} v{version}")
  58. async def _parse_listen_start(self, data: Dict[str, Any]) -> None:
  59. evt = ListenEvent(action=ListenAction.STARTED, username=data["data"])
  60. await self._run_event_handler(evt)
  61. async def _parse_listen_stop(self, data: Dict[str, Any]) -> None:
  62. evt = ListenEvent(action=ListenAction.STOPPED, username=data["data"],
  63. exception=data.get("exception", None))
  64. await self._run_event_handler(evt)
  65. async def subscribe(self, username: str) -> bool:
  66. try:
  67. await self.request("subscribe", "subscribed", username=username)
  68. self._subscriptions.add(username)
  69. return True
  70. except UnexpectedError as e:
  71. self.log.debug("Failed to subscribe to %s: %s", username, e)
  72. return False
  73. async def unsubscribe(self, username: str) -> bool:
  74. try:
  75. await self.request("unsubscribe", "unsubscribed", username=username)
  76. self._subscriptions.remove(username)
  77. return True
  78. except UnexpectedError as e:
  79. self.log.debug("Failed to unsubscribe from %s: %s", username, e)
  80. return False
  81. async def _resubscribe(self, unused_data: Dict[str, Any]) -> None:
  82. if self._subscriptions:
  83. self.log.debug("Resubscribing to users")
  84. for username in list(self._subscriptions):
  85. await self.subscribe(username)
  86. async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
  87. ) -> str:
  88. resp = await self.request("register", "verification_required", username=phone,
  89. voice=voice, captcha=captcha)
  90. return resp["username"]
  91. async def verify(self, username: str, code: str) -> Account:
  92. resp = await self.request("verify", "verification_succeeded", username=username, code=code)
  93. return Account.deserialize(resp)
  94. async def link(self, url_callback: Callable[[str], Awaitable[None]],
  95. device_name: str = "mausignald") -> Account:
  96. req_id = uuid4()
  97. resp_type, resp = await self._raw_request("link", req_id, deviceName=device_name)
  98. if resp_type == "linking_error":
  99. raise make_linking_error(resp)
  100. elif resp_type != "linking_uri":
  101. raise UnexpectedResponse(resp_type, resp)
  102. self.loop.create_task(url_callback(resp["uri"]))
  103. resp_type, resp = await self._wait_response(req_id)
  104. if resp_type == "linking_error":
  105. raise make_linking_error(resp)
  106. elif resp_type != "linking_successful":
  107. raise UnexpectedResponse(resp_type, resp)
  108. return Account.deserialize(resp)
  109. async def list_accounts(self) -> List[Account]:
  110. data = await self.request("list_accounts", "account_list")
  111. return [Account.deserialize(acc) for acc in data["accounts"]]
  112. @staticmethod
  113. def _recipient_to_args(recipient: Union[Address, GroupID]) -> Dict[str, Any]:
  114. if isinstance(recipient, Address):
  115. return {"recipientAddress": recipient.serialize()}
  116. else:
  117. return {"recipientGroupId": recipient}
  118. async def react(self, username: str, recipient: Union[Address, GroupID],
  119. reaction: Reaction) -> None:
  120. await self.request("react", "send_results", username=username,
  121. reaction=reaction.serialize(),
  122. **self._recipient_to_args(recipient))
  123. async def send(self, username: str, recipient: Union[Address, GroupID], body: str,
  124. quote: Optional[Quote] = None, attachments: Optional[List[Attachment]] = None,
  125. mentions: Optional[List[Mention]] = None, timestamp: Optional[int] = None
  126. ) -> None:
  127. serialized_quote = quote.serialize() if quote else None
  128. serialized_attachments = [attachment.serialize() for attachment in (attachments or [])]
  129. serialized_mentions = [mention.serialize() for mention in (mentions or [])]
  130. await self.request("send", "send", username=username, messageBody=body,
  131. attachments=serialized_attachments, quote=serialized_quote,
  132. mentions=serialized_mentions, timestamp=timestamp,
  133. **self._recipient_to_args(recipient), version="v1")
  134. # TODO return something?
  135. async def send_receipt(self, username: str, sender: Address, timestamps: List[int],
  136. when: Optional[int] = None, read: bool = False) -> None:
  137. await self.request_nowait("mark_read" if read else "mark_delivered", username=username,
  138. timestamps=timestamps, when=when,
  139. recipientAddress=sender.serialize())
  140. async def list_contacts(self, username: str) -> List[Contact]:
  141. contacts = await self.request("list_contacts", "contact_list", username=username)
  142. return [Contact.deserialize(contact) for contact in contacts]
  143. async def list_groups(self, username: str) -> List[Union[Group, GroupV2]]:
  144. resp = await self.request("list_groups", "group_list", username=username)
  145. return ([Group.deserialize(group) for group in resp["groups"]]
  146. + [GroupV2.deserialize(group) for group in resp["groupsv2"]])
  147. async def get_group(self, username: str, group_id: GroupID, revision: int = -1
  148. ) -> Optional[GroupV2]:
  149. resp = await self.request("get_group", "get_group", account=username, groupID=group_id,
  150. version="v1", revision=revision)
  151. if "id" not in resp:
  152. return None
  153. return GroupV2.deserialize(resp)
  154. async def get_profile(self, username: str, address: Address) -> Optional[Profile]:
  155. try:
  156. resp = await self.request("get_profile", "get_profile", account=username,
  157. address=address.serialize(), version="v1")
  158. except UnexpectedResponse as e:
  159. if e.resp_type == "profile_not_available":
  160. return None
  161. raise
  162. return Profile.deserialize(resp)
  163. async def get_identities(self, username: str, address: Address) -> GetIdentitiesResponse:
  164. resp = await self.request("get_identities", "identities", username=username,
  165. recipientAddress=address.serialize())
  166. return GetIdentitiesResponse.deserialize(resp)
  167. async def set_profile(self, username: str, new_name: str) -> None:
  168. await self.request("set_profile", "profile_set", username=username, name=new_name)