Browse Source

Maybe make things work without UUIDs

Tulir Asokan 4 years ago
parent
commit
06515b3e20

+ 5 - 4
mausignald/signald.py

@@ -11,7 +11,8 @@ from mautrix.util.logging import TraceLogger
 
 
 from .rpc import SignaldRPCClient
 from .rpc import SignaldRPCClient
 from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
 from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
-from .types import Address, Quote, Attachment, Reaction, Account, Message, Contact, Group, Profile
+from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
+                    Profile, GroupID)
 
 
 T = TypeVar('T')
 T = TypeVar('T')
 EventHandler = Callable[[T], Awaitable[None]]
 EventHandler = Callable[[T], Awaitable[None]]
@@ -96,19 +97,19 @@ class SignaldClient(SignaldRPCClient):
         return [Account.deserialize(acc) for acc in data["accounts"]]
         return [Account.deserialize(acc) for acc in data["accounts"]]
 
 
     @staticmethod
     @staticmethod
-    def _recipient_to_args(recipient: Union[Address, str]) -> Dict[str, Any]:
+    def _recipient_to_args(recipient: Union[Address, GroupID]) -> Dict[str, Any]:
         if isinstance(recipient, Address):
         if isinstance(recipient, Address):
             return {"recipientAddress": recipient.serialize()}
             return {"recipientAddress": recipient.serialize()}
         else:
         else:
             return {"recipientGroupId": recipient}
             return {"recipientGroupId": recipient}
 
 
-    async def react(self, username: str, recipient: Union[Address, str],
+    async def react(self, username: str, recipient: Union[Address, GroupID],
                     reaction: Reaction) -> None:
                     reaction: Reaction) -> None:
         await self.request("react", "send_results", username=username,
         await self.request("react", "send_results", username=username,
                            reaction=reaction.serialize(),
                            reaction=reaction.serialize(),
                            **self._recipient_to_args(recipient))
                            **self._recipient_to_args(recipient))
 
 
-    async def send(self, username: str, recipient: Union[Address, str], body: str,
+    async def send(self, username: str, recipient: Union[Address, GroupID], body: str,
                    quote: Optional[Quote] = None, attachments: Optional[List[Attachment]] = None,
                    quote: Optional[Quote] = None, attachments: Optional[List[Attachment]] = None,
                    timestamp: Optional[int] = None) -> None:
                    timestamp: Optional[int] = None) -> None:
         serialized_quote = quote.serialize() if quote else None
         serialized_quote = quote.serialize() if quote else None

+ 29 - 5
mausignald/types.py

@@ -3,7 +3,7 @@
 # This Source Code Form is subject to the terms of the Mozilla Public
 # This Source Code Form is subject to the terms of the Mozilla Public
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
-from typing import Optional, Dict, Any, List
+from typing import Optional, Dict, Any, List, NewType
 from uuid import UUID
 from uuid import UUID
 
 
 from attr import dataclass
 from attr import dataclass
@@ -11,6 +11,8 @@ import attr
 
 
 from mautrix.types import SerializableAttrs, SerializableEnum
 from mautrix.types import SerializableAttrs, SerializableEnum
 
 
+GroupID = NewType('GroupID', str)
+
 
 
 @dataclass
 @dataclass
 class Account(SerializableAttrs['Account']):
 class Account(SerializableAttrs['Account']):
@@ -23,7 +25,7 @@ class Account(SerializableAttrs['Account']):
     uuid: Optional[UUID] = None
     uuid: Optional[UUID] = None
 
 
 
 
-@dataclass
+@dataclass(frozen=True, eq=False)
 class Address(SerializableAttrs['Address']):
 class Address(SerializableAttrs['Address']):
     number: Optional[str] = None
     number: Optional[str] = None
     uuid: Optional[UUID] = None
     uuid: Optional[UUID] = None
@@ -32,6 +34,28 @@ class Address(SerializableAttrs['Address']):
     def is_valid(self) -> bool:
     def is_valid(self) -> bool:
         return bool(self.number) or bool(self.uuid)
         return bool(self.number) or bool(self.uuid)
 
 
+    @property
+    def best_identifier(self) -> str:
+        return str(self.uuid) if self.uuid else self.number
+
+    def __eq__(self, other: 'Address') -> bool:
+        if not isinstance(other, Address):
+            return False
+        if self.uuid and other.uuid:
+            return self.uuid == other.uuid
+        elif self.number and other.number:
+            return self.number == other.number
+        return False
+
+    def __hash__(self) -> int:
+        if self.uuid:
+            return hash(self.uuid)
+        return hash(self.number)
+
+    @classmethod
+    def parse(cls, value: str) -> 'Address':
+        return Address(number=value) if value.startswith("+") else Address(uuid=UUID(value))
+
 
 
 @dataclass
 @dataclass
 class Contact(SerializableAttrs['Contact']):
 class Contact(SerializableAttrs['Contact']):
@@ -53,8 +77,8 @@ class Profile(SerializableAttrs['Profile']):
 
 
 @dataclass
 @dataclass
 class Group(SerializableAttrs['Group']):
 class Group(SerializableAttrs['Group']):
-    group_id: str = attr.ib(metadata={"json": "groupId"})
-    name: str
+    group_id: GroupID = attr.ib(metadata={"json": "groupId"})
+    name: str = "Unknown group"
 
 
     # Sometimes "UPDATE"
     # Sometimes "UPDATE"
     type: Optional[str] = None
     type: Optional[str] = None
@@ -147,7 +171,7 @@ class TypingAction(SerializableEnum):
 class TypingNotification(SerializableAttrs['TypingNotification']):
 class TypingNotification(SerializableAttrs['TypingNotification']):
     action: TypingAction
     action: TypingAction
     timestamp: int
     timestamp: int
-    group_id: Optional[str] = attr.ib(default=None, metadata={"json": "groupId"})
+    group_id: Optional[GroupID] = attr.ib(default=None, metadata={"json": "groupId"})
 
 
 
 
 @dataclass
 @dataclass

+ 20 - 15
mautrix_signal/db/message.py

@@ -19,9 +19,12 @@ from uuid import UUID
 from attr import dataclass
 from attr import dataclass
 import asyncpg
 import asyncpg
 
 
+from mausignald.types import Address, GroupID
 from mautrix.types import RoomID, EventID
 from mautrix.types import RoomID, EventID
 from mautrix.util.async_db import Database
 from mautrix.util.async_db import Database
 
 
+from ..util import id_to_str
+
 fake_db = Database("") if TYPE_CHECKING else None
 fake_db = Database("") if TYPE_CHECKING else None
 
 
 
 
@@ -31,22 +34,22 @@ class Message:
 
 
     mxid: EventID
     mxid: EventID
     mx_room: RoomID
     mx_room: RoomID
-    sender: UUID
+    sender: Address
     timestamp: int
     timestamp: int
-    signal_chat_id: Union[str, UUID]
+    signal_chat_id: Union[GroupID, Address]
     signal_receiver: str
     signal_receiver: str
 
 
     async def insert(self) -> None:
     async def insert(self) -> None:
         q = ("INSERT INTO message (mxid, mx_room, sender, timestamp, signal_chat_id,"
         q = ("INSERT INTO message (mxid, mx_room, sender, timestamp, signal_chat_id,"
              "                     signal_receiver) VALUES ($1, $2, $3, $4, $5, $6)")
              "                     signal_receiver) VALUES ($1, $2, $3, $4, $5, $6)")
