Browse Source

Don't respond to provisioning API requests until signald is connected

Tulir Asokan 4 years ago
parent
commit
bb41dbb46c
3 changed files with 24 additions and 12 deletions
  1. 19 8
      mausignald/rpc.py
  2. 0 2
      mautrix_signal/signal.py
  3. 5 2
      mautrix_signal/web/provisioning_api.py

+ 19 - 8
mausignald/rpc.py

@@ -28,6 +28,8 @@ class SignaldRPCClient:
     socket_path: str
     _reader: Optional[asyncio.StreamReader]
     _writer: Optional[asyncio.StreamWriter]
+    is_connected: bool
+    _connect_future: asyncio.Future
     _communicate_task: Optional[asyncio.Task]
 
     _response_waiters: Dict[UUID, asyncio.Future]
@@ -41,19 +43,26 @@ class SignaldRPCClient:
         self._reader = None
         self._writer = None
         self._communicate_task = None
+        self.is_connected = False
+        self._connect_future = self.loop.create_future()
         self._response_waiters = {}
         self._rpc_event_handlers = {CONNECT_EVENT: [], DISCONNECT_EVENT: []}
         self.add_rpc_handler(DISCONNECT_EVENT, self._abandon_responses)
 
+    async def wait_for_connected(self, timeout: Optional[int] = None) -> bool:
+        if self.is_connected:
+            return True
+        await asyncio.wait_for(asyncio.shield(self._connect_future), timeout)
+        return self.is_connected
+
     async def connect(self) -> None:
         if self._writer is not None:
             return
 
-        initial_connect = self.loop.create_future()
-        self._communicate_task = asyncio.create_task(self._communicate_forever(initial_connect))
-        await initial_connect
+        self._communicate_task = asyncio.create_task(self._communicate_forever())
+        await self._connect_future
 
-    async def _communicate_forever(self, initial_connect: Optional[asyncio.Future] = None) -> None:
+    async def _communicate_forever(self) -> None:
         while True:
             try:
                 self._reader, self._writer = await asyncio.open_unix_connection(self.socket_path)
@@ -63,14 +72,14 @@ class SignaldRPCClient:
                 continue
 
             read_loop = asyncio.create_task(self._try_read_loop())
+            self.is_connected = True
             await self._run_rpc_handler(CONNECT_EVENT, {})
-
-            if initial_connect:
-                initial_connect.set_result(True)
-                initial_connect = None
+            self._connect_future.set_result(True)
 
             await read_loop
+            self.is_connected = False
             await self._run_rpc_handler(DISCONNECT_EVENT, {})
+            self._connect_future = self.loop.create_future()
 
     async def disconnect(self) -> None:
         if self._writer is not None:
@@ -81,6 +90,8 @@ class SignaldRPCClient:
                 self._communicate_task = None
             self._writer = None
             self._reader = None
+            self.is_connected = False
+            self._connect_future = self.loop.create_future()
 
     def add_rpc_handler(self, method: str, handler: EventHandler) -> None:
         self._rpc_event_handlers.setdefault(method, []).append(handler)

+ 0 - 2
mautrix_signal/signal.py

@@ -16,8 +16,6 @@
 from typing import Optional, List, TYPE_CHECKING
 import asyncio
 import logging
-import os.path
-import shutil
 
 from mausignald import SignaldClient
 from mausignald.types import (Message, MessageData, Address, TypingNotification, TypingAction,

+ 5 - 2
mautrix_signal/web/provisioning_api.py

@@ -70,7 +70,7 @@ class ProvisioningAPI:
     async def login_options(self, _: web.Request) -> web.Response:
         return web.Response(status=200, headers=self._headers)
 
-    def check_token(self, request: web.Request) -> Awaitable['u.User']:
+    async def check_token(self, request: web.Request) -> 'u.User':
         try:
             token = request.headers["Authorization"]
             token = token[len("Bearer "):]
@@ -88,7 +88,10 @@ class ProvisioningAPI:
             raise web.HTTPBadRequest(text='{"error": "Missing user_id query param"}',
                                      headers=self._headers)
 
-        return u.User.get_by_mxid(UserID(user_id))
+        if not self.bridge.signal.is_connected:
+            await self.bridge.signal.wait_for_connected()
+
+        return await u.User.get_by_mxid(UserID(user_id))
 
     async def status(self, request: web.Request) -> web.Response:
         user = await self.check_token(request)