Преглед на файлове

Retry iris subscribe if it doesn't respond or fails

Tulir Asokan преди 4 години
родител
ревизия
7dd2b1ad6a
променени са 3 файла, в които са добавени 29 реда и са изтрити 7 реда
  1. 1 1
      mauigpapi/errors/__init__.py
  2. 5 0
      mauigpapi/errors/mqtt.py
  3. 23 6
      mauigpapi/mqtt/conn.py

+ 1 - 1
mauigpapi/errors/__init__.py

@@ -1,5 +1,5 @@
 from .base import IGError
-from .mqtt import IGMQTTError, MQTTNotLoggedIn, MQTTNotConnected
+from .mqtt import IGMQTTError, MQTTNotLoggedIn, MQTTNotConnected, IrisSubscribeError
 from .state import IGUserIDNotFoundError, IGCookieNotFoundError, IGNoCheckpointError
 from .response import (IGResponseError, IGActionSpamError, IGNotFoundError, IGRateLimitError,
                        IGCheckpointError, IGUserHasLoggedOutError, IGLoginRequiredError,

+ 5 - 0
mauigpapi/errors/mqtt.py

@@ -26,3 +26,8 @@ class MQTTNotLoggedIn(IGMQTTError):
 
 class MQTTNotConnected(IGMQTTError):
     pass
+
+
+class IrisSubscribeError(IGMQTTError):
+    def __init__(self, type: str, message: str) -> None:
+        super().__init__(f"{type}: {message}")

+ 23 - 6
mauigpapi/mqtt/conn.py

@@ -31,7 +31,7 @@ from paho.mqtt.client import MQTTMessage, WebsocketConnectionError
 from yarl import URL
 from mautrix.util.logging import TraceLogger
 
-from ..errors import MQTTNotLoggedIn, MQTTNotConnected
+from ..errors import MQTTNotLoggedIn, MQTTNotConnected, IrisSubscribeError
 from ..state import AndroidState
 from ..types import (CommandResponse, ThreadItemType, ThreadAction, ReactionStatus, TypingStatus,
                      IrisPayload, PubsubPayload, AppPresenceEventPayload, RealtimeDirectEvent,
@@ -205,7 +205,20 @@ class AndroidMQTT:
             res = await self.skywalker_subscribe(self._skywalker_subs)
             self.log.trace("Skywalker subscribe response: %s", res)
         if self._iris_seq_id:
-            await self.iris_subscribe(self._iris_seq_id, self._iris_snapshot_at_ms)
+            retry = 0
+            while True:
+                try:
+                    await self.iris_subscribe(self._iris_seq_id, self._iris_snapshot_at_ms)
+                    break
+                except (asyncio.TimeoutError, IrisSubscribeError) as e:
+                    self.log.exception("Error requesting iris subscribe")
+                    retry += 1
+                    if retry >= 5:
+                        self._disconnect_error = e
+                        self.disconnect()
+                        break
+                    await asyncio.sleep(5)
+                    self.log.debug("Retrying iris subscribe")
 
     def _on_publish_handler(self, client: MQTToTClient, _: Any, mid: int) -> None:
         try:
@@ -474,21 +487,25 @@ class AndroidMQTT:
         return fut
 
     async def request(self, topic: RealtimeTopic, response: RealtimeTopic,
-                      payload: Union[str, bytes, dict]) -> MQTTMessage:
+                      payload: Union[str, bytes, dict], timeout: Optional[int] = None
+                      ) -> MQTTMessage:
         async with self._response_waiter_locks[response]:
             fut = asyncio.Future()
             self._response_waiters[response] = fut
             await self.publish(topic, payload)
             self.log.trace(f"Request published to {topic.value}, "
                            f"waiting for response {response.name}")
-            return await fut
+            return await asyncio.wait_for(fut, timeout)
 
     async def iris_subscribe(self, seq_id: int, snapshot_at_ms: int) -> None:
         self.log.debug(f"Requesting iris subscribe {seq_id}/{snapshot_at_ms}")
         resp = await self.request(RealtimeTopic.SUB_IRIS, RealtimeTopic.SUB_IRIS_RESPONSE,
-                                  {"seq_id": seq_id, "snapshot_at_ms": snapshot_at_ms})
-        # TODO check succeeded and raise error if needed
+                                  {"seq_id": seq_id, "snapshot_at_ms": snapshot_at_ms},
+                                  timeout=20 * 1000)
         self.log.debug("Iris subscribe response: %s", resp.payload.decode("utf-8"))
+        resp_dict = json.loads(resp.payload.decode("utf-8"))
+        if resp_dict["error_type"] and resp_dict["error_message"]:
+            raise IrisSubscribeError(resp_dict["error_type"], resp_dict["error_message"])
 
     def graphql_subscribe(self, subs: Set[str]) -> asyncio.Future:
         self._graphql_subs |= subs