provisioning_api.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING
  18. import asyncio
  19. import json
  20. import logging
  21. from aiohttp import web
  22. from mausignald.errors import (
  23. InternalError,
  24. ScanTimeoutError,
  25. TimeoutException,
  26. UnregisteredUserError,
  27. )
  28. from mausignald.types import Account, Address, Profile
  29. from mautrix.types import UserID
  30. from mautrix.util.logging import TraceLogger
  31. from .. import portal as po, puppet as pu, user as u
  32. from ..util import normalize_number
  33. from .segment_analytics import init as init_segment, track
  34. if TYPE_CHECKING:
  35. from ..__main__ import SignalBridge
  36. class ProvisioningAPI:
  37. log: TraceLogger = logging.getLogger("mau.web.provisioning")
  38. app: web.Application
  39. bridge: "SignalBridge"
  40. def __init__(
  41. self, bridge: "SignalBridge", shared_secret: str, segment_key: str | None
  42. ) -> None:
  43. self.bridge = bridge
  44. self.app = web.Application()
  45. self.shared_secret = shared_secret
  46. if segment_key:
  47. init_segment(segment_key)
  48. # Whoami
  49. self.app.router.add_get("/v1/api/whoami", self.status)
  50. self.app.router.add_get("/v2/whoami", self.status)
  51. # Logout
  52. self.app.router.add_options("/v1/api/logout", self.login_options)
  53. self.app.router.add_post("/v1/api/logout", self.logout)
  54. self.app.router.add_options("/v2/logout", self.login_options)
  55. self.app.router.add_post("/v2/logout", self.logout)
  56. # Link API (will be deprecated soon)
  57. self.app.router.add_options("/v1/api/link", self.login_options)
  58. self.app.router.add_options("/v1/api/link/wait", self.login_options)
  59. self.app.router.add_post("/v1/api/link", self.link)
  60. self.app.router.add_post("/v1/api/link/wait", self.link_wait)
  61. # New Login API
  62. self.app.router.add_options("/v2/link/new", self.login_options)
  63. self.app.router.add_options("/v2/link/wait/scan", self.login_options)
  64. self.app.router.add_options("/v2/link/wait/account", self.login_options)
  65. self.app.router.add_post("/v2/link/new", self.link_new)
  66. self.app.router.add_post("/v2/link/wait/scan", self.link_wait_for_scan)
  67. self.app.router.add_post("/v2/link/wait/account", self.link_wait_for_account)
  68. # Start new chat API
  69. self.app.router.add_get("/v2/contacts", self.list_contacts)
  70. self.app.router.add_post("/v2/pm/{number}", self.start_pm)
  71. @property
  72. def _acao_headers(self) -> dict[str, str]:
  73. return {
  74. "Access-Control-Allow-Origin": "*",
  75. "Access-Control-Allow-Headers": "Authorization, Content-Type",
  76. "Access-Control-Allow-Methods": "POST, OPTIONS",
  77. }
  78. @property
  79. def _headers(self) -> dict[str, str]:
  80. return {
  81. **self._acao_headers,
  82. "Content-Type": "application/json",
  83. }
  84. async def login_options(self, _: web.Request) -> web.Response:
  85. return web.Response(status=200, headers=self._headers)
  86. async def check_token(self, request: web.Request) -> "u.User":
  87. try:
  88. token = request.headers["Authorization"]
  89. token = token[len("Bearer ") :]
  90. except KeyError:
  91. raise web.HTTPBadRequest(
  92. text='{"error": "Missing Authorization header"}', headers=self._headers
  93. )
  94. except IndexError:
  95. raise web.HTTPBadRequest(
  96. text='{"error": "Malformed Authorization header"}', headers=self._headers
  97. )
  98. if token != self.shared_secret:
  99. raise web.HTTPForbidden(text='{"error": "Invalid token"}', headers=self._headers)
  100. try:
  101. user_id = request.query["user_id"]
  102. except KeyError:
  103. raise web.HTTPBadRequest(
  104. text='{"error": "Missing user_id query param"}', headers=self._headers
  105. )
  106. if not self.bridge.signal.is_connected:
  107. await self.bridge.signal.wait_for_connected()
  108. return await u.User.get_by_mxid(UserID(user_id))
  109. async def check_token_and_logged_in(self, request: web.Request) -> "u.User":
  110. user = await self.check_token(request)
  111. if not await user.is_logged_in():
  112. error = {"error": "You're not logged in"}
  113. raise web.HTTPNotFound(text=json.dumps(error), headers=self._headers)
  114. return user
  115. async def status(self, request: web.Request) -> web.Response:
  116. user = await self.check_token(request)
  117. data = {
  118. "permissions": user.permission_level,
  119. "mxid": user.mxid,
  120. "signal": None,
  121. }
  122. if await user.is_logged_in():
  123. try:
  124. profile = await self.bridge.signal.get_profile(
  125. username=user.username, address=Address(number=user.username)
  126. )
  127. except Exception as e:
  128. self.log.exception(f"Failed to get {user.username}'s profile for whoami")
  129. await user.handle_auth_failure(e)
  130. data["signal"] = {
  131. "number": user.username,
  132. "ok": False,
  133. "error": str(e),
  134. }
  135. else:
  136. addr = profile.address if profile else None
  137. number = addr.number if addr else None
  138. uuid = addr.uuid if addr else None
  139. data["signal"] = {
  140. "number": number or user.username,
  141. "uuid": str(uuid or user.uuid or ""),
  142. "name": profile.name if profile else None,
  143. "ok": True,
  144. }
  145. return web.json_response(data, headers=self._acao_headers)
  146. async def _shielded_link(self, user: "u.User", session_id: str, device_name: str) -> Account:
  147. try:
  148. self.log.debug(f"Starting finish link request for {user.mxid} / {session_id}")
  149. account = await self.bridge.signal.finish_link(
  150. session_id=session_id, device_name=device_name, overwrite=True
  151. )
  152. except TimeoutException:
  153. self.log.warning(f"Timed out waiting for linking to finish (session {session_id})")
  154. raise
  155. except Exception:
  156. self.log.exception(
  157. f"Fatal error while waiting for linking to finish (session {session_id})"
  158. )
  159. raise
  160. else:
  161. await user.on_signin(account)
  162. return account
  163. async def _try_shielded_link(
  164. self, user: "u.User", session_id: str, device_name: str
  165. ) -> web.Response:
  166. try:
  167. account = await asyncio.shield(self._shielded_link(user, session_id, device_name))
  168. except asyncio.CancelledError:
  169. self.log.warning(
  170. f"Client cancelled link wait request ({session_id}) before it finished"
  171. )
  172. raise
  173. except TimeoutException:
  174. raise web.HTTPBadRequest(
  175. text='{"error": "Signal linking timed out"}', headers=self._headers
  176. )
  177. except ScanTimeoutError:
  178. raise web.HTTPBadRequest(
  179. text='{"error": "Signald websocket disconnected before linking finished"}',
  180. headers=self._headers,
  181. )
  182. except InternalError:
  183. raise web.HTTPInternalServerError(
  184. text='{"error": "Fatal error in Signal linking"}', headers=self._headers
  185. )
  186. except Exception:
  187. raise web.HTTPInternalServerError(
  188. text='{"error": "Fatal error in Signal linking"}', headers=self._headers
  189. )
  190. else:
  191. return web.json_response(account.address.serialize())
  192. # region Old Link API
  193. async def link(self, request: web.Request) -> web.Response:
  194. user = await self.check_token(request)
  195. if await user.is_logged_in():
  196. raise web.HTTPConflict(
  197. text="""{"error": "You're already logged in"}""", headers=self._headers
  198. )
  199. try:
  200. data = await request.json()
  201. except json.JSONDecodeError:
  202. raise web.HTTPBadRequest(text='{"error": "Malformed JSON"}', headers=self._headers)
  203. device_name = data.get("device_name", "Mautrix-Signal bridge")
  204. sess = await self.bridge.signal.start_link()
  205. user.command_status = {
  206. "action": "Link",
  207. "session_id": sess.session_id,
  208. "device_name": device_name,
  209. }
  210. self.log.debug(f"Returning linking URI for {user.mxid} / {sess.session_id}")
  211. return web.json_response({"uri": sess.uri}, headers=self._acao_headers)
  212. async def link_wait(self, request: web.Request) -> web.Response:
  213. user = await self.check_token(request)
  214. if not user.command_status or user.command_status["action"] != "Link":
  215. raise web.HTTPBadRequest(
  216. text='{"error": "No Signal linking started"}', headers=self._headers
  217. )
  218. session_id = user.command_status["session_id"]
  219. device_name = user.command_status["device_name"]
  220. return await self._try_shielded_link(user, session_id, device_name)
  221. # endregion
  222. # region New Link API
  223. async def _get_request_data(self, request: web.Request) -> tuple[u.User, dict]:
  224. user = await self.check_token(request)
  225. if await user.is_logged_in():
  226. error_text = """{"error": "You're already logged in"}"""
  227. raise web.HTTPConflict(text=error_text, headers=self._headers)
  228. try:
  229. return user, (await request.json())
  230. except json.JSONDecodeError:
  231. raise web.HTTPBadRequest(text='{"error": "Malformed JSON"}', headers=self._headers)
  232. async def link_new(self, request: web.Request) -> web.Response:
  233. """
  234. Starts a new link session.
  235. Params: none
  236. Returns a JSON object with the following fields:
  237. * session_id: a session ID that should be used for all future link-related commands
  238. (wait_for_scan and wait_for_account).
  239. * uri: a URI that should be used to display the QR code.
  240. """
  241. user, _ = await self._get_request_data(request)
  242. self.log.debug(f"Getting session ID and link URI for {user.mxid}")
  243. try:
  244. sess = await self.bridge.signal.start_link()
  245. track(user, "$link_new_success")
  246. self.log.debug(
  247. f"Returning session ID and link URI for {user.mxid} / {sess.session_id}"
  248. )
  249. return web.json_response(sess.serialize(), headers=self._acao_headers)
  250. except Exception as e:
  251. error = {"error": f"Getting a new link failed: {e}"}
  252. track(user, "$link_new_failed", error)
  253. raise web.HTTPBadRequest(text=json.dumps(error), headers=self._headers)
  254. async def link_wait_for_scan(self, request: web.Request) -> web.Response:
  255. """
  256. Waits for the QR code associated with the provided session ID to be scanned.
  257. Params: a JSON object with the following field:
  258. * session_id: a session ID that you got from a call to /link/v2/new.
  259. """
  260. user, request_data = await self._get_request_data(request)
  261. try:
  262. session_id = request_data["session_id"]
  263. except KeyError:
  264. error_text = '{"error": "session_id not provided"}'
  265. raise web.HTTPBadRequest(text=error_text, headers=self._headers)
  266. try:
  267. await self.bridge.signal.wait_for_scan(session_id)
  268. track(user, "$qrcode_scanned")
  269. except Exception as e:
  270. error = {"error": f"Failed waiting for scan. Error: {e}"}
  271. self.log.exception(error["error"])
  272. track(user, "$qrcode_scan_failed", error)
  273. raise web.HTTPBadRequest(text=json.dumps(error), headers=self._headers)
  274. else:
  275. return web.json_response({}, headers=self._acao_headers)
  276. async def link_wait_for_account(self, request: web.Request) -> web.Response:
  277. """
  278. Waits for the link to the user's phone to complete.
  279. Params: a JSON object with the following fields:
  280. * session_id: a session ID that you got from a call to /link/v2/new.
  281. * device_name: the device name that will show up in Linked Devices on the user's device.
  282. Returns: a JSON object representing the user's account.
  283. """
  284. user, request_data = await self._get_request_data(request)
  285. try:
  286. session_id = request_data["session_id"]
  287. device_name = request_data.get("device_name", "Mautrix-Signal bridge")
  288. except KeyError:
  289. error = {"error": "session_id not provided"}
  290. track(user, "$wait_for_account_failed", error)
  291. raise web.HTTPBadRequest(text=json.dumps(error), headers=self._headers)
  292. try:
  293. resp = await self._try_shielded_link(user, session_id, device_name)
  294. track(user, "$wait_for_account_success")
  295. return resp
  296. except Exception as e:
  297. error = {"error": f"Failed waiting for account. Error: {e}"}
  298. self.log.exception(error["error"])
  299. track(user, "$wait_for_account_failed", error)
  300. raise web.HTTPBadRequest(text=json.dumps(error), headers=self._headers)
  301. # endregion
  302. # region Logout
  303. async def logout(self, request: web.Request) -> web.Response:
  304. user = await self.check_token_and_logged_in(request)
  305. await user.logout()
  306. return web.json_response({}, headers=self._acao_headers)
  307. # endregion
  308. # region Start new chat API
  309. async def list_contacts(self, request: web.Request) -> web.Response:
  310. user = await self.check_token_and_logged_in(request)
  311. contacts = await self.bridge.signal.list_contacts(user.username)
  312. return web.json_response(
  313. {
  314. str(c.address.number): {
  315. "name": c.name,
  316. "address": Address.serialize(c.address),
  317. }
  318. for c in contacts
  319. if c.address is not None
  320. },
  321. headers=self._acao_headers,
  322. )
  323. async def start_pm(self, request: web.Request) -> web.Response:
  324. user = await self.check_token_and_logged_in(request)
  325. number = normalize_number(request.match_info.get("number"))
  326. puppet: pu.Puppet = await pu.Puppet.get_by_address(Address(number=number))
  327. if not puppet.uuid and user.username:
  328. try:
  329. uuid = await self.bridge.signal.find_uuid(user.username, puppet.number)
  330. if uuid:
  331. await puppet.handle_uuid_receive(uuid)
  332. except UnregisteredUserError:
  333. error = {"error": f"The phone number {number} is not a registered Signal account"}
  334. raise web.HTTPNotFound(text=json.dumps(error), headers=self._headers)
  335. except Exception as e:
  336. raise web.HTTPBadRequest(reason=str(e), headers=self._headers)
  337. if not puppet:
  338. error = {"error": f"No puppet was found for {number}"}
  339. raise web.HTTPBadRequest(text=json.dumps(error), headers=self._headers)
  340. portal = await po.Portal.get_by_chat_id(
  341. puppet.address, receiver=user.username, create=True
  342. )
  343. if not portal:
  344. error = {"error": f"Failed finding a portal for {puppet.address}"}
  345. raise web.HTTPBadRequest(text=json.dumps(error), headers=self._headers)
  346. if portal.mxid:
  347. await portal.main_intent.invite_user(portal.mxid, user.mxid)
  348. error = {
  349. "error": f"You already have a PM with {number}",
  350. "room_id": f"{portal.mxid}",
  351. }
  352. raise web.HTTPConflict(text=json.dumps(error), headers=self._headers)
  353. room_id = await portal.create_matrix_room(user, puppet.address)
  354. return web.json_response({"room_id": room_id}, headers=self._acao_headers)
  355. # endregion