Преглед изворни кода

Switch to signald v1 protocol for linking

Tulir Asokan пре 4 година
родитељ
комит
8738afa35a

+ 11 - 23
mausignald/errors.py

@@ -21,33 +21,10 @@ class UnexpectedResponse(RPCError):
         self.data = data
 
 
-class LinkingError(RPCError):
-    def __init__(self, message: str, number: int) -> None:
-        super().__init__(message)
-        self.number = number
-
-
 class NotConnected(RPCError):
     pass
 
 
-class LinkingTimeout(LinkingError):
-    pass
-
-
-class LinkingConflict(LinkingError):
-    pass
-
-
-def make_linking_error(data: Dict[str, Any]) -> LinkingError:
-    message = data["message"]
-    msg_number = data.get("msg_number")
-    return {
-        1: LinkingTimeout,
-        3: LinkingConflict,
-    }.get(msg_number, LinkingError)(message, msg_number)
-
-
 class ResponseError(RPCError):
     def __init__(self, data: Dict[str, Any], message_override: Optional[str] = None) -> None:
         self.data = data
@@ -65,8 +42,19 @@ class InvalidRequest(ResponseError):
         super().__init__(data, ", ".join(data.get("validationResults", "")))
 
 
+class TimeoutException(ResponseError):
+    pass
+
+
+class UserAlreadyExistsError(ResponseError):
+    def __init__(self, data: Dict[str, Any]) -> None:
+        super().__init__(data, message_override="You're already logged in")
+
+
 response_error_types = {
     "invalid_request": InvalidRequest,
+    "TimeoutException": TimeoutException,
+    "UserAlreadyExists": UserAlreadyExistsError,
 }
 
 

+ 1 - 2
mausignald/rpc.py

@@ -11,8 +11,7 @@ import json
 
 from mautrix.util.logging import TraceLogger
 
-from .errors import (NotConnected, UnexpectedError, UnexpectedResponse, RPCError,
-                     make_response_error)
+from .errors import NotConnected, UnexpectedError, UnexpectedResponse, make_response_error
 
 EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
 

+ 6 - 23
mausignald/signald.py

@@ -4,16 +4,15 @@
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
 from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, Set, TypeVar, Type
-from uuid import uuid4
 import asyncio
 
 from mautrix.util.logging import TraceLogger
 
 from .rpc import CONNECT_EVENT, SignaldRPCClient
-from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
+from .errors import UnexpectedError, UnexpectedResponse
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
                     Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2,
-                    Mention)
+                    Mention, LinkSession)
 
 T = TypeVar('T')
 EventHandler = Callable[[T], Awaitable[None]]
@@ -110,29 +109,13 @@ class SignaldClient(SignaldRPCClient):
         resp = await self.request("verify", "verification_succeeded", username=username, code=code)
         return Account.deserialize(resp)
 
-    async def link(self, url_callback: Callable[[str], Awaitable[None]],
-                   device_name: str = "mausignald") -> Account:
-        req_id = uuid4()
-        resp_type, resp = await self._raw_request("link", req_id, deviceName=device_name)
-        if resp_type == "linking_error":
-            raise make_linking_error(resp)
-        elif resp_type != "linking_uri":
-            raise UnexpectedResponse(resp_type, resp)
-
-        self.loop.create_task(url_callback(resp["uri"]))
-
-        resp_type, resp = await self._wait_response(req_id)
-        if resp_type == "linking_error":
-            raise make_linking_error(resp)
-        elif resp_type != "linking_successful":
-            raise UnexpectedResponse(resp_type, resp)
+    async def start_link(self) -> LinkSession:
+        return LinkSession.deserialize(await self.request_v1("generate_linking_uri"))
 
+    async def finish_link(self, session_id: str, device_name: str = "mausignald") -> Account:
+        resp = await self.request_v1("finish_link", device_name=device_name, session_id=session_id)
         return Account.deserialize(resp)
 
-    async def list_accounts(self) -> List[Account]:
-        data = await self.request_v0("list_accounts", "account_list")
-        return [Account.deserialize(acc) for acc in data["accounts"]]
-
     @staticmethod
     def _recipient_to_args(recipient: Union[Address, GroupID]) -> Dict[str, Any]:
         if isinstance(recipient, Address):

+ 13 - 11
mausignald/types.py

@@ -14,17 +14,6 @@ from mautrix.types import SerializableAttrs, SerializableEnum
 GroupID = NewType('GroupID', str)
 
 
-@dataclass
-class Account(SerializableAttrs['Account']):
-    device_id: int = attr.ib(metadata={"json": "deviceId"})
-    username: str
-    filename: str
-    registered: bool
-    has_keys: bool
-    subscribed: bool
-    uuid: Optional[UUID] = None
-
-
 @dataclass(frozen=True, eq=False)
 class Address(SerializableAttrs['Address']):
     number: Optional[str] = None
@@ -57,6 +46,19 @@ class Address(SerializableAttrs['Address']):
         return Address(number=value) if value.startswith("+") else Address(uuid=UUID(value))
 
 
