provisioning_api.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING
  18. import asyncio
  19. import json
  20. import logging
  21. from aiohttp import web
  22. from attr import asdict
  23. from mausignald.errors import InternalError, TimeoutException
  24. from mausignald.types import Account, Address
  25. from mautrix.types import UserID
  26. from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
  27. from mautrix.util.logging import TraceLogger
  28. from .. import user as u
  29. if TYPE_CHECKING:
  30. from ..__main__ import SignalBridge
  31. class ProvisioningAPI:
  32. log: TraceLogger = logging.getLogger("mau.web.provisioning")
  33. app: web.Application
  34. bridge: "SignalBridge"
  35. def __init__(self, bridge: "SignalBridge", shared_secret: str) -> None:
  36. self.bridge = bridge
  37. self.app = web.Application()
  38. self.shared_secret = shared_secret
  39. # Whoami
  40. self.app.router.add_get("/v1/api/whoami", self.status)
  41. self.app.router.add_get("/v2/whoami", self.status)
  42. # Logout
  43. self.app.router.add_options("/v1/api/logout", self.login_options)
  44. self.app.router.add_post("/v1/api/logout", self.logout)
  45. self.app.router.add_options("/v2/logout", self.login_options)
  46. self.app.router.add_post("/v2/logout", self.logout)
  47. # Link API (will be deprecated soon)
  48. self.app.router.add_options("/v1/api/link", self.login_options)
  49. self.app.router.add_options("/v1/api/link/wait", self.login_options)
  50. self.app.router.add_post("/v1/api/link", self.link)
  51. self.app.router.add_post("/v1/api/link/wait", self.link_wait)
  52. # New Login API
  53. self.app.router.add_options("/v2/link/new", self.login_options)
  54. self.app.router.add_options("/v2/link/wait/scan", self.login_options)
  55. self.app.router.add_options("/v2/link/wait/account", self.login_options)
  56. self.app.router.add_post("/v2/link/new", self.link_new)
  57. self.app.router.add_post("/v2/link/wait/scan", self.link_wait_for_scan)
  58. self.app.router.add_post("/v2/link/wait/account", self.link_wait_for_account)
  59. @property
  60. def _acao_headers(self) -> dict[str, str]:
  61. return {
  62. "Access-Control-Allow-Origin": "*",
  63. "Access-Control-Allow-Headers": "Authorization, Content-Type",
  64. "Access-Control-Allow-Methods": "POST, OPTIONS",
  65. }
  66. @property
  67. def _headers(self) -> dict[str, str]:
  68. return {
  69. **self._acao_headers,
  70. "Content-Type": "application/json",
  71. }
  72. async def login_options(self, _: web.Request) -> web.Response:
  73. return web.Response(status=200, headers=self._headers)
  74. async def check_token(self, request: web.Request) -> "u.User":
  75. try:
  76. token = request.headers["Authorization"]
  77. token = token[len("Bearer ") :]
  78. except KeyError:
  79. raise web.HTTPBadRequest(
  80. text='{"error": "Missing Authorization header"}', headers=self._headers
  81. )
  82. except IndexError:
  83. raise web.HTTPBadRequest(
  84. text='{"error": "Malformed Authorization header"}', headers=self._headers
  85. )
  86. if token != self.shared_secret:
  87. raise web.HTTPForbidden(text='{"error": "Invalid token"}', headers=self._headers)
  88. try:
  89. user_id = request.query["user_id"]
  90. except KeyError:
  91. raise web.HTTPBadRequest(
  92. text='{"error": "Missing user_id query param"}', headers=self._headers
  93. )
  94. if not self.bridge.signal.is_connected:
  95. await self.bridge.signal.wait_for_connected()
  96. return await u.User.get_by_mxid(UserID(user_id))
  97. async def status(self, request: web.Request) -> web.Response:
  98. user = await self.check_token(request)
  99. data = {
  100. "permissions": user.permission_level,
  101. "mxid": user.mxid,
  102. "signal": None,
  103. }
  104. if await user.is_logged_in():
  105. try:
  106. profile = await self.bridge.signal.get_profile(
  107. username=user.username, address=Address(number=user.username)
  108. )
  109. except Exception as e:
  110. self.log.exception(f"Failed to get {user.username}'s profile for whoami")
  111. auth_failed = "org.whispersystems.signalservice.api.push.exceptions.AuthorizationFailedException"
  112. if isinstance(e, InternalError) and auth_failed in e.data.get("exceptions", []):
  113. await user.push_bridge_state(BridgeStateEvent.BAD_CREDENTIALS, error=str(e))
  114. data["signal"] = {
  115. "number": user.username,
  116. "ok": False,
  117. "error": str(e),
  118. }
  119. else:
  120. addr = profile.address if profile else None
  121. number = addr.number if addr else None
  122. uuid = addr.uuid if addr else None
  123. data["signal"] = {
  124. "number": number or user.username,
  125. "uuid": str(uuid or user.uuid or ""),
  126. "name": profile.name if profile else None,
  127. "ok": True,
  128. }
  129. return web.json_response(data, headers=self._acao_headers)
  130. async def _shielded_link(self, user: "u.User", session_id: str, device_name: str) -> Account:
  131. try:
  132. self.log.debug(f"Starting finish link request for {user.mxid} / {session_id}")
  133. account = await self.bridge.signal.finish_link(
  134. session_id=session_id, device_name=device_name, overwrite=True
  135. )
  136. except TimeoutException:
  137. self.log.warning(f"Timed out waiting for linking to finish (session {session_id})")
  138. raise
  139. except Exception:
  140. self.log.exception(
  141. f"Fatal error while waiting for linking to finish (session {session_id})"
  142. )
  143. raise
  144. else:
  145. await user.on_signin(account)
  146. return account
  147. async def _try_shielded_link(
  148. self, user: "u.User", session_id: str, device_name: str
  149. ) -> web.Response:
  150. try:
  151. account = await asyncio.shield(self._shielded_link(user, session_id, device_name))
  152. except asyncio.CancelledError:
  153. error_text = f"Client cancelled link wait request ({session_id}) before it finished"
  154. self.log.warning(error_text)
  155. raise web.HTTPInternalServerError(
  156. text=f'{{"error": "{error_text}"}}', headers=self._headers
  157. )
  158. except TimeoutException:
  159. raise web.HTTPBadRequest(
  160. text='{"error": "Signal linking timed out"}', headers=self._headers
  161. )
  162. except InternalError as ie:
  163. if "java.io.IOException" in ie.exceptions:
  164. raise web.HTTPBadRequest(
  165. text='{"error": "Signald websocket disconnected before linking finished"}',
  166. headers=self._headers,
  167. )
  168. raise web.HTTPInternalServerError(
  169. text='{"error": "Fatal error in Signal linking"}', headers=self._headers
  170. )
  171. except Exception:
  172. raise web.HTTPInternalServerError(
  173. text='{"error": "Fatal error in Signal linking"}', headers=self._headers
  174. )
  175. else:
  176. return web.json_response(account.address.serialize())
  177. # region Old Link API
  178. async def link(self, request: web.Request) -> web.Response:
  179. user = await self.check_token(request)
  180. if await user.is_logged_in():
  181. raise web.HTTPConflict(
  182. text="""{"error": "You're already logged in"}""", headers=self._headers
  183. )
  184. try:
  185. data = await request.json()
  186. except json.JSONDecodeError:
  187. raise web.HTTPBadRequest(text='{"error": "Malformed JSON"}', headers=self._headers)
  188. device_name = data.get("device_name", "Mautrix-Signal bridge")
  189. sess = await self.bridge.signal.start_link()
  190. user.command_status = {
  191. "action": "Link",
  192. "session_id": sess.session_id,
  193. "device_name": device_name,
  194. }
  195. self.log.debug(f"Returning linking URI for {user.mxid} / {sess.session_id}")
  196. return web.json_response({"uri": sess.uri}, headers=self._acao_headers)
  197. async def link_wait(self, request: web.Request) -> web.Response:
  198. user = await self.check_token(request)
  199. if not user.command_status or user.command_status["action"] != "Link":
  200. raise web.HTTPBadRequest(
  201. text='{"error": "No Signal linking started"}', headers=self._headers
  202. )
  203. session_id = user.command_status["session_id"]
  204. device_name = user.command_status["device_name"]
  205. return await self._try_shielded_link(user, session_id, device_name)
  206. # endregion
  207. # region New Link API
  208. async def _get_request_data(self, request: web.Request) -> tuple[u.User, web.Response]:
  209. user = await self.check_token(request)
  210. if await user.is_logged_in():
  211. error_text = """{"error": "You're already logged in"}"""
  212. raise web.HTTPConflict(text=error_text, headers=self._headers)
  213. try:
  214. return user, (await request.json())
  215. except json.JSONDecodeError:
  216. raise web.HTTPBadRequest(text='{"error": "Malformed JSON"}', headers=self._headers)
  217. async def link_new(self, request: web.Request) -> web.Response:
  218. """
  219. Starts a new link session.
  220. Params: none
  221. Returns a JSON object with the following fields:
  222. * session_id: a session ID that should be used for all future link-related commands
  223. (wait_for_scan and wait_for_account).
  224. * uri: a URI that should be used to display the QR code.
  225. """
  226. user, _ = await self._get_request_data(request)
  227. self.log.debug(f"Getting session ID and link URI for {user.mxid}")
  228. sess = await self.bridge.signal.start_link()
  229. self.log.debug(f"Returning session ID and link URI for {user.mxid} / {sess.session_id}")
  230. return web.json_response(asdict(sess), headers=self._acao_headers)
  231. async def link_wait_for_scan(self, request: web.Request) -> web.Response:
  232. """
  233. Waits for the QR code associated with the provided session ID to be scanned.
  234. Params: a JSON object with the following field:
  235. * session_id: a session ID that you got from a call to /link/v2/new.
  236. """
  237. _, request_data = await self._get_request_data(request)
  238. try:
  239. session_id = request_data["session_id"]
  240. except KeyError:
  241. error_text = '{"error": "session_id not provided"}'
  242. raise web.HTTPBadRequest(text=error_text, headers=self._headers)
  243. try:
  244. await self.bridge.signal.wait_for_scan(session_id)
  245. except Exception as e:
  246. error_text = f"Failed waiting for scan. Error: {e}"
  247. self.log.exception(error_text)
  248. self.log.info(e.__class__)
  249. raise web.HTTPBadRequest(text=error_text, headers=self._headers)
  250. else:
  251. return web.json_response({}, headers=self._acao_headers)
  252. async def link_wait_for_account(self, request: web.Request) -> web.Response:
  253. """
  254. Waits for the link to the user's phone to complete.
  255. Params: a JSON object with the following fields:
  256. * session_id: a session ID that you got from a call to /link/v2/new.
  257. * device_name: the device name that will show up in Linked Devices on the user's device.
  258. Returns: a JSON object representing the user's account.
  259. """
  260. user, request_data = await self._get_request_data(request)
  261. try:
  262. session_id = request_data["session_id"]
  263. device_name = request_data.get("device_name", "Mautrix-Signal bridge")
  264. except KeyError:
  265. error_text = '{"error": "session_id not provided"}'
  266. raise web.HTTPBadRequest(text=error_text, headers=self._headers)
  267. return await self._try_shielded_link(user, session_id, device_name)
  268. async def logout(self, request: web.Request) -> web.Response:
  269. user = await self.check_token(request)
  270. if not await user.is_logged_in():
  271. raise web.HTTPNotFound(
  272. text="""{"error": "You're not logged in"}""", headers=self._headers
  273. )
  274. await user.logout()
  275. return web.json_response({}, headers=self._acao_headers)