|
@@ -15,11 +15,12 @@
|
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
from typing import Awaitable, Dict, TYPE_CHECKING
|
|
|
import logging
|
|
|
+import asyncio
|
|
|
import json
|
|
|
|
|
|
from aiohttp import web
|
|
|
|
|
|
-from mausignald.types import Address
|
|
|
+from mausignald.types import Address, Account
|
|
|
from mausignald.errors import TimeoutException
|
|
|
from mautrix.types import UserID
|
|
|
from mautrix.util.logging import TraceLogger
|
|
@@ -128,6 +129,19 @@ class ProvisioningAPI:
|
|
|
|
|
|
return web.json_response({"uri": sess.uri}, headers=self._acao_headers)
|
|
|
|
|
|
+ async def _shielded_link(self, user: 'u.User', session_id: str, device_name: str) -> Account:
|
|
|
+ try:
|
|
|
+ account = await self.bridge.signal.finish_link(session_id=session_id,
|
|
|
+ device_name=device_name)
|
|
|
+ except TimeoutException:
|
|
|
+ raise
|
|
|
+ except Exception:
|
|
|
+ self.log.exception("Fatal error while waiting for linking to finish")
|
|
|
+ raise
|
|
|
+ else:
|
|
|
+ await user.on_signin(account)
|
|
|
+ return account
|
|
|
+
|
|
|
async def link_wait(self, request: web.Request) -> web.Response:
|
|
|
user = await self.check_token(request)
|
|
|
if not user.command_status or user.command_status["action"] != "Link":
|
|
@@ -136,17 +150,16 @@ class ProvisioningAPI:
|
|
|
session_id = user.command_status["session_id"]
|
|
|
device_name = user.command_status["device_name"]
|
|
|
try:
|
|
|
- account = await self.bridge.signal.finish_link(session_id=session_id,
|
|
|
- device_name=device_name)
|
|
|
+ account = await asyncio.shield(self._shielded_link(user, session_id, device_name))
|
|
|
+ except asyncio.CancelledError:
|
|
|
+ self.log.warning("Client cancelled link wait request before it finished")
|
|
|
except TimeoutException:
|
|
|
raise web.HTTPBadRequest(text='{"error": "Signal linking timed out"}',
|
|
|
headers=self._headers)
|
|
|
except Exception:
|
|
|
- self.log.exception("Fatal error while waiting for linking to finish")
|
|
|
raise web.HTTPInternalServerError(text='{"error": "Fatal error in Signal linking"}',
|
|
|
headers=self._headers)
|
|
|
else:
|
|
|
- await user.on_signin(account)
|
|
|
return web.json_response(account.address.serialize())
|
|
|
|
|
|
async def logout(self, request: web.Request) -> web.Response:
|