-        await self.db.execute(q, self.mxid, self.mx_room, self.sender, self.timestamp,
-                              str(self.signal_chat_id), self.signal_receiver)
+        await self.db.execute(q, self.mxid, self.mx_room, self.sender.best_identifier,
+                              self.timestamp, id_to_str(self.signal_chat_id), self.signal_receiver)
 
 
     async def delete(self) -> None:
     async def delete(self) -> None:
         q = ("DELETE FROM message WHERE sender=$1 AND timestamp=$2"
         q = ("DELETE FROM message WHERE sender=$1 AND timestamp=$2"
              "                          AND signal_chat_id=$3 AND signal_receiver=$4")
              "                          AND signal_chat_id=$3 AND signal_receiver=$4")
-        await self.db.execute(q, self.sender, self.timestamp, str(self.signal_chat_id),
-                              self.signal_receiver)
+        await self.db.execute(q, self.sender.best_identifier, self.timestamp,
+                              id_to_str(self.signal_chat_id), self.signal_receiver)
 
 
     @classmethod
     @classmethod
     async def delete_all(cls, room_id: RoomID) -> None:
     async def delete_all(cls, room_id: RoomID) -> None:
@@ -55,11 +58,11 @@ class Message:
     @classmethod
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> 'Message':
     def _from_row(cls, row: asyncpg.Record) -> 'Message':
         data = {**row}
         data = {**row}
+        chat_id = data.pop("signal_chat_id")
         if data["signal_receiver"]:
         if data["signal_receiver"]:
-            chat_id = UUID(data.pop("signal_chat_id"))
-        else:
-            chat_id = data.pop("signal_chat_id")
-        return cls(signal_chat_id=chat_id, **data)
+            chat_id = Address.parse(chat_id)
+        sender = Address.parse(data.pop("sender"))
+        return cls(signal_chat_id=chat_id, sender=sender, **data)
 
 
     @classmethod
     @classmethod
     async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional['Message']:
     async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional['Message']:
@@ -71,12 +74,14 @@ class Message:
         return cls._from_row(row)
         return cls._from_row(row)
 
 
     @classmethod
     @classmethod
-    async def get_by_signal_id(cls, sender: UUID, timestamp: int, signal_chat_id: Union[str, UUID],
-                               signal_receiver: str = "") -> Optional['Message']:
+    async def get_by_signal_id(cls, sender: Address, timestamp: int,
+                               signal_chat_id: Union[GroupID, Address], signal_receiver: str = ""
+                               ) -> Optional['Message']:
         q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
         q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
              "FROM message WHERE sender=$1 AND timestamp=$2"
              "FROM message WHERE sender=$1 AND timestamp=$2"
              "                   AND signal_chat_id=$3 AND signal_receiver=$4")
              "                   AND signal_chat_id=$3 AND signal_receiver=$4")
-        row = await cls.db.fetchrow(q, sender, timestamp, str(signal_chat_id), signal_receiver)
+        row = await cls.db.fetchrow(q, sender.best_identifier, timestamp,
+                                    id_to_str(signal_chat_id), signal_receiver)
         if not row:
         if not row:
             return None
             return None
         return cls._from_row(row)
         return cls._from_row(row)
@@ -89,10 +94,10 @@ class Message:
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]
 
 
     @classmethod
     @classmethod
-    async def find_by_sender_timestamp(cls, sender: UUID, timestamp: int) -> Optional['Message']:
+    async def find_by_sender_timestamp(cls, sender: Address, timestamp: int) -> Optional['Message']:
         q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
         q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
              "FROM message WHERE sender=$1 AND timestamp=$2")
              "FROM message WHERE sender=$1 AND timestamp=$2")
-        row = await cls.db.fetchrow(q, sender, timestamp)
+        row = await cls.db.fetchrow(q, sender.best_identifier, timestamp)
         if not row:
         if not row:
             return None
             return None
         return cls._from_row(row)
         return cls._from_row(row)

+ 18 - 12
mautrix_signal/db/portal.py

@@ -19,9 +19,12 @@ from uuid import UUID
 from attr import dataclass
 from attr import dataclass
 import asyncpg
 import asyncpg
 
 
+from mausignald.types import Address, GroupID
 from mautrix.types import RoomID, ContentURI
 from mautrix.types import RoomID, ContentURI
 from mautrix.util.async_db import Database
 from mautrix.util.async_db import Database
 
 
+from ..util import id_to_str
+
 fake_db = Database("") if TYPE_CHECKING else None
 fake_db = Database("") if TYPE_CHECKING else None
 
 
 
 
@@ -29,7 +32,7 @@ fake_db = Database("") if TYPE_CHECKING else None
 class Portal:
 class Portal:
     db: ClassVar[Database] = fake_db
     db: ClassVar[Database] = fake_db
 
 
-    chat_id: Union[UUID, str]
+    chat_id: Union[GroupID, Address]
     receiver: str
     receiver: str
     mxid: Optional[RoomID]
     mxid: Optional[RoomID]
     name: Optional[str]
     name: Optional[str]
@@ -37,26 +40,29 @@ class Portal:
     avatar_url: Optional[ContentURI]
     avatar_url: Optional[ContentURI]
     encrypted: bool
     encrypted: bool
 
 
+    @property
+    def chat_id_str(self) -> str:
+        return id_to_str(self.chat_id)
+
     async def insert(self) -> None:
     async def insert(self) -> None:
         q = ("INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, "
         q = ("INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, "
              "                    encrypted) "
              "                    encrypted) "
              "VALUES ($1, $2, $3, $4, $5, $6, $7)")
              "VALUES ($1, $2, $3, $4, $5, $6, $7)")
-        await self.db.execute(q, str(self.chat_id), self.receiver, self.mxid, self.name,
+        await self.db.execute(q, self.chat_id_str, self.receiver, self.mxid, self.name,
                               self.avatar_hash, self.avatar_url, self.encrypted)
                               self.avatar_hash, self.avatar_url, self.encrypted)
 
 
     async def update(self) -> None:
     async def update(self) -> None:
         q = ("UPDATE portal SET mxid=$3, name=$4, avatar_hash=$5, avatar_url=$6, encrypted=$7 "
         q = ("UPDATE portal SET mxid=$3, name=$4, avatar_hash=$5, avatar_url=$6, encrypted=$7 "
              "WHERE chat_id=$1 AND receiver=$2")
              "WHERE chat_id=$1 AND receiver=$2")
-        await self.db.execute(q, str(self.chat_id), self.receiver, self.mxid, self.name,
+        await self.db.execute(q, self.chat_id_str, self.receiver, self.mxid, self.name,
                               self.avatar_hash, self.avatar_url, self.encrypted)
                               self.avatar_hash, self.avatar_url, self.encrypted)
 
 
     @classmethod
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> 'Portal':
     def _from_row(cls, row: asyncpg.Record) -> 'Portal':
         data = {**row}
         data = {**row}
+        chat_id = data.pop("chat_id")
         if data["receiver"]:
         if data["receiver"]:
-            chat_id = UUID(data.pop("chat_id"))
-        else:
-            chat_id = data.pop("chat_id")
+            chat_id = Address.parse(chat_id)
         return cls(chat_id=chat_id, **data)
         return cls(chat_id=chat_id, **data)
 
 
     @classmethod
     @classmethod
@@ -69,27 +75,27 @@ class Portal:
         return cls._from_row(row)
         return cls._from_row(row)
 
 
     @classmethod
     @classmethod
