signald.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Copyright (c) 2020 Tulir Asokan
  2. #
  3. # This Source Code Form is subject to the terms of the Mozilla Public
  4. # License, v. 2.0. If a copy of the MPL was not distributed with this
  5. # file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, TypeVar, Type
  7. from uuid import uuid4
  8. import asyncio
  9. from mautrix.util.logging import TraceLogger
  10. from .rpc import SignaldRPCClient
  11. from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
  12. from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
  13. Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2)
  14. T = TypeVar('T')
  15. EventHandler = Callable[[T], Awaitable[None]]
  16. class SignaldClient(SignaldRPCClient):
  17. _event_handlers: Dict[Type[T], List[EventHandler]]
  18. def __init__(self, socket_path: str = "/var/run/signald/signald.sock",
  19. log: Optional[TraceLogger] = None,
  20. loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
  21. super().__init__(socket_path, log, loop)
  22. self._event_handlers = {}
  23. self.add_rpc_handler("message", self._parse_message)
  24. self.add_rpc_handler("listen_started", self._parse_listen_start)
  25. self.add_rpc_handler("listen_stopped", self._parse_listen_stop)
  26. self.add_rpc_handler("version", self._log_version)
  27. def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
  28. self._event_handlers.setdefault(event_class, []).append(handler)
  29. def remove_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
  30. self._event_handlers.setdefault(event_class, []).remove(handler)
  31. async def _run_event_handler(self, event: T) -> None:
  32. try:
  33. handlers = self._event_handlers[type(event)]
  34. except KeyError:
  35. self.log.warning(f"No handlers for {type(event)}")
  36. else:
  37. for handler in handlers:
  38. try:
  39. await handler(event)
  40. except Exception:
  41. self.log.exception("Exception in event handler")
  42. async def _parse_message(self, data: Dict[str, Any]) -> None:
  43. event_type = data["type"]
  44. event_data = data["data"]
  45. event_class = {
  46. "message": Message,
  47. }[event_type]
  48. event = event_class.deserialize(event_data)
  49. await self._run_event_handler(event)
  50. async def _log_version(self, data: Dict[str, Any]) -> None:
  51. name = data["data"]["name"]
  52. version = data["data"]["version"]
  53. self.log.info(f"Connected to {name} v{version}")
  54. async def _parse_listen_start(self, data: Dict[str, Any]) -> None:
  55. evt = ListenEvent(action=ListenAction.STARTED, username=data["data"])
  56. await self._run_event_handler(evt)
  57. async def _parse_listen_stop(self, data: Dict[str, Any]) -> None:
  58. evt = ListenEvent(action=ListenAction.STARTED, username=data["data"],
  59. exception=data.get("exception", None))
  60. await self._run_event_handler(evt)
  61. async def subscribe(self, username: str) -> bool:
  62. try:
  63. await self.request("subscribe", "subscribed", username=username)
  64. return True
  65. except UnexpectedError as e:
  66. self.log.debug("Failed to subscribe to %s: %s", username, e)
  67. return False
  68. async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
  69. ) -> str:
  70. resp = await self.request("register", "verification_required", username=phone,
  71. voice=voice, captcha=captcha)
  72. return resp["username"]
  73. async def verify(self, username: str, code: str) -> Account:
  74. resp = await self.request("verify", "verification_succeeded", username=username, code=code)
  75. return Account.deserialize(resp)
  76. async def link(self, url_callback: Callable[[str], Awaitable[None]],
  77. device_name: str = "mausignald") -> Account:
  78. req_id = uuid4()
  79. resp_type, resp = await self._raw_request("link", req_id, deviceName=device_name)
  80. if resp_type == "linking_error":
  81. raise make_linking_error(resp)
  82. elif resp_type != "linking_uri":
  83. raise UnexpectedResponse(resp_type, resp)
  84. self.loop.create_task(url_callback(resp["uri"]))
  85. resp_type, resp = await self._wait_response(req_id)
  86. if resp_type == "linking_error":
  87. raise make_linking_error(resp)
  88. elif resp_type != "linking_successful":
  89. raise UnexpectedResponse(resp_type, resp)
  90. return Account.deserialize(resp)
  91. async def list_accounts(self) -> List[Account]:
  92. data = await self.request("list_accounts", "account_list")
  93. return [Account.deserialize(acc) for acc in data["accounts"]]
  94. @staticmethod
  95. def _recipient_to_args(recipient: Union[Address, GroupID]) -> Dict[str, Any]:
  96. if isinstance(recipient, Address):
  97. return {"recipientAddress": recipient.serialize()}
  98. else:
  99. return {"recipientGroupId": recipient}
  100. async def react(self, username: str, recipient: Union[Address, GroupID],
  101. reaction: Reaction) -> None:
  102. await self.request("react", "send_results", username=username,
  103. reaction=reaction.serialize(),
  104. **self._recipient_to_args(recipient))
  105. async def send(self, username: str, recipient: Union[Address, GroupID], body: str,
  106. quote: Optional[Quote] = None, attachments: Optional[List[Attachment]] = None,
  107. timestamp: Optional[int] = None) -> None:
  108. serialized_quote = quote.serialize() if quote else None
  109. serialized_attachments = [attachment.serialize() for attachment in (attachments or [])]
  110. await self.request("send", "send_results", username=username, messageBody=body,
  111. attachments=serialized_attachments, quote=serialized_quote,
  112. timestamp=timestamp, **self._recipient_to_args(recipient))
  113. # TODO return something?
  114. async def send_receipt(self, username: str, sender: Address, timestamps: List[int],
  115. when: Optional[int] = None, read: bool = False) -> None:
  116. await self.request_nowait("mark_read" if read else "mark_delivered", username=username,
  117. timestamps=timestamps, when=when,
  118. recipientAddress=sender.serialize())
  119. async def list_contacts(self, username: str) -> List[Contact]:
  120. contacts = await self.request("list_contacts", "contact_list", username=username)
  121. return [Contact.deserialize(contact) for contact in contacts]
  122. async def list_groups(self, username: str) -> List[Union[Group, GroupV2]]:
  123. resp = await self.request("list_groups", "group_list", username=username)
  124. return ([Group.deserialize(group) for group in resp["groups"]]
  125. + [GroupV2.deserialize(group) for group in resp["groupsv2"]])
  126. async def get_profile(self, username: str, address: Address) -> Optional[Profile]:
  127. try:
  128. resp = await self.request("get_profile", "profile", username=username,
  129. recipientAddress=address.serialize())
  130. except UnexpectedResponse as e:
  131. if e.resp_type == "profile_not_available":
  132. return None
  133. raise
  134. return Profile.deserialize(resp)
  135. async def get_identities(self, username: str, address: Address) -> GetIdentitiesResponse:
  136. resp = await self.request("get_identities", "identities", username=username,
  137. recipientAddress=address.serialize())
  138. return GetIdentitiesResponse.deserialize(resp)
  139. async def set_profile(self, username: str, new_name: str) -> None:
  140. await self.request("set_profile", "profile_set", username=username, name=new_name)