소스 검색

Merge pull request #199 from mautrix/sumner/bri-1634-it-takes-forever-for-mautrix-signal-to

login: use wait_for_scan to figure out when the user scans the QR code
Sumner Evans 3 년 전
부모
커밋
9495b7383c
4개의 변경된 파일152개의 추가작업 그리고 48개의 파일을 삭제
  1. 3 0
      mausignald/signald.py
  2. 2 0
      mautrix_signal/config.py
  3. 1 1
      mautrix_signal/example-config.yaml
  4. 146 47
      mautrix_signal/web/provisioning_api.py

+ 3 - 0
mausignald/signald.py

@@ -150,6 +150,9 @@ class SignaldClient(SignaldRPCClient):
     async def start_link(self) -> LinkSession:
         return LinkSession.deserialize(await self.request_v1("generate_linking_uri"))
 
+    async def wait_for_scan(self, session_id: str) -> None:
+        await self.request_v1("wait_for_scan", session_id=session_id)
+
     async def finish_link(
         self, session_id: str, device_name: str = "mausignald", overwrite: bool = False
     ) -> Account:

+ 2 - 0
mautrix_signal/config.py

@@ -91,6 +91,8 @@ class Config(BaseBridgeConfig):
 
         copy("bridge.provisioning.enabled")
         copy("bridge.provisioning.prefix")
+        if base["bridge.provisioning.prefix"].endswith("/v1"):
+            base["bridge.provisioning.prefix"] = base["bridge.provisioning.prefix"][: -len("/v1")]
         copy("bridge.provisioning.shared_secret")
         if base["bridge.provisioning.shared_secret"] == "generate":
             base["bridge.provisioning.shared_secret"] = self._new_token()

+ 1 - 1
mautrix_signal/example-config.yaml

@@ -198,7 +198,7 @@ bridge:
         # Whether or not the provisioning API should be enabled.
         enabled: true
         # The prefix to use in the provisioning API endpoints.
-        prefix: /_matrix/provision/v1
+        prefix: /_matrix/provision
         # The shared secret to authorize users of the API.
         # Set to "generate" to generate and save a new token.
         shared_secret: generate

+ 146 - 47
mautrix_signal/web/provisioning_api.py

@@ -25,7 +25,7 @@ from aiohttp import web
 from mausignald.errors import InternalError, TimeoutException
 from mausignald.types import Account, Address
 from mautrix.types import UserID
-from mautrix.util.bridge_state import BridgeStateEvent
+from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
 from mautrix.util.logging import TraceLogger
 
 from .. import user as u
@@ -43,17 +43,30 @@ class ProvisioningAPI:
         self.bridge = bridge
         self.app = web.Application()
         self.shared_secret = shared_secret
-        self.app.router.add_get("/api/whoami", self.status)
-        self.app.router.add_options("/api/link", self.login_options)
-        self.app.router.add_options("/api/link/wait", self.login_options)
-        # self.app.router.add_options("/api/register", self.login_options)
-        # self.app.router.add_options("/api/register/code", self.login_options)
-        self.app.router.add_options("/api/logout", self.login_options)
-        self.app.router.add_post("/api/link", self.link)
-        self.app.router.add_post("/api/link/wait", self.link_wait)
-        # self.app.router.add_post("/api/register", self.register)
-        # self.app.router.add_post("/api/register/code", self.register_code)
-        self.app.router.add_post("/api/logout", self.logout)
+
+        # Whoami
+        self.app.router.add_get("/v1/api/whoami", self.status)
+        self.app.router.add_get("/v2/whoami", self.status)
+
+        # Logout
+        self.app.router.add_options("/v1/api/logout", self.login_options)
+        self.app.router.add_post("/v1/api/logout", self.logout)
+        self.app.router.add_options("/v2/logout", self.login_options)
+        self.app.router.add_post("/v2/logout", self.logout)
+
+        # Link API (will be deprecated soon)
+        self.app.router.add_options("/v1/api/link", self.login_options)
+        self.app.router.add_options("/v1/api/link/wait", self.login_options)
+        self.app.router.add_post("/v1/api/link", self.link)
+        self.app.router.add_post("/v1/api/link/wait", self.link_wait)
+
+        # New Login API
+        self.app.router.add_options("/v2/link/new", self.login_options)
+        self.app.router.add_options("/v2/link/wait/scan", self.login_options)
+        self.app.router.add_options("/v2/link/wait/account", self.login_options)
+        self.app.router.add_post("/v2/link/new", self.link_new)
+        self.app.router.add_post("/v2/link/wait/scan", self.link_wait_for_scan)
+        self.app.router.add_post("/v2/link/wait/account", self.link_wait_for_account)
 
     @property
     def _acao_headers(self) -> dict[str, str]:
