فهرست منبع

Maybe make things work without UUIDs

Tulir Asokan 4 سال پیش
والد
کامیت
06515b3e20

+ 5 - 4
mausignald/signald.py

@@ -11,7 +11,8 @@ from mautrix.util.logging import TraceLogger
 
 from .rpc import SignaldRPCClient
 from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
-from .types import Address, Quote, Attachment, Reaction, Account, Message, Contact, Group, Profile
+from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
+                    Profile, GroupID)
 
 T = TypeVar('T')
 EventHandler = Callable[[T], Awaitable[None]]
@@ -96,19 +97,19 @@ class SignaldClient(SignaldRPCClient):
         return [Account.deserialize(acc) for acc in data["accounts"]]
 
     @staticmethod
-    def _recipient_to_args(recipient: Union[Address, str]) -> Dict[str, Any]:
+    def _recipient_to_args(recipient: Union[Address, GroupID]) -> Dict[str, Any]:
         if isinstance(recipient, Address):
             return {"recipientAddress": recipient.serialize()}
         else:
             return {"recipientGroupId": recipient}
 
-    async def react(self, username: str, recipient: Union[Address, str],
+    async def react(self, username: str, recipient: Union[Address, GroupID],
                     reaction: Reaction) -> None:
         await self.request("react", "send_results", username=username,
                            reaction=reaction.serialize(),
                            **self._recipient_to_args(recipient))
 
-    async def send(self, username: str, recipient: Union[Address, str], body: str,
+    async def send(self, username: str, recipient: Union[Address, GroupID], body: str,
                    quote: Optional[Quote] = None, attachments: Optional[List[Attachment]] = None,
                    timestamp: Optional[int] = None) -> None:
         serialized_quote = quote.serialize() if quote else None

+ 29 - 5
mausignald/types.py

@@ -3,7 +3,7 @@
 # This Source Code Form is subject to the terms of the Mozilla Public
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
-from typing import Optional, Dict, Any, List
+from typing import Optional, Dict, Any, List, NewType
 from uuid import UUID
 
 from attr import dataclass
@@ -11,6 +11,8 @@ import attr
 
 from mautrix.types import SerializableAttrs, SerializableEnum
 
+GroupID = NewType('GroupID', str)
+
 
 @dataclass
 class Account(SerializableAttrs['Account']):
@@ -23,7 +25,7 @@ class Account(SerializableAttrs['Account']):
     uuid: Optional[UUID] = None
 
 
-@dataclass
+@dataclass(frozen=True, eq=False)
 class Address(SerializableAttrs['Address']):
     number: Optional[str] = None
     uuid: Optional[UUID] = None
@@ -32,6 +34,28 @@ class Address(SerializableAttrs['Address']):
     def is_valid(self) -> bool:
         return bool(self.number) or bool(self.uuid)
 
+    @property
+    def best_identifier(self) -> str:
+        return str(self.uuid) if self.uuid else self.number
+
+    def __eq__(self, other: 'Address') -> bool:
+        if not isinstance(other, Address):
+            return False
+        if self.uuid and other.uuid:
+            return self.uuid == other.uuid
+        elif self.number and other.number:
+            return self.number == other.number
+        return False
+
+    def __hash__(self) -> int:
+        if self.uuid:
+            return hash(self.uuid)
+        return hash(self.number)
+
+    @classmethod
+    def parse(cls, value: str) -> 'Address':
+        return Address(number=value) if value.startswith("+") else Address(uuid=UUID(value))
+
 
 @dataclass
 class Contact(SerializableAttrs['Contact']):
@@ -53,8 +77,8 @@ class Profile(SerializableAttrs['Profile']):
 
 @dataclass
 class Group(SerializableAttrs['Group']):
-    group_id: str = attr.ib(metadata={"json": "groupId"})
-    name: str
+    group_id: GroupID = attr.ib(metadata={"json": "groupId"})
+    name: str = "Unknown group"
 
     # Sometimes "UPDATE"
     type: Optional[str] = None
@@ -147,7 +171,7 @@ class TypingAction(SerializableEnum):
 class TypingNotification(SerializableAttrs['TypingNotification']):
     action: TypingAction
     timestamp: int
-    group_id: Optional[str] = attr.ib(default=None, metadata={"json": "groupId"})
+    group_id: Optional[GroupID] = attr.ib(default=None, metadata={"json": "groupId"})
 
 
 @dataclass

+ 20 - 15
mautrix_signal/db/message.py

@@ -19,9 +19,12 @@ from uuid import UUID
 from attr import dataclass
 import asyncpg
 
+from mausignald.types import Address, GroupID
 from mautrix.types import RoomID, EventID
 from mautrix.util.async_db import Database
 
+from ..util import id_to_str
+
 fake_db = Database("") if TYPE_CHECKING else None
 
 
@@ -31,22 +34,22 @@ class Message:
 
     mxid: EventID
     mx_room: RoomID
-    sender: UUID
+    sender: Address
     timestamp: int
-    signal_chat_id: Union[str, UUID]
+    signal_chat_id: Union[GroupID, Address]
     signal_receiver: str
 
     async def insert(self) -> None:
         q = ("INSERT INTO message (mxid, mx_room, sender, timestamp, signal_chat_id,"
              "                     signal_receiver) VALUES ($1, $2, $3, $4, $5, $6)")
-        await self.db.execute(q, self.mxid, self.mx_room, self.sender, self.timestamp,
-                              str(self.signal_chat_id), self.signal_receiver)
+        await self.db.execute(q, self.mxid, self.mx_room, self.sender.best_identifier,
+                              self.timestamp, id_to_str(self.signal_chat_id), self.signal_receiver)
 
     async def delete(self) -> None:
         q = ("DELETE FROM message WHERE sender=$1 AND timestamp=$2"
              "                          AND signal_chat_id=$3 AND signal_receiver=$4")
-        await self.db.execute(q, self.sender, self.timestamp, str(self.signal_chat_id),
-                              self.signal_receiver)
+        await self.db.execute(q, self.sender.best_identifier, self.timestamp,
+                              id_to_str(self.signal_chat_id), self.signal_receiver)
 
     @classmethod
     async def delete_all(cls, room_id: RoomID) -> None:
@@ -55,11 +58,11 @@ class Message:
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> 'Message':
         data = {**row}
+        chat_id = data.pop("signal_chat_id")
         if data["signal_receiver"]:
-            chat_id = UUID(data.pop("signal_chat_id"))
-        else:
-            chat_id = data.pop("signal_chat_id")
-        return cls(signal_chat_id=chat_id, **data)
+            chat_id = Address.parse(chat_id)
+        sender = Address.parse(data.pop("sender"))
+        return cls(signal_chat_id=chat_id, sender=sender, **data)
 
     @classmethod
     async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional['Message']:
@@ -71,12 +74,14 @@ class Message:
         return cls._from_row(row)
 
     @classmethod
-    async def get_by_signal_id(cls, sender: UUID, timestamp: int, signal_chat_id: Union[str, UUID],
-                               signal_receiver: str = "") -> Optional['Message']:
+    async def get_by_signal_id(cls, sender: Address, timestamp: int,
+                               signal_chat_id: Union[GroupID, Address], signal_receiver: str = ""
+                               ) -> Optional['Message']:
         q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
              "FROM message WHERE sender=$1 AND timestamp=$2"
              "                   AND signal_chat_id=$3 AND signal_receiver=$4")
-        row = await cls.db.fetchrow(q, sender, timestamp, str(signal_chat_id), signal_receiver)
+        row = await cls.db.fetchrow(q, sender.best_identifier, timestamp,
+                                    id_to_str(signal_chat_id), signal_receiver)
         if not row:
             return None
         return cls._from_row(row)
@@ -89,10 +94,10 @@ class Message:
         return [cls._from_row(row) for row in rows]
 
     @classmethod
-    async def find_by_sender_timestamp(cls, sender: UUID, timestamp: int) -> Optional['Message']:
+    async def find_by_sender_timestamp(cls, sender: Address, timestamp: int) -> Optional['Message']:
         q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
              "FROM message WHERE sender=$1 AND timestamp=$2")
