Переглянути джерело

Switch to v1 for most requests

Tulir Asokan 4 роки тому
батько
коміт
f56a913afe

+ 15 - 3
mausignald/errors.py

@@ -46,19 +46,31 @@ class TimeoutException(ResponseError):
     pass
 
 
+class UnknownIdentityKey(ResponseError):
+    pass
+
+
 class UserAlreadyExistsError(ResponseError):
     def __init__(self, data: Dict[str, Any]) -> None:
         super().__init__(data, message_override="You're already logged in")
 
 
+class RequestValidationFailure(ResponseError):
+    def __init__(self, data: Dict[str, Any]) -> None:
+        super().__init__(data, message_override=", ".join(data["validationResults"]))
+
+
 response_error_types = {
     "invalid_request": InvalidRequest,
     "TimeoutException": TimeoutException,
     "UserAlreadyExists": UserAlreadyExistsError,
+    "RequestValidationFailure": RequestValidationFailure,
+    "UnknownIdentityKey": UnknownIdentityKey,
 }
 
 
 def make_response_error(data: Dict[str, Any]) -> ResponseError:
-    if isinstance(data, str):
-        return UnknownResponseError(data)
-    return response_error_types.get(data["type"], ResponseError)(data)
+    error_data = data["error"]
+    if isinstance(error_data, str):
+        error_data = {"message": error_data}
+    return response_error_types.get(data["error_type"], ResponseError)(error_data)

+ 3 - 10
mausignald/rpc.py

@@ -113,10 +113,10 @@ class SignaldRPCClient:
                 waiter.set_exception(UnexpectedError(data["message"]))
             except KeyError:
                 waiter.set_exception(UnexpectedError("Unexpected error with no message"))
-        elif data and "error" in data and isinstance(data["error"], (str, dict)):
-            waiter.set_exception(make_response_error(data["error"]))
+        # elif data and "error" in data and isinstance(data["error"], (str, dict)):
+        #     waiter.set_exception(make_response_error(data))
         elif "error" in req and isinstance(req["error"], (str, dict)):
-            waiter.set_exception(make_response_error(req["error"]))
+            waiter.set_exception(make_response_error(req))
         else:
             waiter.set_result((command, data))
 
@@ -206,12 +206,5 @@ class SignaldRPCClient:
             raise UnexpectedResponse(resp_type, resp_data)
         return resp_data
 
-    async def request_v0(self, command: str, expected_response: str, **data: Any) -> Any:
-        return await self.request(command, expected_response, version="v0", **data)
-
     async def request_v1(self, command: str, **data: Any) -> Any:
         return await self.request(command, expected_response=command, version="v1", **data)
-
-    async def request_nowait(self, command: str, **data: Any) -> None:
-        _, req = self._create_request(command, **data)
-        await self._send_request(req)

+ 27 - 18
mausignald/signald.py

@@ -101,12 +101,11 @@ class SignaldClient(SignaldRPCClient):
 
     async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
                        ) -> str:
-        resp = await self.request("register", "verification_required", username=phone,
-                                  voice=voice, captcha=captcha)
-        return resp["username"]
+        resp = await self.request_v1("register", account=phone, voice=voice, captcha=captcha)
+        return resp["account_id"]
 
     async def verify(self, username: str, code: str) -> Account:
-        resp = await self.request("verify", "verification_succeeded", username=username, code=code)
+        resp = await self.request_v1("verify", account=username, code=code)
         return Account.deserialize(resp)
 
     async def start_link(self) -> LinkSession:
@@ -143,13 +142,15 @@ class SignaldClient(SignaldRPCClient):
 
     async def send_receipt(self, username: str, sender: Address, timestamps: List[int],
                            when: Optional[int] = None, read: bool = False) -> None:
-        await self.request_nowait("mark_read" if read else "mark_delivered", username=username,
-                                  timestamps=timestamps, when=when,
-                                  recipientAddress=sender.serialize())
+        if not read:
+            # TODO implement
+            return
+        await self.request_v1("mark_read", account=username, timestamps=timestamps, when=when,
+                              to=sender.serialize())
 
-    async def list_contacts(self, username: str) -> List[Contact]:
-        contacts = await self.request_v0("list_contacts", "contact_list", username=username)
-        return [Contact.deserialize(contact) for contact in contacts]
+    async def list_contacts(self, username: str) -> List[Profile]:
+        resp = await self.request_v1("list_contacts", account=username)
+        return [Profile.deserialize(contact) for contact in resp["profiles"]]
 
     async def list_groups(self, username: str) -> List[Union[Group, GroupV2]]:
         resp = await self.request_v1("list_groups", account=username)
@@ -201,8 +202,8 @@ class SignaldClient(SignaldRPCClient):
         return Profile.deserialize(resp)
 
     async def get_identities(self, username: str, address: Address) -> GetIdentitiesResponse:
-        resp = await self.request_v0("get_identities", "identities", username=username,
-                                     recipientAddress=address.serialize())
+        resp = await self.request_v1("get_identities", account=username,
+                                     address=address.serialize())
         return GetIdentitiesResponse.deserialize(resp)
 
     async def set_profile(self, username: str, name: Optional[str] = None,
@@ -214,9 +215,17 @@ class SignaldClient(SignaldRPCClient):
             args["avatarFile"] = avatar_path
         await self.request_v1("set_profile", account=username, **args)
 
-    async def trust(self, username: str, recipient: Address, fingerprint: str, trust_level: str
-                    ) -> str:
-        resp = await self.request_v0("trust", "trusted_safety_number", username=username,
-                                     fingerprint=fingerprint, trust_level=trust_level,
-                                     recipientAddress=recipient.serialize())
-        return resp["message"]
+    async def trust(self, username: str, recipient: Address, trust_level: str,
+                    safety_number: Optional[str] = None, qr_code_data: Optional[str] = None
+                    ) -> None:
+        args = {}
+        if safety_number:
+            if qr_code_data:
+                raise ValueError("only one of safety_number and qr_code_data must be set")
+            args["safety_number"] = safety_number
+        elif qr_code_data:
+            args["qr_code_data"] = qr_code_data
+        else:
+            raise ValueError("safety_number or qr_code_data is required")
+        await self.request_v1("trust", account=username, **args, trust_level=trust_level,
+                              address=recipient.serialize())

+ 1 - 2
mausignald/types.py

@@ -70,14 +70,13 @@ class TrustLevel(SerializableEnum):
 class Identity(SerializableAttrs['Identity']):
     trust_level: TrustLevel
     added: int
-    fingerprint: str
     safety_number: str
     qr_code_data: str
-    address: Address
 
 
 @dataclass
 class GetIdentitiesResponse(SerializableAttrs['GetIdentitiesResponse']):
+    address: Address
     identities: List[Identity]
 
 

+ 14 - 7
mautrix_signal/commands/signal.py

@@ -18,7 +18,9 @@ import base64
 import json
 
 from mautrix.bridge.commands import HelpSection, command_handler, SECTION_ADMIN
+from mautrix.types import EventID
 from mausignald.types import Address
+from mausignald.errors import UnknownIdentityKey
 
 from .. import puppet as pu, portal as po
 from .auth import make_qr, remove_extra_chars
@@ -101,7 +103,7 @@ async def safety_number(evt: CommandEvent) -> None:
     for identity in resp.identities:
         if identity.added > most_recent.added:
             most_recent = identity
-    uuid = most_recent.address.uuid or "unknown"
+    uuid = resp.address.uuid or "unknown"
     await evt.reply(f"### {puppet.name}\n\n"
                     f"**UUID:** {uuid}  \n"
                     f"**Trust level:** {most_recent.trust_level}  \n"
@@ -123,15 +125,20 @@ async def set_profile_name(evt: CommandEvent) -> None:
 @command_handler(needs_auth=True, management_only=False, help_section=SECTION_SIGNAL,
                  help_text="Mark another user's safety number as trusted",
                  help_args="<_recipient phone_> <_safety number_>")
-async def mark_trusted(evt: CommandEvent) -> None:
+async def mark_trusted(evt: CommandEvent) -> EventID:
+    if len(evt.args) < 2:
+        return await evt.reply("**Usage:** `$cmdprefix+sp mark-trusted <recipient phone> "
+                               "<safety number>`")
     number = evt.args[0].translate(remove_extra_chars)
     safety_num = "".join(evt.args[1:]).replace("\n", "")
     if len(safety_num) != 60 or not safety_num.isdecimal():
-        await evt.reply("That doesn't look like a valid safety number")
-        return
-    msg = await evt.bridge.signal.trust(evt.sender.username, Address(number=number),
-                                        fingerprint=safety_num, trust_level="TRUSTED_VERIFIED")
-    await evt.reply(msg)
+        return await evt.reply("That doesn't look like a valid safety number")
+    try:
+        await evt.bridge.signal.trust(evt.sender.username, Address(number=number),
+                                      safety_number=safety_num, trust_level="TRUSTED_VERIFIED")
+    except UnknownIdentityKey as e:
+        return await evt.reply(f"Failed to mark {number} as trusted: {e}")
+    return await evt.reply(f"Successfully marked {number} as trusted")
 
 
 @command_handler(needs_admin=False, needs_auth=True, help_section=SECTION_SIGNAL,

+ 12 - 10
mautrix_signal/user.py

@@ -20,7 +20,7 @@ import asyncio
 import os.path
 import shutil
 
-from mausignald.types import Account, Address, Contact, Group, GroupV2, ListenEvent, ListenAction
+from mausignald.types import Account, Address, Profile, Group, GroupV2, ListenEvent, ListenAction
 from mautrix.bridge import BaseUser, async_getter_lock
 from mautrix.types import UserID, RoomID
 from mautrix.appservice import AppService
@@ -145,21 +145,23 @@ class User(DBUser, BaseUser):
         except Exception:
             self.log.exception("Error while syncing groups")
 
-    async def sync_contact(self, contact: Union[Contact, Address], create_portals: bool = False
+    async def sync_contact(self, contact: Union[Profile, Address], create_portals: bool = False
                            ) -> None:
         self.log.trace("Syncing contact %s", contact)
-        address = contact.address if isinstance(contact, Contact) else contact
-        puppet = await pu.Puppet.get_by_address(address)
-        profile = await self.bridge.signal.get_profile(self.username, address)
-        if profile and profile.name:
-            self.log.trace("Got profile for %s: %s", address, profile)
+        if isinstance(contact, Address):
+            address = contact
+            profile = await self.bridge.signal.get_profile(self.username, address)
+            if profile and profile.name:
+                self.log.trace("Got profile for %s: %s", address, profile)
         else:
-            profile = None
-        await puppet.update_info(profile or contact)
+            address = contact.address
+            profile = contact
+        puppet = await pu.Puppet.get_by_address(address)
+        await puppet.update_info(profile)
         if create_portals:
             portal = await po.Portal.get_by_chat_id(puppet.address, receiver=self.username,
                                                     create=True)
-            await portal.create_matrix_room(self, profile or contact)
+            await portal.create_matrix_room(self, profile)
 
     async def _sync_group(self, group: Group, create_portals: bool) -> None:
         self.log.trace("Syncing group %s", group)