Przeglądaj źródła

Make resyncing on startup optional

Tulir Asokan 3 lat temu
rodzic
commit
3a756b42db

+ 1 - 1
mauigpapi/mqtt/__init__.py

@@ -1,3 +1,3 @@
 from .conn import AndroidMQTT
-from .events import Connect, Disconnect
+from .events import Connect, Disconnect, NewSequenceID
 from .subscription import GraphQLSubscription, SkywalkerSubscription

+ 12 - 1
mauigpapi/mqtt/conn.py

@@ -53,7 +53,7 @@ from ..types import (
     ThreadSyncEvent,
     TypingStatus,
 )
-from .events import Connect, Disconnect
+from .events import Connect, Disconnect, NewSequenceID
 from .otclient import MQTToTClient
 from .subscription import GraphQLQueryID, RealtimeTopic, everclear_subscriptions
 from .thrift import ForegroundStateConfig, IncomingMessage, RealtimeClientInfo, RealtimeConfig
@@ -329,6 +329,9 @@ class AndroidMQTT:
                 self.log.trace(f"Got new seq_id: {parsed_item.seq_id}")
                 self._iris_seq_id = parsed_item.seq_id
                 self._iris_snapshot_at_ms = int(time.time() * 1000)
+                asyncio.create_task(
+                    self._dispatch(NewSequenceID(self._iris_seq_id, self._iris_snapshot_at_ms))
+                )
             for part in parsed_item.data:
                 self._on_messager_sync_item(part, parsed_item)
 
@@ -564,6 +567,14 @@ class AndroidMQTT:
         resp_dict = json.loads(resp.payload.decode("utf-8"))
         if resp_dict["error_type"] and resp_dict["error_message"]:
             raise IrisSubscribeError(resp_dict["error_type"], resp_dict["error_message"])
+        latest_seq_id = resp_dict.get("latest_seq_id")
+        if latest_seq_id > self._iris_seq_id:
+            self.log.info(f"Latest sequence ID is {latest_seq_id}, catching up from {seq_id}")
+            self._iris_seq_id = latest_seq_id
+            self._iris_snapshot_at_ms = resp_dict.get("subscribed_at_ms", int(time.time() * 1000))
+            asyncio.create_task(
+                self._dispatch(NewSequenceID(self._iris_seq_id, self._iris_snapshot_at_ms))
+            )
 
     def graphql_subscribe(self, subs: set[str]) -> asyncio.Future:
         self._graphql_subs |= subs

+ 6 - 0
mauigpapi/mqtt/events.py

@@ -24,3 +24,9 @@ class Connect:
 @dataclass
 class Disconnect:
     reason: str
+
+
+@dataclass
+class NewSequenceID:
+    seq_id: int
+    snapshot_at_ms: int

+ 2 - 0
mautrix_instagram/config.py

@@ -63,6 +63,8 @@ class Config(BaseBridgeConfig):
 
         copy("bridge.portal_create_max_age")
         copy("bridge.chat_sync_limit")
+        copy("bridge.chat_create_limit")
+        copy("bridge.resync_on_startup")
         copy("bridge.sync_with_custom_puppets")
         copy("bridge.sync_direct_chat_list")
         copy("bridge.double_puppet_server_map")

+ 6 - 0
mautrix_instagram/db/upgrade.py

@@ -125,3 +125,9 @@ async def upgrade_v6(conn: Connection) -> None:
 @upgrade_table.register(description="Store reaction timestamps")
 async def upgrade_v7(conn: Connection) -> None:
     await conn.execute("ALTER TABLE reaction ADD COLUMN mx_timestamp BIGINT")
+
+
+@upgrade_table.register(description="Store sync sequence ID in user table")
+async def upgrade_v8(conn: Connection) -> None:
+    await conn.execute('ALTER TABLE "user" ADD COLUMN seq_id BIGINT')
+    await conn.execute('ALTER TABLE "user" ADD COLUMN snapshot_at_ms BIGINT')

+ 32 - 11
mautrix_instagram/db/user.py