-        row = await cls.db.fetchrow(q, sender, timestamp)
+        row = await cls.db.fetchrow(q, sender.best_identifier, timestamp)
         if not row:
             return None
         return cls._from_row(row)

+ 18 - 12
mautrix_signal/db/portal.py

@@ -19,9 +19,12 @@ from uuid import UUID
 from attr import dataclass
 import asyncpg
 
+from mausignald.types import Address, GroupID
 from mautrix.types import RoomID, ContentURI
 from mautrix.util.async_db import Database
 
+from ..util import id_to_str
+
 fake_db = Database("") if TYPE_CHECKING else None
 
 
@@ -29,7 +32,7 @@ fake_db = Database("") if TYPE_CHECKING else None
 class Portal:
     db: ClassVar[Database] = fake_db
 
-    chat_id: Union[UUID, str]
+    chat_id: Union[GroupID, Address]
     receiver: str
     mxid: Optional[RoomID]
     name: Optional[str]
@@ -37,26 +40,29 @@ class Portal:
     avatar_url: Optional[ContentURI]
     encrypted: bool
 
+    @property
+    def chat_id_str(self) -> str:
+        return id_to_str(self.chat_id)
+
     async def insert(self) -> None:
         q = ("INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, "
              "                    encrypted) "
              "VALUES ($1, $2, $3, $4, $5, $6, $7)")
