Эх сурвалжийг харах

Handle group info changes

Tulir Asokan 4 жил өмнө
parent
commit
87685caa81

+ 36 - 18
mauigpapi/mqtt/conn.py

@@ -36,7 +36,8 @@ from ..state import AndroidState
 from ..types import (CommandResponse, ThreadItemType, ThreadAction, ReactionStatus, TypingStatus,
 from ..types import (CommandResponse, ThreadItemType, ThreadAction, ReactionStatus, TypingStatus,
                      IrisPayload, PubsubPayload, AppPresenceEventPayload, RealtimeDirectEvent,
                      IrisPayload, PubsubPayload, AppPresenceEventPayload, RealtimeDirectEvent,
                      RealtimeZeroProvisionPayload, ClientConfigUpdatePayload, MessageSyncEvent,
                      RealtimeZeroProvisionPayload, ClientConfigUpdatePayload, MessageSyncEvent,
-                     MessageSyncMessage, LiveVideoCommentPayload, PubsubEvent)
+                     MessageSyncMessage, LiveVideoCommentPayload, PubsubEvent, IrisPayloadData,
+                     ThreadSyncEvent)
 from .thrift import RealtimeConfig, RealtimeClientInfo, ForegroundStateConfig, IncomingMessage
 from .thrift import RealtimeConfig, RealtimeClientInfo, ForegroundStateConfig, IncomingMessage
 from .otclient import MQTToTClient
 from .otclient import MQTToTClient
 from .subscription import everclear_subscriptions, RealtimeTopic, GraphQLQueryID
 from .subscription import everclear_subscriptions, RealtimeTopic, GraphQLQueryID
@@ -52,6 +53,9 @@ T = TypeVar('T')
 ACTIVITY_INDICATOR_REGEX = re.compile(
 ACTIVITY_INDICATOR_REGEX = re.compile(
     r"/direct_v2/threads/([\w_]+)/activity_indicator_id/([\w_]+)")
     r"/direct_v2/threads/([\w_]+)/activity_indicator_id/([\w_]+)")
 
 
+INBOX_THREAD_REGEX = re.compile(
+    r"/direct_v2/inbox/threads/([\w_]+)")
+
 
 
 class AndroidMQTT:
 class AndroidMQTT:
     _loop: asyncio.AbstractEventLoop
     _loop: asyncio.AbstractEventLoop
@@ -216,8 +220,8 @@ class AndroidMQTT:
             assert blank == ""
             assert blank == ""
             assert direct_v2 == "direct_v2"
             assert direct_v2 == "direct_v2"
             assert threads == "threads"
             assert threads == "threads"
