user.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import Union, Dict, Optional, AsyncGenerator, TYPE_CHECKING, cast
  17. from collections import defaultdict
  18. from uuid import UUID
  19. import asyncio
  20. from mausignald.types import Account, Address, Profile, Group, GroupV2, ListenEvent, ListenAction
  21. from mautrix.bridge import BaseUser, async_getter_lock
  22. from mautrix.types import UserID, RoomID
  23. from mautrix.appservice import AppService
  24. from mautrix.util.opt_prometheus import Gauge
  25. from .db import User as DBUser
  26. from .config import Config
  27. from . import puppet as pu, portal as po
  28. if TYPE_CHECKING:
  29. from .__main__ import SignalBridge
  30. METRIC_CONNECTED = Gauge('bridge_connected', 'Bridge users connected to Signal')
  31. METRIC_LOGGED_IN = Gauge('bridge_logged_in', 'Bridge users logged into Signal')
  32. class User(DBUser, BaseUser):
  33. by_mxid: Dict[UserID, 'User'] = {}
  34. by_username: Dict[str, 'User'] = {}
  35. by_uuid: Dict[UUID, 'User'] = {}
  36. config: Config
  37. az: AppService
  38. loop: asyncio.AbstractEventLoop
  39. bridge: 'SignalBridge'
  40. is_admin: bool
  41. permission_level: str
  42. _sync_lock: asyncio.Lock
  43. _notice_room_lock: asyncio.Lock
  44. def __init__(self, mxid: UserID, username: Optional[str] = None, uuid: Optional[UUID] = None,
  45. notice_room: Optional[RoomID] = None) -> None:
  46. super().__init__(mxid=mxid, username=username, uuid=uuid, notice_room=notice_room)
  47. self._notice_room_lock = asyncio.Lock()
  48. self._sync_lock = asyncio.Lock()
  49. perms = self.config.get_permissions(mxid)
  50. self.is_whitelisted, self.is_admin, self.permission_level = perms
  51. self.log = self.log.getChild(self.mxid)
  52. self.dm_update_lock = asyncio.Lock()
  53. self.command_status = None
  54. self._metric_value = defaultdict(lambda: False)
  55. @classmethod
  56. def init_cls(cls, bridge: 'SignalBridge') -> None:
  57. cls.bridge = bridge
  58. cls.config = bridge.config
  59. cls.az = bridge.az
  60. cls.loop = bridge.loop
  61. @property
  62. def address(self) -> Optional[Address]:
  63. if not self.username:
  64. return None
  65. return Address(uuid=self.uuid, number=self.username)
  66. async def is_logged_in(self) -> bool:
  67. return bool(self.username)
  68. async def logout(self) -> None:
  69. if not self.username:
  70. return
  71. username = self.username
  72. if self.uuid and self.by_uuid.get(self.uuid) == self:
  73. del self.by_uuid[self.uuid]
  74. if self.username and self.by_username.get(self.username) == self:
  75. del self.by_username[self.username]
  76. self.username = None
  77. self.uuid = None
  78. await self.update()
  79. await self.bridge.signal.unsubscribe(username)
  80. # Wait a while for signald to finish disconnecting
  81. await asyncio.sleep(1)
  82. self.bridge.signal.delete_data(username)
  83. self._track_metric(METRIC_LOGGED_IN, False)
  84. async def on_signin(self, account: Account) -> None:
  85. self.username = account.account_id
  86. self.uuid = account.address.uuid
  87. self._add_to_cache()
  88. await self.update()
  89. await self.bridge.signal.subscribe(self.username)
  90. asyncio.create_task(self.sync())
  91. self._track_metric(METRIC_LOGGED_IN, True)
  92. def on_listen(self, evt: ListenEvent) -> None:
  93. if evt.action == ListenAction.STARTED:
  94. self.log.info("Connected to Signal")
  95. self._track_metric(METRIC_CONNECTED, True)
  96. self._track_metric(METRIC_LOGGED_IN, True)
  97. elif evt.action == ListenAction.STOPPED:
  98. if evt.exception:
  99. self.log.warning(f"Disconnected from Signal: {evt.exception}")
  100. else:
  101. self.log.info("Disconnected from Signal")
  102. self._track_metric(METRIC_CONNECTED, False)
  103. else:
  104. self.log.warning(f"Unrecognized listen action {evt.action}")
  105. async def _sync_puppet(self) -> None:
  106. puppet = await pu.Puppet.get_by_address(self.address)
  107. if puppet.uuid and not self.uuid:
  108. self.uuid = puppet.uuid
  109. self.by_uuid[self.uuid] = self
  110. if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
  111. self.log.info(f"Automatically enabling custom puppet")
  112. await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
  113. async def sync(self) -> None:
  114. await self.sync_puppet()
  115. await self.sync_contacts()
  116. await self.sync_groups()
  117. async def sync_puppet(self) -> None:
  118. try:
  119. async with self._sync_lock:
  120. await self._sync_puppet()
  121. except Exception:
  122. self.log.exception("Error while syncing own puppet")
  123. async def sync_contacts(self) -> None:
  124. try:
  125. async with self._sync_lock:
  126. await self._sync_contacts()
  127. except Exception:
  128. self.log.exception("Error while syncing contacts")
  129. async def sync_groups(self) -> None:
  130. try:
  131. async with self._sync_lock:
  132. await self._sync_groups()
  133. except Exception:
  134. self.log.exception("Error while syncing groups")
  135. async def sync_contact(self, contact: Union[Profile, Address], create_portals: bool = False
  136. ) -> None:
  137. self.log.trace("Syncing contact %s", contact)
  138. if isinstance(contact, Address):
  139. address = contact
  140. profile = await self.bridge.signal.get_profile(self.username, address)
  141. if profile and profile.name:
  142. self.log.trace("Got profile for %s: %s", address, profile)
  143. else:
  144. address = contact.address
  145. profile = contact
  146. puppet = await pu.Puppet.get_by_address(address)
  147. await puppet.update_info(profile)
  148. if create_portals:
  149. portal = await po.Portal.get_by_chat_id(puppet.address, receiver=self.username,
  150. create=True)
  151. await portal.create_matrix_room(self, profile)
  152. async def _sync_group(self, group: Group, create_portals: bool) -> None:
  153. self.log.trace("Syncing group %s", group)
  154. portal = await po.Portal.get_by_chat_id(group.group_id, create=True)
  155. if create_portals:
  156. await portal.create_matrix_room(self, group)
  157. elif portal.mxid:
  158. await portal.update_matrix_room(self, group)
  159. async def _sync_group_v2(self, group: GroupV2, create_portals: bool) -> None:
  160. self.log.trace("Syncing group %s", group.id)
  161. portal = await po.Portal.get_by_chat_id(group.id, create=True)
  162. if create_portals:
  163. await portal.create_matrix_room(self, group)
  164. elif portal.mxid:
  165. await portal.update_matrix_room(self, group)
  166. async def _sync_contacts(self) -> None:
  167. create_contact_portal = self.config["bridge.autocreate_contact_portal"]
  168. for contact in await self.bridge.signal.list_contacts(self.username):
  169. try:
  170. await self.sync_contact(contact, create_contact_portal)
  171. except Exception:
  172. self.log.exception(f"Failed to sync contact {contact.address}")
  173. async def _sync_groups(self) -> None:
  174. create_group_portal = self.config["bridge.autocreate_group_portal"]
  175. for group in await self.bridge.signal.list_groups(self.username):
  176. group_id = group.group_id if isinstance(group, Group) else group.id
  177. try:
  178. if isinstance(group, Group):
  179. await self._sync_group(group, create_group_portal)
  180. elif isinstance(group, GroupV2):
  181. await self._sync_group_v2(group, create_group_portal)
  182. else:
  183. self.log.warning("Unknown return type in list_groups: %s", type(group))
  184. except Exception:
  185. self.log.exception(f"Failed to sync group {group_id}")
  186. # region Database getters
  187. def _add_to_cache(self) -> None:
  188. self.by_mxid[self.mxid] = self
  189. if self.username:
  190. self.by_username[self.username] = self
  191. if self.uuid:
  192. self.by_uuid[self.uuid] = self
  193. @classmethod
  194. @async_getter_lock
  195. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['User']:
  196. # Never allow ghosts to be users
  197. if pu.Puppet.get_id_from_mxid(mxid):
  198. return None
  199. try:
  200. return cls.by_mxid[mxid]
  201. except KeyError:
  202. pass
  203. user = cast(cls, await super().get_by_mxid(mxid))
  204. if user is not None:
  205. user._add_to_cache()
  206. return user
  207. if create:
  208. user = cls(mxid)
  209. await user.insert()
  210. user._add_to_cache()
  211. return user
  212. return None
  213. @classmethod
  214. @async_getter_lock
  215. async def get_by_username(cls, username: str) -> Optional['User']:
  216. try:
  217. return cls.by_username[username]
  218. except KeyError:
  219. pass
  220. user = cast(cls, await super().get_by_username(username))
  221. if user is not None:
  222. user._add_to_cache()
  223. return user
  224. return None
  225. @classmethod
  226. @async_getter_lock
  227. async def get_by_uuid(cls, uuid: UUID) -> Optional['User']:
  228. try:
  229. return cls.by_uuid[uuid]
  230. except KeyError:
  231. pass
  232. user = cast(cls, await super().get_by_uuid(uuid))
  233. if user is not None:
  234. user._add_to_cache()
  235. return user
  236. return None
  237. @classmethod
  238. async def get_by_address(cls, address: Address) -> Optional['User']:
  239. if address.uuid:
  240. return await cls.get_by_uuid(address.uuid)
  241. elif address.number:
  242. return await cls.get_by_username(address.number)
  243. else:
  244. raise ValueError("Given address is blank")
  245. @classmethod
  246. async def all_logged_in(cls) -> AsyncGenerator['User', None]:
  247. users = await super().all_logged_in()
  248. user: cls
  249. for user in users:
  250. try:
  251. yield cls.by_mxid[user.mxid]
  252. except KeyError:
  253. user._add_to_cache()
  254. yield user
  255. # endregion