-        await self.db.execute(q, str(self.chat_id), self.receiver, self.mxid, self.name,
+        await self.db.execute(q, self.chat_id_str, self.receiver, self.mxid, self.name,
                               self.avatar_hash, self.avatar_url, self.encrypted)
 
     async def update(self) -> None:
         q = ("UPDATE portal SET mxid=$3, name=$4, avatar_hash=$5, avatar_url=$6, encrypted=$7 "
              "WHERE chat_id=$1 AND receiver=$2")
-        await self.db.execute(q, str(self.chat_id), self.receiver, self.mxid, self.name,
+        await self.db.execute(q, self.chat_id_str, self.receiver, self.mxid, self.name,
                               self.avatar_hash, self.avatar_url, self.encrypted)
 
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> 'Portal':
         data = {**row}
+        chat_id = data.pop("chat_id")
         if data["receiver"]:
-            chat_id = UUID(data.pop("chat_id"))
-        else:
-            chat_id = data.pop("chat_id")
+            chat_id = Address.parse(chat_id)
         return cls(chat_id=chat_id, **data)
 
     @classmethod
@@ -69,27 +75,27 @@ class Portal:
         return cls._from_row(row)
 
     @classmethod
-    async def get_by_chat_id(cls, chat_id: Union[UUID, str], receiver: str = ""
+    async def get_by_chat_id(cls, chat_id: Union[GroupID, Address], receiver: str = ""
                              ) -> Optional['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
              "FROM portal WHERE chat_id=$1 AND receiver=$2")
-        row = await cls.db.fetchrow(q, str(chat_id), receiver)
+        row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
         if not row:
             return None
         return cls._from_row(row)
 
     @classmethod
     async def find_private_chats_of(cls, receiver: str) -> List['Portal']:
-        q =( "SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
+        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
              "FROM portal WHERE receiver=$1")
         rows = await cls.db.fetch(q, receiver)
         return [cls._from_row(row) for row in rows]
 
     @classmethod
-    async def find_private_chats_with(cls, other_user: UUID) -> List['Portal']:
+    async def find_private_chats_with(cls, other_user: Address) -> List['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
-             "FROM portal WHERE chat_id=$1::text AND receiver<>''")
-        rows = await cls.db.fetch(q, other_user)
+             "FROM portal WHERE chat_id=$1 AND receiver<>''")
+        rows = await cls.db.fetch(q, other_user.best_identifier)
         return [cls._from_row(row) for row in rows]
 
     @classmethod

+ 6 - 1
mautrix_signal/db/puppet.py

@@ -55,7 +55,12 @@ class Puppet:
         if self.uuid:
             raise ValueError("Can't re-set UUID for puppet")
         self.uuid = uuid
-        await self.db.execute("UPDATE puppet SET uuid=$1 WHERE number=$2", uuid, self.number)
+        async with self.db.acquire() as conn, conn.transaction():
+            await conn.execute("UPDATE puppet SET uuid=$1 WHERE number=$2", uuid, self.number)
+            uuid = str(uuid)
+            await conn.execute("UPDATE portal SET chat_id=$1 WHERE chat_id=$2", uuid, self.number)
+            await conn.execute("UPDATE message SET sender=$1 WHERE sender=$2", uuid, self.number)
+            await conn.execute("UPDATE reaction SET author=$1 WHERE author=$2", uuid, self.number)
 
     async def update(self) -> None:
         if self.uuid is None:

+ 26 - 18
mautrix_signal/db/reaction.py

@@ -13,15 +13,18 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import Optional, ClassVar, TYPE_CHECKING
+from typing import Optional, ClassVar, Union, TYPE_CHECKING
 from uuid import UUID
 
 from attr import dataclass
 import asyncpg
 
+from mausignald.types import Address, GroupID
 from mautrix.types import RoomID, EventID
 from mautrix.util.async_db import Database
 
+from ..util import id_to_str
+
 fake_db = Database("") if TYPE_CHECKING else None
 
 
@@ -31,42 +34,45 @@ class Reaction:
 
     mxid: EventID
     mx_room: RoomID
-    signal_chat_id: str
+    signal_chat_id: Union[GroupID, Address]
     signal_receiver: str
-    msg_author: UUID
+    msg_author: Address
     msg_timestamp: int
-    author: UUID
+    author: Address
     emoji: str
 
     async def insert(self) -> None:
         q = ("INSERT INTO reaction (mxid, mx_room, signal_chat_id, signal_receiver, msg_author,"
              "                      msg_timestamp, author, emoji) "
              "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)")
-        await self.db.execute(q, self.mxid, self.mx_room, str(self.signal_chat_id),
-                              self.signal_receiver, self.msg_author, self.msg_timestamp,
-                              self.author, self.emoji)
+        await self.db.execute(q, self.mxid, self.mx_room, id_to_str(self.signal_chat_id),
+                              self.signal_receiver, self.msg_author.best_identifier,
+                              self.msg_timestamp, self.author.best_identifier, self.emoji)
 
     async def edit(self, mx_room: RoomID, mxid: EventID, emoji: str) -> None:
         await self.db.execute("UPDATE reaction SET mxid=$1, mx_room=$2, emoji=$3 "
                               "WHERE signal_chat_id=$4 AND signal_receiver=$5"
                               "      AND msg_author=$6 AND msg_timestamp=$7 AND author=$8",
-                              mxid, mx_room, emoji, str(self.signal_chat_id), self.signal_receiver,
-                              self.msg_author, self.msg_timestamp, self.author)
+                              mxid, mx_room, emoji, id_to_str(self.signal_chat_id),
+                              self.signal_receiver, self.msg_author.best_identifier,
+                              self.msg_timestamp, self.author.best_identifier)
 
     async def delete(self) -> None:
         q = ("DELETE FROM reaction WHERE signal_chat_id=$1 AND signal_receiver=$2"
              "                           AND msg_author=$3 AND msg_timestamp=$4 AND author=$5")
-        await self.db.execute(q, str(self.signal_chat_id), self.signal_receiver, self.msg_author,
-                              self.msg_timestamp, self.author)
+        await self.db.execute(q, id_to_str(self.signal_chat_id), self.signal_receiver,
+                              self.msg_author.best_identifier, self.msg_timestamp,
+                              self.author.best_identifier)
 
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> 'Reaction':
         data = {**row}
+        chat_id = data.pop("signal_chat_id")
         if data["signal_receiver"]:
-            chat_id = UUID(data.pop("signal_chat_id"))
-        else:
-            chat_id = data.pop("signal_chat_id")
-        return cls(signal_chat_id=chat_id, **data)
+            chat_id = Address.parse(chat_id)
+        msg_author = Address.parse(data.pop("msg_author"))
+        author = Address.parse(data.pop("author"))
+        return cls(signal_chat_id=chat_id, msg_author=msg_author, author=author, **data)
 
     @classmethod
     async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional['Reaction']:
@@ -79,13 +85,15 @@ class Reaction:
         return cls._from_row(row)
 
     @classmethod
-    async def get_by_signal_id(cls, chat_id: str, receiver: str, msg_author: UUID,
-                               msg_timestamp: int, author: UUID) -> Optional['Reaction']:
+    async def get_by_signal_id(cls, chat_id: Union[GroupID, Address], receiver: str,
+                               msg_author: Address, msg_timestamp: int, author: Address
+                               ) -> Optional['Reaction']:
         q = ("SELECT mxid, mx_room, signal_chat_id, signal_receiver,"
              "       msg_author, msg_timestamp, author, emoji "
              "FROM reaction WHERE signal_chat_id=$1 AND signal_receiver=$2"
              "                    AND msg_author=$3 AND msg_timestamp=$4 AND author=$5")
-        row = await cls.db.fetchrow(q, chat_id, receiver, msg_author, msg_timestamp, author)
+        row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver, msg_author.best_identifier,
+                                    msg_timestamp, author.best_identifier)
         if not row:
             return None
         return cls._from_row(row)

