Browse Source

Add initial group v2 support

Tulir Asokan 4 years ago
parent
commit
bd3fa8fcef
5 changed files with 68 additions and 17 deletions
  1. 4 3
      mausignald/signald.py
  2. 18 1
      mausignald/types.py
  3. 21 8
      mautrix_signal/portal.py
  4. 9 2
      mautrix_signal/signal.py
  5. 16 3
      mautrix_signal/user.py

+ 4 - 3
mausignald/signald.py

@@ -12,7 +12,7 @@ from mautrix.util.logging import TraceLogger
 from .rpc import SignaldRPCClient
 from .rpc import SignaldRPCClient
 from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
 from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
 from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
-                    Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction)
+                    Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2)
 
 
 T = TypeVar('T')
 T = TypeVar('T')
 EventHandler = Callable[[T], Awaitable[None]]
 EventHandler = Callable[[T], Awaitable[None]]
@@ -146,9 +146,10 @@ class SignaldClient(SignaldRPCClient):
         contacts = await self.request("list_contacts", "contact_list", username=username)
         contacts = await self.request("list_contacts", "contact_list", username=username)
         return [Contact.deserialize(contact) for contact in contacts]
         return [Contact.deserialize(contact) for contact in contacts]
 
 
-    async def list_groups(self, username: str) -> List[Group]:
+    async def list_groups(self, username: str) -> List[Union[Group, GroupV2]]:
         resp = await self.request("list_groups", "group_list", username=username)
         resp = await self.request("list_groups", "group_list", username=username)
-        return [Group.deserialize(group) for group in resp["groups"]]
+        return ([Group.deserialize(group) for group in resp["groups"]]
+                + [GroupV2.deserialize(group) for group in resp["groupsv2"]])
 
 
     async def get_profile(self, username: str, address: Address) -> Optional[Profile]:
     async def get_profile(self, username: str, address: Address) -> Optional[Profile]:
         try:
         try:

+ 18 - 1
mausignald/types.py

@@ -108,6 +108,22 @@ class Group(SerializableAttrs['Group']):
     avatar_id: int = attr.ib(default=0, metadata={"json": "avatarId"})
     avatar_id: int = attr.ib(default=0, metadata={"json": "avatarId"})
 
 
 
 
+@dataclass
+class GroupV2ID(SerializableAttrs['GroupV2ID']):
+    id: GroupID
+    revision: int
+
+
+@dataclass
+class GroupV2(GroupV2ID, SerializableAttrs['GroupV2']):
+    title: str
+    master_key: str = attr.ib(metadata={"json": "masterKey"})
+    members: List[Address]
+    pending_members: List[Address] = attr.ib(metadata={"json": "pendingMembers"})
+    requesting_members: List[Address] = attr.ib(metadata={"json": "requestingMembers"})
+    timer: int
+
+
 @dataclass
 @dataclass
 class Attachment(SerializableAttrs['Attachment']):
 class Attachment(SerializableAttrs['Attachment']):
     width: int = 0
     width: int = 0
@@ -161,9 +177,10 @@ class MessageData(SerializableAttrs['MessageData']):
     reaction: Optional[Reaction] = None
     reaction: Optional[Reaction] = None
     attachments: List[Attachment] = attr.ib(factory=lambda: [])
     attachments: List[Attachment] = attr.ib(factory=lambda: [])
     sticker: Optional[Sticker] = None
     sticker: Optional[Sticker] = None
-    # TODO mentions (although signald doesn't support group v2 yet)
+    # TODO mentions
 
 
     group: Optional[Group] = None
     group: Optional[Group] = None
+    group_v2: Optional[GroupV2ID] = attr.ib(default=None, metadata={"json": "groupV2"})
 
 
     end_session: bool = attr.ib(default=False, metadata={"json": "endSession"})
     end_session: bool = attr.ib(default=False, metadata={"json": "endSession"})
     expires_in_seconds: int = attr.ib(default=0, metadata={"json": "expiresInSeconds"})
     expires_in_seconds: int = attr.ib(default=0, metadata={"json": "expiresInSeconds"})

+ 21 - 8
mautrix_signal/portal.py

@@ -25,7 +25,7 @@ import time
 import os
 import os
 
 
 from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
 from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
-                              Attachment, GroupID)
+                              Attachment, GroupID, GroupV2ID, GroupV2)
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.bridge import BasePortal
 from mautrix.bridge import BasePortal
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
@@ -52,7 +52,7 @@ except ImportError:
 
 
 StateBridge = EventType.find("m.bridge", EventType.Class.STATE)
 StateBridge = EventType.find("m.bridge", EventType.Class.STATE)
 StateHalfShotBridge = EventType.find("uk.half-shot.bridge", EventType.Class.STATE)
 StateHalfShotBridge = EventType.find("uk.half-shot.bridge", EventType.Class.STATE)
