Bläddra i källkod

Merge remote-tracking branch 'maltee1/bridge_join_rules'

Tulir Asokan 2 år sedan
förälder
incheckning
ab235c501f
3 ändrade filer med 87 tillägg och 6 borttagningar
  1. 21 4
      mautrix_signal/commands/signal.py
  2. 3 0
      mautrix_signal/matrix.py
  3. 63 2
      mautrix_signal/portal.py

+ 21 - 4
mautrix_signal/commands/signal.py

@@ -24,7 +24,14 @@ from mausignald.errors import UnknownIdentityKey, UnregisteredUserError
 from mausignald.types import Address, GroupID, TrustLevel
 from mautrix.appservice import IntentAPI
 from mautrix.bridge.commands import SECTION_ADMIN, HelpSection, command_handler
-from mautrix.types import ContentURI, EventID, EventType, PowerLevelStateEventContent, RoomID
+from mautrix.types import (
+    ContentURI,
+    EventID,
+    EventType,
+    JoinRule,
+    PowerLevelStateEventContent,
+    RoomID,
+)
 
 from .. import portal as po, puppet as pu
 from ..util import normalize_number, user_has_power_level
@@ -327,7 +334,7 @@ async def create(evt: CommandEvent) -> EventID:
     if evt.portal:
         return await evt.reply("This is already a portal room.")
 
-    title, about, levels, encrypted, avatar_url = await get_initial_state(
+    title, about, levels, encrypted, avatar_url, join_rule = await get_initial_state(
         evt.az.intent, evt.room_id
     )
 
@@ -342,7 +349,7 @@ async def create(evt: CommandEvent) -> EventID:
     )
     await warn_missing_power(levels, evt)
 
-    await portal.create_signal_group(evt.sender, levels)
+    await portal.create_signal_group(evt.sender, levels, join_rule)
     await evt.reply(f"Signal chat created. ID: {portal.chat_id}")
 
 
@@ -526,13 +533,21 @@ async def _locked_confirm_bridge(
 
 async def get_initial_state(
     intent: IntentAPI, room_id: RoomID
-) -> tuple[str | None, str | None, PowerLevelStateEventContent | None, bool, ContentURI | None]:
+) -> tuple[
+    str | None,
+    str | None,
+    PowerLevelStateEventContent | None,
+    bool,
+    ContentURI | None,
+    JoinRule | None,
+]:
     state = await intent.get_state(room_id)
     title: str | None = None
     about: str | None = None
     levels: PowerLevelStateEventContent | None = None
     encrypted: bool = False
     avatar_url: ContentURI | None = None
+    join_rule: JoinRule | None = None
     for event in state:
         try:
             if event.type == EventType.ROOM_NAME:
@@ -547,6 +562,8 @@ async def get_initial_state(
                 encrypted = True
             elif event.type == EventType.ROOM_AVATAR:
                 avatar_url = event.content.url
+            elif event.type == EventType.ROOM_JOIN_RULES:
+                join_rule = event.content.join_rule
         except KeyError:
             # Some state event probably has empty content
             pass

+ 3 - 0
mautrix_signal/matrix.py

@@ -247,6 +247,7 @@ class MatrixHandler(BaseMatrixHandler):
             EventType.ROOM_TOPIC,
             EventType.ROOM_AVATAR,
             EventType.ROOM_POWER_LEVELS,
+            EventType.ROOM_JOIN_RULES,
         ):
             return
 
@@ -265,6 +266,8 @@ class MatrixHandler(BaseMatrixHandler):
             await portal.handle_matrix_topic(user, evt.content.topic)
         elif evt.type == EventType.ROOM_POWER_LEVELS:
             await portal.handle_matrix_power_level(user, evt.content, evt.unsigned.prev_content)
+        elif evt.type == EventType.ROOM_JOIN_RULES:
+            await portal.handle_matrix_join_rules(user, evt.content.join_rule)
 
     async def allow_message(self, user: u.User) -> bool:
         return user.relay_whitelisted

+ 63 - 2
mautrix_signal/portal.py