+ 14 - 0
mautrix_signal/db/upgrade.py

@@ -100,3 +100,17 @@ async def upgrade_v2(conn: Connection) -> None:
 @upgrade_table.register(description="Add double-puppeting base_url to puppe table")
 async def upgrade_v3(conn: Connection) -> None:
     await conn.execute("ALTER TABLE puppet ADD COLUMN base_url TEXT")
+
+
+@upgrade_table.register(description="Allow phone numbers as message sender identifiers")
+async def upgrade_v4(conn: Connection) -> None:
+    cname = await conn.fetchval("SELECT constraint_name FROM information_schema.table_constraints "
+                                "WHERE table_name='reaction' AND constraint_name LIKE '%_fkey'")
+    await conn.execute(f"ALTER TABLE reaction DROP CONSTRAINT {cname}")
+    await conn.execute("ALTER TABLE reaction ALTER COLUMN msg_author SET DATA TYPE TEXT")
+    await conn.execute("ALTER TABLE reaction ALTER COLUMN author SET DATA TYPE TEXT")
+    await conn.execute("ALTER TABLE message ALTER COLUMN sender SET DATA TYPE TEXT")
+    await conn.execute(f"ALTER TABLE reaction ADD CONSTRAINT {cname} "
+                       "FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver) "
+                       "  REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver) "
+                       "  ON DELETE CASCADE ON UPDATE CASCADE")

+ 2 - 4
mautrix_signal/matrix.py

@@ -15,8 +15,6 @@
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 from typing import List, Union, TYPE_CHECKING
 
-from mausignald.types import Address
-
 from mautrix.bridge import BaseMatrixHandler
 from mautrix.types import (Event, ReactionEvent, MessageEvent, StateEvent, EncryptedEvent, RoomID,
                            EventID, UserID, ReactionEventContent, RelationType, EventType,
@@ -93,7 +91,7 @@ class MatrixHandler(BaseMatrixHandler):
             return
 
         user.log.trace(f"Sending read receipt for {message.timestamp} to {message.sender}")
-        await self.signal.send_receipt(user.username, Address(uuid=message.sender),
+        await self.signal.send_receipt(user.username, message.sender,
                                        timestamps=[message.timestamp], when=data.ts, read=True)
 
     async def handle_typing(self, room_id: RoomID, typing: List[UserID]) -> None:
@@ -118,4 +116,4 @@ class MatrixHandler(BaseMatrixHandler):
         if evt.type == EventType.TYPING:
             await self.handle_typing(evt.room_id, evt.content.user_ids)
         else:
-            super().handle_ephemeral_event(evt)
+            await super().handle_ephemeral_event(evt)

+ 54 - 58
mautrix_signal/portal.py

@@ -13,8 +13,8 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import (Dict, Tuple, Optional, List, Deque, Set, Any, Union, AsyncGenerator,
-                    Awaitable, TYPE_CHECKING, cast)
+from typing import (Dict, Tuple, Optional, List, Deque, Any, Union, AsyncGenerator, Awaitable,
+                    TYPE_CHECKING, cast)
 from collections import deque
 from uuid import UUID, uuid4
 import mimetypes
@@ -25,7 +25,7 @@ import time
 import os
 
 from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
-                              Attachment)
+                              Attachment, GroupID)
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.bridge import BasePortal
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
@@ -57,7 +57,7 @@ ChatInfo = Union[Group, Contact, Profile, Address]
 
 class Portal(DBPortal, BasePortal):
     by_mxid: Dict[RoomID, 'Portal'] = {}
-    by_chat_id: Dict[Tuple[Union[str, UUID], str], 'Portal'] = {}
+    by_chat_id: Dict[Tuple[str, str], 'Portal'] = {}
     config: Config
     matrix: 'm.MatrixHandler'
     signal: 's.SignalHandler'
@@ -66,16 +66,16 @@ class Portal(DBPortal, BasePortal):
 
     _main_intent: Optional[IntentAPI]
     _create_room_lock: asyncio.Lock