-ChatInfo = Union[Group, Contact, Profile, Address]
+ChatInfo = Union[Group, GroupV2, GroupV2ID, Contact, Profile, Address]
 
 
 
 
 class Portal(DBPortal, BasePortal):
 class Portal(DBPortal, BasePortal):
@@ -451,9 +451,14 @@ class Portal(DBPortal, BasePortal):
                 self.name = puppet.name
                 self.name = puppet.name
             return
             return
 
 
-        if not isinstance(info, Group):
+        if isinstance(info, Group):
+            changed = await self._update_name(info.name)
+        elif isinstance(info, GroupV2):
+            changed = await self._update_name(info.title)
+        elif isinstance(info, GroupV2ID):
+            return
+        else:
             raise ValueError(f"Unexpected type for group update_info: {type(info)}")
             raise ValueError(f"Unexpected type for group update_info: {type(info)}")
-        changed = await self._update_name(info.name)
         changed = await self._update_avatar()
         changed = await self._update_avatar()
         await self._update_participants(info.members)
         await self._update_participants(info.members)
         if changed:
         if changed:
@@ -498,6 +503,7 @@ class Portal(DBPortal, BasePortal):
         return True
         return True
 
 
     async def _update_participants(self, participants: List[Address]) -> None:
     async def _update_participants(self, participants: List[Address]) -> None:
+        # TODO add support for pending_members and maybe requesting_members?
         if not self.mxid or not participants:
         if not self.mxid or not participants:
             return
             return
 
 
@@ -549,7 +555,7 @@ class Portal(DBPortal, BasePortal):
     # region Creating Matrix rooms
     # region Creating Matrix rooms
 
 
     async def update_matrix_room(self, source: 'u.User', info: ChatInfo) -> None:
     async def update_matrix_room(self, source: 'u.User', info: ChatInfo) -> None:
-        if not self.is_direct and not isinstance(info, Group):
+        if not self.is_direct and not isinstance(info, (Group, GroupV2, GroupV2ID)):
             raise ValueError(f"Unexpected type for updating group portal: {type(info)}")
             raise ValueError(f"Unexpected type for updating group portal: {type(info)}")
         elif self.is_direct and not isinstance(info, (Contact, Profile, Address)):
         elif self.is_direct and not isinstance(info, (Contact, Profile, Address)):
             raise ValueError(f"Unexpected type for updating direct chat portal: {type(info)}")
             raise ValueError(f"Unexpected type for updating direct chat portal: {type(info)}")
@@ -559,13 +565,20 @@ class Portal(DBPortal, BasePortal):
             self.log.exception("Failed to update portal")
             self.log.exception("Failed to update portal")
 
 
     async def create_matrix_room(self, source: 'u.User', info: ChatInfo) -> Optional[RoomID]:
     async def create_matrix_room(self, source: 'u.User', info: ChatInfo) -> Optional[RoomID]:
-        if not self.is_direct and not isinstance(info, Group):
+        if not self.is_direct and not isinstance(info, (Group, GroupV2, GroupV2ID)):
             raise ValueError(f"Unexpected type for creating group portal: {type(info)}")
             raise ValueError(f"Unexpected type for creating group portal: {type(info)}")
         elif self.is_direct and not isinstance(info, (Contact, Profile, Address)):
         elif self.is_direct and not isinstance(info, (Contact, Profile, Address)):
             raise ValueError(f"Unexpected type for creating direct chat portal: {type(info)}")
             raise ValueError(f"Unexpected type for creating direct chat portal: {type(info)}")
-        if isinstance(info, Group):
+        if isinstance(info, Group) and not info.members:
+            groups = await self.signal.list_groups(source.username)
+            info = next((g for g in groups
+                         if isinstance(g, Group) and g.group_id == info.group_id), info)
+        elif isinstance(info, GroupV2ID):
             groups = await self.signal.list_groups(source.username)
             groups = await self.signal.list_groups(source.username)
-            info = next((g for g in groups if g.group_id == info.group_id), info)
+            try:
+                info = next(g for g in groups if isinstance(g, GroupV2) and g.id == info.id)
+            except StopIteration as e:
+                raise ValueError("Couldn't get full group v2 info") from e
         if self.mxid:
         if self.mxid:
             await self.update_matrix_room(source, info)
             await self.update_matrix_room(source, info)
             return self.mxid
             return self.mxid

