portal.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2021 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING, ClassVar
  18. from attr import dataclass
  19. from mautrix.types import ContentURI, RoomID, UserID
  20. from mautrix.util.async_db import Database
  21. import asyncpg
  22. from mausignald.types import Address, GroupID
  23. from ..util import id_to_str
  24. fake_db = Database.create("") if TYPE_CHECKING else None
  25. @dataclass
  26. class Portal:
  27. db: ClassVar[Database] = fake_db
  28. chat_id: GroupID | Address
  29. receiver: str
  30. mxid: RoomID | None
  31. name: str | None
  32. avatar_hash: str | None
  33. avatar_url: ContentURI | None
  34. name_set: bool
  35. avatar_set: bool
  36. revision: int
  37. encrypted: bool
  38. relay_user_id: UserID | None
  39. expiration_time: int | None
  40. @property
  41. def chat_id_str(self) -> str:
  42. return id_to_str(self.chat_id)
  43. async def insert(self) -> None:
  44. q = """
  45. INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set,
  46. avatar_set, revision, encrypted, relay_user_id, expiration_time)
  47. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
  48. """
  49. await self.db.execute(
  50. q,
  51. self.chat_id_str,
  52. self.receiver,
  53. self.mxid,
  54. self.name,
  55. self.avatar_hash,
  56. self.avatar_url,
  57. self.name_set,
  58. self.avatar_set,
  59. self.revision,
  60. self.encrypted,
  61. self.relay_user_id,
  62. self.expiration_time,
  63. )
  64. async def update(self) -> None:
  65. q = """
  66. UPDATE portal SET mxid=$1, name=$2, avatar_hash=$3, avatar_url=$4, name_set=$5,
  67. avatar_set=$6, revision=$7, encrypted=$8, relay_user_id=$9,
  68. expiration_time=$10
  69. WHERE chat_id=$11 AND receiver=$12
  70. """
  71. await self.db.execute(
  72. q,
  73. self.mxid,
  74. self.name,
  75. self.avatar_hash,
  76. self.avatar_url,
  77. self.name_set,
  78. self.avatar_set,
  79. self.revision,
  80. self.encrypted,
  81. self.relay_user_id,
  82. self.expiration_time,
  83. self.chat_id_str,
  84. self.receiver,
  85. )
  86. @classmethod
  87. def _from_row(cls, row: asyncpg.Record) -> Portal:
  88. data = {**row}
  89. chat_id = data.pop("chat_id")
  90. if data["receiver"]:
  91. chat_id = Address.parse(chat_id)
  92. return cls(chat_id=chat_id, **data)
  93. @classmethod
  94. async def get_by_mxid(cls, mxid: RoomID) -> Portal | None:
  95. q = """
  96. SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
  97. revision, encrypted, relay_user_id, expiration_time FROM portal
  98. WHERE mxid=$1
  99. """
  100. row = await cls.db.fetchrow(q, mxid)
  101. if not row:
  102. return None
  103. return cls._from_row(row)
  104. @classmethod
  105. async def get_by_chat_id(cls, chat_id: GroupID | Address, receiver: str = "") -> Portal | None:
  106. q = """
  107. SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
  108. revision, encrypted, relay_user_id, expiration_time FROM portal
  109. WHERE chat_id=$1 AND receiver=$2
  110. """
  111. row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
  112. if not row:
  113. return None
  114. return cls._from_row(row)
  115. @classmethod
  116. async def find_private_chats_of(cls, receiver: str) -> list[Portal]:
  117. q = """
  118. SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
  119. revision, encrypted, relay_user_id, expiration_time FROM portal
  120. WHERE receiver=$1
  121. """
  122. rows = await cls.db.fetch(q, receiver)
  123. return [cls._from_row(row) for row in rows]
  124. @classmethod
  125. async def find_private_chats_with(cls, other_user: Address) -> list[Portal]:
  126. q = """
  127. SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
  128. revision, encrypted, relay_user_id, expiration_time FROM portal
  129. WHERE chat_id=$1 AND receiver<>''
  130. """
  131. rows = await cls.db.fetch(q, other_user.best_identifier)
  132. return [cls._from_row(row) for row in rows]
  133. @classmethod
  134. async def all_with_room(cls) -> list[Portal]:
  135. q = """
  136. SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
  137. revision, encrypted, relay_user_id, expiration_time FROM portal
  138. WHERE mxid IS NOT NULL
  139. """
  140. rows = await cls.db.fetch(q)
  141. return [cls._from_row(row) for row in rows]