@@ -35,18 +35,37 @@ class User:
     igpk: int | None
     state: AndroidState | None
     notice_room: RoomID | None
+    seq_id: int | None
+    snapshot_at_ms: int | None
 
-    async def insert(self) -> None:
-        q = 'INSERT INTO "user" (mxid, igpk, state, notice_room) VALUES ($1, $2, $3, $4)'
-        await self.db.execute(
-            q, self.mxid, self.igpk, self.state.json() if self.state else None, self.notice_room
+    @property
+    def _values(self):
+        return (
+            self.mxid,
+            self.igpk,
+            self.state.json() if self.state else None,
+            self.notice_room,
+            self.seq_id,
+            self.snapshot_at_ms,
         )
 
+    async def insert(self) -> None:
+        q = """
+        INSERT INTO "user" (mxid, igpk, state, notice_room, seq_id, snapshot_at_ms)
+        VALUES ($1, $2, $3, $4, $5, $6)
+        """
+        await self.db.execute(q, *self._values)
+
     async def update(self) -> None:
-        q = 'UPDATE "user" SET igpk=$2, state=$3, notice_room=$4 WHERE mxid=$1'
-        await self.db.execute(
-            q, self.mxid, self.igpk, self.state.json() if self.state else None, self.notice_room
-        )
+        q = """
+        UPDATE "user" SET igpk=$2, state=$3, notice_room=$4, seq_id=$5, snapshot_at_ms=$6
+        WHERE mxid=$1
+        """
+        await self.db.execute(q, *self._values)
+
+    async def save_seq_id(self) -> None:
+        q = 'UPDATE "user" SET seq_id=$2, snapshot_at_ms=$3 WHERE mxid=$1'
+        await self.db.execute(q, self.mxid, self.seq_id, self.snapshot_at_ms)
 
     @classmethod
     def _from_row(cls, row: asyncpg.Record) -> User:
@@ -54,9 +73,11 @@ class User:
         state_str = data.pop("state")
         return cls(state=AndroidState.parse_json(state_str) if state_str else None, **data)
 
+    _columns = "mxid, igpk, state, notice_room, seq_id, snapshot_at_ms"
+
     @classmethod
     async def get_by_mxid(cls, mxid: UserID) -> User | None:
-        q = 'SELECT mxid, igpk, state, notice_room FROM "user" WHERE mxid=$1'
+        q = f'SELECT {cls._columns} FROM "user" WHERE mxid=$1'
         row = await cls.db.fetchrow(q, mxid)
         if not row:
             return None
@@ -64,7 +85,7 @@ class User:
 
     @classmethod
     async def get_by_igpk(cls, igpk: int) -> User | None:
-        q = 'SELECT mxid, igpk, state, notice_room FROM "user" WHERE igpk=$1'
+        q = f'SELECT {cls._columns} FROM "user" WHERE igpk=$1'
         row = await cls.db.fetchrow(q, igpk)
         if not row:
             return None
@@ -72,6 +93,6 @@ class User:
 
     @classmethod
     async def all_logged_in(cls) -> list[User]:
-        q = 'SELECT mxid, igpk, state, notice_room FROM "user" WHERE igpk IS NOT NULL'
+        q = f'SELECT {cls._columns} FROM "user" WHERE igpk IS NOT NULL'
         rows = await cls.db.fetch(q)
         return [cls._from_row(row) for row in rows]

+ 7 - 2
mautrix_instagram/example-config.yaml

@@ -102,9 +102,14 @@ bridge:
     displayname_max_length: 100
 
     # Maximum number of seconds since the last activity in a chat to automatically create portals.
-    portal_create_max_age: 86400
+    portal_create_max_age: 259200
     # Maximum number of chats to fetch for startup sync
-    chat_sync_limit: 100
+    chat_sync_limit: 20
+    # Maximum number of chats to create during startup sync
+    chat_create_limit: 10
+    # Should the chat list be synced on startup?
+    # If false, the bridge will try to reconnect to MQTT directly and ask the server to send missed events.
+    resync_on_startup: true
     # Whether or not to use /sync to get read receipts and typing notifications
     # when double puppeting is enabled
     sync_with_custom_puppets: true

+ 61 - 16
mautrix_instagram/user.py

@@ -32,7 +32,13 @@ from mauigpapi.errors import (
     MQTTNotConnected,
     MQTTNotLoggedIn,
 )
-from mauigpapi.mqtt import Connect, Disconnect, GraphQLSubscription, SkywalkerSubscription
+from mauigpapi.mqtt import (
+    Connect,
+    Disconnect,
+    GraphQLSubscription,
+    NewSequenceID,
+    SkywalkerSubscription,
+)
 from mauigpapi.types import (
     ActivityIndicatorData,
     CurrentUser,
@@ -90,6 +96,7 @@ class User(DBUser, BaseUser):
     client: AndroidAPI | None
     mqtt: AndroidMQTT | None
     _listen_task: asyncio.Task | None = None
+    _seq_id_save_task: asyncio.Task | None
 
     permission_level: str
     username: str | None
@@ -107,8 +114,17 @@ class User(DBUser, BaseUser):
         igpk: int | None = None,
         state: AndroidState | None = None,
         notice_room: RoomID | None = None,
+        seq_id: int | None = None,
+        snapshot_at_ms: int | None = None,
     ) -> None:
-        super().__init__(mxid=mxid, igpk=igpk, state=state, notice_room=notice_room)
+        super().__init__(
+            mxid=mxid,
+            igpk=igpk,
+            state=state,
+            notice_room=notice_room,
+            seq_id=seq_id,
+            snapshot_at_ms=snapshot_at_ms,
+        )
         BaseUser.__init__(self)
         self._notice_room_lock = asyncio.Lock()
         self._notice_send_lock = asyncio.Lock()
@@ -123,6 +139,7 @@ class User(DBUser, BaseUser):
         self.shutdown = False
         self._listen_task = None
         self.remote_typing_status = None
+        self._seq_id_save_task = None
 
     @classmethod
     def init_cls(cls, bridge: "InstagramBridge") -> AsyncIterable[Awaitable[None]]:
@@ -203,6 +220,7 @@ class User(DBUser, BaseUser):
         )
         self.mqtt.add_event_handler(Connect, self.on_connect)
         self.mqtt.add_event_handler(Disconnect, self.on_disconnect)
