Эх сурвалжийг харах

Add support for accepting Signal group invites

Tulir Asokan 4 жил өмнө
parent
commit
90017f3c60

+ 10 - 0
ROADMAP.md

@@ -17,6 +17,11 @@
   * [x] Group info changes
   * [x] Group info changes
     * [x] Name
     * [x] Name
     * [x] Avatar
     * [x] Avatar
+  * [ ] Membership actions
+    * [x] Join (accept invite)
+    * [ ] Invite
+    * [ ] Leave
+    * [ ] Kick
   * [ ] Typing notifications
   * [ ] Typing notifications
   * [ ] Read receipts (currently partial support, only marks last message)
   * [ ] Read receipts (currently partial support, only marks last message)
   * [x] Delivery receipts (sent after message is bridged)
   * [x] Delivery receipts (sent after message is bridged)
@@ -40,6 +45,11 @@
     * [ ] Real time
     * [ ] Real time
       * [x] Groups
       * [x] Groups
       * [ ] Users
       * [ ] Users
+  * [ ] Membership actions
+    * [x] Join
+    * [x] Invite
+    * [ ] Request join (via invite link)
+    * [ ] Kick / leave
   * [x] Group permissions
   * [x] Group permissions
   * [x] Typing notifications
   * [x] Typing notifications
   * [x] Read receipts
   * [x] Read receipts

+ 8 - 0
mausignald/errors.py

@@ -54,6 +54,12 @@ class ResponseError(RPCError):
         super().__init__(message_override or data["message"])
         super().__init__(message_override or data["message"])
 
 
 
 
+class UnknownResponseError(ResponseError):
+    def __init__(self, message: str) -> None:
+        self.data = {}
+        super(RPCError, self).__init__(message)
+
+
 class InvalidRequest(ResponseError):
 class InvalidRequest(ResponseError):
     def __init__(self, data: Dict[str, Any]) -> None:
     def __init__(self, data: Dict[str, Any]) -> None:
         super().__init__(data, ", ".join(data.get("validationResults", "")))
         super().__init__(data, ", ".join(data.get("validationResults", "")))
@@ -65,4 +71,6 @@ response_error_types = {
 
 
 
 
 def make_response_error(data: Dict[str, Any]) -> ResponseError:
 def make_response_error(data: Dict[str, Any]) -> ResponseError:
+    if isinstance(data, str):
+        return UnknownResponseError(data)
     return response_error_types.get(data["type"], ResponseError)(data)
     return response_error_types.get(data["type"], ResponseError)(data)

+ 9 - 4
mausignald/rpc.py

@@ -11,7 +11,8 @@ import json
 
 
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
-from .errors import NotConnected, UnexpectedError, UnexpectedResponse, make_response_error
+from .errors import (NotConnected, UnexpectedError, UnexpectedResponse, RPCError,
+                     make_response_error)
 
 
 EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
 EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
 
 
@@ -101,12 +102,13 @@ class SignaldRPCClient:
                 except Exception:
                 except Exception:
                     self.log.exception("Exception in RPC event handler")
                     self.log.exception("Exception in RPC event handler")
 
 
-    def _run_response_handlers(self, req_id: UUID, command: str, data: Any) -> None:
+    def _run_response_handlers(self, req_id: UUID, command: str, req: Any) -> None:
         try:
         try:
             waiter = self._response_waiters.pop(req_id)
             waiter = self._response_waiters.pop(req_id)
         except KeyError:
         except KeyError:
             self.log.debug(f"Nobody waiting for response to {req_id}")
             self.log.debug(f"Nobody waiting for response to {req_id}")
             return
             return
+        data = req.get("data")
         if command == "unexpected_error":
         if command == "unexpected_error":
             try:
             try:
                 waiter.set_exception(UnexpectedError(data["message"]))
                 waiter.set_exception(UnexpectedError(data["message"]))
@@ -114,6 +116,8 @@ class SignaldRPCClient:
                 waiter.set_exception(UnexpectedError("Unexpected error with no message"))
                 waiter.set_exception(UnexpectedError("Unexpected error with no message"))
         elif data and "error" in data:
         elif data and "error" in data:
             waiter.set_exception(make_response_error(data["error"]))
             waiter.set_exception(make_response_error(data["error"]))
+        elif "error" in req:
+            waiter.set_exception(make_response_error(req["error"]))
         else:
         else:
             waiter.set_result((command, data))
             waiter.set_result((command, data))
 
 
@@ -135,7 +139,7 @@ class SignaldRPCClient:
         if req_id is None:
         if req_id is None:
             self.loop.create_task(self._run_rpc_handler(req_type, req))
             self.loop.create_task(self._run_rpc_handler(req_type, req))
         else:
         else:
-            self._run_response_handlers(UUID(req_id), req_type, req.get("data"))
+            self._run_response_handlers(UUID(req_id), req_type, req)
 
 
     async def _try_read_loop(self) -> None:
     async def _try_read_loop(self) -> None:
         try:
         try:
@@ -179,7 +183,8 @@ class SignaldRPCClient:
         for req_id, waiter in self._response_waiters.items():
         for req_id, waiter in self._response_waiters.items():
             if not waiter.done():
             if not waiter.done():
                 self.log.trace(f"Abandoning response for {req_id}")
                 self.log.trace(f"Abandoning response for {req_id}")
-                waiter.set_exception(NotConnected("Disconnected from signald before RPC completed"))
+                waiter.set_exception(
+                    NotConnected("Disconnected from signald before RPC completed"))
 
 
     async def _send_request(self, data: Dict[str, Any]) -> None:
     async def _send_request(self, data: Dict[str, Any]) -> None:
         if self._writer is None:
         if self._writer is None:

+ 13 - 2
mausignald/signald.py

@@ -197,6 +197,11 @@ class SignaldClient(SignaldRPCClient):
         else:
         else:
             return None
             return None
 
 
+    async def accept_invitation(self, username: str, group_id: GroupID) -> GroupV2:
+        resp = await self.request("accept_invitation", "accept_invitation", version="v1",
+                                  account=username, groupID=group_id)
+        return GroupV2.deserialize(resp)
+
     async def get_group(self, username: str, group_id: GroupID, revision: int = -1
     async def get_group(self, username: str, group_id: GroupID, revision: int = -1
                         ) -> Optional[GroupV2]:
                         ) -> Optional[GroupV2]:
         resp = await self.request("get_group", "get_group", account=username, groupID=group_id,
         resp = await self.request("get_group", "get_group", account=username, groupID=group_id,
@@ -220,5 +225,11 @@ class SignaldClient(SignaldRPCClient):
                                   recipientAddress=address.serialize())
                                   recipientAddress=address.serialize())
         return GetIdentitiesResponse.deserialize(resp)
         return GetIdentitiesResponse.deserialize(resp)
 
 
-    async def set_profile(self, username: str, new_name: str) -> None:
-        await self.request("set_profile", "profile_set", username=username, name=new_name)
+    async def set_profile(self, username: str, name: Optional[str] = None,
+                          avatar_path: Optional[str] = None) -> None:
+        args = {}
+        if name is not None:
+            args["name"] = name
+        if avatar_path is not None:
+            args["avatarFile"] = avatar_path
+        await self.request("set_profile", "set_profile", account=username, version="v1", **args)

+ 3 - 1
mautrix_signal/commands/auth.py

@@ -102,7 +102,9 @@ async def enter_register_code(evt: CommandEvent) -> None:
             raise
             raise
     else:
     else:
         await evt.sender.on_signin(account)
         await evt.sender.on_signin(account)
-        await evt.reply(f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}")
+        await evt.reply(f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}."
+                        f"\n\n**N.B.** You must set a Signal profile name with `$cmdprefix+sp "
+                        f"set-profile-name <name>` before you can participate in new groups.")
 
 
 
 
 @command_handler(needs_auth=True, management_only=True, help_section=SECTION_AUTH,
 @command_handler(needs_auth=True, management_only=True, help_section=SECTION_AUTH,

+ 7 - 0
mautrix_signal/commands/signal.py

@@ -115,6 +115,13 @@ async def safety_number(evt: CommandEvent) -> None:
         await evt.main_intent.send_message(evt.room_id, content)
         await evt.main_intent.send_message(evt.room_id, content)
 
 
 
 
+@command_handler(needs_auth=True, management_only=False, help_section=SECTION_SIGNAL,
+                 help_text="Set your Signal profile name", help_args="<_name_>")
+async def set_profile_name(evt: CommandEvent) -> None:
+    await evt.bridge.signal.set_profile(evt.sender.username, name=" ".join(evt.args))
+    await evt.reply("Successfully updated profile name")
+
+
 @command_handler(needs_admin=False, needs_auth=True, help_section=SECTION_SIGNAL,
 @command_handler(needs_admin=False, needs_auth=True, help_section=SECTION_SIGNAL,
                  help_text="Sync data from Signal")
                  help_text="Sync data from Signal")
 async def sync(evt: CommandEvent) -> None:
 async def sync(evt: CommandEvent) -> None:

+ 8 - 0
mautrix_signal/db/user.py

@@ -58,6 +58,14 @@ class User:
             return None
             return None
         return cls(**row)
         return cls(**row)
 
 
+    @classmethod
+    async def get_by_uuid(cls, uuid: UUID) -> Optional['User']:
+        q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE uuid=$1'
+        row = await cls.db.fetchrow(q, uuid)
+        if not row:
+            return None
+        return cls(**row)
+
     @classmethod
     @classmethod
     async def all_logged_in(cls) -> List['User']:
     async def all_logged_in(cls) -> List['User']:
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE username IS NOT NULL'
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE username IS NOT NULL'

+ 11 - 0
mautrix_signal/matrix.py

@@ -68,6 +68,17 @@ class MatrixHandler(BaseMatrixHandler):
 
 
         await portal.handle_matrix_leave(user)
         await portal.handle_matrix_leave(user)
 
 
+    async def handle_join(self, room_id: RoomID, user_id: UserID, event_id: EventID) -> None:
+        portal = await po.Portal.get_by_mxid(room_id)
+        if not portal:
+            return
+
+        user = await u.User.get_by_mxid(user_id, create=False)
+        if not user:
+            return
+
+        await portal.handle_matrix_join(user)
+
     @classmethod
     @classmethod
     async def handle_reaction(cls, room_id: RoomID, user_id: UserID, event_id: EventID,
     async def handle_reaction(cls, room_id: RoomID, user_id: UserID, event_id: EventID,
                               content: ReactionEventContent) -> None:
                               content: ReactionEventContent) -> None:

+ 54 - 16
mautrix_signal/portal.py

@@ -13,7 +13,7 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import (Dict, Tuple, Optional, List, Deque, Any, Union, AsyncGenerator, Awaitable,
+from typing import (Dict, Tuple, Optional, List, Deque, Any, Union, AsyncGenerator, Awaitable, Set,
                     Callable, TYPE_CHECKING, cast)
                     Callable, TYPE_CHECKING, cast)
 from collections import deque
 from collections import deque
 from uuid import UUID, uuid4
 from uuid import UUID, uuid4
@@ -27,12 +27,13 @@ import os
 from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
 from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
                               Attachment, GroupID, GroupV2ID, GroupV2, Mention, Sticker,
                               Attachment, GroupID, GroupV2ID, GroupV2, Mention, Sticker,
                               GroupAccessControl, AccessControlMode, GroupMemberRole)
                               GroupAccessControl, AccessControlMode, GroupMemberRole)
