浏览代码

Add initial group v2 support

Tulir Asokan 4 年之前
父节点
当前提交
bd3fa8fcef
共有 5 个文件被更改,包括 68 次插入17 次删除
  1. 4 3
      mausignald/signald.py
  2. 18 1
      mausignald/types.py
  3. 21 8
      mautrix_signal/portal.py
  4. 9 2
      mautrix_signal/signal.py
  5. 16 3
      mautrix_signal/user.py

+ 4 - 3
mausignald/signald.py

@@ -12,7 +12,7 @@ from mautrix.util.logging import TraceLogger
 from .rpc import SignaldRPCClient
 from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
-                    Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction)
+                    Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2)
 
 T = TypeVar('T')
 EventHandler = Callable[[T], Awaitable[None]]
@@ -146,9 +146,10 @@ class SignaldClient(SignaldRPCClient):
         contacts = await self.request("list_contacts", "contact_list", username=username)
         return [Contact.deserialize(contact) for contact in contacts]
 
-    async def list_groups(self, username: str) -> List[Group]:
+    async def list_groups(self, username: str) -> List[Union[Group, GroupV2]]:
         resp = await self.request("list_groups", "group_list", username=username)
-        return [Group.deserialize(group) for group in resp["groups"]]
+        return ([Group.deserialize(group) for group in resp["groups"]]
+                + [GroupV2.deserialize(group) for group in resp["groupsv2"]])
 
     async def get_profile(self, username: str, address: Address) -> Optional[Profile]:
         try:

+ 18 - 1
mausignald/types.py

@@ -108,6 +108,22 @@ class Group(SerializableAttrs['Group']):
     avatar_id: int = attr.ib(default=0, metadata={"json": "avatarId"})
 
 
+@dataclass
+class GroupV2ID(SerializableAttrs['GroupV2ID']):
+    id: GroupID
+    revision: int
+
+
+@dataclass
+class GroupV2(GroupV2ID, SerializableAttrs['GroupV2']):
+    title: str
+    master_key: str = attr.ib(metadata={"json": "masterKey"})
+    members: List[Address]
+    pending_members: List[Address] = attr.ib(metadata={"json": "pendingMembers"})
+    requesting_members: List[Address] = attr.ib(metadata={"json": "requestingMembers"})
+    timer: int
+
+
 @dataclass
 class Attachment(SerializableAttrs['Attachment']):
     width: int = 0
@@ -161,9 +177,10 @@ class MessageData(SerializableAttrs['MessageData']):
     reaction: Optional[Reaction] = None
     attachments: List[Attachment] = attr.ib(factory=lambda: [])
     sticker: Optional[Sticker] = None
-    # TODO mentions (although signald doesn't support group v2 yet)
+    # TODO mentions
 
     group: Optional[Group] = None
+    group_v2: Optional[GroupV2ID] = attr.ib(default=None, metadata={"json": "groupV2"})
 
     end_session: bool = attr.ib(default=False, metadata={"json": "endSession"})
     expires_in_seconds: int = attr.ib(default=0, metadata={"json": "expiresInSeconds"})

+ 21 - 8
mautrix_signal/portal.py

@@ -25,7 +25,7 @@ import time
 import os
 
 from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
-                              Attachment, GroupID)
+                              Attachment, GroupID, GroupV2ID, GroupV2)
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.bridge import BasePortal
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
@@ -52,7 +52,7 @@ except ImportError:
 
 StateBridge = EventType.find("m.bridge", EventType.Class.STATE)
 StateHalfShotBridge = EventType.find("uk.half-shot.bridge", EventType.Class.STATE)
-ChatInfo = Union[Group, Contact, Profile, Address]
+ChatInfo = Union[Group, GroupV2, GroupV2ID, Contact, Profile, Address]
 
 
 class Portal(DBPortal, BasePortal):
@@ -451,9 +451,14 @@ class Portal(DBPortal, BasePortal):
                 self.name = puppet.name
             return
 
-        if not isinstance(info, Group):
+        if isinstance(info, Group):
+            changed = await self._update_name(info.name)
+        elif isinstance(info, GroupV2):
+            changed = await self._update_name(info.title)
+        elif isinstance(info, GroupV2ID):
+            return
+        else:
             raise ValueError(f"Unexpected type for group update_info: {type(info)}")
-        changed = await self._update_name(info.name)
         changed = await self._update_avatar()
         await self._update_participants(info.members)
         if changed:
@@ -498,6 +503,7 @@ class Portal(DBPortal, BasePortal):
         return True
 
     async def _update_participants(self, participants: List[Address]) -> None:
+        # TODO add support for pending_members and maybe requesting_members?
         if not self.mxid or not participants:
             return
 
@@ -549,7 +555,7 @@ class Portal(DBPortal, BasePortal):
     # region Creating Matrix rooms
 
     async def update_matrix_room(self, source: 'u.User', info: ChatInfo) -> None:
