Bladeren bron

Fix bugs added when dropping phone numbers

Tulir Asokan 2 jaren geleden
bovenliggende
commit
b1e0f87732

+ 7 - 5
mausignald/signald.py

@@ -180,8 +180,10 @@ class SignaldClient(SignaldRPCClient):
 
     @staticmethod
     def _recipient_to_args(
-        recipient: Address | GroupID, simple_name: bool = False
+        recipient: UUID | Address | GroupID, simple_name: bool = False
     ) -> dict[str, Any]:
+        if isinstance(recipient, UUID):
+            recipient = Address(uuid=recipient)
         if isinstance(recipient, Address):
             recipient = recipient.serialize()
             field_name = "address" if simple_name else "recipientAddress"
@@ -192,7 +194,7 @@ class SignaldClient(SignaldRPCClient):
     async def react(
         self,
         username: str,
-        recipient: Address | GroupID,
+        recipient: UUID | Address | GroupID,
         reaction: Reaction,
         req_id: UUID | None = None,
     ) -> None:
@@ -205,7 +207,7 @@ class SignaldClient(SignaldRPCClient):
         )
 
     async def remote_delete(
-        self, username: str, recipient: Address | GroupID, timestamp: int
+        self, username: str, recipient: UUID | Address | GroupID, timestamp: int
     ) -> None:
         await self.request_v1(
             "remote_delete",
@@ -217,7 +219,7 @@ class SignaldClient(SignaldRPCClient):
     async def send_raw(
         self,
         username: str,
-        recipient: Address | GroupID,
+        recipient: UUID | Address | GroupID,
         body: str,
         quote: Quote | None = None,
         attachments: list[Attachment] | None = None,
@@ -247,7 +249,7 @@ class SignaldClient(SignaldRPCClient):
     async def send(
         self,
         username: str,
-        recipient: Address | GroupID,
+        recipient: UUID | Address | GroupID,
         body: str,
         quote: Quote | None = None,
         attachments: list[Attachment] | None = None,

+ 2 - 4
mautrix_signal/commands/signal.py

@@ -105,9 +105,7 @@ async def pm(evt: CommandEvent) -> None:
     puppet = await _get_puppet_from_cmd(evt)
     if not puppet:
         return
-    portal = await po.Portal.get_by_chat_id(
-        puppet.address, receiver=evt.sender.username, create=True
-    )
+    portal = await po.Portal.get_by_chat_id(puppet.uuid, receiver=evt.sender.username, create=True)
     if portal.mxid:
         await evt.reply(
             f"You already have a private chat with {puppet.name}: "
@@ -178,7 +176,7 @@ async def safety_number(evt: CommandEvent) -> None:
             return
         evt.args = evt.args[1:]
     if len(evt.args) == 0 and evt.portal and evt.portal.is_direct:
-        puppet = await pu.Puppet.get_by_uuid(evt.portal.chat_id.uuid)
+        puppet = await evt.portal.get_dm_puppet()
     else:
         puppet = await _get_puppet_from_cmd(evt)
     if not puppet:

+ 2 - 1
mautrix_signal/matrix.py

@@ -17,6 +17,7 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING
 
+from mausignald.types import Address
 from mautrix.bridge import BaseMatrixHandler, RejectMatrixInvite
 from mautrix.types import (
     Event,
@@ -205,7 +206,7 @@ class MatrixHandler(BaseMatrixHandler):
         try:
             await self.signal.send_receipt(
                 user.username,
-                message.sender,
+                Address(uuid=message.sender),
                 timestamps=[message.timestamp],
                 when=data.ts,
                 read=True,

+ 1 - 1
mautrix_signal/portal.py

@@ -2450,7 +2450,7 @@ class Portal(DBPortal, BasePortal):
     @classmethod
     @async_getter_lock
     async def get_by_chat_id(
-        cls, chat_id: GroupID | UUID, receiver: str = "", /, *, create: bool
+        cls, chat_id: GroupID | UUID, /, *, receiver: str = "", create: bool = False
     ) -> Portal | None:
         if isinstance(chat_id, str):
             receiver = ""

+ 1 - 1
mautrix_signal/puppet.py

@@ -330,7 +330,7 @@ class Puppet(DBPuppet, BasePuppet):
     async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
         portal: p.Portal = await p.Portal.get_by_mxid(room_id)
         # Leave all portals except the notes to self room
-        return not (portal and portal.is_direct and portal.chat_id.uuid == self.uuid)
+        return not (portal and portal.is_direct and portal.chat_id == self.uuid)
 
     # region Database getters
 

+ 9 - 11
mautrix_signal/signal.py

@@ -53,7 +53,7 @@ class SignalHandler(SignaldClient):
     loop: asyncio.AbstractEventLoop
     data_dir: str
     delete_unknown_accounts: bool
-    error_message_events: dict[tuple[Address, str, int], Awaitable[EventID] | None]
+    error_message_events: dict[tuple[UUID, str, int], Awaitable[EventID] | None]
 
     def __init__(self, bridge: "SignalBridge") -> None:
         super().__init__(bridge.config["signal.socket_path"], loop=bridge.loop)
@@ -102,7 +102,7 @@ class SignalHandler(SignaldClient):
 
         try:
             event_id_future = self.error_message_events.pop(
-                (sender.address, user.username, evt.timestamp)
+                (sender.uuid, user.username, evt.timestamp)
             )
         except KeyError:
             pass
@@ -110,7 +110,7 @@ class SignalHandler(SignaldClient):
             self.log.debug(f"Got previously errored message {evt.timestamp} from {sender.address}")
             event_id = await event_id_future if event_id_future is not None else None
             if event_id is not None:
-                portal = await po.Portal.get_by_chat_id(sender.address, receiver=user.username)
+                portal = await po.Portal.get_by_chat_id(sender.uuid, receiver=user.username)
                 if portal and portal.mxid:
                     await sender.intent_for(portal).redact(portal.mxid, event_id)
 
@@ -127,14 +127,14 @@ class SignalHandler(SignaldClient):
         if not sender:
             return
         user = await u.User.get_by_username(err.account)
-        portal = await po.Portal.get_by_chat_id(sender.address, receiver=user.username)
+        portal = await po.Portal.get_by_chat_id(sender.uuid, receiver=user.username)
         if not portal or not portal.mxid:
             return
 
         # Add the error to the error_message_events dictionary, then wait for 10 seconds until
         # sending an error. If a success for the timestamp comes in before the 10 seconds is up,
         # don't send the error message.
-        error_message_event_key = (sender.address, user.username, err.data.timestamp)
+        error_message_event_key = (sender.uuid, user.username, err.data.timestamp)
         self.error_message_events[error_message_event_key] = None
 
         await asyncio.sleep(10)
@@ -194,7 +194,7 @@ class SignalHandler(SignaldClient):
             portal = await po.Portal.get_by_chat_id(msg.group.group_id, create=True)
         else:
             portal = await po.Portal.get_by_chat_id(
-                addr_override or sender.address, receiver=user.username, create=True
+                addr_override or sender.uuid, receiver=user.username, create=True
             )
             if addr_override and not sender.is_real_user:
                 portal.log.debug(
@@ -255,9 +255,7 @@ class SignalHandler(SignaldClient):
     @staticmethod
     async def handle_call_message(user: u.User, sender: pu.Puppet, msg: IncomingMessage) -> None:
         assert msg.call_message
-        portal = await po.Portal.get_by_chat_id(
-            sender.address, receiver=user.username, create=True
-        )
+        portal = await po.Portal.get_by_chat_id(sender.uuid, receiver=user.username, create=True)
         if not portal.mxid:
             # FIXME
             # await portal.create_matrix_room(
@@ -298,7 +296,7 @@ class SignalHandler(SignaldClient):
             puppet = await pu.Puppet.get_by_address(receipt.sender, create=False)
             if not puppet:
                 continue
-            message = await DBMessage.find_by_sender_timestamp(puppet.address, receipt.timestamp)
+            message = await DBMessage.find_by_sender_timestamp(puppet.uuid, receipt.timestamp)
             if not message:
                 continue
             portal = await po.Portal.get_by_mxid(message.mx_room)
@@ -311,7 +309,7 @@ class SignalHandler(SignaldClient):
         if typing.group_id:
             portal = await po.Portal.get_by_chat_id(typing.group_id)
         else:
-            portal = await po.Portal.get_by_chat_id(sender.address, receiver=user.username)
+            portal = await po.Portal.get_by_chat_id(sender.uuid, receiver=user.username)
         if not portal or not portal.mxid:
             return
         is_typing = typing.action == TypingAction.STARTED

+ 3 - 7
mautrix_signal/web/provisioning_api.py

@@ -150,7 +150,7 @@ class ProvisioningAPI:
         if await user.is_logged_in():
             try:
                 profile = await self.bridge.signal.get_profile(
-                    username=user.username, address=Address(number=user.username)
+                    username=user.username, address=user.address
                 )
             except Exception as e:
                 self.log.exception(f"Failed to get {user.username}'s profile for whoami")
@@ -416,9 +416,7 @@ class ProvisioningAPI:
         user = await self.check_token_and_logged_in(request)
         puppet = await self._resolve_identifier(request.match_info["number"], user)
 
-        portal = await po.Portal.get_by_chat_id(
-            puppet.address, receiver=user.username, create=True
-        )
+        portal = await po.Portal.get_by_chat_id(puppet.uuid, receiver=user.username, create=True)
         assert portal, "Portal.get_by_chat_id with create=True can't return None"
 
         if portal.mxid:
@@ -445,9 +443,7 @@ class ProvisioningAPI:
     async def resolve_identifier(self, request: web.Request) -> web.Response:
         user = await self.check_token_and_logged_in(request)
         puppet = await self._resolve_identifier(request.match_info["number"], user)
-        portal = await po.Portal.get_by_chat_id(
-            puppet.address, receiver=user.username, create=False
-        )
+        portal = await po.Portal.get_by_chat_id(puppet.uuid, receiver=user.username, create=False)
         return web.json_response(
             {
                 "room_id": portal.mxid if portal else None,