+from mausignald.errors import RPCError
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.bridge import BasePortal, async_getter_lock
 from mautrix.bridge import BasePortal, async_getter_lock
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
                            MessageEvent, EncryptedEvent, ContentURI, MediaMessageEventContent,
                            MessageEvent, EncryptedEvent, ContentURI, MediaMessageEventContent,
                            ImageInfo, VideoInfo, FileInfo, AudioInfo, PowerLevelStateEventContent)
                            ImageInfo, VideoInfo, FileInfo, AudioInfo, PowerLevelStateEventContent)
-from mautrix.errors import MatrixError, MForbidden
+from mautrix.errors import MatrixError, MForbidden, IntentError
 
 
 from .db import Portal as DBPortal, Message as DBMessage, Reaction as DBReaction
 from .db import Portal as DBPortal, Message as DBMessage, Reaction as DBReaction
 from .config import Config
 from .config import Config
@@ -78,6 +79,7 @@ class Portal(DBPortal, BasePortal):
     _msgts_dedup: Deque[Tuple[Address, int]]
     _msgts_dedup: Deque[Tuple[Address, int]]
     _reaction_dedup: Deque[Tuple[Address, int, str]]
     _reaction_dedup: Deque[Tuple[Address, int, str]]
     _reaction_lock: asyncio.Lock
     _reaction_lock: asyncio.Lock
