123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- # Copyright (c) 2020 Tulir Asokan
- #
- # This Source Code Form is subject to the terms of the Mozilla Public
- # License, v. 2.0. If a copy of the MPL was not distributed with this
- # file, You can obtain one at http://mozilla.org/MPL/2.0/.
- from typing import Optional, Dict, List, Callable, Awaitable, Any, Tuple
- from uuid import UUID, uuid4
- import asyncio
- import logging
- import json
- from mautrix.util.logging import TraceLogger
- from .errors import UnexpectedError, UnexpectedResponse
- EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
- class SignaldRPCClient:
- loop: asyncio.AbstractEventLoop
- log: TraceLogger
- socket_path: str
- _reader: Optional[asyncio.StreamReader]
- _writer: Optional[asyncio.StreamWriter]
- _response_waiters: Dict[UUID, asyncio.Future]
- _rpc_event_handlers: Dict[str, List[EventHandler]]
- def __init__(self, socket_path: str, log: Optional[TraceLogger] = None,
- loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
- self.socket_path = socket_path
- self.log = log or logging.getLogger("mausignald")
- self.loop = loop or asyncio.get_event_loop()
- self._reader = None
- self._writer = None
- self._response_waiters = {}
- self._rpc_event_handlers = {}
- async def connect(self) -> None:
- if self._writer is not None:
- return
- self._reader, self._writer = await asyncio.open_unix_connection(self.socket_path)
- self.loop.create_task(self._try_read_loop())
- async def disconnect(self) -> None:
- self._writer.write_eof()
- await self._writer.drain()
- self._writer = None
- self._reader = None
- def add_rpc_handler(self, method: str, handler: EventHandler) -> None:
- self._rpc_event_handlers.setdefault(method, []).append(handler)
- def remove_rpc_handler(self, method: str, handler: EventHandler) -> None:
- self._rpc_event_handlers.setdefault(method, []).remove(handler)
- async def _run_rpc_handler(self, command: str, req: Dict[str, Any]) -> None:
- try:
- handlers = self._rpc_event_handlers[command]
- except KeyError:
- self.log.warning("No handlers for RPC request %s", command)
- self.log.trace("Data unhandled request: %s", req)
- else:
- for handler in handlers:
- try:
- await handler(req)
- except Exception:
- self.log.exception("Exception in RPC event handler")
- async def _run_response_handlers(self, req_id: UUID, command: str, data: Any) -> None:
- try:
- waiter = self._response_waiters.pop(req_id)
- except KeyError:
- self.log.debug(f"Nobody waiting for response to {req_id}")
- return
- if command == "unexpected_error":
- try:
- waiter.set_exception(UnexpectedError(data["message"]))
- except KeyError:
- waiter.set_exception(UnexpectedError("Unexpected error with no message"))
- else:
- waiter.set_result((command, data))
- async def _handle_incoming_line(self, line: str) -> None:
- try:
- req = json.loads(line)
- except json.JSONDecodeError:
- self.log.debug(f"Got non-JSON data from server: {line}")
- return
- try:
- req_type = req["type"]
- except KeyError:
- self.log.debug(f"Got invalid request from server: {line}")
- return
- self.log.trace("Got data from server: %s", req)
- req_id = req.get("id")
- if req_id is None:
- await self._run_rpc_handler(req_type, req)
- else:
- await self._run_response_handlers(UUID(req_id), req_type, req.get("data"))
- async def _try_read_loop(self) -> None:
- try:
- await self._read_loop()
- except Exception:
- self.log.exception("Fatal error in read loop")
- async def _read_loop(self) -> None:
- while self._reader is not None and not self._reader.at_eof():
- line = await self._reader.readline()
- if not line:
- continue
- try:
- line_str = line.decode("utf-8")
- except UnicodeDecodeError:
- self.log.exception("Got non-unicode request from server: %s", line)
- continue
- try:
- await self._handle_incoming_line(line_str)
- except Exception:
- self.log.exception("Failed to handle incoming request %s", line_str)
- self.log.debug("Reader disconnected")
- self._reader = None
- self._writer = None
- def _create_request(self, command: str, req_id: Optional[UUID] = None, **data: Any
- ) -> Tuple[asyncio.Future, Dict[str, Any]]:
- req_id = req_id or uuid4()
- req = {"id": str(req_id), "type": command, **data}
- self.log.trace("Request %s: %s %s", req_id, command, data)
- return self._wait_response(req_id), req
- def _wait_response(self, req_id: UUID) -> asyncio.Future:
- try:
- future = self._response_waiters[req_id]
- except KeyError:
- future = self._response_waiters[req_id] = self.loop.create_future()
- return future
- async def _send_request(self, data: Dict[str, Any]) -> None:
- self._writer.write(json.dumps(data).encode("utf-8"))
- self._writer.write(b"\n")
- await self._writer.drain()
- self.log.trace("Sent data to server server: %s", data)
- async def _raw_request(self, command: str, req_id: Optional[UUID] = None, **data: Any
- ) -> Tuple[str, Dict[str, Any]]:
- future, data = self._create_request(command, req_id, **data)
- await self._send_request(data)
- return await future
- async def request(self, command: str, expected_response: str, **data: Any) -> Any:
- resp_type, resp_data = await self._raw_request(command, **data)
- if resp_type != expected_response:
- raise UnexpectedResponse(resp_type, resp_data)
- return resp_data
- async def request_nowait(self, command: str, **data: Any) -> None:
- _, req = self._create_request(command, **data)
- await self._send_request(req)
|