Browse Source

Fix backfilling portals created when receiving new message

Tulir Asokan 2 years ago
parent
commit
fcf532e78e
2 changed files with 15 additions and 18 deletions
  1. 0 9
      mautrix_instagram/portal.py
  2. 15 9
      mautrix_instagram/user.py

+ 0 - 9
mautrix_instagram/portal.py

@@ -1642,15 +1642,6 @@ class Portal(DBPortal, BasePortal):
         self.log.debug(
         self.log.debug(
             f"Handling Instagram message {item.item_id} ({item.client_context}) by {item.user_id}"
             f"Handling Instagram message {item.item_id} ({item.client_context}) by {item.user_id}"
         )
         )
-        if not self.mxid:
-            thread = await source.client.get_thread(item.thread_id)
-            mxid = await self.create_matrix_room(source, thread.thread)
-            if not mxid:
-                # Failed to create
-                return
-
-            if self.config["bridge.backfill.enable"] and self.config["bridge.backfill.msc2716"]:
-                await self.enqueue_immediate_backfill(source, 0)
 
 
         intent = sender.intent_for(self)
         intent = sender.intent_for(self)
         background_task.create(intent.set_typing(self.mxid, timeout=0))
         background_task.create(intent.set_typing(self.mxid, timeout=0))

+ 15 - 9
mautrix_instagram/user.py

@@ -559,7 +559,9 @@ class User(DBUser, BaseUser):
             info = {"challenge": e.body.challenge.serialize() if e.body.challenge else None}
             info = {"challenge": e.body.challenge.serialize() if e.body.challenge else None}
         await self.push_bridge_state(BridgeStateEvent.BAD_CREDENTIALS, error=error_code, info=info)
         await self.push_bridge_state(BridgeStateEvent.BAD_CREDENTIALS, error=error_code, info=info)
 
 
-    async def _sync_thread(self, thread: Thread) -> bool:
+    async def _sync_thread(
+        self, thread: Thread, enqueue_backfill: bool = True, portal: po.Portal | None = None
+    ) -> bool:
         """
         """
         Sync a specific thread. Returns whether the thread had messages after the last message in
         Sync a specific thread. Returns whether the thread had messages after the last message in
         the database before the sync.
         the database before the sync.
@@ -569,8 +571,11 @@ class User(DBUser, BaseUser):
         forward_messages = thread.items
         forward_messages = thread.items
 
 
         assert self.client
         assert self.client
-        portal = await po.Portal.get_by_thread(thread, self.igpk)
-        assert portal
+        if not portal:
+            portal = await po.Portal.get_by_thread(thread, self.igpk)
+            assert portal
+        else:
+            assert portal.thread_id == thread.thread_id
 
 
         # Create or update the Matrix room
         # Create or update the Matrix room
         if not portal.mxid:
         if not portal.mxid:
@@ -655,7 +660,7 @@ class User(DBUser, BaseUser):
 
 
             await portal._update_read_receipts(thread.last_seen_at)
             await portal._update_read_receipts(thread.last_seen_at)
 
 
-        if self.config["bridge.backfill.msc2716"]:
+        if self.config["bridge.backfill.msc2716"] and enqueue_backfill:
             await portal.enqueue_immediate_backfill(self, 1)
             await portal.enqueue_immediate_backfill(self, 1)
         return len(forward_messages) > 0
         return len(forward_messages) > 0
 
 
@@ -1077,11 +1082,12 @@ class User(DBUser, BaseUser):
     async def handle_message(self, evt: MessageSyncEvent) -> None:
     async def handle_message(self, evt: MessageSyncEvent) -> None:
         portal = await po.Portal.get_by_thread_id(evt.message.thread_id, receiver=self.igpk)
         portal = await po.Portal.get_by_thread_id(evt.message.thread_id, receiver=self.igpk)
         if not portal or not portal.mxid:
         if not portal or not portal.mxid:
-            self.log.debug("Got message in thread with no portal, getting info...")
+            self.log.debug(
+                "Got message in thread with no portal, getting info and syncing thread..."
+            )
             resp = await self.client.get_thread(evt.message.thread_id)
             resp = await self.client.get_thread(evt.message.thread_id)
             portal = await po.Portal.get_by_thread(resp.thread, self.igpk)
             portal = await po.Portal.get_by_thread(resp.thread, self.igpk)
-            self.log.debug("Got info for unknown portal, creating room")
-            await portal.create_matrix_room(self, resp.thread)
+            await self._sync_thread(resp.thread, enqueue_backfill=False, portal=portal)
             if not portal.mxid:
             if not portal.mxid:
                 self.log.warning(
                 self.log.warning(
                     "Room creation appears to have failed, "
                     "Room creation appears to have failed, "
@@ -1114,18 +1120,18 @@ class User(DBUser, BaseUser):
         portal = await po.Portal.get_by_thread(evt, receiver=self.igpk)
         portal = await po.Portal.get_by_thread(evt, receiver=self.igpk)
         if portal.mxid:
         if portal.mxid:
             self.log.debug("Got thread sync event for %s with existing portal", portal.thread_id)
             self.log.debug("Got thread sync event for %s with existing portal", portal.thread_id)
-            await portal.update_matrix_room(self, evt)
         elif evt.is_group:
         elif evt.is_group:
             self.log.debug(
             self.log.debug(
                 "Got thread sync event for group %s without existing portal, creating room",
                 "Got thread sync event for group %s without existing portal, creating room",
                 portal.thread_id,
                 portal.thread_id,
             )
             )
-            await portal.create_matrix_room(self, evt)
         else:
         else:
             self.log.debug(
             self.log.debug(
                 "Got thread sync event for DM %s without existing portal, ignoring",
                 "Got thread sync event for DM %s without existing portal, ignoring",
                 portal.thread_id,
                 portal.thread_id,
             )
             )
+            return
+        await self._sync_thread(evt, enqueue_backfill=False, portal=portal)
 
 
     async def handle_thread_remove(self, evt: ThreadRemoveEvent) -> None:
     async def handle_thread_remove(self, evt: ThreadRemoveEvent) -> None:
         self.log.debug("Got thread remove event: %s", evt.serialize())
         self.log.debug("Got thread remove event: %s", evt.serialize())