+        self.mqtt.add_event_handler(NewSequenceID, self.update_seq_id)
         self.mqtt.add_event_handler(MessageSyncEvent, self.handle_message)
         self.mqtt.add_event_handler(ThreadSyncEvent, self.handle_thread_sync)
         self.mqtt.add_event_handler(RealtimeDirectEvent, self.handle_rtd)
@@ -210,7 +228,11 @@ class User(DBUser, BaseUser):
         await self.update()
 
         self.loop.create_task(self._try_sync_puppet(user))
-        self.loop.create_task(self._try_sync())
+        if not self.seq_id or self.config["bridge.resync_on_startup"]:
+            self.loop.create_task(self._try_sync())
+        else:
+            self.log.debug("Connecting to MQTT directly as resync_on_startup is false")
+            self.start_listen()
 
     async def on_connect(self, evt: Connect) -> None:
         self.log.debug("Connected to Instagram")
@@ -364,7 +386,7 @@ class User(DBUser, BaseUser):
                         )
                         await asyncio.sleep(minutes * 60)
             else:
-                await self.start_listen()
+                self.start_listen()
         finally:
             self._is_refreshing = False
 
@@ -407,12 +429,12 @@ class User(DBUser, BaseUser):
             info = {"challenge": e.body.challenge.serialize() if e.body.challenge else None}
         await self.push_bridge_state(BridgeStateEvent.BAD_CREDENTIALS, error=error_code, info=info)
 
-    async def _sync_thread(self, thread: Thread, min_active_at: int) -> None:
+    async def _sync_thread(self, thread: Thread, allow_create: bool) -> None:
         portal = await po.Portal.get_by_thread(thread, self.igpk)
         if portal.mxid:
             self.log.debug(f"{thread.thread_id} has a portal, syncing and backfilling...")
             await portal.update_matrix_room(self, thread, backfill=True)
