Browse Source

backfill: run the queue on user connect

Signed-off-by: Sumner Evans <sumner@beeper.com>
Sumner Evans 2 years ago
parent
commit
a26ec4400e
3 changed files with 332 additions and 40 deletions
  1. 18 5
      mauigpapi/http/thread.py
  2. 9 6
      mautrix_instagram/portal.py
  3. 305 29
      mautrix_instagram/user.py

+ 18 - 5
mauigpapi/http/thread.py

@@ -16,15 +16,17 @@
 from __future__ import annotations
 
 from typing import AsyncIterable, Type
+import asyncio
 import json
 
+from mauigpapi.errors.response import IGRateLimitError
+
 from ..types import (
     CommandResponse,
     DMInboxResponse,
     DMThreadResponse,
     Thread,
     ThreadAction,
-    ThreadItem,
     ThreadItemType,
 )
 from .base import BaseAndroidAPI, T
@@ -58,14 +60,16 @@ class ThreadAPI(BaseAndroidAPI):
         self,
         start_at: DMInboxResponse | None = None,
         local_limit: int | None = None,
-    ) -> AsyncIterable[Thread]:
+        rate_limit_exceeded_backoff: float = 60.0,
+    ) -> AsyncIterable[tuple[Thread, int | None, str | None]]:
+        print("ITER INBOX")
         thread_counter = 0
         if start_at:
             cursor = start_at.inbox.oldest_cursor
             seq_id = start_at.seq_id
             has_more = start_at.inbox.has_older
             for thread in start_at.inbox.threads:
-                yield thread
+                yield thread, seq_id, cursor
                 thread_counter += 1
                 if local_limit and thread_counter >= local_limit:
                     return
@@ -74,12 +78,21 @@ class ThreadAPI(BaseAndroidAPI):
             seq_id = None
             has_more = True
         while has_more:
-            resp = await self.get_inbox(message_limit=10, cursor=cursor, seq_id=seq_id)
+            try:
+                resp = await self.get_inbox(message_limit=10, cursor=cursor, seq_id=seq_id)
+            except IGRateLimitError:
+                self.log.warning(
+                    "Fetching more threads failed due to rate limit. Waiting for "
+                    f"{rate_limit_exceeded_backoff} seconds before resuming."
+                )
+                await asyncio.sleep(rate_limit_exceeded_backoff)
+                continue
+
             seq_id = resp.seq_id
             cursor = resp.inbox.oldest_cursor
             has_more = resp.inbox.has_older
             for thread in resp.inbox.threads:
-                yield thread
+                yield thread, seq_id, cursor
                 thread_counter += 1
                 if local_limit and thread_counter >= local_limit:
                     return

+ 9 - 6
mautrix_instagram/portal.py

@@ -1435,7 +1435,8 @@ class Portal(DBPortal, BasePortal):
                 return
 
             if self.config["bridge.backfill.enable"]:
-                await self.enqueue_immediate_backfill(source, 0)
+                if self.config["bridge.backfill.msc2716"]:
+                    await self.enqueue_immediate_backfill(source, 0)
 
         intent = sender.intent_for(self)
         asyncio.create_task(intent.set_typing(self.mxid, is_typing=False))
