rpc.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # Copyright (c) 2020 Tulir Asokan
  2. #
  3. # This Source Code Form is subject to the terms of the Mozilla Public
  4. # License, v. 2.0. If a copy of the MPL was not distributed with this
  5. # file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. from typing import Optional, Dict, List, Callable, Awaitable, Any, Tuple
  7. from uuid import UUID, uuid4
  8. import asyncio
  9. import logging
  10. import json
  11. from mautrix.util.logging import TraceLogger
  12. from .errors import NotConnected, UnexpectedError, UnexpectedResponse, make_response_error
  13. EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
  14. # These are synthetic RPC events for registering callbacks on socket
  15. # connect and disconnect.
  16. CONNECT_EVENT = "_socket_connected"
  17. DISCONNECT_EVENT = "_socket_disconnected"
  18. class SignaldRPCClient:
  19. loop: asyncio.AbstractEventLoop
  20. log: TraceLogger
  21. socket_path: str
  22. _reader: Optional[asyncio.StreamReader]
  23. _writer: Optional[asyncio.StreamWriter]
  24. _communicate_task: Optional[asyncio.Task]
  25. _response_waiters: Dict[UUID, asyncio.Future]
  26. _rpc_event_handlers: Dict[str, List[EventHandler]]
  27. def __init__(self, socket_path: str, log: Optional[TraceLogger] = None,
  28. loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
  29. self.socket_path = socket_path
  30. self.log = log or logging.getLogger("mausignald")
  31. self.loop = loop or asyncio.get_event_loop()
  32. self._reader = None
  33. self._writer = None
  34. self._communicate_task = None
  35. self._response_waiters = {}
  36. self._rpc_event_handlers = {CONNECT_EVENT: [], DISCONNECT_EVENT: []}
  37. self.add_rpc_handler(DISCONNECT_EVENT, self._abandon_responses)
  38. async def connect(self) -> None:
  39. if self._writer is not None:
  40. return
  41. initial_connect = self.loop.create_future()
  42. self._communicate_task = self.loop.create_task(self._communicate_forever(initial_connect))
  43. await initial_connect
  44. async def _communicate_forever(self, initial_connect: Optional[asyncio.Future] = None) -> None:
  45. while True:
  46. try:
  47. self._reader, self._writer = await asyncio.open_unix_connection(self.socket_path)
  48. except OSError as e:
  49. self.log.error(f"Connection to {self.socket_path} failed: {e}")
  50. await asyncio.sleep(5)
  51. continue
  52. read_loop = self.loop.create_task(self._try_read_loop())
  53. await self._run_rpc_handler(CONNECT_EVENT, {})
  54. if initial_connect:
  55. initial_connect.set_result(True)
  56. initial_connect = None
  57. await read_loop
  58. await self._run_rpc_handler(DISCONNECT_EVENT, {})
  59. async def disconnect(self) -> None:
  60. if self._writer is not None:
  61. self._writer.write_eof()
  62. await self._writer.drain()
  63. if self._communicate_task:
  64. self._communicate_task.cancel()
  65. self._communicate_task = None
  66. self._writer = None
  67. self._reader = None
  68. def add_rpc_handler(self, method: str, handler: EventHandler) -> None:
  69. self._rpc_event_handlers.setdefault(method, []).append(handler)
  70. def remove_rpc_handler(self, method: str, handler: EventHandler) -> None:
  71. self._rpc_event_handlers.setdefault(method, []).remove(handler)
  72. async def _run_rpc_handler(self, command: str, req: Dict[str, Any]) -> None:
  73. try:
  74. handlers = self._rpc_event_handlers[command]
  75. except KeyError:
  76. self.log.warning("No handlers for RPC request %s", command)
  77. self.log.trace("Data unhandled request: %s", req)
  78. else:
  79. for handler in handlers:
  80. try:
  81. await handler(req)
  82. except Exception:
  83. self.log.exception("Exception in RPC event handler")
  84. def _run_response_handlers(self, req_id: UUID, command: str, req: Any) -> None:
  85. try:
  86. waiter = self._response_waiters.pop(req_id)
  87. except KeyError:
  88. self.log.debug(f"Nobody waiting for response to {req_id}")
  89. return
  90. data = req.get("data")
  91. if command == "unexpected_error":
  92. try:
  93. waiter.set_exception(UnexpectedError(data["message"]))
  94. except KeyError:
  95. waiter.set_exception(UnexpectedError("Unexpected error with no message"))
  96. elif data and "error" in data and isinstance(data["error"], (str, dict)):
  97. waiter.set_exception(make_response_error(data["error"]))
  98. elif "error" in req and isinstance(req["error"], (str, dict)):
  99. waiter.set_exception(make_response_error(req["error"]))
  100. else:
  101. waiter.set_result((command, data))
  102. async def _handle_incoming_line(self, line: str) -> None:
  103. try:
  104. req = json.loads(line)
  105. except json.JSONDecodeError:
  106. self.log.debug(f"Got non-JSON data from server: {line}")
  107. return
  108. try:
  109. req_type = req["type"]
  110. except KeyError:
  111. self.log.debug(f"Got invalid request from server: {line}")
  112. return
  113. self.log.trace("Got data from server: %s", req)
  114. req_id = req.get("id")
  115. if req_id is None:
  116. self.loop.create_task(self._run_rpc_handler(req_type, req))
  117. else:
  118. self._run_response_handlers(UUID(req_id), req_type, req)
  119. async def _try_read_loop(self) -> None:
  120. try:
  121. await self._read_loop()
  122. except Exception:
  123. self.log.exception("Fatal error in read loop")
  124. async def _read_loop(self) -> None:
  125. while self._reader is not None and not self._reader.at_eof():
  126. line = await self._reader.readline()
  127. if not line:
  128. continue
  129. try:
  130. line_str = line.decode("utf-8")
  131. except UnicodeDecodeError:
  132. self.log.exception("Got non-unicode request from server: %s", line)
  133. continue
  134. try:
  135. await self._handle_incoming_line(line_str)
  136. except Exception:
  137. self.log.exception("Failed to handle incoming request %s", line_str)
  138. self.log.debug("Reader disconnected")
  139. self._reader = None
  140. self._writer = None
  141. def _create_request(self, command: str, req_id: Optional[UUID] = None, **data: Any
  142. ) -> Tuple[asyncio.Future, Dict[str, Any]]:
  143. req_id = req_id or uuid4()
  144. req = {"id": str(req_id), "type": command, **data}
  145. self.log.trace("Request %s: %s %s", req_id, command, data)
  146. return self._wait_response(req_id), req
  147. def _wait_response(self, req_id: UUID) -> asyncio.Future:
  148. try:
  149. future = self._response_waiters[req_id]
  150. except KeyError:
  151. future = self._response_waiters[req_id] = self.loop.create_future()
  152. return future
  153. async def _abandon_responses(self, unused_data: Dict[str, Any]) -> None:
  154. for req_id, waiter in self._response_waiters.items():
  155. if not waiter.done():
  156. self.log.trace(f"Abandoning response for {req_id}")
  157. waiter.set_exception(
  158. NotConnected("Disconnected from signald before RPC completed"))
  159. async def _send_request(self, data: Dict[str, Any]) -> None:
  160. if self._writer is None:
  161. raise NotConnected("Not connected to signald")
  162. self._writer.write(json.dumps(data).encode("utf-8"))
  163. self._writer.write(b"\n")
  164. await self._writer.drain()
  165. self.log.trace("Sent data to server server: %s", data)
  166. async def _raw_request(self, command: str, req_id: Optional[UUID] = None, **data: Any
  167. ) -> Tuple[str, Dict[str, Any]]:
  168. future, data = self._create_request(command, req_id, **data)
  169. await self._send_request(data)
  170. return await future
  171. async def request(self, command: str, expected_response: str, **data: Any) -> Any:
  172. resp_type, resp_data = await self._raw_request(command, **data)
  173. if resp_type != expected_response:
  174. raise UnexpectedResponse(resp_type, resp_data)
  175. return resp_data
  176. async def request_v0(self, command: str, expected_response: str, **data: Any) -> Any:
  177. return await self.request(command, expected_response, version="v0", **data)
  178. async def request_v1(self, command: str, **data: Any) -> Any:
  179. return await self.request(command, expected_response=command, version="v1", **data)
  180. async def request_nowait(self, command: str, **data: Any) -> None:
  181. _, req = self._create_request(command, **data)
  182. await self._send_request(req)