Parcourir la source

Add support for accepting Signal group invites

Tulir Asokan il y a 4 ans
Parent
commit
90017f3c60

+ 10 - 0
ROADMAP.md

@@ -17,6 +17,11 @@
   * [x] Group info changes
     * [x] Name
     * [x] Avatar
+  * [ ] Membership actions
+    * [x] Join (accept invite)
+    * [ ] Invite
+    * [ ] Leave
+    * [ ] Kick
   * [ ] Typing notifications
   * [ ] Read receipts (currently partial support, only marks last message)
   * [x] Delivery receipts (sent after message is bridged)
@@ -40,6 +45,11 @@
     * [ ] Real time
       * [x] Groups
       * [ ] Users
+  * [ ] Membership actions
+    * [x] Join
+    * [x] Invite
+    * [ ] Request join (via invite link)
+    * [ ] Kick / leave
   * [x] Group permissions
   * [x] Typing notifications
   * [x] Read receipts

+ 8 - 0
mausignald/errors.py

@@ -54,6 +54,12 @@ class ResponseError(RPCError):
         super().__init__(message_override or data["message"])
 
 
+class UnknownResponseError(ResponseError):
+    def __init__(self, message: str) -> None:
+        self.data = {}
+        super(RPCError, self).__init__(message)
+
+
 class InvalidRequest(ResponseError):
     def __init__(self, data: Dict[str, Any]) -> None:
         super().__init__(data, ", ".join(data.get("validationResults", "")))
