puppet.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING, ClassVar
  18. from uuid import UUID
  19. from attr import dataclass
  20. from yarl import URL
  21. import asyncpg
  22. from mausignald.types import Address
  23. from mautrix.types import ContentURI, SyncToken, UserID
  24. from mautrix.util.async_db import Database
  25. fake_db = Database.create("") if TYPE_CHECKING else None
  26. @dataclass
  27. class Puppet:
  28. db: ClassVar[Database] = fake_db
  29. uuid: UUID | None
  30. number: str | None
  31. name: str | None
  32. name_quality: int
  33. avatar_hash: str | None
  34. avatar_url: ContentURI | None
  35. name_set: bool
  36. avatar_set: bool
  37. uuid_registered: bool
  38. number_registered: bool
  39. custom_mxid: UserID | None
  40. access_token: str | None
  41. next_batch: SyncToken | None
  42. base_url: URL | None
  43. @property
  44. def _base_url_str(self) -> str | None:
  45. return str(self.base_url) if self.base_url else None
  46. async def insert(self) -> None:
  47. q = """
  48. INSERT INTO puppet (uuid, number, name, name_quality, avatar_hash, avatar_url, name_set,
  49. avatar_set, uuid_registered, number_registered,
  50. custom_mxid, access_token, next_batch, base_url)
  51. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
  52. """
  53. await self.db.execute(
  54. q,
  55. self.uuid,
  56. self.number,
  57. self.name,
  58. self.name_quality,
  59. self.avatar_hash,
  60. self.avatar_url,
  61. self.name_set,
  62. self.avatar_set,
  63. self.uuid_registered,
  64. self.number_registered,
  65. self.custom_mxid,
  66. self.access_token,
  67. self.next_batch,
  68. self._base_url_str,
  69. )
  70. async def _set_uuid(self, uuid: UUID) -> None:
  71. async with self.db.acquire() as conn, conn.transaction():
  72. await conn.execute(
  73. "DELETE FROM puppet WHERE uuid=$1 AND number<>$2", uuid, self.number
  74. )
  75. await conn.execute("UPDATE puppet SET uuid=$1 WHERE number=$2", uuid, self.number)
  76. await self._update_number_to_uuid(conn, self.number, str(uuid))
  77. async def _set_number(self, number: str) -> None:
  78. async with self.db.acquire() as conn, conn.transaction():
  79. await conn.execute(
  80. "DELETE FROM puppet WHERE number=$1 AND uuid<>$2", number, self.uuid
  81. )
  82. await conn.execute("UPDATE puppet SET number=$1 WHERE uuid=$2", number, self.uuid)
  83. await self._update_number_to_uuid(conn, number, str(self.uuid))
  84. @staticmethod
  85. async def _update_number_to_uuid(
  86. conn: asyncpg.Connection, old_number: str, new_uuid: str
  87. ) -> None:
  88. try:
  89. async with conn.transaction():
  90. await conn.execute(
  91. "UPDATE portal SET chat_id=$1 WHERE chat_id=$2", new_uuid, old_number
  92. )
  93. except asyncpg.UniqueViolationError:
  94. await conn.execute("DELETE FROM portal WHERE chat_id=$1", old_number)
  95. await conn.execute("UPDATE message SET sender=$1 WHERE sender=$2", new_uuid, old_number)
  96. await conn.execute("UPDATE reaction SET author=$1 WHERE author=$2", new_uuid, old_number)
  97. async def update(self) -> None:
  98. set_columns = (
  99. "name=$3, name_quality=$4, avatar_hash=$5, avatar_url=$6, name_set=$7, avatar_set=$8, "
  100. "uuid_registered=$9, number_registered=$10, "
  101. "custom_mxid=$11, access_token=$12, next_batch=$13, base_url=$14"
  102. )
  103. q = (
  104. f"UPDATE puppet SET uuid=$1, {set_columns} WHERE number=$2"
  105. if self.uuid is None
  106. else f"UPDATE puppet SET number=$2, {set_columns} WHERE uuid=$1"
  107. )
  108. await self.db.execute(
  109. q,
  110. self.uuid,
  111. self.number,
  112. self.name,
  113. self.name_quality,
  114. self.avatar_hash,
  115. self.avatar_url,
  116. self.name_set,
  117. self.avatar_set,
  118. self.uuid_registered,
  119. self.number_registered,
  120. self.custom_mxid,
  121. self.access_token,
  122. self.next_batch,
  123. self._base_url_str,
  124. )
  125. @classmethod
  126. def _from_row(cls, row: asyncpg.Record) -> Puppet:
  127. data = {**row}
  128. base_url_str = data.pop("base_url")
  129. base_url = URL(base_url_str) if base_url_str is not None else None
  130. return cls(base_url=base_url, **data)
  131. _select_base = (
  132. "SELECT uuid, number, name, name_quality, avatar_hash, avatar_url, name_set, avatar_set, "
  133. " uuid_registered, number_registered, custom_mxid, access_token, "
  134. " next_batch, base_url "
  135. "FROM puppet"
  136. )
  137. @classmethod
  138. async def get_by_address(cls, address: Address) -> Puppet | None:
  139. if address.uuid:
  140. if address.number:
  141. row = await cls.db.fetchrow(
  142. f"{cls._select_base} WHERE uuid=$1 OR number=$2", address.uuid, address.number
  143. )
  144. else:
  145. row = await cls.db.fetchrow(f"{cls._select_base} WHERE uuid=$1", address.uuid)
  146. elif address.number:
  147. row = await cls.db.fetchrow(f"{cls._select_base} WHERE number=$1", address.number)
  148. else:
  149. raise ValueError("Invalid address")
  150. if not row:
  151. return None
  152. return cls._from_row(row)
  153. @classmethod
  154. async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
  155. row = await cls.db.fetchrow(f"{cls._select_base} WHERE custom_mxid=$1", mxid)
  156. if not row:
  157. return None
  158. return cls._from_row(row)
  159. @classmethod
  160. async def all_with_custom_mxid(cls) -> list[Puppet]:
  161. rows = await cls.db.fetch(f"{cls._select_base} WHERE custom_mxid IS NOT NULL")
  162. return [cls._from_row(row) for row in rows]