@@ -69,7 +69,6 @@ from mautrix.types import (
     FileInfo,
     ImageInfo,
     JoinRule,
-    LocationMessageEventContent,
     MediaMessageEventContent,
     Membership,
     MessageEvent,
@@ -1037,6 +1036,32 @@ class Portal(DBPortal, BasePortal):
                     await self.signal.get_group(sender.username, self.chat_id)
                 )
 
+    async def handle_matrix_join_rules(self, sender: u.User, join_rule: JoinRule) -> None:
+        if join_rule == JoinRule.PUBLIC:
+            link_access = AccessControlMode.ANY
+        elif join_rule == JoinRule.INVITE:
+            link_access = AccessControlMode.UNSATISFIABLE
+        else:
+            link_access = AccessControlMode.ADMINISTRATOR
+        sender, is_relay = await self.get_relay_sender(sender, "join_rule change")
+        if not sender:
+            return
+
+        try:
+            update_meta = await self.signal.update_group(
+                sender.username,
+                self.chat_id,
+                update_access_control=GroupAccessControl(
+                    attributes=None, members=None, link=link_access
+                ),
+            )
+            self.revision = update_meta.revision
+        except Exception as e:
+            self.log.exception(f"Failed to update Signal link access control: {e}")
+            await self._update_join_rules(
+                await self.signal.get_group(sender.username, self.chat_id)
+            )
+
     # endregion
     # region Signal event handling
 
@@ -1715,7 +1740,7 @@ class Portal(DBPortal, BasePortal):
     # region Matrix -> Signal metadata
 
     async def create_signal_group(
-        self, source: u.User, levels: PowerLevelStateEventContent
+        self, source: u.User, levels: PowerLevelStateEventContent, join_rule: JoinRule
     ) -> None:
         user_mxids = await self.az.intent.get_room_members(
             self.mxid, (Membership.JOIN, Membership.INVITE)
@@ -1749,6 +1774,7 @@ class Portal(DBPortal, BasePortal):
         if self.topic:
             await self.signal.update_group(source.username, self.chat_id, description=self.topic)
         await self.handle_matrix_power_level(source, levels)
+        await self.handle_matrix_join_rules(source, join_rule)
         await self.update()
         await self.update_bridge_info()
 
@@ -1818,6 +1844,10 @@ class Portal(DBPortal, BasePortal):
             await self._update_power_levels(info)
         except Exception:
             self.log.warning("Error updating power levels", exc_info=True)
+        try:
+            await self._update_join_rules(info)
+        except:
+            self.log.warning("Error updating join rules", exc_info=True)
         if changed:
             await self.update_bridge_info()
             await self.update()
@@ -2033,6 +2063,37 @@ class Portal(DBPortal, BasePortal):
         power_levels = await self._get_power_levels(power_levels, info=info, is_initial=False)
         await self.main_intent.set_power_levels(self.mxid, power_levels)
 
+    async def _update_join_rules(self, info: ChatInfo) -> None:
+        if not self.mxid:
+            return
+        link_access = info.access_control.link
+        old_join_rule = await self._get_join_rule()
+        if link_access == AccessControlMode.ANY:
+            if self.config["bridge.public_portals"]:
+                join_rule = JoinRule.PUBLIC
+            elif old_join_rule and old_join_rule == JoinRule.PUBLIC:
+                return
+        elif link_access == AccessControlMode.ADMINISTRATOR:
+            if old_join_rule and (
+                old_join_rule == JoinRule.KNOCK or old_join_rule == JoinRule.RESTRICTED
+            ):
+                return
+            else:
+                join_rule = JoinRule.KNOCK
+        else:
+            join_rule = JoinRule.INVITE
+        await self.main_intent.set_join_rule(self.mxid, join_rule)
+
+    async def _get_join_rule(self) -> JoinRule:
+        state = await self.main_intent.get_state(self.mxid)
+        for event in state:
+            try:
+                if event.type == EventType.ROOM_JOIN_RULES:
+                    return event.content.join_rule
+            except KeyError:
+                pass
+        return None
+
     # endregion
     # region Bridge info state event