فهرست منبع

provisioning: split wait_for_scan to a separate call

Sumner Evans 3 سال پیش
والد
کامیت
73d62e9cc0
2فایلهای تغییر یافته به همراه24 افزوده شده و 14 حذف شده
  1. 0 2
      mautrix_signal/user.py
  2. 24 12
      mautrix_signal/web/provisioning_api.py

+ 0 - 2
mautrix_signal/user.py

@@ -155,8 +155,6 @@ class User(DBUser, BaseUser):
         self.username = account.account_id
         self.uuid = account.address.uuid
         self._add_to_cache()
-        # Push a remote state immediately so that the client knows that it's doing something.
-        asyncio.create_task(self.push_bridge_state(BridgeStateEvent.CONNECTING))
         await self.update()
         await self.bridge.signal.subscribe(self.username)
         asyncio.create_task(self.sync())

+ 24 - 12
mautrix_signal/web/provisioning_api.py

@@ -50,6 +50,7 @@ class ProvisioningAPI:
         # self.app.router.add_options("/api/register/code", self.login_options)
         self.app.router.add_options("/api/logout", self.login_options)
         self.app.router.add_post("/api/link", self.link)
+        self.app.router.add_post("/api/link/wait_for_scan", self.link_wait_for_scan)
         self.app.router.add_post("/api/link/wait", self.link_wait)
         # self.app.router.add_post("/api/register", self.register)
         # self.app.router.add_post("/api/register/code", self.register_code)
@@ -160,6 +161,28 @@ class ProvisioningAPI:
         self.log.debug(f"Returning linking URI for {user.mxid} / {sess.session_id}")
         return web.json_response({"uri": sess.uri}, headers=self._acao_headers)
 
+    async def link_wait_for_scan(self, request: web.Request) -> web.Response:
+        user = await self.check_token(request)
+        if not user.command_status or user.command_status["action"] != "Link":
+            raise web.HTTPBadRequest(
+                text='{"error": "No Signal linking started"}', headers=self._headers
+            )
+        session_id = user.command_status["session_id"]
+        try:
+            await self.bridge.signal.wait_for_scan(session_id)
+            status_endpoint = self.bridge.config["homeserver.status_endpoint"]
+            if status_endpoint:
+                state = BridgeState(state_event=BridgeStateEvent.CONNECTING).fill()
+                asyncio.create_task(state.send(status_endpoint, self.bridge.az.as_token, self.log))
+        except Exception as e:
+            self.log.exception(f"Failed waiting for scan. Error: {e}")
+            self.log.info(e.__class__)
+            raise web.HTTPBadRequest(
+                text='{"error": "Failed to wait for scan"}', headers=self._headers
+            )
+        else:
+            return web.json_response()
+
     async def _shielded_link(self, user: "u.User", session_id: str, device_name: str) -> Account:
         try:
             self.log.debug(f"Starting finish link request for {user.mxid} / {session_id}")
@@ -171,7 +194,7 @@ class ProvisioningAPI:
             raise
         except Exception:
             self.log.exception(
-                "Fatal error while waiting for linking to finish (session {session_id})"
+                f"Fatal error while waiting for linking to finish (session {session_id})"
             )
             raise
         else:
@@ -186,17 +209,6 @@ class ProvisioningAPI:
             )
         session_id = user.command_status["session_id"]
         device_name = user.command_status["device_name"]
-        try:
-            await self.bridge.signal.wait_for_scan(session_id)
-            status_endpoint = self.bridge.config["homeserver.status_endpoint"]
-            if status_endpoint:
-                state = BridgeState(state_event=BridgeStateEvent.CONNECTING).fill()
-                asyncio.create_task(state.send(status_endpoint, self.bridge.az.as_token, self.log))
-        except Exception as e:
-            self.log.exception(f"Failed waiting for scan. Error: {e}")
-            raise web.HTTPBadRequest(
-                text='{"error": "Failed to wait for scan"}', headers=self._headers
-            )
 
         try:
             account = await asyncio.shield(self._shielded_link(user, session_id, device_name))