Browse Source

db/message: add bulk_insert function

Signed-off-by: Sumner Evans <sumner@beeper.com>
Sumner Evans 2 years ago
parent
commit
643d67523e
1 changed files with 16 additions and 9 deletions
  1. 16 9
      mautrix_instagram/db/message.py

+ 16 - 9
mautrix_instagram/db/message.py

@@ -18,9 +18,10 @@ from __future__ import annotations
 from typing import TYPE_CHECKING, ClassVar
 
 from attr import dataclass
+import attr
 
 from mautrix.types import EventID, RoomID
-from mautrix.util.async_db import Database
+from mautrix.util.async_db import Database, Scheme
 
 fake_db = Database.create("") if TYPE_CHECKING else None
 
@@ -37,18 +38,16 @@ class Message:
     sender: int
     ig_timestamp: int | None
 
+    _columns = "mxid, mx_room, item_id, client_context, receiver, sender, ig_timestamp"
+    _insert_query = f"INSERT INTO message ({_columns}) VALUES ($1, $2, $3, $4, $5, $6, $7)"
+
     @property
     def ig_timestamp_ms(self) -> int:
         return (self.ig_timestamp // 1000) if self.ig_timestamp else 0
 
     async def insert(self) -> None:
-        q = """
-            INSERT INTO message (mxid, mx_room, item_id, client_context, receiver, sender,
-                                 ig_timestamp)
-            VALUES ($1, $2, $3, $4, $5, $6, $7)
-        """
         await self.db.execute(
-            q,
+            self._insert_query,
             self.mxid,
             self.mx_room,
             self.item_id,
@@ -58,6 +57,16 @@ class Message:
             self.ig_timestamp,
         )
 
+    @classmethod
+    async def bulk_insert(cls, messages: list[Message]) -> None:
+        columns = cls._columns.split(", ")
+        records = [attr.astuple(message) for message in messages]
+        async with cls.db.acquire() as conn, conn.transaction():
+            if cls.db.scheme == Scheme.POSTGRES:
+                await conn.copy_records_to_table("message", records=records, columns=columns)
+            else:
+                await conn.executemany(cls._insert_query, records)
+
     async def delete(self) -> None:
         q = "DELETE FROM message WHERE item_id=$1 AND receiver=$2"
         await self.db.execute(q, self.item_id, self.receiver)
@@ -66,8 +75,6 @@ class Message:
     async def delete_all(cls, room_id: RoomID) -> None:
         await cls.db.execute("DELETE FROM message WHERE mx_room=$1", room_id)
 
-    _columns = "mxid, mx_room, item_id, client_context, receiver, sender, ig_timestamp"
-
     @classmethod
     async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Message | None:
         q = f"SELECT {cls._columns} FROM message WHERE mxid=$1 AND mx_room=$2"