provisioning_api.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import TYPE_CHECKING, Awaitable, Dict
  17. import asyncio
  18. import json
  19. import logging
  20. from aiohttp import web
  21. from mautrix.types import UserID
  22. from mautrix.util.logging import TraceLogger
  23. from mausignald.errors import InternalError, TimeoutException
  24. from mausignald.types import Account, Address
  25. from .. import user as u
  26. if TYPE_CHECKING:
  27. from ..__main__ import SignalBridge
  28. class ProvisioningAPI:
  29. log: TraceLogger = logging.getLogger("mau.web.provisioning")
  30. app: web.Application
  31. bridge: "SignalBridge"
  32. def __init__(self, bridge: "SignalBridge", shared_secret: str) -> None:
  33. self.bridge = bridge
  34. self.app = web.Application()
  35. self.shared_secret = shared_secret
  36. self.app.router.add_get("/api/whoami", self.status)
  37. self.app.router.add_options("/api/link", self.login_options)
  38. self.app.router.add_options("/api/link/wait", self.login_options)
  39. # self.app.router.add_options("/api/register", self.login_options)
  40. # self.app.router.add_options("/api/register/code", self.login_options)
  41. self.app.router.add_options("/api/logout", self.login_options)
  42. self.app.router.add_post("/api/link", self.link)
  43. self.app.router.add_post("/api/link/wait", self.link_wait)
  44. # self.app.router.add_post("/api/register", self.register)
  45. # self.app.router.add_post("/api/register/code", self.register_code)
  46. self.app.router.add_post("/api/logout", self.logout)
  47. @property
  48. def _acao_headers(self) -> Dict[str, str]:
  49. return {
  50. "Access-Control-Allow-Origin": "*",
  51. "Access-Control-Allow-Headers": "Authorization, Content-Type",
  52. "Access-Control-Allow-Methods": "POST, OPTIONS",
  53. }
  54. @property
  55. def _headers(self) -> Dict[str, str]:
  56. return {
  57. **self._acao_headers,
  58. "Content-Type": "application/json",
  59. }
  60. async def login_options(self, _: web.Request) -> web.Response:
  61. return web.Response(status=200, headers=self._headers)
  62. async def check_token(self, request: web.Request) -> "u.User":
  63. try:
  64. token = request.headers["Authorization"]
  65. token = token[len("Bearer ") :]
  66. except KeyError:
  67. raise web.HTTPBadRequest(
  68. text='{"error": "Missing Authorization header"}', headers=self._headers
  69. )
  70. except IndexError:
  71. raise web.HTTPBadRequest(
  72. text='{"error": "Malformed Authorization header"}', headers=self._headers
  73. )
  74. if token != self.shared_secret:
  75. raise web.HTTPForbidden(text='{"error": "Invalid token"}', headers=self._headers)
  76. try:
  77. user_id = request.query["user_id"]
  78. except KeyError:
  79. raise web.HTTPBadRequest(
  80. text='{"error": "Missing user_id query param"}', headers=self._headers
  81. )
  82. if not self.bridge.signal.is_connected:
  83. await self.bridge.signal.wait_for_connected()
  84. return await u.User.get_by_mxid(UserID(user_id))
  85. async def status(self, request: web.Request) -> web.Response:
  86. user = await self.check_token(request)
  87. data = {
  88. "permissions": user.permission_level,
  89. "mxid": user.mxid,
  90. "signal": None,
  91. }
  92. if await user.is_logged_in():
  93. try:
  94. profile = await self.bridge.signal.get_profile(
  95. username=user.username, address=Address(number=user.username)
  96. )
  97. except Exception as e:
  98. self.log.exception(f"Failed to get {user.username}'s profile for whoami")
  99. data["signal"] = {
  100. "number": user.username,
  101. "ok": False,
  102. "error": str(e),
  103. }
  104. else:
  105. addr = profile.address if profile else None
  106. number = addr.number if addr else None
  107. uuid = addr.uuid if addr else None
  108. data["signal"] = {
  109. "number": number or user.username,
  110. "uuid": str(uuid or user.uuid or ""),
  111. "name": profile.name if profile else None,
  112. "ok": True,
  113. }
  114. return web.json_response(data, headers=self._acao_headers)
  115. async def link(self, request: web.Request) -> web.Response:
  116. user = await self.check_token(request)
  117. if await user.is_logged_in():
  118. raise web.HTTPConflict(
  119. text="""{"error": "You're already logged in"}""", headers=self._headers
  120. )
  121. try:
  122. data = await request.json()
  123. except json.JSONDecodeError:
  124. raise web.HTTPBadRequest(text='{"error": "Malformed JSON"}', headers=self._headers)
  125. device_name = data.get("device_name", "Mautrix-Signal bridge")
  126. sess = await self.bridge.signal.start_link()
  127. user.command_status = {
  128. "action": "Link",
  129. "session_id": sess.session_id,
  130. "device_name": device_name,
  131. }
  132. self.log.debug(f"Returning linking URI for {user.mxid} / {sess.session_id}")
  133. return web.json_response({"uri": sess.uri}, headers=self._acao_headers)
  134. async def _shielded_link(self, user: "u.User", session_id: str, device_name: str) -> Account:
  135. try:
  136. self.log.debug(f"Starting finish link request for {user.mxid} / {session_id}")
  137. account = await self.bridge.signal.finish_link(
  138. session_id=session_id, overwrite=True, device_name=device_name
  139. )
  140. except TimeoutException:
  141. self.log.warning(f"Timed out waiting for linking to finish (session {session_id})")
  142. raise
  143. except Exception:
  144. self.log.exception(
  145. "Fatal error while waiting for linking to finish " f"(session {session_id})"
  146. )
  147. raise
  148. else:
  149. await user.on_signin(account)
  150. return account
  151. async def link_wait(self, request: web.Request) -> web.Response:
  152. user = await self.check_token(request)
  153. if not user.command_status or user.command_status["action"] != "Link":
  154. raise web.HTTPBadRequest(
  155. text='{"error": "No Signal linking started"}', headers=self._headers
  156. )
  157. session_id = user.command_status["session_id"]
  158. device_name = user.command_status["device_name"]
  159. try:
  160. account = await asyncio.shield(self._shielded_link(user, session_id, device_name))
  161. except asyncio.CancelledError:
  162. self.log.warning(
  163. f"Client cancelled link wait request ({session_id})" " before it finished"
  164. )
  165. except TimeoutException:
  166. raise web.HTTPBadRequest(
  167. text='{"error": "Signal linking timed out"}', headers=self._headers
  168. )
  169. except InternalError as ie:
  170. if "java.io.IOException" in ie.exceptions:
  171. raise web.HTTPBadRequest(
  172. text='{"error": "Signald websocket disconnected before linking finished"}',
  173. headers=self._headers,
  174. )
  175. raise web.HTTPInternalServerError(
  176. text='{"error": "Fatal error in Signal linking"}', headers=self._headers
  177. )
  178. except Exception:
  179. raise web.HTTPInternalServerError(
  180. text='{"error": "Fatal error in Signal linking"}', headers=self._headers
  181. )
  182. else:
  183. return web.json_response(account.address.serialize())
  184. async def logout(self, request: web.Request) -> web.Response:
  185. user = await self.check_token(request)
  186. if not await user.is_logged_in():
  187. raise web.HTTPNotFound(
  188. text="""{"error": "You're not logged in"}""", headers=self._headers
  189. )
  190. await user.logout()
  191. return web.json_response({}, headers=self._acao_headers)