Эх сурвалжийг харах

Merge pull request #199 from mautrix/sumner/bri-1634-it-takes-forever-for-mautrix-signal-to

login: use wait_for_scan to figure out when the user scans the QR code
Sumner Evans 3 жил өмнө
parent
commit
9495b7383c

+ 3 - 0
mausignald/signald.py

@@ -150,6 +150,9 @@ class SignaldClient(SignaldRPCClient):
     async def start_link(self) -> LinkSession:
     async def start_link(self) -> LinkSession:
         return LinkSession.deserialize(await self.request_v1("generate_linking_uri"))
         return LinkSession.deserialize(await self.request_v1("generate_linking_uri"))
 
 
+    async def wait_for_scan(self, session_id: str) -> None:
+        await self.request_v1("wait_for_scan", session_id=session_id)
+
     async def finish_link(
     async def finish_link(
         self, session_id: str, device_name: str = "mausignald", overwrite: bool = False
         self, session_id: str, device_name: str = "mausignald", overwrite: bool = False
     ) -> Account:
     ) -> Account:

+ 2 - 0
mautrix_signal/config.py

@@ -91,6 +91,8 @@ class Config(BaseBridgeConfig):
 
 
         copy("bridge.provisioning.enabled")
         copy("bridge.provisioning.enabled")
         copy("bridge.provisioning.prefix")
         copy("bridge.provisioning.prefix")
+        if base["bridge.provisioning.prefix"].endswith("/v1"):
+            base["bridge.provisioning.prefix"] = base["bridge.provisioning.prefix"][: -len("/v1")]
         copy("bridge.provisioning.shared_secret")
         copy("bridge.provisioning.shared_secret")
         if base["bridge.provisioning.shared_secret"] == "generate":
         if base["bridge.provisioning.shared_secret"] == "generate":
             base["bridge.provisioning.shared_secret"] = self._new_token()
             base["bridge.provisioning.shared_secret"] = self._new_token()

+ 1 - 1
mautrix_signal/example-config.yaml

@@ -198,7 +198,7 @@ bridge:
         # Whether or not the provisioning API should be enabled.
         # Whether or not the provisioning API should be enabled.
         enabled: true
         enabled: true
         # The prefix to use in the provisioning API endpoints.
         # The prefix to use in the provisioning API endpoints.
-        prefix: /_matrix/provision/v1
+        prefix: /_matrix/provision
         # The shared secret to authorize users of the API.
         # The shared secret to authorize users of the API.
         # Set to "generate" to generate and save a new token.
         # Set to "generate" to generate and save a new token.
         shared_secret: generate
         shared_secret: generate

+ 146 - 47
mautrix_signal/web/provisioning_api.py

@@ -25,7 +25,7 @@ from aiohttp import web
 from mausignald.errors import InternalError, TimeoutException
 from mausignald.errors import InternalError, TimeoutException
 from mausignald.types import Account, Address
 from mausignald.types import Account, Address
 from mautrix.types import UserID
 from mautrix.types import UserID
-from mautrix.util.bridge_state import BridgeStateEvent
+from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
 from .. import user as u
 from .. import user as u
@@ -43,17 +43,30 @@ class ProvisioningAPI:
         self.bridge = bridge
         self.bridge = bridge
         self.app = web.Application()
         self.app = web.Application()
         self.shared_secret = shared_secret
         self.shared_secret = shared_secret
-        self.app.router.add_get("/api/whoami", self.status)
-        self.app.router.add_options("/api/link", self.login_options)
-        self.app.router.add_options("/api/link/wait", self.login_options)
-        # self.app.router.add_options("/api/register", self.login_options)
-        # self.app.router.add_options("/api/register/code", self.login_options)
-        self.app.router.add_options("/api/logout", self.login_options)
-        self.app.router.add_post("/api/link", self.link)
-        self.app.router.add_post("/api/link/wait", self.link_wait)
-        # self.app.router.add_post("/api/register", self.register)
-        # self.app.router.add_post("/api/register/code", self.register_code)
-        self.app.router.add_post("/api/logout", self.logout)
+
+        # Whoami
+        self.app.router.add_get("/v1/api/whoami", self.status)
+        self.app.router.add_get("/v2/whoami", self.status)
+
+        # Logout
+        self.app.router.add_options("/v1/api/logout", self.login_options)
+        self.app.router.add_post("/v1/api/logout", self.logout)
+        self.app.router.add_options("/v2/logout", self.login_options)
+        self.app.router.add_post("/v2/logout", self.logout)
+
+        # Link API (will be deprecated soon)
+        self.app.router.add_options("/v1/api/link", self.login_options)
+        self.app.router.add_options("/v1/api/link/wait", self.login_options)
+        self.app.router.add_post("/v1/api/link", self.link)
+        self.app.router.add_post("/v1/api/link/wait", self.link_wait)
+
+        # New Login API
+        self.app.router.add_options("/v2/link/new", self.login_options)
+        self.app.router.add_options("/v2/link/wait/scan", self.login_options)
+        self.app.router.add_options("/v2/link/wait/account", self.login_options)
+        self.app.router.add_post("/v2/link/new", self.link_new)
+        self.app.router.add_post("/v2/link/wait/scan", self.link_wait_for_scan)
+        self.app.router.add_post("/v2/link/wait/account", self.link_wait_for_account)
 
 
     @property
     @property
     def _acao_headers(self) -> dict[str, str]:
     def _acao_headers(self) -> dict[str, str]:
