Bläddra i källkod

Switch to signald v1 protocol for linking

Tulir Asokan 4 år sedan
förälder
incheckning
8738afa35a

+ 11 - 23
mausignald/errors.py

@@ -21,33 +21,10 @@ class UnexpectedResponse(RPCError):
         self.data = data
         self.data = data
 
 
 
 
-class LinkingError(RPCError):
-    def __init__(self, message: str, number: int) -> None:
-        super().__init__(message)
-        self.number = number
-
-
 class NotConnected(RPCError):
 class NotConnected(RPCError):
     pass
     pass
 
 
 
 
-class LinkingTimeout(LinkingError):
-    pass
-
-
-class LinkingConflict(LinkingError):
-    pass
-
-
-def make_linking_error(data: Dict[str, Any]) -> LinkingError:
-    message = data["message"]
-    msg_number = data.get("msg_number")
-    return {
-        1: LinkingTimeout,
-        3: LinkingConflict,
-    }.get(msg_number, LinkingError)(message, msg_number)
-
-
 class ResponseError(RPCError):
 class ResponseError(RPCError):
     def __init__(self, data: Dict[str, Any], message_override: Optional[str] = None) -> None:
     def __init__(self, data: Dict[str, Any], message_override: Optional[str] = None) -> None:
         self.data = data
         self.data = data
@@ -65,8 +42,19 @@ class InvalidRequest(ResponseError):
         super().__init__(data, ", ".join(data.get("validationResults", "")))
         super().__init__(data, ", ".join(data.get("validationResults", "")))
 
 
 
 
+class TimeoutException(ResponseError):
+    pass
+
+
+class UserAlreadyExistsError(ResponseError):
+    def __init__(self, data: Dict[str, Any]) -> None:
+        super().__init__(data, message_override="You're already logged in")
+
+
 response_error_types = {
 response_error_types = {
     "invalid_request": InvalidRequest,
     "invalid_request": InvalidRequest,
+    "TimeoutException": TimeoutException,
+    "UserAlreadyExists": UserAlreadyExistsError,
 }
 }
 
 
 
 

+ 1 - 2
mausignald/rpc.py

@@ -11,8 +11,7 @@ import json
 
 
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
-from .errors import (NotConnected, UnexpectedError, UnexpectedResponse, RPCError,
-                     make_response_error)
+from .errors import NotConnected, UnexpectedError, UnexpectedResponse, make_response_error
 
 
 EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
 EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
 
 

+ 6 - 23
mausignald/signald.py

@@ -4,16 +4,15 @@
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
 from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, Set, TypeVar, Type
 from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, Set, TypeVar, Type
-from uuid import uuid4
 import asyncio
 import asyncio
 
 
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
 from .rpc import CONNECT_EVENT, SignaldRPCClient
 from .rpc import CONNECT_EVENT, SignaldRPCClient
-from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
+from .errors import UnexpectedError, UnexpectedResponse
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
                     Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2,
                     Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2,
-                    Mention)
+                    Mention, LinkSession)
 
 
 T = TypeVar('T')
 T = TypeVar('T')
 EventHandler = Callable[[T], Awaitable[None]]
 EventHandler = Callable[[T], Awaitable[None]]
@@ -110,29 +109,13 @@ class SignaldClient(SignaldRPCClient):
         resp = await self.request("verify", "verification_succeeded", username=username, code=code)
         resp = await self.request("verify", "verification_succeeded", username=username, code=code)
         return Account.deserialize(resp)
         return Account.deserialize(resp)
 
 
-    async def link(self, url_callback: Callable[[str], Awaitable[None]],
-                   device_name: str = "mausignald") -> Account:
-        req_id = uuid4()
-        resp_type, resp = await self._raw_request("link", req_id, deviceName=device_name)
-        if resp_type == "linking_error":
-            raise make_linking_error(resp)
-        elif resp_type != "linking_uri":
-            raise UnexpectedResponse(resp_type, resp)
-
-        self.loop.create_task(url_callback(resp["uri"]))
-
-        resp_type, resp = await self._wait_response(req_id)
-        if resp_type == "linking_error":
-            raise make_linking_error(resp)
-        elif resp_type != "linking_successful":
-            raise UnexpectedResponse(resp_type, resp)
+    async def start_link(self) -> LinkSession:
+        return LinkSession.deserialize(await self.request_v1("generate_linking_uri"))
 
 
+    async def finish_link(self, session_id: str, device_name: str = "mausignald") -> Account:
+        resp = await self.request_v1("finish_link", device_name=device_name, session_id=session_id)
         return Account.deserialize(resp)
         return Account.deserialize(resp)
 
 