-    async def get_by_chat_id(cls, chat_id: Union[UUID, str], receiver: str = ""
+    async def get_by_chat_id(cls, chat_id: Union[GroupID, Address], receiver: str = ""
                              ) -> Optional['Portal']:
                              ) -> Optional['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
              "FROM portal WHERE chat_id=$1 AND receiver=$2")
              "FROM portal WHERE chat_id=$1 AND receiver=$2")
-        row = await cls.db.fetchrow(q, str(chat_id), receiver)
+        row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
         if not row:
         if not row:
             return None
             return None
         return cls._from_row(row)
         return cls._from_row(row)
 
 
     @classmethod
     @classmethod
     async def find_private_chats_of(cls, receiver: str) -> List['Portal']:
     async def find_private_chats_of(cls, receiver: str) -> List['Portal']:
-        q =( "SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
+        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
              "FROM portal WHERE receiver=$1")
              "FROM portal WHERE receiver=$1")
         rows = await cls.db.fetch(q, receiver)
         rows = await cls.db.fetch(q, receiver)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]
 
 
     @classmethod
     @classmethod
-    async def find_private_chats_with(cls, other_user: UUID) -> List['Portal']:
+    async def find_private_chats_with(cls, other_user: Address) -> List['Portal']:
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
         q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted "
-             "FROM portal WHERE chat_id=$1::text AND receiver<>''")
-        rows = await cls.db.fetch(q, other_user)
+             "FROM portal WHERE chat_id=$1 AND receiver<>''")
+        rows = await cls.db.fetch(q, other_user.best_identifier)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]
 
 
     @classmethod
     @classmethod

+ 6 - 1
mautrix_signal/db/puppet.py

@@ -55,7 +55,12 @@ class Puppet:
         if self.uuid:
         if self.uuid:
             raise ValueError("Can't re-set UUID for puppet")
             raise ValueError("Can't re-set UUID for puppet")
         self.uuid = uuid
         self.uuid = uuid
-        await self.db.execute("UPDATE puppet SET uuid=$1 WHERE number=$2", uuid, self.number)
+        async with self.db.acquire() as conn, conn.transaction():
+            await conn.execute("UPDATE puppet SET uuid=$1 WHERE number=$2", uuid, self.number)
+            uuid = str(uuid)
+            await conn.execute("UPDATE portal SET chat_id=$1 WHERE chat_id=$2", uuid, self.number)
+            await conn.execute("UPDATE message SET sender=$1 WHERE sender=$2", uuid, self.number)
+            await conn.execute("UPDATE reaction SET author=$1 WHERE author=$2", uuid, self.number)
 
 
     async def update(self) -> None:
     async def update(self) -> None:
         if self.uuid is None:
         if self.uuid is None:

+ 26 - 18
mautrix_signal/db/reaction.py

@@ -13,15 +13,18 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import Optional, ClassVar, TYPE_CHECKING
+from typing import Optional, ClassVar, Union, TYPE_CHECKING
 from uuid import UUID
 from uuid import UUID
 
 
 from attr import dataclass
 from attr import dataclass
 import asyncpg
 import asyncpg
 
 
+from mausignald.types import Address, GroupID
 from mautrix.types import RoomID, EventID
 from mautrix.types import RoomID, EventID
 from mautrix.util.async_db import Database
 from mautrix.util.async_db import Database
 
 
+from ..util import id_to_str
+
 fake_db = Database("") if TYPE_CHECKING else None
 fake_db = Database("") if TYPE_CHECKING else None
 
 
 
 
@@ -31,42 +34,45 @@ class Reaction:
 
 
     mxid: EventID
     mxid: EventID
     mx_room: RoomID
     mx_room: RoomID
-    signal_chat_id: str
+    signal_chat_id: Union[GroupID, Address]
     signal_receiver: str
     signal_receiver: str
-    msg_author: UUID
+    msg_author: Address
     msg_timestamp: int
     msg_timestamp: int
-    author: UUID
+    author: Address
     emoji: str
     emoji: str
 
 
     async def insert(self) -> None:
     async def insert(self) -> None:
         q = ("INSERT INTO reaction (mxid, mx_room, signal_chat_id, signal_receiver, msg_author,"
         q = ("INSERT INTO reaction (mxid, mx_room, signal_chat_id, signal_receiver, msg_author,"
              "                      msg_timestamp, author, emoji) "
              "                      msg_timestamp, author, emoji) "
              "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)")
              "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)")
-        await self.db.execute(q, self.mxid, self.mx_room, str(self.signal_chat_id),
-                              self.signal_receiver, self.msg_author, self.msg_timestamp,
-                              self.author, self.emoji)
+        await self.db.execute(q, self.mxid, self.mx_room, id_to_str(self.signal_chat_id),
+                              self.signal_receiver, self.msg_author.best_identifier,
+                              self.msg_timestamp, self.author.best_identifier, self.emoji)
 
 
     async def edit(self, mx_room: RoomID, mxid: EventID, emoji: str) -> None:
     async def edit(self, mx_room: RoomID, mxid: EventID, emoji: str) -> None:
         await self.db.execute("UPDATE reaction SET mxid=$1, mx_room=$2, emoji=$3 "
         await self.db.execute("UPDATE reaction SET mxid=$1, mx_room=$2, emoji=$3 "
                               "WHERE signal_chat_id=$4 AND signal_receiver=$5"
                               "WHERE signal_chat_id=$4 AND signal_receiver=$5"
                               "      AND msg_author=$6 AND msg_timestamp=$7 AND author=$8",
                               "      AND msg_author=$6 AND msg_timestamp=$7 AND author=$8",
-                              mxid, mx_room, emoji, str(self.signal_chat_id), self.signal_receiver,
-                              self.msg_author, self.msg_timestamp, self.author)
+                              mxid, mx_room, emoji, id_to_str(self.signal_chat_id),
+                              self.signal_receiver, self.msg_author.best_identifier,
+                              self.msg_timestamp, self.author.best_identifier)
 
 
     async def delete(self) -> None:
     async def delete(self) -> None:
         q = ("DELETE FROM reaction WHERE signal_chat_id=$1 AND signal_receiver=$2"
         q = ("DELETE FROM reaction WHERE signal_chat_id=$1 AND signal_receiver=$2"
              "                           AND msg_author=$3 AND msg_timestamp=$4 AND author=$5")
              "                           AND msg_author=$3 AND msg_timestamp=$4 AND author=$5")
-        await self.db.execute(q, str(self.signal_chat_id), self.signal_receiver, self.msg_author,
-                              self.msg_timestamp, self.author)
+        await self.db.execute(q, id_to_str(self.signal_chat_id), self.signal_receiver,
+                              self.msg_author.best_identifier, self.msg_timestamp,
+                              self.author.best_identifier)
 
 
     @classmethod
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> 'Reaction':
     def _from_row(cls, row: asyncpg.Record) -> 'Reaction':
         data = {**row}
         data = {**row}
+        chat_id = data.pop("signal_chat_id")
         if data["signal_receiver"]:
         if data["signal_receiver"]:
-            chat_id = UUID(data.pop("signal_chat_id"))
-        else:
-            chat_id = data.pop("signal_chat_id")
-        return cls(signal_chat_id=chat_id, **data)
+            chat_id = Address.parse(chat_id)
+        msg_author = Address.parse(data.pop("msg_author"))
+        author = Address.parse(data.pop("author"))
+        return cls(signal_chat_id=chat_id, msg_author=msg_author, author=author, **data)
 
 
     @classmethod
     @classmethod
     async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional['Reaction']:
     async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional['Reaction']:
