瀏覽代碼

Improve MQTT disconnection handling

Tulir Asokan 4 年之前
父節點
當前提交
992c09b7a9

+ 1 - 1
mauigpapi/errors/__init__.py

@@ -1,5 +1,5 @@
 from .base import IGError
-from .mqtt import IGMQTTError, NotLoggedIn, NotConnected
+from .mqtt import IGMQTTError, MQTTNotLoggedIn, MQTTNotConnected
 from .state import IGUserIDNotFoundError, IGCookieNotFoundError, IGNoCheckpointError
 from .response import (IGResponseError, IGActionSpamError, IGNotFoundError, IGRateLimitError,
                        IGCheckpointError, IGUserHasLoggedOutError, IGLoginRequiredError,

+ 2 - 2
mauigpapi/errors/mqtt.py

@@ -20,9 +20,9 @@ class IGMQTTError(IGError):
     pass
 
 
-class NotLoggedIn(IGMQTTError):
+class MQTTNotLoggedIn(IGMQTTError):
     pass
 
 
-class NotConnected(IGMQTTError):
+class MQTTNotConnected(IGMQTTError):
     pass

+ 4 - 4
mauigpapi/mqtt/conn.py

@@ -31,7 +31,7 @@ from paho.mqtt.client import MQTTMessage, WebsocketConnectionError
 from yarl import URL
 from mautrix.util.logging import TraceLogger
 
-from ..errors import NotLoggedIn, NotConnected
+from ..errors import MQTTNotLoggedIn, MQTTNotConnected
 from ..state import AndroidState
 from ..types import (CommandResponse, ThreadItemType, ThreadAction, ReactionStatus, TypingStatus,
                      IrisPayload, PubsubPayload, AppPresenceEventPayload, RealtimeDirectEvent,
@@ -363,7 +363,7 @@ class AndroidMQTT:
             self._client.reconnect()
         except (SocketError, OSError, WebsocketConnectionError) as e:
             # TODO custom class
-            raise NotLoggedIn("MQTT reconnection failed") from e
+            raise MQTTNotLoggedIn("MQTT reconnection failed") from e
 
     def add_event_handler(self, evt_type: Type[T], handler: Callable[[T], Awaitable[None]]
                           ) -> None:
@@ -414,10 +414,10 @@ class AndroidMQTT:
                     # See https://github.com/eclipse/paho.mqtt.python/issues/340
                     await self._dispatch(Disconnect(reason="Connection lost, retrying"))
                 elif rc == paho.mqtt.client.MQTT_ERR_CONN_REFUSED:
-                    raise NotLoggedIn("MQTT connection refused")
+                    raise MQTTNotLoggedIn("MQTT connection refused")
                 elif rc == paho.mqtt.client.MQTT_ERR_NO_CONN:
                     if exit_if_not_connected:
-                        raise NotConnected("MQTT error: no connection")
+                        raise MQTTNotConnected("MQTT error: no connection")
                     await self._dispatch(Disconnect(reason="MQTT Error: no connection, retrying"))
                 else:
                     err = paho.mqtt.client.error_string(rc)

+ 1 - 1
mautrix_instagram/__main__.py

@@ -73,7 +73,7 @@ class InstagramBridge(Bridge):
         await super().start()
 
     def prepare_stop(self) -> None:
-        self.add_shutdown_actions(user.stop() for user in User.by_igpk.values())
+        self.add_shutdown_actions(user.stop_listen() for user in User.by_igpk.values())
         self.log.debug("Stopping puppet syncers")
         for puppet in Puppet.by_custom_mxid.values():
             puppet.stop()

+ 26 - 1
mautrix_instagram/commands/conn.py

@@ -46,6 +46,12 @@ async def ping(evt: CommandEvent) -> None:
         user = user_info.user
         await evt.reply(f"You're logged in as {user.full_name} ([@{user.username}]"
                         f"(https://instagram.com/{user.username}), user ID: {user.pk})")
+    if evt.sender.is_connected:
+        await evt.reply("MQTT connection is active")
+    elif evt.sender.mqtt and evt.sender._listen_task:
+        await evt.reply("MQTT connection is reconnecting")
+    else:
+        await evt.reply("MQTT not connected")
 
 
 @command_handler(needs_auth=True, management_only=False, help_section=SECTION_CONNECTION,
@@ -55,4 +61,23 @@ async def sync(evt: CommandEvent) -> None:
     await evt.reply("Synchronization complete")
 
 
-# TODO connect/disconnect MQTT commands
+@command_handler(needs_auth=True, management_only=False, help_section=SECTION_CONNECTION,
+                 help_text="Connect to Instagram", aliases=["reconnect"])
+async def connect(evt: CommandEvent) -> None:
+    if evt.sender.is_connected:
+        await evt.reply("You're already connected to Instagram.")
+        return
+    await evt.sender.stop_listen()
+    evt.sender.shutdown = False
+    await evt.sender.start_listen()
+    await evt.reply("Restarted connection to Instagram.")
+
+
+@command_handler(needs_auth=True, management_only=False, help_section=SECTION_CONNECTION,
+                 help_text="Disconnect from Instagram")
+async def disconnect(evt: CommandEvent) -> None:
+    if not evt.sender.mqtt:
+        await evt.reply("You're not connected to Instagram.")
+    await evt.sender.stop_listen()
+    evt.sender.shutdown = False
+    await evt.reply("Successfully disconnected from Instagram.")

+ 1 - 0
mautrix_instagram/config.py

@@ -79,6 +79,7 @@ class Config(BaseBridgeConfig):
         copy("bridge.delivery_receipts")
         copy("bridge.delivery_error_reports")
         copy("bridge.resend_bridge_info")
+        copy("bridge.unimportant_bridge_notices")
 
         copy("bridge.provisioning.enabled")
         copy("bridge.provisioning.prefix")

+ 3 - 0
mautrix_instagram/example-config.yaml

@@ -150,6 +150,9 @@ bridge:
     # This field will automatically be changed back to false after it,
     # except if the config file is not writable.
     resend_bridge_info: false
+    # Whether or not unimportant bridge notices should be sent to the user.
+    # (e.g. connected, disconnected but will retry)
+    unimportant_bridge_notices: true
 
     # Provisioning API part of the web server for automated portal creation and fetching information.
     # Used by things like mautrix-manager (https://github.com/tulir/mautrix-manager).

+ 26 - 9
mautrix_instagram/portal.py

@@ -130,11 +130,13 @@ class Portal(DBPortal, BasePortal):
             except Exception:
                 self.log.exception("Failed to send delivery receipt for %s", event_id)
 
-    async def _send_bridge_error(self, msg: str) -> None:
+    async def _send_bridge_error(self, msg: str, event_type: str = "message",
+                                 confirmed: bool = False) -> None:
         if self.config["bridge.delivery_error_reports"]:
+            error_type = "was not" if confirmed else "may not have been"
             await self._send_message(self.main_intent, TextMessageEventContent(
                 msgtype=MessageType.NOTICE,
-                body=f"\u26a0 Your message may not have been bridged: {msg}"))
+                body=f"\u26a0 Your {event_type} {error_type} bridged: {msg}"))
 
     async def _upsert_reaction(self, existing: DBReaction, intent: IntentAPI, mxid: EventID,
                                message: DBMessage, sender: Union['u.User', 'p.Puppet'],
@@ -155,13 +157,22 @@ class Portal(DBPortal, BasePortal):
 
     async def handle_matrix_message(self, sender: 'u.User', message: MessageEventContent,
                                     event_id: EventID) -> None:
-        if not sender.client:
-            self.log.debug(f"Ignoring message {event_id} as user is not connected")
-            return
-        elif ((message.get(self.bridge.real_user_content_key,
-                           False) and await p.Puppet.get_by_custom_mxid(sender.mxid))):
+        try:
+            await self._handle_matrix_message(sender, message, event_id)
+        except Exception:
+            self.log.exception(f"Fatal error handling Matrix event {event_id}")
+            await self._send_bridge_error("Fatal error in message handling "
+                                          "(see logs for more details)")
+
+    async def _handle_matrix_message(self, sender: 'u.User', message: MessageEventContent,
+                                     event_id: EventID) -> None:
+        if ((message.get(self.bridge.real_user_content_key, False)
+             and await p.Puppet.get_by_custom_mxid(sender.mxid))):
             self.log.debug(f"Ignoring puppet-sent message by confirmed puppet user {sender.mxid}")
             return
+        elif not sender.is_connected:
+            await self._send_bridge_error("You're not connected to Instagram", confirmed=True)
+            return
         request_id = str(uuid4())
         self._reqid_dedup.add(request_id)
         if message.msgtype in (MessageType.EMOTE, MessageType.TEXT):
@@ -194,7 +205,8 @@ class Portal(DBPortal, BasePortal):
                                                      upload_id=upload_resp.upload_id,
                                                      allow_full_aspect_ratio="1")
             else:
-                await self._send_bridge_error("Non-image files are currently not supported")
+                await self._send_bridge_error("Non-image files are currently not supported",
+                                              confirmed=True)
                 return
         else:
             return
@@ -234,6 +246,10 @@ class Portal(DBPortal, BasePortal):
                                       redaction_event_id: EventID) -> None:
         if not self.mxid:
             return