-        if not self.is_direct and not isinstance(info, Group):
+        if not self.is_direct and not isinstance(info, (Group, GroupV2, GroupV2ID)):
             raise ValueError(f"Unexpected type for updating group portal: {type(info)}")
         elif self.is_direct and not isinstance(info, (Contact, Profile, Address)):
             raise ValueError(f"Unexpected type for updating direct chat portal: {type(info)}")
@@ -559,13 +565,20 @@ class Portal(DBPortal, BasePortal):
             self.log.exception("Failed to update portal")
 
     async def create_matrix_room(self, source: 'u.User', info: ChatInfo) -> Optional[RoomID]:
-        if not self.is_direct and not isinstance(info, Group):
+        if not self.is_direct and not isinstance(info, (Group, GroupV2, GroupV2ID)):
             raise ValueError(f"Unexpected type for creating group portal: {type(info)}")
         elif self.is_direct and not isinstance(info, (Contact, Profile, Address)):
             raise ValueError(f"Unexpected type for creating direct chat portal: {type(info)}")
-        if isinstance(info, Group):
+        if isinstance(info, Group) and not info.members:
+            groups = await self.signal.list_groups(source.username)
+            info = next((g for g in groups
+                         if isinstance(g, Group) and g.group_id == info.group_id), info)
+        elif isinstance(info, GroupV2ID):
             groups = await self.signal.list_groups(source.username)
-            info = next((g for g in groups if g.group_id == info.group_id), info)
+            try:
+                info = next(g for g in groups if isinstance(g, GroupV2) and g.id == info.id)
+            except StopIteration as e:
+                raise ValueError("Couldn't get full group v2 info") from e
         if self.mxid:
             await self.update_matrix_room(source, info)
             return self.mxid

+ 9 - 2
mautrix_signal/signal.py

@@ -73,7 +73,13 @@ class SignalHandler(SignaldClient):
     @staticmethod
     async def handle_message(user: 'u.User', sender: 'pu.Puppet', msg: MessageData,
                              addr_override: Optional[Address] = None) -> None:
-        if msg.group:
+        if msg.group_v2:
+            portal = await po.Portal.get_by_chat_id(msg.group_v2.id, create=False)
+            # TODO get group info for missing v2 groups and create portal when necessary
+            if not portal:
+                user.log.debug(f"Dropping message in unknown v2 group {msg.group_v2.id}")
+                return
+        elif msg.group:
             portal = await po.Portal.get_by_chat_id(msg.group.group_id, create=True)
         else:
             portal = await po.Portal.get_by_chat_id(addr_override or sender.address,
@@ -83,7 +89,8 @@ class SignalHandler(SignaldClient):
                                  " double puppeting enabled")
                 return
         if not portal.mxid:
-            await portal.create_matrix_room(user, msg.group or addr_override or sender.address)
+            await portal.create_matrix_room(user, (msg.group_v2 or msg.group
+                                                   or addr_override or sender.address))
         if msg.reaction:
             await portal.handle_signal_reaction(sender, msg.reaction)
         if msg.body or msg.attachments or msg.sticker:

+ 16 - 3
mautrix_signal/user.py

@@ -13,12 +13,12 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import Dict, Optional, AsyncGenerator, TYPE_CHECKING, cast
+from typing import Dict, Optional, AsyncGenerator, Union, TYPE_CHECKING, cast
 from collections import defaultdict
 from uuid import UUID
 import asyncio
 
-from mausignald.types import Account, Address, Contact, Group, ListenEvent, ListenAction
+from mausignald.types import Account, Address, Contact, Group, GroupV2, ListenEvent, ListenAction
 from mautrix.bridge import BaseUser
 from mautrix.types import UserID, RoomID
 from mautrix.appservice import AppService
@@ -134,6 +134,14 @@ class User(DBUser, BaseUser):
         elif portal.mxid:
             await portal.update_matrix_room(self, group)
 
+    async def _sync_group_v2(self, group: GroupV2, create_portals: bool) -> None:
+        self.log.trace("Syncing group %s", group.id)
+        portal = await po.Portal.get_by_chat_id(group.id, create=True)
+        if create_portals:
+            await portal.create_matrix_room(self, group)
+        elif portal.mxid:
+            await portal.update_matrix_room(self, group)
+
     async def _sync(self) -> None:
         create_contact_portal = self.config["bridge.autocreate_contact_portal"]
         for contact in await self.bridge.signal.list_contacts(self.username):
@@ -144,7 +152,12 @@ class User(DBUser, BaseUser):
         create_group_portal = self.config["bridge.autocreate_group_portal"]
         for group in await self.bridge.signal.list_groups(self.username):
             try:
-                await self._sync_group(group, create_group_portal)
+                if isinstance(group, Group):
+                    await self._sync_group(group, create_group_portal)
+                elif isinstance(group, GroupV2):
+                    await self._sync_group_v2(group, create_group_portal)
+                else:
+                    self.log.warning("Unknown return type in list_groups: %s", type(group))
             except Exception:
                 self.log.exception(f"Failed to sync group {group.group_id}")