@@ -1771,8 +1772,8 @@ class Portal(DBPortal, BasePortal):
             await Backfill.new(
                 source.mxid,
                 priority,
-                self.fbid,
-                self.fb_receiver,
+                self.thread_id,
+                self.receiver,
                 self.config["bridge.backfill.incremental.max_pages"],
                 self.config["bridge.backfill.incremental.page_delay"],
                 self.config["bridge.backfill.incremental.post_batch_delay"],
@@ -1835,7 +1836,9 @@ class Portal(DBPortal, BasePortal):
 
         pages_backfilled = 0
         for i in range(pages_to_backfill):
-            base_insertion_event_id = await self.backfill_message_page(source, messages)
+            base_insertion_event_id = await self.backfill_message_page(
+                source, list(reversed(messages))
+            )
             self.cursor = cursor
             await self.save()
             pages_backfilled += 1
@@ -1881,8 +1884,8 @@ class Portal(DBPortal, BasePortal):
                 source.mxid,
                 # Always enqueue subsequent backfills at the lowest priority
                 2,
-                self.fbid,
-                self.fb_receiver,
+                self.thread_id,
+                self.receiver,
                 backfill_request.num_pages,
                 backfill_request.page_delay,
                 backfill_request.post_batch_delay,

+ 305 - 29
mautrix_instagram/user.py

@@ -15,7 +15,9 @@
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
+from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, Callable, cast
+from datetime import datetime, timedelta
+from functools import partial
 import asyncio
 import logging
 import time
@@ -53,16 +55,18 @@ from mauigpapi.types import (
     ThreadSyncEvent,
     TypingStatus,
 )
+from mauigpapi.types.direct_inbox import DMInbox, DMInboxResponse
 from mautrix.appservice import AppService
 from mautrix.bridge import BaseUser, async_getter_lock
 from mautrix.types import EventID, MessageType, RoomID, TextMessageEventContent, UserID
 from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
 from mautrix.util.logging import TraceLogger
 from mautrix.util.opt_prometheus import Gauge, Summary, async_time
+from mautrix.util.simple_lock import SimpleLock
 
 from . import portal as po, puppet as pu
 from .config import Config
-from .db import Portal as DBPortal, User as DBUser
+from .db import Backfill, Message as DBMessage, Portal as DBPortal, User as DBUser
 
 if TYPE_CHECKING:
     from .__main__ import InstagramBridge
@@ -111,6 +115,9 @@ class User(DBUser, BaseUser):
     client: AndroidAPI | None
     mqtt: AndroidMQTT | None
     _listen_task: asyncio.Task | None = None
+    _sync_lock: SimpleLock
+    _backfill_loop_task: asyncio.Task | None
+    _thread_sync_task: asyncio.Task | None
     _seq_id_save_task: asyncio.Task | None
 
     permission_level: str
@@ -131,6 +138,9 @@ class User(DBUser, BaseUser):
         notice_room: RoomID | None = None,
         seq_id: int | None = None,
         snapshot_at_ms: int | None = None,
+        oldest_cursor: str | None = None,
+        total_backfilled_portals: int | None = None,
+        thread_sync_completed: bool = False,
     ) -> None:
         super().__init__(
             mxid=mxid,
@@ -139,6 +149,9 @@ class User(DBUser, BaseUser):
             notice_room=notice_room,
             seq_id=seq_id,
             snapshot_at_ms=snapshot_at_ms,
+            oldest_cursor=oldest_cursor,
+            total_backfilled_portals=total_backfilled_portals,
+            thread_sync_completed=thread_sync_completed,
         )
         BaseUser.__init__(self)
         self._notice_room_lock = asyncio.Lock()
@@ -152,7 +165,12 @@ class User(DBUser, BaseUser):
         self._is_connected = False
         self._is_refreshing = False
         self.shutdown = False
+        self._sync_lock = SimpleLock(
+            "Waiting for thread sync to finish before handling %s", log=self.log
+        )
         self._listen_task = None
+        self._thread_sync_task = None
+        self._backfill_loop_task = None
         self.remote_typing_status = None
         self._seq_id_save_task = None
 
@@ -258,12 +276,46 @@ class User(DBUser, BaseUser):
         await self.update()
 
         self.loop.create_task(self._try_sync_puppet(user))
-        if not self.seq_id or self.config["bridge.max_startup_thread_sync_count"]:
-            self.loop.create_task(self._try_sync())
+
+        # Backfill requests are handled synchronously so as not to overload the homeserver.
+        # Users can configure their backfill stages to be more or less aggressive with backfilling
+        # to try and avoid getting banned.
+        if not self._backfill_loop_task or self._backfill_loop_task.done():
+            self._backfill_loop_task = asyncio.create_task(self._handle_backfill_requests_loop())
+
+        if not self.seq_id:
+            await self._try_sync()
         else:
             self.log.debug("Connecting to MQTT directly as resync_on_startup is false")
             self.start_listen()
 
