provisioning_api.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING
  18. import asyncio
  19. import json
  20. import logging
  21. from aiohttp import web
  22. from mausignald.errors import (
  23. InternalError,
  24. ScanTimeoutError,
  25. TimeoutException,
  26. UnregisteredUserError,
  27. )
  28. from mausignald.types import Account, Address, Profile
  29. from mautrix.types import JSON, UserID
  30. from mautrix.util.logging import TraceLogger
  31. from .. import portal as po, puppet as pu, user as u
  32. from ..util import normalize_number
  33. from .segment_analytics import init as init_segment, track
  34. if TYPE_CHECKING:
  35. from ..__main__ import SignalBridge
  36. class ProvisioningAPI:
  37. log: TraceLogger = logging.getLogger("mau.web.provisioning")
  38. app: web.Application
  39. bridge: "SignalBridge"
  40. def __init__(
  41. self,
  42. bridge: "SignalBridge",
  43. shared_secret: str,
  44. segment_key: str | None,
  45. segment_user_id: str | None,
  46. ) -> None:
  47. self.bridge = bridge
  48. self.app = web.Application()
  49. self.shared_secret = shared_secret
  50. if segment_key:
  51. init_segment(segment_key, segment_user_id)
  52. # Whoami
  53. self.app.router.add_get("/v1/api/whoami", self.status)
  54. self.app.router.add_get("/v2/whoami", self.status)
  55. # Logout
  56. self.app.router.add_options("/v1/api/logout", self.login_options)
  57. self.app.router.add_post("/v1/api/logout", self.logout)
  58. self.app.router.add_options("/v2/logout", self.login_options)
  59. self.app.router.add_post("/v2/logout", self.logout)
  60. # Link API (will be deprecated soon)
  61. self.app.router.add_options("/v1/api/link", self.login_options)
  62. self.app.router.add_options("/v1/api/link/wait", self.login_options)
  63. self.app.router.add_post("/v1/api/link", self.link)
  64. self.app.router.add_post("/v1/api/link/wait", self.link_wait)
  65. # New Login API
  66. self.app.router.add_options("/v2/link/new", self.login_options)
  67. self.app.router.add_options("/v2/link/wait/scan", self.login_options)
  68. self.app.router.add_options("/v2/link/wait/account", self.login_options)
  69. self.app.router.add_post("/v2/link/new", self.link_new)
  70. self.app.router.add_post("/v2/link/wait/scan", self.link_wait_for_scan)
  71. self.app.router.add_post("/v2/link/wait/account", self.link_wait_for_account)
  72. # Start new chat API
  73. self.app.router.add_get("/v2/contacts", self.list_contacts)
  74. self.app.router.add_get("/v2/sync", self.request_sync)
  75. self.app.router.add_get("/v2/resolve_identifier/{number}", self.resolve_identifier)
  76. self.app.router.add_post("/v2/pm/{number}", self.start_pm)
  77. @property
  78. def _acao_headers(self) -> dict[str, str]:
  79. return {
  80. "Access-Control-Allow-Origin": "*",
  81. "Access-Control-Allow-Headers": "Authorization, Content-Type",
  82. "Access-Control-Allow-Methods": "POST, OPTIONS",
  83. }
  84. @property
  85. def _headers(self) -> dict[str, str]:
  86. return {
  87. **self._acao_headers,
  88. "Content-Type": "application/json",
  89. }
  90. async def login_options(self, _: web.Request) -> web.Response:
  91. return web.Response(status=200, headers=self._headers)
  92. async def check_token(self, request: web.Request) -> "u.User":
  93. try:
  94. token = request.headers["Authorization"]
  95. token = token[len("Bearer ") :]
  96. except KeyError:
  97. raise web.HTTPBadRequest(
  98. text='{"error": "Missing Authorization header"}', headers=self._headers
  99. )
  100. except IndexError:
  101. raise web.HTTPBadRequest(
  102. text='{"error": "Malformed Authorization header"}', headers=self._headers
  103. )
  104. if token != self.shared_secret:
  105. raise web.HTTPForbidden(text='{"error": "Invalid token"}', headers=self._headers)
  106. try:
  107. user_id = request.query["user_id"]
  108. except KeyError:
  109. raise web.HTTPBadRequest(
  110. text='{"error": "Missing user_id query param"}', headers=self._headers
  111. )
  112. try:
  113. if not self.bridge.signal.is_connected:
  114. await self.bridge.signal.wait_for_connected(timeout=10)
  115. except asyncio.TimeoutError:
  116. raise web.HTTPServiceUnavailable(
  117. text=json.dumps({"error": "Cannot connect to signald"}), headers=self._headers
  118. )
  119. return await u.User.get_by_mxid(UserID(user_id))
  120. async def check_token_and_logged_in(self, request: web.Request) -> "u.User":
  121. user = await self.check_token(request)
  122. if not await user.is_logged_in():
  123. error = {"error": "You're not logged in"}
  124. raise web.HTTPNotFound(text=json.dumps(error), headers=self._headers)
  125. return user
  126. async def status(self, request: web.Request) -> web.Response:
  127. user = await self.check_token(request)
  128. data = {
  129. "permissions": user.permission_level,
  130. "mxid": user.mxid,
  131. "signal": None,
  132. }
  133. if await user.is_logged_in():
  134. try:
  135. profile = await self.bridge.signal.get_profile(
  136. username=user.username, address=user.address
  137. )
  138. except Exception as e:
  139. self.log.exception(f"Failed to get {user.username}'s profile for whoami")
  140. await user.handle_auth_failure(e)
  141. data["signal"] = {
  142. "number": user.username,
  143. "ok": False,
  144. "error": str(e),
  145. }
  146. else:
  147. addr = profile.address if profile else None
  148. number = addr.number if addr else None
  149. uuid = addr.uuid if addr else None
  150. data["signal"] = {
  151. "number": number or user.username,
  152. "uuid": str(uuid or user.uuid or ""),
  153. "name": profile.name if profile else None,
  154. "ok": True,
  155. }
  156. return web.json_response(data, headers=self._acao_headers)
  157. async def _shielded_link(self, user: "u.User", session_id: str, device_name: str) -> Account:
  158. try:
  159. self.log.debug(f"Starting finish link request for {user.mxid} / {session_id}")
  160. account = await self.bridge.signal.finish_link(
  161. session_id=session_id, device_name=device_name, overwrite=True
  162. )
  163. except TimeoutException:
  164. self.log.warning(f"Timed out waiting for linking to finish (session {session_id})")
  165. raise
  166. except Exception:
  167. self.log.exception(
  168. f"Fatal error while waiting for linking to finish (session {session_id})"
  169. )
  170. raise
  171. else:
  172. await user.on_signin(account)
  173. return account
  174. async def _try_shielded_link(
  175. self, user: "u.User", session_id: str, device_name: str
  176. ) -> web.Response:
  177. try:
  178. account = await asyncio.shield(self._shielded_link(user, session_id, device_name))
  179. except asyncio.CancelledError:
  180. self.log.warning(
  181. f"Client cancelled link wait request ({session_id}) before it finished"
  182. )
  183. raise
  184. except (TimeoutException, ScanTimeoutError):
  185. raise web.HTTPBadRequest(
  186. text='{"error": "Signal linking timed out"}', headers=self._headers
  187. )
  188. except InternalError:
  189. raise web.HTTPInternalServerError(
  190. text='{"error": "Fatal error in Signal linking"}', headers=self._headers
  191. )
  192. except Exception:
  193. raise web.HTTPInternalServerError(
  194. text='{"error": "Fatal error in Signal linking"}', headers=self._headers
  195. )
  196. else:
  197. return web.json_response(account.address.serialize())
  198. # region Old Link API
  199. async def link(self, request: web.Request) -> web.Response:
  200. user = await self.check_token(request)
  201. if await user.is_logged_in():
  202. raise web.HTTPConflict(
  203. text="""{"error": "You're already logged in"}""", headers=self._headers
  204. )
  205. try:
  206. data = await request.json()
  207. except json.JSONDecodeError:
  208. raise web.HTTPBadRequest(text='{"error": "Malformed JSON"}', headers=self._headers)
  209. device_name = data.get("device_name", "Mautrix-Signal bridge")
  210. sess = await self.bridge.signal.start_link()
  211. user.command_status = {
  212. "action": "Link",
  213. "session_id": sess.session_id,
  214. "device_name": device_name,
  215. }
  216. self.log.debug(f"Returning linking URI for {user.mxid} / {sess.session_id}")
  217. return web.json_response({"uri": sess.uri}, headers=self._acao_headers)
  218. async def link_wait(self, request: web.Request) -> web.Response:
  219. user = await self.check_token(request)
  220. if not user.command_status or user.command_status["action"] != "Link":
  221. raise web.HTTPBadRequest(
  222. text='{"error": "No Signal linking started"}', headers=self._headers
  223. )
  224. session_id = user.command_status["session_id"]
  225. device_name = user.command_status["device_name"]
  226. return await self._try_shielded_link(user, session_id, device_name)
  227. # endregion
  228. # region New Link API
  229. async def _get_request_data(self, request: web.Request) -> tuple[u.User, dict]:
  230. user = await self.check_token(request)
  231. if await user.is_logged_in():
  232. error_text = """{"error": "You're already logged in"}"""
  233. raise web.HTTPConflict(text=error_text, headers=self._headers)
  234. try:
  235. return user, (await request.json())
  236. except json.JSONDecodeError:
  237. raise web.HTTPBadRequest(text='{"error": "Malformed JSON"}', headers=self._headers)
  238. async def link_new(self, request: web.Request) -> web.Response:
  239. """
  240. Starts a new link session.
  241. Params: none
  242. Returns a JSON object with the following fields:
  243. * session_id: a session ID that should be used for all future link-related commands
  244. (wait_for_scan and wait_for_account).
  245. * uri: a URI that should be used to display the QR code.
  246. """
  247. user, _ = await self._get_request_data(request)
  248. self.log.debug(f"Getting session ID and link URI for {user.mxid}")
  249. try:
  250. sess = await self.bridge.signal.start_link()
  251. track(user, "$link_new_success")
  252. self.log.debug(
  253. f"Returning session ID and link URI for {user.mxid} / {sess.session_id}"
  254. )
  255. return web.json_response(sess.serialize(), headers=self._acao_headers)
  256. except Exception as e:
  257. error = {"error": f"Getting a new link failed: {e}"}
  258. track(user, "$link_new_failed", error)
  259. raise web.HTTPBadRequest(text=json.dumps(error), headers=self._headers)
  260. async def link_wait_for_scan(self, request: web.Request) -> web.Response:
  261. """
  262. Waits for the QR code associated with the provided session ID to be scanned.
  263. Params: a JSON object with the following field:
  264. * session_id: a session ID that you got from a call to /link/v2/new.
  265. """
  266. user, request_data = await self._get_request_data(request)
  267. try:
  268. session_id = request_data["session_id"]
  269. except KeyError:
  270. error_text = '{"error": "session_id not provided"}'
  271. raise web.HTTPBadRequest(text=error_text, headers=self._headers)
  272. try:
  273. await self.bridge.signal.wait_for_scan(session_id)
  274. track(user, "$qrcode_scanned")
  275. except Exception as e:
  276. error = {"error": f"Failed waiting for scan. Error: {e}"}
  277. self.log.exception(error["error"])
  278. track(user, "$qrcode_scan_failed", error)
  279. raise web.HTTPBadRequest(text=json.dumps(error), headers=self._headers)
  280. else:
  281. return web.json_response({}, headers=self._acao_headers)
  282. async def link_wait_for_account(self, request: web.Request) -> web.Response:
  283. """
  284. Waits for the link to the user's phone to complete.
  285. Params: a JSON object with the following fields:
  286. * session_id: a session ID that you got from a call to /link/v2/new.
  287. * device_name: the device name that will show up in Linked Devices on the user's device.
  288. Returns: a JSON object representing the user's account.
  289. """
  290. user, request_data = await self._get_request_data(request)
  291. try:
  292. session_id = request_data["session_id"]
  293. device_name = request_data.get("device_name", "Mautrix-Signal bridge")
  294. except KeyError:
  295. error = {"error": "session_id not provided"}
  296. track(user, "$wait_for_account_failed", error)
  297. raise web.HTTPBadRequest(text=json.dumps(error), headers=self._headers)
  298. try:
  299. resp = await self._try_shielded_link(user, session_id, device_name)
  300. track(user, "$wait_for_account_success")
  301. return resp
  302. except Exception as e:
  303. error = {"error": f"Failed waiting for account. Error: {e}"}
  304. self.log.exception(error["error"])
  305. track(user, "$wait_for_account_failed", error)
  306. raise web.HTTPBadRequest(text=json.dumps(error), headers=self._headers)
  307. # endregion
  308. # region Logout
  309. async def logout(self, request: web.Request) -> web.Response:
  310. try:
  311. user = await self.check_token_and_logged_in(request)
  312. await user.logout()
  313. return web.json_response({}, headers=self._acao_headers)
  314. except web.HTTPNotFound:
  315. return web.json_response({"error": "You're not logged in"}, headers=self._acao_headers)
  316. # endregion
  317. # region Start new chat API
  318. async def list_contacts(self, request: web.Request) -> web.Response:
  319. user = await self.check_token_and_logged_in(request)
  320. contacts = await self.bridge.signal.list_contacts(user.username, use_cache=True)
  321. async def transform(profile: Profile) -> JSON:
  322. assert profile.address
  323. puppet = await pu.Puppet.get_by_address(profile.address, create=False)
  324. avatar_url = puppet.avatar_url if puppet else None
  325. return {
  326. "name": profile.name,
  327. "contact_name": profile.contact_name,
  328. "profile_name": profile.profile_name,
  329. "avatar_url": avatar_url,
  330. "address": profile.address.serialize(),
  331. }
  332. return web.json_response(
  333. {
  334. c.address.number: await transform(c)
  335. for c in contacts
  336. if c.address and c.address.number
  337. },
  338. headers=self._acao_headers,
  339. )
  340. async def request_sync(self, request: web.Request) -> web.Response:
  341. user = await self.check_token_and_logged_in(request)
  342. await self.bridge.signal.request_sync(user.username)
  343. return web.json_response({}, headers=self._acao_headers)
  344. async def _resolve_identifier(self, number: str, user: u.User) -> pu.Puppet:
  345. try:
  346. number = normalize_number(number)
  347. except Exception as e:
  348. raise web.HTTPBadRequest(text=json.dumps({"error": str(e)}), headers=self._headers)
  349. try:
  350. puppet: pu.Puppet = await pu.Puppet.get_by_number(
  351. number, resolve_via=user.username, raise_resolve=True
  352. )
  353. except UnregisteredUserError:
  354. error = {"error": f"The phone number {number} is not a registered Signal account"}
  355. raise web.HTTPNotFound(text=json.dumps(error), headers=self._headers)
  356. except Exception:
  357. self.log.exception(f"Unknown error fetching UUID for {number}")
  358. error = {"error": "Unknown error while fetching UUID"}
  359. raise web.HTTPInternalServerError(text=json.dumps(error), headers=self._headers)
  360. if not puppet:
  361. error = {
  362. "error": (
  363. f"The phone number {number} doesn't seem to be a registered Signal account"
  364. )
  365. }
  366. raise web.HTTPNotFound(text=json.dumps(error), headers=self._headers)
  367. return puppet
  368. async def start_pm(self, request: web.Request) -> web.Response:
  369. user = await self.check_token_and_logged_in(request)
  370. puppet = await self._resolve_identifier(request.match_info["number"], user)
  371. portal = await po.Portal.get_by_chat_id(puppet.uuid, receiver=user.username, create=True)
  372. assert portal, "Portal.get_by_chat_id with create=True can't return None"
  373. if portal.mxid:
  374. await portal.main_intent.invite_user(portal.mxid, user.mxid)
  375. just_created = False
  376. else:
  377. await portal.create_matrix_room(user, puppet.address)
  378. just_created = True
  379. return web.json_response(
  380. {
  381. "room_id": portal.mxid,
  382. "just_created": just_created,
  383. "chat_id": puppet.address.serialize(),
  384. "other_user": {
  385. "mxid": puppet.mxid,
  386. "displayname": puppet.name,
  387. "avatar_url": puppet.avatar_url,
  388. },
  389. },
  390. headers=self._acao_headers,
  391. status=201 if just_created else 200,
  392. )
  393. async def resolve_identifier(self, request: web.Request) -> web.Response:
  394. user = await self.check_token_and_logged_in(request)
  395. puppet = await self._resolve_identifier(request.match_info["number"], user)
  396. portal = await po.Portal.get_by_chat_id(puppet.uuid, receiver=user.username, create=False)
  397. return web.json_response(
  398. {
  399. "room_id": portal.mxid if portal else None,
  400. "chat_id": puppet.address.serialize(),
  401. "other_user": {
  402. "mxid": puppet.mxid,
  403. "displayname": puppet.name,
  404. "avatar_url": puppet.avatar_url,
  405. },
  406. },
  407. headers=self._acao_headers,
  408. )
  409. # endregion