Sfoglia il codice sorgente

Add support for Matrix->Instagram replies

Tulir Asokan 3 anni fa
parent
commit
e873762d62

+ 1 - 1
ROADMAP.md

@@ -9,7 +9,7 @@
       * [x] Voice messages
       * [ ] Locations
       * [ ] †Files
-    * [ ] Replies
+    * [x] Replies
   * [x] Message redactions
   * [x] Message reactions
   * [ ] Presence

+ 4 - 0
mauigpapi/mqtt/conn.py

@@ -766,6 +766,8 @@ class AndroidMQTT:
         text: str = "",
         shh_mode: bool = False,
         client_context: str | None = None,
+        replied_to_item_id: str | None = None,
+        replied_to_client_context: str | None = None,
     ) -> Awaitable[CommandResponse]:
         return self.send_item(
             thread_id,
@@ -773,6 +775,8 @@ class AndroidMQTT:
             shh_mode=shh_mode,
             item_type=ThreadItemType.TEXT,
             client_context=client_context,
+            replied_to_item_id=replied_to_item_id,
+            replied_to_client_context=replied_to_client_context,
         )
 
     def mark_seen(

+ 14 - 5
mautrix_instagram/db/message.py

@@ -32,15 +32,24 @@ class Message:
     mxid: EventID
     mx_room: RoomID
     item_id: str
+    client_context: str | None
     receiver: int
     sender: int
 
     async def insert(self) -> None:
         q = (
-            "INSERT INTO message (mxid, mx_room, item_id, receiver, sender) "
-            "VALUES ($1, $2, $3, $4, $5)"
+            "INSERT INTO message (mxid, mx_room, item_id, client_context, receiver, sender) "
+            "VALUES ($1, $2, $3, $4, $5, $6)"
+        )
+        await self.db.execute(
+            q,
+            self.mxid,
+            self.mx_room,
+            self.item_id,
+            self.client_context,
+            self.receiver,
+            self.sender,
         )
-        await self.db.execute(q, self.mxid, self.mx_room, self.item_id, self.receiver, self.sender)
 
     async def delete(self) -> None:
         q = "DELETE FROM message WHERE item_id=$1 AND receiver=$2"
@@ -53,7 +62,7 @@ class Message:
     @classmethod
     async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Message | None:
         q = (
-            "SELECT mxid, mx_room, item_id, receiver, sender "
+            "SELECT mxid, mx_room, item_id, client_context, receiver, sender "
             "FROM message WHERE mxid=$1 AND mx_room=$2"
         )
         row = await cls.db.fetchrow(q, mxid, mx_room)
@@ -64,7 +73,7 @@ class Message:
     @classmethod
     async def get_by_item_id(cls, item_id: str, receiver: int) -> Message | None:
         q = (
-            "SELECT mxid, mx_room, item_id, receiver, sender "
+            "SELECT mxid, mx_room, item_id, client_context, receiver, sender "
             "FROM message WHERE item_id=$1 AND receiver=$2"
         )
         row = await cls.db.fetchrow(q, item_id, receiver)

+ 5 - 0
mautrix_instagram/db/upgrade.py

@@ -105,3 +105,8 @@ async def upgrade_v2(conn: Connection) -> None:
 @upgrade_table.register(description="Add relay user field to portal table")
 async def upgrade_v3(conn: Connection) -> None:
     await conn.execute("ALTER TABLE portal ADD COLUMN relay_user_id TEXT")
+
+
+@upgrade_table.register(description="Add client context field to message table")
+async def upgrade_v4(conn: Connection) -> None:
+    await conn.execute("ALTER TABLE message ADD COLUMN client_context TEXT")

+ 13 - 1
mautrix_instagram/portal.py

@@ -376,6 +376,15 @@ class Portal(DBPortal, BasePortal):
         if is_relay:
             await self.apply_relay_message_format(orig_sender, message)
 
+        reply_to = {}
+        if message.get_reply_to():
+            msg = await DBMessage.get_by_mxid(message.get_reply_to(), self.mxid)
+            if msg and msg.client_context:
+                reply_to = {
+                    "replied_to_item_id": msg.item_id,
+                    "replied_to_client_context": msg.client_context,
+                }
+
         request_id = sender.state.gen_client_context()
         self._reqid_dedup.add(request_id)
         self.log.debug(
@@ -388,7 +397,7 @@ class Portal(DBPortal, BasePortal):
                 text = f"/me {text}"
             self.log.trace(f"Sending Matrix text from {event_id} with request ID {request_id}")
             resp = await sender.mqtt.send_text(
-                self.thread_id, text=text, client_context=request_id
+                self.thread_id, text=text, client_context=request_id, **reply_to
             )
         elif message.msgtype.is_media:
             if message.file and decrypt_attachment:
@@ -457,6 +466,7 @@ class Portal(DBPortal, BasePortal):
                     mxid=event_id,
                     mx_room=self.mxid,
                     item_id=resp.payload.item_id,
+                    client_context=resp.payload.client_context,
                     receiver=self.receiver,
                     sender=sender.igpk,
                 ).insert()
@@ -938,6 +948,7 @@ class Portal(DBPortal, BasePortal):
                     mxid=media_event_id,
                     mx_room=self.mxid,
                     item_id=fake_item_id,
+                    client_context=None,
                     receiver=self.receiver,
                     sender=media.user.pk,
                 ).insert()
@@ -1118,6 +1129,7 @@ class Portal(DBPortal, BasePortal):
                 mxid=event_id,
                 mx_room=self.mxid,
                 item_id=item.item_id,
+                client_context=item.client_context,
                 receiver=self.receiver,
                 sender=sender.pk,
             )