+        if self.config["bridge.backfill.enable"]:
+            if self._thread_sync_task and not self._thread_sync_task.done():
+                self.log.warning("Cancelling existing background thread sync task")
+                self._thread_sync_task.cancel()
+            self._thread_sync_task = asyncio.create_task(self.backfill_threads())
+
+    async def _handle_backfill_requests_loop(self) -> None:
+        while True:
+            await self._sync_lock.wait("backfill request")
+            req = await Backfill.get_next(self.mxid)
+            if not req:
+                await asyncio.sleep(30)
+                continue
+            self.log.info("Backfill request %s", req)
+            try:
+                portal = await po.Portal.get_by_thread_id(
+                    req.portal_thread_id, receiver=req.portal_receiver
+                )
+                await req.mark_dispatched()
+                await portal.backfill(self, req)
+                await req.mark_done()
+            except Exception as e:
+                self.log.exception("Failed to backfill portal %s: %s", req.portal_thread_id, e)
+
+                # Don't try again to backfill this portal for a minute.
+                await req.set_cooldown_timeout(60)
+
     async def on_connect(self, evt: Connect) -> None:
         self.log.debug("Connected to Instagram")
         self._track_metric(METRIC_CONNECTED, True)
@@ -464,16 +516,96 @@ class User(DBUser, BaseUser):
             info = {"challenge": e.body.challenge.serialize() if e.body.challenge else None}
         await self.push_bridge_state(BridgeStateEvent.BAD_CREDENTIALS, error=error_code, info=info)
 
-    async def _sync_thread(self, thread: Thread, allow_create: bool) -> None:
+    async def _sync_thread(self, thread: Thread) -> bool:
+        """
+        Sync a specific thread. Returns whether the thread had messages after the last message in
+        the database before the sync.
+        """
+        self.log.debug(f"Syncing thread {thread.thread_id}")
+
+        forward_messages = thread.items
+
+        assert self.client
         portal = await po.Portal.get_by_thread(thread, self.igpk)
-        if portal.mxid:
-            self.log.debug(f"{thread.thread_id} has a portal, syncing and backfilling...")
-            await portal.update_matrix_room(self, thread)
-        elif allow_create:
-            self.log.debug(f"{thread.thread_id} has been active recently, creating portal...")
+        assert portal
+
+        # Create or update the Matrix room
+        if not portal.mxid:
             await portal.create_matrix_room(self, thread)
         else:
-            self.log.debug(f"{thread.thread_id} is not active and doesn't have a portal")
+            await portal.update_matrix_room(self, thread)
+
+        last_message = await DBMessage.get_last(portal.mxid)
+        if last_message:
+            original_number_of_messages = len(thread.items)
+            new_messages = [
+                m for m in thread.items if last_message.ig_timestamp_ms < m.timestamp_ms
+            ]
+            forward_messages = new_messages
+
+            portal.log.debug(
+                f"{len(new_messages)}/{original_number_of_messages} messages are after most recent"
+                " message."
+            )
+
+            # Fetch more messages until we get back to messages that have been bridged already.
+            cursor = thread.prev_cursor
+            while len(new_messages) > 0 and len(new_messages) == original_number_of_messages:
+                await asyncio.sleep(self.config["bridge.backfill.incremental.page_delay"])
+
+                portal.log.debug("Fetching more messages for forward backfill")
+                resp = await self.client.get_thread(portal.thread_id, cursor=cursor)
+                if len(resp.thread.items) == 0:
+                    break
+                original_number_of_messages = len(resp.thread.items)
+                new_messages = [
+                    m for m in resp.thread.items if last_message.ig_timestamp_ms < m.timestamp_ms
+                ]
+                forward_messages = new_messages + forward_messages
+                cursor = resp.thread.prev_cursor
+                portal.log.debug(
+                    f"{len(new_messages)}/{original_number_of_messages} messages are after most "
+                    "recent message."
+                )
+            portal.cursor = cursor
+            await portal.update()
+        elif not portal.first_event_id:
+            self.log.debug(
+                f"Skipping backfilling {portal.thread_id} as the first event ID is not known"
+            )
+            return False
+
+        if forward_messages:
+            mark_read = thread.read_state == 0 or (
+                (hours := self.config["bridge.backfill.unread_hours_threshold"]) > 0
+                and (
+                    datetime.fromtimestamp(forward_messages[0].timestamp_ms / 1000)
+                    < datetime.now() - timedelta(hours=hours)
+                )
+            )
+            base_insertion_event_id = await portal.backfill_message_page(
+                self,
+                list(reversed(forward_messages)),
+                forward=True,
+                last_message=last_message,
+                mark_read=mark_read,
+            )
+            if not self.bridge.homeserver_software.is_hungry:
+                await portal.send_post_backfill_dummy(
+                    forward_messages[0].timestamp, base_insertion_event_id=base_insertion_event_id
+                )
+            if (
+                mark_read
+                and not self.bridge.homeserver_software.is_hungry
+                and (puppet := await self.get_puppet())
+            ):
+                last_message = await DBMessage.get_last(portal.mxid)
+                if last_message:
+                    await puppet.intent_for(portal).mark_read(portal.mxid, last_message.mxid)
+
+        if self.config["bridge.backfill.msc2716"]:
+            await portal.enqueue_immediate_backfill(self, 1)
+        return len(forward_messages) > 0
 
     async def _maybe_update_proxy(self, source: str) -> None:
         if not self._listen_task:
