rpc.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # Copyright (c) 2022 Tulir Asokan
  2. #
  3. # This Source Code Form is subject to the terms of the Mozilla Public
  4. # License, v. 2.0. If a copy of the MPL was not distributed with this
  5. # file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. from __future__ import annotations
  7. from typing import Any, Awaitable, Callable, Dict
  8. from uuid import UUID, uuid4
  9. import asyncio
  10. import json
  11. import logging
  12. from mautrix.util import background_task
  13. from mautrix.util.logging import TraceLogger
  14. from .errors import NotConnected, UnexpectedError, UnexpectedResponse, make_response_error
  15. EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
  16. # These are synthetic RPC events for registering callbacks on socket
  17. # connect and disconnect.
  18. CONNECT_EVENT = "_socket_connected"
  19. DISCONNECT_EVENT = "_socket_disconnected"
  20. _SOCKET_LIMIT = 1024 * 1024 # 1 MiB
  21. class SignaldRPCClient:
  22. loop: asyncio.AbstractEventLoop
  23. log: TraceLogger
  24. socket_path: str
  25. _reader: asyncio.StreamReader | None
  26. _writer: asyncio.StreamWriter | None
  27. is_connected: bool
  28. _connect_future: asyncio.Future
  29. _communicate_task: asyncio.Task | None
  30. _response_waiters: dict[UUID, asyncio.Future]
  31. _rpc_event_handlers: dict[str, list[EventHandler]]
  32. def __init__(
  33. self,
  34. socket_path: str,
  35. log: TraceLogger | None = None,
  36. loop: asyncio.AbstractEventLoop | None = None,
  37. ) -> None:
  38. self.socket_path = socket_path
  39. self.log = log or logging.getLogger("mausignald")
  40. self.loop = loop or asyncio.get_event_loop()
  41. self._reader = None
  42. self._writer = None
  43. self._communicate_task = None
  44. self.is_connected = False
  45. self._connect_future = self.loop.create_future()
  46. self._response_waiters = {}
  47. self._rpc_event_handlers = {CONNECT_EVENT: [], DISCONNECT_EVENT: []}
  48. self.add_rpc_handler(DISCONNECT_EVENT, self._abandon_responses)
  49. async def wait_for_connected(self, timeout: int | None = None) -> bool:
  50. if self.is_connected:
  51. return True
  52. await asyncio.wait_for(asyncio.shield(self._connect_future), timeout)
  53. return self.is_connected
  54. async def connect(self) -> None:
  55. if self._writer is not None:
  56. return
  57. self._communicate_task = asyncio.create_task(self._communicate_forever())
  58. await self._connect_future
  59. async def _communicate_forever(self) -> None:
  60. while True:
  61. try:
  62. await self._communicate()
  63. except Exception:
  64. self.log.exception("Unknown error in signald socket")
  65. await asyncio.sleep(30)
  66. async def _communicate(self) -> None:
  67. try:
  68. self.log.debug(f"Connecting to {self.socket_path}...")
  69. self._reader, self._writer = await asyncio.open_unix_connection(
  70. self.socket_path, limit=_SOCKET_LIMIT
  71. )
  72. except OSError as e:
  73. self.log.error(f"Connection to {self.socket_path} failed: {e}")
  74. await asyncio.sleep(5)
  75. return
  76. read_loop = asyncio.create_task(self._try_read_loop())
  77. self.is_connected = True
  78. background_task.create(self._run_rpc_handler(CONNECT_EVENT, {}))
  79. self._connect_future.set_result(True)
  80. await read_loop
  81. self.is_connected = False
  82. self._connect_future = self.loop.create_future()
  83. await self._run_rpc_handler(DISCONNECT_EVENT, {})
  84. async def disconnect(self) -> None:
  85. if self._writer is not None:
  86. self._writer.write_eof()
  87. await self._writer.drain()
  88. if self._communicate_task:
  89. self._communicate_task.cancel()
  90. self._communicate_task = None
  91. self._writer = None
  92. self._reader = None
  93. self.is_connected = False
  94. self._connect_future = self.loop.create_future()
  95. def add_rpc_handler(self, method: str, handler: EventHandler) -> None:
  96. self._rpc_event_handlers.setdefault(method, []).append(handler)
  97. def remove_rpc_handler(self, method: str, handler: EventHandler) -> None:
  98. self._rpc_event_handlers.setdefault(method, []).remove(handler)
  99. async def _run_rpc_handler(self, command: str, req: dict[str, Any]) -> None:
  100. try:
  101. handlers = self._rpc_event_handlers[command]
  102. except KeyError:
  103. self.log.warning("No handlers for RPC request %s", command)
  104. self.log.trace("Data unhandled request: %s", req)
  105. else:
  106. for handler in handlers:
  107. try:
  108. await handler(req)
  109. except Exception:
  110. self.log.exception("Exception in RPC event handler")
  111. def _run_response_handlers(self, req_id: UUID, command: str, req: Any) -> None:
  112. try:
  113. waiter = self._response_waiters.pop(req_id)
  114. except KeyError:
  115. self.log.debug(f"Nobody waiting for response to {req_id}")
  116. return
  117. data = req.get("data")
  118. if command == "unexpected_error":
  119. try:
  120. waiter.set_exception(UnexpectedError(data["message"]))
  121. except KeyError:
  122. waiter.set_exception(UnexpectedError("Unexpected error with no message"))
  123. # elif data and "error" in data and isinstance(data["error"], (str, dict)):
  124. # waiter.set_exception(make_response_error(data))
  125. elif "error" in req and isinstance(req["error"], (str, dict)):
  126. waiter.set_exception(make_response_error(req))
  127. else:
  128. waiter.set_result((command, data))
  129. async def _handle_incoming_line(self, line: str) -> None:
  130. try:
  131. req = json.loads(line)
  132. except json.JSONDecodeError:
  133. self.log.debug(f"Got non-JSON data from server: {line}")
  134. return
  135. try:
  136. req_type = req["type"]
  137. except KeyError:
  138. self.log.debug(f"Got invalid request from server: {line}")
  139. return
  140. self.log.trace("Got data from server: %s", req)
  141. req_id = req.get("id")
  142. if req_id is None:
  143. background_task.create(self._run_rpc_handler(req_type, req))
  144. else:
  145. self._run_response_handlers(UUID(req_id), req_type, req)
  146. async def _try_read_loop(self) -> None:
  147. try:
  148. await self._read_loop()
  149. except Exception:
  150. self.log.exception("Fatal error in read loop")
  151. else:
  152. self.log.debug("Reader disconnected")
  153. finally:
  154. self._reader = None
  155. self._writer = None
  156. async def _read_loop(self) -> None:
  157. while self._reader is not None and not self._reader.at_eof():
  158. line = await self._reader.readline()
  159. if not line:
  160. continue
  161. try:
  162. line_str = line.decode("utf-8")
  163. except UnicodeDecodeError:
  164. self.log.exception("Got non-unicode request from server: %s", line)
  165. continue
  166. try:
  167. await self._handle_incoming_line(line_str)
  168. except Exception:
  169. self.log.exception("Failed to handle incoming request %s", line_str)
  170. def _create_request(
  171. self, command: str, req_id: UUID | None = None, **data: Any
  172. ) -> tuple[asyncio.Future, dict[str, Any]]:
  173. req_id = req_id or uuid4()
  174. req = {"id": str(req_id), "type": command, **data}
  175. self.log.debug("Request %s: %s", req_id, command)
  176. self.log.trace("Request %s: %s with data: %s", req_id, command, data)
  177. return self._wait_response(req_id), req
  178. def _wait_response(self, req_id: UUID) -> asyncio.Future:
  179. try:
  180. future = self._response_waiters[req_id]
  181. except KeyError:
  182. future = self._response_waiters[req_id] = self.loop.create_future()
  183. return future
  184. async def _abandon_responses(self, unused_data: dict[str, Any]) -> None:
  185. for req_id, waiter in self._response_waiters.items():
  186. if not waiter.done():
  187. self.log.trace(f"Abandoning response for {req_id}")
  188. waiter.set_exception(
  189. NotConnected("Disconnected from signald before RPC completed")
  190. )
  191. async def _send_request(self, data: dict[str, Any]) -> None:
  192. if self._writer is None:
  193. raise NotConnected("Not connected to signald")
  194. self._writer.write(json.dumps(data).encode("utf-8"))
  195. self._writer.write(b"\n")
  196. await self._writer.drain()
  197. self.log.trace("Sent data to server server: %s", data)
  198. async def _raw_request(
  199. self, command: str, req_id: UUID | None = None, **data: Any
  200. ) -> tuple[str, dict[str, Any]]:
  201. future, data = self._create_request(command, req_id, **data)
  202. await self._send_request(data)
  203. return await asyncio.shield(future)
  204. async def _request(self, command: str, expected_response: str, **data: Any) -> Any:
  205. resp_type, resp_data = await self._raw_request(command, **data)
  206. if resp_type != expected_response:
  207. raise UnexpectedResponse(resp_type, resp_data)
  208. return resp_data
  209. async def request_v1(self, command: str, **data: Any) -> Any:
  210. return await self._request(command, expected_response=command, version="v1", **data)