@@ -79,13 +85,15 @@ class Reaction:
         return cls._from_row(row)
         return cls._from_row(row)
 
 
     @classmethod
     @classmethod
-    async def get_by_signal_id(cls, chat_id: str, receiver: str, msg_author: UUID,
-                               msg_timestamp: int, author: UUID) -> Optional['Reaction']:
+    async def get_by_signal_id(cls, chat_id: Union[GroupID, Address], receiver: str,
+                               msg_author: Address, msg_timestamp: int, author: Address
+                               ) -> Optional['Reaction']:
         q = ("SELECT mxid, mx_room, signal_chat_id, signal_receiver,"
         q = ("SELECT mxid, mx_room, signal_chat_id, signal_receiver,"
              "       msg_author, msg_timestamp, author, emoji "
              "       msg_author, msg_timestamp, author, emoji "
              "FROM reaction WHERE signal_chat_id=$1 AND signal_receiver=$2"
              "FROM reaction WHERE signal_chat_id=$1 AND signal_receiver=$2"
              "                    AND msg_author=$3 AND msg_timestamp=$4 AND author=$5")
              "                    AND msg_author=$3 AND msg_timestamp=$4 AND author=$5")
-        row = await cls.db.fetchrow(q, chat_id, receiver, msg_author, msg_timestamp, author)
+        row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver, msg_author.best_identifier,
+                                    msg_timestamp, author.best_identifier)
         if not row:
         if not row:
             return None
             return None
         return cls._from_row(row)
         return cls._from_row(row)

+ 14 - 0
mautrix_signal/db/upgrade.py

@@ -100,3 +100,17 @@ async def upgrade_v2(conn: Connection) -> None:
 @upgrade_table.register(description="Add double-puppeting base_url to puppe table")
 @upgrade_table.register(description="Add double-puppeting base_url to puppe table")
 async def upgrade_v3(conn: Connection) -> None:
 async def upgrade_v3(conn: Connection) -> None:
     await conn.execute("ALTER TABLE puppet ADD COLUMN base_url TEXT")
     await conn.execute("ALTER TABLE puppet ADD COLUMN base_url TEXT")
+
+
+@upgrade_table.register(description="Allow phone numbers as message sender identifiers")
+async def upgrade_v4(conn: Connection) -> None:
+    cname = await conn.fetchval("SELECT constraint_name FROM information_schema.table_constraints "
+                                "WHERE table_name='reaction' AND constraint_name LIKE '%_fkey'")
+    await conn.execute(f"ALTER TABLE reaction DROP CONSTRAINT {cname}")
+    await conn.execute("ALTER TABLE reaction ALTER COLUMN msg_author SET DATA TYPE TEXT")
+    await conn.execute("ALTER TABLE reaction ALTER COLUMN author SET DATA TYPE TEXT")
+    await conn.execute("ALTER TABLE message ALTER COLUMN sender SET DATA TYPE TEXT")
+    await conn.execute(f"ALTER TABLE reaction ADD CONSTRAINT {cname} "
+                       "FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver) "
+                       "  REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver) "
+                       "  ON DELETE CASCADE ON UPDATE CASCADE")

+ 2 - 4
mautrix_signal/matrix.py

@@ -15,8 +15,6 @@
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 from typing import List, Union, TYPE_CHECKING
 from typing import List, Union, TYPE_CHECKING
 
 
-from mausignald.types import Address
-
 from mautrix.bridge import BaseMatrixHandler
 from mautrix.bridge import BaseMatrixHandler
 from mautrix.types import (Event, ReactionEvent, MessageEvent, StateEvent, EncryptedEvent, RoomID,
 from mautrix.types import (Event, ReactionEvent, MessageEvent, StateEvent, EncryptedEvent, RoomID,
                            EventID, UserID, ReactionEventContent, RelationType, EventType,
                            EventID, UserID, ReactionEventContent, RelationType, EventType,
@@ -93,7 +91,7 @@ class MatrixHandler(BaseMatrixHandler):
             return
             return
 
 
         user.log.trace(f"Sending read receipt for {message.timestamp} to {message.sender}")
         user.log.trace(f"Sending read receipt for {message.timestamp} to {message.sender}")
-        await self.signal.send_receipt(user.username, Address(uuid=message.sender),
+        await self.signal.send_receipt(user.username, message.sender,
                                        timestamps=[message.timestamp], when=data.ts, read=True)
                                        timestamps=[message.timestamp], when=data.ts, read=True)
 
 
     async def handle_typing(self, room_id: RoomID, typing: List[UserID]) -> None:
     async def handle_typing(self, room_id: RoomID, typing: List[UserID]) -> None:
@@ -118,4 +116,4 @@ class MatrixHandler(BaseMatrixHandler):
         if evt.type == EventType.TYPING:
         if evt.type == EventType.TYPING:
             await self.handle_typing(evt.room_id, evt.content.user_ids)
             await self.handle_typing(evt.room_id, evt.content.user_ids)
         else:
         else:
-            super().handle_ephemeral_event(evt)
+            await super().handle_ephemeral_event(evt)

+ 54 - 58
mautrix_signal/portal.py

@@ -13,8 +13,8 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import (Dict, Tuple, Optional, List, Deque, Set, Any, Union, AsyncGenerator,
-                    Awaitable, TYPE_CHECKING, cast)
+from typing import (Dict, Tuple, Optional, List, Deque, Any, Union, AsyncGenerator, Awaitable,
+                    TYPE_CHECKING, cast)
 from collections import deque
 from collections import deque
 from uuid import UUID, uuid4
 from uuid import UUID, uuid4
 import mimetypes
 import mimetypes
@@ -25,7 +25,7 @@ import time
 import os
 import os
 
 
 from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
 from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
-                              Attachment)
+                              Attachment, GroupID)
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.bridge import BasePortal
 from mautrix.bridge import BasePortal
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
@@ -57,7 +57,7 @@ ChatInfo = Union[Group, Contact, Profile, Address]
 
 
 class Portal(DBPortal, BasePortal):
 class Portal(DBPortal, BasePortal):
     by_mxid: Dict[RoomID, 'Portal'] = {}
     by_mxid: Dict[RoomID, 'Portal'] = {}
-    by_chat_id: Dict[Tuple[Union[str, UUID], str], 'Portal'] = {}
+    by_chat_id: Dict[Tuple[str, str], 'Portal'] = {}
     config: Config
     config: Config
     matrix: 'm.MatrixHandler'
     matrix: 'm.MatrixHandler'
     signal: 's.SignalHandler'
     signal: 's.SignalHandler'
@@ -66,16 +66,16 @@ class Portal(DBPortal, BasePortal):
 
 
     _main_intent: Optional[IntentAPI]
     _main_intent: Optional[IntentAPI]
     _create_room_lock: asyncio.Lock
     _create_room_lock: asyncio.Lock
-    _msgts_dedup: Deque[Tuple[UUID, int]]
-    _reaction_dedup: Deque[Tuple[UUID, int, str]]
+    _msgts_dedup: Deque[Tuple[Address, int]]
+    _reaction_dedup: Deque[Tuple[Address, int, str]]
     _reaction_lock: asyncio.Lock
     _reaction_lock: asyncio.Lock
 
 