+ 9 - 2
mautrix_signal/signal.py

@@ -73,7 +73,13 @@ class SignalHandler(SignaldClient):
     @staticmethod
     @staticmethod
     async def handle_message(user: 'u.User', sender: 'pu.Puppet', msg: MessageData,
     async def handle_message(user: 'u.User', sender: 'pu.Puppet', msg: MessageData,
                              addr_override: Optional[Address] = None) -> None:
                              addr_override: Optional[Address] = None) -> None:
-        if msg.group:
+        if msg.group_v2:
+            portal = await po.Portal.get_by_chat_id(msg.group_v2.id, create=False)
+            # TODO get group info for missing v2 groups and create portal when necessary
+            if not portal:
+                user.log.debug(f"Dropping message in unknown v2 group {msg.group_v2.id}")
+                return
+        elif msg.group:
             portal = await po.Portal.get_by_chat_id(msg.group.group_id, create=True)
             portal = await po.Portal.get_by_chat_id(msg.group.group_id, create=True)
         else:
         else:
             portal = await po.Portal.get_by_chat_id(addr_override or sender.address,
             portal = await po.Portal.get_by_chat_id(addr_override or sender.address,
@@ -83,7 +89,8 @@ class SignalHandler(SignaldClient):
                                  " double puppeting enabled")
                                  " double puppeting enabled")
                 return
                 return
         if not portal.mxid:
         if not portal.mxid:
-            await portal.create_matrix_room(user, msg.group or addr_override or sender.address)
+            await portal.create_matrix_room(user, (msg.group_v2 or msg.group
+                                                   or addr_override or sender.address))
         if msg.reaction:
         if msg.reaction:
             await portal.handle_signal_reaction(sender, msg.reaction)
             await portal.handle_signal_reaction(sender, msg.reaction)
         if msg.body or msg.attachments or msg.sticker:
         if msg.body or msg.attachments or msg.sticker:

+ 16 - 3
mautrix_signal/user.py

@@ -13,12 +13,12 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import Dict, Optional, AsyncGenerator, TYPE_CHECKING, cast
+from typing import Dict, Optional, AsyncGenerator, Union, TYPE_CHECKING, cast
 from collections import defaultdict
 from collections import defaultdict
 from uuid import UUID
 from uuid import UUID
 import asyncio
 import asyncio
 
 
-from mausignald.types import Account, Address, Contact, Group, ListenEvent, ListenAction
+from mausignald.types import Account, Address, Contact, Group, GroupV2, ListenEvent, ListenAction
 from mautrix.bridge import BaseUser
 from mautrix.bridge import BaseUser
 from mautrix.types import UserID, RoomID
 from mautrix.types import UserID, RoomID
 from mautrix.appservice import AppService
 from mautrix.appservice import AppService
@@ -134,6 +134,14 @@ class User(DBUser, BaseUser):
         elif portal.mxid:
         elif portal.mxid:
             await portal.update_matrix_room(self, group)
             await portal.update_matrix_room(self, group)
 
 
+    async def _sync_group_v2(self, group: GroupV2, create_portals: bool) -> None:
+        self.log.trace("Syncing group %s", group.id)
+        portal = await po.Portal.get_by_chat_id(group.id, create=True)
+        if create_portals:
+            await portal.create_matrix_room(self, group)
+        elif portal.mxid:
+            await portal.update_matrix_room(self, group)
+
     async def _sync(self) -> None:
     async def _sync(self) -> None:
         create_contact_portal = self.config["bridge.autocreate_contact_portal"]
         create_contact_portal = self.config["bridge.autocreate_contact_portal"]
         for contact in await self.bridge.signal.list_contacts(self.username):
         for contact in await self.bridge.signal.list_contacts(self.username):
@@ -144,7 +152,12 @@ class User(DBUser, BaseUser):
         create_group_portal = self.config["bridge.autocreate_group_portal"]
         create_group_portal = self.config["bridge.autocreate_group_portal"]
         for group in await self.bridge.signal.list_groups(self.username):
         for group in await self.bridge.signal.list_groups(self.username):
             try:
             try:
-                await self._sync_group(group, create_group_portal)
+                if isinstance(group, Group):
+                    await self._sync_group(group, create_group_portal)
+                elif isinstance(group, GroupV2):
+                    await self._sync_group_v2(group, create_group_portal)
+                else:
+                    self.log.warning("Unknown return type in list_groups: %s", type(group))
             except Exception:
             except Exception:
                 self.log.exception(f"Failed to sync group {group.group_id}")
                 self.log.exception(f"Failed to sync group {group.group_id}")