Bladeren bron

backfill: implement reaction backfill

Signed-off-by: Sumner Evans <sumner@beeper.com>
Sumner Evans 2 jaren geleden
bovenliggende
commit
3d1c7af23b
1 gewijzigde bestanden met toevoegingen van 88 en 34 verwijderingen
  1. 88 34
      mautrix_instagram/portal.py

+ 88 - 34
mautrix_instagram/portal.py

@@ -19,6 +19,8 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Unio
 from collections import deque
 from io import BytesIO
 import asyncio
+import base64
+import hashlib
 import html
 import json
 import mimetypes
@@ -74,6 +76,7 @@ from mautrix.types import (
     MessageStatus,
     MessageStatusReason,
     MessageType,
+    ReactionEventContent,
     RelatesTo,
     RelationType,
     RoomID,
@@ -1464,13 +1467,6 @@ class Portal(DBPortal, BasePortal):
         ).insert()
         await self._send_delivery_receipt(event_ids[-1])
 
-        # TODO handle reactions
-        return
-        if isinstance(message, graphql.Message) and message.message_reactions:
-            await self._handle_graphql_reactions(
-                source, created_msgs[0], message.message_reactions, timestamp
-            )
-
     async def convert_instagram_item(
         self, source: u.User, sender: p.Puppet, item: ThreadItem
     ) -> list[ConvertedMessage]:
@@ -1537,9 +1533,17 @@ class Portal(DBPortal, BasePortal):
             converted.append(await self._convert_instagram_unhandled(item))
 
         return converted
-        # TODO HANDLE REACTIONS
-        if is_backfill and item.reactions:
-            await self._handle_instagram_reactions(msg, item.reactions.emojis, is_backfill=True)
+
+    def _deterministic_event_id(
+        self, sender: p.Puppet, item_id: str, part_name: int | None = None
+    ) -> EventID:
+        hash_content = f"{self.mxid}/instagram/{sender.igpk}/{item_id}"
+        if part_name:
+            hash_content += f"/{part_name}"
+        print("HASH CONTENT:", hash_content)
+        hashed = hashlib.sha256(hash_content.encode("utf-8")).digest()
+        b64hash = base64.urlsafe_b64encode(hashed).decode("utf-8").rstrip("=")
+        return EventID(f"${b64hash}:telegram.org")
 
     async def handle_instagram_remove(self, item_id: str) -> None:
         message = await DBMessage.get_by_item_id(item_id, self.receiver)
@@ -1992,20 +1996,21 @@ class Portal(DBPortal, BasePortal):
             )
             added_members.add(mxid)
 
-        message_infos: list[tuple[ThreadItem, int]] = []
-        intents: list[IntentAPI] = []
-
-        for message in message_page:
-            puppet: p.Puppet = await p.Puppet.get_by_pk(message.user_id)
+        async def intent_for(user_id: int) -> tuple[p.Puppet, IntentAPI]:
+            puppet: p.Puppet = await p.Puppet.get_by_pk(user_id)
             if puppet:
                 intent = puppet.intent_for(self)
-                if not puppet.name:
-                    # TODO where to get "info"
-                    await puppet.update_info(info, source)
             else:
                 intent = self.main_intent
             if puppet.is_real_user and not self._can_double_puppet_backfill(intent.mxid):
                 intent = puppet.default_mxid_intent
+            return puppet, intent
+
+        message_infos: list[tuple[ThreadItem | Reaction, int]] = []
+        intents: list[IntentAPI] = []
+
+        for message in message_page:
+            puppet, intent = await intent_for(message.user_id)
 
             # Convert the message
             converted = await self.convert_instagram_item(source, puppet, message)
@@ -2016,6 +2021,7 @@ class Portal(DBPortal, BasePortal):
             if intent.mxid not in current_members:
                 add_member(puppet, intent.mxid)
 
