puppet.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING, ClassVar
  18. from uuid import UUID
  19. from attr import dataclass
  20. from yarl import URL
  21. import asyncpg
  22. from mausignald.types import Address
  23. from mautrix.types import ContentURI, SyncToken, UserID
  24. from mautrix.util.async_db import Database
  25. fake_db = Database.create("") if TYPE_CHECKING else None
  26. @dataclass
  27. class Puppet:
  28. db: ClassVar[Database] = fake_db
  29. uuid: UUID | None
  30. number: str | None
  31. name: str | None
  32. avatar_hash: str | None
  33. avatar_url: ContentURI | None
  34. name_set: bool
  35. avatar_set: bool
  36. uuid_registered: bool
  37. number_registered: bool
  38. custom_mxid: UserID | None
  39. access_token: str | None
  40. next_batch: SyncToken | None
  41. base_url: URL | None
  42. @property
  43. def _base_url_str(self) -> str | None:
  44. return str(self.base_url) if self.base_url else None
  45. async def insert(self) -> None:
  46. q = (
  47. "INSERT INTO puppet (uuid, number, name, avatar_hash, avatar_url, name_set, "
  48. " avatar_set, uuid_registered, number_registered, "
  49. " custom_mxid, access_token, next_batch, base_url) "
  50. "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)"
  51. )
  52. await self.db.execute(
  53. q,
  54. self.uuid,
  55. self.number,
  56. self.name,
  57. self.avatar_hash,
  58. self.avatar_url,
  59. self.name_set,
  60. self.avatar_set,
  61. self.uuid_registered,
  62. self.number_registered,
  63. self.custom_mxid,
  64. self.access_token,
  65. self.next_batch,
  66. self._base_url_str,
  67. )
  68. async def _set_uuid(self, uuid: UUID) -> None:
  69. async with self.db.acquire() as conn, conn.transaction():
  70. await conn.execute(
  71. "DELETE FROM puppet WHERE uuid=$1 AND number<>$2", uuid, self.number
  72. )
  73. await conn.execute("UPDATE puppet SET uuid=$1 WHERE number=$2", uuid, self.number)
  74. await self._update_number_to_uuid(conn, self.number, str(uuid))
  75. async def _set_number(self, number: str) -> None:
  76. async with self.db.acquire() as conn, conn.transaction():
  77. await conn.execute(
  78. "DELETE FROM puppet WHERE number=$1 AND uuid<>$2", number, self.uuid
  79. )
  80. await conn.execute("UPDATE puppet SET number=$1 WHERE uuid=$2", number, self.uuid)
  81. await self._update_number_to_uuid(conn, number, str(self.uuid))
  82. @staticmethod
  83. async def _update_number_to_uuid(
  84. conn: asyncpg.Connection, old_number: str, new_uuid: str
  85. ) -> None:
  86. try:
  87. async with conn.transaction():
  88. await conn.execute(
  89. "UPDATE portal SET chat_id=$1 WHERE chat_id=$2", new_uuid, old_number
  90. )
  91. except asyncpg.UniqueViolationError:
  92. await conn.execute("DELETE FROM portal WHERE chat_id=$1", old_number)
  93. await conn.execute("UPDATE message SET sender=$1 WHERE sender=$2", new_uuid, old_number)
  94. await conn.execute("UPDATE reaction SET author=$1 WHERE author=$2", new_uuid, old_number)
  95. async def update(self) -> None:
  96. set_columns = (
  97. "name=$3, avatar_hash=$4, avatar_url=$5, name_set=$6, avatar_set=$7, "
  98. "uuid_registered=$8, number_registered=$9, "
  99. "custom_mxid=$10, access_token=$11, next_batch=$12, base_url=$13"
  100. )
  101. q = (
  102. f"UPDATE puppet SET uuid=$1, {set_columns} WHERE number=$2"
  103. if self.uuid is None
  104. else f"UPDATE puppet SET number=$2, {set_columns} WHERE uuid=$1"
  105. )
  106. await self.db.execute(
  107. q,
  108. self.uuid,
  109. self.number,
  110. self.name,
  111. self.avatar_hash,
  112. self.avatar_url,
  113. self.name_set,
  114. self.avatar_set,
  115. self.uuid_registered,
  116. self.number_registered,
  117. self.custom_mxid,
  118. self.access_token,
  119. self.next_batch,
  120. self._base_url_str,
  121. )
  122. @classmethod
  123. def _from_row(cls, row: asyncpg.Record) -> Puppet:
  124. data = {**row}
  125. base_url_str = data.pop("base_url")
  126. base_url = URL(base_url_str) if base_url_str is not None else None
  127. return cls(base_url=base_url, **data)
  128. _select_base = (
  129. "SELECT uuid, number, name, avatar_hash, avatar_url, name_set, avatar_set, "
  130. " uuid_registered, number_registered, custom_mxid, access_token, "
  131. " next_batch, base_url "
  132. "FROM puppet"
  133. )
  134. @classmethod
  135. async def get_by_address(cls, address: Address) -> Puppet | None:
  136. if address.uuid:
  137. if address.number:
  138. row = await cls.db.fetchrow(
  139. f"{cls._select_base} WHERE uuid=$1 OR number=$2", address.uuid, address.number
  140. )
  141. else:
  142. row = await cls.db.fetchrow(f"{cls._select_base} WHERE uuid=$1", address.uuid)
  143. elif address.number:
  144. row = await cls.db.fetchrow(f"{cls._select_base} WHERE number=$1", address.number)
  145. else:
  146. raise ValueError("Invalid address")
  147. if not row:
  148. return None
  149. return cls._from_row(row)
  150. @classmethod
  151. async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
  152. row = await cls.db.fetchrow(f"{cls._select_base} WHERE custom_mxid=$1", mxid)
  153. if not row:
  154. return None
  155. return cls._from_row(row)
  156. @classmethod
  157. async def all_with_custom_mxid(cls) -> list[Puppet]:
  158. rows = await cls.db.fetch(f"{cls._select_base} WHERE custom_mxid IS NOT NULL")
  159. return [cls._from_row(row) for row in rows]