portal.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2021 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING, ClassVar
  18. from uuid import UUID
  19. from attr import dataclass
  20. import asyncpg
  21. from mausignald.types import GroupID
  22. from mautrix.types import ContentURI, RoomID, UserID
  23. from mautrix.util.async_db import Database
  24. from .util import ensure_uuid
  25. fake_db = Database.create("") if TYPE_CHECKING else None
  26. @dataclass
  27. class Portal:
  28. db: ClassVar[Database] = fake_db
  29. chat_id: GroupID | UUID
  30. receiver: str
  31. mxid: RoomID | None
  32. name: str | None
  33. topic: str | None
  34. avatar_hash: str | None
  35. avatar_url: ContentURI | None
  36. name_set: bool
  37. avatar_set: bool
  38. revision: int
  39. encrypted: bool
  40. relay_user_id: UserID | None
  41. expiration_time: int | None
  42. @property
  43. def _values(self):
  44. return (
  45. str(self.chat_id),
  46. self.receiver,
  47. self.mxid,
  48. self.name,
  49. self.topic,
  50. self.avatar_hash,
  51. self.avatar_url,
  52. self.name_set,
  53. self.avatar_set,
  54. self.revision,
  55. self.encrypted,
  56. self.relay_user_id,
  57. self.expiration_time,
  58. )
  59. async def insert(self) -> None:
  60. q = """
  61. INSERT INTO portal (
  62. chat_id, receiver, mxid, name, topic, avatar_hash, avatar_url, name_set, avatar_set,
  63. revision, encrypted, relay_user_id, expiration_time
  64. ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
  65. """
  66. await self.db.execute(q, *self._values)
  67. async def update(self) -> None:
  68. q = """
  69. UPDATE portal SET mxid=$3, name=$4, topic=$5, avatar_hash=$6, avatar_url=$7, name_set=$8,
  70. avatar_set=$9, revision=$10, encrypted=$11, relay_user_id=$12,
  71. expiration_time=$13
  72. WHERE chat_id=$1 AND receiver=$2
  73. """
  74. await self.db.execute(q, *self._values)
  75. @classmethod
  76. def _from_row(cls, row: asyncpg.Record | None) -> Portal | None:
  77. if row is None:
  78. return None
  79. data = {**row}
  80. chat_id = data.pop("chat_id")
  81. if data["receiver"]:
  82. chat_id = ensure_uuid(chat_id)
  83. return cls(chat_id=chat_id, **data)
  84. _columns = (
  85. "chat_id, receiver, mxid, name, topic, avatar_hash, avatar_url, name_set, avatar_set, "
  86. "revision, encrypted, relay_user_id, expiration_time"
  87. )
  88. @classmethod
  89. async def get_by_mxid(cls, mxid: RoomID) -> Portal | None:
  90. q = f"SELECT {cls._columns} FROM portal WHERE mxid=$1"
  91. return cls._from_row(await cls.db.fetchrow(q, mxid))
  92. @classmethod
  93. async def get_by_chat_id(cls, chat_id: GroupID | UUID, receiver: str = "") -> Portal | None:
  94. q = f"SELECT {cls._columns} FROM portal WHERE chat_id=$1 AND receiver=$2"
  95. return cls._from_row(await cls.db.fetchrow(q, str(chat_id), receiver))
  96. @classmethod
  97. async def find_private_chats_of(cls, receiver: str) -> list[Portal]:
  98. q = f"SELECT {cls._columns} FROM portal WHERE receiver=$1"
  99. rows = await cls.db.fetch(q, receiver)
  100. return [cls._from_row(row) for row in rows]
  101. @classmethod
  102. async def find_private_chats_with(cls, other_user: UUID) -> list[Portal]:
  103. q = f"SELECT {cls._columns} FROM portal WHERE chat_id=$1 AND receiver<>''"
  104. rows = await cls.db.fetch(q, str(other_user))
  105. return [cls._from_row(row) for row in rows]
  106. @classmethod
  107. async def all_with_room(cls) -> list[Portal]:
  108. q = f"SELECT {cls._columns} FROM portal WHERE mxid IS NOT NULL"
  109. rows = await cls.db.fetch(q)
  110. return [cls._from_row(row) for row in rows]