瀏覽代碼

Offload MQTT reconnection handling to client

This removes a bunch of logic we have added to handle MQTT connection
issues and instead utilises the clients builtin mechanisms. Achieved
by overwriting the keepalive value temporarily while doing send/receive
to force quicker detection of connection issues.
Nick Barrett 2 年之前
父節點
當前提交
ef120abcd7
共有 1 個文件被更改,包括 37 次插入55 次删除
  1. 37 55
      mauigpapi/mqtt/conn.py

+ 37 - 55
mauigpapi/mqtt/conn.py

@@ -80,10 +80,9 @@ ACTIVITY_INDICATOR_REGEX = re.compile(
 
 INBOX_THREAD_REGEX = re.compile(r"/direct_v2/inbox/threads/([\w_]+)")
 
-REQUEST_PUBLISH_TIMEOUT = 5
-REQUEST_RESPONSE_TIMEOUT = 30
-
-RECONNECT_ATTEMPTS = 5
+REQUEST_TIMEOUT = 60 * 3
+DEFAULT_KEEPALIVE = 60
+REQUEST_KEEPALIVE = 5
 
 
 class AndroidMQTT:
@@ -529,20 +528,12 @@ class AndroidMQTT:
     # endregion
 
     async def _reconnect(self) -> None:
-        if self._client.is_connected():
-            self.log.debug("Trying to reconnect to MQTT (currently connected)")
-        else:
-            self.log.debug("Trying to reconnect to MQTT (currently not connected)")
-        attempts = 0
-        while True:
-            try:
-                self._client.reconnect()
-                return
-            except (SocketError, OSError, pmc.WebsocketConnectionError) as e:
-                self.log.exception("Error on attempt %d reconnecting to MQTT", attempts)
-                attempts += 1
-                if attempts > RECONNECT_ATTEMPTS:
-                    raise MQTTReconnectionError("MQTT reconnection failed") from e
+        try:
+            self._client.reconnect()
+            return
+        except (SocketError, OSError, pmc.WebsocketConnectionError) as e:
+            self.log.exception("Error reconnecting to MQTT")
+            raise MQTTReconnectionError("MQTT reconnection failed") from e
 
     def add_event_handler(
         self, evt_type: Type[T], handler: Callable[[T], Awaitable[None]]
@@ -585,7 +576,7 @@ class AndroidMQTT:
         skywalker_subs: set[str] | None = None,
         seq_id: int = None,
         snapshot_at_ms: int = None,
-        retry_limit: int = 5,
+        retry_limit: int = 10,
     ) -> None:
         self._graphql_subs = graphql_subs or set()
         self._skywalker_subs = skywalker_subs or set()
@@ -642,9 +633,11 @@ class AndroidMQTT:
             on_proxy_change=lambda: self._dispatch(ProxyUpdate()),
             max_retries=retry_limit,
             retryable_exceptions=(MQTTNotConnected, MQTTReconnectionError),
-            # Wait 1s * errors, max 10s for fast reconnect or die
-            max_wait_seconds=10,
+            # Wait 1s * errors, max 5s for fast reconnect or die
+            max_wait_seconds=5,
             multiply_wait_seconds=1,
+            # If connection stable for >1h, reset the error counter
+            reset_after_seconds=3600,
         )
 
         if self._event_dispatcher_task:
@@ -668,6 +661,18 @@ class AndroidMQTT:
         if not fut.done():
             fut.set_exception(asyncio.TimeoutError("MQTT request timed out"))
 
+    # The following two functions mutate the client keepalive (cheeky) to temporarily increase
+    # ping attempts during read/write to MQTT. If things are flowing this should change nothing,
+    # as pings only send when idle. It should, however, allow the client to detect a bad MQTT
+    # connection much quicker than the default keepalive.
+    def set_request_keepalive(self):
+        self._client._keepalive = REQUEST_KEEPALIVE
+
+    def maybe_reset_keepalive(self):
+        # Reset the keepalive back to the default value if we have no pending publish/receive
+        if not self._response_waiters and not self._publish_waiters:
+            self._client._keepalive = DEFAULT_KEEPALIVE
+
     def publish(self, topic: RealtimeTopic, payload: str | bytes | dict) -> asyncio.Future:
         if isinstance(payload, dict):
             payload = json.dumps(payload)