@@ -482,7 +614,10 @@ class User(DBUser, BaseUser):
         else:
             self.log.debug(f"Not updating proxy: listen_task is still running? (caller: {source})")
 
-    async def sync(self) -> None:
+    async def sync(self, increment_total_backfilled_portals: bool = False) -> None:
+        await self.run_with_sync_lock(partial(self._sync, increment_total_backfilled_portals))
+
+    async def _sync(self, increment_total_backfilled_portals: bool = False) -> None:
         sleep_minutes = 2
         errors = 0
         while True:
@@ -533,28 +668,158 @@ class User(DBUser, BaseUser):
         if not self._listen_task:
             self.start_listen(is_after_sync=True)
 
-        max_age = self.config["bridge.portal_create_max_age"] * 1_000_000
-        limit = self.config["bridge.backfill.max_conversations"]  # TODO
-        create_limit = self.config["bridge.backfill.max_conversations"]  # TODO
-        min_active_at = (time.time() * 1_000_000) - max_age
-        i = 0
-        await self.push_bridge_state(BridgeStateEvent.BACKFILLING)
-        async for thread in self.client.iter_inbox(start_at=resp):
-            try:
-                await self._sync_thread(
-                    thread=thread,
-                    allow_create=thread.last_activity_at > min_active_at and i < create_limit,
-                )
-            except Exception:
-                self.log.exception(f"Error syncing thread {thread.thread_id}")
-            i += 1
-            if i >= limit:
-                break
+        sync_count = min(
+            self.config["bridge.backfill.max_conversations"],
+            self.config["bridge.max_startup_thread_sync_count"],
+        )
+        self.log.debug(f"Fetching {sync_count} threads, 20 at a time...")
+
+        local_limit: int | None = sync_count
+        if sync_count == 0:
+            return
+        elif sync_count < 0:
+            local_limit = None
+
+        await self._sync_threads_with_delay(
+            self.client.iter_inbox(start_at=resp, local_limit=local_limit),
+            stop_when_threads_have_no_messages_to_backfill=True,
+            increment_total_backfilled_portals=increment_total_backfilled_portals,
+            local_limit=local_limit,
+        )
+
         try:
             await self.update_direct_chats()
         except Exception:
             self.log.exception("Error updating direct chat list")
 
