puppet.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import Optional, ClassVar, List, TYPE_CHECKING
  17. from uuid import UUID
  18. from attr import dataclass
  19. from yarl import URL
  20. import asyncpg
  21. from mausignald.types import Address
  22. from mautrix.types import UserID, SyncToken, ContentURI
  23. from mautrix.util.async_db import Database
  24. fake_db = Database.create("") if TYPE_CHECKING else None
  25. @dataclass
  26. class Puppet:
  27. db: ClassVar[Database] = fake_db
  28. uuid: Optional[UUID]
  29. number: Optional[str]
  30. name: Optional[str]
  31. avatar_hash: Optional[str]
  32. avatar_url: Optional[ContentURI]
  33. name_set: bool
  34. avatar_set: bool
  35. uuid_registered: bool
  36. number_registered: bool
  37. custom_mxid: Optional[UserID]
  38. access_token: Optional[str]
  39. next_batch: Optional[SyncToken]
  40. base_url: Optional[URL]
  41. @property
  42. def _base_url_str(self) -> Optional[str]:
  43. return str(self.base_url) if self.base_url else None
  44. async def insert(self) -> None:
  45. q = (
  46. "INSERT INTO puppet (uuid, number, name, avatar_hash, avatar_url, name_set, "
  47. " avatar_set, uuid_registered, number_registered, "
  48. " custom_mxid, access_token, next_batch, base_url) "
  49. "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)"
  50. )
  51. await self.db.execute(
  52. q,
  53. self.uuid,
  54. self.number,
  55. self.name,
  56. self.avatar_hash,
  57. self.avatar_url,
  58. self.name_set,
  59. self.avatar_set,
  60. self.uuid_registered,
  61. self.number_registered,
  62. self.custom_mxid,
  63. self.access_token,
  64. self.next_batch,
  65. self._base_url_str,
  66. )
  67. async def _set_uuid(self, uuid: UUID) -> None:
  68. async with self.db.acquire() as conn, conn.transaction():
  69. await conn.execute(
  70. "DELETE FROM puppet WHERE uuid=$1 AND number<>$2", uuid, self.number
  71. )
  72. await conn.execute("UPDATE puppet SET uuid=$1 WHERE number=$2", uuid, self.number)
  73. await self._update_number_to_uuid(conn, self.number, str(uuid))
  74. async def _set_number(self, number: str) -> None:
  75. async with self.db.acquire() as conn, conn.transaction():
  76. await conn.execute(
  77. "DELETE FROM puppet WHERE number=$1 AND uuid<>$2", number, self.uuid
  78. )
  79. await conn.execute("UPDATE puppet SET number=$1 WHERE uuid=$2", number, self.uuid)
  80. await self._update_number_to_uuid(conn, number, str(self.uuid))
  81. @staticmethod
  82. async def _update_number_to_uuid(
  83. conn: asyncpg.Connection, old_number: str, new_uuid: str
  84. ) -> None:
  85. try:
  86. async with conn.transaction():
  87. await conn.execute(
  88. "UPDATE portal SET chat_id=$1 WHERE chat_id=$2", new_uuid, old_number
  89. )
  90. except asyncpg.UniqueViolationError:
  91. await conn.execute("DELETE FROM portal WHERE chat_id=$1", old_number)
  92. await conn.execute("UPDATE message SET sender=$1 WHERE sender=$2", new_uuid, old_number)
  93. await conn.execute("UPDATE reaction SET author=$1 WHERE author=$2", new_uuid, old_number)
  94. async def update(self) -> None:
  95. set_columns = (
  96. "name=$3, avatar_hash=$4, avatar_url=$5, name_set=$6, avatar_set=$7, "
  97. "uuid_registered=$8, number_registered=$9, "
  98. "custom_mxid=$10, access_token=$11, next_batch=$12, base_url=$13"
  99. )
  100. q = (
  101. f"UPDATE puppet SET uuid=$1, {set_columns} WHERE number=$2"
  102. if self.uuid is None
  103. else f"UPDATE puppet SET number=$2, {set_columns} WHERE uuid=$1"
  104. )
  105. await self.db.execute(
  106. q,
  107. self.uuid,
  108. self.number,
  109. self.name,
  110. self.avatar_hash,
  111. self.avatar_url,
  112. self.name_set,
  113. self.avatar_set,
  114. self.uuid_registered,
  115. self.number_registered,
  116. self.custom_mxid,
  117. self.access_token,
  118. self.next_batch,
  119. self._base_url_str,
  120. )
  121. @classmethod
  122. def _from_row(cls, row: asyncpg.Record) -> "Puppet":
  123. data = {**row}
  124. base_url_str = data.pop("base_url")
  125. base_url = URL(base_url_str) if base_url_str is not None else None
  126. return cls(base_url=base_url, **data)
  127. _select_base = (
  128. "SELECT uuid, number, name, avatar_hash, avatar_url, name_set, avatar_set, "
  129. " uuid_registered, number_registered, custom_mxid, access_token, "
  130. " next_batch, base_url "
  131. "FROM puppet"
  132. )
  133. @classmethod
  134. async def get_by_address(cls, address: Address) -> Optional["Puppet"]:
  135. if address.uuid:
  136. if address.number:
  137. row = await cls.db.fetchrow(
  138. f"{cls._select_base} WHERE uuid=$1 OR number=$2", address.uuid, address.number
  139. )
  140. else:
  141. row = await cls.db.fetchrow(f"{cls._select_base} WHERE uuid=$1", address.uuid)
  142. elif address.number:
  143. row = await cls.db.fetchrow(f"{cls._select_base} WHERE number=$1", address.number)
  144. else:
  145. raise ValueError("Invalid address")
  146. if not row:
  147. return None
  148. return cls._from_row(row)
  149. @classmethod
  150. async def get_by_custom_mxid(cls, mxid: UserID) -> Optional["Puppet"]:
  151. row = await cls.db.fetchrow(f"{cls._select_base} WHERE custom_mxid=$1", mxid)
  152. if not row:
  153. return None
  154. return cls._from_row(row)
  155. @classmethod
  156. async def all_with_custom_mxid(cls) -> List["Puppet"]:
  157. rows = await cls.db.fetch(f"{cls._select_base} WHERE custom_mxid IS NOT NULL")
  158. return [cls._from_row(row) for row in rows]