puppet.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import Optional, ClassVar, List, TYPE_CHECKING
  17. from uuid import UUID
  18. from attr import dataclass
  19. from yarl import URL
  20. import asyncpg
  21. from mausignald.types import Address
  22. from mautrix.types import UserID, SyncToken
  23. from mautrix.util.async_db import Database
  24. fake_db = Database("") if TYPE_CHECKING else None
  25. @dataclass
  26. class Puppet:
  27. db: ClassVar[Database] = fake_db
  28. uuid: Optional[UUID]
  29. number: Optional[str]
  30. name: Optional[str]
  31. uuid_registered: bool
  32. number_registered: bool
  33. custom_mxid: Optional[UserID]
  34. access_token: Optional[str]
  35. next_batch: Optional[SyncToken]
  36. base_url: Optional[URL]
  37. async def insert(self) -> None:
  38. q = ("INSERT INTO puppet (uuid, number, name, uuid_registered, number_registered, "
  39. " custom_mxid, access_token, next_batch, base_url) "
  40. "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)")
  41. await self.db.execute(q, self.uuid, self.number, self.name, self.uuid_registered,
  42. self.number_registered, self.custom_mxid, self.access_token,
  43. self.next_batch, str(self.base_url) if self.base_url else None)
  44. async def _set_uuid(self, uuid: UUID) -> None:
  45. if self.uuid:
  46. raise ValueError("Can't re-set UUID for puppet")
  47. self.uuid = uuid
  48. await self.db.execute("UPDATE puppet SET uuid=$1 WHERE number=$2", uuid, self.number)
  49. async def update(self) -> None:
  50. if self.uuid is None:
  51. q = ("UPDATE puppet SET uuid=$1, name=$3, uuid_registered=$4, number_registered=$5, "
  52. " custom_mxid=$6, access_token=$7, next_batch=$8, base_url=$9 "
  53. "WHERE number=$2")
  54. else:
  55. q = ("UPDATE puppet SET number=$2, name=$3, uuid_registered=$4, number_registered=$5, "
  56. " custom_mxid=$6, access_token=$7, next_batch=$8, base_url=$9 "
  57. "WHERE uuid=$1")
  58. await self.db.execute(q, self.uuid, self.number, self.name, self.uuid_registered,
  59. self.number_registered, self.custom_mxid, self.access_token,
  60. self.next_batch, str(self.base_url) if self.base_url else None)
  61. @classmethod
  62. def _from_row(cls, row: asyncpg.Record) -> 'Puppet':
  63. data = {**row}
  64. base_url_str = data.pop("base_url")
  65. base_url = URL(base_url_str) if base_url_str is not None else None
  66. return cls(base_url=base_url, **data)
  67. @classmethod
  68. async def get_by_address(cls, address: Address) -> Optional['Puppet']:
  69. select = ("SELECT uuid, number, name, uuid_registered, "
  70. " number_registered, custom_mxid, access_token, next_batch, base_url "
  71. "FROM puppet")
  72. if address.uuid:
  73. if address.number:
  74. row = await cls.db.fetchrow(f"{select} WHERE uuid=$1 OR number=$2",
  75. address.uuid, address.number)
  76. else:
  77. row = await cls.db.fetchrow(f"{select} WHERE uuid=$1", address.uuid)
  78. elif address.number:
  79. row = await cls.db.fetchrow(f"{select} WHERE number=$1", address.number)
  80. else:
  81. raise ValueError("Invalid address")
  82. if not row:
  83. return None
  84. return cls._from_row(row)
  85. @classmethod
  86. async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
  87. q = ("SELECT uuid, number, name, uuid_registered, number_registered,"
  88. " custom_mxid, access_token, next_batch, base_url "
  89. "FROM puppet WHERE custom_mxid=$1")
  90. row = await cls.db.fetchrow(q, mxid)
  91. if not row:
  92. return None
  93. return cls._from_row(row)
  94. @classmethod
  95. async def all_with_custom_mxid(cls) -> List['Puppet']:
  96. q = ("SELECT uuid, number, name, uuid_registered, number_registered,"
  97. " custom_mxid, access_token, next_batch, base_url "
  98. "FROM puppet WHERE custom_mxid IS NOT NULL")
  99. rows = await cls.db.fetch(q)
  100. return [cls._from_row(row) for row in rows]