-    def __init__(self, chat_id: Union[str, UUID], receiver: str, mxid: Optional[RoomID] = None,
+    def __init__(self, chat_id: Union[GroupID, Address], receiver: str, mxid: Optional[RoomID] = None,
                  name: Optional[str] = None, avatar_hash: Optional[str] = None,
                  name: Optional[str] = None, avatar_hash: Optional[str] = None,
                  avatar_url: Optional[ContentURI] = None, encrypted: bool = False) -> None:
                  avatar_url: Optional[ContentURI] = None, encrypted: bool = False) -> None:
         super().__init__(chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted)
         super().__init__(chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted)
         self._create_room_lock = asyncio.Lock()
         self._create_room_lock = asyncio.Lock()
-        self.log = self.log.getChild(str(chat_id))
+        self.log = self.log.getChild(self.chat_id_str)
         self._main_intent = None
         self._main_intent = None
         self._msgts_dedup = deque(maxlen=100)
         self._msgts_dedup = deque(maxlen=100)
         self._reaction_dedup = deque(maxlen=100)
         self._reaction_dedup = deque(maxlen=100)
@@ -90,14 +90,15 @@ class Portal(DBPortal, BasePortal):
 
 
     @property
     @property
     def is_direct(self) -> bool:
     def is_direct(self) -> bool:
-        return isinstance(self.chat_id, UUID)
+        return isinstance(self.chat_id, Address)
 
 
-    @property
-    def recipient(self) -> Union[str, Address]:
-        if self.is_direct:
-            return Address(uuid=self.chat_id)
-        else:
-            return self.chat_id
+    def handle_uuid_receive(self, uuid: UUID) -> None:
+        if not self.is_direct or self.chat_id.uuid:
+            raise ValueError("handle_uuid_receive can only be used for private chat portals with "
+                             "a phone number chat_id")
+        del self.by_chat_id[(self.chat_id_str, self.receiver)]
+        self.chat_id = Address(uuid=uuid)
+        self.by_chat_id[(self.chat_id_str, self.receiver)] = self
 
 
     @classmethod
     @classmethod
     def init_cls(cls, bridge: 'SignalBridge') -> None:
     def init_cls(cls, bridge: 'SignalBridge') -> None:
@@ -131,9 +132,10 @@ class Portal(DBPortal, BasePortal):
             await existing.edit(emoji=emoji, mxid=mxid, mx_room=message.mx_room)
             await existing.edit(emoji=emoji, mxid=mxid, mx_room=message.mx_room)
         else:
         else:
             self.log.debug(f"_upsert_reaction inserting {mxid} (message: {message.mxid})")
             self.log.debug(f"_upsert_reaction inserting {mxid} (message: {message.mxid})")
-            await DBReaction(mxid=mxid, mx_room=message.mx_room, emoji=emoji, author=sender.uuid,
+            await DBReaction(mxid=mxid, mx_room=message.mx_room, emoji=emoji,
                              signal_chat_id=self.chat_id, signal_receiver=self.receiver,
                              signal_chat_id=self.chat_id, signal_receiver=self.receiver,
-                             msg_author=message.sender, msg_timestamp=message.timestamp).insert()
+                             msg_author=message.sender, msg_timestamp=message.timestamp,
+                             author=sender.address).insert()
 
 
     # endregion
     # endregion
     # region Matrix event handling
     # region Matrix event handling
@@ -168,14 +170,14 @@ class Portal(DBPortal, BasePortal):
             self.log.debug(f"Ignoring puppet-sent message by confirmed puppet user {sender.mxid}")
             self.log.debug(f"Ignoring puppet-sent message by confirmed puppet user {sender.mxid}")
             return
             return
         request_id = int(time.time() * 1000)
         request_id = int(time.time() * 1000)
-        self._msgts_dedup.appendleft((sender.uuid, request_id))
+        self._msgts_dedup.appendleft((sender.address, request_id))
 
 
         quote = None
         quote = None
         if message.get_reply_to():
         if message.get_reply_to():
             reply = await DBMessage.get_by_mxid(message.get_reply_to(), self.mxid)
             reply = await DBMessage.get_by_mxid(message.get_reply_to(), self.mxid)
             # TODO include actual text? either store in db or fetch event from homeserver
             # TODO include actual text? either store in db or fetch event from homeserver
             if reply is not None:
             if reply is not None:
-                quote = Quote(id=reply.timestamp, author=Address(uuid=reply.sender), text="")
+                quote = Quote(id=reply.timestamp, author=reply.sender, text="")
 
 
         text = message.body
         text = message.body
         attachments: Optional[List[Attachment]] = None
         attachments: Optional[List[Attachment]] = None
@@ -188,9 +190,9 @@ class Portal(DBPortal, BasePortal):
             attachments = [attachment]
             attachments = [attachment]
             text = None
             text = None
             self.log.trace("Formed outgoing attachment %s", attachment)
             self.log.trace("Formed outgoing attachment %s", attachment)
-        await self.signal.send(username=sender.username, recipient=self.recipient, body=text,
+        await self.signal.send(username=sender.username, recipient=self.chat_id, body=text,
                                quote=quote, attachments=attachments, timestamp=request_id)
                                quote=quote, attachments=attachments, timestamp=request_id)
-        msg = DBMessage(mxid=event_id, mx_room=self.mxid, sender=sender.uuid, timestamp=request_id,
+        msg = DBMessage(mxid=event_id, mx_room=self.mxid, sender=sender.address, timestamp=request_id,
                         signal_chat_id=self.chat_id, signal_receiver=self.receiver)
                         signal_chat_id=self.chat_id, signal_receiver=self.receiver)
         await msg.insert()
         await msg.insert()
         await self._send_delivery_receipt(event_id)
         await self._send_delivery_receipt(event_id)
@@ -212,7 +214,7 @@ class Portal(DBPortal, BasePortal):
             return
             return
 
 
         existing = await DBReaction.get_by_signal_id(self.chat_id, self.receiver, message.sender,
         existing = await DBReaction.get_by_signal_id(self.chat_id, self.receiver, message.sender,
-                                                     message.timestamp, sender.uuid)
+                                                     message.timestamp, sender.address)
         if existing and existing.emoji == emoji:
         if existing and existing.emoji == emoji:
             return
             return
 
 
@@ -220,9 +222,9 @@ class Portal(DBPortal, BasePortal):
         self._reaction_dedup.appendleft(dedup_id)
         self._reaction_dedup.appendleft(dedup_id)
         async with self._reaction_lock:
         async with self._reaction_lock:
             reaction = Reaction(emoji=emoji, remove=False,
             reaction = Reaction(emoji=emoji, remove=False,
-                                target_author=Address(uuid=message.sender),
+                                target_author=message.sender,
                                 target_sent_timestamp=message.timestamp)
                                 target_sent_timestamp=message.timestamp)
-            await self.signal.react(username=sender.username, recipient=self.recipient,
+            await self.signal.react(username=sender.username, recipient=self.chat_id,
                                     reaction=reaction)
                                     reaction=reaction)
             await self._upsert_reaction(existing, self.main_intent, event_id, sender, message,
             await self._upsert_reaction(existing, self.main_intent, event_id, sender, message,
                                         emoji)
                                         emoji)
@@ -239,9 +241,9 @@ class Portal(DBPortal, BasePortal):
             try:
             try:
                 await reaction.delete()
                 await reaction.delete()
                 remove_reaction = Reaction(emoji=reaction.emoji, remove=True,
                 remove_reaction = Reaction(emoji=reaction.emoji, remove=True,
-                                           target_author=Address(uuid=reaction.msg_author),
+                                           target_author=reaction.msg_author,
                                            target_sent_timestamp=reaction.msg_timestamp)
                                            target_sent_timestamp=reaction.msg_timestamp)