@@ -65,4 +71,6 @@ response_error_types = {
 
 
 def make_response_error(data: Dict[str, Any]) -> ResponseError:
+    if isinstance(data, str):
+        return UnknownResponseError(data)
     return response_error_types.get(data["type"], ResponseError)(data)

+ 9 - 4
mausignald/rpc.py

@@ -11,7 +11,8 @@ import json
 
 from mautrix.util.logging import TraceLogger
 
-from .errors import NotConnected, UnexpectedError, UnexpectedResponse, make_response_error
+from .errors import (NotConnected, UnexpectedError, UnexpectedResponse, RPCError,
+                     make_response_error)
 
 EventHandler = Callable[[Dict[str, Any]], Awaitable[None]]
 
@@ -101,12 +102,13 @@ class SignaldRPCClient:
                 except Exception:
                     self.log.exception("Exception in RPC event handler")
 
-    def _run_response_handlers(self, req_id: UUID, command: str, data: Any) -> None:
+    def _run_response_handlers(self, req_id: UUID, command: str, req: Any) -> None:
         try:
             waiter = self._response_waiters.pop(req_id)
         except KeyError:
             self.log.debug(f"Nobody waiting for response to {req_id}")
             return
+        data = req.get("data")
         if command == "unexpected_error":
             try:
                 waiter.set_exception(UnexpectedError(data["message"]))
@@ -114,6 +116,8 @@ class SignaldRPCClient:
                 waiter.set_exception(UnexpectedError("Unexpected error with no message"))
         elif data and "error" in data:
             waiter.set_exception(make_response_error(data["error"]))
+        elif "error" in req:
+            waiter.set_exception(make_response_error(req["error"]))
         else:
             waiter.set_result((command, data))
 
@@ -135,7 +139,7 @@ class SignaldRPCClient:
         if req_id is None:
             self.loop.create_task(self._run_rpc_handler(req_type, req))
         else:
-            self._run_response_handlers(UUID(req_id), req_type, req.get("data"))
+            self._run_response_handlers(UUID(req_id), req_type, req)
 
     async def _try_read_loop(self) -> None:
         try:
@@ -179,7 +183,8 @@ class SignaldRPCClient:
         for req_id, waiter in self._response_waiters.items():
             if not waiter.done():
                 self.log.trace(f"Abandoning response for {req_id}")
-                waiter.set_exception(NotConnected("Disconnected from signald before RPC completed"))
+                waiter.set_exception(
+                    NotConnected("Disconnected from signald before RPC completed"))
 
     async def _send_request(self, data: Dict[str, Any]) -> None:
         if self._writer is None:

+ 13 - 2
mausignald/signald.py

@@ -197,6 +197,11 @@ class SignaldClient(SignaldRPCClient):
         else:
             return None
 
+    async def accept_invitation(self, username: str, group_id: GroupID) -> GroupV2:
+        resp = await self.request("accept_invitation", "accept_invitation", version="v1",
+                                  account=username, groupID=group_id)
+        return GroupV2.deserialize(resp)
+
     async def get_group(self, username: str, group_id: GroupID, revision: int = -1
                         ) -> Optional[GroupV2]:
         resp = await self.request("get_group", "get_group", account=username, groupID=group_id,
@@ -220,5 +225,11 @@ class SignaldClient(SignaldRPCClient):
                                   recipientAddress=address.serialize())
         return GetIdentitiesResponse.deserialize(resp)
 
-    async def set_profile(self, username: str, new_name: str) -> None:
-        await self.request("set_profile", "profile_set", username=username, name=new_name)
+    async def set_profile(self, username: str, name: Optional[str] = None,
+                          avatar_path: Optional[str] = None) -> None:
+        args = {}
+        if name is not None:
+            args["name"] = name
+        if avatar_path is not None:
+            args["avatarFile"] = avatar_path
+        await self.request("set_profile", "set_profile", account=username, version="v1", **args)

+ 3 - 1
mautrix_signal/commands/auth.py

@@ -102,7 +102,9 @@ async def enter_register_code(evt: CommandEvent) -> None:
             raise
     else:
         await evt.sender.on_signin(account)
-        await evt.reply(f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}")
+        await evt.reply(f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}."
+                        f"\n\n**N.B.** You must set a Signal profile name with `$cmdprefix+sp "
+                        f"set-profile-name <name>` before you can participate in new groups.")
 
 
 @command_handler(needs_auth=True, management_only=True, help_section=SECTION_AUTH,

+ 7 - 0
mautrix_signal/commands/signal.py

@@ -115,6 +115,13 @@ async def safety_number(evt: CommandEvent) -> None:
         await evt.main_intent.send_message(evt.room_id, content)
 
 
+@command_handler(needs_auth=True, management_only=False, help_section=SECTION_SIGNAL,
+                 help_text="Set your Signal profile name", help_args="<_name_>")
+async def set_profile_name(evt: CommandEvent) -> None:
+    await evt.bridge.signal.set_profile(evt.sender.username, name=" ".join(evt.args))
+    await evt.reply("Successfully updated profile name")
+
+
 @command_handler(needs_admin=False, needs_auth=True, help_section=SECTION_SIGNAL,
                  help_text="Sync data from Signal")
 async def sync(evt: CommandEvent) -> None:

+ 8 - 0
mautrix_signal/db/user.py

@@ -58,6 +58,14 @@ class User:
             return None
         return cls(**row)
 
+    @classmethod
+    async def get_by_uuid(cls, uuid: UUID) -> Optional['User']:
+        q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE uuid=$1'
+        row = await cls.db.fetchrow(q, uuid)
+        if not row:
+            return None
+        return cls(**row)
+
     @classmethod
     async def all_logged_in(cls) -> List['User']:
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE username IS NOT NULL'

+ 11 - 0
mautrix_signal/matrix.py

@@ -68,6 +68,17 @@ class MatrixHandler(BaseMatrixHandler):
 
         await portal.handle_matrix_leave(user)
 
+    async def handle_join(self, room_id: RoomID, user_id: UserID, event_id: EventID) -> None:
+        portal = await po.Portal.get_by_mxid(room_id)
+        if not portal:
+            return
+
+        user = await u.User.get_by_mxid(user_id, create=False)
+        if not user:
+            return
+
+        await portal.handle_matrix_join(user)
+
     @classmethod
     async def handle_reaction(cls, room_id: RoomID, user_id: UserID, event_id: EventID,
                               content: ReactionEventContent) -> None:

+ 54 - 16
mautrix_signal/portal.py

@@ -13,7 +13,7 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import (Dict, Tuple, Optional, List, Deque, Any, Union, AsyncGenerator, Awaitable,
+from typing import (Dict, Tuple, Optional, List, Deque, Any, Union, AsyncGenerator, Awaitable, Set,
                     Callable, TYPE_CHECKING, cast)
 from collections import deque
 from uuid import UUID, uuid4
@@ -27,12 +27,13 @@ import os
 from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
                               Attachment, GroupID, GroupV2ID, GroupV2, Mention, Sticker,
                               GroupAccessControl, AccessControlMode, GroupMemberRole)
+from mausignald.errors import RPCError
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.bridge import BasePortal, async_getter_lock
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
                            MessageEvent, EncryptedEvent, ContentURI, MediaMessageEventContent,
                            ImageInfo, VideoInfo, FileInfo, AudioInfo, PowerLevelStateEventContent)
-from mautrix.errors import MatrixError, MForbidden
+from mautrix.errors import MatrixError, MForbidden, IntentError
 
 from .db import Portal as DBPortal, Message as DBMessage, Reaction as DBReaction
 from .config import Config
@@ -78,6 +79,7 @@ class Portal(DBPortal, BasePortal):
     _msgts_dedup: Deque[Tuple[Address, int]]
     _reaction_dedup: Deque[Tuple[Address, int, str]]
     _reaction_lock: asyncio.Lock