+        elif not sender.is_connected:
+            await self._send_bridge_error("You're not connected to Instagram",
+                                          event_type="redaction", confirmed=True)
+            return
 
         # TODO implement
         reaction = await DBReaction.get_by_mxid(event_id, self.mxid)
@@ -268,7 +284,8 @@ class Portal(DBPortal, BasePortal):
     async def _handle_matrix_typing(self, users: Set[UserID], status: TypingStatus) -> None:
         for mxid in users:
             user = await u.User.get_by_mxid(mxid, create=False)
-            if not user or not await user.is_logged_in() or user.remote_typing_status == status:
+            if ((not user or not await user.is_logged_in() or user.remote_typing_status == status
+                 or not user.is_connected)):
                 continue
             user.remote_typing_status = None
             await user.mqtt.indicate_activity(self.thread_id, status)

+ 53 - 10
mautrix_instagram/user.py

@@ -24,7 +24,7 @@ from mauigpapi import AndroidAPI, AndroidState, AndroidMQTT
 from mauigpapi.mqtt import Connect, Disconnect, GraphQLSubscription, SkywalkerSubscription
 from mauigpapi.types import (CurrentUser, MessageSyncEvent, Operation, RealtimeDirectEvent,
                              ActivityIndicatorData, TypingStatus, ThreadSyncEvent)