-                await self.signal.react(username=sender.username, recipient=self.recipient,
+                await self.signal.react(username=sender.username, recipient=self.chat_id,
                                         reaction=remove_reaction)
                                         reaction=remove_reaction)
                 await self._send_delivery_receipt(redaction_event_id)
                 await self._send_delivery_receipt(redaction_event_id)
                 self.log.trace(f"Removed {reaction} after Matrix redaction")
                 self.log.trace(f"Removed {reaction} after Matrix redaction")
@@ -263,21 +265,17 @@ class Portal(DBPortal, BasePortal):
     # region Signal event handling
     # region Signal event handling
 
 
     @staticmethod
     @staticmethod
-    async def _find_address_uuid(address: Address) -> Optional[UUID]:
-        if address.uuid:
-            return address.uuid
+    async def _resolve_address(address: Address) -> Address:
         puppet = await p.Puppet.get_by_address(address, create=False)
         puppet = await p.Puppet.get_by_address(address, create=False)
-        if puppet and puppet.uuid:
-            return puppet.uuid
-        return None
+        return puppet.address
 
 
     async def _find_quote_event_id(self, quote: Optional[Quote]
     async def _find_quote_event_id(self, quote: Optional[Quote]
                                    ) -> Optional[Union[MessageEvent, EventID]]:
                                    ) -> Optional[Union[MessageEvent, EventID]]:
         if not quote:
         if not quote:
             return None
             return None
 
 
-        author_uuid = await self._find_address_uuid(quote.author)
-        reply_msg = await DBMessage.get_by_signal_id(author_uuid, quote.id,
+        author_address = await self._resolve_address(quote.author)
+        reply_msg = await DBMessage.get_by_signal_id(author_address, quote.id,
                                                      self.chat_id, self.receiver)
                                                      self.chat_id, self.receiver)
         if not reply_msg:
         if not reply_msg:
             return None
             return None
@@ -291,13 +289,13 @@ class Portal(DBPortal, BasePortal):
 
 
     async def handle_signal_message(self, source: 'u.User', sender: 'p.Puppet',
     async def handle_signal_message(self, source: 'u.User', sender: 'p.Puppet',
                                     message: MessageData) -> None:
                                     message: MessageData) -> None:
-        if (sender.uuid, message.timestamp) in self._msgts_dedup:
+        if (sender.address, message.timestamp) in self._msgts_dedup:
             self.log.debug(f"Ignoring message {message.timestamp} by {sender.uuid}"
             self.log.debug(f"Ignoring message {message.timestamp} by {sender.uuid}"
                            " as it was already handled (message.timestamp in dedup queue)")
                            " as it was already handled (message.timestamp in dedup queue)")
             await self.signal.send_receipt(source.username, sender.address,
             await self.signal.send_receipt(source.username, sender.address,
                                            timestamps=[message.timestamp])
                                            timestamps=[message.timestamp])
             return
             return
-        old_message = await DBMessage.get_by_signal_id(sender.uuid, message.timestamp,
+        old_message = await DBMessage.get_by_signal_id(sender.address, message.timestamp,
                                                        self.chat_id, self.receiver)
                                                        self.chat_id, self.receiver)
         if old_message is not None:
         if old_message is not None:
             self.log.debug(f"Ignoring message {message.timestamp} by {sender.uuid}"
             self.log.debug(f"Ignoring message {message.timestamp} by {sender.uuid}"
@@ -307,7 +305,7 @@ class Portal(DBPortal, BasePortal):
             return
             return
         self.log.debug(f"Started handling message {message.timestamp} by {sender.uuid}")
         self.log.debug(f"Started handling message {message.timestamp} by {sender.uuid}")
         self.log.trace(f"Message content: {message}")
         self.log.trace(f"Message content: {message}")
-        self._msgts_dedup.appendleft((sender.uuid, message.timestamp))
+        self._msgts_dedup.appendleft((sender.address, message.timestamp))
         intent = sender.intent_for(self)
         intent = sender.intent_for(self)
         await intent.set_typing(self.mxid, False)
         await intent.set_typing(self.mxid, False)
         event_id = None
         event_id = None
@@ -345,7 +343,7 @@ class Portal(DBPortal, BasePortal):
 
 
         if event_id:
         if event_id:
             msg = DBMessage(mxid=event_id, mx_room=self.mxid,
             msg = DBMessage(mxid=event_id, mx_room=self.mxid,
-                            sender=sender.uuid, timestamp=message.timestamp,
+                            sender=sender.address, timestamp=message.timestamp,
                             signal_chat_id=self.chat_id, signal_receiver=self.receiver)
                             signal_chat_id=self.chat_id, signal_receiver=self.receiver)
             await msg.insert()
             await msg.insert()
             await self.signal.send_receipt(source.username, sender.address,
             await self.signal.send_receipt(source.username, sender.address,
@@ -404,20 +402,16 @@ class Portal(DBPortal, BasePortal):
         return content
         return content
 
 
     async def handle_signal_reaction(self, sender: 'p.Puppet', reaction: Reaction) -> None:
     async def handle_signal_reaction(self, sender: 'p.Puppet', reaction: Reaction) -> None:
-        author_uuid = await self._find_address_uuid(reaction.target_author)
+        author_address = await self._resolve_address(reaction.target_author)
         target_id = reaction.target_sent_timestamp
         target_id = reaction.target_sent_timestamp
-        if author_uuid is None:
-            self.log.warning(f"Failed to handle reaction from {sender.uuid}: "
-                             f"couldn't find UUID of {reaction.target_author}")
-            return
         async with self._reaction_lock:
         async with self._reaction_lock:
-            dedup_id = (author_uuid, target_id, reaction.emoji)
+            dedup_id = (author_address, target_id, reaction.emoji)
             if dedup_id in self._reaction_dedup:
             if dedup_id in self._reaction_dedup:
                 return
                 return
             self._reaction_dedup.appendleft(dedup_id)
             self._reaction_dedup.appendleft(dedup_id)
 
 
         existing = await DBReaction.get_by_signal_id(self.chat_id, self.receiver,
         existing = await DBReaction.get_by_signal_id(self.chat_id, self.receiver,
-                                                     author_uuid, target_id, sender.uuid)
+                                                     author_address, target_id, sender.address)
 
 
         if reaction.remove:
         if reaction.remove:
             if existing:
             if existing:
@@ -431,7 +425,7 @@ class Portal(DBPortal, BasePortal):
         elif existing and existing.emoji == reaction.emoji:
         elif existing and existing.emoji == reaction.emoji:
             return
             return
 
 
-        message = await DBMessage.get_by_signal_id(author_uuid, target_id,
+        message = await DBMessage.get_by_signal_id(author_address, target_id,
                                                    self.chat_id, self.receiver)
                                                    self.chat_id, self.receiver)
         if not message:
         if not message:
             self.log.debug(f"Ignoring reaction to unknown message {target_id}")
             self.log.debug(f"Ignoring reaction to unknown message {target_id}")
@@ -440,7 +434,7 @@ class Portal(DBPortal, BasePortal):
         intent = sender.intent_for(self)
         intent = sender.intent_for(self)
         # TODO add variation selectors to emoji before sending to Matrix
         # TODO add variation selectors to emoji before sending to Matrix
         mxid = await intent.react(message.mx_room, message.mxid, reaction.emoji)
         mxid = await intent.react(message.mx_room, message.mxid, reaction.emoji)
-        self.log.debug(f"{sender.uuid} reacted to {message.mxid} -> {mxid}")
+        self.log.debug(f"{sender.address} reacted to {message.mxid} -> {mxid}")
         await self._upsert_reaction(existing, intent, mxid, sender, message, reaction.emoji)
         await self._upsert_reaction(existing, intent, mxid, sender, message, reaction.emoji)
 
 
     # endregion
     # endregion
@@ -451,7 +445,7 @@ class Portal(DBPortal, BasePortal):
             if not isinstance(info, (Contact, Profile, Address)):
             if not isinstance(info, (Contact, Profile, Address)):
                 raise ValueError(f"Unexpected type for direct chat update_info: {type(info)}")
                 raise ValueError(f"Unexpected type for direct chat update_info: {type(info)}")
             if not self.name:
             if not self.name:
-                puppet = await p.Puppet.get_by_address(Address(uuid=self.chat_id))
+                puppet = await p.Puppet.get_by_address(self.chat_id)
                 if not puppet.name:
                 if not puppet.name:
                     await puppet.update_info(info)
                     await puppet.update_info(info)
                 self.name = puppet.name
                 self.name = puppet.name
@@ -619,18 +613,18 @@ class Portal(DBPortal, BasePortal):
         if self.config["bridge.encryption.default"] and self.matrix.e2ee:
         if self.config["bridge.encryption.default"] and self.matrix.e2ee:
             self.encrypted = True
             self.encrypted = True
             initial_state.append({
             initial_state.append({
-                "type": "m.room.encryption",
+                "type": str(EventType.ROOM_ENCRYPTION),
                 "content": {"algorithm": "m.megolm.v1.aes-sha2"},
                 "content": {"algorithm": "m.megolm.v1.aes-sha2"},
             })
             })
             if self.is_direct:
             if self.is_direct:
                 invites.append(self.az.bot_mxid)
                 invites.append(self.az.bot_mxid)
-        if source.uuid == self.chat_id:
+        if self.is_direct and source.address == self.chat_id:
             name = self.name = "Signal Note to Self"
             name = self.name = "Signal Note to Self"
         elif self.encrypted or self.private_chat_portal_meta or not self.is_direct:
         elif self.encrypted or self.private_chat_portal_meta or not self.is_direct:
             name = self.name
             name = self.name
         if self.avatar_url:
         if self.avatar_url:
             initial_state.append({
             initial_state.append({
-                "type": "m.room.avatar",
+                "type": str(EventType.ROOM_AVATAR),
                 "content": {"url": self.avatar_url},
                 "content": {"url": self.avatar_url},
             })
             })
         if self.config["appservice.community_id"]:
         if self.config["appservice.community_id"]:
@@ -638,10 +632,9 @@ class Portal(DBPortal, BasePortal):
                 "type": "m.room.related_groups",
                 "type": "m.room.related_groups",
                 "content": {"groups": [self.config["appservice.community_id"]]},
                 "content": {"groups": [self.config["appservice.community_id"]]},
             })
             })