@@ -135,63 +148,34 @@ class ProvisioningAPI:
                 }
         return web.json_response(data, headers=self._acao_headers)
 
-    async def link(self, request: web.Request) -> web.Response:
-        user = await self.check_token(request)
-
-        if await user.is_logged_in():
-            raise web.HTTPConflict(
-                text="""{"error": "You're already logged in"}""", headers=self._headers
-            )
-
-        try:
-            data = await request.json()
-        except json.JSONDecodeError:
-            raise web.HTTPBadRequest(text='{"error": "Malformed JSON"}', headers=self._headers)
-
-        device_name = data.get("device_name", "Mautrix-Signal bridge")
-        sess = await self.bridge.signal.start_link()
-
-        user.command_status = {
-            "action": "Link",
-            "session_id": sess.session_id,
-            "device_name": device_name,
-        }
-
-        self.log.debug(f"Returning linking URI for {user.mxid} / {sess.session_id}")
-        return web.json_response({"uri": sess.uri}, headers=self._acao_headers)
-
     async def _shielded_link(self, user: "u.User", session_id: str, device_name: str) -> Account:
         try:
             self.log.debug(f"Starting finish link request for {user.mxid} / {session_id}")
             account = await self.bridge.signal.finish_link(
-                session_id=session_id, overwrite=True, device_name=device_name
+                session_id=session_id, device_name=device_name, overwrite=True
             )
         except TimeoutException:
             self.log.warning(f"Timed out waiting for linking to finish (session {session_id})")
             raise
         except Exception:
             self.log.exception(
-                "Fatal error while waiting for linking to finish (session {session_id})"
+                f"Fatal error while waiting for linking to finish (session {session_id})"
             )
             raise
         else:
             await user.on_signin(account)
             return account
 
-    async def link_wait(self, request: web.Request) -> web.Response:
-        user = await self.check_token(request)
-        if not user.command_status or user.command_status["action"] != "Link":
-            raise web.HTTPBadRequest(
-                text='{"error": "No Signal linking started"}', headers=self._headers
-            )
-        session_id = user.command_status["session_id"]
-        device_name = user.command_status["device_name"]
+    async def _try_shielded_link(
+        self, user: "u.User", session_id: str, device_name: str
+    ) -> web.Response:
         try:
             account = await asyncio.shield(self._shielded_link(user, session_id, device_name))
         except asyncio.CancelledError:
             self.log.warning(
                 f"Client cancelled link wait request ({session_id}) before it finished"
             )