-from mauigpapi.errors import IGNotLoggedInError
+from mauigpapi.errors import IGNotLoggedInError, MQTTNotLoggedIn, MQTTNotConnected
 from mautrix.bridge import BaseUser
 from mautrix.types import UserID, RoomID, EventID, TextMessageEventContent, MessageType
 from mautrix.appservice import AppService
@@ -64,6 +64,8 @@ class User(DBUser, BaseUser):
     _notice_room_lock: asyncio.Lock
     _notice_send_lock: asyncio.Lock
     _is_logged_in: bool
+    _is_connected: bool
+    shutdown: bool
     remote_typing_status: Optional[TypingStatus]
 
     def __init__(self, mxid: UserID, igpk: Optional[int] = None,
@@ -81,6 +83,8 @@ class User(DBUser, BaseUser):
         self.dm_update_lock = asyncio.Lock()
         self._metric_value = defaultdict(lambda: False)
         self._is_logged_in = False
+        self._is_connected = False
+        self.shutdown = False
         self._listen_task = None
         self.command_status = None
         self.remote_typing_status = None
@@ -108,6 +112,10 @@ class User(DBUser, BaseUser):
     def api_log(self) -> TraceLogger:
         return self.ig_base_log.getChild("http").getChild(self.mxid)
 
+    @property
+    def is_connected(self) -> bool:
+        return bool(self.client) and bool(self.mqtt) and self._is_connected
+
     async def connect(self) -> None:
         client = AndroidAPI(self.state, log=self.api_log)
 
@@ -116,7 +124,8 @@ class User(DBUser, BaseUser):
         except IGNotLoggedInError as e:
             self.log.warning(f"Failed to connect to Instagram: {e}")
             # TODO show reason?
-            await self.send_bridge_notice("You have been logged out of Instagram")
+            await self.send_bridge_notice("You have been logged out of Instagram",
+                                          important=True)
             return
         self.client = client
         self._is_logged_in = True
@@ -141,10 +150,13 @@ class User(DBUser, BaseUser):
     async def on_connect(self, evt: Connect) -> None:
         self.log.debug("Connected to Instagram")
         self._track_metric(METRIC_CONNECTED, True)
+        self._is_connected = True
+        await self.send_bridge_notice("Connected to Instagram")
 
     async def on_disconnect(self, evt: Disconnect) -> None:
         self.log.debug("Disconnected from Instagram")
         self._track_metric(METRIC_CONNECTED, False)
+        self._is_connected = False
 
     # TODO this stuff could probably be moved to mautrix-python
     async def get_notice_room(self) -> RoomID:
@@ -162,6 +174,9 @@ class User(DBUser, BaseUser):
 
     async def send_bridge_notice(self, text: str, edit: Optional[EventID] = None,
                                  important: bool = False) -> Optional[EventID]:
+        if not important and not self.config["bridge.unimportant_bridge_notices"]:
+            self.log.debug("Not sending unimportant bridge notice: %s", text)
+            return
         event_id = None
         try:
             self.log.debug("Sending bridge notice: %s", text)
@@ -218,18 +233,46 @@ class User(DBUser, BaseUser):
                 self.log.debug(f"{thread.thread_id} is not active and doesn't have a portal")
         await self.update_direct_chats()
 
-        self._listen_task = self.loop.create_task(self.mqtt.listen(
-            graphql_subs={GraphQLSubscription.app_presence(),
-                          GraphQLSubscription.direct_typing(self.state.user_id),
-                          GraphQLSubscription.direct_status()},
-            skywalker_subs={SkywalkerSubscription.direct_sub(self.state.user_id),
-                            SkywalkerSubscription.live_sub(self.state.user_id)},
-            seq_id=resp.seq_id, snapshot_at_ms=resp.snapshot_at_ms))
+        if not self._listen_task:
+            await self.start_listen(resp.seq_id, resp.snapshot_at_ms)
+
+    async def start_listen(self, seq_id: Optional[int] = None, snapshot_at_ms: Optional[int] = None) -> None:
+        if not seq_id:
+            resp = await self.client.get_inbox(limit=1)
+            seq_id, snapshot_at_ms = resp.seq_id, resp.snapshot_at_ms
+        task = self.listen(seq_id=seq_id, snapshot_at_ms=snapshot_at_ms)
+        self._listen_task = self.loop.create_task(task)
 
-    async def stop(self) -> None:
+    async def listen(self, seq_id: int, snapshot_at_ms: int) -> None:
+        try:
+            await self.mqtt.listen(
+                graphql_subs={GraphQLSubscription.app_presence(),
+                              GraphQLSubscription.direct_typing(self.state.user_id),
+                              GraphQLSubscription.direct_status()},
+                skywalker_subs={SkywalkerSubscription.direct_sub(self.state.user_id),
+                                SkywalkerSubscription.live_sub(self.state.user_id)},
+                seq_id=seq_id, snapshot_at_ms=snapshot_at_ms)
+        except Exception:
+            self.log.exception("Fatal error in listener")
+            await self.send_bridge_notice("Fatal error in listener (see logs for more info)",
+                                          important=True)
+            self.mqtt.disconnect()
+            self._is_connected = False
+            self._track_metric(METRIC_CONNECTED, False)
+        else:
+            if not self.shutdown:
+                await self.send_bridge_notice("Instagram connection closed without error")
+        finally:
+            self._listen_task = None
+
+    async def stop_listen(self) -> None:
+        self.shutdown = True
         if self.mqtt:
             self.mqtt.disconnect()
+            if self._listen_task:
+                await self._listen_task
         self._track_metric(METRIC_CONNECTED, False)
+        self._is_connected = False
         await self.update()
 
     async def logout(self) -> None: