|
@@ -17,10 +17,10 @@ from __future__ import annotations
|
|
|
|
|
|
from typing import TYPE_CHECKING, ClassVar
|
|
from typing import TYPE_CHECKING, ClassVar
|
|
|
|
|
|
|
|
+from asyncpg import Record
|
|
from attr import dataclass
|
|
from attr import dataclass
|
|
-import asyncpg
|
|
|
|
|
|
|
|
-from mautrix.types import ContentURI, RoomID, UserID
|
|
|
|
|
|
+from mautrix.types import BatchID, ContentURI, EventID, RoomID, UserID
|
|
from mautrix.util.async_db import Database
|
|
from mautrix.util.async_db import Database
|
|
|
|
|
|
fake_db = Database.create("") if TYPE_CHECKING else None
|
|
fake_db = Database.create("") if TYPE_CHECKING else None
|
|
@@ -40,6 +40,10 @@ class Portal:
|
|
name_set: bool
|
|
name_set: bool
|
|
avatar_set: bool
|
|
avatar_set: bool
|
|
relay_user_id: UserID | None
|
|
relay_user_id: UserID | None
|
|
|
|
+ first_event_id: EventID | None
|
|
|
|
+ next_batch_id: BatchID | None
|
|
|
|
+ historical_base_insertion_event_id: EventID | None
|
|
|
|
+ cursor: str | none
|
|
|
|
|
|
@property
|
|
@property
|
|
def _values(self):
|
|
def _values(self):
|
|
@@ -54,77 +58,93 @@ class Portal:
|
|
self.name_set,
|
|
self.name_set,
|
|
self.avatar_set,
|
|
self.avatar_set,
|
|
self.relay_user_id,
|
|
self.relay_user_id,
|
|
|
|
+ self.first_event_id,
|
|
|
|
+ self.next_batch_id,
|
|
|
|
+ self.historical_base_insertion_event_id,
|
|
|
|
+ self.cursor,
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ column_names = ",".join(
|
|
|
|
+ (
|
|
|
|
+ "thread_id",
|
|
|
|
+ "receiver",
|
|
|
|
+ "other_user_pk",
|
|
|
|
+ "mxid",
|
|
|
|
+ "name",
|
|
|
|
+ "avatar_url",
|
|
|
|
+ "encrypted",
|
|
|
|
+ "name_set",
|
|
|
|
+ "avatar_set",
|
|
|
|
+ "relay_user_id",
|
|
|
|
+ "first_event_id",
|
|
|
|
+ "next_batch_id",
|
|
|
|
+ "historical_base_insertion_event_id",
|
|
|
|
+ "cursor",
|
|
|
|
+ )
|
|
|
|
+ )
|
|
|
|
+
|
|
async def insert(self) -> None:
|
|
async def insert(self) -> None:
|
|
q = (
|
|
q = (
|
|
- "INSERT INTO portal (thread_id, receiver, other_user_pk, mxid, name, avatar_url, "
|
|
|
|
- " encrypted, name_set, avatar_set, relay_user_id) "
|
|
|
|
- "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"
|
|
|
|
|
|
+ f"INSERT INTO portal ({self.column_names}) "
|
|
|
|
+ "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)"
|
|
)
|
|
)
|
|
await self.db.execute(q, *self._values)
|
|
await self.db.execute(q, *self._values)
|
|
|
|
|
|
async def update(self) -> None:
|
|
async def update(self) -> None:
|
|
q = (
|
|
q = (
|
|
"UPDATE portal SET other_user_pk=$3, mxid=$4, name=$5, avatar_url=$6, encrypted=$7,"
|
|
"UPDATE portal SET other_user_pk=$3, mxid=$4, name=$5, avatar_url=$6, encrypted=$7,"
|
|
- " name_set=$8, avatar_set=$9, relay_user_id=$10 "
|
|
|
|
|
|
+ " name_set=$8, avatar_set=$9, relay_user_id=$10, first_event_id=$11,"
|
|
|
|
+ " next_batch_id=$12, historical_base_insertion_event_id=$13,"
|
|
|
|
+ " cursor=$14 "
|
|
"WHERE thread_id=$1 AND receiver=$2"
|
|
"WHERE thread_id=$1 AND receiver=$2"
|
|
)
|
|
)
|
|
await self.db.execute(q, *self._values)
|
|
await self.db.execute(q, *self._values)
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
- def _from_row(cls, row: asyncpg.Record) -> Portal:
|
|
|
|
|
|
+ def _from_row(cls, row: Record | None) -> Portal | None:
|
|
|
|
+ if row is None:
|
|
|
|
+ return None
|
|
return cls(**row)
|
|
return cls(**row)
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
async def get_by_mxid(cls, mxid: RoomID) -> Portal | None:
|
|
async def get_by_mxid(cls, mxid: RoomID) -> Portal | None:
|
|
- q = (
|
|
|
|
- "SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
|
|
|
|
- " name_set, avatar_set, relay_user_id "
|
|
|
|
- "FROM portal WHERE mxid=$1"
|
|
|
|
- )
|
|
|
|
|
|
+ q = f"SELECT {cls.column_names} FROM portal WHERE mxid=$1"
|
|
row = await cls.db.fetchrow(q, mxid)
|
|
row = await cls.db.fetchrow(q, mxid)
|
|
- if not row:
|
|
|
|
- return None
|
|
|
|
return cls._from_row(row)
|
|
return cls._from_row(row)
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
async def get_by_thread_id(
|
|
async def get_by_thread_id(
|
|
cls, thread_id: str, receiver: int, rec_must_match: bool = True
|
|
cls, thread_id: str, receiver: int, rec_must_match: bool = True
|
|
) -> Portal | None:
|
|
) -> Portal | None:
|
|
- q = (
|
|
|
|
- "SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
|
|
|
|
- " name_set, avatar_set, relay_user_id "
|
|
|
|
- "FROM portal WHERE thread_id=$1 AND receiver=$2"
|
|
|
|
- )
|
|
|
|
|
|
+ q = f"SELECT {cls.column_names} FROM portal WHERE thread_id=$1 AND receiver=$2"
|
|
if not rec_must_match:
|
|
if not rec_must_match:
|
|
- q = (
|
|
|
|
- "SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
|
|
|
|
- " name_set, avatar_set "
|
|
|
|
- "FROM portal WHERE thread_id=$1 AND (receiver=$2 OR receiver=0)"
|
|
|
|
- )
|
|
|
|
|
|
+ q = f"""
|
|
|
|
+ SELECT {cls.column_names}
|
|
|
|
+ FROM portal
|
|
|
|
+ WHERE thread_id=$1
|
|
|
|
+ AND (receiver=$2 OR receiver=0)
|
|
|
|
+ """
|
|
row = await cls.db.fetchrow(q, thread_id, receiver)
|
|
row = await cls.db.fetchrow(q, thread_id, receiver)
|
|
- if not row:
|
|
|
|
- return None
|
|
|
|
return cls._from_row(row)
|
|
return cls._from_row(row)
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
async def find_private_chats_of(cls, receiver: int) -> list[Portal]:
|
|
async def find_private_chats_of(cls, receiver: int) -> list[Portal]:
|
|
- q = (
|
|
|
|
- "SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
|
|
|
|
- " name_set, avatar_set, relay_user_id "
|
|
|
|
- "FROM portal WHERE receiver=$1 AND other_user_pk IS NOT NULL"
|
|
|
|
- )
|
|
|
|
|
|
+ q = f"""
|
|
|
|
+ SELECT {cls.column_names}
|
|
|
|
+ FROM portal
|
|
|
|
+ WHERE receiver=$1
|
|
|
|
+ AND other_user_pk IS NOT NULL
|
|
|
|
+ """
|
|
rows = await cls.db.fetch(q, receiver)
|
|
rows = await cls.db.fetch(q, receiver)
|
|
return [cls._from_row(row) for row in rows]
|
|
return [cls._from_row(row) for row in rows]
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
async def find_private_chats_with(cls, other_user: int) -> list[Portal]:
|
|
async def find_private_chats_with(cls, other_user: int) -> list[Portal]:
|
|
- q = (
|
|
|
|
- "SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
|
|
|
|
- " name_set, avatar_set, relay_user_id "
|
|
|
|
- "FROM portal WHERE other_user_pk=$1"
|
|
|
|
- )
|
|
|
|
|
|
+ q = f"""
|
|
|
|
+ SELECT {cls.column_names}
|
|
|
|
+ FROM portal
|
|
|
|
+ WHERE other_user_pk=$1
|
|
|
|
+ """
|
|
rows = await cls.db.fetch(q, other_user)
|
|
rows = await cls.db.fetch(q, other_user)
|
|
return [cls._from_row(row) for row in rows]
|
|
return [cls._from_row(row) for row in rows]
|
|
|
|
|
|
@@ -135,10 +155,10 @@ class Portal:
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
async def all_with_room(cls) -> list[Portal]:
|
|
async def all_with_room(cls) -> list[Portal]:
|
|
- q = (
|
|
|
|
- "SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
|
|
|
|
- " name_set, avatar_set, relay_user_id "
|
|
|
|
- "FROM portal WHERE mxid IS NOT NULL"
|
|
|
|
- )
|
|
|
|
|
|
+ q = f"""
|
|
|
|
+ SELECT {cls.column_names}
|
|
|
|
+ FROM portal
|
|
|
|
+ WHERE mxid IS NOT NULL
|
|
|
|
+ """
|
|
rows = await cls.db.fetch(q)
|
|
rows = await cls.db.fetch(q)
|
|
return [cls._from_row(row) for row in rows]
|
|
return [cls._from_row(row) for row in rows]
|