puppet.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import Optional, ClassVar, List, TYPE_CHECKING
  17. from uuid import UUID
  18. from attr import dataclass
  19. from yarl import URL
  20. import asyncpg
  21. from mausignald.types import Address
  22. from mautrix.types import UserID, SyncToken, ContentURI
  23. from mautrix.util.async_db import Database
  24. fake_db = Database("") if TYPE_CHECKING else None
  25. @dataclass
  26. class Puppet:
  27. db: ClassVar[Database] = fake_db
  28. uuid: Optional[UUID]
  29. number: Optional[str]
  30. name: Optional[str]
  31. avatar_hash: Optional[str]
  32. avatar_url: Optional[ContentURI]
  33. name_set: bool
  34. avatar_set: bool
  35. uuid_registered: bool
  36. number_registered: bool
  37. custom_mxid: Optional[UserID]
  38. access_token: Optional[str]
  39. next_batch: Optional[SyncToken]
  40. base_url: Optional[URL]
  41. @property
  42. def _base_url_str(self) -> Optional[str]:
  43. return str(self.base_url) if self.base_url else None
  44. async def insert(self) -> None:
  45. q = ("INSERT INTO puppet (uuid, number, name, avatar_hash, avatar_url, name_set, "
  46. " avatar_set, uuid_registered, number_registered, "
  47. " custom_mxid, access_token, next_batch, base_url) "
  48. "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)")
  49. await self.db.execute(q, self.uuid, self.number, self.name, self.avatar_hash,
  50. self.avatar_url, self.name_set, self.avatar_set,
  51. self.uuid_registered, self.number_registered, self.custom_mxid,
  52. self.access_token, self.next_batch, self._base_url_str)
  53. async def _set_uuid(self, uuid: UUID) -> None:
  54. async with self.db.acquire() as conn, conn.transaction():
  55. await conn.execute("DELETE FROM puppet WHERE uuid=$1 AND number<>$2",
  56. uuid, self.number)
  57. await conn.execute("UPDATE puppet SET uuid=$1 WHERE number=$2", uuid, self.number)
  58. await self._update_number_to_uuid(conn, self.number, str(uuid))
  59. async def _set_number(self, number: str) -> None:
  60. async with self.db.acquire() as conn, conn.transaction():
  61. await conn.execute("DELETE FROM puppet WHERE number=$1 AND uuid<>$2",
  62. number, self.uuid)
  63. await conn.execute("UPDATE puppet SET number=$1 WHERE uuid=$2", number, self.uuid)
  64. await self._update_number_to_uuid(conn, number, str(self.uuid))
  65. @staticmethod
  66. async def _update_number_to_uuid(conn: asyncpg.Connection, old_number: str, new_uuid: str
  67. ) -> None:
  68. try:
  69. async with conn.transaction():
  70. await conn.execute("UPDATE portal SET chat_id=$1 WHERE chat_id=$2",
  71. new_uuid, old_number)
  72. except asyncpg.UniqueViolationError:
  73. await conn.execute("DELETE FROM portal WHERE chat_id=$1", old_number)
  74. await conn.execute("UPDATE message SET sender=$1 WHERE sender=$2", new_uuid, old_number)
  75. await conn.execute("UPDATE reaction SET author=$1 WHERE author=$2", new_uuid, old_number)
  76. async def update(self) -> None:
  77. set_columns = (
  78. "name=$3, avatar_hash=$4, avatar_url=$5, name_set=$6, avatar_set=$7, "
  79. "uuid_registered=$8, number_registered=$9, "
  80. "custom_mxid=$10, access_token=$11, next_batch=$12, base_url=$13"
  81. )
  82. q = (f"UPDATE puppet SET uuid=$1, {set_columns} WHERE number=$2"
  83. if self.uuid is None
  84. else f"UPDATE puppet SET number=$2, {set_columns} WHERE uuid=$1")
  85. await self.db.execute(q,self.uuid, self.number, self.name, self.avatar_hash,
  86. self.avatar_url, self.name_set, self.avatar_set,
  87. self.uuid_registered, self.number_registered, self.custom_mxid,
  88. self.access_token, self.next_batch, self._base_url_str)
  89. @classmethod
  90. def _from_row(cls, row: asyncpg.Record) -> 'Puppet':
  91. data = {**row}
  92. base_url_str = data.pop("base_url")
  93. base_url = URL(base_url_str) if base_url_str is not None else None
  94. return cls(base_url=base_url, **data)
  95. _select_base = ("SELECT uuid, number, name, avatar_hash, avatar_url, name_set, avatar_set, "
  96. " uuid_registered, number_registered, custom_mxid, access_token, "
  97. " next_batch, base_url "
  98. "FROM puppet")
  99. @classmethod
  100. async def get_by_address(cls, address: Address) -> Optional['Puppet']:
  101. if address.uuid:
  102. if address.number:
  103. row = await cls.db.fetchrow(f"{cls._select_base} WHERE uuid=$1 OR number=$2",
  104. address.uuid, address.number)
  105. else:
  106. row = await cls.db.fetchrow(f"{cls._select_base} WHERE uuid=$1", address.uuid)
  107. elif address.number:
  108. row = await cls.db.fetchrow(f"{cls._select_base} WHERE number=$1", address.number)
  109. else:
  110. raise ValueError("Invalid address")
  111. if not row:
  112. return None
  113. return cls._from_row(row)
  114. @classmethod
  115. async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
  116. row = await cls.db.fetchrow(f"{cls._select_base} WHERE custom_mxid=$1", mxid)
  117. if not row:
  118. return None
  119. return cls._from_row(row)
  120. @classmethod
  121. async def all_with_custom_mxid(cls) -> List['Puppet']:
  122. rows = await cls.db.fetch(f"{cls._select_base} WHERE custom_mxid IS NOT NULL")
  123. return [cls._from_row(row) for row in rows]