瀏覽代碼

Add asyncio shields around linking and other RPC requests

Tulir Asokan 4 年之前
父節點
當前提交
4f7d14b688
共有 2 個文件被更改,包括 19 次插入6 次删除
  1. 1 1
      mausignald/rpc.py
  2. 18 5
      mautrix_signal/web/provisioning_api.py

+ 1 - 1
mausignald/rpc.py

@@ -198,7 +198,7 @@ class SignaldRPCClient:
                            ) -> Tuple[str, Dict[str, Any]]:
         future, data = self._create_request(command, req_id, **data)
         await self._send_request(data)
-        return await future
+        return await asyncio.shield(future)
 
     async def request(self, command: str, expected_response: str, **data: Any) -> Any:
         resp_type, resp_data = await self._raw_request(command, **data)

+ 18 - 5
mautrix_signal/web/provisioning_api.py

@@ -15,11 +15,12 @@
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 from typing import Awaitable, Dict, TYPE_CHECKING
 import logging
+import asyncio
 import json
 
 from aiohttp import web
 
-from mausignald.types import Address
+from mausignald.types import Address, Account
 from mausignald.errors import TimeoutException
 from mautrix.types import UserID
 from mautrix.util.logging import TraceLogger
@@ -128,6 +129,19 @@ class ProvisioningAPI:
 
         return web.json_response({"uri": sess.uri}, headers=self._acao_headers)
 
+    async def _shielded_link(self, user: 'u.User', session_id: str, device_name: str) -> Account:
+        try:
+            account = await self.bridge.signal.finish_link(session_id=session_id,
+                                                           device_name=device_name)
+        except TimeoutException:
+            raise
+        except Exception:
+            self.log.exception("Fatal error while waiting for linking to finish")
+            raise
+        else:
+            await user.on_signin(account)
+            return account
+
     async def link_wait(self, request: web.Request) -> web.Response:
         user = await self.check_token(request)
         if not user.command_status or user.command_status["action"] != "Link":
@@ -136,17 +150,16 @@ class ProvisioningAPI:
         session_id = user.command_status["session_id"]
         device_name = user.command_status["device_name"]
         try:
-            account = await self.bridge.signal.finish_link(session_id=session_id,
-                                                           device_name=device_name)
+            account = await asyncio.shield(self._shielded_link(user, session_id, device_name))
+        except asyncio.CancelledError:
+            self.log.warning("Client cancelled link wait request before it finished")
         except TimeoutException:
             raise web.HTTPBadRequest(text='{"error": "Signal linking timed out"}',
                                      headers=self._headers)
         except Exception:
-            self.log.exception("Fatal error while waiting for linking to finish")
             raise web.HTTPInternalServerError(text='{"error": "Fatal error in Signal linking"}',
                                               headers=self._headers)
         else:
-            await user.on_signin(account)
             return web.json_response(account.address.serialize())
 
     async def logout(self, request: web.Request) -> web.Response: