Parcourir la source

Drop support for phone number IDs and v1 groups

Tulir Asokan il y a 2 ans
Parent
commit
2ec415648c

+ 11 - 21
mausignald/signald.py

@@ -20,9 +20,9 @@ from .types import (
     DeviceInfo,
     ErrorMessage,
     GetIdentitiesResponse,
-    Group,
     GroupAccessControl,
     GroupID,
+    GroupMember,
     GroupV2,
     IncomingMessage,
     JoinGroupResponse,
@@ -342,11 +342,9 @@ class SignaldClient(SignaldRPCClient):
         resp = await self.request_v1("list_contacts", account=username, **kwargs)
         return [Profile.deserialize(contact) for contact in resp["profiles"]]
 
-    async def list_groups(self, username: str) -> list[Group | GroupV2]:
+    async def list_groups(self, username: str) -> list[GroupV2]:
         resp = await self.request_v1("list_groups", account=username)
-        legacy = [Group.deserialize(group) for group in resp.get("legacyGroups", [])]
-        v2 = [GroupV2.deserialize(group) for group in resp.get("groups", [])]
-        return legacy + v2
+        return [GroupV2.deserialize(group) for group in resp.get("groups", [])]
 
     async def join_group(self, username: str, uri: str) -> JoinGroupResponse:
         resp = await self.request_v1("join_group", account=username, uri=uri)
@@ -355,27 +353,19 @@ class SignaldClient(SignaldRPCClient):
     async def leave_group(self, username: str, group_id: GroupID) -> None:
         await self.request_v1("leave_group", account=username, groupID=group_id)
 
-    async def ban_user(
-        self, username: str, group_id: GroupID, users: list[Address]
-    ) -> Group | GroupV2:
+    async def ban_user(self, username: str, group_id: GroupID, users: list[Address]) -> GroupV2:
         serialized_users = [user.serialize() for user in (users or [])]
         resp = await self.request_v1(
             "ban_user", account=username, group_id=group_id, users=serialized_users
         )
-        legacy = [Group.deserialize(group) for group in resp.get("legacyGroups", [])]
-        v2 = [GroupV2.deserialize(group) for group in resp.get("groups", [])]
-        return legacy + v2
+        return GroupV2.deserialize(resp)
 
-    async def unban_user(
-        self, username: str, group_id: GroupID, users: list[Address]
-    ) -> Group | GroupV2:
+    async def unban_user(self, username: str, group_id: GroupID, users: list[Address]) -> GroupV2:
         serialized_users = [user.serialize() for user in (users or [])]
         resp = await self.request_v1(
             "unban_user", account=username, group_id=group_id, users=serialized_users
         )
-        legacy = [Group.deserialize(group) for group in resp.get("legacyGroups", [])]
-        v2 = [GroupV2.deserialize(group) for group in resp.get("groups", [])]
-        return legacy + v2
+        return GroupV2.deserialize(resp)
 
     async def update_group(
         self,
@@ -388,7 +378,7 @@ class SignaldClient(SignaldRPCClient):
         remove_members: list[Address] | None = None,
         update_access_control: GroupAccessControl | None = None,
         update_role: GroupMember | None = None,
-    ) -> Group | GroupV2 | None:
+    ) -> GroupV2 | None:
         update_params = {
             key: value
             for key, value in {
@@ -408,10 +398,10 @@ class SignaldClient(SignaldRPCClient):
             if value is not None
         }
         resp = await self.request_v1("update_group", account=username, **update_params)
-        if "v1" in resp:
-            return Group.deserialize(resp["v1"])
-        elif "v2" in resp:
+        if "v2" in resp:
             return GroupV2.deserialize(resp["v2"])
+        elif "v1" in resp:
+            raise RuntimeError("v1 groups are no longer supported")
         else:
             return None
 

+ 1 - 1
mausignald/types.py

@@ -16,8 +16,8 @@ GroupID = NewType("GroupID", str)
 
 @dataclass(frozen=True, eq=False)
 class Address(SerializableAttrs):
-    number: Optional[str] = None
     uuid: Optional[UUID] = None
+    number: Optional[str] = None
 
     @property
     def is_valid(self) -> bool:

+ 10 - 4
mautrix_signal/commands/signal.py

@@ -57,8 +57,11 @@ async def _get_puppet_from_cmd(evt: CommandEvent) -> pu.Puppet | None:
         )
         return None
 
-    puppet: pu.Puppet = await pu.Puppet.get_by_address(Address(number=phone))
-    if not puppet.uuid and evt.sender.username:
+    puppet: pu.Puppet = await pu.Puppet.get_by_number(phone)
+    if not puppet:
+        if not evt.sender.username:
+            await evt.reply("UUID of user not known")
+            return None
         try:
             uuid = await evt.bridge.signal.find_uuid(evt.sender.username, puppet.number)
         except UnregisteredUserError:
@@ -66,7 +69,10 @@ async def _get_puppet_from_cmd(evt: CommandEvent) -> pu.Puppet | None:
             return None
 
         if uuid:
-            await puppet.handle_uuid_receive(uuid)
+            puppet = await pu.Puppet.get_by_uuid(uuid)
+        else:
+            await evt.reply("UUID of user not found")
+            return None
     return puppet
 
 
@@ -172,7 +178,7 @@ async def safety_number(evt: CommandEvent) -> None:
             return
         evt.args = evt.args[1:]
     if len(evt.args) == 0 and evt.portal and evt.portal.is_direct:
-        puppet = await pu.Puppet.get_by_address(evt.portal.chat_id)
+        puppet = await pu.Puppet.get_by_uuid(evt.portal.chat_id.uuid)
     else:
         puppet = await _get_puppet_from_cmd(evt)
     if not puppet:

+ 1 - 9
mautrix_signal/db/__init__.py

@@ -1,6 +1,3 @@
-import sqlite3
-import uuid
-
 from mautrix.util.async_db import Database
 
 from .disappearing_message import DisappearingMessage
@@ -10,6 +7,7 @@ from .puppet import Puppet
 from .reaction import Reaction
 from .upgrade import upgrade_table
 from .user import User
+from .util import ensure_uuid
 
 
 def init(db: Database) -> None:
@@ -17,12 +15,6 @@ def init(db: Database) -> None:
         table.db = db
 
 
-# TODO should this be in mautrix-python?
-sqlite3.register_adapter(uuid.UUID, lambda u: str(u))
-sqlite3.register_converter(
-    "UUID", lambda b: uuid.UUID(b.decode("utf-8") if isinstance(b, bytes) else b)
-)
-
 __all__ = [
     "upgrade_table",
     "init",

+ 22 - 31
mautrix_signal/db/message.py

@@ -16,15 +16,16 @@
 from __future__ import annotations
 
 from typing import TYPE_CHECKING, ClassVar
+from uuid import UUID
 
 from attr import dataclass
 import asyncpg
 
-from mausignald.types import Address, GroupID
+from mausignald.types import GroupID
 from mautrix.types import EventID, RoomID
 from mautrix.util.async_db import Database, Scheme
 
-from ..util import id_to_str
+from .util import ensure_uuid
 
 fake_db = Database.create("") if TYPE_CHECKING else None
 
@@ -35,9 +36,9 @@ class Message:
 
     mxid: EventID
     mx_room: RoomID
-    sender: Address
+    sender: UUID
     timestamp: int
-    signal_chat_id: GroupID | Address
+    signal_chat_id: GroupID | UUID
     signal_receiver: str
 
     async def insert(self) -> None:
@@ -49,9 +50,9 @@ class Message:
             q,
             self.mxid,
             self.mx_room,
-            self.sender.best_identifier,
+            self.sender,
             self.timestamp,
-            id_to_str(self.signal_chat_id),
+            self.signal_chat_id,
             self.signal_receiver,
         )
 
@@ -62,9 +63,9 @@ class Message:
         """
         await self.db.execute(
             q,
-            self.sender.best_identifier,
+            self.sender,
             self.timestamp,
-            id_to_str(self.signal_chat_id),
+            self.signal_chat_id,
             self.signal_receiver,
         )
 
@@ -73,12 +74,14 @@ class Message:
         await cls.db.execute("DELETE FROM message WHERE mx_room=$1", room_id)
 
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> Message:
+    def _from_row(cls, row: asyncpg.Record | None) -> Message | None:
+        if row is None:
+            return None
         data = {**row}
         chat_id = data.pop("signal_chat_id")
         if data["signal_receiver"]:
-            chat_id = Address.parse(chat_id)
-        sender = Address.parse(data.pop("sender"))
+            chat_id = ensure_uuid(chat_id)
+        sender = ensure_uuid(data.pop("sender"))
         return cls(signal_chat_id=chat_id, sender=sender, **data)
 
     @classmethod
@@ -87,29 +90,23 @@ class Message:
         SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message
         WHERE mxid=$1 AND mx_room=$2
         """
-        row = await cls.db.fetchrow(q, mxid, mx_room)
-        if not row:
-            return None
-        return cls._from_row(row)
+        return cls._from_row(await cls.db.fetchrow(q, mxid, mx_room))
 
     @classmethod
     async def get_by_signal_id(
         cls,
-        sender: Address,
+        sender: UUID,
         timestamp: int,
-        signal_chat_id: GroupID | Address,
+        signal_chat_id: GroupID | UUID,
         signal_receiver: str = "",
     ) -> Message | None:
         q = """
         SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message
         WHERE sender=$1 AND timestamp=$2 AND signal_chat_id=$3 AND signal_receiver=$4
         """
-        row = await cls.db.fetchrow(
-            q, sender.best_identifier, timestamp, id_to_str(signal_chat_id), signal_receiver
+        return cls._from_row(
+            await cls.db.fetchrow(q, sender, timestamp, signal_chat_id, signal_receiver)
         )
-        if not row:
-            return None
-        return cls._from_row(row)
 
     @classmethod
     async def find_by_timestamps(cls, timestamps: list[int]) -> list[Message]:
@@ -129,15 +126,12 @@ class Message:
         return [cls._from_row(row) for row in rows]
 
     @classmethod
-    async def find_by_sender_timestamp(cls, sender: Address, timestamp: int) -> Message | None:
+    async def find_by_sender_timestamp(cls, sender: UUID, timestamp: int) -> Message | None:
         q = """
         SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message
         WHERE sender=$1 AND timestamp=$2
         """
-        row = await cls.db.fetchrow(q, sender.best_identifier, timestamp)
-        if not row:
-            return None
-        return cls._from_row(row)
+        return cls._from_row(await cls.db.fetchrow(q, sender, timestamp))
 
     @classmethod
     async def get_first_before(cls, mx_room: RoomID, timestamp: int) -> Message | None:
@@ -147,7 +141,4 @@ class Message:
         ORDER BY timestamp DESC
         LIMIT 1
         """
-        row = await cls.db.fetchrow(q, mx_room, timestamp)
-        if not row:
-            return None
-        return cls._from_row(row)
+        return cls._from_row(await cls.db.fetchrow(q, mx_room, timestamp))

+ 15 - 18
mautrix_signal/db/portal.py

@@ -16,15 +16,16 @@
 from __future__ import annotations
 
 from typing import TYPE_CHECKING, ClassVar
+from uuid import UUID
 
 from attr import dataclass
 import asyncpg
 
-from mausignald.types import Address, GroupID
+from mausignald.types import GroupID
 from mautrix.types import ContentURI, RoomID, UserID
 from mautrix.util.async_db import Database
 
-from ..util import id_to_str
+from .util import ensure_uuid
 
 fake_db = Database.create("") if TYPE_CHECKING else None
 
@@ -33,7 +34,7 @@ fake_db = Database.create("") if TYPE_CHECKING else None
 class Portal:
     db: ClassVar[Database] = fake_db
 
-    chat_id: GroupID | Address
+    chat_id: GroupID | UUID
     receiver: str
     mxid: RoomID | None
     name: str | None
@@ -49,12 +50,12 @@ class Portal:
 
     @property
     def chat_id_str(self) -> str:
-        return id_to_str(self.chat_id)
+        return str(self.chat_id)
 
     @property
     def _values(self):
         return (
-            self.chat_id_str,
+            self.chat_id,
             self.receiver,
             self.mxid,
             self.name,
@@ -88,11 +89,13 @@ class Portal:
         await self.db.execute(q, *self._values)
 
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> Portal:
+    def _from_row(cls, row: asyncpg.Record | None) -> Portal | None:
+        if row is None:
+            return None
         data = {**row}
         chat_id = data.pop("chat_id")
         if data["receiver"]:
-            chat_id = Address.parse(chat_id)
+            chat_id = ensure_uuid(chat_id)
         return cls(chat_id=chat_id, **data)
 
     _columns = (
@@ -103,18 +106,12 @@ class Portal:
     @classmethod
     async def get_by_mxid(cls, mxid: RoomID) -> Portal | None:
         q = f"SELECT {cls._columns} FROM portal WHERE mxid=$1"
-        row = await cls.db.fetchrow(q, mxid)
-        if not row:
-            return None
-        return cls._from_row(row)
+        return cls._from_row(await cls.db.fetchrow(q, mxid))
 
     @classmethod
-    async def get_by_chat_id(cls, chat_id: GroupID | Address, receiver: str = "") -> Portal | None:
+    async def get_by_chat_id(cls, chat_id: GroupID | UUID, receiver: str = "") -> Portal | None:
         q = f"SELECT {cls._columns} FROM portal WHERE chat_id=$1 AND receiver=$2"
-        row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
-        if not row:
-            return None
-        return cls._from_row(row)
+        return cls._from_row(await cls.db.fetchrow(q, chat_id, receiver))
 
     @classmethod
     async def find_private_chats_of(cls, receiver: str) -> list[Portal]:
@@ -123,9 +120,9 @@ class Portal:
         return [cls._from_row(row) for row in rows]
 
     @classmethod
-    async def find_private_chats_with(cls, other_user: Address) -> list[Portal]:
+    async def find_private_chats_with(cls, other_user: UUID) -> list[Portal]:
         q = f"SELECT {cls._columns} FROM portal WHERE chat_id=$1 AND receiver<>''"
-        rows = await cls.db.fetch(q, other_user.best_identifier)
+        rows = await cls.db.fetch(q, other_user)
         return [cls._from_row(row) for row in rows]
 
     @classmethod

+ 41 - 89
mautrix_signal/db/puppet.py

@@ -22,9 +22,8 @@ from attr import dataclass
 from yarl import URL
 import asyncpg
 
-from mausignald.types import Address
 from mautrix.types import ContentURI, SyncToken, UserID
-from mautrix.util.async_db import Connection, Database
+from mautrix.util.async_db import Database
 
 fake_db = Database.create("") if TYPE_CHECKING else None
 
@@ -33,7 +32,7 @@ fake_db = Database.create("") if TYPE_CHECKING else None
 class Puppet:
     db: ClassVar[Database] = fake_db
 
-    uuid: UUID | None
+    uuid: UUID
     number: str | None
     name: str | None
     name_quality: int
@@ -41,9 +40,7 @@ class Puppet:
     avatar_url: ContentURI | None
     name_set: bool
     avatar_set: bool
-
-    uuid_registered: bool
-    number_registered: bool
+    is_registered: bool
 
     custom_mxid: UserID | None
     access_token: str | None
@@ -54,15 +51,9 @@ class Puppet:
     def _base_url_str(self) -> str | None:
         return str(self.base_url) if self.base_url else None
 
-    async def insert(self) -> None:
-        q = """
-        INSERT INTO puppet (uuid, number, name, name_quality, avatar_hash, avatar_url, name_set,
-                            avatar_set, uuid_registered, number_registered,
-                            custom_mxid, access_token, next_batch, base_url)
-        VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
-        """
-        await self.db.execute(
-            q,
+    @property
+    def _values(self):
+        return (
             self.uuid,
             self.number,
             self.name,
@@ -71,73 +62,43 @@ class Puppet:
             self.avatar_url,
             self.name_set,
             self.avatar_set,
-            self.uuid_registered,
-            self.number_registered,
+            self.is_registered,
             self.custom_mxid,
             self.access_token,
             self.next_batch,
             self._base_url_str,
         )
 
-    async def _set_uuid(self, uuid: UUID) -> None:
-        async with self.db.acquire() as conn, conn.transaction():
-            await conn.execute(
-                "DELETE FROM puppet WHERE uuid=$1 AND number<>$2", uuid, self.number
-            )
-            await conn.execute("UPDATE puppet SET uuid=$1 WHERE number=$2", uuid, self.number)
-            await self._update_number_to_uuid(conn, self.number, str(uuid))
+    async def insert(self) -> None:
+        q = """
+        INSERT INTO puppet (uuid, number, name, name_quality, avatar_hash, avatar_url,
+                            name_set, avatar_set, is_registered,
+                            custom_mxid, access_token, next_batch, base_url)
+        VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
+        """
+        await self.db.execute(q, *self._values)
 
     async def _set_number(self, number: str) -> None:
         async with self.db.acquire() as conn, conn.transaction():
             await conn.execute(
-                "DELETE FROM puppet WHERE number=$1 AND uuid<>$2", number, self.uuid
+                "UPDATE puppet SET number=null WHERE number=$1 AND uuid<>$2", number, self.uuid
             )
             await conn.execute("UPDATE puppet SET number=$1 WHERE uuid=$2", number, self.uuid)
-            await self._update_number_to_uuid(conn, number, str(self.uuid))
-
-    @staticmethod
-    async def _update_number_to_uuid(conn: Connection, old_number: str, new_uuid: str) -> None:
-        try:
-            async with conn.transaction():
-                await conn.execute(
-                    "UPDATE portal SET chat_id=$1 WHERE chat_id=$2", new_uuid, old_number
-                )
-        except asyncpg.UniqueViolationError:
-            await conn.execute("DELETE FROM portal WHERE chat_id=$1", old_number)
-        await conn.execute("UPDATE message SET sender=$1 WHERE sender=$2", new_uuid, old_number)
-        await conn.execute("UPDATE reaction SET author=$1 WHERE author=$2", new_uuid, old_number)
 
     async def update(self) -> None:
-        set_columns = (
-            "name=$3, name_quality=$4, avatar_hash=$5, avatar_url=$6, name_set=$7, avatar_set=$8, "
-            "uuid_registered=$9, number_registered=$10, "
-            "custom_mxid=$11, access_token=$12, next_batch=$13, base_url=$14"
-        )
-        q = (
-            f"UPDATE puppet SET uuid=$1, {set_columns} WHERE number=$2"
-            if self.uuid is None
-            else f"UPDATE puppet SET number=$2, {set_columns} WHERE uuid=$1"
-        )
-        await self.db.execute(
-            q,
-            self.uuid,
-            self.number,
-            self.name,
-            self.name_quality,
-            self.avatar_hash,
-            self.avatar_url,
-            self.name_set,
-            self.avatar_set,
-            self.uuid_registered,
-            self.number_registered,
-            self.custom_mxid,
-            self.access_token,
-            self.next_batch,
-            self._base_url_str,
-        )
+        q = """
+        UPDATE puppet
+        SET number=$2, name=$3, name_quality=$4, avatar_hash=$5, avatar_url=$6,
+            name_set=$7, avatar_set=$8, is_registered=$9,
+            custom_mxid=$10, access_token=$11, next_batch=$12, base_url=$13
+        WHERE uuid=$1
+        """
+        await self.db.execute(q, *self._values)
 
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> Puppet:
+    def _from_row(cls, row: asyncpg.Record | None) -> Puppet | None:
+        if not row:
+            return None
         data = {**row}
         base_url_str = data.pop("base_url")
         base_url = URL(base_url_str) if base_url_str is not None else None
@@ -145,36 +106,27 @@ class Puppet:
 
     _select_base = (
         "SELECT uuid, number, name, name_quality, avatar_hash, avatar_url, name_set, avatar_set, "
-        "       uuid_registered, number_registered, custom_mxid, access_token, "
-        "       next_batch, base_url "
+        "       is_registered, custom_mxid, access_token, next_batch, base_url "
         "FROM puppet"
     )
 
     @classmethod
-    async def get_by_address(cls, address: Address) -> Puppet | None:
-        if address.uuid:
-            if address.number:
-                row = await cls.db.fetchrow(
-                    f"{cls._select_base} WHERE uuid=$1 OR number=$2", address.uuid, address.number
-                )
-            else:
-                row = await cls.db.fetchrow(f"{cls._select_base} WHERE uuid=$1", address.uuid)
-        elif address.number:
-            row = await cls.db.fetchrow(f"{cls._select_base} WHERE number=$1", address.number)
-        else:
-            raise ValueError("Invalid address")
-        if not row:
-            return None
-        return cls._from_row(row)
+    async def get_by_uuid(cls, uuid: UUID) -> Puppet | None:
+        return cls._from_row(await cls.db.fetchrow(f"{cls._select_base} WHERE uuid=$1", uuid))
+
+    @classmethod
+    async def get_by_number(cls, number: str) -> Puppet | None:
+        return cls._from_row(await cls.db.fetchrow(f"{cls._select_base} WHERE number=$1", number))
 
     @classmethod
     async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
-        row = await cls.db.fetchrow(f"{cls._select_base} WHERE custom_mxid=$1", mxid)
-        if not row:
-            return None
-        return cls._from_row(row)
+        return cls._from_row(
+            await cls.db.fetchrow(f"{cls._select_base} WHERE custom_mxid=$1", mxid)
+        )
 
     @classmethod
     async def all_with_custom_mxid(cls) -> list[Puppet]:
-        rows = await cls.db.fetch(f"{cls._select_base} WHERE custom_mxid IS NOT NULL")
-        return [cls._from_row(row) for row in rows]
+        return [
+            cls._from_row(row)
+            for row in await cls.db.fetch(f"{cls._select_base} WHERE custom_mxid IS NOT NULL")
+        ]

+ 34 - 35
mautrix_signal/db/reaction.py

@@ -16,15 +16,16 @@
 from __future__ import annotations
 
 from typing import TYPE_CHECKING, ClassVar
+from uuid import UUID
 
 from attr import dataclass
 import asyncpg
 
-from mausignald.types import Address, GroupID
+from mausignald.types import GroupID
 from mautrix.types import EventID, RoomID
 from mautrix.util.async_db import Database
 
-from ..util import id_to_str
+from .util import ensure_uuid
 
 fake_db = Database.create("") if TYPE_CHECKING else None
 
@@ -35,11 +36,11 @@ class Reaction:
 
     mxid: EventID
     mx_room: RoomID
-    signal_chat_id: GroupID | Address
+    signal_chat_id: GroupID | UUID
     signal_receiver: str
-    msg_author: Address
+    msg_author: UUID
     msg_timestamp: int
-    author: Address
+    author: UUID
     emoji: str
 
     async def insert(self) -> None:
@@ -52,11 +53,11 @@ class Reaction:
             q,
             self.mxid,
             self.mx_room,
-            id_to_str(self.signal_chat_id),
+            self.signal_chat_id,
             self.signal_receiver,
-            self.msg_author.best_identifier,
+            self.msg_author,
             self.msg_timestamp,
-            self.author.best_identifier,
+            self.author,
             self.emoji,
         )
 
@@ -68,11 +69,11 @@ class Reaction:
             mxid,
             mx_room,
             emoji,
-            id_to_str(self.signal_chat_id),
+            self.signal_chat_id,
             self.signal_receiver,
-            self.msg_author.best_identifier,
+            self.msg_author,
             self.msg_timestamp,
-            self.author.best_identifier,
+            self.author,
         )
 
     async def delete(self) -> None:
@@ -82,21 +83,23 @@ class Reaction:
         )
         await self.db.execute(
             q,
-            id_to_str(self.signal_chat_id),
+            self.signal_chat_id,
             self.signal_receiver,
-            self.msg_author.best_identifier,
+            self.msg_author,
             self.msg_timestamp,
-            self.author.best_identifier,
+            self.author,
         )
 
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> Reaction:
+    def _from_row(cls, row: asyncpg.Record | None) -> Reaction | None:
+        if row is None:
+            return None
         data = {**row}
         chat_id = data.pop("signal_chat_id")
         if data["signal_receiver"]:
-            chat_id = Address.parse(chat_id)
-        msg_author = Address.parse(data.pop("msg_author"))
-        author = Address.parse(data.pop("author"))
+            chat_id = ensure_uuid(chat_id)
+        msg_author = ensure_uuid(data.pop("msg_author"))
+        author = ensure_uuid(data.pop("author"))
         return cls(signal_chat_id=chat_id, msg_author=msg_author, author=author, **data)
 
     @classmethod
@@ -106,19 +109,16 @@ class Reaction:
             "       msg_author, msg_timestamp, author, emoji "
             "FROM reaction WHERE mxid=$1 AND mx_room=$2"
         )
-        row = await cls.db.fetchrow(q, mxid, mx_room)
-        if not row:
-            return None
-        return cls._from_row(row)
+        return cls._from_row(await cls.db.fetchrow(q, mxid, mx_room))
 
     @classmethod
     async def get_by_signal_id(
         cls,
-        chat_id: GroupID | Address,
+        chat_id: GroupID | UUID,
         receiver: str,
-        msg_author: Address,
+        msg_author: UUID,
         msg_timestamp: int,
-        author: Address,
+        author: UUID,
     ) -> Reaction | None:
         q = (
             "SELECT mxid, mx_room, signal_chat_id, signal_receiver,"
@@ -126,14 +126,13 @@ class Reaction:
             "FROM reaction WHERE signal_chat_id=$1 AND signal_receiver=$2"
             "                    AND msg_author=$3 AND msg_timestamp=$4 AND author=$5"
         )
-        row = await cls.db.fetchrow(
-            q,
-            id_to_str(chat_id),
-            receiver,
-            msg_author.best_identifier,
-            msg_timestamp,
-            author.best_identifier,
+        return cls._from_row(
+            await cls.db.fetchrow(
+                q,
+                chat_id,
+                receiver,
+                msg_author,
+                msg_timestamp,
+                author,
+            )
         )
-        if not row:
-            return None
-        return cls._from_row(row)

+ 1 - 0
mautrix_signal/db/upgrade/__init__.py

@@ -13,4 +13,5 @@ from . import (
     v08_disappearing_messages,
     v09_group_topic,
     v10_puppet_name_quality,
+    v11_drop_number_support,
 )

+ 13 - 12
mautrix_signal/db/upgrade/v00_latest_revision.py

@@ -18,7 +18,7 @@ from mautrix.util.async_db import Connection
 from . import upgrade_table
 
 
-@upgrade_table.register(description="Initial revision", upgrades_to=10)
+@upgrade_table.register(description="Initial revision", upgrades_to=11)
 async def upgrade_latest(conn: Connection) -> None:
     await conn.execute(
         """CREATE TABLE portal (
@@ -49,7 +49,7 @@ async def upgrade_latest(conn: Connection) -> None:
     )
     await conn.execute(
         """CREATE TABLE puppet (
-            uuid         UUID UNIQUE,
+            uuid         UUID PRIMARY KEY,
             number       TEXT UNIQUE,
             name         TEXT,
             name_quality INTEGER NOT NULL DEFAULT 0,
@@ -58,8 +58,7 @@ async def upgrade_latest(conn: Connection) -> None:
             name_set     BOOLEAN NOT NULL DEFAULT false,
             avatar_set   BOOLEAN NOT NULL DEFAULT false,
 
-            uuid_registered   BOOLEAN NOT NULL DEFAULT false,
-            number_registered BOOLEAN NOT NULL DEFAULT false,
+            is_registered BOOLEAN NOT NULL DEFAULT false,
 
             custom_mxid  TEXT,
             access_token TEXT,
@@ -82,14 +81,14 @@ async def upgrade_latest(conn: Connection) -> None:
         """CREATE TABLE message (
             mxid    TEXT NOT NULL,
             mx_room TEXT NOT NULL,
-            sender          TEXT,
+            sender          UUID,
             timestamp       BIGINT,
             signal_chat_id  TEXT,
             signal_receiver TEXT,
 
             PRIMARY KEY (sender, timestamp, signal_chat_id, signal_receiver),
-            FOREIGN KEY (signal_chat_id, signal_receiver) REFERENCES portal(chat_id, receiver)
-                ON UPDATE CASCADE ON DELETE CASCADE,
+            FOREIGN KEY (signal_chat_id, signal_receiver) REFERENCES portal(chat_id, receiver) ON DELETE CASCADE,
+            FOREIGN KEY (sender) REFERENCES puppet(uuid) ON DELETE CASCADE,
             UNIQUE (mxid, mx_room)
         )"""
     )
@@ -100,16 +99,18 @@ async def upgrade_latest(conn: Connection) -> None:
 
             signal_chat_id  TEXT   NOT NULL,
             signal_receiver TEXT   NOT NULL,
-            msg_author      TEXT   NOT NULL,
+            msg_author      UUID   NOT NULL,
             msg_timestamp   BIGINT NOT NULL,
-            author          TEXT   NOT NULL,
+            author          UUID   NOT NULL,
 
             emoji TEXT NOT NULL,
 
             PRIMARY KEY (signal_chat_id, signal_receiver, msg_author, msg_timestamp, author),
-            FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver)
-                REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver)
-                ON DELETE CASCADE ON UPDATE CASCADE,
+            CONSTRAINT reaction_message_fkey
+                FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver)
+                    REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver)
+                    ON DELETE CASCADE,
+            FOREIGN KEY (author) REFERENCES puppet(uuid) ON DELETE CASCADE,
             UNIQUE (mxid, mx_room)
         )"""
     )

+ 97 - 0
mautrix_signal/db/upgrade/v11_drop_number_support.py

@@ -0,0 +1,97 @@
+# mautrix-signal - A Matrix-Signal puppeting bridge
+# Copyright (C) 2022 Tulir Asokan
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+from mautrix.util.async_db import Connection, Scheme
+
+from . import upgrade_table
+
+
+@upgrade_table.register(description="Drop support for phone numbers as puppet identifiers")
+async def upgrade_v11(conn: Connection, scheme: Scheme) -> None:
+    await conn.execute("DELETE FROM portal WHERE chat_id LIKE '+%'")
+    await conn.execute("DELETE FROM message WHERE sender LIKE '+%'")
+    await conn.execute("DELETE FROM reaction WHERE author LIKE '+%'")
+    await conn.execute("DELETE FROM puppet WHERE uuid IS NULL")
+    if scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
+        await conn.execute(
+            """
+            ALTER TABLE puppet
+                DROP CONSTRAINT puppet_uuid_key,
+                ADD CONSTRAINT puppet_pkey PRIMARY KEY (uuid)
+            """
+        )
+        await conn.execute("ALTER TABLE puppet DROP COLUMN number_registered")
+        await conn.execute("ALTER TABLE puppet RENAME COLUMN uuid_registered TO is_registered")
+        await conn.execute(
+            "ALTER TABLE reaction DROP CONSTRAINT reaction_msg_author_msg_timestamp_signal_chat_id_signal_re_fkey"
+        )
+        await conn.execute("ALTER TABLE message ALTER COLUMN sender TYPE UUID USING sender::uuid")
+        await conn.execute(
+            "ALTER TABLE reaction ALTER COLUMN msg_author TYPE UUID USING msg_author::uuid"
+        )
+        await conn.execute(
+            """
+            ALTER TABLE reaction ADD CONSTRAINT reaction_message_fkey
+                FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver)
+                REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver)
+                ON DELETE CASCADE
+            """
+        )
+        await conn.execute("ALTER TABLE reaction ALTER COLUMN author TYPE UUID USING author::uuid")
+        await conn.execute(
+            """
+            ALTER TABLE message ADD CONSTRAINT message_sender_fkey
+                FOREIGN KEY (sender) REFERENCES puppet(uuid) ON DELETE CASCADE
+            """
+        )
+        await conn.execute(
+            """
+            ALTER TABLE reaction ADD CONSTRAINT reaction_author_fkey
+                FOREIGN KEY (author) REFERENCES puppet(uuid) ON DELETE CASCADE
+            """
+        )
+    else:
+        await conn.execute(
+            """CREATE TABLE new_puppet (
+                uuid         UUID PRIMARY KEY,
+                number       TEXT UNIQUE,
+                name         TEXT,
+                name_quality INTEGER NOT NULL DEFAULT 0,
+                avatar_hash  TEXT,
+                avatar_url   TEXT,
+                name_set     BOOLEAN NOT NULL DEFAULT false,
+                avatar_set   BOOLEAN NOT NULL DEFAULT false,
+
+                is_registered BOOLEAN NOT NULL DEFAULT false,
+
+                custom_mxid  TEXT,
+                access_token TEXT,
+                next_batch   TEXT,
+                base_url     TEXT
+            )"""
+        )
+        await conn.execute(
+            """
+            INSERT INTO new_puppet (
+                uuid, number, name, name_quality, avatar_hash, avatar_url, name_set, avatar_set,
+                is_registered, custom_mxid, access_token, next_batch, base_url
+            )
+            SELECT uuid, number, name, name_quality, avatar_hash, avatar_url, name_set, avatar_set,
+                   uuid_registered, custom_mxid, access_token, next_batch, base_url
+            FROM puppet
+            """
+        )
+        await conn.execute("DROP TABLE puppet")
+        await conn.execute("ALTER TABLE new_puppet RENAME TO puppet")

+ 14 - 0
mautrix_signal/db/util.py

@@ -0,0 +1,14 @@
+from uuid import UUID
+import sqlite3
+
+
+def ensure_uuid(val: bytes | str | UUID) -> UUID:
+    if not isinstance(val, UUID):
+        if isinstance(val, bytes):
+            val = val.decode("utf-8")
+        return UUID(val)
+    return val
+
+
+sqlite3.register_adapter(UUID, str)
+sqlite3.register_converter("UUID", ensure_uuid)

+ 1 - 1
mautrix_signal/formatter.py

@@ -61,7 +61,7 @@ async def signal_to_matrix(message: MessageData) -> TextMessageEventContent:
 
             text_chunks.append(before)
             html_chunks.append(html.escape(before))
-            puppet = await pu.Puppet.get_by_address(Address(uuid=mention.uuid))
+            puppet = await pu.Puppet.get_by_uuid(mention.uuid)
             name = add_surrogate(puppet.name or puppet.mxid)
             text_chunks.append(name)
             html_chunks.append(f'<a href="https://matrix.to/#/{puppet.mxid}">{name}</a>')

+ 69 - 75
mautrix_signal/portal.py

@@ -97,7 +97,6 @@ from .db import (
     Reaction as DBReaction,
 )
 from .formatter import matrix_to_signal, signal_to_matrix
-from .util import id_to_str
 
 if TYPE_CHECKING:
     from .__main__ import SignalBridge
@@ -144,15 +143,15 @@ class Portal(DBPortal, BasePortal):
 
     _main_intent: IntentAPI | None
     _create_room_lock: asyncio.Lock
-    _msgts_dedup: deque[tuple[Address, int]]
-    _reaction_dedup: deque[tuple[Address, int, str, Address, bool]]
+    _msgts_dedup: deque[tuple[UUID, int]]
+    _reaction_dedup: deque[tuple[UUID, int, str, UUID, bool]]
     _reaction_lock: asyncio.Lock
     _pending_members: set[UUID] | None
     _expiration_lock: asyncio.Lock
 
     def __init__(
         self,
-        chat_id: GroupID | Address,
+        chat_id: GroupID | UUID,
         receiver: str,
         mxid: RoomID | None = None,
         name: str | None = None,
@@ -201,22 +200,12 @@ class Portal(DBPortal, BasePortal):
 
     @property
     def is_direct(self) -> bool:
-        return isinstance(self.chat_id, Address)
+        return isinstance(self.chat_id, UUID)
 
     @property
     def disappearing_enabled(self) -> bool:
         return self.is_direct or self.config["signal.enable_disappearing_messages_in_groups"]
 
-    def handle_uuid_receive(self, uuid: UUID) -> None:
-        if not self.is_direct or self.chat_id.uuid:
-            raise ValueError(
-                "handle_uuid_receive can only be used for private chat portals with a phone "
-                "number chat_id"
-            )
-        del self.by_chat_id[(self.chat_id_str, self.receiver)]
-        self.chat_id = Address(uuid=uuid)
-        self.by_chat_id[(self.chat_id_str, self.receiver)] = self
-
     @classmethod
     def init_cls(cls, bridge: "SignalBridge") -> None:
         cls.config = bridge.config
@@ -265,7 +254,7 @@ class Portal(DBPortal, BasePortal):
                 signal_receiver=self.receiver,
                 msg_author=message.sender,
                 msg_timestamp=message.timestamp,
-                author=sender.address,
+                author=sender.uuid,
             ).insert()
 
     # endregion
@@ -426,14 +415,14 @@ class Portal(DBPortal, BasePortal):
             await self.apply_relay_message_format(orig_sender, message)
 
         request_id = int(time.time() * 1000)
-        self._msgts_dedup.appendleft((sender.address, request_id))
+        self._msgts_dedup.appendleft((sender.uuid, request_id))
 
         quote = None
         if message.get_reply_to():
             reply = await DBMessage.get_by_mxid(message.get_reply_to(), self.mxid)
             # TODO include actual text? either store in db or fetch event from homeserver
             if reply is not None:
-                quote = Quote(id=reply.timestamp, author=reply.sender, text="")
+                quote = Quote(id=reply.timestamp, author=Address(uuid=reply.sender), text="")
                 # TODO only send this when it's actually a reply to an attachment?
                 #      Neither Signal Android nor iOS seem to care though, so this works too
                 quote.attachments = [QuotedAttachment("", "")]
@@ -445,8 +434,11 @@ class Portal(DBPortal, BasePortal):
         if message.msgtype.is_text:
             text, mentions = await matrix_to_signal(message)
             message_previews = message.get(BEEPER_LINK_PREVIEWS_KEY, [])
-            potential_link_previews = await asyncio.gather(
-                *(self._beeper_link_preview_to_signal(m) for m in message_previews)
+            potential_link_previews = cast(
+                list[LinkPreview | None],
+                await asyncio.gather(
+                    *(self._beeper_link_preview_to_signal(m) for m in message_previews)
+                ),
             )
             link_previews = [p for p in potential_link_previews if p is not None]
         elif message.msgtype.is_media:
@@ -504,7 +496,7 @@ class Portal(DBPortal, BasePortal):
         msg = DBMessage(
             mxid=event_id,
             mx_room=self.mxid,
-            sender=sender.address,
+            sender=sender.uuid,
             timestamp=request_id,
             signal_chat_id=self.chat_id,
             signal_receiver=self.receiver,
@@ -643,18 +635,18 @@ class Portal(DBPortal, BasePortal):
 
         async with self._reaction_lock:
             existing = await DBReaction.get_by_signal_id(
-                self.chat_id, self.receiver, message.sender, message.timestamp, sender.address
+                self.chat_id, self.receiver, message.sender, message.timestamp, sender.uuid
             )
             if existing and existing.emoji == emoji:
                 return
 
-            dedup_id = (message.sender, message.timestamp, emoji, sender.address, False)
+            dedup_id = (message.sender, message.timestamp, emoji, sender.uuid, False)
             self._reaction_dedup.appendleft(dedup_id)
 
             reaction = Reaction(
                 emoji=emoji,
                 remove=False,
-                target_author=message.sender,
+                target_author=Address(uuid=message.sender),
                 target_sent_timestamp=message.timestamp,
             )
             self.log.trace(f"{sender.mxid} reacted to {message.timestamp} with {emoji}")
@@ -713,7 +705,7 @@ class Portal(DBPortal, BasePortal):
                 remove_reaction = Reaction(
                     emoji=reaction.emoji,
                     remove=True,
-                    target_author=reaction.msg_author,
+                    target_author=Address(uuid=reaction.msg_author),
                     target_sent_timestamp=reaction.msg_timestamp,
                 )
                 await self.signal.react(
@@ -825,7 +817,7 @@ class Portal(DBPortal, BasePortal):
             is_banned = False
             if info.banned_members:
                 for member in info.banned_members:
-                    is_banned = user.address.uuid == member.uuid or is_banned
+                    is_banned = user.uuid == member.uuid or is_banned
             if not is_banned:
                 await self.main_intent.unban_user(
                     self.mxid, user.mxid, reason=f"Failed to ban Signal user: {e}"
@@ -846,7 +838,7 @@ class Portal(DBPortal, BasePortal):
             info = await self.signal.get_group(source.username, self.chat_id)
             if info.banned_members:
                 for member in info.banned_members:
-                    if member.uuid == user.address.uuid:
+                    if member.uuid == user.uuid:
                         await self.main_intent.ban_user(
                             self.mxid, user.mxid, reason=f"Failed to unban Signal user: {e}"
                         )
@@ -984,18 +976,18 @@ class Portal(DBPortal, BasePortal):
                     changes[user] = levels.users_default
         if changes:
             for user, level in changes.items():
-                address = p.Puppet.get_id_from_mxid(user)
-                if not address:
+                uuid = p.Puppet.get_id_from_mxid(user)
+                if not uuid:
                     mx_user = await u.User.get_by_mxid(user, create=False)
                     if not mx_user or not mx_user.is_logged_in:
                         continue
-                    address = mx_user.address
-                if not address or not address.uuid:
+                    uuid = mx_user.uuid
+                if not uuid:
                     continue
                 signal_role = (
                     GroupMemberRole.DEFAULT if level < 50 else GroupMemberRole.ADMINISTRATOR
                 )
-                group_member = GroupMember(uuid=address.uuid, role=signal_role)
+                group_member = GroupMember(uuid=uuid, role=signal_role)
                 try:
                     update_meta = await self.signal.update_group(
                         sender.username, self.chat_id, update_role=group_member
@@ -1080,18 +1072,15 @@ class Portal(DBPortal, BasePortal):
     # endregion
     # region Signal event handling
 
-    @staticmethod
-    async def _resolve_address(address: Address) -> Address:
-        puppet = await p.Puppet.get_by_address(address, create=False)
-        return puppet.address
-
     async def _find_quote_event_id(self, quote: Quote | None) -> MessageEvent | EventID | None:
         if not quote:
             return None
 
-        author_address = await self._resolve_address(quote.author)
+        puppet = await p.Puppet.get_by_address(quote.author, create=False)
+        if not puppet:
+            return None
         reply_msg = await DBMessage.get_by_signal_id(
-            author_address, quote.id, self.chat_id, self.receiver
+            puppet.uuid, quote.id, self.chat_id, self.receiver
         )
         if not reply_msg:
             return None
@@ -1148,7 +1137,7 @@ class Portal(DBPortal, BasePortal):
     async def handle_signal_message(
         self, source: u.User, sender: p.Puppet, message: MessageData
     ) -> None:
-        if (sender.address, message.timestamp) in self._msgts_dedup:
+        if (sender.uuid, message.timestamp) in self._msgts_dedup:
             self.log.debug(
                 f"Ignoring message {message.timestamp} by {sender.uuid} as it was already handled "
                 "(message.timestamp in dedup queue)"
@@ -1157,9 +1146,9 @@ class Portal(DBPortal, BasePortal):
                 source.username, sender.address, timestamps=[message.timestamp]
             )
             return
-        self._msgts_dedup.appendleft((sender.address, message.timestamp))
+        self._msgts_dedup.appendleft((sender.uuid, message.timestamp))
         old_message = await DBMessage.get_by_signal_id(
-            sender.address, message.timestamp, self.chat_id, self.receiver
+            sender.uuid, message.timestamp, self.chat_id, self.receiver
         )
         if old_message is not None:
             self.log.debug(
@@ -1258,7 +1247,7 @@ class Portal(DBPortal, BasePortal):
             msg = DBMessage(
                 mxid=event_id,
                 mx_room=self.mxid,
-                sender=sender.address,
+                sender=sender.uuid,
                 timestamp=message.timestamp,
                 signal_chat_id=self.chat_id,
                 signal_receiver=self.receiver,
@@ -1289,6 +1278,9 @@ class Portal(DBPortal, BasePortal):
         else:
             return
         editor = await p.Puppet.get_by_address(group_change.editor)
+        if not editor:
+            self.log.warning(f"Didn't get puppet for group change editor {group_change.editor}")
+            return
         editor_intent = editor.intent_for(self)
         if (
             group_change.delete_members
@@ -1316,7 +1308,7 @@ class Portal(DBPortal, BasePortal):
             levels = await editor.intent_for(self).get_power_levels(self.mxid)
             for group_member in group_change.modify_member_roles:
                 users = [
-                    await p.Puppet.get_by_address(group_member.address),
+                    await p.Puppet.get_by_uuid(group_member.uuid),
                     await u.User.get_by_uuid(group_member.uuid),
                 ]
                 for user in users:
@@ -1337,7 +1329,7 @@ class Portal(DBPortal, BasePortal):
         if group_change.new_banned_members:
             for banned_member in group_change.new_banned_members:
                 users = [
-                    await p.Puppet.get_by_address(banned_member.address),
+                    await p.Puppet.get_by_uuid(banned_member.uuid),
                     await u.User.get_by_uuid(banned_member.uuid),
                 ]
                 for user in users:
@@ -1358,7 +1350,7 @@ class Portal(DBPortal, BasePortal):
         if group_change.new_unbanned_members:
             for banned_member in group_change.new_unbanned_members:
                 users = [
-                    await p.Puppet.get_by_address(banned_member.address),
+                    await p.Puppet.get_by_uuid(banned_member.uuid),
                     await u.User.get_by_uuid(banned_member.uuid),
                 ]
                 for user in users:
@@ -1387,7 +1379,7 @@ class Portal(DBPortal, BasePortal):
                 + (group_change.new_pending_members or [])
                 + (group_change.promote_requesting_members or [])
             ):
-                puppet = await p.Puppet.get_by_address(group_member.address)
+                puppet = await p.Puppet.get_by_uuid(group_member.uuid)
                 await source.sync_contact(group_member.address)
                 users = [puppet, await u.User.get_by_uuid(group_member.uuid)]
                 for user in users:
@@ -1422,7 +1414,7 @@ class Portal(DBPortal, BasePortal):
         if group_change.promote_pending_members:
             for group_member in group_change.promote_pending_members:
                 await source.sync_contact(group_member.address)
-                user = await p.Puppet.get_by_address(group_member.address)
+                user = await p.Puppet.get_by_uuid(group_member.uuid)
                 if not user:
                     continue
                 try:
@@ -1438,7 +1430,7 @@ class Portal(DBPortal, BasePortal):
                     self.log.debug(
                         f"Profile of puppet with uuid {group_member.uuid} is unavailable"
                     )
-                user = await p.Puppet.get_by_address(group_member.address)
+                user = await p.Puppet.get_by_uuid(group_member.uuid)
                 try:
                     await user.intent_for(self).knock_room(self.mxid, reason="via invite link")
                 except (MForbidden, MBadState) as e:
@@ -1699,16 +1691,24 @@ class Portal(DBPortal, BasePortal):
     async def handle_signal_reaction(
         self, sender: p.Puppet, reaction: Reaction, timestamp: int
     ) -> None:
-        author_address = await self._resolve_address(reaction.target_author)
+        author_puppet = await p.Puppet.get_by_address(reaction.target_author, create=False)
+        if not author_puppet:
+            return None
         target_id = reaction.target_sent_timestamp
         async with self._reaction_lock:
-            dedup_id = (author_address, target_id, reaction.emoji, sender.address, reaction.remove)
+            dedup_id = (
+                author_puppet.uuid,
+                target_id,
+                reaction.emoji,
+                sender.uuid,
+                reaction.remove,
+            )
             if dedup_id in self._reaction_dedup:
                 return
             self._reaction_dedup.appendleft(dedup_id)
 
         existing = await DBReaction.get_by_signal_id(
-            self.chat_id, self.receiver, author_address, target_id, sender.address
+            self.chat_id, self.receiver, author_puppet.uuid, target_id, sender.uuid
         )
 
         if reaction.remove:
@@ -1724,7 +1724,7 @@ class Portal(DBPortal, BasePortal):
             return
 
         message = await DBMessage.get_by_signal_id(
-            author_address, target_id, self.chat_id, self.receiver
+            author_puppet.uuid, target_id, self.chat_id, self.receiver
         )
         if not message:
             self.log.debug(f"Ignoring reaction to unknown message {target_id}")
@@ -1733,12 +1733,12 @@ class Portal(DBPortal, BasePortal):
         intent = sender.intent_for(self)
         matrix_emoji = variation_selector.add(reaction.emoji)
         mxid = await intent.react(message.mx_room, message.mxid, matrix_emoji, timestamp=timestamp)
-        self.log.debug(f"{sender.address} reacted to {message.mxid} -> {mxid}")
+        self.log.debug(f"{sender.uuid} reacted to {message.mxid} -> {mxid}")
         await self._upsert_reaction(existing, intent, mxid, sender, message, reaction.emoji)
 
     async def handle_signal_delete(self, sender: p.Puppet, message_ts: int) -> None:
         message = await DBMessage.get_by_signal_id(
-            sender.address, message_ts, self.chat_id, self.receiver
+            sender.uuid, message_ts, self.chat_id, self.receiver
         )
         if not message:
             return
@@ -1888,7 +1888,7 @@ class Portal(DBPortal, BasePortal):
     async def get_dm_puppet(self) -> p.Puppet | None:
         if not self.is_direct:
             return None
-        return await p.Puppet.get_by_address(self.chat_id)
+        return await p.Puppet.get_by_uuid(self.chat_id)
 
     async def update_info_from_puppet(self, puppet: p.Puppet | None = None) -> None:
         if not self.is_direct:
@@ -2023,9 +2023,12 @@ class Portal(DBPortal, BasePortal):
                 try:
                     await self.main_intent.invite_user(self.mxid, user.mxid, check_cache=True)
                 except (MForbidden, IntentError, MBadState) as e:
-                    self.log.debug(f"could not invite {user.mxid}: {e}")
+                    self.log.debug(f"Failed to invite {user.mxid}: {e}")
 
             puppet = await p.Puppet.get_by_address(address)
+            if not puppet:
+                self.log.warning(f"Didn't find puppet for member {address}")
+                continue
             try:
                 await source.sync_contact(address)
             except ProfileUnavailableError:
@@ -2036,7 +2039,7 @@ class Portal(DBPortal, BasePortal):
                 )
             except (MForbidden, IntentError, MBadState) as e:
                 self.log.debug(f"could not invite {user.mxid}: {e}")
-            if not address.uuid in self._pending_members:
+            if address.uuid not in self._pending_members:
                 await puppet.intent_for(self).ensure_joined(self.mxid)
             remove_users.discard(puppet.default_mxid)
 
@@ -2245,7 +2248,7 @@ class Portal(DBPortal, BasePortal):
             if isinstance(info, GroupV2):
                 ac = info.access_control
                 for detail in info.member_detail + info.pending_member_detail:
-                    puppet = await p.Puppet.get_by_address(Address(uuid=detail.uuid))
+                    puppet = await p.Puppet.get_by_uuid(detail.uuid)
                     puppet_mxid = puppet.intent_for(self).mxid
                     current_level = levels.get_user_level(puppet_mxid)
                     if bot_pl > current_level and bot_pl >= 50:
@@ -2324,7 +2327,7 @@ class Portal(DBPortal, BasePortal):
             )
             if self.is_direct:
                 invites.append(self.az.bot_mxid)
-        if self.is_direct and source.address == self.chat_id:
+        if self.is_direct and source.uuid == self.chat_id:
             name = self.name = "Signal Note to Self"
         elif self.encrypted or self.private_chat_portal_meta or not self.is_direct:
             name = self.name
@@ -2416,7 +2419,7 @@ class Portal(DBPortal, BasePortal):
         return cls._db_to_portals(super().all_with_room())
 
     @classmethod
-    def find_private_chats_with(cls, other_user: Address) -> AsyncGenerator[Portal, None]:
+    def find_private_chats_with(cls, other_user: UUID) -> AsyncGenerator[Portal, None]:
         return cls._db_to_portals(super().find_private_chats_with(other_user))
 
     @classmethod
@@ -2431,7 +2434,7 @@ class Portal(DBPortal, BasePortal):
 
     @classmethod
     @async_getter_lock
-    async def get_by_mxid(cls, mxid: RoomID) -> Portal | None:
+    async def get_by_mxid(cls, mxid: RoomID, /) -> Portal | None:
         try:
             return cls.by_mxid[mxid]
         except KeyError:
@@ -2445,28 +2448,19 @@ class Portal(DBPortal, BasePortal):
         return None
 
     @classmethod
+    @async_getter_lock
     async def get_by_chat_id(
-        cls, chat_id: GroupID | Address, *, receiver: str = "", create: bool = False
+        cls, chat_id: GroupID | UUID, receiver: str = "", /, *, create: bool
     ) -> Portal | None:
         if isinstance(chat_id, str):
             receiver = ""
-        elif not isinstance(chat_id, Address):
+        elif not isinstance(chat_id, UUID):
             raise ValueError(f"Invalid chat ID type {type(chat_id)}")
         elif not receiver:
             raise ValueError("Direct chats must have a receiver")
-        best_id = id_to_str(chat_id)
-        portal = await cls._get_by_chat_id(best_id, receiver, create=create, chat_id=chat_id)
-        if portal:
-            portal.log.debug(f"get_by_chat_id({chat_id}, {receiver}) -> {hex(id(portal))}")
-        return portal
 
-    @classmethod
-    @async_getter_lock
-    async def _get_by_chat_id(
-        cls, best_id: str, receiver: str, *, create: bool, chat_id: GroupID | Address
-    ) -> Portal | None:
         try:
-            return cls.by_chat_id[(best_id, receiver)]
+            return cls.by_chat_id[(str(chat_id), receiver)]
         except KeyError:
             pass
 

+ 86 - 112
mautrix_signal/puppet.py

@@ -23,6 +23,7 @@ import os.path
 
 from yarl import URL
 
+from mausignald.errors import UnregisteredUserError
 from mausignald.types import Address, Profile
 from mautrix.appservice import IntentAPI
 from mautrix.bridge import BasePuppet, async_getter_lock
@@ -37,7 +38,7 @@ from mautrix.types import (
 )
 from mautrix.util.simple_template import SimpleTemplate
 
-from . import portal as p, user as u
+from . import portal as p, signal, user as u
 from .config import Config
 from .db import Puppet as DBPuppet
 
@@ -58,6 +59,7 @@ class Puppet(DBPuppet, BasePuppet):
     mxid_template: SimpleTemplate[str]
 
     config: Config
+    signal: signal.SignalHandler
 
     default_mxid_intent: IntentAPI
     default_mxid: UserID
@@ -67,7 +69,7 @@ class Puppet(DBPuppet, BasePuppet):
 
     def __init__(
         self,
-        uuid: UUID | None,
+        uuid: UUID,
         number: str | None,
         name: str | None = None,
         name_quality: int = 0,
@@ -75,13 +77,14 @@ class Puppet(DBPuppet, BasePuppet):
         avatar_hash: str | None = None,
         name_set: bool = False,
         avatar_set: bool = False,
-        uuid_registered: bool = False,
-        number_registered: bool = False,
+        is_registered: bool = False,
         custom_mxid: UserID | None = None,
         access_token: str | None = None,
         next_batch: SyncToken | None = None,
         base_url: URL | None = None,
     ) -> None:
+        assert uuid, "UUID must be set for ghosts"
+        assert isinstance(uuid, UUID)
         super().__init__(
             uuid=uuid,
             number=number,
@@ -91,8 +94,7 @@ class Puppet(DBPuppet, BasePuppet):
             avatar_hash=avatar_hash,
             name_set=name_set,
             avatar_set=avatar_set,
-            uuid_registered=uuid_registered,
-            number_registered=number_registered,
+            is_registered=is_registered,
             custom_mxid=custom_mxid,
             access_token=access_token,
             next_batch=next_batch,
@@ -100,7 +102,7 @@ class Puppet(DBPuppet, BasePuppet):
         )
         self.log = self.log.getChild(str(uuid) if uuid else number)
 
-        self.default_mxid = self.get_mxid_from_id(self.address)
+        self.default_mxid = self.get_mxid_from_id(self.uuid)
         self.default_mxid_intent = self.az.intent.user(self.default_mxid)
         self.intent = self._fresh_intent()
 
@@ -111,6 +113,7 @@ class Puppet(DBPuppet, BasePuppet):
     def init_cls(cls, bridge: "SignalBridge") -> AsyncIterable[Awaitable[None]]:
         cls.config = bridge.config
         cls.loop = bridge.loop
+        cls.signal = bridge.signal
         cls.mx = bridge.matrix
         cls.az = bridge.az
         cls.hs_domain = cls.config["homeserver.domain"]
@@ -140,28 +143,10 @@ class Puppet(DBPuppet, BasePuppet):
             return self.default_mxid_intent
         return self.intent
 
-    @property
-    def is_registered(self) -> bool:
-        return self.uuid_registered if self.uuid is not None else self.number_registered
-
-    @is_registered.setter
-    def is_registered(self, value: bool) -> None:
-        if self.uuid is not None:
-            self.uuid_registered = value
-        else:
-            self.number_registered = value
-
     @property
     def address(self) -> Address:
         return Address(uuid=self.uuid, number=self.number)
 
-    async def handle_uuid_receive(self, uuid: UUID) -> None:
-        async with self._uuid_lock:
-            if self.uuid:
-                # Received UUID was handled while this call was waiting
-                return
-            await self._handle_uuid_receive(uuid)
-
     async def handle_number_receive(self, number: str) -> None:
         async with self._uuid_lock:
             if self.number == number:
@@ -171,36 +156,6 @@ class Puppet(DBPuppet, BasePuppet):
             self.number = number
             self.by_number[self.number] = self
             await self._set_number(number)
-            async for portal in p.Portal.find_private_chats_with(Address(number=number)):
-                self.log.trace(f"Updating chat_id of private chat portal {portal.receiver}")
-                portal.handle_uuid_receive(self.uuid)
-            prev_mxid = self.get_mxid_from_id(Address(number=number))
-            if await self.az.state_store.is_registered(prev_mxid):
-                prev_intent = self.az.intent.user(prev_mxid)
-                await self._migrate_memberships(prev_intent, self.default_mxid_intent)
-
-    async def _handle_uuid_receive(self, uuid: UUID) -> None:
-        self.log.debug(f"Found UUID for user: {uuid}")
-        user = await u.User.get_by_username(self.number)
-        if user and not user.uuid:
-            user.uuid = self.uuid
-            user.by_uuid[user.uuid] = user
-            await user.update()
-        self.uuid = uuid
-        self.by_uuid[self.uuid] = self
-        await self._set_uuid(uuid)
-        async for portal in p.Portal.find_private_chats_with(Address(number=self.number)):
-            self.log.trace(f"Updating chat_id of private chat portal {portal.receiver}")
-            portal.handle_uuid_receive(self.uuid)
-        prev_intent = self.default_mxid_intent
-        self.default_mxid = self.get_mxid_from_id(self.address)
-        self.default_mxid_intent = self.az.intent.user(self.default_mxid)
-        self.intent = self._fresh_intent()
-        await self.default_mxid_intent.ensure_registered()
-        if self.name:
-            await self.default_mxid_intent.set_displayname(self.name)
-        self.log = Puppet.log.getChild(str(uuid))
-        await self._migrate_memberships(prev_intent, self.default_mxid_intent)
 
     async def _migrate_memberships(self, prev_intent: IntentAPI, new_intent: IntentAPI) -> None:
         self.log.debug(f"Migrating memberships {prev_intent.mxid} -> {new_intent.mxid}")
@@ -235,8 +190,6 @@ class Puppet(DBPuppet, BasePuppet):
     async def update_info(self, info: Profile | Address, source: u.User) -> None:
         update = False
         address = info.address if isinstance(info, Profile) else info
-        if address.uuid and not self.uuid:
-            await self.handle_uuid_receive(address.uuid)
         if address.number and address.number != self.number:
             await self.handle_number_receive(address.number)
             update = True
@@ -362,7 +315,7 @@ class Puppet(DBPuppet, BasePuppet):
         return True
 
     async def _update_portal_meta(self) -> None:
-        async for portal in p.Portal.find_private_chats_with(self.address):
+        async for portal in p.Portal.find_private_chats_with(self.uuid):
             if portal.receiver == self.number:
                 # This is a note to self chat, don't change the name
                 continue
@@ -376,20 +329,13 @@ class Puppet(DBPuppet, BasePuppet):
 
     async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
         portal: p.Portal = await p.Portal.get_by_mxid(room_id)
-        if not portal or not portal.is_direct:
-            return True
-        elif portal.chat_id.uuid and self.uuid:
-            return portal.chat_id.uuid != self.uuid
-        elif portal.chat_id.number and self.number:
-            return portal.chat_id.number != self.number
-        else:
-            return True
+        # Leave all portals except the notes to self room
+        return not (portal and portal.is_direct and portal.chat_id.uuid == self.uuid)
 
     # region Database getters
 
     def _add_to_cache(self) -> None:
-        if self.uuid:
-            self.by_uuid[self.uuid] = self
+        self.by_uuid[self.uuid] = self
         if self.number:
             self.by_number[self.number] = self
         if self.custom_mxid:
@@ -400,10 +346,10 @@ class Puppet(DBPuppet, BasePuppet):
 
     @classmethod
     async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Puppet | None:
-        address = cls.get_id_from_mxid(mxid)
-        if not address:
+        uuid = cls.get_id_from_mxid(mxid)
+        if not uuid:
             return None
-        return await cls.get_by_address(address, create)
+        return await cls.get_by_uuid(uuid, create=create)
 
     @classmethod
     @async_getter_lock
@@ -421,59 +367,90 @@ class Puppet(DBPuppet, BasePuppet):
         return None
 
     @classmethod
-    def get_id_from_mxid(cls, mxid: UserID) -> Address | None:
+    def get_id_from_mxid(cls, mxid: UserID) -> UUID | None:
         identifier = cls.mxid_template.parse(mxid)
         if not identifier:
             return None
-        if identifier.startswith("phone_"):
-            return Address(number="+" + identifier[len("phone_") :])
-        else:
+        try:
+            return UUID(identifier.upper())
+        except ValueError:
+            return None
+
+    @classmethod
+    def get_mxid_from_id(cls, uuid: UUID) -> UserID:
+        return UserID(cls.mxid_template.format_full(str(uuid).lower()))
+
+    @classmethod
+    @async_getter_lock
+    async def get_by_number(
+        cls, number: str, /, *, resolve_via: str | None = None, raise_resolve: bool = False
+    ) -> Puppet | None:
+        try:
+            return cls.by_number[number]
+        except KeyError:
+            pass
+
+        puppet = cast(cls, await super().get_by_number(number))
+        if puppet is not None:
+            puppet._add_to_cache()
+            return puppet
+
+        if resolve_via:
+            cls.log.debug(
+                f"Couldn't find puppet with number {number}, resolving UUID via {resolve_via}"
+            )
             try:
-                return Address(uuid=UUID(identifier.upper()))
-            except ValueError:
+                uuid = await cls.signal.find_uuid(resolve_via, number)
+            except UnregisteredUserError:
+                if raise_resolve:
+                    raise
+                cls.log.debug(f"Resolving {number} via {resolve_via} threw UnregisteredUserError")
                 return None
+            except Exception:
+                if raise_resolve:
+                    raise
+                cls.log.exception(f"Failed to resolve {number} via {resolve_via}")
+                return None
+            if uuid:
+                cls.log.debug(f"Found {uuid} for {number} after resolving via {resolve_via}")
+                return await cls.get_by_uuid(uuid, number=number)
+            else:
+                cls.log.debug(f"Didn't find UUID for {number} via {resolve_via}")
+
+        return None
 
     @classmethod
-    def get_mxid_from_id(cls, address: Address) -> UserID:
-        if address.uuid:
-            identifier = str(address.uuid).lower()
-        elif address.number:
-            identifier = f"phone_{address.number.lstrip('+')}"
+    async def get_by_address(
+        cls,
+        address: Address,
+        create: bool = True,
+        resolve_via: str | None = None,
+        raise_resolve: bool = False,
+    ) -> Puppet | None:
+        if not address.uuid:
+            return await cls.get_by_number(
+                address.number, resolve_via=resolve_via, raise_resolve=raise_resolve
+            )
         else:
-            raise ValueError("Empty address")
-        return UserID(cls.mxid_template.format_full(identifier))
+            return await cls.get_by_uuid(address.uuid, create=create, number=address.number)
 
     @classmethod
     @async_getter_lock
-    async def get_by_address(cls, address: Address, create: bool = True) -> Puppet | None:
-        puppet = await cls._get_by_address(address, create)
-        if puppet and address.uuid and not puppet.uuid:
-            # We found a UUID for this user, store it ASAP
-            await puppet.handle_uuid_receive(address.uuid)
-        return puppet
-
-    @classmethod
-    async def _get_by_address(cls, address: Address, create: bool = True) -> Puppet | None:
-        if not address.is_valid:
-            raise ValueError("Empty address")
-        if address.uuid:
-            try:
-                return cls.by_uuid[address.uuid]
-            except KeyError:
-                pass
-        if address.number:
-            try:
-                return cls.by_number[address.number]
-            except KeyError:
-                pass
+    async def get_by_uuid(
+        cls, uuid: UUID, /, *, create: bool = True, number: str | None = None
+    ) -> Puppet | None:
+        try:
+            return cls.by_uuid[uuid]
+        except KeyError:
+            pass
 
-        puppet = cast(cls, await super().get_by_address(address))
+        puppet = cast(cls, await super().get_by_uuid(uuid))
         if puppet is not None:
             puppet._add_to_cache()
             return puppet
 
         if create:
-            puppet = cls(address.uuid, address.number)
+            puppet = cls(uuid, number)
             await puppet.insert()
             puppet._add_to_cache()
             return puppet
@@ -488,10 +465,7 @@ class Puppet(DBPuppet, BasePuppet):
             try:
                 yield cls.by_uuid[puppet.uuid]
             except KeyError:
-                try:
-                    yield cls.by_number[puppet.number]
-                except KeyError:
-                    puppet._add_to_cache()
-                    yield puppet
+                puppet._add_to_cache()
+                yield puppet
 
     # endregion

+ 10 - 2
mautrix_signal/signal.py

@@ -16,6 +16,7 @@
 from __future__ import annotations
 
 from typing import TYPE_CHECKING, Awaitable
+from uuid import UUID
 import asyncio
 import logging
 
@@ -67,7 +68,10 @@ class SignalHandler(SignaldClient):
         )
 
     async def on_message(self, evt: IncomingMessage) -> None:
-        sender = await pu.Puppet.get_by_address(evt.source)
+        sender = await pu.Puppet.get_by_address(evt.source, resolve_via=evt.account)
+        if not sender:
+            self.log.warning(f"Didn't find puppet for incoming message {evt.source}")
+            return
         user = await u.User.get_by_username(evt.account)
         # TODO add lots of logging
 
@@ -117,7 +121,11 @@ class SignalHandler(SignaldClient):
             f"{err.data.message}"
         )
 
-        sender = await pu.Puppet.get_by_address(Address.parse(err.data.sender))
+        sender = await pu.Puppet.get_by_address(
+            Address.parse(err.data.sender), resolve_via=err.account
+        )
+        if not sender:
+            return
         user = await u.User.get_by_username(err.account)
         portal = await po.Portal.get_by_chat_id(sender.address, receiver=user.username)
         if not portal or not portal.mxid:

+ 15 - 25
mautrix_signal/user.py

@@ -159,9 +159,7 @@ class User(DBUser, BaseUser):
     async def get_portal_with(self, puppet: pu.Puppet, create: bool = True) -> po.Portal | None:
         if not self.username:
             return None
-        return await po.Portal.get_by_chat_id(
-            puppet.address, receiver=self.username, create=create
-        )
+        return await po.Portal.get_by_chat_id(puppet.uuid, receiver=self.username, create=create)
 
     async def on_signin(self, account: Account) -> None:
         self.username = account.account_id
@@ -248,7 +246,10 @@ class User(DBUser, BaseUser):
         self._websocket_connection_state = bridge_state
 
     async def _sync_puppet(self) -> None:
-        puppet = await pu.Puppet.get_by_address(self.address)
+        puppet = await self.get_puppet()
+        if not puppet:
+            self.log.warning(f"Didn't find puppet for own address {self.address}")
+            return
         if puppet.uuid and not self.uuid:
             self.uuid = puppet.uuid
             self.by_uuid[self.uuid] = self
@@ -306,25 +307,20 @@ class User(DBUser, BaseUser):
             else:
                 address = contact.address
                 profile = contact
-            puppet = await pu.Puppet.get_by_address(address)
+            puppet = await pu.Puppet.get_by_address(address, resolve_via=self.username)
+            if not puppet:
+                self.log.debug(f"Didn't find puppet for {address} while syncing contact")
+                return
             await puppet.update_info(profile or address, self)
             if create_portals:
                 portal = await po.Portal.get_by_chat_id(
-                    puppet.address, receiver=self.username, create=True
+                    puppet.uuid, receiver=self.username, create=True
                 )
                 await portal.create_matrix_room(self, profile or address)
         except Exception as e:
             await self.handle_auth_failure(e)
             raise
 
-    async def _sync_group(self, group: Group, create_portals: bool) -> None:
-        self.log.trace("Syncing group %s", group)
-        portal = await po.Portal.get_by_chat_id(group.group_id, create=True)
-        if create_portals:
-            await portal.create_matrix_room(self, group)
-        elif portal.mxid:
-            await portal.update_matrix_room(self, group)
-
     async def _sync_group_v2(self, group: GroupV2, create_portals: bool) -> None:
         self.log.trace("Syncing group %s", group.id)
         portal = await po.Portal.get_by_chat_id(group.id, create=True)
@@ -344,16 +340,10 @@ class User(DBUser, BaseUser):
     async def _sync_groups(self) -> None:
         create_group_portal = self.config["bridge.autocreate_group_portal"]
         for group in await self.bridge.signal.list_groups(self.username):
-            group_id = group.group_id if isinstance(group, Group) else group.id
             try:
-                if isinstance(group, Group):
-                    await self._sync_group(group, create_group_portal)
-                elif isinstance(group, GroupV2):
-                    await self._sync_group_v2(group, create_group_portal)
-                else:
-                    self.log.warning("Unknown return type in list_groups: %s", type(group))
+                await self._sync_group_v2(group, create_group_portal)
             except Exception:
-                self.log.exception(f"Failed to sync group {group_id}")
+                self.log.exception(f"Failed to sync group {group.id}")
 
     # region Database getters
 
@@ -366,7 +356,7 @@ class User(DBUser, BaseUser):
 
     @classmethod
     @async_getter_lock
-    async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> User | None:
+    async def get_by_mxid(cls, mxid: UserID, /, *, create: bool = True) -> User | None:
         # Never allow ghosts to be users
         if pu.Puppet.get_id_from_mxid(mxid):
             return None
@@ -390,7 +380,7 @@ class User(DBUser, BaseUser):
 
     @classmethod
     @async_getter_lock
-    async def get_by_username(cls, username: str) -> User | None:
+    async def get_by_username(cls, username: str, /) -> User | None:
         try:
             return cls.by_username[username]
         except KeyError:
@@ -405,7 +395,7 @@ class User(DBUser, BaseUser):
 
     @classmethod
     @async_getter_lock
-    async def get_by_uuid(cls, uuid: UUID) -> User | None:
+    async def get_by_uuid(cls, uuid: UUID, /) -> User | None:
         try:
             return cls.by_uuid[uuid]
         except KeyError:

+ 0 - 1
mautrix_signal/util/__init__.py

@@ -1,4 +1,3 @@
 from .color_log import ColorFormatter
-from .id_to_str import id_to_str
 from .normalize_number import normalize_number
 from .user_has_power_level import user_has_power_level

+ 0 - 24
mautrix_signal/util/id_to_str.py

@@ -1,24 +0,0 @@
-# mautrix-signal - A Matrix-Signal puppeting bridge
-# Copyright (C) 2020 Tulir Asokan
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-# GNU Affero General Public License for more details.
-#
-# You should have received a copy of the GNU Affero General Public License
-# along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import Union
-
-from mausignald.types import Address, GroupID
-
-
-def id_to_str(identifier: Union[Address, GroupID]) -> str:
-    if isinstance(identifier, Address):
-        return identifier.best_identifier
-    return identifier

+ 17 - 15
mautrix_signal/web/provisioning_api.py

@@ -369,7 +369,7 @@ class ProvisioningAPI:
 
         async def transform(profile: Profile) -> JSON:
             assert profile.address
-            puppet = await pu.Puppet.get_by_address(profile.address, False)
+            puppet = await pu.Puppet.get_by_address(profile.address, create=False)
             avatar_url = puppet.avatar_url if puppet else None
             return {
                 "name": profile.name,
@@ -394,20 +394,22 @@ class ProvisioningAPI:
         except Exception as e:
             raise web.HTTPBadRequest(text=json.dumps({"error": str(e)}), headers=self._headers)
 
-        puppet: pu.Puppet = await pu.Puppet.get_by_address(Address(number=number))
-        assert puppet, "Puppet.get_by_address with create=True can't return None"
-        if not puppet.uuid:
-            try:
-                uuid = await self.bridge.signal.find_uuid(user.username, puppet.number)
-                if uuid:
-                    await puppet.handle_uuid_receive(uuid)
-            except UnregisteredUserError:
-                error = {"error": f"The phone number {number} is not a registered Signal account"}
-                raise web.HTTPNotFound(text=json.dumps(error), headers=self._headers)
-            except Exception:
-                self.log.exception(f"Unknown error fetching UUID for {puppet.number}")
-                error = {"error": "Unknown error while fetching UUID"}
-                raise web.HTTPInternalServerError(text=json.dumps(error), headers=self._headers)
+        try:
+            puppet: pu.Puppet = await pu.Puppet.get_by_number(number, raise_resolve=True)
+        except UnregisteredUserError:
+            error = {"error": f"The phone number {number} is not a registered Signal account"}
+            raise web.HTTPNotFound(text=json.dumps(error), headers=self._headers)
+        except Exception:
+            self.log.exception(f"Unknown error fetching UUID for {puppet.number}")
+            error = {"error": "Unknown error while fetching UUID"}
+            raise web.HTTPInternalServerError(text=json.dumps(error), headers=self._headers)
+        if not puppet:
+            error = {
+                "error": (
+                    f"The phone number {number} doesn't seem to be a registered Signal account"
+                )
+            }
+            raise web.HTTPNotFound(text=json.dumps(error), headers=self._headers)
         return puppet
 
     async def start_pm(self, request: web.Request) -> web.Response: