user.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2021 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Union, cast
  17. from asyncio.tasks import sleep
  18. from datetime import datetime
  19. from uuid import UUID
  20. import asyncio
  21. from mautrix.appservice import AppService
  22. from mautrix.bridge import AutologinError, BaseUser, async_getter_lock
  23. from mautrix.types import RoomID, UserID
  24. from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
  25. from mautrix.util.opt_prometheus import Gauge
  26. from mausignald.types import (
  27. Account,
  28. Address,
  29. Group,
  30. GroupV2,
  31. Profile,
  32. WebsocketConnectionState,
  33. WebsocketConnectionStateChangeEvent,
  34. )
  35. from . import portal as po
  36. from . import puppet as pu
  37. from .config import Config
  38. from .db import User as DBUser
  39. if TYPE_CHECKING:
  40. from .__main__ import SignalBridge
  41. METRIC_CONNECTED = Gauge("bridge_connected", "Bridge users connected to Signal")
  42. METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Bridge users logged into Signal")
  43. BridgeState.human_readable_errors.update(
  44. {
  45. "logged-out": "You're not logged into Signal",
  46. "signal-not-connected": None,
  47. }
  48. )
  49. class User(DBUser, BaseUser):
  50. by_mxid: Dict[UserID, "User"] = {}
  51. by_username: Dict[str, "User"] = {}
  52. by_uuid: Dict[UUID, "User"] = {}
  53. config: Config
  54. az: AppService
  55. loop: asyncio.AbstractEventLoop
  56. bridge: "SignalBridge"
  57. relay_whitelisted: bool
  58. is_admin: bool
  59. permission_level: str
  60. _sync_lock: asyncio.Lock
  61. _notice_room_lock: asyncio.Lock
  62. _connected: bool
  63. _websocket_connection_state: Optional[WebsocketConnectionState]
  64. _latest_non_transient_disconnect_state: Optional[datetime]
  65. def __init__(
  66. self,
  67. mxid: UserID,
  68. username: Optional[str] = None,
  69. uuid: Optional[UUID] = None,
  70. notice_room: Optional[RoomID] = None,
  71. ) -> None:
  72. super().__init__(mxid=mxid, username=username, uuid=uuid, notice_room=notice_room)
  73. BaseUser.__init__(self)
  74. self._notice_room_lock = asyncio.Lock()
  75. self._sync_lock = asyncio.Lock()
  76. self._connected = False
  77. self._websocket_connection_state = None
  78. perms = self.config.get_permissions(mxid)
  79. self.relay_whitelisted, self.is_whitelisted, self.is_admin, self.permission_level = perms
  80. @classmethod
  81. def init_cls(cls, bridge: "SignalBridge") -> None:
  82. cls.bridge = bridge
  83. cls.config = bridge.config
  84. cls.az = bridge.az
  85. cls.loop = bridge.loop
  86. @property
  87. def address(self) -> Optional[Address]:
  88. if not self.username:
  89. return None
  90. return Address(uuid=self.uuid, number=self.username)
  91. async def is_logged_in(self) -> bool:
  92. return bool(self.username)
  93. async def logout(self) -> None:
  94. if not self.username:
  95. return
  96. username = self.username
  97. if self.uuid and self.by_uuid.get(self.uuid) == self:
  98. del self.by_uuid[self.uuid]
  99. if self.username and self.by_username.get(self.username) == self:
  100. del self.by_username[self.username]
  101. self.username = None
  102. self.uuid = None
  103. await self.update()
  104. await self.bridge.signal.unsubscribe(username)
  105. # Wait a while for signald to finish disconnecting
  106. await asyncio.sleep(1)
  107. await self.bridge.signal.delete_account(username)
  108. self._track_metric(METRIC_LOGGED_IN, False)
  109. await self.push_bridge_state(BridgeStateEvent.LOGGED_OUT, remote_id=username)
  110. async def fill_bridge_state(self, state: BridgeState) -> None:
  111. await super().fill_bridge_state(state)
  112. if not state.remote_id:
  113. state.remote_id = self.username
  114. if self.address:
  115. puppet = await self.get_puppet()
  116. state.remote_name = puppet.name or self.username
  117. async def get_bridge_states(self) -> List[BridgeState]:
  118. if not self.username:
  119. return []
  120. state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR)
  121. if self.bridge.signal.is_connected and self._connected:
  122. state.state_event = BridgeStateEvent.CONNECTED
  123. else:
  124. state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
  125. return [state]
  126. async def get_puppet(self) -> Optional["pu.Puppet"]:
  127. if not self.address:
  128. return None
  129. return await pu.Puppet.get_by_address(self.address)
  130. async def on_signin(self, account: Account) -> None:
  131. self.username = account.account_id
  132. self.uuid = account.address.uuid
  133. self._add_to_cache()
  134. await self.update()
  135. await self.bridge.signal.subscribe(self.username)
  136. asyncio.create_task(self.sync())
  137. self._track_metric(METRIC_LOGGED_IN, True)
  138. def on_websocket_connection_state_change(
  139. self, evt: WebsocketConnectionStateChangeEvent
  140. ) -> None:
  141. if evt.state == WebsocketConnectionState.CONNECTED:
  142. self.log.info("Connected to Signal")
  143. self._track_metric(METRIC_CONNECTED, True)
  144. self._track_metric(METRIC_LOGGED_IN, True)
  145. self._connected = True
  146. else:
  147. self.log.warning(
  148. f"New websocket state from signald: {evt.state}. Error: {evt.exception}"
  149. )
  150. self._track_metric(METRIC_CONNECTED, False)
  151. self._connected = False
  152. bridge_state = {
  153. # Signald disconnected
  154. WebsocketConnectionState.SOCKET_DISCONNECTED: BridgeStateEvent.TRANSIENT_DISCONNECT,
  155. # Websocket state reported by signald
  156. WebsocketConnectionState.DISCONNECTED: (
  157. None
  158. if self._websocket_connection_state == BridgeStateEvent.BAD_CREDENTIALS
  159. else BridgeStateEvent.TRANSIENT_DISCONNECT
  160. ),
  161. WebsocketConnectionState.CONNECTING: BridgeStateEvent.CONNECTING,
  162. WebsocketConnectionState.CONNECTED: BridgeStateEvent.CONNECTED,
  163. WebsocketConnectionState.RECONNECTING: BridgeStateEvent.TRANSIENT_DISCONNECT,
  164. WebsocketConnectionState.DISCONNECTING: BridgeStateEvent.TRANSIENT_DISCONNECT,
  165. WebsocketConnectionState.AUTHENTICATION_FAILED: BridgeStateEvent.BAD_CREDENTIALS,
  166. WebsocketConnectionState.FAILED: BridgeStateEvent.TRANSIENT_DISCONNECT,
  167. }.get(evt.state)
  168. if bridge_state is None:
  169. self.log.info(f"Websocket state {evt.state} seen. Will not report new Bridge State")
  170. return
  171. now = datetime.now()
  172. if bridge_state == BridgeStateEvent.TRANSIENT_DISCONNECT:
  173. async def wait_report_transient_disconnect():
  174. # Wait for 10 seconds (that should be enough for the bridge to get connected)
  175. # before sending a TRANSIENT_DISCONNECT.
  176. # self._latest_non_transient_disconnect_state will only be None if the bridge is
  177. # still starting.
  178. if self._latest_non_transient_disconnect_state is None:
  179. await sleep(15)
  180. if self._latest_non_transient_disconnect_state is None:
  181. asyncio.create_task(self.push_bridge_state(bridge_state))
  182. # Wait for another minute. If the bridge stays in TRANSIENT_DISCONNECT for that
  183. # long, something terrible has happened (signald failed to restart, the internet
  184. # broke, etc.)
  185. await sleep(60)
  186. if (
  187. self._latest_non_transient_disconnect_state
  188. and now > self._latest_non_transient_disconnect_state
  189. ):
  190. asyncio.create_task(self.push_bridge_state(BridgeStateEvent.UNKNOWN_ERROR))
  191. else:
  192. self.log.info(
  193. "New state since last TRANSIENT_DISCONNECT push. "
  194. "Not transitioning to UNKNOWN_ERROR."
  195. )
  196. asyncio.create_task(wait_report_transient_disconnect())
  197. else:
  198. asyncio.create_task(self.push_bridge_state(bridge_state))
  199. self._latest_non_transient_disconnect_state = now
  200. self._websocket_connection_state = bridge_state
  201. async def _sync_puppet(self) -> None:
  202. puppet = await pu.Puppet.get_by_address(self.address)
  203. if puppet.uuid and not self.uuid:
  204. self.uuid = puppet.uuid
  205. self.by_uuid[self.uuid] = self
  206. if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
  207. self.log.info("Automatically enabling custom puppet")
  208. try:
  209. await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
  210. except AutologinError as e:
  211. self.log.warning(f"Failed to enable custom puppet: {e}")
  212. async def sync(self) -> None:
  213. await self.sync_puppet()
  214. await self.sync_contacts()
  215. await self.sync_groups()
  216. async def sync_puppet(self) -> None:
  217. try:
  218. async with self._sync_lock:
  219. await self._sync_puppet()
  220. except Exception:
  221. self.log.exception("Error while syncing own puppet")
  222. async def sync_contacts(self) -> None:
  223. try:
  224. async with self._sync_lock:
  225. await self._sync_contacts()
  226. except Exception:
  227. self.log.exception("Error while syncing contacts")
  228. async def sync_groups(self) -> None:
  229. try:
  230. async with self._sync_lock:
  231. await self._sync_groups()
  232. except Exception:
  233. self.log.exception("Error while syncing groups")
  234. async def sync_contact(
  235. self, contact: Union[Profile, Address], create_portals: bool = False
  236. ) -> None:
  237. self.log.trace("Syncing contact %s", contact)
  238. if isinstance(contact, Address):
  239. address = contact
  240. profile = await self.bridge.signal.get_profile(self.username, address, use_cache=True)
  241. if profile and profile.name:
  242. self.log.trace("Got profile for %s: %s", address, profile)
  243. else:
  244. address = contact.address
  245. profile = contact
  246. puppet = await pu.Puppet.get_by_address(address)
  247. await puppet.update_info(profile)
  248. if create_portals:
  249. portal = await po.Portal.get_by_chat_id(
  250. puppet.address, receiver=self.username, create=True
  251. )
  252. await portal.create_matrix_room(self, profile)
  253. async def _sync_group(self, group: Group, create_portals: bool) -> None:
  254. self.log.trace("Syncing group %s", group)
  255. portal = await po.Portal.get_by_chat_id(group.group_id, create=True)
  256. if create_portals:
  257. await portal.create_matrix_room(self, group)
  258. elif portal.mxid:
  259. await portal.update_matrix_room(self, group)
  260. async def _sync_group_v2(self, group: GroupV2, create_portals: bool) -> None:
  261. self.log.trace("Syncing group %s", group.id)
  262. portal = await po.Portal.get_by_chat_id(group.id, create=True)
  263. if create_portals:
  264. await portal.create_matrix_room(self, group)
  265. elif portal.mxid:
  266. await portal.update_matrix_room(self, group)
  267. async def _sync_contacts(self) -> None:
  268. create_contact_portal = self.config["bridge.autocreate_contact_portal"]
  269. for contact in await self.bridge.signal.list_contacts(self.username):
  270. try:
  271. await self.sync_contact(contact, create_contact_portal)
  272. except Exception:
  273. self.log.exception(f"Failed to sync contact {contact.address}")
  274. async def _sync_groups(self) -> None:
  275. create_group_portal = self.config["bridge.autocreate_group_portal"]
  276. for group in await self.bridge.signal.list_groups(self.username):
  277. group_id = group.group_id if isinstance(group, Group) else group.id
  278. try:
  279. if isinstance(group, Group):
  280. await self._sync_group(group, create_group_portal)
  281. elif isinstance(group, GroupV2):
  282. await self._sync_group_v2(group, create_group_portal)
  283. else:
  284. self.log.warning("Unknown return type in list_groups: %s", type(group))
  285. except Exception:
  286. self.log.exception(f"Failed to sync group {group_id}")
  287. # region Database getters
  288. def _add_to_cache(self) -> None:
  289. self.by_mxid[self.mxid] = self
  290. if self.username:
  291. self.by_username[self.username] = self
  292. if self.uuid:
  293. self.by_uuid[self.uuid] = self
  294. @classmethod
  295. @async_getter_lock
  296. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional["User"]:
  297. # Never allow ghosts to be users
  298. if pu.Puppet.get_id_from_mxid(mxid):
  299. return None
  300. try:
  301. return cls.by_mxid[mxid]
  302. except KeyError:
  303. pass
  304. user = cast(cls, await super().get_by_mxid(mxid))
  305. if user is not None:
  306. user._add_to_cache()
  307. return user
  308. if create:
  309. user = cls(mxid)
  310. await user.insert()
  311. user._add_to_cache()
  312. return user
  313. return None
  314. @classmethod
  315. @async_getter_lock
  316. async def get_by_username(cls, username: str) -> Optional["User"]:
  317. try:
  318. return cls.by_username[username]
  319. except KeyError:
  320. pass
  321. user = cast(cls, await super().get_by_username(username))
  322. if user is not None:
  323. user._add_to_cache()
  324. return user
  325. return None
  326. @classmethod
  327. @async_getter_lock
  328. async def get_by_uuid(cls, uuid: UUID) -> Optional["User"]:
  329. try:
  330. return cls.by_uuid[uuid]
  331. except KeyError:
  332. pass
  333. user = cast(cls, await super().get_by_uuid(uuid))
  334. if user is not None:
  335. user._add_to_cache()
  336. return user
  337. return None
  338. @classmethod
  339. async def get_by_address(cls, address: Address) -> Optional["User"]:
  340. if address.uuid:
  341. return await cls.get_by_uuid(address.uuid)
  342. elif address.number:
  343. return await cls.get_by_username(address.number)
  344. else:
  345. raise ValueError("Given address is blank")
  346. @classmethod
  347. async def all_logged_in(cls) -> AsyncGenerator["User", None]:
  348. users = await super().all_logged_in()
  349. user: cls
  350. for user in users:
  351. try:
  352. yield cls.by_mxid[user.mxid]
  353. except KeyError:
  354. user._add_to_cache()
  355. yield user
  356. # endregion