|
@@ -87,6 +87,8 @@ class AndroidMQTT:
|
|
_message_response_waiters: dict[str, asyncio.Future]
|
|
_message_response_waiters: dict[str, asyncio.Future]
|
|
_disconnect_error: Exception | None
|
|
_disconnect_error: Exception | None
|
|
_event_handlers: dict[Type[T], list[Callable[[T], Awaitable[None]]]]
|
|
_event_handlers: dict[Type[T], list[Callable[[T], Awaitable[None]]]]
|
|
|
|
+ _outgoing_events: asyncio.Queue
|
|
|
|
+ _event_dispatcher_task: asyncio.Task | None
|
|
|
|
|
|
# region Initialization
|
|
# region Initialization
|
|
|
|
|
|
@@ -106,6 +108,8 @@ class AndroidMQTT:
|
|
self._disconnect_error = None
|
|
self._disconnect_error = None
|
|
self._response_waiter_locks = defaultdict(lambda: asyncio.Lock())
|
|
self._response_waiter_locks = defaultdict(lambda: asyncio.Lock())
|
|
self._event_handlers = defaultdict(lambda: [])
|
|
self._event_handlers = defaultdict(lambda: [])
|
|
|
|
+ self._event_dispatcher_task = None
|
|
|
|
+ self._outgoing_events = asyncio.Queue()
|
|
self.log = log or logging.getLogger("mauigpapi.mqtt")
|
|
self.log = log or logging.getLogger("mauigpapi.mqtt")
|
|
self._loop = loop or asyncio.get_event_loop()
|
|
self._loop = loop or asyncio.get_event_loop()
|
|
self.state = state
|
|
self.state = state
|
|
@@ -292,7 +296,7 @@ class AndroidMQTT:
|
|
self.log.trace("Parsed path %s -> %s", path, additional)
|
|
self.log.trace("Parsed path %s -> %s", path, additional)
|
|
return additional
|
|
return additional
|
|
|
|
|
|
- def _on_messager_sync_item(self, part: IrisPayloadData, parsed_item: IrisPayload) -> None:
|
|
|
|
|
|
+ def _on_messager_sync_item(self, part: IrisPayloadData, parsed_item: IrisPayload) -> bool:
|
|
if part.path.startswith("/direct_v2/threads/"):
|
|
if part.path.startswith("/direct_v2/threads/"):
|
|
raw_message = {
|
|
raw_message = {
|
|
"path": part.path,
|
|
"path": part.path,
|
|
@@ -317,12 +321,14 @@ class AndroidMQTT:
|
|
evt = ThreadSyncEvent.deserialize(raw_message)
|
|
evt = ThreadSyncEvent.deserialize(raw_message)
|
|
else:
|
|
else:
|
|
self.log.warning(f"Unsupported path {part.path}")
|
|
self.log.warning(f"Unsupported path {part.path}")
|
|
- return
|
|
|
|
- self._loop.create_task(self._dispatch(evt))
|
|
|
|
|
|
+ return False
|
|
|
|
+ self._outgoing_events.put_nowait(evt)
|
|
|
|
+ return True
|
|
|
|
|
|
def _on_message_sync(self, payload: bytes) -> None:
|
|
def _on_message_sync(self, payload: bytes) -> None:
|
|
parsed = json.loads(payload.decode("utf-8"))
|
|
parsed = json.loads(payload.decode("utf-8"))
|
|
self.log.trace("Got message sync event: %s", parsed)
|
|
self.log.trace("Got message sync event: %s", parsed)
|
|
|
|
+ has_items = False
|
|
for sync_item in parsed:
|
|
for sync_item in parsed:
|
|
parsed_item = IrisPayload.deserialize(sync_item)
|
|
parsed_item = IrisPayload.deserialize(sync_item)
|
|
if self._iris_seq_id < parsed_item.seq_id:
|
|
if self._iris_seq_id < parsed_item.seq_id:
|
|
@@ -333,7 +339,9 @@ class AndroidMQTT:
|
|
self._dispatch(NewSequenceID(self._iris_seq_id, self._iris_snapshot_at_ms))
|
|
self._dispatch(NewSequenceID(self._iris_seq_id, self._iris_snapshot_at_ms))
|
|
)
|
|
)
|
|
for part in parsed_item.data:
|
|
for part in parsed_item.data:
|
|
- self._on_messager_sync_item(part, parsed_item)
|
|
|
|
|
|
+ has_items = self._on_messager_sync_item(part, parsed_item) or has_items
|
|
|
|
+ if has_items and not self._event_dispatcher_task:
|
|
|
|
+ self._event_dispatcher_task = asyncio.create_task(self._dispatcher_loop())
|
|
|
|
|
|
def _on_pubsub(self, payload: bytes) -> None:
|
|
def _on_pubsub(self, payload: bytes) -> None:
|
|
parsed_thrift = IncomingMessage.from_thrift(payload)
|
|
parsed_thrift = IncomingMessage.from_thrift(payload)
|
|
@@ -456,6 +464,26 @@ class AndroidMQTT:
|
|
def disconnect(self) -> None:
|
|
def disconnect(self) -> None:
|
|
self._client.disconnect()
|
|
self._client.disconnect()
|
|
|
|
|
|
|
|
+ async def _dispatcher_loop(self) -> None:
|
|
|
|
+ loop_id = f"{hex(id(self))}#{time.monotonic()}"
|
|
|
|
+ self.log.debug(f"Dispatcher loop {loop_id} starting")
|
|
|
|
+ try:
|
|
|
|
+ while True:
|
|
|
|
+ evt = await self._outgoing_events.get()
|
|
|
|
+ await asyncio.shield(self._dispatch(evt))
|
|
|
|
+ except asyncio.CancelledError:
|
|
|
|
+ tasks = self._outgoing_events
|
|
|
|
+ self._outgoing_events = asyncio.Queue()
|
|
|
|
+ if not tasks.empty():
|
|
|
|
+ self.log.debug(
|
|
|
|
+ f"Dispatcher loop {loop_id} stopping after dispatching {tasks.qsize()} events"
|
|
|
|
+ )
|
|
|
|
+ while not tasks.empty():
|
|
|
|
+ await self._dispatch(tasks.get_nowait())
|
|
|
|
+ raise
|
|
|
|
+ finally:
|
|
|
|
+ self.log.debug(f"Dispatcher loop {loop_id} stopped")
|
|
|
|
+
|
|
async def listen(
|
|
async def listen(
|
|
self,
|
|
self,
|
|
graphql_subs: set[str] | None = None,
|
|
graphql_subs: set[str] | None = None,
|
|
@@ -518,6 +546,9 @@ class AndroidMQTT:
|
|
connection_retries += 1
|
|
connection_retries += 1
|
|
else:
|
|
else:
|
|
connection_retries = 0
|
|
connection_retries = 0
|
|
|
|
+ if self._event_dispatcher_task:
|
|
|
|
+ self._event_dispatcher_task.cancel()
|
|
|
|
+ self._event_dispatcher_task = None
|
|
if self._disconnect_error:
|
|
if self._disconnect_error:
|
|
self.log.info("disconnect_error is set, raising and clearing variable")
|
|
self.log.info("disconnect_error is set, raising and clearing variable")
|
|
err = self._disconnect_error
|
|
err = self._disconnect_error
|