Преглед изворни кода

provisioning api: add start new PM endpoint

This endpoint allows you to create PMs with a Signal-registered phone
number.

Example usage:

  curl -X POST \
    -H 'Authorization: Bearer PROVISIONING_TOKEN_HERE' \
    'http://your.matrix.server.url/_matrix/provision/v2/pm/+11234567890?user_id=@test:localhost'
Sumner Evans пре 3 година
родитељ
комит
cb2a8b964d

+ 5 - 0
mausignald/errors.py

@@ -102,6 +102,10 @@ class AttachmentTooLargeError(ResponseError):
         super().__init__(data, message_override="File is over the 100MB limit.")
 
 
+class UnregisteredUserError(ResponseError):
+    pass
+
+
 response_error_types = {
     "invalid_request": RequestValidationFailure,
     "TimeoutException": TimeoutException,
@@ -114,6 +118,7 @@ response_error_types = {
     "AuthorizationFailedError": AuthorizationFailedError,
     "ScanTimeoutError": ScanTimeoutError,
     "OwnProfileKeyDoesNotExistError": OwnProfileKeyDoesNotExistError,
+    "UnregisteredUserError": UnregisteredUserError,
     # TODO add rest from https://gitlab.com/signald/signald/-/tree/main/src/main/java/io/finn/signald/clientprotocol/v1/exceptions
 }
 

+ 5 - 4
mautrix_signal/commands/auth.py

@@ -22,6 +22,7 @@ from mautrix.bridge.commands import HelpSection, command_handler
 from mautrix.types import EventID, ImageInfo, MediaMessageEventContent, MessageType
 
 from .. import puppet as pu
+from ..util import normalize_number
 from .typehint import CommandEvent
 
 try:
@@ -31,7 +32,6 @@ except ImportError:
     qrcode = None
 
 SECTION_AUTH = HelpSection("Authentication", 10, "")
-remove_extra_chars = str.maketrans("", "", " .,-()")
 
 
 async def make_qr(
@@ -127,9 +127,10 @@ async def register(evt: CommandEvent) -> None:
                 evt.args = evt.args[2:]
         else:
             break
-    phone = evt.args[0].translate(remove_extra_chars)
-    if not phone.startswith("+") or not phone[1:].isdecimal():
-        await evt.reply(f"Please enter the phone number in international format (E.164)")
+    try:
+        phone = normalize_number(evt.args[0])
+    except Exception:
+        await evt.reply("Please enter the phone number in international format (E.164)")
         return
     username = await evt.bridge.signal.register(phone, voice=voice, captcha=captcha)
     evt.sender.command_status = {

+ 7 - 10
mautrix_signal/commands/signal.py

@@ -24,7 +24,8 @@ from mautrix.bridge.commands import SECTION_ADMIN, HelpSection, command_handler
 from mautrix.types import EventID
 
 from .. import portal as po, puppet as pu
-from .auth import make_qr, remove_extra_chars
+from ..util import normalize_number
+from .auth import make_qr
 from .typehint import CommandEvent
 
 try:
@@ -37,19 +38,15 @@ SECTION_SIGNAL = HelpSection("Signal actions", 20, "")
 
 
 async def _get_puppet_from_cmd(evt: CommandEvent) -> pu.Puppet | None:
-    if len(evt.args) == 0 or not evt.args[0].startswith("+"):
-        await evt.reply(
-            f"**Usage:** `$cmdprefix+sp {evt.command} <phone>` "
-            "(enter phone number in international format)"
-        )
-        return None
-    phone = "".join(evt.args).translate(remove_extra_chars)
-    if not phone[1:].isdecimal():
+    try:
+        phone = normalize_number("".join(evt.args))
+    except Exception:
         await evt.reply(
             f"**Usage:** `$cmdprefix+sp {evt.command} <phone>` "
             "(enter phone number in international format)"
         )
         return None
+
     puppet: pu.Puppet = await pu.Puppet.get_by_address(Address(number=phone))
     if not puppet.uuid and evt.sender.username:
         uuid = await evt.bridge.signal.find_uuid(evt.sender.username, puppet.number)
@@ -214,7 +211,7 @@ async def mark_trusted(evt: CommandEvent) -> EventID:
         return await evt.reply(
             "**Usage:** `$cmdprefix+sp mark-trusted <recipient phone> [level] <safety number>`"
         )
-    number = evt.args[0].translate(remove_extra_chars)
+    number = normalize_number(evt.args[0])
     remaining_args = evt.args[1:]
     trust_level = TrustLevel.TRUSTED_VERIFIED
     if len(evt.args) > 2 and evt.args[1].upper() in _trust_levels:

+ 1 - 0
mautrix_signal/util/__init__.py

@@ -1,2 +1,3 @@
 from .color_log import ColorFormatter
 from .id_to_str import id_to_str
+from .normalize_number import normalize_number

+ 24 - 0
mautrix_signal/util/normalize_number.py

@@ -0,0 +1,24 @@
+# mautrix-signal - A Matrix-Signal puppeting bridge
+# Copyright (C) 2022 Tulir Asokan, Sumner Evans
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+remove_extra_chars = str.maketrans("", "", " .,-()")
+
+
+def normalize_number(number: str) -> str:
+    phone = number.translate(remove_extra_chars)
+    if not number.startswith("+") or not phone[1:].isdecimal():
+        raise Exception("Phone number must be entered in international format")
+    return phone

+ 67 - 7
mautrix_signal/web/provisioning_api.py

@@ -22,12 +22,18 @@ import logging
 
 from aiohttp import web
 
-from mausignald.errors import InternalError, ScanTimeoutError, TimeoutException
+from mausignald.errors import (
+    InternalError,
+    ScanTimeoutError,
+    TimeoutException,
+    UnregisteredUserError,
+)
 from mausignald.types import Account, Address
 from mautrix.types import UserID
 from mautrix.util.logging import TraceLogger
 
-from .. import user as u
+from .. import portal as po, puppet as pu, user as u
+from ..util import normalize_number
 from .segment_analytics import init as init_segment, track
 
 if TYPE_CHECKING:
@@ -73,6 +79,9 @@ class ProvisioningAPI:
         self.app.router.add_post("/v2/link/wait/scan", self.link_wait_for_scan)
         self.app.router.add_post("/v2/link/wait/account", self.link_wait_for_account)
 
+        # Start new chat API
+        self.app.router.add_post("/v2/pm/{number}", self.start_pm)
+
     @property
     def _acao_headers(self) -> dict[str, str]:
         return {
@@ -117,6 +126,13 @@ class ProvisioningAPI:
 
         return await u.User.get_by_mxid(UserID(user_id))
 
+    async def check_token_and_logged_in(self, request: web.Request) -> "u.User":
+        user = await self.check_token(request)
+        if not await user.is_logged_in():
+            error = {"error": "You're not logged in"}
+            raise web.HTTPNotFound(text=json.dumps(error), headers=self._headers)
+        return user
+
     async def status(self, request: web.Request) -> web.Response:
         user = await self.check_token(request)
         data = {
@@ -332,11 +348,55 @@ class ProvisioningAPI:
             track(user, "$wait_for_account_failed", error)
             raise web.HTTPBadRequest(text=json.dumps(error), headers=self._headers)
 
+    # endregion
+
+    # region Logout
+
     async def logout(self, request: web.Request) -> web.Response:
-        user = await self.check_token(request)
-        if not await user.is_logged_in():
-            raise web.HTTPNotFound(
-                text="""{"error": "You're not logged in"}""", headers=self._headers
-            )
+        user = await self.check_token_and_logged_in(request)
         await user.logout()
         return web.json_response({}, headers=self._acao_headers)
+
+    # endregion
+
+    # region Start new chat API
+
+    async def start_pm(self, request: web.Request) -> web.Response:
+        user = await self.check_token_and_logged_in(request)
+        number = normalize_number(request.match_info.get("number"))
+
+        puppet: pu.Puppet = await pu.Puppet.get_by_address(Address(number=number))
+        if not puppet.uuid and user.username:
+            try:
+                uuid = await self.bridge.signal.find_uuid(user.username, puppet.number)
+                if uuid:
+                    await puppet.handle_uuid_receive(uuid)
+            except UnregisteredUserError:
+                error = {"error": f"The phone number {number} is not a registered Signal account"}
+                raise web.HTTPNotFound(text=json.dumps(error), headers=self._headers)
+            except Exception as e:
+                raise web.HTTPBadRequest(reason=str(e), headers=self._headers)
+
+        if not puppet:
+            error = {"error": f"No puppet was found for {number}"}
+            raise web.HTTPBadRequest(text=json.dumps(error), headers=self._headers)
+
+        portal = await po.Portal.get_by_chat_id(
+            puppet.address, receiver=user.username, create=True
+        )
+        if not portal:
+            error = {"error": f"Failed finding a portal for {puppet.address}"}
+            raise web.HTTPBadRequest(text=json.dumps(error), headers=self._headers)
+
+        if portal.mxid:
+            await portal.main_intent.invite_user(portal.mxid, user.mxid)
+            error = {
+                "error": f"You already have a PM with {number}",
+                "room_id": f"{portal.mxid}",
+            }
+            raise web.HTTPConflict(text=json.dumps(error), headers=self._headers)
+
+        room_id = await portal.create_matrix_room(user, puppet.address)
+        return web.json_response({"room_id": room_id}, headers=self._acao_headers)
+
+    # endregion