signald.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # Copyright (c) 2020 Tulir Asokan
  2. #
  3. # This Source Code Form is subject to the terms of the Mozilla Public
  4. # License, v. 2.0. If a copy of the MPL was not distributed with this
  5. # file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, Set, TypeVar, Type
  7. import asyncio
  8. from mautrix.util.logging import TraceLogger
  9. from .rpc import CONNECT_EVENT, SignaldRPCClient
  10. from .errors import UnexpectedError, UnexpectedResponse
  11. from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
  12. Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2,
  13. Mention, LinkSession)
  14. T = TypeVar('T')
  15. EventHandler = Callable[[T], Awaitable[None]]
  16. class SignaldClient(SignaldRPCClient):
  17. _event_handlers: Dict[Type[T], List[EventHandler]]
  18. _subscriptions: Set[str]
  19. def __init__(self, socket_path: str = "/var/run/signald/signald.sock",
  20. log: Optional[TraceLogger] = None,
  21. loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
  22. super().__init__(socket_path, log, loop)
  23. self._event_handlers = {}
  24. self._subscriptions = set()
  25. self.add_rpc_handler("message", self._parse_message)
  26. self.add_rpc_handler("listen_started", self._parse_listen_start)
  27. self.add_rpc_handler("listen_stopped", self._parse_listen_stop)
  28. self.add_rpc_handler("version", self._log_version)
  29. self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
  30. def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
  31. self._event_handlers.setdefault(event_class, []).append(handler)
  32. def remove_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
  33. self._event_handlers.setdefault(event_class, []).remove(handler)
  34. async def _run_event_handler(self, event: T) -> None:
  35. try:
  36. handlers = self._event_handlers[type(event)]
  37. except KeyError:
  38. self.log.warning(f"No handlers for {type(event)}")
  39. else:
  40. for handler in handlers:
  41. try:
  42. await handler(event)
  43. except Exception:
  44. self.log.exception("Exception in event handler")
  45. async def _parse_message(self, data: Dict[str, Any]) -> None:
  46. event_type = data["type"]
  47. event_data = data["data"]
  48. event_class = {
  49. "message": Message,
  50. }[event_type]
  51. event = event_class.deserialize(event_data)
  52. await self._run_event_handler(event)
  53. async def _log_version(self, data: Dict[str, Any]) -> None:
  54. name = data["data"]["name"]
  55. version = data["data"]["version"]
  56. self.log.info(f"Connected to {name} v{version}")
  57. async def _parse_listen_start(self, data: Dict[str, Any]) -> None:
  58. evt = ListenEvent(action=ListenAction.STARTED, username=data["data"])
  59. await self._run_event_handler(evt)
  60. async def _parse_listen_stop(self, data: Dict[str, Any]) -> None:
  61. evt = ListenEvent(action=ListenAction.STOPPED, username=data["data"],
  62. exception=data.get("exception", None))
  63. await self._run_event_handler(evt)
  64. async def subscribe(self, username: str) -> bool:
  65. try:
  66. await self.request("subscribe", "subscribed", username=username)
  67. self._subscriptions.add(username)
  68. return True
  69. except UnexpectedError as e:
  70. self.log.debug("Failed to subscribe to %s: %s", username, e)
  71. return False
  72. async def unsubscribe(self, username: str) -> bool:
  73. try:
  74. await self.request("unsubscribe", "unsubscribed", username=username)
  75. self._subscriptions.remove(username)
  76. return True
  77. except UnexpectedError as e:
  78. self.log.debug("Failed to unsubscribe from %s: %s", username, e)
  79. return False
  80. async def _resubscribe(self, unused_data: Dict[str, Any]) -> None:
  81. if self._subscriptions:
  82. self.log.debug("Resubscribing to users")
  83. for username in list(self._subscriptions):
  84. await self.subscribe(username)
  85. async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
  86. ) -> str:
  87. resp = await self.request("register", "verification_required", username=phone,
  88. voice=voice, captcha=captcha)
  89. return resp["username"]
  90. async def verify(self, username: str, code: str) -> Account:
  91. resp = await self.request("verify", "verification_succeeded", username=username, code=code)
  92. return Account.deserialize(resp)
  93. async def start_link(self) -> LinkSession:
  94. return LinkSession.deserialize(await self.request_v1("generate_linking_uri"))
  95. async def finish_link(self, session_id: str, device_name: str = "mausignald") -> Account:
  96. resp = await self.request_v1("finish_link", device_name=device_name, session_id=session_id)
  97. return Account.deserialize(resp)
  98. @staticmethod
  99. def _recipient_to_args(recipient: Union[Address, GroupID]) -> Dict[str, Any]:
  100. if isinstance(recipient, Address):
  101. return {"recipientAddress": recipient.serialize()}
  102. else:
  103. return {"recipientGroupId": recipient}
  104. async def react(self, username: str, recipient: Union[Address, GroupID],
  105. reaction: Reaction) -> None:
  106. await self.request_v1("react", username=username, reaction=reaction.serialize(),
  107. **self._recipient_to_args(recipient))
  108. async def send(self, username: str, recipient: Union[Address, GroupID], body: str,
  109. quote: Optional[Quote] = None, attachments: Optional[List[Attachment]] = None,
  110. mentions: Optional[List[Mention]] = None, timestamp: Optional[int] = None
  111. ) -> None:
  112. serialized_quote = quote.serialize() if quote else None
  113. serialized_attachments = [attachment.serialize() for attachment in (attachments or [])]
  114. serialized_mentions = [mention.serialize() for mention in (mentions or [])]
  115. await self.request_v1("send", username=username, messageBody=body,
  116. attachments=serialized_attachments, quote=serialized_quote,
  117. mentions=serialized_mentions, timestamp=timestamp,
  118. **self._recipient_to_args(recipient))
  119. # TODO return something?
  120. async def send_receipt(self, username: str, sender: Address, timestamps: List[int],
  121. when: Optional[int] = None, read: bool = False) -> None:
  122. await self.request_nowait("mark_read" if read else "mark_delivered", username=username,
  123. timestamps=timestamps, when=when,
  124. recipientAddress=sender.serialize())
  125. async def list_contacts(self, username: str) -> List[Contact]:
  126. contacts = await self.request_v0("list_contacts", "contact_list", username=username)
  127. return [Contact.deserialize(contact) for contact in contacts]
  128. async def list_groups(self, username: str) -> List[Union[Group, GroupV2]]:
  129. resp = await self.request_v1("list_groups", account=username)
  130. legacy = [Group.deserialize(group) for group in resp.get("legacyGroups", [])]
  131. v2 = [GroupV2.deserialize(group) for group in resp.get("groups", [])]
  132. return legacy + v2
  133. async def update_group(self, username: str, group_id: GroupID, title: Optional[str] = None,
  134. avatar_path: Optional[str] = None,
  135. add_members: Optional[List[Address]] = None,
  136. remove_members: Optional[List[Address]] = None
  137. ) -> Union[Group, GroupV2, None]:
  138. update_params = {key: value for key, value in {
  139. "groupID": group_id,
  140. "avatar": avatar_path,
  141. "title": title,
  142. "addMembers": [addr.serialize() for addr in add_members] if add_members else None,
  143. "removeMembers": ([addr.serialize() for addr in remove_members]
  144. if remove_members else None),
  145. }.items() if value is not None}
  146. resp = await self.request_v1("update_group", account=username, **update_params)
  147. if "v1" in resp:
  148. return Group.deserialize(resp["v1"])
  149. elif "v2" in resp:
  150. return GroupV2.deserialize(resp["v2"])
  151. else:
  152. return None
  153. async def accept_invitation(self, username: str, group_id: GroupID) -> GroupV2:
  154. resp = await self.request_v1("accept_invitation", account=username, groupID=group_id)
  155. return GroupV2.deserialize(resp)
  156. async def get_group(self, username: str, group_id: GroupID, revision: int = -1
  157. ) -> Optional[GroupV2]:
  158. resp = await self.request_v1("get_group", account=username, groupID=group_id,
  159. revision=revision)
  160. if "id" not in resp:
  161. return None
  162. return GroupV2.deserialize(resp)
  163. async def get_profile(self, username: str, address: Address) -> Optional[Profile]:
  164. try:
  165. resp = await self.request_v1("get_profile", account=username,
  166. address=address.serialize())
  167. except UnexpectedResponse as e:
  168. if e.resp_type == "profile_not_available":
  169. return None
  170. raise
  171. return Profile.deserialize(resp)
  172. async def get_identities(self, username: str, address: Address) -> GetIdentitiesResponse:
  173. resp = await self.request_v0("get_identities", "identities", username=username,
  174. recipientAddress=address.serialize())
  175. return GetIdentitiesResponse.deserialize(resp)
  176. async def set_profile(self, username: str, name: Optional[str] = None,
  177. avatar_path: Optional[str] = None) -> None:
  178. args = {}
  179. if name is not None:
  180. args["name"] = name
  181. if avatar_path is not None:
  182. args["avatarFile"] = avatar_path
  183. await self.request_v1("set_profile", account=username, **args)
  184. async def trust(self, username: str, recipient: Address, fingerprint: str, trust_level: str
  185. ) -> str:
  186. resp = await self.request_v0("trust", "trusted_safety_number", username=username,
  187. fingerprint=fingerprint, trust_level=trust_level,
  188. recipientAddress=recipient.serialize())
  189. return resp["message"]