+    async def backfill_threads(self):
+        try:
+            await self.run_with_sync_lock(self._backfill_threads)
+        except Exception:
+            self.log.exception("Error in thread backfill loop")
+
+    async def _backfill_threads(self):
+        assert self.client
+        if not self.config["bridge.backfill.enable"]:
+            return
+
+        max_conversations = self.config["bridge.backfill.max_conversations"] or 0
+        if 0 <= max_conversations <= (self.total_backfilled_portals or 0):
+            self.log.info("Backfill max_conversations count reached, not syncing any more portals")
+            return
+        elif self.thread_sync_completed:
+            self.log.debug("Thread backfill is marked as completed, not syncing more portals")
+            return
+        local_limit = (
+            max_conversations - (self.total_backfilled_portals or 0)
+            if max_conversations >= 0
+            else None
+        )
+
+        start_at = None
+        if self.oldest_cursor:
+            start_at = DMInboxResponse(
+                status="",
+                seq_id=self.seq_id,
+                snapshot_at_ms=0,
+                pending_requests_total=0,
+                has_pending_top_requests=False,
+                viewer=None,
+                inbox=DMInbox(
+                    threads=[],
+                    has_older=True,
+                    unseen_count=0,
+                    unseen_count_ts=0,
+                    blended_inbox_enabled=False,
+                    oldest_cursor=self.oldest_cursor,
+                ),
+            )
+        backoff = self.config.get("bridge.backfill.backoff.thread_list", 300)
+        await self._sync_threads_with_delay(
+            self.client.iter_inbox(
+                start_at,
+                local_limit=local_limit,
+                rate_limit_exceeded_backoff=backoff,
+            ),
+            increment_total_backfilled_portals=True,
+            local_limit=local_limit,
+        )
+        await self.update_direct_chats()
+
+    async def _sync_threads_with_delay(
+        self,
+        threads: AsyncIterable[tuple[Thread, int | None, str | None]],
+        increment_total_backfilled_portals: bool = False,
+        stop_when_threads_have_no_messages_to_backfill: bool = False,
+        local_limit: int | None = None,
+    ):
+        sync_delay = self.config["bridge.backfill.min_sync_thread_delay"]
+        last_thread_sync_ts = 0.0
+        found_thread_count = 0
+        async for thread, seq_id, cursor in threads:
+            found_thread_count += 1
+            now = time.monotonic()
+            if last_thread_sync_ts is not None and now < last_thread_sync_ts + sync_delay:
+                delay = last_thread_sync_ts + sync_delay - now
+                self.log.debug("Thread sync is happening too quickly. Waiting for %ds", delay)
+                await asyncio.sleep(delay)
+
+            last_thread_sync_ts = now
+            had_new_messages = await self._sync_thread(thread)
+            if not had_new_messages and stop_when_threads_have_no_messages_to_backfill:
+                self.log.debug("Got to threads with no new messages. Stopping sync.")
+                return
+
+            if increment_total_backfilled_portals:
+                self.total_backfilled_portals = (self.total_backfilled_portals or 0) + 1
+            if seq_id:
+                self.seq_id = seq_id
+            if cursor:
+                self.oldest_cursor = cursor
+            await self.update()
+        if local_limit is None or found_thread_count < local_limit:
+            if local_limit is None:
+                self.log.info(
+                    "Reached end of thread list with no limit, marking thread sync as completed"
+                )
+            else:
+                self.log.info(
+                    f"Reached end of thread list (got {found_thread_count} with "
+                    f"limit {local_limit}), marking thread sync as completed"
+                )
+            self.thread_sync_completed = True
+        await self.update()
+
+    async def run_with_sync_lock(self, func: Callable[[], Awaitable]):
+        with self._sync_lock:
+            retry_count = 0
+            while retry_count < 5:
+                try:
+                    retry_count += 1
+                    await func()
+
+                    # The sync was successful. Exit the loop.
+                    return
+                except IGNotLoggedInError as e:
+                    await self.send_bridge_notice(
+                        f"You have been logged out of Instagram: {e!s}",
+                        important=True,
+                        state_event=BridgeStateEvent.BAD_CREDENTIALS,
+                        error_code="ig-auth-error",
+                        error_message=str(e),
+                    )
+                    await self.logout(error=e)
+                    return
+                except Exception:
+                    self.log.exception(
+                        "Failed to sync threads. Waiting 30 seconds before retrying sync."
+                    )
+                    await asyncio.sleep(30)
+
+            # If we get here, it means that the sync has failed five times. If this happens, most
+            # likely something very bad has happened.
+            self.log.error("Failed to sync threads five times. Will not retry.")
+
     def start_listen(self, is_after_sync: bool = False) -> None:
         self.shutdown = False
         task = self._listen(
@@ -678,7 +943,17 @@ class User(DBUser, BaseUser):
         self._is_connected = False
         await self.update()
 
+    def stop_backfill_tasks(self) -> None:
+        if self._backfill_loop_task:
+            self._backfill_loop_task.cancel()
+            self._backfill_loop_task = None
+        if self._thread_sync_task:
+            self._thread_sync_task.cancel()
+            self._thread_sync_task = None
+
     async def logout(self, error: IGNotLoggedInError | None = None) -> None:
+        await self.stop_listen()
+        self.stop_backfill_tasks()
         if self.client and error is None:
             try:
                 await self.client.logout(one_tap_app_login=False)
@@ -712,6 +987,7 @@ class User(DBUser, BaseUser):
         self.state = None
         self.seq_id = None
         self.snapshot_at_ms = None
+        self.thread_sync_completed = False
         self._is_logged_in = False
         await self.update()