Переглянути джерело

Retry iris subscribe if it doesn't respond or fails

Tulir Asokan 4 роки тому
батько
коміт
7dd2b1ad6a
3 змінених файлів з 29 додано та 7 видалено
  1. 1 1
      mauigpapi/errors/__init__.py
  2. 5 0
      mauigpapi/errors/mqtt.py
  3. 23 6
      mauigpapi/mqtt/conn.py

+ 1 - 1
mauigpapi/errors/__init__.py

@@ -1,5 +1,5 @@
 from .base import IGError
 from .base import IGError
-from .mqtt import IGMQTTError, MQTTNotLoggedIn, MQTTNotConnected
+from .mqtt import IGMQTTError, MQTTNotLoggedIn, MQTTNotConnected, IrisSubscribeError
 from .state import IGUserIDNotFoundError, IGCookieNotFoundError, IGNoCheckpointError
 from .state import IGUserIDNotFoundError, IGCookieNotFoundError, IGNoCheckpointError
 from .response import (IGResponseError, IGActionSpamError, IGNotFoundError, IGRateLimitError,
 from .response import (IGResponseError, IGActionSpamError, IGNotFoundError, IGRateLimitError,
                        IGCheckpointError, IGUserHasLoggedOutError, IGLoginRequiredError,
                        IGCheckpointError, IGUserHasLoggedOutError, IGLoginRequiredError,

+ 5 - 0
mauigpapi/errors/mqtt.py

@@ -26,3 +26,8 @@ class MQTTNotLoggedIn(IGMQTTError):
 
 
 class MQTTNotConnected(IGMQTTError):
 class MQTTNotConnected(IGMQTTError):
     pass
     pass
+
+
+class IrisSubscribeError(IGMQTTError):
+    def __init__(self, type: str, message: str) -> None:
+        super().__init__(f"{type}: {message}")

+ 23 - 6
mauigpapi/mqtt/conn.py

@@ -31,7 +31,7 @@ from paho.mqtt.client import MQTTMessage, WebsocketConnectionError
 from yarl import URL
 from yarl import URL
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
-from ..errors import MQTTNotLoggedIn, MQTTNotConnected
+from ..errors import MQTTNotLoggedIn, MQTTNotConnected, IrisSubscribeError
 from ..state import AndroidState
 from ..state import AndroidState
 from ..types import (CommandResponse, ThreadItemType, ThreadAction, ReactionStatus, TypingStatus,
 from ..types import (CommandResponse, ThreadItemType, ThreadAction, ReactionStatus, TypingStatus,
                      IrisPayload, PubsubPayload, AppPresenceEventPayload, RealtimeDirectEvent,
                      IrisPayload, PubsubPayload, AppPresenceEventPayload, RealtimeDirectEvent,
@@ -205,7 +205,20 @@ class AndroidMQTT:
             res = await self.skywalker_subscribe(self._skywalker_subs)
             res = await self.skywalker_subscribe(self._skywalker_subs)
             self.log.trace("Skywalker subscribe response: %s", res)
             self.log.trace("Skywalker subscribe response: %s", res)
         if self._iris_seq_id:
         if self._iris_seq_id:
-            await self.iris_subscribe(self._iris_seq_id, self._iris_snapshot_at_ms)
+            retry = 0
+            while True:
+                try:
+                    await self.iris_subscribe(self._iris_seq_id, self._iris_snapshot_at_ms)
+                    break
+                except (asyncio.TimeoutError, IrisSubscribeError) as e:
+                    self.log.exception("Error requesting iris subscribe")
+                    retry += 1
+                    if retry >= 5:
+                        self._disconnect_error = e
+                        self.disconnect()
+                        break
+                    await asyncio.sleep(5)
+                    self.log.debug("Retrying iris subscribe")
 
 
     def _on_publish_handler(self, client: MQTToTClient, _: Any, mid: int) -> None:
     def _on_publish_handler(self, client: MQTToTClient, _: Any, mid: int) -> None:
         try:
         try:
@@ -474,21 +487,25 @@ class AndroidMQTT:
         return fut
         return fut
 
 
     async def request(self, topic: RealtimeTopic, response: RealtimeTopic,
     async def request(self, topic: RealtimeTopic, response: RealtimeTopic,
-                      payload: Union[str, bytes, dict]) -> MQTTMessage:
+                      payload: Union[str, bytes, dict], timeout: Optional[int] = None
+                      ) -> MQTTMessage:
         async with self._response_waiter_locks[response]:
         async with self._response_waiter_locks[response]:
             fut = asyncio.Future()
             fut = asyncio.Future()
             self._response_waiters[response] = fut
             self._response_waiters[response] = fut
             await self.publish(topic, payload)
             await self.publish(topic, payload)
             self.log.trace(f"Request published to {topic.value}, "
             self.log.trace(f"Request published to {topic.value}, "
                            f"waiting for response {response.name}")
                            f"waiting for response {response.name}")
-            return await fut
+            return await asyncio.wait_for(fut, timeout)
 
 
     async def iris_subscribe(self, seq_id: int, snapshot_at_ms: int) -> None:
     async def iris_subscribe(self, seq_id: int, snapshot_at_ms: int) -> None:
         self.log.debug(f"Requesting iris subscribe {seq_id}/{snapshot_at_ms}")
         self.log.debug(f"Requesting iris subscribe {seq_id}/{snapshot_at_ms}")
         resp = await self.request(RealtimeTopic.SUB_IRIS, RealtimeTopic.SUB_IRIS_RESPONSE,
         resp = await self.request(RealtimeTopic.SUB_IRIS, RealtimeTopic.SUB_IRIS_RESPONSE,
-                                  {"seq_id": seq_id, "snapshot_at_ms": snapshot_at_ms})
-        # TODO check succeeded and raise error if needed
+                                  {"seq_id": seq_id, "snapshot_at_ms": snapshot_at_ms},
+                                  timeout=20 * 1000)
         self.log.debug("Iris subscribe response: %s", resp.payload.decode("utf-8"))
         self.log.debug("Iris subscribe response: %s", resp.payload.decode("utf-8"))
+        resp_dict = json.loads(resp.payload.decode("utf-8"))
+        if resp_dict["error_type"] and resp_dict["error_message"]:
+            raise IrisSubscribeError(resp_dict["error_type"], resp_dict["error_message"])
 
 
     def graphql_subscribe(self, subs: Set[str]) -> asyncio.Future:
     def graphql_subscribe(self, subs: Set[str]) -> asyncio.Future:
         self._graphql_subs |= subs
         self._graphql_subs |= subs