user.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import Union, Dict, Optional, AsyncGenerator, TYPE_CHECKING, cast
  17. from collections import defaultdict
  18. from uuid import UUID
  19. import asyncio
  20. import os.path
  21. import shutil
  22. from mausignald.types import Account, Address, Contact, Group, GroupV2, ListenEvent, ListenAction
  23. from mautrix.bridge import BaseUser, async_getter_lock
  24. from mautrix.types import UserID, RoomID
  25. from mautrix.appservice import AppService
  26. from mautrix.util.opt_prometheus import Gauge
  27. from .db import User as DBUser
  28. from .config import Config
  29. from . import puppet as pu, portal as po
  30. if TYPE_CHECKING:
  31. from .__main__ import SignalBridge
  32. METRIC_CONNECTED = Gauge('bridge_connected', 'Bridge users connected to Signal')
  33. class User(DBUser, BaseUser):
  34. by_mxid: Dict[UserID, 'User'] = {}
  35. by_username: Dict[str, 'User'] = {}
  36. by_uuid: Dict[UUID, 'User'] = {}
  37. config: Config
  38. az: AppService
  39. loop: asyncio.AbstractEventLoop
  40. bridge: 'SignalBridge'
  41. is_admin: bool
  42. permission_level: str
  43. _notice_room_lock: asyncio.Lock
  44. def __init__(self, mxid: UserID, username: Optional[str] = None, uuid: Optional[UUID] = None,
  45. notice_room: Optional[RoomID] = None) -> None:
  46. super().__init__(mxid=mxid, username=username, uuid=uuid, notice_room=notice_room)
  47. self._notice_room_lock = asyncio.Lock()
  48. perms = self.config.get_permissions(mxid)
  49. self.is_whitelisted, self.is_admin, self.permission_level = perms
  50. self.log = self.log.getChild(self.mxid)
  51. self.dm_update_lock = asyncio.Lock()
  52. self.command_status = None
  53. self._metric_value = defaultdict(lambda: False)
  54. @classmethod
  55. def init_cls(cls, bridge: 'SignalBridge') -> None:
  56. cls.bridge = bridge
  57. cls.config = bridge.config
  58. cls.az = bridge.az
  59. cls.loop = bridge.loop
  60. @property
  61. def address(self) -> Optional[Address]:
  62. if not self.username:
  63. return None
  64. return Address(uuid=self.uuid, number=self.username)
  65. async def is_logged_in(self) -> bool:
  66. return bool(self.username)
  67. async def logout(self) -> None:
  68. if not self.username:
  69. return
  70. username = self.username
  71. if self.uuid and self.by_uuid.get(self.uuid) == self:
  72. del self.by_uuid[self.uuid]
  73. if self.username and self.by_username.get(self.username) == self:
  74. del self.by_username[self.username]
  75. self.username = None
  76. self.uuid = None
  77. await self.update()
  78. await self.bridge.signal.unsubscribe(username)
  79. # Wait a while for signald to finish disconnecting
  80. await asyncio.sleep(1)
  81. path = os.path.join(self.config["signal.data_dir"], username)
  82. extra_dir = f"{path}.d/"
  83. try:
  84. self.log.debug("Removing %s", path)
  85. os.remove(path)
  86. except FileNotFoundError as e:
  87. self.log.warning(f"Failed to remove signald data file: {e}")
  88. self.log.debug("Removing %s", extra_dir)
  89. shutil.rmtree(extra_dir, ignore_errors=True)
  90. async def on_signin(self, account: Account) -> None:
  91. self.username = account.username
  92. self.uuid = account.uuid
  93. self._add_to_cache()
  94. await self.update()
  95. await self.bridge.signal.subscribe(self.username)
  96. self.loop.create_task(self.sync())
  97. def on_listen(self, evt: ListenEvent) -> None:
  98. if evt.action == ListenAction.STARTED:
  99. self.log.info("Connected to Signal")
  100. self._track_metric(METRIC_CONNECTED, True)
  101. elif evt.action == ListenAction.STOPPED:
  102. if evt.exception:
  103. self.log.warning(f"Disconnected from Signal: {evt.exception}")
  104. else:
  105. self.log.info("Disconnected from Signal")
  106. self._track_metric(METRIC_CONNECTED, False)
  107. else:
  108. self.log.warning(f"Unrecognized listen action {evt.action}")
  109. async def _sync_puppet(self) -> None:
  110. puppet = await pu.Puppet.get_by_address(self.address)
  111. if puppet.uuid and not self.uuid:
  112. self.uuid = puppet.uuid
  113. self.by_uuid[self.uuid] = self
  114. if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
  115. self.log.info(f"Automatically enabling custom puppet")
  116. await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
  117. async def sync(self) -> None:
  118. try:
  119. await self._sync_puppet()
  120. except Exception:
  121. self.log.exception("Error while syncing own puppet")
  122. try:
  123. await self._sync_contacts()
  124. except Exception:
  125. self.log.exception("Error while syncing contacts")
  126. try:
  127. await self._sync_groups()
  128. except Exception:
  129. self.log.exception("Error while syncing groups")
  130. async def sync_contact(self, contact: Union[Contact, Address], create_portals: bool = False
  131. ) -> None:
  132. self.log.trace("Syncing contact %s", contact)
  133. address = contact.address if isinstance(contact, Contact) else contact
  134. puppet = await pu.Puppet.get_by_address(address)
  135. profile = await self.bridge.signal.get_profile(self.username, address)
  136. if profile and profile.name:
  137. self.log.trace("Got profile for %s: %s", address, profile)
  138. else:
  139. profile = None
  140. await puppet.update_info(profile or contact)
  141. if create_portals:
  142. portal = await po.Portal.get_by_chat_id(puppet.address, receiver=self.username,
  143. create=True)
  144. await portal.create_matrix_room(self, profile or contact)
  145. async def _sync_group(self, group: Group, create_portals: bool) -> None:
  146. self.log.trace("Syncing group %s", group)
  147. portal = await po.Portal.get_by_chat_id(group.group_id, create=True)
  148. if create_portals:
  149. await portal.create_matrix_room(self, group)
  150. elif portal.mxid:
  151. await portal.update_matrix_room(self, group)
  152. async def _sync_group_v2(self, group: GroupV2, create_portals: bool) -> None:
  153. self.log.trace("Syncing group %s", group.id)
  154. portal = await po.Portal.get_by_chat_id(group.id, create=True)
  155. if create_portals:
  156. await portal.create_matrix_room(self, group)
  157. elif portal.mxid:
  158. await portal.update_matrix_room(self, group)
  159. async def _sync_contacts(self) -> None:
  160. create_contact_portal = self.config["bridge.autocreate_contact_portal"]
  161. for contact in await self.bridge.signal.list_contacts(self.username):
  162. try:
  163. await self.sync_contact(contact, create_contact_portal)
  164. except Exception:
  165. self.log.exception(f"Failed to sync contact {contact.address}")
  166. async def _sync_groups(self) -> None:
  167. create_group_portal = self.config["bridge.autocreate_group_portal"]
  168. for group in await self.bridge.signal.list_groups(self.username):
  169. group_id = group.group_id if isinstance(group, Group) else group.id
  170. try:
  171. if isinstance(group, Group):
  172. await self._sync_group(group, create_group_portal)
  173. elif isinstance(group, GroupV2):
  174. await self._sync_group_v2(group, create_group_portal)
  175. else:
  176. self.log.warning("Unknown return type in list_groups: %s", type(group))
  177. except Exception:
  178. self.log.exception(f"Failed to sync group {group_id}")
  179. # region Database getters
  180. def _add_to_cache(self) -> None:
  181. self.by_mxid[self.mxid] = self
  182. if self.username:
  183. self.by_username[self.username] = self
  184. if self.uuid:
  185. self.by_uuid[self.uuid] = self
  186. @classmethod
  187. @async_getter_lock
  188. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['User']:
  189. # Never allow ghosts to be users
  190. if pu.Puppet.get_id_from_mxid(mxid):
  191. return None
  192. try:
  193. return cls.by_mxid[mxid]
  194. except KeyError:
  195. pass
  196. user = cast(cls, await super().get_by_mxid(mxid))
  197. if user is not None:
  198. user._add_to_cache()
  199. return user
  200. if create:
  201. user = cls(mxid)
  202. await user.insert()
  203. user._add_to_cache()
  204. return user
  205. return None
  206. @classmethod
  207. @async_getter_lock
  208. async def get_by_username(cls, username: str) -> Optional['User']:
  209. try:
  210. return cls.by_username[username]
  211. except KeyError:
  212. pass
  213. user = cast(cls, await super().get_by_username(username))
  214. if user is not None:
  215. user._add_to_cache()
  216. return user
  217. return None
  218. @classmethod
  219. @async_getter_lock
  220. async def get_by_uuid(cls, uuid: UUID) -> Optional['User']:
  221. try:
  222. return cls.by_uuid[uuid]
  223. except KeyError:
  224. pass
  225. user = cast(cls, await super().get_by_uuid(uuid))
  226. if user is not None:
  227. user._add_to_cache()
  228. return user
  229. return None
  230. @classmethod
  231. async def get_by_address(cls, address: Address) -> Optional['User']:
  232. if address.uuid:
  233. return await cls.get_by_uuid(address.uuid)
  234. elif address.number:
  235. return await cls.get_by_username(address.number)
  236. else:
  237. raise ValueError("Given address is blank")
  238. @classmethod
  239. async def all_logged_in(cls) -> AsyncGenerator['User', None]:
  240. users = await super().all_logged_in()
  241. user: cls
  242. for user in users:
  243. try:
  244. yield cls.by_mxid[user.mxid]
  245. except KeyError:
  246. user._add_to_cache()
  247. yield user
  248. # endregion