|
@@ -32,7 +32,13 @@ from mauigpapi.errors import (
|
|
|
MQTTNotConnected,
|
|
|
MQTTNotLoggedIn,
|
|
|
)
|
|
|
-from mauigpapi.mqtt import Connect, Disconnect, GraphQLSubscription, SkywalkerSubscription
|
|
|
+from mauigpapi.mqtt import (
|
|
|
+ Connect,
|
|
|
+ Disconnect,
|
|
|
+ GraphQLSubscription,
|
|
|
+ NewSequenceID,
|
|
|
+ SkywalkerSubscription,
|
|
|
+)
|
|
|
from mauigpapi.types import (
|
|
|
ActivityIndicatorData,
|
|
|
CurrentUser,
|
|
@@ -90,6 +96,7 @@ class User(DBUser, BaseUser):
|
|
|
client: AndroidAPI | None
|
|
|
mqtt: AndroidMQTT | None
|
|
|
_listen_task: asyncio.Task | None = None
|
|
|
+ _seq_id_save_task: asyncio.Task | None
|
|
|
|
|
|
permission_level: str
|
|
|
username: str | None
|
|
@@ -107,8 +114,17 @@ class User(DBUser, BaseUser):
|
|
|
igpk: int | None = None,
|
|
|
state: AndroidState | None = None,
|
|
|
notice_room: RoomID | None = None,
|
|
|
+ seq_id: int | None = None,
|
|
|
+ snapshot_at_ms: int | None = None,
|
|
|
) -> None:
|
|
|
- super().__init__(mxid=mxid, igpk=igpk, state=state, notice_room=notice_room)
|
|
|
+ super().__init__(
|
|
|
+ mxid=mxid,
|
|
|
+ igpk=igpk,
|
|
|
+ state=state,
|
|
|
+ notice_room=notice_room,
|
|
|
+ seq_id=seq_id,
|
|
|
+ snapshot_at_ms=snapshot_at_ms,
|
|
|
+ )
|
|
|
BaseUser.__init__(self)
|
|
|
self._notice_room_lock = asyncio.Lock()
|
|
|
self._notice_send_lock = asyncio.Lock()
|
|
@@ -123,6 +139,7 @@ class User(DBUser, BaseUser):
|
|
|
self.shutdown = False
|
|
|
self._listen_task = None
|
|
|
self.remote_typing_status = None
|
|
|
+ self._seq_id_save_task = None
|
|
|
|
|
|
@classmethod
|
|
|
def init_cls(cls, bridge: "InstagramBridge") -> AsyncIterable[Awaitable[None]]:
|
|
@@ -203,6 +220,7 @@ class User(DBUser, BaseUser):
|
|
|
)
|
|
|
self.mqtt.add_event_handler(Connect, self.on_connect)
|
|
|
self.mqtt.add_event_handler(Disconnect, self.on_disconnect)
|
|
|
+ self.mqtt.add_event_handler(NewSequenceID, self.update_seq_id)
|
|
|
self.mqtt.add_event_handler(MessageSyncEvent, self.handle_message)
|
|
|
self.mqtt.add_event_handler(ThreadSyncEvent, self.handle_thread_sync)
|
|
|
self.mqtt.add_event_handler(RealtimeDirectEvent, self.handle_rtd)
|
|
@@ -210,7 +228,11 @@ class User(DBUser, BaseUser):
|
|
|
await self.update()
|
|
|
|
|
|
self.loop.create_task(self._try_sync_puppet(user))
|
|
|
- self.loop.create_task(self._try_sync())
|
|
|
+ if not self.seq_id or self.config["bridge.resync_on_startup"]:
|
|
|
+ self.loop.create_task(self._try_sync())
|
|
|
+ else:
|
|
|
+ self.log.debug("Connecting to MQTT directly as resync_on_startup is false")
|
|
|
+ self.start_listen()
|
|
|
|
|
|
async def on_connect(self, evt: Connect) -> None:
|
|
|
self.log.debug("Connected to Instagram")
|
|
@@ -364,7 +386,7 @@ class User(DBUser, BaseUser):
|
|
|
)
|
|
|
await asyncio.sleep(minutes * 60)
|
|
|
else:
|
|
|
- await self.start_listen()
|
|
|
+ self.start_listen()
|
|
|
finally:
|
|
|
self._is_refreshing = False
|
|
|
|
|
@@ -407,12 +429,12 @@ class User(DBUser, BaseUser):
|
|
|
info = {"challenge": e.body.challenge.serialize() if e.body.challenge else None}
|
|
|
await self.push_bridge_state(BridgeStateEvent.BAD_CREDENTIALS, error=error_code, info=info)
|
|
|
|
|
|
- async def _sync_thread(self, thread: Thread, min_active_at: int) -> None:
|
|
|
+ async def _sync_thread(self, thread: Thread, allow_create: bool) -> None:
|
|
|
portal = await po.Portal.get_by_thread(thread, self.igpk)
|
|
|
if portal.mxid:
|
|
|
self.log.debug(f"{thread.thread_id} has a portal, syncing and backfilling...")
|
|
|
await portal.update_matrix_room(self, thread, backfill=True)
|
|
|
- elif thread.last_activity_at > min_active_at:
|
|
|
+ elif allow_create:
|
|
|
self.log.debug(f"{thread.thread_id} has been active recently, creating portal...")
|
|
|
await portal.create_matrix_room(self, thread)
|
|
|
else:
|
|
@@ -443,17 +465,25 @@ class User(DBUser, BaseUser):
|
|
|
await self._handle_checkpoint(e, on="sync")
|
|
|
return
|
|
|
|
|
|
+ self.seq_id = resp.seq_id
|
|
|
+ self.snapshot_at_ms = resp.snapshot_at_ms
|
|
|
+ await self.save_seq_id()
|
|
|
+
|
|
|
if not self._listen_task:
|
|
|
- await self.start_listen(resp.seq_id, resp.snapshot_at_ms)
|
|
|
+ self.start_listen()
|
|
|
|
|
|
max_age = self.config["bridge.portal_create_max_age"] * 1_000_000
|
|
|
limit = self.config["bridge.chat_sync_limit"]
|
|
|
+ create_limit = self.config["bridge.chat_create_limit"]
|
|
|
min_active_at = (time.time() * 1_000_000) - max_age
|
|
|
i = 0
|
|
|
await self.push_bridge_state(BridgeStateEvent.BACKFILLING)
|
|
|
async for thread in self.client.iter_inbox(start_at=resp):
|
|
|
try:
|
|
|
- await self._sync_thread(thread, min_active_at)
|
|
|
+ await self._sync_thread(
|
|
|
+ thread=thread,
|
|
|
+ allow_create=thread.last_activity_at > min_active_at and i < create_limit,
|
|
|
+ )
|
|
|
except Exception:
|
|
|
self.log.exception(f"Error syncing thread {thread.thread_id}")
|
|
|
i += 1
|
|
@@ -464,17 +494,12 @@ class User(DBUser, BaseUser):
|
|
|
except Exception:
|
|
|
self.log.exception("Error updating direct chat list")
|
|
|
|
|
|
- async def start_listen(
|
|
|
- self, seq_id: int | None = None, snapshot_at_ms: int | None = None
|
|
|
- ) -> None:
|
|
|
+ def start_listen(self) -> None:
|
|
|
self.shutdown = False
|
|
|
- if not seq_id:
|
|
|
- resp = await self.client.get_inbox(limit=1)
|
|
|
- seq_id, snapshot_at_ms = resp.seq_id, resp.snapshot_at_ms
|
|
|
- task = self.listen(seq_id=seq_id, snapshot_at_ms=snapshot_at_ms)
|
|
|
+ task = self._listen(seq_id=self.seq_id, snapshot_at_ms=self.snapshot_at_ms)
|
|
|
self._listen_task = self.loop.create_task(task)
|
|
|
|
|
|
- async def listen(self, seq_id: int, snapshot_at_ms: int) -> None:
|
|
|
+ async def _listen(self, seq_id: int, snapshot_at_ms: int) -> None:
|
|
|
try:
|
|
|
await self.mqtt.listen(
|
|
|
graphql_subs={
|
|
@@ -565,12 +590,32 @@ class User(DBUser, BaseUser):
|
|
|
self.client = None
|
|
|
self.mqtt = None
|
|
|
self.state = None
|
|
|
+ self.seq_id = None
|
|
|
+ self.snapshot_at_ms = None
|
|
|
self._is_logged_in = False
|
|
|
await self.update()
|
|
|
|
|
|
# endregion
|
|
|
# region Event handlers
|
|
|
|
|
|
+ async def _save_seq_id_after_sleep(self) -> None:
|
|
|
+ await asyncio.sleep(120)
|
|
|
+ self._seq_id_save_task = None
|
|
|
+ self.log.trace("Saving sequence ID %d/%d", self.seq_id, self.snapshot_at_ms)
|
|
|
+ try:
|
|
|
+ await self.save_seq_id()
|
|
|
+ except Exception:
|
|
|
+ self.log.exception("Error saving sequence ID")
|
|
|
+
|
|
|
+ async def update_seq_id(self, evt: NewSequenceID) -> None:
|
|
|
+ self.seq_id = evt.seq_id
|
|
|
+ self.snapshot_at_ms = evt.snapshot_at_ms
|
|
|
+ if not self._seq_id_save_task or self._seq_id_save_task.done():
|
|
|
+ self.log.trace("Starting seq id save task (%d/%d)", evt.seq_id, evt.snapshot_at_ms)
|
|
|
+ self._seq_id_save_task = asyncio.create_task(self._save_seq_id_after_sleep())
|
|
|
+ else:
|
|
|
+ self.log.trace("Not starting seq id save task (%d/%d)", evt.seq_id, evt.snapshot_at_ms)
|
|
|
+
|
|
|
@async_time(METRIC_MESSAGE)
|
|
|
async def handle_message(self, evt: MessageSyncEvent) -> None:
|
|
|
portal = await po.Portal.get_by_thread_id(evt.message.thread_id, receiver=self.igpk)
|