-    async def list_accounts(self) -> List[Account]:
-        data = await self.request_v0("list_accounts", "account_list")
-        return [Account.deserialize(acc) for acc in data["accounts"]]
-
     @staticmethod
     @staticmethod
     def _recipient_to_args(recipient: Union[Address, GroupID]) -> Dict[str, Any]:
     def _recipient_to_args(recipient: Union[Address, GroupID]) -> Dict[str, Any]:
         if isinstance(recipient, Address):
         if isinstance(recipient, Address):

+ 13 - 11
mausignald/types.py

@@ -14,17 +14,6 @@ from mautrix.types import SerializableAttrs, SerializableEnum
 GroupID = NewType('GroupID', str)
 GroupID = NewType('GroupID', str)
 
 
 
 
-@dataclass
-class Account(SerializableAttrs['Account']):
-    device_id: int = attr.ib(metadata={"json": "deviceId"})
-    username: str
-    filename: str
-    registered: bool
-    has_keys: bool
-    subscribed: bool
-    uuid: Optional[UUID] = None
-
-
 @dataclass(frozen=True, eq=False)
 @dataclass(frozen=True, eq=False)
 class Address(SerializableAttrs['Address']):
 class Address(SerializableAttrs['Address']):
     number: Optional[str] = None
     number: Optional[str] = None
@@ -57,6 +46,19 @@ class Address(SerializableAttrs['Address']):
         return Address(number=value) if value.startswith("+") else Address(uuid=UUID(value))
         return Address(number=value) if value.startswith("+") else Address(uuid=UUID(value))
 
 
 
 
+@dataclass
+class Account(SerializableAttrs['Account']):
+    account_id: str
+    device_id: int
+    address: Address
+
+
+@dataclass
+class LinkSession(SerializableAttrs['LinkSession']):
+    uri: str
+    session_id: str
+
+
 @dataclass
 @dataclass
 class TrustLevel(SerializableEnum):
 class TrustLevel(SerializableEnum):
     TRUSTED_UNVERIFIED = "TRUSTED_UNVERIFIED"
     TRUSTED_UNVERIFIED = "TRUSTED_UNVERIFIED"

+ 18 - 8
mautrix_signal/commands/auth.py

@@ -16,7 +16,7 @@
 from typing import Union
 from typing import Union
 import io
 import io
 
 
-from mausignald.errors import UnexpectedResponse
+from mausignald.errors import UnexpectedResponse, TimeoutException
 from mautrix.client import Client
 from mautrix.client import Client
 from mautrix.bridge import custom_puppet as cpu
 from mautrix.bridge import custom_puppet as cpu
 from mautrix.appservice import IntentAPI
 from mautrix.appservice import IntentAPI
@@ -59,13 +59,23 @@ async def link(evt: CommandEvent) -> None:
     # TODO make default device name configurable
     # TODO make default device name configurable
     device_name = " ".join(evt.args) or "Mautrix-Signal bridge"
     device_name = " ".join(evt.args) or "Mautrix-Signal bridge"
 
 
-    async def callback(uri: str) -> None:
-        content = await make_qr(evt.az.intent, uri)
-        await evt.az.intent.send_message(evt.room_id, content)
-
-    account = await evt.bridge.signal.link(callback, device_name=device_name)
-    await evt.sender.on_signin(account)
-    await evt.reply(f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}")
+    sess = await evt.bridge.signal.start_link()
+    content = await make_qr(evt.az.intent, sess.uri)
+    event_id = await evt.az.intent.send_message(evt.room_id, content)
+    try:
+        account = await evt.bridge.signal.finish_link(session_id=sess.session_id,
+                                                      device_name=device_name)
+    except TimeoutException:
+        await evt.reply("Linking timed out, please try again.")
+    except Exception:
+        evt.log.exception("Fatal error while waiting for linking to finish")
+        await evt.reply("Fatal error while waiting for linking to finish "
+                        "(see logs for more details)")
+    else:
+        await evt.sender.on_signin(account)
+        await evt.reply(f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}")
+    finally:
+        await evt.main_intent.redact(evt.room_id, event_id)
 
 
 
 
 @command_handler(needs_auth=False, management_only=True, help_section=SECTION_AUTH,
 @command_handler(needs_auth=False, management_only=True, help_section=SECTION_AUTH,

+ 2 - 2
mautrix_signal/user.py

@@ -102,8 +102,8 @@ class User(DBUser, BaseUser):
         shutil.rmtree(extra_dir, ignore_errors=True)
         shutil.rmtree(extra_dir, ignore_errors=True)
 
 
     async def on_signin(self, account: Account) -> None:
     async def on_signin(self, account: Account) -> None:
-        self.username = account.username
-        self.uuid = account.uuid
+        self.username = account.account_id
+        self.uuid = account.address.uuid
         self._add_to_cache()
         self._add_to_cache()
         await self.update()
         await self.update()
         await self.bridge.signal.subscribe(self.username)
         await self.bridge.signal.subscribe(self.username)

+ 21 - 20
mautrix_signal/web/provisioning_api.py

@@ -15,13 +15,12 @@
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 from typing import Awaitable, Dict, TYPE_CHECKING
 from typing import Awaitable, Dict, TYPE_CHECKING
 import logging
 import logging
-import asyncio
 import json
 import json
 
 
 from aiohttp import web
 from aiohttp import web
 
 
-from mausignald.types import Address, Account
-from mausignald.errors import LinkingTimeout
+from mausignald.types import Address
+from mausignald.errors import TimeoutException
 from mautrix.types import UserID
 from mautrix.types import UserID
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
@@ -119,37 +118,39 @@ class ProvisioningAPI:
             raise web.HTTPBadRequest(text='{"error": "Malformed JSON"}', headers=self._headers)
             raise web.HTTPBadRequest(text='{"error": "Malformed JSON"}', headers=self._headers)
 
 
         device_name = data.get("device_name", "Mautrix-Signal bridge")
         device_name = data.get("device_name", "Mautrix-Signal bridge")
-        uri_future = asyncio.Future()
-
-        async def _callback(uri: str) -> None:
-            uri_future.set_result(uri)
-
-        async def _link() -> Account:
-            account = await self.bridge.signal.link(_callback, device_name=device_name)
-            await user.on_signin(account)
-            return account
+        sess = await self.bridge.signal.start_link()
 
 
         user.command_status = {
         user.command_status = {
             "action": "Link",
             "action": "Link",
-            "task": self.bridge.loop.create_task(_link()),
+            "session_id": sess.session_id,
+            "device_name": device_name,
         }
         }
 
 
-        return web.json_response({"uri": await uri_future}, headers=self._acao_headers)
+        return web.json_response({"uri": sess.uri}, headers=self._acao_headers)
 
 
     async def link_wait(self, request: web.Request) -> web.Response:
     async def link_wait(self, request: web.Request) -> web.Response:
         user = await self.check_token(request)
         user = await self.check_token(request)
         if not user.command_status or user.command_status["action"] != "Link":
         if not user.command_status or user.command_status["action"] != "Link":
             raise web.HTTPBadRequest(text='{"error": "No Signal linking started"}',
             raise web.HTTPBadRequest(text='{"error": "No Signal linking started"}',
                                      headers=self._headers)
                                      headers=self._headers)
+        session_id = user.command_status["session_id"]
+        device_name = user.command_status["device_name"]
         try:
         try:
-            account = await user.command_status["task"]
-        except LinkingTimeout:
+            account = await self.bridge.signal.finish_link(session_id=session_id,
+                                                           device_name=device_name)
+        except TimeoutException:
             raise web.HTTPBadRequest(text='{"error": "Signal linking timed out"}',
             raise web.HTTPBadRequest(text='{"error": "Signal linking timed out"}',
                                      headers=self._headers)
                                      headers=self._headers)
-        return web.json_response({
-            "number": account.username,
-            "uuid": str(account.uuid),
-        })
+        except Exception:
+            self.log.exception("Fatal error while waiting for linking to finish")
+            raise web.HTTPInternalServerError(text='{"error": "Fatal error in Signal linking"}',
+                                              headers=self._headers)
+        else:
+            await user.on_signin(account)
+            return web.json_response({
+                "number": account.username,
+                "uuid": str(account.uuid),
+            })
 
 
     async def logout(self, request: web.Request) -> web.Response:
     async def logout(self, request: web.Request) -> web.Response:
         user = await self.check_token(request)
         user = await self.check_token(request)