Эх сурвалжийг харах

Process events from Instagram in order

Tulir Asokan 3 жил өмнө
parent
commit
904ebbcdd6
1 өөрчлөгдсөн 35 нэмэгдсэн , 4 устгасан
  1. 35 4
      mauigpapi/mqtt/conn.py

+ 35 - 4
mauigpapi/mqtt/conn.py

@@ -87,6 +87,8 @@ class AndroidMQTT:
     _message_response_waiters: dict[str, asyncio.Future]
     _disconnect_error: Exception | None
     _event_handlers: dict[Type[T], list[Callable[[T], Awaitable[None]]]]
+    _outgoing_events: asyncio.Queue
+    _event_dispatcher_task: asyncio.Task | None
 
     # region Initialization
 
@@ -106,6 +108,8 @@ class AndroidMQTT:
         self._disconnect_error = None
         self._response_waiter_locks = defaultdict(lambda: asyncio.Lock())
         self._event_handlers = defaultdict(lambda: [])
+        self._event_dispatcher_task = None
+        self._outgoing_events = asyncio.Queue()
         self.log = log or logging.getLogger("mauigpapi.mqtt")
         self._loop = loop or asyncio.get_event_loop()
         self.state = state
@@ -292,7 +296,7 @@ class AndroidMQTT:
         self.log.trace("Parsed path %s -> %s", path, additional)
         return additional
 
-    def _on_messager_sync_item(self, part: IrisPayloadData, parsed_item: IrisPayload) -> None:
+    def _on_messager_sync_item(self, part: IrisPayloadData, parsed_item: IrisPayload) -> bool:
         if part.path.startswith("/direct_v2/threads/"):
             raw_message = {
                 "path": part.path,
@@ -317,12 +321,14 @@ class AndroidMQTT:
             evt = ThreadSyncEvent.deserialize(raw_message)
         else:
             self.log.warning(f"Unsupported path {part.path}")
-            return
-        self._loop.create_task(self._dispatch(evt))
+            return False
+        self._outgoing_events.put_nowait(evt)
+        return True
 
     def _on_message_sync(self, payload: bytes) -> None:
         parsed = json.loads(payload.decode("utf-8"))
         self.log.trace("Got message sync event: %s", parsed)
+        has_items = False
         for sync_item in parsed:
             parsed_item = IrisPayload.deserialize(sync_item)
             if self._iris_seq_id < parsed_item.seq_id:
@@ -333,7 +339,9 @@ class AndroidMQTT:
                     self._dispatch(NewSequenceID(self._iris_seq_id, self._iris_snapshot_at_ms))
                 )
             for part in parsed_item.data:
-                self._on_messager_sync_item(part, parsed_item)
+                has_items = self._on_messager_sync_item(part, parsed_item) or has_items
+        if has_items and not self._event_dispatcher_task:
+            self._event_dispatcher_task = asyncio.create_task(self._dispatcher_loop())
 
     def _on_pubsub(self, payload: bytes) -> None:
         parsed_thrift = IncomingMessage.from_thrift(payload)
@@ -456,6 +464,26 @@ class AndroidMQTT:
     def disconnect(self) -> None:
         self._client.disconnect()
 
+    async def _dispatcher_loop(self) -> None:
+        loop_id = f"{hex(id(self))}#{time.monotonic()}"
+        self.log.debug(f"Dispatcher loop {loop_id} starting")
+        try:
+            while True:
+                evt = await self._outgoing_events.get()
+                await asyncio.shield(self._dispatch(evt))
+        except asyncio.CancelledError:
+            tasks = self._outgoing_events
+            self._outgoing_events = asyncio.Queue()
+            if not tasks.empty():
+                self.log.debug(
+                    f"Dispatcher loop {loop_id} stopping after dispatching {tasks.qsize()} events"
+                )
+            while not tasks.empty():
+                await self._dispatch(tasks.get_nowait())
+            raise
+        finally:
+            self.log.debug(f"Dispatcher loop {loop_id} stopped")
+
     async def listen(
         self,
         graphql_subs: set[str] | None = None,
@@ -518,6 +546,9 @@ class AndroidMQTT:
                 connection_retries += 1
             else:
                 connection_retries = 0
+        if self._event_dispatcher_task:
+            self._event_dispatcher_task.cancel()
+            self._event_dispatcher_task = None
         if self._disconnect_error:
             self.log.info("disconnect_error is set, raising and clearing variable")
             err = self._disconnect_error