Преглед изворни кода

Improve MQTT disconnection handling

Tulir Asokan пре 4 година
родитељ
комит
992c09b7a9

+ 1 - 1
mauigpapi/errors/__init__.py

@@ -1,5 +1,5 @@
 from .base import IGError
 from .base import IGError
-from .mqtt import IGMQTTError, NotLoggedIn, NotConnected
+from .mqtt import IGMQTTError, MQTTNotLoggedIn, MQTTNotConnected
 from .state import IGUserIDNotFoundError, IGCookieNotFoundError, IGNoCheckpointError
 from .state import IGUserIDNotFoundError, IGCookieNotFoundError, IGNoCheckpointError
 from .response import (IGResponseError, IGActionSpamError, IGNotFoundError, IGRateLimitError,
 from .response import (IGResponseError, IGActionSpamError, IGNotFoundError, IGRateLimitError,
                        IGCheckpointError, IGUserHasLoggedOutError, IGLoginRequiredError,
                        IGCheckpointError, IGUserHasLoggedOutError, IGLoginRequiredError,

+ 2 - 2
mauigpapi/errors/mqtt.py

@@ -20,9 +20,9 @@ class IGMQTTError(IGError):
     pass
     pass
 
 
 
 
-class NotLoggedIn(IGMQTTError):
+class MQTTNotLoggedIn(IGMQTTError):
     pass
     pass
 
 
 
 
-class NotConnected(IGMQTTError):
+class MQTTNotConnected(IGMQTTError):
     pass
     pass

+ 4 - 4
mauigpapi/mqtt/conn.py

@@ -31,7 +31,7 @@ from paho.mqtt.client import MQTTMessage, WebsocketConnectionError
 from yarl import URL
 from yarl import URL
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
-from ..errors import NotLoggedIn, NotConnected
+from ..errors import MQTTNotLoggedIn, MQTTNotConnected
 from ..state import AndroidState
 from ..state import AndroidState
 from ..types import (CommandResponse, ThreadItemType, ThreadAction, ReactionStatus, TypingStatus,
 from ..types import (CommandResponse, ThreadItemType, ThreadAction, ReactionStatus, TypingStatus,
                      IrisPayload, PubsubPayload, AppPresenceEventPayload, RealtimeDirectEvent,
                      IrisPayload, PubsubPayload, AppPresenceEventPayload, RealtimeDirectEvent,
@@ -363,7 +363,7 @@ class AndroidMQTT:
             self._client.reconnect()
             self._client.reconnect()
         except (SocketError, OSError, WebsocketConnectionError) as e:
         except (SocketError, OSError, WebsocketConnectionError) as e:
             # TODO custom class
             # TODO custom class
-            raise NotLoggedIn("MQTT reconnection failed") from e
+            raise MQTTNotLoggedIn("MQTT reconnection failed") from e
 
 
     def add_event_handler(self, evt_type: Type[T], handler: Callable[[T], Awaitable[None]]
     def add_event_handler(self, evt_type: Type[T], handler: Callable[[T], Awaitable[None]]
                           ) -> None:
                           ) -> None:
@@ -414,10 +414,10 @@ class AndroidMQTT:
                     # See https://github.com/eclipse/paho.mqtt.python/issues/340
                     # See https://github.com/eclipse/paho.mqtt.python/issues/340
                     await self._dispatch(Disconnect(reason="Connection lost, retrying"))
                     await self._dispatch(Disconnect(reason="Connection lost, retrying"))
                 elif rc == paho.mqtt.client.MQTT_ERR_CONN_REFUSED:
                 elif rc == paho.mqtt.client.MQTT_ERR_CONN_REFUSED:
-                    raise NotLoggedIn("MQTT connection refused")
+                    raise MQTTNotLoggedIn("MQTT connection refused")
                 elif rc == paho.mqtt.client.MQTT_ERR_NO_CONN:
                 elif rc == paho.mqtt.client.MQTT_ERR_NO_CONN:
                     if exit_if_not_connected:
                     if exit_if_not_connected:
-                        raise NotConnected("MQTT error: no connection")
+                        raise MQTTNotConnected("MQTT error: no connection")
                     await self._dispatch(Disconnect(reason="MQTT Error: no connection, retrying"))
                     await self._dispatch(Disconnect(reason="MQTT Error: no connection, retrying"))
                 else:
                 else:
                     err = paho.mqtt.client.error_string(rc)
                     err = paho.mqtt.client.error_string(rc)

+ 1 - 1
mautrix_instagram/__main__.py

@@ -73,7 +73,7 @@ class InstagramBridge(Bridge):
         await super().start()
         await super().start()
 
 
     def prepare_stop(self) -> None:
     def prepare_stop(self) -> None:
-        self.add_shutdown_actions(user.stop() for user in User.by_igpk.values())
+        self.add_shutdown_actions(user.stop_listen() for user in User.by_igpk.values())
         self.log.debug("Stopping puppet syncers")
         self.log.debug("Stopping puppet syncers")
         for puppet in Puppet.by_custom_mxid.values():
         for puppet in Puppet.by_custom_mxid.values():
             puppet.stop()
             puppet.stop()

+ 26 - 1
mautrix_instagram/commands/conn.py

@@ -46,6 +46,12 @@ async def ping(evt: CommandEvent) -> None:
         user = user_info.user
         user = user_info.user
         await evt.reply(f"You're logged in as {user.full_name} ([@{user.username}]"
         await evt.reply(f"You're logged in as {user.full_name} ([@{user.username}]"
                         f"(https://instagram.com/{user.username}), user ID: {user.pk})")
                         f"(https://instagram.com/{user.username}), user ID: {user.pk})")
+    if evt.sender.is_connected:
+        await evt.reply("MQTT connection is active")
+    elif evt.sender.mqtt and evt.sender._listen_task:
+        await evt.reply("MQTT connection is reconnecting")
+    else:
+        await evt.reply("MQTT not connected")
 
 
 
 
 @command_handler(needs_auth=True, management_only=False, help_section=SECTION_CONNECTION,
 @command_handler(needs_auth=True, management_only=False, help_section=SECTION_CONNECTION,
@@ -55,4 +61,23 @@ async def sync(evt: CommandEvent) -> None:
     await evt.reply("Synchronization complete")
     await evt.reply("Synchronization complete")
 
 
 
 
-# TODO connect/disconnect MQTT commands
+@command_handler(needs_auth=True, management_only=False, help_section=SECTION_CONNECTION,
+                 help_text="Connect to Instagram", aliases=["reconnect"])
+async def connect(evt: CommandEvent) -> None:
+    if evt.sender.is_connected:
+        await evt.reply("You're already connected to Instagram.")
+        return
+    await evt.sender.stop_listen()
+    evt.sender.shutdown = False
+    await evt.sender.start_listen()
+    await evt.reply("Restarted connection to Instagram.")
+
+
+@command_handler(needs_auth=True, management_only=False, help_section=SECTION_CONNECTION,
+                 help_text="Disconnect from Instagram")
+async def disconnect(evt: CommandEvent) -> None:
+    if not evt.sender.mqtt:
+        await evt.reply("You're not connected to Instagram.")
+    await evt.sender.stop_listen()
+    evt.sender.shutdown = False
+    await evt.reply("Successfully disconnected from Instagram.")

+ 1 - 0
mautrix_instagram/config.py

@@ -79,6 +79,7 @@ class Config(BaseBridgeConfig):
         copy("bridge.delivery_receipts")
         copy("bridge.delivery_receipts")
         copy("bridge.delivery_error_reports")
         copy("bridge.delivery_error_reports")
         copy("bridge.resend_bridge_info")
         copy("bridge.resend_bridge_info")
+        copy("bridge.unimportant_bridge_notices")
 
 
         copy("bridge.provisioning.enabled")
         copy("bridge.provisioning.enabled")
         copy("bridge.provisioning.prefix")
         copy("bridge.provisioning.prefix")

+ 3 - 0
mautrix_instagram/example-config.yaml

@@ -150,6 +150,9 @@ bridge:
     # This field will automatically be changed back to false after it,
     # This field will automatically be changed back to false after it,
     # except if the config file is not writable.
     # except if the config file is not writable.
     resend_bridge_info: false
     resend_bridge_info: false
+    # Whether or not unimportant bridge notices should be sent to the user.
+    # (e.g. connected, disconnected but will retry)
+    unimportant_bridge_notices: true
 
 
     # Provisioning API part of the web server for automated portal creation and fetching information.
     # Provisioning API part of the web server for automated portal creation and fetching information.
     # Used by things like mautrix-manager (https://github.com/tulir/mautrix-manager).
     # Used by things like mautrix-manager (https://github.com/tulir/mautrix-manager).

+ 26 - 9
mautrix_instagram/portal.py

@@ -130,11 +130,13 @@ class Portal(DBPortal, BasePortal):
             except Exception:
             except Exception:
                 self.log.exception("Failed to send delivery receipt for %s", event_id)
                 self.log.exception("Failed to send delivery receipt for %s", event_id)
 
 
-    async def _send_bridge_error(self, msg: str) -> None:
+    async def _send_bridge_error(self, msg: str, event_type: str = "message",
+                                 confirmed: bool = False) -> None:
         if self.config["bridge.delivery_error_reports"]:
         if self.config["bridge.delivery_error_reports"]:
+            error_type = "was not" if confirmed else "may not have been"
             await self._send_message(self.main_intent, TextMessageEventContent(
             await self._send_message(self.main_intent, TextMessageEventContent(
                 msgtype=MessageType.NOTICE,
                 msgtype=MessageType.NOTICE,
-                body=f"\u26a0 Your message may not have been bridged: {msg}"))
+                body=f"\u26a0 Your {event_type} {error_type} bridged: {msg}"))
 
 
     async def _upsert_reaction(self, existing: DBReaction, intent: IntentAPI, mxid: EventID,
     async def _upsert_reaction(self, existing: DBReaction, intent: IntentAPI, mxid: EventID,
                                message: DBMessage, sender: Union['u.User', 'p.Puppet'],
                                message: DBMessage, sender: Union['u.User', 'p.Puppet'],
@@ -155,13 +157,22 @@ class Portal(DBPortal, BasePortal):
 
 
     async def handle_matrix_message(self, sender: 'u.User', message: MessageEventContent,
     async def handle_matrix_message(self, sender: 'u.User', message: MessageEventContent,
                                     event_id: EventID) -> None:
                                     event_id: EventID) -> None:
-        if not sender.client:
-            self.log.debug(f"Ignoring message {event_id} as user is not connected")
-            return
-        elif ((message.get(self.bridge.real_user_content_key,
-                           False) and await p.Puppet.get_by_custom_mxid(sender.mxid))):
+        try:
+            await self._handle_matrix_message(sender, message, event_id)
+        except Exception:
+            self.log.exception(f"Fatal error handling Matrix event {event_id}")
+            await self._send_bridge_error("Fatal error in message handling "
+                                          "(see logs for more details)")
+
+    async def _handle_matrix_message(self, sender: 'u.User', message: MessageEventContent,
+                                     event_id: EventID) -> None:
+        if ((message.get(self.bridge.real_user_content_key, False)
+             and await p.Puppet.get_by_custom_mxid(sender.mxid))):
             self.log.debug(f"Ignoring puppet-sent message by confirmed puppet user {sender.mxid}")
             self.log.debug(f"Ignoring puppet-sent message by confirmed puppet user {sender.mxid}")
             return
             return
+        elif not sender.is_connected:
+            await self._send_bridge_error("You're not connected to Instagram", confirmed=True)
+            return
         request_id = str(uuid4())
         request_id = str(uuid4())
         self._reqid_dedup.add(request_id)
         self._reqid_dedup.add(request_id)
         if message.msgtype in (MessageType.EMOTE, MessageType.TEXT):
         if message.msgtype in (MessageType.EMOTE, MessageType.TEXT):
@@ -194,7 +205,8 @@ class Portal(DBPortal, BasePortal):
                                                      upload_id=upload_resp.upload_id,
                                                      upload_id=upload_resp.upload_id,
                                                      allow_full_aspect_ratio="1")
                                                      allow_full_aspect_ratio="1")
             else:
             else:
-                await self._send_bridge_error("Non-image files are currently not supported")
+                await self._send_bridge_error("Non-image files are currently not supported",
+                                              confirmed=True)
                 return
                 return
         else:
         else:
             return
             return
@@ -234,6 +246,10 @@ class Portal(DBPortal, BasePortal):
                                       redaction_event_id: EventID) -> None:
                                       redaction_event_id: EventID) -> None:
         if not self.mxid:
         if not self.mxid:
             return
             return
+        elif not sender.is_connected:
+            await self._send_bridge_error("You're not connected to Instagram",
+                                          event_type="redaction", confirmed=True)
+            return
 
 
         # TODO implement
         # TODO implement
         reaction = await DBReaction.get_by_mxid(event_id, self.mxid)
         reaction = await DBReaction.get_by_mxid(event_id, self.mxid)
@@ -268,7 +284,8 @@ class Portal(DBPortal, BasePortal):
     async def _handle_matrix_typing(self, users: Set[UserID], status: TypingStatus) -> None:
     async def _handle_matrix_typing(self, users: Set[UserID], status: TypingStatus) -> None:
         for mxid in users:
         for mxid in users:
             user = await u.User.get_by_mxid(mxid, create=False)
             user = await u.User.get_by_mxid(mxid, create=False)
-            if not user or not await user.is_logged_in() or user.remote_typing_status == status:
+            if ((not user or not await user.is_logged_in() or user.remote_typing_status == status
+                 or not user.is_connected)):
                 continue
                 continue
             user.remote_typing_status = None
             user.remote_typing_status = None
             await user.mqtt.indicate_activity(self.thread_id, status)
             await user.mqtt.indicate_activity(self.thread_id, status)

+ 53 - 10
mautrix_instagram/user.py

@@ -24,7 +24,7 @@ from mauigpapi import AndroidAPI, AndroidState, AndroidMQTT
 from mauigpapi.mqtt import Connect, Disconnect, GraphQLSubscription, SkywalkerSubscription
 from mauigpapi.mqtt import Connect, Disconnect, GraphQLSubscription, SkywalkerSubscription
 from mauigpapi.types import (CurrentUser, MessageSyncEvent, Operation, RealtimeDirectEvent,
 from mauigpapi.types import (CurrentUser, MessageSyncEvent, Operation, RealtimeDirectEvent,
                              ActivityIndicatorData, TypingStatus, ThreadSyncEvent)
                              ActivityIndicatorData, TypingStatus, ThreadSyncEvent)
-from mauigpapi.errors import IGNotLoggedInError
+from mauigpapi.errors import IGNotLoggedInError, MQTTNotLoggedIn, MQTTNotConnected
 from mautrix.bridge import BaseUser
 from mautrix.bridge import BaseUser
 from mautrix.types import UserID, RoomID, EventID, TextMessageEventContent, MessageType
 from mautrix.types import UserID, RoomID, EventID, TextMessageEventContent, MessageType
 from mautrix.appservice import AppService
 from mautrix.appservice import AppService
@@ -64,6 +64,8 @@ class User(DBUser, BaseUser):
     _notice_room_lock: asyncio.Lock
     _notice_room_lock: asyncio.Lock
     _notice_send_lock: asyncio.Lock
     _notice_send_lock: asyncio.Lock
     _is_logged_in: bool
     _is_logged_in: bool
+    _is_connected: bool
+    shutdown: bool
     remote_typing_status: Optional[TypingStatus]
     remote_typing_status: Optional[TypingStatus]
 
 
     def __init__(self, mxid: UserID, igpk: Optional[int] = None,
     def __init__(self, mxid: UserID, igpk: Optional[int] = None,
@@ -81,6 +83,8 @@ class User(DBUser, BaseUser):
         self.dm_update_lock = asyncio.Lock()
         self.dm_update_lock = asyncio.Lock()
         self._metric_value = defaultdict(lambda: False)
         self._metric_value = defaultdict(lambda: False)
         self._is_logged_in = False
         self._is_logged_in = False
+        self._is_connected = False
+        self.shutdown = False
         self._listen_task = None
         self._listen_task = None
         self.command_status = None
         self.command_status = None
         self.remote_typing_status = None
         self.remote_typing_status = None
@@ -108,6 +112,10 @@ class User(DBUser, BaseUser):
     def api_log(self) -> TraceLogger:
     def api_log(self) -> TraceLogger:
         return self.ig_base_log.getChild("http").getChild(self.mxid)
         return self.ig_base_log.getChild("http").getChild(self.mxid)
 
 
+    @property
+    def is_connected(self) -> bool:
+        return bool(self.client) and bool(self.mqtt) and self._is_connected
+
     async def connect(self) -> None:
     async def connect(self) -> None:
         client = AndroidAPI(self.state, log=self.api_log)
         client = AndroidAPI(self.state, log=self.api_log)
 
 
@@ -116,7 +124,8 @@ class User(DBUser, BaseUser):
         except IGNotLoggedInError as e:
         except IGNotLoggedInError as e:
             self.log.warning(f"Failed to connect to Instagram: {e}")
             self.log.warning(f"Failed to connect to Instagram: {e}")
             # TODO show reason?
             # TODO show reason?
-            await self.send_bridge_notice("You have been logged out of Instagram")
+            await self.send_bridge_notice("You have been logged out of Instagram",
+                                          important=True)
             return
             return
         self.client = client
         self.client = client
         self._is_logged_in = True
         self._is_logged_in = True
@@ -141,10 +150,13 @@ class User(DBUser, BaseUser):
     async def on_connect(self, evt: Connect) -> None:
     async def on_connect(self, evt: Connect) -> None:
         self.log.debug("Connected to Instagram")
         self.log.debug("Connected to Instagram")
         self._track_metric(METRIC_CONNECTED, True)
         self._track_metric(METRIC_CONNECTED, True)
+        self._is_connected = True
+        await self.send_bridge_notice("Connected to Instagram")
 
 
     async def on_disconnect(self, evt: Disconnect) -> None:
     async def on_disconnect(self, evt: Disconnect) -> None:
         self.log.debug("Disconnected from Instagram")
         self.log.debug("Disconnected from Instagram")
         self._track_metric(METRIC_CONNECTED, False)
         self._track_metric(METRIC_CONNECTED, False)
+        self._is_connected = False
 
 
     # TODO this stuff could probably be moved to mautrix-python
     # TODO this stuff could probably be moved to mautrix-python
     async def get_notice_room(self) -> RoomID:
     async def get_notice_room(self) -> RoomID:
@@ -162,6 +174,9 @@ class User(DBUser, BaseUser):
 
 
     async def send_bridge_notice(self, text: str, edit: Optional[EventID] = None,
     async def send_bridge_notice(self, text: str, edit: Optional[EventID] = None,
                                  important: bool = False) -> Optional[EventID]:
                                  important: bool = False) -> Optional[EventID]:
+        if not important and not self.config["bridge.unimportant_bridge_notices"]:
+            self.log.debug("Not sending unimportant bridge notice: %s", text)
+            return
         event_id = None
         event_id = None
         try:
         try:
             self.log.debug("Sending bridge notice: %s", text)
             self.log.debug("Sending bridge notice: %s", text)
@@ -218,18 +233,46 @@ class User(DBUser, BaseUser):
                 self.log.debug(f"{thread.thread_id} is not active and doesn't have a portal")
                 self.log.debug(f"{thread.thread_id} is not active and doesn't have a portal")
         await self.update_direct_chats()
         await self.update_direct_chats()
 
 
-        self._listen_task = self.loop.create_task(self.mqtt.listen(
-            graphql_subs={GraphQLSubscription.app_presence(),
-                          GraphQLSubscription.direct_typing(self.state.user_id),
-                          GraphQLSubscription.direct_status()},
-            skywalker_subs={SkywalkerSubscription.direct_sub(self.state.user_id),
-                            SkywalkerSubscription.live_sub(self.state.user_id)},
-            seq_id=resp.seq_id, snapshot_at_ms=resp.snapshot_at_ms))
+        if not self._listen_task:
+            await self.start_listen(resp.seq_id, resp.snapshot_at_ms)
+
+    async def start_listen(self, seq_id: Optional[int] = None, snapshot_at_ms: Optional[int] = None) -> None:
+        if not seq_id:
+            resp = await self.client.get_inbox(limit=1)
+            seq_id, snapshot_at_ms = resp.seq_id, resp.snapshot_at_ms
+        task = self.listen(seq_id=seq_id, snapshot_at_ms=snapshot_at_ms)
+        self._listen_task = self.loop.create_task(task)
 
 
-    async def stop(self) -> None:
+    async def listen(self, seq_id: int, snapshot_at_ms: int) -> None:
+        try:
+            await self.mqtt.listen(
+                graphql_subs={GraphQLSubscription.app_presence(),
+                              GraphQLSubscription.direct_typing(self.state.user_id),
+                              GraphQLSubscription.direct_status()},
+                skywalker_subs={SkywalkerSubscription.direct_sub(self.state.user_id),
+                                SkywalkerSubscription.live_sub(self.state.user_id)},
+                seq_id=seq_id, snapshot_at_ms=snapshot_at_ms)
+        except Exception:
+            self.log.exception("Fatal error in listener")
+            await self.send_bridge_notice("Fatal error in listener (see logs for more info)",
+                                          important=True)
+            self.mqtt.disconnect()
+            self._is_connected = False
+            self._track_metric(METRIC_CONNECTED, False)
+        else:
+            if not self.shutdown:
+                await self.send_bridge_notice("Instagram connection closed without error")
+        finally:
+            self._listen_task = None
+
+    async def stop_listen(self) -> None:
+        self.shutdown = True
         if self.mqtt:
         if self.mqtt:
             self.mqtt.disconnect()
             self.mqtt.disconnect()
+            if self._listen_task:
+                await self._listen_task
         self._track_metric(METRIC_CONNECTED, False)
         self._track_metric(METRIC_CONNECTED, False)
+        self._is_connected = False
         await self.update()
         await self.update()
 
 
     async def logout(self) -> None:
     async def logout(self) -> None: