portal.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2021 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING, ClassVar
  18. from attr import dataclass
  19. import asyncpg
  20. from mausignald.types import Address, GroupID
  21. from mautrix.types import ContentURI, RoomID, UserID
  22. from mautrix.util.async_db import Database
  23. from ..util import id_to_str
  24. fake_db = Database.create("") if TYPE_CHECKING else None
  25. @dataclass
  26. class Portal:
  27. db: ClassVar[Database] = fake_db
  28. chat_id: GroupID | Address
  29. receiver: str
  30. mxid: RoomID | None
  31. name: str | None
  32. topic: str | None
  33. avatar_hash: str | None
  34. avatar_url: ContentURI | None
  35. name_set: bool
  36. avatar_set: bool
  37. revision: int
  38. encrypted: bool
  39. relay_user_id: UserID | None
  40. expiration_time: int | None
  41. @property
  42. def chat_id_str(self) -> str:
  43. return id_to_str(self.chat_id)
  44. @property
  45. def _values(self):
  46. return (
  47. self.chat_id_str,
  48. self.receiver,
  49. self.mxid,
  50. self.name,
  51. self.topic,
  52. self.avatar_hash,
  53. self.avatar_url,
  54. self.name_set,
  55. self.avatar_set,
  56. self.revision,
  57. self.encrypted,
  58. self.relay_user_id,
  59. self.expiration_time,
  60. )
  61. async def insert(self) -> None:
  62. q = """
  63. INSERT INTO portal (
  64. chat_id, receiver, mxid, name, topic, avatar_hash, avatar_url, name_set, avatar_set,
  65. revision, encrypted, relay_user_id, expiration_time
  66. ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
  67. """
  68. await self.db.execute(q, *self._values)
  69. async def update(self) -> None:
  70. q = """
  71. UPDATE portal SET mxid=$3, name=$4, topic=$5, avatar_hash=$6, avatar_url=$7, name_set=$8,
  72. avatar_set=$9, revision=$10, encrypted=$11, relay_user_id=$12,
  73. expiration_time=$13
  74. WHERE chat_id=$1 AND receiver=$2
  75. """
  76. await self.db.execute(q, *self._values)
  77. @classmethod
  78. def _from_row(cls, row: asyncpg.Record) -> Portal:
  79. data = {**row}
  80. chat_id = data.pop("chat_id")
  81. if data["receiver"]:
  82. chat_id = Address.parse(chat_id)
  83. return cls(chat_id=chat_id, **data)
  84. _columns = (
  85. "chat_id, receiver, mxid, name, topic, avatar_hash, avatar_url, name_set, avatar_set, "
  86. "revision, encrypted, relay_user_id, expiration_time"
  87. )
  88. @classmethod
  89. async def get_by_mxid(cls, mxid: RoomID) -> Portal | None:
  90. q = f"SELECT {cls._columns} FROM portal WHERE mxid=$1"
  91. row = await cls.db.fetchrow(q, mxid)
  92. if not row:
  93. return None
  94. return cls._from_row(row)
  95. @classmethod
  96. async def get_by_chat_id(cls, chat_id: GroupID | Address, receiver: str = "") -> Portal | None:
  97. q = f"SELECT {cls._columns} FROM portal WHERE chat_id=$1 AND receiver=$2"
  98. row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
  99. if not row:
  100. return None
  101. return cls._from_row(row)
  102. @classmethod
  103. async def find_private_chats_of(cls, receiver: str) -> list[Portal]:
  104. q = f"SELECT {cls._columns} FROM portal WHERE receiver=$1"
  105. rows = await cls.db.fetch(q, receiver)
  106. return [cls._from_row(row) for row in rows]
  107. @classmethod
  108. async def find_private_chats_with(cls, other_user: Address) -> list[Portal]:
  109. q = f"SELECT {cls._columns} FROM portal WHERE chat_id=$1 AND receiver<>''"
  110. rows = await cls.db.fetch(q, other_user.best_identifier)
  111. return [cls._from_row(row) for row in rows]
  112. @classmethod
  113. async def all_with_room(cls) -> list[Portal]:
  114. q = f"SELECT {cls._columns} FROM portal WHERE mxid IS NOT NULL"
  115. rows = await cls.db.fetch(q)
  116. return [cls._from_row(row) for row in rows]