@@ -135,63 +148,34 @@ class ProvisioningAPI:
                 }
                 }
         return web.json_response(data, headers=self._acao_headers)
         return web.json_response(data, headers=self._acao_headers)
 
 
-    async def link(self, request: web.Request) -> web.Response:
-        user = await self.check_token(request)
-
-        if await user.is_logged_in():
-            raise web.HTTPConflict(
-                text="""{"error": "You're already logged in"}""", headers=self._headers
-            )
-
-        try:
-            data = await request.json()
-        except json.JSONDecodeError:
-            raise web.HTTPBadRequest(text='{"error": "Malformed JSON"}', headers=self._headers)
-
-        device_name = data.get("device_name", "Mautrix-Signal bridge")
-        sess = await self.bridge.signal.start_link()
-
-        user.command_status = {
-            "action": "Link",
-            "session_id": sess.session_id,
-            "device_name": device_name,
-        }
-
-        self.log.debug(f"Returning linking URI for {user.mxid} / {sess.session_id}")
-        return web.json_response({"uri": sess.uri}, headers=self._acao_headers)
-
     async def _shielded_link(self, user: "u.User", session_id: str, device_name: str) -> Account:
     async def _shielded_link(self, user: "u.User", session_id: str, device_name: str) -> Account:
         try:
         try:
             self.log.debug(f"Starting finish link request for {user.mxid} / {session_id}")
             self.log.debug(f"Starting finish link request for {user.mxid} / {session_id}")
             account = await self.bridge.signal.finish_link(
             account = await self.bridge.signal.finish_link(
-                session_id=session_id, overwrite=True, device_name=device_name
+                session_id=session_id, device_name=device_name, overwrite=True
             )
             )
         except TimeoutException:
         except TimeoutException:
             self.log.warning(f"Timed out waiting for linking to finish (session {session_id})")
             self.log.warning(f"Timed out waiting for linking to finish (session {session_id})")
             raise
             raise
         except Exception:
         except Exception:
             self.log.exception(
             self.log.exception(
-                "Fatal error while waiting for linking to finish (session {session_id})"
+                f"Fatal error while waiting for linking to finish (session {session_id})"
             )
             )
             raise
             raise
         else:
         else:
             await user.on_signin(account)
             await user.on_signin(account)
             return account
             return account
 
 
-    async def link_wait(self, request: web.Request) -> web.Response:
-        user = await self.check_token(request)
-        if not user.command_status or user.command_status["action"] != "Link":
-            raise web.HTTPBadRequest(
-                text='{"error": "No Signal linking started"}', headers=self._headers
-            )
-        session_id = user.command_status["session_id"]
-        device_name = user.command_status["device_name"]
+    async def _try_shielded_link(
+        self, user: "u.User", session_id: str, device_name: str
+    ) -> web.Response:
         try:
         try:
             account = await asyncio.shield(self._shielded_link(user, session_id, device_name))
             account = await asyncio.shield(self._shielded_link(user, session_id, device_name))
         except asyncio.CancelledError:
         except asyncio.CancelledError:
             self.log.warning(
             self.log.warning(
                 f"Client cancelled link wait request ({session_id}) before it finished"
                 f"Client cancelled link wait request ({session_id}) before it finished"
             )
             )
+            raise
         except TimeoutException:
         except TimeoutException:
             raise web.HTTPBadRequest(
             raise web.HTTPBadRequest(
                 text='{"error": "Signal linking timed out"}', headers=self._headers
                 text='{"error": "Signal linking timed out"}', headers=self._headers
@@ -212,6 +196,121 @@ class ProvisioningAPI:
         else:
         else:
             return web.json_response(account.address.serialize())
             return web.json_response(account.address.serialize())
 
 
+    # region Old Link API
+
+    async def link(self, request: web.Request) -> web.Response:
+        user = await self.check_token(request)
+
+        if await user.is_logged_in():
+            raise web.HTTPConflict(
+                text="""{"error": "You're already logged in"}""", headers=self._headers
+            )
+
+        try:
+            data = await request.json()
+        except json.JSONDecodeError:
+            raise web.HTTPBadRequest(text='{"error": "Malformed JSON"}', headers=self._headers)
+
+        device_name = data.get("device_name", "Mautrix-Signal bridge")
+        sess = await self.bridge.signal.start_link()
+
+        user.command_status = {
+            "action": "Link",
+            "session_id": sess.session_id,
+            "device_name": device_name,
+        }
+
+        self.log.debug(f"Returning linking URI for {user.mxid} / {sess.session_id}")
+        return web.json_response({"uri": sess.uri}, headers=self._acao_headers)
+
+    async def link_wait(self, request: web.Request) -> web.Response:
+        user = await self.check_token(request)
+        if not user.command_status or user.command_status["action"] != "Link":
+            raise web.HTTPBadRequest(
+                text='{"error": "No Signal linking started"}', headers=self._headers
+            )
+        session_id = user.command_status["session_id"]
+        device_name = user.command_status["device_name"]
+        return await self._try_shielded_link(user, session_id, device_name)
+
+    # endregion
+
+    # region New Link API
+
+    async def _get_request_data(self, request: web.Request) -> tuple[u.User, dict]:
+        user = await self.check_token(request)
+        if await user.is_logged_in():
+            error_text = """{"error": "You're already logged in"}"""
+            raise web.HTTPConflict(text=error_text, headers=self._headers)
+
+        try:
+            return user, (await request.json())
+        except json.JSONDecodeError:
+            raise web.HTTPBadRequest(text='{"error": "Malformed JSON"}', headers=self._headers)
+
+    async def link_new(self, request: web.Request) -> web.Response:
+        """
+        Starts a new link session.
+
+        Params: none
+
+        Returns a JSON object with the following fields:
+
+        * session_id: a session ID that should be used for all future link-related commands
+          (wait_for_scan and wait_for_account).
+        * uri: a URI that should be used to display the QR code.
+        """
+        user, _ = await self._get_request_data(request)
+        self.log.debug(f"Getting session ID and link URI for {user.mxid}")
+        sess = await self.bridge.signal.start_link()
+        self.log.debug(f"Returning session ID and link URI for {user.mxid} / {sess.session_id}")
+        return web.json_response(sess.serialize(), headers=self._acao_headers)
+
+    async def link_wait_for_scan(self, request: web.Request) -> web.Response:
+        """
+        Waits for the QR code associated with the provided session ID to be scanned.
+
+        Params: a JSON object with the following field:
+
+        * session_id: a session ID that you got from a call to /link/v2/new.
+        """
+        _, request_data = await self._get_request_data(request)
+        try:
+            session_id = request_data["session_id"]
+        except KeyError:
+            error_text = '{"error": "session_id not provided"}'
+            raise web.HTTPBadRequest(text=error_text, headers=self._headers)
+
+        try:
+            await self.bridge.signal.wait_for_scan(session_id)
+        except Exception as e:
+            error_text = f"Failed waiting for scan. Error: {e}"
+            self.log.exception(error_text)
+            raise web.HTTPBadRequest(text=error_text, headers=self._headers)
+        else:
+            return web.json_response({}, headers=self._acao_headers)
+
+    async def link_wait_for_account(self, request: web.Request) -> web.Response:
+        """
+        Waits for the link to the user's phone to complete.
+
+        Params: a JSON object with the following fields:
+
+        * session_id: a session ID that you got from a call to /link/v2/new.
+        * device_name: the device name that will show up in Linked Devices on the user's device.
+
+        Returns: a JSON object representing the user's account.
+        """
+        user, request_data = await self._get_request_data(request)
+        try:
+            session_id = request_data["session_id"]
+            device_name = request_data.get("device_name", "Mautrix-Signal bridge")
+        except KeyError:
+            error_text = '{"error": "session_id not provided"}'
+            raise web.HTTPBadRequest(text=error_text, headers=self._headers)
+
+        return await self._try_shielded_link(user, session_id, device_name)
+
     async def logout(self, request: web.Request) -> web.Response:
     async def logout(self, request: web.Request) -> web.Response:
         user = await self.check_token(request)
         user = await self.check_token(request)
         if not await user.is_logged_in():
         if not await user.is_logged_in():