Răsfoiți Sursa

Fix sending messages across MQTT reset (#97)

* Fix up publish/response waiters with timeouts like FB

* Maintain message_response_waiter across MQTT resets

This parallels maintaining the response_waiters across MQTT resets.

Currently after MQTT reconnects the next waiting message sends
immediately, then gets it's message_response, then another
message_response comes in for the command sent immediately before the
MQTT reset. This change should block after MQTT reconnects until the
message_response from the last command comes in, then it can go about
it's way.
Scott Weber 2 ani în urmă
părinte
comite
2d3ca01565
1 a modificat fișierele cu 53 adăugiri și 24 ștergeri
  1. 53 24
      mauigpapi/mqtt/conn.py

+ 53 - 24
mauigpapi/mqtt/conn.py

@@ -80,6 +80,8 @@ ACTIVITY_INDICATOR_REGEX = re.compile(
 
 INBOX_THREAD_REGEX = re.compile(r"/direct_v2/inbox/threads/([\w_]+)")
 
+REQUEST_TIMEOUT = 30
+
 
 class AndroidMQTT:
     _loop: asyncio.AbstractEventLoop
@@ -172,24 +174,10 @@ class AndroidMQTT:
                     proxy_password=proxy_url.password,
                 )
 
-    def _clear_response_waiters(self) -> None:
-        for waiter in self._response_waiters.values():
-            if not waiter.done():
-                waiter.set_exception(
-                    MQTTNotConnected("MQTT disconnected before request returned response")
-                )
+    def _clear_publish_waiters(self) -> None:
         for waiter in self._publish_waiters.values():
             if not waiter.done():
-                waiter.set_exception(
-                    MQTTNotConnected("MQTT disconnected before request was published")
-                )
-        if self._message_response_waiter and not self._message_response_waiter.done():
-            self._message_response_waiter.set_exception(
-                MQTTNotConnected("MQTT disconnected before message send returned response")
-            )
-            self._message_response_waiter = None
-            self._message_response_waiter_id = None
-        self._response_waiters = {}
+                waiter.set_exception(MQTTNotConnected("MQTT disconnected before PUBACK received"))
         self._publish_waiters = {}
 
     def _form_client_id(self) -> bytes:
@@ -275,7 +263,7 @@ class AndroidMQTT:
     def _on_disconnect_handler(self, client: MQTToTClient, _: Any, rc: int) -> None:
         err_str = "Generic error." if rc == pmc.MQTT_ERR_NOMEM else pmc.error_string(rc)
         self.log.debug(f"MQTT disconnection code %d: %s", rc, err_str)
-        self._clear_response_waiters()
+        self._clear_publish_waiters()
 
     async def _post_connect(self) -> None:
         await self._dispatch(Connect())
@@ -516,8 +504,15 @@ class AndroidMQTT:
                         "No handler for MQTT message in %s: %s", topic.value, message.payload
                     )
                 else:
-                    self.log.trace("Got response %s: %s", topic.value, message.payload)
-                    waiter.set_result(message)
+                    if not waiter.done():
+                        waiter.set_result(message)
+                        self.log.trace("Got response %s: %s", topic.value, message.payload)
+                    else:
+                        self.log.debug(
+                            "Got response in %s, but waiter was already cancelled: %s",
+                            topic,
+                            message.payload,
+                        )
         except Exception:
             self.log.exception("Error in incoming MQTT message handler")
             self.log.trace("Errored MQTT payload: %s", message.payload)
@@ -649,6 +644,16 @@ class AndroidMQTT:
 
     # region Basic outgoing MQTT
 
+    @staticmethod
+    def _publish_cancel_later(fut: asyncio.Future) -> None:
+        if not fut.done():
+            fut.set_exception(asyncio.TimeoutError("MQTT publish timed out"))
+
+    @staticmethod
+    def _request_cancel_later(fut: asyncio.Future) -> None:
+        if not fut.done():
+            fut.set_exception(asyncio.TimeoutError("MQTT request timed out"))
+
     def publish(self, topic: RealtimeTopic, payload: str | bytes | dict) -> asyncio.Future:
         if isinstance(payload, dict):
             payload = json.dumps(payload)
@@ -658,7 +663,9 @@ class AndroidMQTT:
         payload = zlib.compress(payload, level=9)
         info = self._client.publish(topic.encoded, payload, qos=1)
         self.log.trace(f"Published message ID: {info.mid}")
-        fut = asyncio.Future()
+        fut = self._loop.create_future()
+        timeout_handle = self._loop.call_later(REQUEST_TIMEOUT, self._publish_cancel_later, fut)
+        fut.add_done_callback(lambda _: timeout_handle.cancel())
         self._publish_waiters[info.mid] = fut
         return fut
 
@@ -670,13 +677,26 @@ class AndroidMQTT:
         timeout: int | None = None,
     ) -> pmc.MQTTMessage:
         async with self._response_waiter_locks[response]:
-            fut = asyncio.Future()
+            fut = self._loop.create_future()
             self._response_waiters[response] = fut
-            await self.publish(topic, payload)
+            try:
+                await self.publish(topic, payload)
+            except asyncio.TimeoutError:
+                self.log.warning("Publish timed out - try forcing reconnect")
+                self._client.reconnect()
+            except MQTTNotConnected:
+                self.log.warning(
+                    "MQTT disconnected before PUBACK - wait a hot minute, we should get "
+                    "the response after we auto reconnect"
+                )
             self.log.trace(
                 f"Request published to {topic.value}, waiting for response {response.name}"
             )
-            return await asyncio.wait_for(fut, timeout)
+            timeout_handle = self._loop.call_later(
+                timeout or REQUEST_TIMEOUT, self._request_cancel_later, fut
+            )
+            fut.add_done_callback(lambda _: timeout_handle.cancel())
+            return await fut
 
     async def iris_subscribe(self, seq_id: int, snapshot_at_ms: int) -> None:
         self.log.debug(f"Requesting iris subscribe {seq_id}/{snapshot_at_ms}")
@@ -757,7 +777,16 @@ class AndroidMQTT:
             fut = self._message_response_waiter = asyncio.Future()
             self._message_response_waiter_id = client_context
             self.log.debug(f"Publishing {action} to {thread_id} with {client_context}")
-            await self.publish(RealtimeTopic.SEND_MESSAGE, req)
+            try:
+                await self.publish(RealtimeTopic.SEND_MESSAGE, req)
+            except asyncio.TimeoutError:
+                self.log.warning("Publish timed out - try forcing reconnect")
+                self._client.reconnect()
+            except MQTTNotConnected:
+                self.log.warning(
+                    "MQTT disconnected before PUBACK - wait a hot minute, we should get "
+                    "the response after we auto reconnect"
+                )
             self.log.trace(
                 f"Request published to {RealtimeTopic.SEND_MESSAGE}, "
                 f"waiting for response {RealtimeTopic.SEND_MESSAGE_RESPONSE}"