+@dataclass
+class Account(SerializableAttrs['Account']):
+    account_id: str
+    device_id: int
+    address: Address
+
+
+@dataclass
+class LinkSession(SerializableAttrs['LinkSession']):
+    uri: str
+    session_id: str
+
+
 @dataclass
 class TrustLevel(SerializableEnum):
     TRUSTED_UNVERIFIED = "TRUSTED_UNVERIFIED"

+ 18 - 8
mautrix_signal/commands/auth.py

@@ -16,7 +16,7 @@
 from typing import Union
 import io
 
-from mausignald.errors import UnexpectedResponse
+from mausignald.errors import UnexpectedResponse, TimeoutException
 from mautrix.client import Client
 from mautrix.bridge import custom_puppet as cpu
 from mautrix.appservice import IntentAPI
@@ -59,13 +59,23 @@ async def link(evt: CommandEvent) -> None:
     # TODO make default device name configurable
     device_name = " ".join(evt.args) or "Mautrix-Signal bridge"
 
-    async def callback(uri: str) -> None:
-        content = await make_qr(evt.az.intent, uri)
-        await evt.az.intent.send_message(evt.room_id, content)
-
-    account = await evt.bridge.signal.link(callback, device_name=device_name)
-    await evt.sender.on_signin(account)
-    await evt.reply(f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}")
+    sess = await evt.bridge.signal.start_link()
+    content = await make_qr(evt.az.intent, sess.uri)
+    event_id = await evt.az.intent.send_message(evt.room_id, content)
+    try:
+        account = await evt.bridge.signal.finish_link(session_id=sess.session_id,
+                                                      device_name=device_name)
+    except TimeoutException:
+        await evt.reply("Linking timed out, please try again.")
+    except Exception:
+        evt.log.exception("Fatal error while waiting for linking to finish")
+        await evt.reply("Fatal error while waiting for linking to finish "
+                        "(see logs for more details)")
+    else:
+        await evt.sender.on_signin(account)
+        await evt.reply(f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}")
+    finally:
+        await evt.main_intent.redact(evt.room_id, event_id)
 
 
 @command_handler(needs_auth=False, management_only=True, help_section=SECTION_AUTH,

+ 2 - 2
mautrix_signal/user.py

@@ -102,8 +102,8 @@ class User(DBUser, BaseUser):
         shutil.rmtree(extra_dir, ignore_errors=True)
 
     async def on_signin(self, account: Account) -> None:
-        self.username = account.username
-        self.uuid = account.uuid
+        self.username = account.account_id
+        self.uuid = account.address.uuid
         self._add_to_cache()
         await self.update()
         await self.bridge.signal.subscribe(self.username)

+ 21 - 20
mautrix_signal/web/provisioning_api.py

@@ -15,13 +15,12 @@
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 from typing import Awaitable, Dict, TYPE_CHECKING
 import logging
-import asyncio
 import json
 
 from aiohttp import web
 
-from mausignald.types import Address, Account
-from mausignald.errors import LinkingTimeout
+from mausignald.types import Address
+from mausignald.errors import TimeoutException
 from mautrix.types import UserID
 from mautrix.util.logging import TraceLogger
 
@@ -119,37 +118,39 @@ class ProvisioningAPI:
             raise web.HTTPBadRequest(text='{"error": "Malformed JSON"}', headers=self._headers)
 
         device_name = data.get("device_name", "Mautrix-Signal bridge")
-        uri_future = asyncio.Future()
-
-        async def _callback(uri: str) -> None:
-            uri_future.set_result(uri)
-
-        async def _link() -> Account:
-            account = await self.bridge.signal.link(_callback, device_name=device_name)
-            await user.on_signin(account)
-            return account
+        sess = await self.bridge.signal.start_link()
 
         user.command_status = {
             "action": "Link",
-            "task": self.bridge.loop.create_task(_link()),
+            "session_id": sess.session_id,
+            "device_name": device_name,
         }
 
-        return web.json_response({"uri": await uri_future}, headers=self._acao_headers)
+        return web.json_response({"uri": sess.uri}, headers=self._acao_headers)
 
     async def link_wait(self, request: web.Request) -> web.Response:
         user = await self.check_token(request)
         if not user.command_status or user.command_status["action"] != "Link":
             raise web.HTTPBadRequest(text='{"error": "No Signal linking started"}',
                                      headers=self._headers)
+        session_id = user.command_status["session_id"]
+        device_name = user.command_status["device_name"]
         try:
-            account = await user.command_status["task"]
-        except LinkingTimeout:
+            account = await self.bridge.signal.finish_link(session_id=session_id,
+                                                           device_name=device_name)
+        except TimeoutException:
             raise web.HTTPBadRequest(text='{"error": "Signal linking timed out"}',
                                      headers=self._headers)
-        return web.json_response({
-            "number": account.username,
-            "uuid": str(account.uuid),
-        })
+        except Exception:
+            self.log.exception("Fatal error while waiting for linking to finish")
+            raise web.HTTPInternalServerError(text='{"error": "Fatal error in Signal linking"}',
+                                              headers=self._headers)
+        else:
+            await user.on_signin(account)
+            return web.json_response({
+                "number": account.username,
+                "uuid": str(account.uuid),
+            })
 
     async def logout(self, request: web.Request) -> web.Response:
         user = await self.check_token(request)