-        except (AssertionError, ValueError) as e:
-            self.log.debug(f"Got {e} while parsing path {path}")
+        except (AssertionError, ValueError, IndexError) as e:
+            self.log.debug(f"Got {e!r} while parsing path {path}")
             raise
             raise
         additional = {
         additional = {
             "thread_id": thread_id
             "thread_id": thread_id
@@ -243,27 +247,41 @@ class AndroidMQTT:
         self.log.trace("Parsed path %s -> %s", path, additional)
         self.log.trace("Parsed path %s -> %s", path, additional)
         return additional
         return additional
 
 
+    def _on_messager_sync_item(self, part: IrisPayloadData, parsed_item: IrisPayload) -> None:
+        if part.path.startswith("/direct_v2/threads/"):
+            raw_message = {
+                "path": part.path,
+                "op": part.op,
+                **self._parse_direct_thread_path(part.path),
+            }
+            try:
+                raw_message = {
+                    **raw_message,
+                    **json.loads(part.value),
+                }
+            except (json.JSONDecodeError, TypeError):
+                raw_message["value"] = part.value
+            message = MessageSyncMessage.deserialize(raw_message)
+            evt = MessageSyncEvent(iris=parsed_item, message=message)
+        elif part.path.startswith("/direct_v2/inbox/threads/"):
+            raw_message = {
+                "path": part.path,
+                "op": part.op,
+                **json.loads(part.value),
+            }
+            evt = ThreadSyncEvent.deserialize(raw_message)
+        else:
+            self.log.warning(f"Unsupported path {part.path}")
+            return
+        self._loop.create_task(self._dispatch(evt))
+
     def _on_message_sync(self, payload: bytes) -> None:
     def _on_message_sync(self, payload: bytes) -> None:
         parsed = json.loads(payload.decode("utf-8"))
         parsed = json.loads(payload.decode("utf-8"))
         self.log.trace("Got message sync event: %s", parsed)
         self.log.trace("Got message sync event: %s", parsed)
         for sync_item in parsed:
         for sync_item in parsed:
             parsed_item = IrisPayload.deserialize(sync_item)
             parsed_item = IrisPayload.deserialize(sync_item)
             for part in parsed_item.data:
             for part in parsed_item.data:
-                raw_message = {
-                    "path": part.path,
-                    "op": part.op,
-                    **self._parse_direct_thread_path(part.path),
-                }
-                try:
-                    raw_message = {
-                        **raw_message,
-                        **json.loads(part.value),
-                    }
-                except (json.JSONDecodeError, TypeError):
-                    raw_message["value"] = part.value
-                message = MessageSyncMessage.deserialize(raw_message)
-                evt = MessageSyncEvent(iris=parsed_item, message=message)
-                self._loop.create_task(self._dispatch(evt))
+                self._on_messager_sync_item(part, parsed_item)
 
 
     def _on_pubsub(self, payload: bytes) -> None:
     def _on_pubsub(self, payload: bytes) -> None:
         parsed_thrift = IncomingMessage.from_thrift(payload)
         parsed_thrift = IncomingMessage.from_thrift(payload)

+ 1 - 1
mauigpapi/types/__init__.py

@@ -23,5 +23,5 @@ from .mqtt import (Operation, ThreadAction, ReactionStatus, TypingStatus, Comman
                    AppPresenceEvent, ZeroProductProvisioningEvent, RealtimeZeroProvisionPayload,
                    AppPresenceEvent, ZeroProductProvisioningEvent, RealtimeZeroProvisionPayload,
                    ClientConfigUpdatePayload, ClientConfigUpdateEvent, RealtimeDirectData,
                    ClientConfigUpdatePayload, ClientConfigUpdateEvent, RealtimeDirectData,
                    RealtimeDirectEvent, LiveVideoSystemComment, LiveVideoCommentEvent,
                    RealtimeDirectEvent, LiveVideoSystemComment, LiveVideoCommentEvent,
-                   LiveVideoComment, LiveVideoCommentPayload)
+                   LiveVideoComment, LiveVideoCommentPayload, ThreadSyncEvent)
 from .challenge import ChallengeStateResponse, ChallengeStateData
 from .challenge import ChallengeStateResponse, ChallengeStateData

+ 8 - 1
mauigpapi/types/mqtt.py

@@ -21,7 +21,8 @@ import attr
 
 
 from mautrix.types import SerializableAttrs, SerializableEnum, JSON
 from mautrix.types import SerializableAttrs, SerializableEnum, JSON
 
 
-from .thread import ThreadItem
+from .thread import Thread
+from .thread_item import ThreadItem
 from .account import BaseResponseUser
 from .account import BaseResponseUser
 
 
 
 
@@ -103,6 +104,12 @@ class MessageSyncEvent(SerializableAttrs['MessageSyncEvent']):
     message: MessageSyncMessage
     message: MessageSyncMessage
 
 
 
 
+@dataclass
+class ThreadSyncEvent(Thread, SerializableAttrs['ThreadSyncEvent']):
+    path: str
+    op: Operation
+
+
 @dataclass(kw_only=True)
 @dataclass(kw_only=True)
 class PubsubPublishMetadata(SerializableAttrs['PubsubPublishMetadata']):
 class PubsubPublishMetadata(SerializableAttrs['PubsubPublishMetadata']):
     publish_time_ms: str
     publish_time_ms: str

+ 5 - 4
mauigpapi/types/thread.py

@@ -46,6 +46,7 @@ class Thread(SerializableAttrs['Thread']):
     thread_v2_id: str
     thread_v2_id: str
 
 
     users: List[ThreadUser]
     users: List[ThreadUser]
+    # left_users: List[TODO]
     inviter: BaseResponseUser
     inviter: BaseResponseUser
     admin_user_ids: List[int]
     admin_user_ids: List[int]
 
 
@@ -78,11 +79,11 @@ class Thread(SerializableAttrs['Thread']):
     theme: ThreadTheme
     theme: ThreadTheme
     last_seen_at: Dict[int, ThreadUserLastSeenAt]
     last_seen_at: Dict[int, ThreadUserLastSeenAt]
 
 
-    newest_cursor: str
-    oldest_cursor: str
+    newest_cursor: Optional[str] = None
+    oldest_cursor: Optional[str] = None
     next_cursor: Optional[str] = None
     next_cursor: Optional[str] = None
-    prev_cursor: str
-    last_permanent_item: ThreadItem
+    prev_cursor: Optional[str] = None
+    last_permanent_item: Optional[ThreadItem] = None
     items: List[ThreadItem]
     items: List[ThreadItem]
 
 
     # These might only be in single thread requests and not inbox
     # These might only be in single thread requests and not inbox

+ 9 - 1
mautrix_instagram/user.py

@@ -23,7 +23,7 @@ import time
 from mauigpapi import AndroidAPI, AndroidState, AndroidMQTT
 from mauigpapi import AndroidAPI, AndroidState, AndroidMQTT
 from mauigpapi.mqtt import Connect, Disconnect, GraphQLSubscription, SkywalkerSubscription
 from mauigpapi.mqtt import Connect, Disconnect, GraphQLSubscription, SkywalkerSubscription
 from mauigpapi.types import (CurrentUser, MessageSyncEvent, Operation, RealtimeDirectEvent,
 from mauigpapi.types import (CurrentUser, MessageSyncEvent, Operation, RealtimeDirectEvent,
-                             ActivityIndicatorData, TypingStatus)
+                             ActivityIndicatorData, TypingStatus, ThreadSyncEvent)
 from mauigpapi.errors import IGNotLoggedInError
 from mauigpapi.errors import IGNotLoggedInError
 from mautrix.bridge import BaseUser
 from mautrix.bridge import BaseUser
 from mautrix.types import UserID, RoomID, EventID, TextMessageEventContent, MessageType
 from mautrix.types import UserID, RoomID, EventID, TextMessageEventContent, MessageType
@@ -39,6 +39,7 @@ if TYPE_CHECKING:
     from .__main__ import InstagramBridge
     from .__main__ import InstagramBridge
 
 
 METRIC_MESSAGE = Summary("bridge_on_message", "calls to handle_message")
 METRIC_MESSAGE = Summary("bridge_on_message", "calls to handle_message")
+METRIC_THREAD_SYNC = Summary("bridge_on_thread_sync", "calls to handle_thread_sync")
 METRIC_RTD = Summary("bridge_on_rtd", "calls to handle_rtd")
 METRIC_RTD = Summary("bridge_on_rtd", "calls to handle_rtd")
 METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Users logged into the bridge")
 METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Users logged into the bridge")
 METRIC_CONNECTED = Gauge("bridge_connected", "Bridged users connected to Instagram")
 METRIC_CONNECTED = Gauge("bridge_connected", "Bridged users connected to Instagram")
@@ -129,6 +130,7 @@ class User(DBUser, BaseUser):
         self.mqtt.add_event_handler(Connect, self.on_connect)
         self.mqtt.add_event_handler(Connect, self.on_connect)
         self.mqtt.add_event_handler(Disconnect, self.on_disconnect)
         self.mqtt.add_event_handler(Disconnect, self.on_disconnect)
         self.mqtt.add_event_handler(MessageSyncEvent, self.handle_message)
         self.mqtt.add_event_handler(MessageSyncEvent, self.handle_message)
+        self.mqtt.add_event_handler(ThreadSyncEvent, self.handle_thread_sync)
         self.mqtt.add_event_handler(RealtimeDirectEvent, self.handle_rtd)
         self.mqtt.add_event_handler(RealtimeDirectEvent, self.handle_rtd)
 
 
         await self.update()
         await self.update()
@@ -281,6 +283,12 @@ class User(DBUser, BaseUser):
         elif evt.message.op == Operation.REPLACE:
         elif evt.message.op == Operation.REPLACE:
             await portal.handle_instagram_update(evt.message)
             await portal.handle_instagram_update(evt.message)
 
 
+    @async_time(METRIC_THREAD_SYNC)
+    async def handle_thread_sync(self, evt: ThreadSyncEvent) -> None:
+        self.log.trace("Received thread sync event %s", evt)
+        portal = await po.Portal.get_by_thread(evt, receiver=self.igpk)
+        await portal.create_matrix_room(self, evt)
+
     @async_time(METRIC_RTD)
     @async_time(METRIC_RTD)
     async def handle_rtd(self, evt: RealtimeDirectEvent) -> None:
     async def handle_rtd(self, evt: RealtimeDirectEvent) -> None:
         if not isinstance(evt.value, ActivityIndicatorData):
         if not isinstance(evt.value, ActivityIndicatorData):