+    _pending_members: Optional[Set[UUID]]
 
 
     def __init__(self, chat_id: Union[GroupID, Address], receiver: str,
     def __init__(self, chat_id: Union[GroupID, Address], receiver: str,
                  mxid: Optional[RoomID] = None, name: Optional[str] = None,
                  mxid: Optional[RoomID] = None, name: Optional[str] = None,
@@ -93,6 +95,7 @@ class Portal(DBPortal, BasePortal):
         self._reaction_dedup = deque(maxlen=100)
         self._reaction_dedup = deque(maxlen=100)
         self._last_participant_update = set()
         self._last_participant_update = set()
         self._reaction_lock = asyncio.Lock()
         self._reaction_lock = asyncio.Lock()
+        self._pending_members = None
 
 
     @property
     @property
     def main_intent(self) -> IntentAPI:
     def main_intent(self) -> IntentAPI:
@@ -269,6 +272,25 @@ class Portal(DBPortal, BasePortal):
             except Exception:
             except Exception:
                 self.log.exception("Removing reaction failed")
                 self.log.exception("Removing reaction failed")
 
 
+    async def handle_matrix_join(self, user: 'u.User') -> None:
+        if self._pending_members is None:
+            self.log.debug(f"{user.mxid} ({user.uuid}) joined room, but pending_members is None,"
+                           " updating chat info")
+            await self.update_info(user, GroupV2ID(id=self.chat_id))
+        if self._pending_members is None:
+            self.log.warning(f"Didn't get pending member list after info update, "
+                             f"{user.mxid} ({user.uuid}) may not be in the group on Signal.")
+        elif user.uuid in self._pending_members:
+            self.log.debug(f"{user.mxid} ({user.uuid}) joined room, accepting invite on Signal")
+            try:
+                resp = await self.signal.accept_invitation(user.username, self.chat_id)
+                self._pending_members.remove(user.uuid)
+            except RPCError as e:
+                await self.main_intent.send_notice(self.mxid, "\u26a0 Failed to accept invite "
+                                                              f"on Signal: {e}")
+            else:
+                await self.update_info(user, resp)
+
     async def handle_matrix_leave(self, user: 'u.User') -> None:
     async def handle_matrix_leave(self, user: 'u.User') -> None:
         if self.is_direct:
         if self.is_direct:
             self.log.info(f"{user.mxid} left private chat portal with {self.chat_id}")
             self.log.info(f"{user.mxid} left private chat portal with {self.chat_id}")
@@ -562,7 +584,7 @@ class Portal(DBPortal, BasePortal):
             if existing:
             if existing:
                 try:
                 try:
                     await sender.intent_for(self).redact(existing.mx_room, existing.mxid)
                     await sender.intent_for(self).redact(existing.mx_room, existing.mxid)
-                except MForbidden:
+                except IntentError:
                     await self.main_intent.redact(existing.mx_room, existing.mxid)
                     await self.main_intent.redact(existing.mx_room, existing.mxid)
                 await existing.delete()
                 await existing.delete()
                 self.log.trace(f"Removed {existing} after Signal removal")
                 self.log.trace(f"Removed {existing} after Signal removal")
@@ -632,7 +654,7 @@ class Portal(DBPortal, BasePortal):
         else:
         else:
             raise ValueError(f"Unexpected type for group update_info: {type(info)}")
             raise ValueError(f"Unexpected type for group update_info: {type(info)}")
         changed = await self._update_avatar(info, sender) or changed
         changed = await self._update_avatar(info, sender) or changed
-        await self._update_participants(source, info.members)
+        await self._update_participants(source, info)
         try:
         try:
             await self._update_power_levels(info)
             await self._update_power_levels(info)
         except Exception:
         except Exception:
@@ -687,7 +709,7 @@ class Portal(DBPortal, BasePortal):
         if puppet:
         if puppet:
             try:
             try:
                 await action(puppet.intent_for(self))
                 await action(puppet.intent_for(self))
-            except MForbidden:
+            except (MForbidden, IntentError):
                 await action(self.main_intent)
                 await action(self.main_intent)
         else:
         else:
             await action(self.main_intent)
             await action(self.main_intent)
@@ -714,17 +736,33 @@ class Portal(DBPortal, BasePortal):
             self.avatar_set = False
             self.avatar_set = False
         return True
         return True
 
 
-    async def _update_participants(self, source: 'u.User', participants: List[Address]) -> None:
-        # TODO add support for pending_members and maybe requesting_members?
-        if not self.mxid or not participants:
+    async def _update_participants(self, source: 'u.User', info: ChatInfo) -> None:
+        if not self.mxid or not isinstance(info, (Group, GroupV2)):
             return
             return
 
 
-        for address in participants:
+        pending_members = info.pending_members if isinstance(info, GroupV2) else []
+        self._pending_members = {addr.uuid for addr in pending_members}
+
+        for address in info.members:
+            user = await u.User.get_by_address(address)
+            if user:
+                await self.main_intent.invite_user(self.mxid, user.mxid)
+
             puppet = await p.Puppet.get_by_address(address)
             puppet = await p.Puppet.get_by_address(address)
             if not puppet.name:
             if not puppet.name:
                 await source.sync_contact(address)
                 await source.sync_contact(address)
             await puppet.intent_for(self).ensure_joined(self.mxid)
             await puppet.intent_for(self).ensure_joined(self.mxid)
 
 
+        for address in pending_members:
+            user = await u.User.get_by_address(address)
+            if user:
+                await self.main_intent.invite_user(self.mxid, user.mxid)
+
+            puppet = await p.Puppet.get_by_address(address)
+            if not puppet.name:
+                await source.sync_contact(address)
+            await self.main_intent.invite_user(self.mxid, puppet.intent_for(self).mxid)
+
     async def _update_power_levels(self, info: ChatInfo) -> None:
     async def _update_power_levels(self, info: ChatInfo) -> None:
         if not self.mxid:
         if not self.mxid:
             return
             return
@@ -806,11 +844,11 @@ class Portal(DBPortal, BasePortal):
             return await self._create_matrix_room(source, info)
             return await self._create_matrix_room(source, info)
 
 
     async def _update_matrix_room(self, source: 'u.User', info: ChatInfo) -> None:
     async def _update_matrix_room(self, source: 'u.User', info: ChatInfo) -> None:
-        await self.main_intent.invite_user(self.mxid, source.mxid, check_cache=True)
-        puppet = await p.Puppet.get_by_custom_mxid(source.mxid)
-        if puppet:
-            did_join = await puppet.intent.ensure_joined(self.mxid)
-            if did_join and self.is_direct:
+        if self.is_direct:
+            puppet = await p.Puppet.get_by_custom_mxid(source.mxid)
+            if puppet:
+                await self.main_intent.invite_user(self.mxid, source.mxid, check_cache=True)
+                await puppet.intent.ensure_joined(self.mxid)
                 await source.update_direct_chats({self.main_intent.mxid: [self.mxid]})
                 await source.update_direct_chats({self.main_intent.mxid: [self.mxid]})
 
 
         await self.update_info(source, info)
         await self.update_info(source, info)
@@ -828,7 +866,7 @@ class Portal(DBPortal, BasePortal):
         else:
         else:
             if isinstance(info, GroupV2):
             if isinstance(info, GroupV2):
                 ac = info.access_control
                 ac = info.access_control
-                for detail in info.member_detail:
+                for detail in info.member_detail + info.pending_member_detail:
                     puppet = await p.Puppet.get_by_address(Address(uuid=detail.uuid))
                     puppet = await p.Puppet.get_by_address(Address(uuid=detail.uuid))
                     level = 50 if detail.role == GroupMemberRole.ADMINISTRATOR else 0
                     level = 50 if detail.role == GroupMemberRole.ADMINISTRATOR else 0
                     levels.users[puppet.intent_for(self).mxid] = level
                     levels.users[puppet.intent_for(self).mxid] = level
@@ -916,7 +954,7 @@ class Portal(DBPortal, BasePortal):
         self.log.debug(f"Matrix room created: {self.mxid}")
         self.log.debug(f"Matrix room created: {self.mxid}")
         self.by_mxid[self.mxid] = self
         self.by_mxid[self.mxid] = self
         if not self.is_direct:
         if not self.is_direct:
-            await self._update_participants(source, info.members)
+            await self._update_participants(source, info)
         else:
         else:
             puppet = await p.Puppet.get_by_custom_mxid(source.mxid)
             puppet = await p.Puppet.get_by_custom_mxid(source.mxid)
             if puppet:
             if puppet:

+ 1 - 0
mautrix_signal/puppet.py

@@ -128,6 +128,7 @@ class Puppet(DBPuppet, BasePuppet):
         user = await u.User.get_by_username(self.number)
         user = await u.User.get_by_username(self.number)
         if user and not user.uuid:
         if user and not user.uuid:
             user.uuid = self.uuid
             user.uuid = self.uuid
+            user.by_uuid[user.uuid] = user
             await user.update()
             await user.update()
         await self._set_uuid(uuid)
         await self._set_uuid(uuid)
         self.by_uuid[self.uuid] = self
         self.by_uuid[self.uuid] = self

+ 35 - 0
mautrix_signal/user.py

@@ -39,6 +39,7 @@ METRIC_CONNECTED = Gauge('bridge_connected', 'Bridge users connected to Signal')
 class User(DBUser, BaseUser):
 class User(DBUser, BaseUser):
     by_mxid: Dict[UserID, 'User'] = {}
     by_mxid: Dict[UserID, 'User'] = {}
     by_username: Dict[str, 'User'] = {}
     by_username: Dict[str, 'User'] = {}
+    by_uuid: Dict[UUID, 'User'] = {}
     config: Config
     config: Config
     az: AppService
     az: AppService
     loop: asyncio.AbstractEventLoop
     loop: asyncio.AbstractEventLoop
@@ -80,6 +81,10 @@ class User(DBUser, BaseUser):
         if not self.username:
         if not self.username:
             return
             return
         username = self.username
         username = self.username
+        if self.uuid and self.by_uuid.get(self.uuid) == self:
+            del self.by_uuid[self.uuid]
+        if self.username and self.by_username.get(self.username) == self:
+            del self.by_username[self.username]
         self.username = None
         self.username = None
         self.uuid = None
         self.uuid = None
         await self.update()
         await self.update()
@@ -99,6 +104,7 @@ class User(DBUser, BaseUser):
     async def on_signin(self, account: Account) -> None:
     async def on_signin(self, account: Account) -> None:
         self.username = account.username
         self.username = account.username
         self.uuid = account.uuid
         self.uuid = account.uuid
+        self._add_to_cache()
         await self.update()
         await self.update()
         await self.bridge.signal.subscribe(self.username)
         await self.bridge.signal.subscribe(self.username)
         self.loop.create_task(self.sync())
         self.loop.create_task(self.sync())
@@ -118,6 +124,9 @@ class User(DBUser, BaseUser):
 
 
     async def _sync_puppet(self) -> None:
     async def _sync_puppet(self) -> None:
         puppet = await pu.Puppet.get_by_address(self.address)
         puppet = await pu.Puppet.get_by_address(self.address)
+        if puppet.uuid and not self.uuid:
+            self.uuid = puppet.uuid
+            self.by_uuid[self.uuid] = self
         if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
         if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
             self.log.info(f"Automatically enabling custom puppet")
             self.log.info(f"Automatically enabling custom puppet")
             await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
             await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
@@ -196,6 +205,8 @@ class User(DBUser, BaseUser):
         self.by_mxid[self.mxid] = self
         self.by_mxid[self.mxid] = self
         if self.username:
         if self.username:
             self.by_username[self.username] = self
             self.by_username[self.username] = self
+        if self.uuid:
+            self.by_uuid[self.uuid] = self
 
 
     @classmethod
     @classmethod
     @async_getter_lock
     @async_getter_lock
@@ -236,6 +247,30 @@ class User(DBUser, BaseUser):
 
 
         return None
         return None
 
 
+    @classmethod
+    @async_getter_lock
+    async def get_by_uuid(cls, uuid: UUID) -> Optional['User']:
+        try:
+            return cls.by_uuid[uuid]
+        except KeyError:
+            pass
+
+        user = cast(cls, await super().get_by_uuid(uuid))
+        if user is not None:
+            user._add_to_cache()
+            return user
+
+        return None
+
+    @classmethod
+    async def get_by_address(cls, address: Address) -> Optional['User']:
+        if address.uuid:
+            return await cls.get_by_uuid(address.uuid)
+        elif address.number:
+            return await cls.get_by_username(address.number)
+        else:
+            raise ValueError("Given address is blank")
+
     @classmethod
     @classmethod
     async def all_logged_in(cls) -> AsyncGenerator['User', None]:
     async def all_logged_in(cls) -> AsyncGenerator['User', None]:
         users = await super().all_logged_in()
         users = await super().all_logged_in()