@@ -675,13 +680,13 @@ class AndroidMQTT:
             payload = payload.encode("utf-8")
         self.log.trace(f"Publishing message in {topic.value} ({topic.encoded}): {payload}")
         payload = zlib.compress(payload, level=9)
+        self.set_request_keepalive()
         info = self._client.publish(topic.encoded, payload, qos=1)
         self.log.trace(f"Published message ID: {info.mid}")
         fut = self._loop.create_future()
-        timeout_handle = self._loop.call_later(
-            REQUEST_PUBLISH_TIMEOUT, self._publish_cancel_later, fut
-        )
+        timeout_handle = self._loop.call_later(REQUEST_TIMEOUT, self._publish_cancel_later, fut)
         fut.add_done_callback(lambda _: timeout_handle.cancel())
+        fut.add_done_callback(lambda _: self.maybe_reset_keepalive())
         self._publish_waiters[info.mid] = fut
         return fut
 
@@ -695,23 +700,15 @@ class AndroidMQTT:
         async with self._response_waiter_locks[response]:
             fut = self._loop.create_future()
             self._response_waiters[response] = fut
-            try:
-                await self.publish(topic, payload)
-            except asyncio.TimeoutError:
-                self.log.warning("Publish timed out - try forcing reconnect")
-                await self._reconnect()
-            except MQTTNotConnected:
-                self.log.warning(
-                    "MQTT disconnected before PUBACK - wait a hot minute, we should get "
-                    "the response after we auto reconnect"
-                )
+            background_task.create(self.publish(topic, payload))
             self.log.trace(
-                f"Request published to {topic.value}, waiting for response {response.name}"
+                f"Request publish to {topic.value} queued, waiting for response {response.name}"
             )
             timeout_handle = self._loop.call_later(
-                timeout or REQUEST_RESPONSE_TIMEOUT, self._request_cancel_later, fut
+                timeout or REQUEST_TIMEOUT, self._request_cancel_later, fut
             )
             fut.add_done_callback(lambda _: timeout_handle.cancel())
+            fut.add_done_callback(lambda _: self.maybe_reset_keepalive())
             return await fut
 
     async def iris_subscribe(self, seq_id: int, snapshot_at_ms: int) -> None:
@@ -792,29 +789,14 @@ class AndroidMQTT:
                 self.log.warning(f"Waited {lock_wait_dur:.3f} seconds to send {client_context}")
             fut = self._message_response_waiter = asyncio.Future()
             self._message_response_waiter_id = client_context
-            self.log.debug(f"Publishing {action} to {thread_id} with {client_context}")
-            try:
-                await self.publish(RealtimeTopic.SEND_MESSAGE, req)
-            except asyncio.TimeoutError:
-                self.log.warning("Publish timed out - try forcing reconnect")
-                await self._reconnect()
-            except MQTTNotConnected:
-                self.log.warning(
-                    "MQTT disconnected before PUBACK - wait a hot minute, we should get "
-                    "the response after we auto reconnect"
-                )
+            background_task.create(self.publish(RealtimeTopic.SEND_MESSAGE, req))
             self.log.debug(
-                f"Request published to {RealtimeTopic.SEND_MESSAGE}, "
+                f"Request publish to {RealtimeTopic.SEND_MESSAGE} queued, "
                 f"waiting for response {RealtimeTopic.SEND_MESSAGE_RESPONSE}"
             )
-            # If we don't have a response in req timeout / 2, force reconnect
-            reconnect_handle = self._loop.call_later(
-                REQUEST_RESPONSE_TIMEOUT / 2,
-                lambda: self._loop.create_task(self._reconnect()),
-            )
-            fut.add_done_callback(lambda _: reconnect_handle.cancel())
+            fut.add_done_callback(lambda _: self.maybe_reset_keepalive())
             try:
-                resp = await asyncio.wait_for(fut, timeout=REQUEST_RESPONSE_TIMEOUT)
+                resp = await asyncio.wait_for(fut, timeout=REQUEST_TIMEOUT)
             except asyncio.TimeoutError:
                 self.log.error(f"Request with ID {client_context} timed out!")
                 raise