فهرست منبع

Recheck auth if MQTT disconnects repeatedly

Tulir Asokan 3 سال پیش
والد
کامیت
0d86dd759d
4فایلهای تغییر یافته به همراه67 افزوده شده و 18 حذف شده
  1. 7 1
      mauigpapi/errors/__init__.py
  2. 5 0
      mauigpapi/errors/mqtt.py
  3. 28 16
      mauigpapi/mqtt/conn.py
  4. 27 1
      mautrix_instagram/user.py

+ 7 - 1
mauigpapi/errors/__init__.py

@@ -1,5 +1,11 @@
 from .base import IGError
-from .mqtt import IGMQTTError, IrisSubscribeError, MQTTNotConnected, MQTTNotLoggedIn
+from .mqtt import (
+    IGMQTTError,
+    IrisSubscribeError,
+    MQTTConnectionUnauthorized,
+    MQTTNotConnected,
+    MQTTNotLoggedIn,
+)
 from .response import (
     IGActionSpamError,
     IGBad2FACodeError,

+ 5 - 0
mauigpapi/errors/mqtt.py

@@ -28,6 +28,11 @@ class MQTTNotConnected(IGMQTTError):
     pass
 
 
+class MQTTConnectionUnauthorized(IGMQTTError):
+    def __init__(self) -> None:
+        super().__init__("Server refused connection with error code 5")
+
+
 class IrisSubscribeError(IGMQTTError):
     def __init__(self, type: str, message: str) -> None:
         super().__init__(f"{type}: {message}")

+ 28 - 16
mauigpapi/mqtt/conn.py

@@ -26,13 +26,17 @@ import time
 import urllib.request
 import zlib
 
-from paho.mqtt.client import MQTTMessage, WebsocketConnectionError
 from yarl import URL
-import paho.mqtt.client
+import paho.mqtt.client as pmc
 
 from mautrix.util.logging import TraceLogger
 
-from ..errors import IrisSubscribeError, MQTTNotConnected, MQTTNotLoggedIn
+from ..errors import (
+    IrisSubscribeError,
+    MQTTConnectionUnauthorized,
+    MQTTNotConnected,
+    MQTTNotLoggedIn,
+)
 from ..state import AndroidState
 from ..types import (
     AppPresenceEventPayload,
@@ -116,7 +120,7 @@ class AndroidMQTT:
         self._client = MQTToTClient(
             client_id=self._form_client_id(),
             clean_session=True,
-            protocol=paho.mqtt.client.MQTTv31,
+            protocol=pmc.MQTTv31,
             transport="tcp",
         )
         try:
@@ -149,7 +153,7 @@ class AndroidMQTT:
         self._client.on_message = self._on_message_handler
         self._client.on_publish = self._on_publish_handler
         self._client.on_connect = self._on_connect_handler
-        # self._client.on_disconnect = self._on_disconnect_handler
+        self._client.on_disconnect = self._on_disconnect_handler
         self._client.on_socket_open = self._on_socket_open
         self._client.on_socket_close = self._on_socket_close
         self._client.on_socket_register_write = self._on_socket_register_write
@@ -223,12 +227,20 @@ class AndroidMQTT:
         self, client: MQTToTClient, _: Any, flags: dict[str, Any], rc: int
     ) -> None:
         if rc != 0:
-            err = paho.mqtt.client.connack_string(rc)
+            err = pmc.connack_string(rc)
             self.log.error("MQTT Connection Error: %s (%d)", err, rc)
+            if rc == pmc.CONNACK_REFUSED_NOT_AUTHORIZED:
+                self._disconnect_error = MQTTConnectionUnauthorized()
+                self.disconnect()
             return
 
         self._loop.create_task(self._post_connect())
 
+    def _on_disconnect_handler(self, client: MQTToTClient, _: Any, rc: int) -> None:
+        err_str = "Generic error." if rc == pmc.MQTT_ERR_NOMEM else pmc.error_string(rc)
+        self.log.debug(f"MQTT disconnection code %d: %s", rc, err_str)
+        # self._clear_response_waiters()
+
     async def _post_connect(self) -> None:
         await self._dispatch(Connect())
         self.log.debug("Re-subscribing to things after connect")
@@ -402,7 +414,7 @@ class AndroidMQTT:
         for evt in self._parse_realtime_sub_item(topic, parsed_json):
             self._loop.create_task(self._dispatch(evt))
 
-    def _on_message_handler(self, client: MQTToTClient, _: Any, message: MQTTMessage) -> None:
+    def _on_message_handler(self, client: MQTToTClient, _: Any, message: pmc.MQTTMessage) -> None:
         try:
             topic = RealtimeTopic.decode(message.topic)
             # Instagram Android MQTT messages are always compressed
@@ -445,7 +457,7 @@ class AndroidMQTT:
         try:
             self.log.trace("Trying to reconnect to MQTT")
             self._client.reconnect()
-        except (SocketError, OSError, WebsocketConnectionError) as e:
+        except (SocketError, OSError, pmc.WebsocketConnectionError) as e:
             # TODO custom class
             raise MQTTNotLoggedIn("MQTT reconnection failed") from e
 
@@ -513,20 +525,20 @@ class AndroidMQTT:
 
             # If disconnect() has been called
             # Beware, internal API, may have to change this to something more stable!
-            if self._client._state == paho.mqtt.client.mqtt_cs_disconnecting:
+            if self._client._state == pmc.mqtt_cs_disconnecting:
                 break  # Stop listening
 
-            if rc != paho.mqtt.client.MQTT_ERR_SUCCESS:
+            if rc != pmc.MQTT_ERR_SUCCESS:
                 # If known/expected error
-                if rc == paho.mqtt.client.MQTT_ERR_CONN_LOST:
+                if rc == pmc.MQTT_ERR_CONN_LOST:
                     await self._dispatch(Disconnect(reason="Connection lost, retrying"))
-                elif rc == paho.mqtt.client.MQTT_ERR_NOMEM:
+                elif rc == pmc.MQTT_ERR_NOMEM:
                     # This error is wrongly classified
                     # See https://github.com/eclipse/paho.mqtt.python/issues/340
                     await self._dispatch(Disconnect(reason="Connection lost, retrying"))
-                elif rc == paho.mqtt.client.MQTT_ERR_CONN_REFUSED:
+                elif rc == pmc.MQTT_ERR_CONN_REFUSED:
                     raise MQTTNotLoggedIn("MQTT connection refused")
-                elif rc == paho.mqtt.client.MQTT_ERR_NO_CONN:
+                elif rc == pmc.MQTT_ERR_NO_CONN:
                     if connection_retries > retry_limit:
                         raise MQTTNotConnected(f"Connection failed {connection_retries} times")
                     sleep = connection_retries * 2
@@ -538,7 +550,7 @@ class AndroidMQTT:
                     )
                     await asyncio.sleep(sleep)
                 else:
-                    err = paho.mqtt.client.error_string(rc)
+                    err = pmc.error_string(rc)
                     self.log.error("MQTT Error: %s", err)
                     await self._dispatch(Disconnect(reason=f"MQTT Error: {err}, retrying"))
 
@@ -576,7 +588,7 @@ class AndroidMQTT:
         response: RealtimeTopic,
         payload: str | bytes | dict,
         timeout: int | None = None,
-    ) -> MQTTMessage:
+    ) -> pmc.MQTTMessage:
         async with self._response_waiter_locks[response]:
             fut = asyncio.Future()
             self._response_waiters[response] = fut

+ 27 - 1
mautrix_instagram/user.py

@@ -29,6 +29,7 @@ from mauigpapi.errors import (
     IGRateLimitError,
     IGUserIDNotFoundError,
     IrisSubscribeError,
+    MQTTConnectionUnauthorized,
     MQTTNotConnected,
     MQTTNotLoggedIn,
 )
@@ -509,6 +510,29 @@ class User(DBUser, BaseUser):
         )
         self._listen_task = self.loop.create_task(task)
 
+    async def fetch_user_and_reconnect(self) -> None:
+        self.log.debug("Refetching current user after disconnection")
+        try:
+            resp = await self.client.current_user()
+        except IGNotLoggedInError as e:
+            self.log.warning(f"Failed to reconnect to Instagram: {e}, logging out")
+            await self.logout(error=e)
+            return
+        except (IGChallengeError, IGConsentRequiredError) as e:
+            await self._handle_checkpoint(e, on="reconnect")
+            return
+        except Exception as e:
+            self.log.exception("Error while reconnecting to Instagram")
+            if isinstance(e, IGCheckpointError):
+                self.log.debug("Checkpoint error content: %s", e.body)
+            await self.push_bridge_state(
+                BridgeStateEvent.UNKNOWN_ERROR, info={"python_error": str(e)}
+            )
+            return
+        else:
+            self.log.debug(f"Confirmed current user {resp.user.pk}")
+            self.start_listen()
+
     async def _listen(self, seq_id: int, snapshot_at_ms: int, is_after_sync: bool) -> None:
         try:
             await self.mqtt.listen(
@@ -538,7 +562,8 @@ class User(DBUser, BaseUser):
             else:
                 self.log.warning(f"Got IrisSubscribeError {e}, refreshing...")
                 asyncio.create_task(self.refresh())
-        except (MQTTNotConnected, MQTTNotLoggedIn) as e:
+        except (MQTTNotConnected, MQTTNotLoggedIn, MQTTConnectionUnauthorized) as e:
+            self.log.warning(f"Unexpected connection error: {e}")
             await self.send_bridge_notice(
                 f"Error in listener: {e}",
                 important=True,
@@ -546,6 +571,7 @@ class User(DBUser, BaseUser):
                 error_code="ig-connection-error",
             )
             self.mqtt.disconnect()
+            asyncio.create_task(self.fetch_user_and_reconnect())
         except Exception as e:
             self.log.exception("Fatal error in listener")
             await self.send_bridge_notice(