portal.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2021 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import Optional, ClassVar, List, Union, TYPE_CHECKING
  17. from attr import dataclass
  18. import asyncpg
  19. from mausignald.types import Address, GroupID
  20. from mautrix.types import RoomID, ContentURI, UserID
  21. from mautrix.util.async_db import Database
  22. from ..util import id_to_str
  23. fake_db = Database.create("") if TYPE_CHECKING else None
  24. @dataclass
  25. class Portal:
  26. db: ClassVar[Database] = fake_db
  27. chat_id: Union[GroupID, Address]
  28. receiver: str
  29. mxid: Optional[RoomID]
  30. name: Optional[str]
  31. avatar_hash: Optional[str]
  32. avatar_url: Optional[ContentURI]
  33. name_set: bool
  34. avatar_set: bool
  35. revision: int
  36. encrypted: bool
  37. relay_user_id: Optional[UserID]
  38. expiration_time: Optional[int]
  39. @property
  40. def chat_id_str(self) -> str:
  41. return id_to_str(self.chat_id)
  42. async def insert(self) -> None:
  43. q = """
  44. INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set,
  45. avatar_set, revision, encrypted, relay_user_id, expiration_time)
  46. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
  47. """
  48. await self.db.execute(
  49. q,
  50. self.chat_id_str,
  51. self.receiver,
  52. self.mxid,
  53. self.name,
  54. self.avatar_hash,
  55. self.avatar_url,
  56. self.name_set,
  57. self.avatar_set,
  58. self.revision,
  59. self.encrypted,
  60. self.relay_user_id,
  61. self.expiration_time,
  62. )
  63. async def update(self) -> None:
  64. q = """
  65. UPDATE portal SET mxid=$1, name=$2, avatar_hash=$3, avatar_url=$4, name_set=$5,
  66. avatar_set=$6, revision=$7, encrypted=$8, relay_user_id=$9,
  67. expiration_time=$10
  68. WHERE chat_id=$11 AND receiver=$12
  69. """
  70. await self.db.execute(
  71. q,
  72. self.mxid,
  73. self.name,
  74. self.avatar_hash,
  75. self.avatar_url,
  76. self.name_set,
  77. self.avatar_set,
  78. self.revision,
  79. self.encrypted,
  80. self.relay_user_id,
  81. self.expiration_time,
  82. self.chat_id_str,
  83. self.receiver,
  84. )
  85. @classmethod
  86. def _from_row(cls, row: asyncpg.Record) -> "Portal":
  87. data = {**row}
  88. chat_id = data.pop("chat_id")
  89. if data["receiver"]:
  90. chat_id = Address.parse(chat_id)
  91. return cls(chat_id=chat_id, **data)
  92. @classmethod
  93. async def get_by_mxid(cls, mxid: RoomID) -> Optional["Portal"]:
  94. q = """
  95. SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
  96. revision, encrypted, relay_user_id, expiration_time
  97. FROM portal
  98. WHERE mxid=$1
  99. """
  100. row = await cls.db.fetchrow(q, mxid)
  101. if not row:
  102. return None
  103. return cls._from_row(row)
  104. @classmethod
  105. async def get_by_chat_id(
  106. cls, chat_id: Union[GroupID, Address], receiver: str = ""
  107. ) -> Optional["Portal"]:
  108. q = """
  109. SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
  110. revision, encrypted, relay_user_id, expiration_time
  111. FROM portal
  112. WHERE chat_id=$1 AND receiver=$2
  113. """
  114. row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
  115. if not row:
  116. return None
  117. return cls._from_row(row)
  118. @classmethod
  119. async def find_private_chats_of(cls, receiver: str) -> List["Portal"]:
  120. q = """
  121. SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
  122. revision, encrypted, relay_user_id, expiration_time
  123. FROM portal
  124. WHERE receiver=$1
  125. """
  126. rows = await cls.db.fetch(q, receiver)
  127. return [cls._from_row(row) for row in rows]
  128. @classmethod
  129. async def find_private_chats_with(cls, other_user: Address) -> List["Portal"]:
  130. q = """
  131. SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
  132. revision, encrypted, relay_user_id, expiration_time
  133. FROM portal
  134. WHERE chat_id=$1 AND receiver<>''
  135. """
  136. rows = await cls.db.fetch(q, other_user.best_identifier)
  137. return [cls._from_row(row) for row in rows]
  138. @classmethod
  139. async def all_with_room(cls) -> List["Portal"]:
  140. q = """
  141. SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
  142. revision, encrypted, relay_user_id, expiration_time
  143. FROM portal
  144. WHERE mxid IS NOT NULL
  145. """
  146. rows = await cls.db.fetch(q)
  147. return [cls._from_row(row) for row in rows]