+    _pending_members: Optional[Set[UUID]]
 
     def __init__(self, chat_id: Union[GroupID, Address], receiver: str,
                  mxid: Optional[RoomID] = None, name: Optional[str] = None,
@@ -93,6 +95,7 @@ class Portal(DBPortal, BasePortal):
         self._reaction_dedup = deque(maxlen=100)
         self._last_participant_update = set()
         self._reaction_lock = asyncio.Lock()
+        self._pending_members = None
 
     @property
     def main_intent(self) -> IntentAPI:
@@ -269,6 +272,25 @@ class Portal(DBPortal, BasePortal):
             except Exception:
                 self.log.exception("Removing reaction failed")
 
+    async def handle_matrix_join(self, user: 'u.User') -> None:
+        if self._pending_members is None:
+            self.log.debug(f"{user.mxid} ({user.uuid}) joined room, but pending_members is None,"
+                           " updating chat info")
+            await self.update_info(user, GroupV2ID(id=self.chat_id))
+        if self._pending_members is None:
+            self.log.warning(f"Didn't get pending member list after info update, "
+                             f"{user.mxid} ({user.uuid}) may not be in the group on Signal.")
+        elif user.uuid in self._pending_members:
+            self.log.debug(f"{user.mxid} ({user.uuid}) joined room, accepting invite on Signal")
+            try:
+                resp = await self.signal.accept_invitation(user.username, self.chat_id)
+                self._pending_members.remove(user.uuid)
+            except RPCError as e:
+                await self.main_intent.send_notice(self.mxid, "\u26a0 Failed to accept invite "
+                                                              f"on Signal: {e}")
+            else:
+                await self.update_info(user, resp)
+
     async def handle_matrix_leave(self, user: 'u.User') -> None:
         if self.is_direct:
             self.log.info(f"{user.mxid} left private chat portal with {self.chat_id}")
@@ -562,7 +584,7 @@ class Portal(DBPortal, BasePortal):
             if existing:
                 try:
                     await sender.intent_for(self).redact(existing.mx_room, existing.mxid)
-                except MForbidden:
+                except IntentError:
                     await self.main_intent.redact(existing.mx_room, existing.mxid)
                 await existing.delete()
                 self.log.trace(f"Removed {existing} after Signal removal")
@@ -632,7 +654,7 @@ class Portal(DBPortal, BasePortal):
         else:
             raise ValueError(f"Unexpected type for group update_info: {type(info)}")
         changed = await self._update_avatar(info, sender) or changed
-        await self._update_participants(source, info.members)
+        await self._update_participants(source, info)
         try:
             await self._update_power_levels(info)
         except Exception:
@@ -687,7 +709,7 @@ class Portal(DBPortal, BasePortal):
         if puppet:
             try:
                 await action(puppet.intent_for(self))
-            except MForbidden:
+            except (MForbidden, IntentError):
                 await action(self.main_intent)
         else:
             await action(self.main_intent)
@@ -714,17 +736,33 @@ class Portal(DBPortal, BasePortal):
             self.avatar_set = False
         return True
 
-    async def _update_participants(self, source: 'u.User', participants: List[Address]) -> None:
-        # TODO add support for pending_members and maybe requesting_members?
-        if not self.mxid or not participants:
+    async def _update_participants(self, source: 'u.User', info: ChatInfo) -> None:
+        if not self.mxid or not isinstance(info, (Group, GroupV2)):
             return
 
-        for address in participants:
+        pending_members = info.pending_members if isinstance(info, GroupV2) else []
+        self._pending_members = {addr.uuid for addr in pending_members}
+
+        for address in info.members:
+            user = await u.User.get_by_address(address)
+            if user:
+                await self.main_intent.invite_user(self.mxid, user.mxid)
+
             puppet = await p.Puppet.get_by_address(address)
             if not puppet.name:
                 await source.sync_contact(address)
             await puppet.intent_for(self).ensure_joined(self.mxid)
 
+        for address in pending_members:
+            user = await u.User.get_by_address(address)
+            if user:
+                await self.main_intent.invite_user(self.mxid, user.mxid)
+
+            puppet = await p.Puppet.get_by_address(address)
+            if not puppet.name:
+                await source.sync_contact(address)
+            await self.main_intent.invite_user(self.mxid, puppet.intent_for(self).mxid)
+
     async def _update_power_levels(self, info: ChatInfo) -> None:
         if not self.mxid:
             return
@@ -806,11 +844,11 @@ class Portal(DBPortal, BasePortal):
             return await self._create_matrix_room(source, info)
 
     async def _update_matrix_room(self, source: 'u.User', info: ChatInfo) -> None:
-        await self.main_intent.invite_user(self.mxid, source.mxid, check_cache=True)
-        puppet = await p.Puppet.get_by_custom_mxid(source.mxid)
-        if puppet:
-            did_join = await puppet.intent.ensure_joined(self.mxid)
-            if did_join and self.is_direct:
+        if self.is_direct:
+            puppet = await p.Puppet.get_by_custom_mxid(source.mxid)
+            if puppet:
+                await self.main_intent.invite_user(self.mxid, source.mxid, check_cache=True)
+                await puppet.intent.ensure_joined(self.mxid)
                 await source.update_direct_chats({self.main_intent.mxid: [self.mxid]})
 
         await self.update_info(source, info)
@@ -828,7 +866,7 @@ class Portal(DBPortal, BasePortal):
         else:
             if isinstance(info, GroupV2):
                 ac = info.access_control
-                for detail in info.member_detail:
+                for detail in info.member_detail + info.pending_member_detail:
                     puppet = await p.Puppet.get_by_address(Address(uuid=detail.uuid))
                     level = 50 if detail.role == GroupMemberRole.ADMINISTRATOR else 0
                     levels.users[puppet.intent_for(self).mxid] = level
@@ -916,7 +954,7 @@ class Portal(DBPortal, BasePortal):
         self.log.debug(f"Matrix room created: {self.mxid}")
         self.by_mxid[self.mxid] = self
         if not self.is_direct:
-            await self._update_participants(source, info.members)
+            await self._update_participants(source, info)
         else:
             puppet = await p.Puppet.get_by_custom_mxid(source.mxid)
             if puppet:

+ 1 - 0
mautrix_signal/puppet.py

@@ -128,6 +128,7 @@ class Puppet(DBPuppet, BasePuppet):
         user = await u.User.get_by_username(self.number)
         if user and not user.uuid:
             user.uuid = self.uuid
+            user.by_uuid[user.uuid] = user
             await user.update()
         await self._set_uuid(uuid)
         self.by_uuid[self.uuid] = self

+ 35 - 0
mautrix_signal/user.py

@@ -39,6 +39,7 @@ METRIC_CONNECTED = Gauge('bridge_connected', 'Bridge users connected to Signal')
 class User(DBUser, BaseUser):
     by_mxid: Dict[UserID, 'User'] = {}
     by_username: Dict[str, 'User'] = {}
+    by_uuid: Dict[UUID, 'User'] = {}
     config: Config
     az: AppService
     loop: asyncio.AbstractEventLoop
@@ -80,6 +81,10 @@ class User(DBUser, BaseUser):
         if not self.username:
             return
         username = self.username
+        if self.uuid and self.by_uuid.get(self.uuid) == self:
+            del self.by_uuid[self.uuid]
+        if self.username and self.by_username.get(self.username) == self:
+            del self.by_username[self.username]
         self.username = None
         self.uuid = None
         await self.update()
@@ -99,6 +104,7 @@ class User(DBUser, BaseUser):
     async def on_signin(self, account: Account) -> None:
         self.username = account.username
         self.uuid = account.uuid
+        self._add_to_cache()
         await self.update()
         await self.bridge.signal.subscribe(self.username)
         self.loop.create_task(self.sync())
@@ -118,6 +124,9 @@ class User(DBUser, BaseUser):
 
     async def _sync_puppet(self) -> None:
         puppet = await pu.Puppet.get_by_address(self.address)
+        if puppet.uuid and not self.uuid:
+            self.uuid = puppet.uuid
+            self.by_uuid[self.uuid] = self
         if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
             self.log.info(f"Automatically enabling custom puppet")
             await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
@@ -196,6 +205,8 @@ class User(DBUser, BaseUser):
         self.by_mxid[self.mxid] = self
         if self.username:
             self.by_username[self.username] = self
+        if self.uuid:
+            self.by_uuid[self.uuid] = self
 
     @classmethod
     @async_getter_lock
@@ -236,6 +247,30 @@ class User(DBUser, BaseUser):
 
         return None
 
+    @classmethod
+    @async_getter_lock
+    async def get_by_uuid(cls, uuid: UUID) -> Optional['User']:
+        try:
+            return cls.by_uuid[uuid]
+        except KeyError:
+            pass
+
+        user = cast(cls, await super().get_by_uuid(uuid))
+        if user is not None:
+            user._add_to_cache()
+            return user
+
+        return None
+
+    @classmethod
+    async def get_by_address(cls, address: Address) -> Optional['User']:
+        if address.uuid:
+            return await cls.get_by_uuid(address.uuid)
+        elif address.number:
+            return await cls.get_by_username(address.number)
+        else:
+            raise ValueError("Given address is blank")
+
     @classmethod
     async def all_logged_in(cls) -> AsyncGenerator['User', None]:
         users = await super().all_logged_in()