Selaa lähdekoodia

db/portal: update model with new fields

Signed-off-by: Sumner Evans <sumner@beeper.com>
Sumner Evans 2 vuotta sitten
vanhempi
sitoutus
f61abcfa13
2 muutettua tiedostoa jossa 67 lisäystä ja 41 poistoa
  1. 61 41
      mautrix_instagram/db/portal.py
  2. 6 0
      mautrix_instagram/portal.py

+ 61 - 41
mautrix_instagram/db/portal.py

@@ -17,10 +17,10 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING, ClassVar
 
+from asyncpg import Record
 from attr import dataclass
-import asyncpg
 
-from mautrix.types import ContentURI, RoomID, UserID
+from mautrix.types import BatchID, ContentURI, EventID, RoomID, UserID
 from mautrix.util.async_db import Database
 
 fake_db = Database.create("") if TYPE_CHECKING else None
@@ -40,6 +40,10 @@ class Portal:
     name_set: bool
     avatar_set: bool
     relay_user_id: UserID | None
+    first_event_id: EventID | None
+    next_batch_id: BatchID | None
+    historical_base_insertion_event_id: EventID | None
+    cursor: str | none
 
     @property
     def _values(self):
@@ -54,77 +58,93 @@ class Portal:
             self.name_set,
             self.avatar_set,
             self.relay_user_id,
+            self.first_event_id,
+            self.next_batch_id,
+            self.historical_base_insertion_event_id,
+            self.cursor,
         )
 
+    column_names = ",".join(
+        (
+            "thread_id",
+            "receiver",
+            "other_user_pk",
+            "mxid",
+            "name",
+            "avatar_url",
+            "encrypted",
+            "name_set",
+            "avatar_set",
+            "relay_user_id",
+            "first_event_id",
+            "next_batch_id",
+            "historical_base_insertion_event_id",
+            "cursor",
+        )
+    )
+
     async def insert(self) -> None:
         q = (
-            "INSERT INTO portal (thread_id, receiver, other_user_pk, mxid, name, avatar_url, "
-            "                    encrypted, name_set, avatar_set, relay_user_id) "
-            "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"
+            f"INSERT INTO portal ({self.column_names}) "
+            "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)"
         )
         await self.db.execute(q, *self._values)
 
     async def update(self) -> None:
         q = (
             "UPDATE portal SET other_user_pk=$3, mxid=$4, name=$5, avatar_url=$6, encrypted=$7,"
-            "                  name_set=$8, avatar_set=$9, relay_user_id=$10 "
+            "                  name_set=$8, avatar_set=$9, relay_user_id=$10, first_event_id=$11,"
+            "                  next_batch_id=$12, historical_base_insertion_event_id=$13,"
+            "                  cursor=$14 "
             "WHERE thread_id=$1 AND receiver=$2"
         )
         await self.db.execute(q, *self._values)
 
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> Portal:
+    def _from_row(cls, row: Record | None) -> Portal | None:
+        if row is None:
+            return None
         return cls(**row)
 
     @classmethod
     async def get_by_mxid(cls, mxid: RoomID) -> Portal | None:
-        q = (
-            "SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
-            "       name_set, avatar_set, relay_user_id "
-            "FROM portal WHERE mxid=$1"
-        )
+        q = f"SELECT {cls.column_names} FROM portal WHERE mxid=$1"
         row = await cls.db.fetchrow(q, mxid)
-        if not row:
-            return None
         return cls._from_row(row)
 
     @classmethod
     async def get_by_thread_id(
         cls, thread_id: str, receiver: int, rec_must_match: bool = True
     ) -> Portal | None:
-        q = (
-            "SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
-            "       name_set, avatar_set, relay_user_id "
-            "FROM portal WHERE thread_id=$1 AND receiver=$2"
-        )
+        q = f"SELECT {cls.column_names} FROM portal WHERE thread_id=$1 AND receiver=$2"
         if not rec_must_match:
-            q = (
-                "SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
-                "       name_set, avatar_set "
-                "FROM portal WHERE thread_id=$1 AND (receiver=$2 OR receiver=0)"
-            )
+            q = f"""
+                SELECT {cls.column_names}
+                FROM portal
+                WHERE thread_id=$1
+                    AND (receiver=$2 OR receiver=0)
+            """
         row = await cls.db.fetchrow(q, thread_id, receiver)
-        if not row:
-            return None
         return cls._from_row(row)
 
     @classmethod
     async def find_private_chats_of(cls, receiver: int) -> list[Portal]:
-        q = (
-            "SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
-            "       name_set, avatar_set, relay_user_id "
-            "FROM portal WHERE receiver=$1 AND other_user_pk IS NOT NULL"
-        )
+        q = f"""
+            SELECT {cls.column_names}
+            FROM portal
+            WHERE receiver=$1
+                AND other_user_pk IS NOT NULL
+        """
         rows = await cls.db.fetch(q, receiver)
         return [cls._from_row(row) for row in rows]
 
     @classmethod
     async def find_private_chats_with(cls, other_user: int) -> list[Portal]:
-        q = (
-            "SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
-            "       name_set, avatar_set, relay_user_id "
-            "FROM portal WHERE other_user_pk=$1"
-        )
+        q = f"""
+            SELECT {cls.column_names}
+            FROM portal
+            WHERE other_user_pk=$1
+        """
         rows = await cls.db.fetch(q, other_user)
         return [cls._from_row(row) for row in rows]
 
@@ -135,10 +155,10 @@ class Portal:
 
     @classmethod
     async def all_with_room(cls) -> list[Portal]:
-        q = (
-            "SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
-            "       name_set, avatar_set, relay_user_id "
-            "FROM portal WHERE mxid IS NOT NULL"
-        )
+        q = f"""
+            SELECT {cls.column_names}
+            FROM portal
+            WHERE mxid IS NOT NULL
+        """
         rows = await cls.db.fetch(q)
         return [cls._from_row(row) for row in rows]

+ 6 - 0
mautrix_instagram/portal.py

@@ -153,6 +153,9 @@ class Portal(DBPortal, BasePortal):
         name_set: bool = False,
         avatar_set: bool = False,
         relay_user_id: UserID | None = None,
+        first_event_id: EventID | None = None,
+        next_batch_id: BatchID | None = None,
+        historical_base_insertion_event_id: EventID | None = None,
     ) -> None:
         super().__init__(
             thread_id,
@@ -165,6 +168,9 @@ class Portal(DBPortal, BasePortal):
             name_set,
             avatar_set,
             relay_user_id,
+            first_event_id,
+            next_batch_id,
+            historical_base_insertion_event_id,
         )
         self._create_room_lock = asyncio.Lock()
         self.log = self.log.getChild(thread_id)