-        #Allow chaning of room avatar and name in direct chats
         if self.is_direct:
         if self.is_direct:
             initial_state.append({
             initial_state.append({
-                "type": "m.room.power_levels",
+                "type": str(EventType.ROOM_POWER_LEVELS),
                 "content": {"users": {self.main_intent.mxid: 100},
                 "content": {"users": {self.main_intent.mxid: 100},
                             "events": {"m.room.avatar": 0, "m.room.name": 0}}
                             "events": {"m.room.avatar": 0, "m.room.name": 0}}
             })
             })
@@ -689,7 +682,7 @@ class Portal(DBPortal, BasePortal):
         if self.mxid:
         if self.mxid:
             self.by_mxid[self.mxid] = self
             self.by_mxid[self.mxid] = self
         if self.is_direct:
         if self.is_direct:
-            puppet = await p.Puppet.get_by_address(Address(uuid=self.chat_id))
+            puppet = await p.Puppet.get_by_address(self.chat_id)
             self._main_intent = puppet.default_mxid_intent
             self._main_intent = puppet.default_mxid_intent
         elif not self.is_direct:
         elif not self.is_direct:
             self._main_intent = self.az.intent
             self._main_intent = self.az.intent
@@ -709,7 +702,7 @@ class Portal(DBPortal, BasePortal):
         return cls._db_to_portals(super().all_with_room())
         return cls._db_to_portals(super().all_with_room())
 
 
     @classmethod
     @classmethod
-    def find_private_chats_with(cls, other_user: UUID) -> AsyncGenerator['Portal', None]:
+    def find_private_chats_with(cls, other_user: Address) -> AsyncGenerator['Portal', None]:
         return cls._db_to_portals(super().find_private_chats_with(other_user))
         return cls._db_to_portals(super().find_private_chats_with(other_user))
 
 
     @classmethod
     @classmethod
@@ -718,7 +711,7 @@ class Portal(DBPortal, BasePortal):
         portals = await query
         portals = await query
         for index, portal in enumerate(portals):
         for index, portal in enumerate(portals):
             try:
             try:
-                yield cls.by_chat_id[(portal.chat_id, portal.receiver)]
+                yield cls.by_chat_id[(portal.chat_id_str, portal.receiver)]
             except KeyError:
             except KeyError:
                 await portal._postinit()
                 await portal._postinit()
                 yield portal
                 yield portal
@@ -738,14 +731,17 @@ class Portal(DBPortal, BasePortal):
         return None
         return None
 
 
     @classmethod
     @classmethod
-    async def get_by_chat_id(cls, chat_id: Union[UUID, str], receiver: str = "",
+    async def get_by_chat_id(cls, chat_id: Union[GroupID, Address], receiver: str = "",
                              create: bool = False) -> Optional['Portal']:
                              create: bool = False) -> Optional['Portal']:
         if isinstance(chat_id, str):
         if isinstance(chat_id, str):
             receiver = ""
             receiver = ""
+        elif not isinstance(chat_id, Address):
+            raise ValueError(f"Invalid chat ID type {type(chat_id)}")
         elif not receiver:
         elif not receiver:
             raise ValueError("Direct chats must have a receiver")
             raise ValueError("Direct chats must have a receiver")
         try:
         try:
-            return cls.by_chat_id[(chat_id, receiver)]
+            best_id = chat_id.best_identifier if isinstance(chat_id, Address) else chat_id
+            return cls.by_chat_id[(best_id, receiver)]
         except KeyError:
         except KeyError:
             pass
             pass
 
 

+ 12 - 4
mautrix_signal/puppet.py

@@ -28,7 +28,7 @@ from mautrix.util.simple_template import SimpleTemplate
 
 
 from .db import Puppet as DBPuppet
 from .db import Puppet as DBPuppet
 from .config import Config
 from .config import Config
-from . import portal as p
+from . import portal as p, user as u
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .__main__ import SignalBridge
     from .__main__ import SignalBridge
@@ -90,7 +90,7 @@ class Puppet(DBPuppet, BasePuppet):
         return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
         return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
 
 
     def intent_for(self, portal: 'p.Portal') -> IntentAPI:
     def intent_for(self, portal: 'p.Portal') -> IntentAPI:
-        if portal.chat_id == self.uuid:
+        if portal.chat_id == self.address:
             return self.default_mxid_intent
             return self.default_mxid_intent
         return self.intent
         return self.intent
 
 
@@ -118,14 +118,22 @@ class Puppet(DBPuppet, BasePuppet):
 
 
     async def _handle_uuid_receive(self, uuid: UUID) -> None:
     async def _handle_uuid_receive(self, uuid: UUID) -> None:
         self.log.debug(f"Found UUID for user: {uuid}")
         self.log.debug(f"Found UUID for user: {uuid}")