-        elif thread.last_activity_at > min_active_at:
+        elif allow_create:
             self.log.debug(f"{thread.thread_id} has been active recently, creating portal...")
             await portal.create_matrix_room(self, thread)
         else:
@@ -443,17 +465,25 @@ class User(DBUser, BaseUser):
                 await self._handle_checkpoint(e, on="sync")
                 return
 
+        self.seq_id = resp.seq_id
+        self.snapshot_at_ms = resp.snapshot_at_ms
+        await self.save_seq_id()
+
         if not self._listen_task:
-            await self.start_listen(resp.seq_id, resp.snapshot_at_ms)
+            self.start_listen()
 
         max_age = self.config["bridge.portal_create_max_age"] * 1_000_000
         limit = self.config["bridge.chat_sync_limit"]
+        create_limit = self.config["bridge.chat_create_limit"]
         min_active_at = (time.time() * 1_000_000) - max_age
         i = 0
         await self.push_bridge_state(BridgeStateEvent.BACKFILLING)
         async for thread in self.client.iter_inbox(start_at=resp):
             try:
-                await self._sync_thread(thread, min_active_at)
+                await self._sync_thread(
+                    thread=thread,
+                    allow_create=thread.last_activity_at > min_active_at and i < create_limit,
+                )
             except Exception:
                 self.log.exception(f"Error syncing thread {thread.thread_id}")
             i += 1
@@ -464,17 +494,12 @@ class User(DBUser, BaseUser):
         except Exception:
             self.log.exception("Error updating direct chat list")
 
-    async def start_listen(
-        self, seq_id: int | None = None, snapshot_at_ms: int | None = None
-    ) -> None:
+    def start_listen(self) -> None:
         self.shutdown = False
-        if not seq_id:
-            resp = await self.client.get_inbox(limit=1)
-            seq_id, snapshot_at_ms = resp.seq_id, resp.snapshot_at_ms
-        task = self.listen(seq_id=seq_id, snapshot_at_ms=snapshot_at_ms)
+        task = self._listen(seq_id=self.seq_id, snapshot_at_ms=self.snapshot_at_ms)
         self._listen_task = self.loop.create_task(task)
 
-    async def listen(self, seq_id: int, snapshot_at_ms: int) -> None:
+    async def _listen(self, seq_id: int, snapshot_at_ms: int) -> None:
         try:
             await self.mqtt.listen(
                 graphql_subs={
@@ -565,12 +590,32 @@ class User(DBUser, BaseUser):
         self.client = None
         self.mqtt = None
         self.state = None
+        self.seq_id = None
+        self.snapshot_at_ms = None
         self._is_logged_in = False
         await self.update()
 
     # endregion
     # region Event handlers
 
+    async def _save_seq_id_after_sleep(self) -> None:
+        await asyncio.sleep(120)
+        self._seq_id_save_task = None
+        self.log.trace("Saving sequence ID %d/%d", self.seq_id, self.snapshot_at_ms)
+        try:
+            await self.save_seq_id()
+        except Exception:
+            self.log.exception("Error saving sequence ID")
+
+    async def update_seq_id(self, evt: NewSequenceID) -> None:
+        self.seq_id = evt.seq_id
+        self.snapshot_at_ms = evt.snapshot_at_ms
+        if not self._seq_id_save_task or self._seq_id_save_task.done():
+            self.log.trace("Starting seq id save task (%d/%d)", evt.seq_id, evt.snapshot_at_ms)
+            self._seq_id_save_task = asyncio.create_task(self._save_seq_id_after_sleep())
+        else:
+            self.log.trace("Not starting seq id save task (%d/%d)", evt.seq_id, evt.snapshot_at_ms)
+
     @async_time(METRIC_MESSAGE)
     async def handle_message(self, evt: MessageSyncEvent) -> None:
         portal = await po.Portal.get_by_thread_id(evt.message.thread_id, receiver=self.igpk)