user.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import Dict, Optional, AsyncGenerator, TYPE_CHECKING, cast
  17. from uuid import UUID
  18. import asyncio
  19. from mausignald.types import Account, Address
  20. from mautrix.bridge import BaseUser
  21. from mautrix.types import UserID, RoomID
  22. from mautrix.appservice import AppService
  23. from .db import User as DBUser
  24. from .config import Config
  25. from . import puppet as pu, portal as po
  26. if TYPE_CHECKING:
  27. from .__main__ import SignalBridge
  28. class User(DBUser, BaseUser):
  29. by_mxid: Dict[UserID, 'User'] = {}
  30. by_username: Dict[str, 'User'] = {}
  31. config: Config
  32. az: AppService
  33. loop: asyncio.AbstractEventLoop
  34. bridge: 'SignalBridge'
  35. is_admin: bool
  36. permission_level: str
  37. _notice_room_lock: asyncio.Lock
  38. def __init__(self, mxid: UserID, username: Optional[str] = None, uuid: Optional[UUID] = None,
  39. notice_room: Optional[RoomID] = None) -> None:
  40. super().__init__(mxid=mxid, username=username, uuid=uuid, notice_room=notice_room)
  41. self._notice_room_lock = asyncio.Lock()
  42. perms = self.config.get_permissions(mxid)
  43. self.is_whitelisted, self.is_admin, self.permission_level = perms
  44. self.log = self.log.getChild(self.mxid)
  45. self.dm_update_lock = asyncio.Lock()
  46. self.command_status = None
  47. @classmethod
  48. def init_cls(cls, bridge: 'SignalBridge') -> None:
  49. cls.bridge = bridge
  50. cls.config = bridge.config
  51. cls.az = bridge.az
  52. cls.loop = bridge.loop
  53. @property
  54. def address(self) -> Optional[Address]:
  55. if not self.username:
  56. return None
  57. return Address(uuid=self.uuid, number=self.username)
  58. async def is_logged_in(self) -> bool:
  59. return bool(self.username)
  60. async def on_signin(self, account: Account) -> None:
  61. self.username = account.username
  62. self.uuid = account.uuid
  63. await self.update()
  64. await self.bridge.signal.subscribe(self.username)
  65. self.loop.create_task(self.sync())
  66. async def _sync_puppet(self) -> None:
  67. puppet = await pu.Puppet.get_by_address(self.address)
  68. if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
  69. self.log.info(f"Automatically enabling custom puppet")
  70. await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
  71. async def sync(self) -> None:
  72. try:
  73. await self._sync_puppet()
  74. except Exception:
  75. self.log.exception("Error while syncing own puppet")
  76. try:
  77. await self._sync()
  78. except Exception:
  79. self.log.exception("Error while syncing")
  80. async def _sync(self) -> None:
  81. create_contact_portal = self.config["bridge.autocreate_contact_portal"]
  82. for contact in await self.bridge.signal.list_contacts(self.username):
  83. self.log.trace("Syncing contact %s", contact)
  84. puppet = await pu.Puppet.get_by_address(contact.address)
  85. if not puppet.name:
  86. profile = await self.bridge.signal.get_profile(self.username, contact.address)
  87. if profile:
  88. self.log.trace("Got profile for %s: %s", contact.address, profile)
  89. else:
  90. # get_profile probably does a request to the servers, so let's not do that unless
  91. # necessary, but maybe we could listen for updates?
  92. profile = None
  93. await puppet.update_info(profile or contact)
  94. if puppet.uuid and create_contact_portal:
  95. portal = await po.Portal.get_by_chat_id(puppet.uuid, self.username, create=True)
  96. await portal.create_matrix_room(self, profile or contact)
  97. create_group_portal = self.config["bridge.autocreate_group_portal"]
  98. for group in await self.bridge.signal.list_groups(self.username):
  99. self.log.trace("Syncing group %s", group)
  100. portal = await po.Portal.get_by_chat_id(group.group_id, create=True)
  101. if create_group_portal:
  102. await portal.create_matrix_room(self, group)
  103. elif portal.mxid:
  104. await portal.update_matrix_room(self, group)
  105. # region Database getters
  106. def _add_to_cache(self) -> None:
  107. self.by_mxid[self.mxid] = self
  108. if self.username:
  109. self.by_username[self.username] = self
  110. @classmethod
  111. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['User']:
  112. # Never allow ghosts to be users
  113. if pu.Puppet.get_id_from_mxid(mxid):
  114. return None
  115. try:
  116. return cls.by_mxid[mxid]
  117. except KeyError:
  118. pass
  119. user = cast(cls, await super().get_by_mxid(mxid))
  120. if user is not None:
  121. user._add_to_cache()
  122. return user
  123. if create:
  124. user = cls(mxid)
  125. await user.insert()
  126. user._add_to_cache()
  127. return user
  128. return None
  129. @classmethod
  130. async def get_by_username(cls, username: str) -> Optional['User']:
  131. try:
  132. return cls.by_username[username]
  133. except KeyError:
  134. pass
  135. user = cast(cls, await super().get_by_username(username))
  136. if user is not None:
  137. user._add_to_cache()
  138. return user
  139. return None
  140. @classmethod
  141. async def all_logged_in(cls) -> AsyncGenerator['User', None]:
  142. users = await super().all_logged_in()
  143. user: cls
  144. for user in users:
  145. try:
  146. yield cls.by_mxid[user.mxid]
  147. except KeyError:
  148. user._add_to_cache()
  149. yield user
  150. # endregion