+        user = await u.User.get_by_username(self.number)
+        if user and not user.uuid:
+            user.uuid = self.uuid
+            await user.update()
         await self._set_uuid(uuid)
         await self._set_uuid(uuid)
         self.by_uuid[self.uuid] = self
         self.by_uuid[self.uuid] = self
+        async for portal in p.Portal.find_private_chats_with(Address(number=self.number)):
+            self.log.trace(f"Updating chat_id of private chat portal {portal.receiver}")
+            portal.handle_uuid_receive(self.uuid)
         prev_intent = self.default_mxid_intent
         prev_intent = self.default_mxid_intent
         self.default_mxid = self.get_mxid_from_id(self.address)
         self.default_mxid = self.get_mxid_from_id(self.address)
         self.default_mxid_intent = self.az.intent.user(self.default_mxid)
         self.default_mxid_intent = self.az.intent.user(self.default_mxid)
         self.intent = self._fresh_intent()
         self.intent = self._fresh_intent()
         await self.intent.ensure_registered()
         await self.intent.ensure_registered()
-        await self.intent.set_displayname(self.name)
+        if self.name:
+            await self.intent.set_displayname(self.name)
         self.log = self.log.getChild(str(uuid))
         self.log = self.log.getChild(str(uuid))
         self.log.debug(f"Migrating memberships {prev_intent.mxid} -> {self.default_mxid_intent}")
         self.log.debug(f"Migrating memberships {prev_intent.mxid} -> {self.default_mxid_intent}")
         for room_id in await prev_intent.get_joined_rooms():
         for room_id in await prev_intent.get_joined_rooms():
@@ -186,7 +194,7 @@ class Puppet(DBPuppet, BasePuppet):
         return False
         return False
 
 
     async def _update_portal_names(self) -> None:
     async def _update_portal_names(self) -> None:
-        async for portal in p.Portal.find_private_chats_with(self.uuid):
+        async for portal in p.Portal.find_private_chats_with(self.address):
             if portal.receiver == self.number:
             if portal.receiver == self.number:
                 # This is a note to self chat, don't change the name
                 # This is a note to self chat, don't change the name
                 continue
                 continue

+ 4 - 9
mautrix_signal/signal.py

@@ -42,10 +42,6 @@ class SignalHandler(SignaldClient):
 
 
     async def on_message(self, evt: Message) -> None:
     async def on_message(self, evt: Message) -> None:
         sender = await pu.Puppet.get_by_address(evt.source)
         sender = await pu.Puppet.get_by_address(evt.source)
-        if not sender.uuid:
-            self.log.debug("Got message sender puppet with no UUID, not handling message")
-            self.log.trace("Message content: %s", evt)
-            return
         user = await u.User.get_by_username(evt.username)
         user = await u.User.get_by_username(evt.username)
         # TODO add lots of logging
         # TODO add lots of logging
 
 
@@ -74,8 +70,7 @@ class SignalHandler(SignaldClient):
         if msg.group:
         if msg.group:
             portal = await po.Portal.get_by_chat_id(msg.group.group_id, create=True)
             portal = await po.Portal.get_by_chat_id(msg.group.group_id, create=True)
         else:
         else:
-            portal = await po.Portal.get_by_chat_id(addr_override.uuid
-                                                    if addr_override else sender.uuid,
+            portal = await po.Portal.get_by_chat_id(addr_override or sender.address,
                                                     receiver=user.username, create=True)
                                                     receiver=user.username, create=True)
             if addr_override and not sender.is_real_user:
             if addr_override and not sender.is_real_user:
                 portal.log.debug(f"Ignoring own message {msg.timestamp} as user doesn't have"
                 portal.log.debug(f"Ignoring own message {msg.timestamp} as user doesn't have"
@@ -94,9 +89,9 @@ class SignalHandler(SignaldClient):
     async def handle_own_receipts(sender: 'pu.Puppet', receipts: List[OwnReadReceipt]) -> None:
     async def handle_own_receipts(sender: 'pu.Puppet', receipts: List[OwnReadReceipt]) -> None:
         for receipt in receipts:
         for receipt in receipts:
             puppet = await pu.Puppet.get_by_address(receipt.sender, create=False)
             puppet = await pu.Puppet.get_by_address(receipt.sender, create=False)
-            if not puppet or not puppet.uuid:
+            if not puppet:
                 continue
                 continue
-            message = await DBMessage.find_by_sender_timestamp(puppet.uuid, receipt.timestamp)
+            message = await DBMessage.find_by_sender_timestamp(puppet.address, receipt.timestamp)
             if not message:
             if not message:
                 continue
                 continue
             portal = await po.Portal.get_by_mxid(message.mx_room)
             portal = await po.Portal.get_by_mxid(message.mx_room)
@@ -110,7 +105,7 @@ class SignalHandler(SignaldClient):
         if typing.group_id:
         if typing.group_id:
             portal = await po.Portal.get_by_chat_id(typing.group_id)
             portal = await po.Portal.get_by_chat_id(typing.group_id)
         else:
         else:
-            portal = await po.Portal.get_by_chat_id(sender.uuid, receiver=user.username)
+            portal = await po.Portal.get_by_chat_id(sender.address, receiver=user.username)
         if not portal or not portal.mxid:
         if not portal or not portal.mxid:
             return
             return
         is_typing = typing.action == TypingAction.STARTED
         is_typing = typing.action == TypingAction.STARTED

+ 2 - 2
mautrix_signal/user.py

@@ -106,8 +106,8 @@ class User(DBUser, BaseUser):
                 # necessary, but maybe we could listen for updates?
                 # necessary, but maybe we could listen for updates?
                 profile = None
                 profile = None
             await puppet.update_info(profile or contact)
             await puppet.update_info(profile or contact)
-            if puppet.uuid and create_contact_portal:
-                portal = await po.Portal.get_by_chat_id(puppet.uuid, self.username, create=True)
+            if create_contact_portal:
+                portal = await po.Portal.get_by_chat_id(puppet.address, self.username, create=True)
                 await portal.create_matrix_room(self, profile or contact)
                 await portal.create_matrix_room(self, profile or contact)
 
 
         create_group_portal = self.config["bridge.autocreate_group_portal"]
         create_group_portal = self.config["bridge.autocreate_group_portal"]

+ 1 - 0
mautrix_signal/util/__init__.py

@@ -1 +1,2 @@
 from .color_log import ColorFormatter
 from .color_log import ColorFormatter
+from .id_to_str import id_to_str

+ 24 - 0
mautrix_signal/util/id_to_str.py

@@ -0,0 +1,24 @@
+# mautrix-signal - A Matrix-Signal puppeting bridge
+# Copyright (C) 2020 Tulir Asokan
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+from typing import Union
+
+from mausignald.types import Address, GroupID
+
+
+def id_to_str(identifier: Union[Address, GroupID]) -> str:
+    if isinstance(identifier, Address):
+        return identifier.best_identifier
+    return identifier

+ 1 - 1
requirements.txt

@@ -4,5 +4,5 @@ commonmark>=0.8,<0.10
 aiohttp>=3,<3.7
 aiohttp>=3,<3.7
 yarl>=1,<1.6
 yarl>=1,<1.6
 attrs>=19.1
 attrs>=19.1
-mautrix==0.8.0.beta7
+mautrix==0.8.0.beta9
 asyncpg>=0.20,<0.22
 asyncpg>=0.20,<0.22