-    _msgts_dedup: Deque[Tuple[UUID, int]]
-    _reaction_dedup: Deque[Tuple[UUID, int, str]]
+    _msgts_dedup: Deque[Tuple[Address, int]]
+    _reaction_dedup: Deque[Tuple[Address, int, str]]
     _reaction_lock: asyncio.Lock
 
-    def __init__(self, chat_id: Union[str, UUID], receiver: str, mxid: Optional[RoomID] = None,
+    def __init__(self, chat_id: Union[GroupID, Address], receiver: str, mxid: Optional[RoomID] = None,
                  name: Optional[str] = None, avatar_hash: Optional[str] = None,
                  avatar_url: Optional[ContentURI] = None, encrypted: bool = False) -> None:
         super().__init__(chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted)
         self._create_room_lock = asyncio.Lock()
-        self.log = self.log.getChild(str(chat_id))
+        self.log = self.log.getChild(self.chat_id_str)
         self._main_intent = None
         self._msgts_dedup = deque(maxlen=100)
         self._reaction_dedup = deque(maxlen=100)
@@ -90,14 +90,15 @@ class Portal(DBPortal, BasePortal):
 
     @property
     def is_direct(self) -> bool:
-        return isinstance(self.chat_id, UUID)
+        return isinstance(self.chat_id, Address)
 
-    @property
-    def recipient(self) -> Union[str, Address]:
-        if self.is_direct:
-            return Address(uuid=self.chat_id)
-        else:
-            return self.chat_id
+    def handle_uuid_receive(self, uuid: UUID) -> None:
+        if not self.is_direct or self.chat_id.uuid:
+            raise ValueError("handle_uuid_receive can only be used for private chat portals with "
+                             "a phone number chat_id")
+        del self.by_chat_id[(self.chat_id_str, self.receiver)]
+        self.chat_id = Address(uuid=uuid)
+        self.by_chat_id[(self.chat_id_str, self.receiver)] = self
 
     @classmethod
     def init_cls(cls, bridge: 'SignalBridge') -> None:
@@ -131,9 +132,10 @@ class Portal(DBPortal, BasePortal):
             await existing.edit(emoji=emoji, mxid=mxid, mx_room=message.mx_room)
         else:
             self.log.debug(f"_upsert_reaction inserting {mxid} (message: {message.mxid})")
-            await DBReaction(mxid=mxid, mx_room=message.mx_room, emoji=emoji, author=sender.uuid,
+            await DBReaction(mxid=mxid, mx_room=message.mx_room, emoji=emoji,
                              signal_chat_id=self.chat_id, signal_receiver=self.receiver,
-                             msg_author=message.sender, msg_timestamp=message.timestamp).insert()
+                             msg_author=message.sender, msg_timestamp=message.timestamp,
+                             author=sender.address).insert()
 
     # endregion
     # region Matrix event handling
@@ -168,14 +170,14 @@ class Portal(DBPortal, BasePortal):
             self.log.debug(f"Ignoring puppet-sent message by confirmed puppet user {sender.mxid}")
             return
         request_id = int(time.time() * 1000)
-        self._msgts_dedup.appendleft((sender.uuid, request_id))
+        self._msgts_dedup.appendleft((sender.address, request_id))
 
         quote = None
         if message.get_reply_to():
             reply = await DBMessage.get_by_mxid(message.get_reply_to(), self.mxid)
             # TODO include actual text? either store in db or fetch event from homeserver
             if reply is not None:
-                quote = Quote(id=reply.timestamp, author=Address(uuid=reply.sender), text="")
+                quote = Quote(id=reply.timestamp, author=reply.sender, text="")
 
         text = message.body
         attachments: Optional[List[Attachment]] = None
@@ -188,9 +190,9 @@ class Portal(DBPortal, BasePortal):
             attachments = [attachment]
             text = None
             self.log.trace("Formed outgoing attachment %s", attachment)
-        await self.signal.send(username=sender.username, recipient=self.recipient, body=text,
+        await self.signal.send(username=sender.username, recipient=self.chat_id, body=text,
                                quote=quote, attachments=attachments, timestamp=request_id)
-        msg = DBMessage(mxid=event_id, mx_room=self.mxid, sender=sender.uuid, timestamp=request_id,
+        msg = DBMessage(mxid=event_id, mx_room=self.mxid, sender=sender.address, timestamp=request_id,
                         signal_chat_id=self.chat_id, signal_receiver=self.receiver)
         await msg.insert()
         await self._send_delivery_receipt(event_id)
@@ -212,7 +214,7 @@ class Portal(DBPortal, BasePortal):
             return
 
         existing = await DBReaction.get_by_signal_id(self.chat_id, self.receiver, message.sender,
-                                                     message.timestamp, sender.uuid)
+                                                     message.timestamp, sender.address)
         if existing and existing.emoji == emoji:
             return
 
@@ -220,9 +222,9 @@ class Portal(DBPortal, BasePortal):
         self._reaction_dedup.appendleft(dedup_id)
         async with self._reaction_lock:
             reaction = Reaction(emoji=emoji, remove=False,
-                                target_author=Address(uuid=message.sender),
+                                target_author=message.sender,
                                 target_sent_timestamp=message.timestamp)
-            await self.signal.react(username=sender.username, recipient=self.recipient,
+            await self.signal.react(username=sender.username, recipient=self.chat_id,
                                     reaction=reaction)
             await self._upsert_reaction(existing, self.main_intent, event_id, sender, message,
                                         emoji)
@@ -239,9 +241,9 @@ class Portal(DBPortal, BasePortal):
             try:
                 await reaction.delete()
                 remove_reaction = Reaction(emoji=reaction.emoji, remove=True,
-                                           target_author=Address(uuid=reaction.msg_author),
+                                           target_author=reaction.msg_author,
                                            target_sent_timestamp=reaction.msg_timestamp)
-                await self.signal.react(username=sender.username, recipient=self.recipient,
+                await self.signal.react(username=sender.username, recipient=self.chat_id,
                                         reaction=remove_reaction)
                 await self._send_delivery_receipt(redaction_event_id)
                 self.log.trace(f"Removed {reaction} after Matrix redaction")
@@ -263,21 +265,17 @@ class Portal(DBPortal, BasePortal):
     # region Signal event handling
 
     @staticmethod
-    async def _find_address_uuid(address: Address) -> Optional[UUID]:
-        if address.uuid:
-            return address.uuid
+    async def _resolve_address(address: Address) -> Address:
         puppet = await p.Puppet.get_by_address(address, create=False)
-        if puppet and puppet.uuid:
-            return puppet.uuid
-        return None
+        return puppet.address
 
     async def _find_quote_event_id(self, quote: Optional[Quote]
                                    ) -> Optional[Union[MessageEvent, EventID]]:
         if not quote:
             return None
 
-        author_uuid = await self._find_address_uuid(quote.author)
-        reply_msg = await DBMessage.get_by_signal_id(author_uuid, quote.id,
+        author_address = await self._resolve_address(quote.author)
+        reply_msg = await DBMessage.get_by_signal_id(author_address, quote.id,
                                                      self.chat_id, self.receiver)
         if not reply_msg:
             return None
@@ -291,13 +289,13 @@ class Portal(DBPortal, BasePortal):
 
     async def handle_signal_message(self, source: 'u.User', sender: 'p.Puppet',
                                     message: MessageData) -> None:
-        if (sender.uuid, message.timestamp) in self._msgts_dedup:
+        if (sender.address, message.timestamp) in self._msgts_dedup:
             self.log.debug(f"Ignoring message {message.timestamp} by {sender.uuid}"
                            " as it was already handled (message.timestamp in dedup queue)")
             await self.signal.send_receipt(source.username, sender.address,
                                            timestamps=[message.timestamp])
             return
-        old_message = await DBMessage.get_by_signal_id(sender.uuid, message.timestamp,
+        old_message = await DBMessage.get_by_signal_id(sender.address, message.timestamp,
                                                        self.chat_id, self.receiver)
         if old_message is not None:
             self.log.debug(f"Ignoring message {message.timestamp} by {sender.uuid}"
@@ -307,7 +305,7 @@ class Portal(DBPortal, BasePortal):
             return
         self.log.debug(f"Started handling message {message.timestamp} by {sender.uuid}")
         self.log.trace(f"Message content: {message}")
-        self._msgts_dedup.appendleft((sender.uuid, message.timestamp))
+        self._msgts_dedup.appendleft((sender.address, message.timestamp))
         intent = sender.intent_for(self)
         await intent.set_typing(self.mxid, False)
         event_id = None
@@ -345,7 +343,7 @@ class Portal(DBPortal, BasePortal):
 
         if event_id:
             msg = DBMessage(mxid=event_id, mx_room=self.mxid,
-                            sender=sender.uuid, timestamp=message.timestamp,
+                            sender=sender.address, timestamp=message.timestamp,
                             signal_chat_id=self.chat_id, signal_receiver=self.receiver)
             await msg.insert()
             await self.signal.send_receipt(source.username, sender.address,
@@ -404,20 +402,16 @@ class Portal(DBPortal, BasePortal):
         return content
 
     async def handle_signal_reaction(self, sender: 'p.Puppet', reaction: Reaction) -> None:
-        author_uuid = await self._find_address_uuid(reaction.target_author)
+        author_address = await self._resolve_address(reaction.target_author)
         target_id = reaction.target_sent_timestamp
-        if author_uuid is None:
-            self.log.warning(f"Failed to handle reaction from {sender.uuid}: "
-                             f"couldn't find UUID of {reaction.target_author}")
-            return
         async with self._reaction_lock:
-            dedup_id = (author_uuid, target_id, reaction.emoji)
+            dedup_id = (author_address, target_id, reaction.emoji)
             if dedup_id in self._reaction_dedup:
                 return
             self._reaction_dedup.appendleft(dedup_id)
 
         existing = await DBReaction.get_by_signal_id(self.chat_id, self.receiver,
-                                                     author_uuid, target_id, sender.uuid)
+                                                     author_address, target_id, sender.address)
 
         if reaction.remove:
             if existing:
@@ -431,7 +425,7 @@ class Portal(DBPortal, BasePortal):
         elif existing and existing.emoji == reaction.emoji:
             return
 
-        message = await DBMessage.get_by_signal_id(author_uuid, target_id,
+        message = await DBMessage.get_by_signal_id(author_address, target_id,
                                                    self.chat_id, self.receiver)
         if not message:
             self.log.debug(f"Ignoring reaction to unknown message {target_id}")
@@ -440,7 +434,7 @@ class Portal(DBPortal, BasePortal):
         intent = sender.intent_for(self)
         # TODO add variation selectors to emoji before sending to Matrix
         mxid = await intent.react(message.mx_room, message.mxid, reaction.emoji)
-        self.log.debug(f"{sender.uuid} reacted to {message.mxid} -> {mxid}")
+        self.log.debug(f"{sender.address} reacted to {message.mxid} -> {mxid}")
         await self._upsert_reaction(existing, intent, mxid, sender, message, reaction.emoji)
 
     # endregion
@@ -451,7 +445,7 @@ class Portal(DBPortal, BasePortal):
             if not isinstance(info, (Contact, Profile, Address)):
                 raise ValueError(f"Unexpected type for direct chat update_info: {type(info)}")
             if not self.name:
-                puppet = await p.Puppet.get_by_address(Address(uuid=self.chat_id))
+                puppet = await p.Puppet.get_by_address(self.chat_id)
                 if not puppet.name:
                     await puppet.update_info(info)
                 self.name = puppet.name
@@ -619,18 +613,18 @@ class Portal(DBPortal, BasePortal):
         if self.config["bridge.encryption.default"] and self.matrix.e2ee:
             self.encrypted = True
             initial_state.append({
-                "type": "m.room.encryption",
+                "type": str(EventType.ROOM_ENCRYPTION),
                 "content": {"algorithm": "m.megolm.v1.aes-sha2"},
             })
             if self.is_direct:
                 invites.append(self.az.bot_mxid)
-        if source.uuid == self.chat_id:
+        if self.is_direct and source.address == self.chat_id:
             name = self.name = "Signal Note to Self"
         elif self.encrypted or self.private_chat_portal_meta or not self.is_direct:
             name = self.name
         if self.avatar_url:
             initial_state.append({
-                "type": "m.room.avatar",
+                "type": str(EventType.ROOM_AVATAR),
                 "content": {"url": self.avatar_url},
             })
         if self.config["appservice.community_id"]:
@@ -638,10 +632,9 @@ class Portal(DBPortal, BasePortal):
                 "type": "m.room.related_groups",
                 "content": {"groups": [self.config["appservice.community_id"]]},
             })
-        #Allow chaning of room avatar and name in direct chats
         if self.is_direct:
             initial_state.append({
-                "type": "m.room.power_levels",
+                "type": str(EventType.ROOM_POWER_LEVELS),
                 "content": {"users": {self.main_intent.mxid: 100},
                             "events": {"m.room.avatar": 0, "m.room.name": 0}}
             })
@@ -689,7 +682,7 @@ class Portal(DBPortal, BasePortal):
         if self.mxid:
             self.by_mxid[self.mxid] = self
         if self.is_direct:
-            puppet = await p.Puppet.get_by_address(Address(uuid=self.chat_id))
+            puppet = await p.Puppet.get_by_address(self.chat_id)
             self._main_intent = puppet.default_mxid_intent
         elif not self.is_direct:
             self._main_intent = self.az.intent
@@ -709,7 +702,7 @@ class Portal(DBPortal, BasePortal):
         return cls._db_to_portals(super().all_with_room())
 
     @classmethod
-    def find_private_chats_with(cls, other_user: UUID) -> AsyncGenerator['Portal', None]:
+    def find_private_chats_with(cls, other_user: Address) -> AsyncGenerator['Portal', None]:
         return cls._db_to_portals(super().find_private_chats_with(other_user))
 
     @classmethod
@@ -718,7 +711,7 @@ class Portal(DBPortal, BasePortal):
         portals = await query
         for index, portal in enumerate(portals):
             try:
-                yield cls.by_chat_id[(portal.chat_id, portal.receiver)]
+                yield cls.by_chat_id[(portal.chat_id_str, portal.receiver)]
             except KeyError:
                 await portal._postinit()
                 yield portal
@@ -738,14 +731,17 @@ class Portal(DBPortal, BasePortal):
         return None
 
     @classmethod
-    async def get_by_chat_id(cls, chat_id: Union[UUID, str], receiver: str = "",
+    async def get_by_chat_id(cls, chat_id: Union[GroupID, Address], receiver: str = "",
                              create: bool = False) -> Optional['Portal']:
         if isinstance(chat_id, str):
             receiver = ""
+        elif not isinstance(chat_id, Address):
+            raise ValueError(f"Invalid chat ID type {type(chat_id)}")
         elif not receiver:
             raise ValueError("Direct chats must have a receiver")
         try:
-            return cls.by_chat_id[(chat_id, receiver)]
+            best_id = chat_id.best_identifier if isinstance(chat_id, Address) else chat_id
+            return cls.by_chat_id[(best_id, receiver)]
         except KeyError:
             pass
 

+ 12 - 4
mautrix_signal/puppet.py

@@ -28,7 +28,7 @@ from mautrix.util.simple_template import SimpleTemplate
 
 from .db import Puppet as DBPuppet
 from .config import Config
-from . import portal as p
+from . import portal as p, user as u
 
 if TYPE_CHECKING:
     from .__main__ import SignalBridge
@@ -90,7 +90,7 @@ class Puppet(DBPuppet, BasePuppet):
         return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
 
     def intent_for(self, portal: 'p.Portal') -> IntentAPI:
-        if portal.chat_id == self.uuid:
+        if portal.chat_id == self.address:
             return self.default_mxid_intent
         return self.intent
 
@@ -118,14 +118,22 @@ class Puppet(DBPuppet, BasePuppet):
 
     async def _handle_uuid_receive(self, uuid: UUID) -> None:
         self.log.debug(f"Found UUID for user: {uuid}")
+        user = await u.User.get_by_username(self.number)
+        if user and not user.uuid:
+            user.uuid = self.uuid
+            await user.update()
         await self._set_uuid(uuid)
         self.by_uuid[self.uuid] = self
+        async for portal in p.Portal.find_private_chats_with(Address(number=self.number)):
+            self.log.trace(f"Updating chat_id of private chat portal {portal.receiver}")
+            portal.handle_uuid_receive(self.uuid)
         prev_intent = self.default_mxid_intent
         self.default_mxid = self.get_mxid_from_id(self.address)
         self.default_mxid_intent = self.az.intent.user(self.default_mxid)
         self.intent = self._fresh_intent()
         await self.intent.ensure_registered()
-        await self.intent.set_displayname(self.name)
+        if self.name:
+            await self.intent.set_displayname(self.name)
         self.log = self.log.getChild(str(uuid))
         self.log.debug(f"Migrating memberships {prev_intent.mxid} -> {self.default_mxid_intent}")
         for room_id in await prev_intent.get_joined_rooms():
@@ -186,7 +194,7 @@ class Puppet(DBPuppet, BasePuppet):
         return False
 
     async def _update_portal_names(self) -> None:
-        async for portal in p.Portal.find_private_chats_with(self.uuid):
+        async for portal in p.Portal.find_private_chats_with(self.address):
             if portal.receiver == self.number:
                 # This is a note to self chat, don't change the name
                 continue

+ 4 - 9
mautrix_signal/signal.py

@@ -42,10 +42,6 @@ class SignalHandler(SignaldClient):
 
     async def on_message(self, evt: Message) -> None:
         sender = await pu.Puppet.get_by_address(evt.source)
-        if not sender.uuid:
-            self.log.debug("Got message sender puppet with no UUID, not handling message")
-            self.log.trace("Message content: %s", evt)
-            return
         user = await u.User.get_by_username(evt.username)
         # TODO add lots of logging
 
@@ -74,8 +70,7 @@ class SignalHandler(SignaldClient):
         if msg.group:
             portal = await po.Portal.get_by_chat_id(msg.group.group_id, create=True)
         else:
-            portal = await po.Portal.get_by_chat_id(addr_override.uuid
-                                                    if addr_override else sender.uuid,
+            portal = await po.Portal.get_by_chat_id(addr_override or sender.address,
                                                     receiver=user.username, create=True)
             if addr_override and not sender.is_real_user:
                 portal.log.debug(f"Ignoring own message {msg.timestamp} as user doesn't have"
@@ -94,9 +89,9 @@ class SignalHandler(SignaldClient):
     async def handle_own_receipts(sender: 'pu.Puppet', receipts: List[OwnReadReceipt]) -> None:
         for receipt in receipts:
             puppet = await pu.Puppet.get_by_address(receipt.sender, create=False)
-            if not puppet or not puppet.uuid:
+            if not puppet:
                 continue
-            message = await DBMessage.find_by_sender_timestamp(puppet.uuid, receipt.timestamp)
+            message = await DBMessage.find_by_sender_timestamp(puppet.address, receipt.timestamp)
             if not message:
                 continue
             portal = await po.Portal.get_by_mxid(message.mx_room)
@@ -110,7 +105,7 @@ class SignalHandler(SignaldClient):
         if typing.group_id:
             portal = await po.Portal.get_by_chat_id(typing.group_id)
         else:
-            portal = await po.Portal.get_by_chat_id(sender.uuid, receiver=user.username)
+            portal = await po.Portal.get_by_chat_id(sender.address, receiver=user.username)
         if not portal or not portal.mxid:
             return
         is_typing = typing.action == TypingAction.STARTED

+ 2 - 2
mautrix_signal/user.py

@@ -106,8 +106,8 @@ class User(DBUser, BaseUser):
                 # necessary, but maybe we could listen for updates?
                 profile = None
             await puppet.update_info(profile or contact)
-            if puppet.uuid and create_contact_portal:
-                portal = await po.Portal.get_by_chat_id(puppet.uuid, self.username, create=True)
+            if create_contact_portal:
+                portal = await po.Portal.get_by_chat_id(puppet.address, self.username, create=True)
                 await portal.create_matrix_room(self, profile or contact)
 
         create_group_portal = self.config["bridge.autocreate_group_portal"]

+ 1 - 0
mautrix_signal/util/__init__.py

@@ -1 +1,2 @@
 from .color_log import ColorFormatter
+from .id_to_str import id_to_str

+ 24 - 0
mautrix_signal/util/id_to_str.py

@@ -0,0 +1,24 @@
+# mautrix-signal - A Matrix-Signal puppeting bridge
+# Copyright (C) 2020 Tulir Asokan
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+from typing import Union
+
+from mausignald.types import Address, GroupID
+
+
+def id_to_str(identifier: Union[Address, GroupID]) -> str:
+    if isinstance(identifier, Address):
+        return identifier.best_identifier
+    return identifier

+ 1 - 1
requirements.txt

@@ -4,5 +4,5 @@ commonmark>=0.8,<0.10
 aiohttp>=3,<3.7
 yarl>=1,<1.6
 attrs>=19.1
-mautrix==0.8.0.beta7
+mautrix==0.8.0.beta9
 asyncpg>=0.20,<0.22