Parcourir la source

Add overwrite flag when linking device

When combined with https://gitlab.com/signald/signald/-/merge_requests/62, this
should allow linking accounts even if there's an existing failed link.
Tulir Asokan il y a 4 ans
Parent
commit
61f7d663dc

+ 4 - 2
mausignald/signald.py

@@ -111,8 +111,10 @@ class SignaldClient(SignaldRPCClient):
     async def start_link(self) -> LinkSession:
         return LinkSession.deserialize(await self.request_v1("generate_linking_uri"))
 
-    async def finish_link(self, session_id: str, device_name: str = "mausignald") -> Account:
-        resp = await self.request_v1("finish_link", device_name=device_name, session_id=session_id)
+    async def finish_link(self, session_id: str, device_name: str = "mausignald",
+                          overwrite: bool = False) -> Account:
+        resp = await self.request_v1("finish_link", device_name=device_name, session_id=session_id,
+                                     overwrite=overwrite)
         return Account.deserialize(resp)
 
     @staticmethod

+ 9 - 1
mautrix_signal/commands/auth.py

@@ -54,6 +54,10 @@ async def link(evt: CommandEvent) -> None:
     if qrcode is None:
         await evt.reply("Can't generate QR code: qrcode and/or PIL not installed")
         return
+    if await evt.sender.is_logged_in():
+        await evt.reply("You're already logged in. "
+                        "If you want to relink, log out with `$cmdprefix+sp logout` first.")
+        return
     # TODO make default device name configurable
     device_name = " ".join(evt.args) or "Mautrix-Signal bridge"
 
@@ -61,7 +65,7 @@ async def link(evt: CommandEvent) -> None:
     content = await make_qr(evt.az.intent, sess.uri)
     event_id = await evt.az.intent.send_message(evt.room_id, content)
     try:
-        account = await evt.bridge.signal.finish_link(session_id=sess.session_id,
+        account = await evt.bridge.signal.finish_link(session_id=sess.session_id, overwrite=True,
                                                       device_name=device_name)
     except TimeoutException:
         await evt.reply("Linking timed out, please try again.")
@@ -82,6 +86,10 @@ async def register(evt: CommandEvent) -> None:
     if len(evt.args) == 0:
         await evt.reply("**Usage**: $cmdprefix+sp register [--voice] [--captcha <token>] <phone>")
         return
+    if await evt.sender.is_logged_in():
+        await evt.reply("You're already logged in. "
+                        "If you want to re-register, log out with `$cmdprefix+sp logout` first.")
+        return
     voice = False
     captcha = None
     while True:

+ 13 - 4
mautrix_signal/web/provisioning_api.py

@@ -123,6 +123,10 @@ class ProvisioningAPI:
     async def link(self, request: web.Request) -> web.Response:
         user = await self.check_token(request)
 
+        if await user.is_logged_in():
+            raise web.HTTPConflict(text='''{"error": "You're already logged in"}''',
+                                   headers=self._headers)
+
         try:
             data = await request.json()
         except json.JSONDecodeError:
@@ -137,16 +141,20 @@ class ProvisioningAPI:
             "device_name": device_name,
         }
 
+        self.log.debug(f"Returning linking URI for {user.mxid} / {sess.session_id}")
         return web.json_response({"uri": sess.uri}, headers=self._acao_headers)
 
     async def _shielded_link(self, user: 'u.User', session_id: str, device_name: str) -> Account:
         try:
-            account = await self.bridge.signal.finish_link(session_id=session_id,
+            self.log.debug(f"Starting finish link request for {user.mxid} / {session_id}")
+            account = await self.bridge.signal.finish_link(session_id=session_id, overwrite=True,
                                                            device_name=device_name)
         except TimeoutException:
+            self.log.warning(f"Timed out waiting for linking to finish (session {session_id})")
             raise
         except Exception:
-            self.log.exception("Fatal error while waiting for linking to finish")
+            self.log.exception("Fatal error while waiting for linking to finish "
+                               f"(session {session_id})")
             raise
         else:
             await user.on_signin(account)
@@ -162,7 +170,8 @@ class ProvisioningAPI:
         try:
             account = await asyncio.shield(self._shielded_link(user, session_id, device_name))
         except asyncio.CancelledError:
-            self.log.warning("Client cancelled link wait request before it finished")
+            self.log.warning(f"Client cancelled link wait request ({session_id})"
+                             " before it finished")
         except TimeoutException:
             raise web.HTTPBadRequest(text='{"error": "Signal linking timed out"}',
                                      headers=self._headers)
@@ -174,7 +183,7 @@ class ProvisioningAPI:
 
     async def logout(self, request: web.Request) -> web.Response:
         user = await self.check_token(request)
-        if not user.username:
+        if not await user.is_logged_in():
             raise web.HTTPNotFound(text='''{"error": "You're not logged in"}''',
                                    headers=self._headers)
         await user.logout()