+            d_event_id = None
             for index, (event_type, content) in enumerate(converted):
                 if self.encrypted and self.matrix.e2ee:
                     event_type, content = await self.matrix.e2ee.encrypt(
@@ -2024,6 +2030,9 @@ class Portal(DBPortal, BasePortal):
                 if intent.api.is_real_user and intent.api.bridge_name is not None:
                     content[DOUBLE_PUPPET_SOURCE_KEY] = intent.api.bridge_name
 
+                if self.bridge.homeserver_software.is_hungry:
+                    d_event_id = self._deterministic_event_id(puppet, message.item_id, index)
+
                 message_infos.append((message, index))
                 batch_messages.append(
                     BatchSendEvent(
@@ -2031,10 +2040,30 @@ class Portal(DBPortal, BasePortal):
                         type=event_type,
                         sender=intent.mxid,
                         timestamp=message.timestamp_ms,
+                        event_id=d_event_id,
                     )
                 )
                 intents.append(intent)
 
+            if self.bridge.homeserver_software.is_hungry and message.reactions:
+                for reaction in message.reactions.emojis:
+                    puppet, intent = await intent_for(reaction.sender_id)
+
+                    reaction_event = ReactionEventContent()
+                    reaction_event.relates_to = RelatesTo(
+                        rel_type=RelationType.ANNOTATION, event_id=d_event_id, key=reaction.emoji
+                    )
+
+                    message_infos.append((reaction, 0))
+                    batch_messages.append(
+                        BatchSendEvent(
+                            content=reaction_event,
+                            type=EventType.REACTION,
+                            sender=intent.mxid,
+                            timestamp=reaction.timestamp_ms,
+                        )
+                    )
+
         if not batch_messages:
             return None
 
@@ -2096,7 +2125,7 @@ class Portal(DBPortal, BasePortal):
         )
 
     async def _finish_batch(
-        self, event_ids: list[EventID], message_infos: list[tuple[ThreadItem, int]]
+        self, event_ids: list[EventID], message_infos: list[tuple[ThreadItem | Reaction, int]]
     ):
         # We have to do this slightly annoying processing of the event IDs and message infos so
         # that we only map the last event ID to the message.
@@ -2104,21 +2133,40 @@ class Portal(DBPortal, BasePortal):
         # since there's only ever one event per message.
         current_message = None
         messages = []
-        for event_id, (message, index) in zip(event_ids, message_infos):
-            if index == 0 and current_message:
-                # This means that all of the events for the previous message have been processed,
-                # and the current_message is the most recent event for that message.
-                messages.append(current_message)
-
-            current_message = DBMessage(
-                mxid=event_id,
-                mx_room=self.mxid,
-                item_id=message.item_id,
-                client_context=message.client_context,
-                receiver=self.receiver,
-                sender=message.user_id,
-                ig_timestamp=message.timestamp,
-            )
+        reactions = []
+        message_id = None
+        for event_id, (message_or_reaction, index) in zip(event_ids, message_infos):
+            if isinstance(message_or_reaction, ThreadItem):
+                message = message_or_reaction
+                if index == 0 and current_message:
+                    # This means that all of the events for the previous message have been processed,
+                    # and the current_message is the most recent event for that message.
+                    messages.append(current_message)
+
+                current_message = DBMessage(
+                    mxid=event_id,
+                    mx_room=self.mxid,
+                    item_id=message.item_id,
+                    client_context=message.client_context,
+                    receiver=self.receiver,
+                    sender=message.user_id,
+                    ig_timestamp=message.timestamp,
+                )
+                message_id = message.item_id
+            else:
+                assert message_id
+                reaction = message_or_reaction
+                reactions.append(
+                    DBReaction(
+                        mxid=event_id,
+                        mx_room=self.mxid,
+                        ig_item_id=message_id,
+                        ig_receiver=self.receiver,
+                        ig_sender=reaction.sender_id,
+                        reaction=reaction.emoji,
+                        mx_timestamp=reaction.timestamp_ms,
+                    )
+                )
 
         if current_message:
             messages.append(current_message)
@@ -2128,6 +2176,12 @@ class Portal(DBPortal, BasePortal):
         except Exception:
             self.log.exception("Failed to store batch message IDs")
 
+        try:
+            for reaction in reactions:
+                await reaction.insert()
+        except Exception:
+            self.log.exception("Failed to store backfilled reactions")
+
     async def send_post_backfill_dummy(
         self,
         last_message_ig_timestamp: int,