user.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2021 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING, AsyncGenerator, cast
  18. from asyncio.tasks import sleep
  19. from datetime import datetime
  20. from uuid import UUID
  21. import asyncio
  22. from mautrix.appservice import AppService
  23. from mautrix.bridge import AutologinError, BaseUser, async_getter_lock
  24. from mautrix.types import RoomID, UserID
  25. from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
  26. from mautrix.util.opt_prometheus import Gauge
  27. from mausignald.types import (
  28. Account,
  29. Address,
  30. Group,
  31. GroupV2,
  32. Profile,
  33. WebsocketConnectionState,
  34. WebsocketConnectionStateChangeEvent,
  35. )
  36. from . import portal as po, puppet as pu
  37. from .config import Config
  38. from .db import User as DBUser
  39. if TYPE_CHECKING:
  40. from .__main__ import SignalBridge
  41. METRIC_CONNECTED = Gauge("bridge_connected", "Bridge users connected to Signal")
  42. METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Bridge users logged into Signal")
  43. BridgeState.human_readable_errors.update(
  44. {
  45. "logged-out": "You're not logged into Signal",
  46. "signal-not-connected": None,
  47. }
  48. )
  49. class User(DBUser, BaseUser):
  50. by_mxid: dict[UserID, User] = {}
  51. by_username: dict[str, User] = {}
  52. by_uuid: dict[UUID, User] = {}
  53. config: Config
  54. az: AppService
  55. loop: asyncio.AbstractEventLoop
  56. bridge: "SignalBridge"
  57. relay_whitelisted: bool
  58. is_admin: bool
  59. permission_level: str
  60. _sync_lock: asyncio.Lock
  61. _notice_room_lock: asyncio.Lock
  62. _connected: bool
  63. _websocket_connection_state: BridgeStateEvent | None
  64. _latest_non_transient_disconnect_state: datetime | None
  65. def __init__(
  66. self,
  67. mxid: UserID,
  68. username: str | None = None,
  69. uuid: UUID | None = None,
  70. notice_room: RoomID | None = None,
  71. ) -> None:
  72. super().__init__(mxid=mxid, username=username, uuid=uuid, notice_room=notice_room)
  73. BaseUser.__init__(self)
  74. self._notice_room_lock = asyncio.Lock()
  75. self._sync_lock = asyncio.Lock()
  76. self._connected = False
  77. self._websocket_connection_state = None
  78. perms = self.config.get_permissions(mxid)
  79. self.relay_whitelisted, self.is_whitelisted, self.is_admin, self.permission_level = perms
  80. @classmethod
  81. def init_cls(cls, bridge: "SignalBridge") -> None:
  82. cls.bridge = bridge
  83. cls.config = bridge.config
  84. cls.az = bridge.az
  85. cls.loop = bridge.loop
  86. @property
  87. def address(self) -> Address | None:
  88. if not self.username:
  89. return None
  90. return Address(uuid=self.uuid, number=self.username)
  91. async def is_logged_in(self) -> bool:
  92. return bool(self.username)
  93. async def needs_relay(self, portal: po.Portal) -> bool:
  94. return not await self.is_logged_in() or (
  95. portal.is_direct and portal.receiver != self.username
  96. )
  97. async def logout(self) -> None:
  98. if not self.username:
  99. return
  100. username = self.username
  101. if self.uuid and self.by_uuid.get(self.uuid) == self:
  102. del self.by_uuid[self.uuid]
  103. if self.username and self.by_username.get(self.username) == self:
  104. del self.by_username[self.username]
  105. self.username = None
  106. self.uuid = None
  107. await self.update()
  108. await self.bridge.signal.unsubscribe(username)
  109. # Wait a while for signald to finish disconnecting
  110. await asyncio.sleep(1)
  111. await self.bridge.signal.delete_account(username)
  112. self._track_metric(METRIC_LOGGED_IN, False)
  113. await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT, remote_id=username)
  114. async def fill_bridge_state(self, state: BridgeState) -> None:
  115. await super().fill_bridge_state(state)
  116. if not state.remote_id:
  117. state.remote_id = self.username
  118. if self.address:
  119. puppet = await self.get_puppet()
  120. state.remote_name = puppet.name or self.username
  121. async def get_bridge_states(self) -> list[BridgeState]:
  122. if not self.username:
  123. return []
  124. state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR)
  125. if self.bridge.signal.is_connected and self._connected:
  126. state.state_event = BridgeStateEvent.CONNECTED
  127. else:
  128. state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
  129. return [state]
  130. async def get_puppet(self) -> pu.Puppet | None:
  131. if not self.address:
  132. return None
  133. return await pu.Puppet.get_by_address(self.address)
  134. async def on_signin(self, account: Account) -> None:
  135. self.username = account.account_id
  136. self.uuid = account.address.uuid
  137. self._add_to_cache()
  138. await self.update()
  139. await self.bridge.signal.subscribe(self.username)
  140. asyncio.create_task(self.sync())
  141. self._track_metric(METRIC_LOGGED_IN, True)
  142. def on_websocket_connection_state_change(
  143. self, evt: WebsocketConnectionStateChangeEvent
  144. ) -> None:
  145. if evt.state == WebsocketConnectionState.CONNECTED:
  146. self.log.info("Connected to Signal")
  147. self._track_metric(METRIC_CONNECTED, True)
  148. self._track_metric(METRIC_LOGGED_IN, True)
  149. self._connected = True
  150. else:
  151. self.log.warning(
  152. f"New websocket state from signald: {evt.state}. Error: {evt.exception}"
  153. )
  154. self._track_metric(METRIC_CONNECTED, False)
  155. self._connected = False
  156. bridge_state = {
  157. # Signald disconnected
  158. WebsocketConnectionState.SOCKET_DISCONNECTED: BridgeStateEvent.TRANSIENT_DISCONNECT,
  159. # Websocket state reported by signald
  160. WebsocketConnectionState.DISCONNECTED: (
  161. None
  162. if self._websocket_connection_state == BridgeStateEvent.BAD_CREDENTIALS
  163. else BridgeStateEvent.TRANSIENT_DISCONNECT
  164. ),
  165. WebsocketConnectionState.CONNECTING: BridgeStateEvent.CONNECTING,
  166. WebsocketConnectionState.CONNECTED: BridgeStateEvent.CONNECTED,
  167. WebsocketConnectionState.RECONNECTING: BridgeStateEvent.TRANSIENT_DISCONNECT,
  168. WebsocketConnectionState.DISCONNECTING: BridgeStateEvent.TRANSIENT_DISCONNECT,
  169. WebsocketConnectionState.AUTHENTICATION_FAILED: BridgeStateEvent.BAD_CREDENTIALS,
  170. WebsocketConnectionState.FAILED: BridgeStateEvent.TRANSIENT_DISCONNECT,
  171. }.get(evt.state)
  172. if bridge_state is None:
  173. self.log.info(f"Websocket state {evt.state} seen. Will not report new Bridge State")
  174. return
  175. now = datetime.now()
  176. if bridge_state == BridgeStateEvent.TRANSIENT_DISCONNECT:
  177. async def wait_report_transient_disconnect():
  178. # Wait for 10 seconds (that should be enough for the bridge to get connected)
  179. # before sending a TRANSIENT_DISCONNECT.
  180. # self._latest_non_transient_disconnect_state will only be None if the bridge is
  181. # still starting.
  182. if self._latest_non_transient_disconnect_state is None:
  183. await sleep(15)
  184. if self._latest_non_transient_disconnect_state is None:
  185. asyncio.create_task(self.push_bridge_state(bridge_state))
  186. # Wait for another minute. If the bridge stays in TRANSIENT_DISCONNECT for that
  187. # long, something terrible has happened (signald failed to restart, the internet
  188. # broke, etc.)
  189. await sleep(60)
  190. if (
  191. self._latest_non_transient_disconnect_state
  192. and now > self._latest_non_transient_disconnect_state
  193. ):
  194. asyncio.create_task(self.push_bridge_state(BridgeStateEvent.UNKNOWN_ERROR))
  195. else:
  196. self.log.info(
  197. "New state since last TRANSIENT_DISCONNECT push. "
  198. "Not transitioning to UNKNOWN_ERROR."
  199. )
  200. asyncio.create_task(wait_report_transient_disconnect())
  201. else:
  202. asyncio.create_task(self.push_bridge_state(bridge_state))
  203. self._latest_non_transient_disconnect_state = now
  204. self._websocket_connection_state = bridge_state
  205. async def _sync_puppet(self) -> None:
  206. puppet = await pu.Puppet.get_by_address(self.address)
  207. if puppet.uuid and not self.uuid:
  208. self.uuid = puppet.uuid
  209. self.by_uuid[self.uuid] = self
  210. if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
  211. self.log.info("Automatically enabling custom puppet")
  212. try:
  213. await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
  214. except AutologinError as e:
  215. self.log.warning(f"Failed to enable custom puppet: {e}")
  216. async def sync(self) -> None:
  217. await self.sync_puppet()
  218. await self.sync_contacts()
  219. await self.sync_groups()
  220. async def sync_puppet(self) -> None:
  221. try:
  222. async with self._sync_lock:
  223. await self._sync_puppet()
  224. except Exception:
  225. self.log.exception("Error while syncing own puppet")
  226. async def sync_contacts(self) -> None:
  227. try:
  228. async with self._sync_lock:
  229. await self._sync_contacts()
  230. except Exception:
  231. self.log.exception("Error while syncing contacts")
  232. async def sync_groups(self) -> None:
  233. try:
  234. async with self._sync_lock:
  235. await self._sync_groups()
  236. except Exception:
  237. self.log.exception("Error while syncing groups")
  238. async def sync_contact(self, contact: Profile | Address, create_portals: bool = False) -> None:
  239. self.log.trace("Syncing contact %s", contact)
  240. if isinstance(contact, Address):
  241. address = contact
  242. profile = await self.bridge.signal.get_profile(self.username, address, use_cache=True)
  243. if profile and profile.name:
  244. self.log.trace("Got profile for %s: %s", address, profile)
  245. else:
  246. address = contact.address
  247. profile = contact
  248. puppet = await pu.Puppet.get_by_address(address)
  249. await puppet.update_info(profile)
  250. if create_portals:
  251. portal = await po.Portal.get_by_chat_id(
  252. puppet.address, receiver=self.username, create=True
  253. )
  254. await portal.create_matrix_room(self, profile)
  255. async def _sync_group(self, group: Group, create_portals: bool) -> None:
  256. self.log.trace("Syncing group %s", group)
  257. portal = await po.Portal.get_by_chat_id(group.group_id, create=True)
  258. if create_portals:
  259. await portal.create_matrix_room(self, group)
  260. elif portal.mxid:
  261. await portal.update_matrix_room(self, group)
  262. async def _sync_group_v2(self, group: GroupV2, create_portals: bool) -> None:
  263. self.log.trace("Syncing group %s", group.id)
  264. portal = await po.Portal.get_by_chat_id(group.id, create=True)
  265. if create_portals:
  266. await portal.create_matrix_room(self, group)
  267. elif portal.mxid:
  268. await portal.update_matrix_room(self, group)
  269. async def _sync_contacts(self) -> None:
  270. create_contact_portal = self.config["bridge.autocreate_contact_portal"]
  271. for contact in await self.bridge.signal.list_contacts(self.username):
  272. try:
  273. await self.sync_contact(contact, create_contact_portal)
  274. except Exception:
  275. self.log.exception(f"Failed to sync contact {contact.address}")
  276. async def _sync_groups(self) -> None:
  277. create_group_portal = self.config["bridge.autocreate_group_portal"]
  278. for group in await self.bridge.signal.list_groups(self.username):
  279. group_id = group.group_id if isinstance(group, Group) else group.id
  280. try:
  281. if isinstance(group, Group):
  282. await self._sync_group(group, create_group_portal)
  283. elif isinstance(group, GroupV2):
  284. await self._sync_group_v2(group, create_group_portal)
  285. else:
  286. self.log.warning("Unknown return type in list_groups: %s", type(group))
  287. except Exception:
  288. self.log.exception(f"Failed to sync group {group_id}")
  289. # region Database getters
  290. def _add_to_cache(self) -> None:
  291. self.by_mxid[self.mxid] = self
  292. if self.username:
  293. self.by_username[self.username] = self
  294. if self.uuid:
  295. self.by_uuid[self.uuid] = self
  296. @classmethod
  297. @async_getter_lock
  298. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> User | None:
  299. # Never allow ghosts to be users
  300. if pu.Puppet.get_id_from_mxid(mxid):
  301. return None
  302. try:
  303. return cls.by_mxid[mxid]
  304. except KeyError:
  305. pass
  306. user = cast(cls, await super().get_by_mxid(mxid))
  307. if user is not None:
  308. user._add_to_cache()
  309. return user
  310. if create:
  311. user = cls(mxid)
  312. await user.insert()
  313. user._add_to_cache()
  314. return user
  315. return None
  316. @classmethod
  317. @async_getter_lock
  318. async def get_by_username(cls, username: str) -> User | None:
  319. try:
  320. return cls.by_username[username]
  321. except KeyError:
  322. pass
  323. user = cast(cls, await super().get_by_username(username))
  324. if user is not None:
  325. user._add_to_cache()
  326. return user
  327. return None
  328. @classmethod
  329. @async_getter_lock
  330. async def get_by_uuid(cls, uuid: UUID) -> User | None:
  331. try:
  332. return cls.by_uuid[uuid]
  333. except KeyError:
  334. pass
  335. user = cast(cls, await super().get_by_uuid(uuid))
  336. if user is not None:
  337. user._add_to_cache()
  338. return user
  339. return None
  340. @classmethod
  341. async def get_by_address(cls, address: Address) -> User | None:
  342. if address.uuid:
  343. return await cls.get_by_uuid(address.uuid)
  344. elif address.number:
  345. return await cls.get_by_username(address.number)
  346. else:
  347. raise ValueError("Given address is blank")
  348. @classmethod
  349. async def all_logged_in(cls) -> AsyncGenerator[User, None]:
  350. users = await super().all_logged_in()
  351. user: cls
  352. for user in users:
  353. try:
  354. yield cls.by_mxid[user.mxid]
  355. except KeyError:
  356. user._add_to_cache()
  357. yield user
  358. # endregion