+            raise
         except TimeoutException:
             raise web.HTTPBadRequest(
                 text='{"error": "Signal linking timed out"}', headers=self._headers
@@ -212,6 +196,121 @@ class ProvisioningAPI:
         else:
             return web.json_response(account.address.serialize())
 
+    # region Old Link API
+
+    async def link(self, request: web.Request) -> web.Response:
+        user = await self.check_token(request)
+
+        if await user.is_logged_in():
+            raise web.HTTPConflict(
+                text="""{"error": "You're already logged in"}""", headers=self._headers
+            )
+
+        try:
+            data = await request.json()
+        except json.JSONDecodeError:
+            raise web.HTTPBadRequest(text='{"error": "Malformed JSON"}', headers=self._headers)
+
+        device_name = data.get("device_name", "Mautrix-Signal bridge")
+        sess = await self.bridge.signal.start_link()
+
+        user.command_status = {
+            "action": "Link",
+            "session_id": sess.session_id,
+            "device_name": device_name,
+        }
+
+        self.log.debug(f"Returning linking URI for {user.mxid} / {sess.session_id}")
+        return web.json_response({"uri": sess.uri}, headers=self._acao_headers)
+
+    async def link_wait(self, request: web.Request) -> web.Response:
+        user = await self.check_token(request)
+        if not user.command_status or user.command_status["action"] != "Link":
+            raise web.HTTPBadRequest(
+                text='{"error": "No Signal linking started"}', headers=self._headers
+            )
+        session_id = user.command_status["session_id"]
+        device_name = user.command_status["device_name"]
+        return await self._try_shielded_link(user, session_id, device_name)
+
+    # endregion
+
+    # region New Link API
+
+    async def _get_request_data(self, request: web.Request) -> tuple[u.User, dict]:
+        user = await self.check_token(request)
+        if await user.is_logged_in():
+            error_text = """{"error": "You're already logged in"}"""
+            raise web.HTTPConflict(text=error_text, headers=self._headers)
+
+        try:
+            return user, (await request.json())
+        except json.JSONDecodeError:
+            raise web.HTTPBadRequest(text='{"error": "Malformed JSON"}', headers=self._headers)
+
+    async def link_new(self, request: web.Request) -> web.Response:
+        """
+        Starts a new link session.
+
+        Params: none
+
+        Returns a JSON object with the following fields:
+
+        * session_id: a session ID that should be used for all future link-related commands
+          (wait_for_scan and wait_for_account).
+        * uri: a URI that should be used to display the QR code.
+        """
+        user, _ = await self._get_request_data(request)
+        self.log.debug(f"Getting session ID and link URI for {user.mxid}")
+        sess = await self.bridge.signal.start_link()
+        self.log.debug(f"Returning session ID and link URI for {user.mxid} / {sess.session_id}")
+        return web.json_response(sess.serialize(), headers=self._acao_headers)
+
+    async def link_wait_for_scan(self, request: web.Request) -> web.Response:
+        """
+        Waits for the QR code associated with the provided session ID to be scanned.
+
+        Params: a JSON object with the following field:
+
+        * session_id: a session ID that you got from a call to /link/v2/new.
+        """
+        _, request_data = await self._get_request_data(request)
+        try:
+            session_id = request_data["session_id"]
+        except KeyError:
+            error_text = '{"error": "session_id not provided"}'
+            raise web.HTTPBadRequest(text=error_text, headers=self._headers)
+
+        try:
+            await self.bridge.signal.wait_for_scan(session_id)
+        except Exception as e:
+            error_text = f"Failed waiting for scan. Error: {e}"
+            self.log.exception(error_text)
+            raise web.HTTPBadRequest(text=error_text, headers=self._headers)
+        else:
+            return web.json_response({}, headers=self._acao_headers)
+
+    async def link_wait_for_account(self, request: web.Request) -> web.Response:
+        """
+        Waits for the link to the user's phone to complete.
+
+        Params: a JSON object with the following fields:
+
+        * session_id: a session ID that you got from a call to /link/v2/new.
+        * device_name: the device name that will show up in Linked Devices on the user's device.
+
+        Returns: a JSON object representing the user's account.
+        """
+        user, request_data = await self._get_request_data(request)
+        try:
+            session_id = request_data["session_id"]
+            device_name = request_data.get("device_name", "Mautrix-Signal bridge")
+        except KeyError:
+            error_text = '{"error": "session_id not provided"}'
+            raise web.HTTPBadRequest(text=error_text, headers=self._headers)
+
+        return await self._try_shielded_link(user, session_id, device_name)
+
     async def logout(self, request: web.Request) -> web.Response:
         user = await self.check_token(request)
         if not await user.is_logged_in():