Kaynağa Gözat

formatting: use black and add CI job to enforce styling

Sumner Evans 3 yıl önce
ebeveyn
işleme
4f31ce41e0

+ 1 - 1
.editorconfig

@@ -17,5 +17,5 @@ trim_trailing_whitespace = false
 [*.{yaml,yml,py,md}]
 indent_style = space
 
-[{.gitlab-ci.yml,*.md}]
+[{.gitlab-ci.yml,*.md,.github/workflows/*.yml}]
 indent_size = 2

+ 18 - 0
.github/workflows/python-lint.yml

@@ -0,0 +1,18 @@
+name: Python lint
+
+on: [push, pull_request]
+
+jobs:
+  lint:
+    runs-on: ubuntu-latest
+    steps:
+    - uses: actions/checkout@v2
+    - uses: actions/setup-python@v2
+      with:
+        python-version: "3.10"
+    - uses: isort/isort-action@master
+      with:
+        sortPaths: "./mausignald ./mautrix_signal"
+    - uses: psf/black@21.12b0
+      with:
+        src: "./mausignald ./mautrix_signal"

+ 6 - 2
mausignald/errors.py

@@ -26,8 +26,12 @@ class NotConnected(RPCError):
 
 
 class ResponseError(RPCError):
-    def __init__(self, data: Dict[str, Any], error_type: Optional[str] = None,
-                 message_override: Optional[str] = None) -> None:
+    def __init__(
+        self,
+        data: Dict[str, Any],
+        error_type: Optional[str] = None,
+        message_override: Optional[str] = None,
+    ) -> None:
         self.data = data
         msg = message_override or data["message"]
         if error_type:

+ 16 - 8
mausignald/rpc.py

@@ -36,8 +36,12 @@ class SignaldRPCClient:
     _response_waiters: Dict[UUID, asyncio.Future]
     _rpc_event_handlers: Dict[str, List[EventHandler]]
 
-    def __init__(self, socket_path: str, log: Optional[TraceLogger] = None,
-                 loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
+    def __init__(
+        self,
+        socket_path: str,
+        log: Optional[TraceLogger] = None,
+        loop: Optional[asyncio.AbstractEventLoop] = None,
+    ) -> None:
         self.socket_path = socket_path
         self.log = log or logging.getLogger("mausignald")
         self.loop = loop or asyncio.get_event_loop()
@@ -67,7 +71,8 @@ class SignaldRPCClient:
         while True:
             try:
                 self._reader, self._writer = await asyncio.open_unix_connection(
-                    self.socket_path, limit=_SOCKET_LIMIT)
+                    self.socket_path, limit=_SOCKET_LIMIT
+                )
             except OSError as e:
                 self.log.error(f"Connection to {self.socket_path} failed: {e}")
                 await asyncio.sleep(5)
@@ -177,8 +182,9 @@ class SignaldRPCClient:
         self._reader = None
         self._writer = None
 
-    def _create_request(self, command: str, req_id: Optional[UUID] = None, **data: Any
-                        ) -> Tuple[asyncio.Future, Dict[str, Any]]:
+    def _create_request(
+        self, command: str, req_id: Optional[UUID] = None, **data: Any
+    ) -> Tuple[asyncio.Future, Dict[str, Any]]:
         req_id = req_id or uuid4()
         req = {"id": str(req_id), "type": command, **data}
         self.log.trace("Request %s: %s %s", req_id, command, data)
@@ -196,7 +202,8 @@ class SignaldRPCClient:
             if not waiter.done():
                 self.log.trace(f"Abandoning response for {req_id}")
                 waiter.set_exception(
-                    NotConnected("Disconnected from signald before RPC completed"))
+                    NotConnected("Disconnected from signald before RPC completed")
+                )
 
     async def _send_request(self, data: Dict[str, Any]) -> None:
         if self._writer is None:
@@ -207,8 +214,9 @@ class SignaldRPCClient:
         await self._writer.drain()
         self.log.trace("Sent data to server server: %s", data)
 
-    async def _raw_request(self, command: str, req_id: Optional[UUID] = None, **data: Any
-                           ) -> Tuple[str, Dict[str, Any]]:
+    async def _raw_request(
+        self, command: str, req_id: Optional[UUID] = None, **data: Any
+    ) -> Tuple[str, Dict[str, Any]]:
         future, data = self._create_request(command, req_id, **data)
         await self._send_request(data)
         return await asyncio.shield(future)

+ 156 - 77
mausignald/signald.py

@@ -10,11 +10,26 @@ from mautrix.util.logging import TraceLogger
 
 from .rpc import CONNECT_EVENT, DISCONNECT_EVENT, SignaldRPCClient
 from .errors import UnexpectedError, UnexpectedResponse
-from .types import (Address, Quote, Attachment, Reaction, Account, Message, DeviceInfo, Group,
-                    Profile, GroupID, GetIdentitiesResponse, GroupV2, Mention, LinkSession,
-                    WebsocketConnectionState, WebsocketConnectionStateChangeEvent)
-
-T = TypeVar('T')
+from .types import (
+    Address,
+    Quote,
+    Attachment,
+    Reaction,
+    Account,
+    Message,
+    DeviceInfo,
+    Group,
+    Profile,
+    GroupID,
+    GetIdentitiesResponse,
+    GroupV2,
+    Mention,
+    LinkSession,
+    WebsocketConnectionState,
+    WebsocketConnectionStateChangeEvent,
+)
+
+T = TypeVar("T")
 EventHandler = Callable[[T], Awaitable[None]]
 
 
@@ -22,15 +37,19 @@ class SignaldClient(SignaldRPCClient):
     _event_handlers: Dict[Type[T], List[EventHandler]]
     _subscriptions: Set[str]
 
-    def __init__(self, socket_path: str = "/var/run/signald/signald.sock",
-                 log: Optional[TraceLogger] = None,
-                 loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
+    def __init__(
+        self,
+        socket_path: str = "/var/run/signald/signald.sock",
+        log: Optional[TraceLogger] = None,
+        loop: Optional[asyncio.AbstractEventLoop] = None,
+    ) -> None:
         super().__init__(socket_path, log, loop)
         self._event_handlers = {}
         self._subscriptions = set()
         self.add_rpc_handler("message", self._parse_message)
-        self.add_rpc_handler("websocket_connection_state_change",
-                             self._websocket_connection_state_change)
+        self.add_rpc_handler(
+            "websocket_connection_state_change", self._websocket_connection_state_change
+        )
         self.add_rpc_handler("version", self._log_version)
         self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
         self.add_rpc_handler(DISCONNECT_EVENT, self._on_disconnect)
@@ -111,12 +130,13 @@ class SignaldClient(SignaldRPCClient):
                 evt = WebsocketConnectionStateChangeEvent(
                     state=WebsocketConnectionState.SOCKET_DISCONNECTED,
                     account=username,
-                    exception="Disconnected from signald"
+                    exception="Disconnected from signald",
                 )
                 await self._run_event_handler(evt)
 
-    async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
-                       ) -> str:
+    async def register(
+        self, phone: str, voice: bool = False, captcha: Optional[str] = None
+    ) -> str:
         resp = await self.request_v1("register", account=phone, voice=voice, captcha=captcha)
         return resp["account_id"]
 
@@ -127,15 +147,18 @@ class SignaldClient(SignaldRPCClient):
     async def start_link(self) -> LinkSession:
         return LinkSession.deserialize(await self.request_v1("generate_linking_uri"))
 
-    async def finish_link(self, session_id: str, device_name: str = "mausignald",
-                          overwrite: bool = False) -> Account:
-        resp = await self.request_v1("finish_link", device_name=device_name, session_id=session_id,
-                                     overwrite=overwrite)
+    async def finish_link(
+        self, session_id: str, device_name: str = "mausignald", overwrite: bool = False
+    ) -> Account:
+        resp = await self.request_v1(
+            "finish_link", device_name=device_name, session_id=session_id, overwrite=overwrite
+        )
         return Account.deserialize(resp)
 
     @staticmethod
-    def _recipient_to_args(recipient: Union[Address, GroupID], simple_name: bool = False
-                           ) -> Dict[str, Any]:
+    def _recipient_to_args(
+        recipient: Union[Address, GroupID], simple_name: bool = False
+    ) -> Dict[str, Any]:
         if isinstance(recipient, Address):
             recipient = recipient.serialize()
             field_name = "address" if simple_name else "recipientAddress"
@@ -143,27 +166,49 @@ class SignaldClient(SignaldRPCClient):
             field_name = "group" if simple_name else "recipientGroupId"
         return {field_name: recipient}
 
-    async def react(self, username: str, recipient: Union[Address, GroupID],
-                    reaction: Reaction) -> None:
-        await self.request_v1("react", username=username, reaction=reaction.serialize(),
-                              **self._recipient_to_args(recipient))
-
-    async def remote_delete(self, username: str, recipient: Union[Address, GroupID], timestamp: int
-                            ) -> None:
-        await self.request_v1("remote_delete", account=username, timestamp=timestamp,
-                              **self._recipient_to_args(recipient, simple_name=True))
-
-    async def send(self, username: str, recipient: Union[Address, GroupID], body: str,
-                   quote: Optional[Quote] = None, attachments: Optional[List[Attachment]] = None,
-                   mentions: Optional[List[Mention]] = None, timestamp: Optional[int] = None
-                   ) -> None:
+    async def react(
+        self, username: str, recipient: Union[Address, GroupID], reaction: Reaction
+    ) -> None:
+        await self.request_v1(
+            "react",
+            username=username,
+            reaction=reaction.serialize(),
+            **self._recipient_to_args(recipient),
+        )
+
+    async def remote_delete(
+        self, username: str, recipient: Union[Address, GroupID], timestamp: int
+    ) -> None:
+        await self.request_v1(
+            "remote_delete",
+            account=username,
+            timestamp=timestamp,
+            **self._recipient_to_args(recipient, simple_name=True),
+        )
+
+    async def send(
+        self,
+        username: str,
+        recipient: Union[Address, GroupID],
+        body: str,
+        quote: Optional[Quote] = None,
+        attachments: Optional[List[Attachment]] = None,
+        mentions: Optional[List[Mention]] = None,
+        timestamp: Optional[int] = None,
+    ) -> None:
         serialized_quote = quote.serialize() if quote else None
         serialized_attachments = [attachment.serialize() for attachment in (attachments or [])]
         serialized_mentions = [mention.serialize() for mention in (mentions or [])]
-        resp = await self.request_v1("send", username=username, messageBody=body,
-                                     attachments=serialized_attachments, quote=serialized_quote,
-                                     mentions=serialized_mentions, timestamp=timestamp,
-                                     **self._recipient_to_args(recipient))
+        resp = await self.request_v1(
+            "send",
+            username=username,
+            messageBody=body,
+            attachments=serialized_attachments,
+            quote=serialized_quote,
+            mentions=serialized_mentions,
+            timestamp=timestamp,
+            **self._recipient_to_args(recipient),
+        )
         errors = []
 
         # We handle unregisteredFailure a little differently than other errors. If there are no
@@ -173,9 +218,8 @@ class SignaldClient(SignaldRPCClient):
         successful_send_count = 0
         results = resp.get("results", [])
         for result in results:
-            number = (
-                result.get("address", {}).get("number") or result.get("address", {}).get("uuid")
-            )
+            address = result.get("addres", {})
+            number = address.get("number") or address.get("uuid")
             proof_required_failure = result.get("proof_required_failure")
             if result.get("networkFailure", False):
                 errors.append(f"Network failure occurred while sending message to {number}.")
@@ -186,9 +230,10 @@ class SignaldClient(SignaldRPCClient):
             elif result.get("identityFailure", ""):
                 errors.append(
                     f"Identity failure occurred while sending message to {number}. New identity: "
-                    f"{result['identityFailure']}")
+                    f"{result['identityFailure']}"
+                )
             elif proof_required_failure:
-                options = proof_required_failure.get('options')
+                options = proof_required_failure.get("options")
                 self.log.warning(
                     f"Proof Required Failure {options}. "
                     f"Retry after: {proof_required_failure.get('retry_after')}. "
@@ -206,17 +251,26 @@ class SignaldClient(SignaldRPCClient):
                     await self.request_v1("submit_challenge")
             else:
                 successful_send_count += 1
-        self.log.info(f"Successfully sent message to {successful_send_count}/{len(results)} users in {recipient}")
+        self.log.info(
+            f"Successfully sent message to {successful_send_count}/{len(results)} users in {recipient}"
+        )
         if errors or successful_send_count == 0:
             raise Exception("\n".join(errors + unregistered_failures))
 
-    async def send_receipt(self, username: str, sender: Address, timestamps: List[int],
-                           when: Optional[int] = None, read: bool = False) -> None:
+    async def send_receipt(
+        self,
+        username: str,
+        sender: Address,
+        timestamps: List[int],
+        when: Optional[int] = None,
+        read: bool = False,
+    ) -> None:
         if not read:
             # TODO implement
             return
-        await self.request_v1("mark_read", account=username, timestamps=timestamps, when=when,
-                              to=sender.serialize())
+        await self.request_v1(
+            "mark_read", account=username, timestamps=timestamps, when=when, to=sender.serialize()
+        )
 
     async def list_accounts(self) -> List[Account]:
         resp = await self.request_v1("list_accounts")
@@ -242,19 +296,28 @@ class SignaldClient(SignaldRPCClient):
         v2 = [GroupV2.deserialize(group) for group in resp.get("groups", [])]
         return legacy + v2
 
-    async def update_group(self, username: str, group_id: GroupID, title: Optional[str] = None,
-                           avatar_path: Optional[str] = None,
-                           add_members: Optional[List[Address]] = None,
-                           remove_members: Optional[List[Address]] = None
-                           ) -> Union[Group, GroupV2, None]:
-        update_params = {key: value for key, value in {
-            "groupID": group_id,
-            "avatar": avatar_path,
-            "title": title,
-            "addMembers": [addr.serialize() for addr in add_members] if add_members else None,
-            "removeMembers": ([addr.serialize() for addr in remove_members]
-                              if remove_members else None),
-        }.items() if value is not None}
+    async def update_group(
+        self,
+        username: str,
+        group_id: GroupID,
+        title: Optional[str] = None,
+        avatar_path: Optional[str] = None,
+        add_members: Optional[List[Address]] = None,
+        remove_members: Optional[List[Address]] = None,
+    ) -> Union[Group, GroupV2, None]:
+        update_params = {
+            key: value
+            for key, value in {
+                "groupID": group_id,
+                "avatar": avatar_path,
+                "title": title,
+                "addMembers": [addr.serialize() for addr in add_members] if add_members else None,
+                "removeMembers": (
+                    [addr.serialize() for addr in remove_members] if remove_members else None
+                ),
+            }.items()
+            if value is not None
+        }
         resp = await self.request_v1("update_group", account=username, **update_params)
         if "v1" in resp:
             return Group.deserialize(resp["v1"])
@@ -267,21 +330,25 @@ class SignaldClient(SignaldRPCClient):
         resp = await self.request_v1("accept_invitation", account=username, groupID=group_id)
         return GroupV2.deserialize(resp)
 
-    async def get_group(self, username: str, group_id: GroupID, revision: int = -1
-                        ) -> Optional[GroupV2]:
-        resp = await self.request_v1("get_group", account=username, groupID=group_id,
-                                     revision=revision)
+    async def get_group(
+        self, username: str, group_id: GroupID, revision: int = -1
+    ) -> Optional[GroupV2]:
+        resp = await self.request_v1(
+            "get_group", account=username, groupID=group_id, revision=revision
+        )
         if "id" not in resp:
             return None
         return GroupV2.deserialize(resp)
 
-    async def get_profile(self, username: str, address: Address, use_cache: bool = False
-                          ) -> Optional[Profile]:
+    async def get_profile(
+        self, username: str, address: Address, use_cache: bool = False
+    ) -> Optional[Profile]:
         try:
             # async is a reserved keyword, so can't pass it as a normal parameter
             kwargs = {"async": use_cache}
-            resp = await self.request_v1("get_profile", account=username,
-                                         address=address.serialize(), **kwargs)
+            resp = await self.request_v1(
+                "get_profile", account=username, address=address.serialize(), **kwargs
+            )
         except UnexpectedResponse as e:
             if e.resp_type == "profile_not_available":
                 return None
@@ -289,12 +356,14 @@ class SignaldClient(SignaldRPCClient):
         return Profile.deserialize(resp)
 
     async def get_identities(self, username: str, address: Address) -> GetIdentitiesResponse:
-        resp = await self.request_v1("get_identities", account=username,
-                                     address=address.serialize())
+        resp = await self.request_v1(
+            "get_identities", account=username, address=address.serialize()
+        )
         return GetIdentitiesResponse.deserialize(resp)
 
-    async def set_profile(self, username: str, name: Optional[str] = None,
-                          avatar_path: Optional[str] = None) -> None:
+    async def set_profile(
+        self, username: str, name: Optional[str] = None, avatar_path: Optional[str] = None
+    ) -> None:
         args = {}
         if name is not None:
             args["name"] = name
@@ -302,9 +371,14 @@ class SignaldClient(SignaldRPCClient):
             args["avatarFile"] = avatar_path
         await self.request_v1("set_profile", account=username, **args)
 
-    async def trust(self, username: str, recipient: Address, trust_level: str,
-                    safety_number: Optional[str] = None, qr_code_data: Optional[str] = None
-                    ) -> None:
+    async def trust(
+        self,
+        username: str,
+        recipient: Address,
+        trust_level: str,
+        safety_number: Optional[str] = None,
+        qr_code_data: Optional[str] = None,
+    ) -> None:
         args = {}
         if safety_number:
             if qr_code_data:
@@ -314,5 +388,10 @@ class SignaldClient(SignaldRPCClient):
             args["qr_code_data"] = qr_code_data
         else:
             raise ValueError("safety_number or qr_code_data is required")
-        await self.request_v1("trust", account=username, **args, trust_level=trust_level,
-                              address=recipient.serialize())
+        await self.request_v1(
+            "trust",
+            account=username,
+            **args,
+            trust_level=trust_level,
+            address=recipient.serialize(),
+        )

+ 19 - 13
mausignald/types.py

@@ -11,7 +11,7 @@ from attr import dataclass
 
 from mautrix.types import SerializableAttrs, SerializableEnum, ExtensibleEnum, field
 
-GroupID = NewType('GroupID', str)
+GroupID = NewType("GroupID", str)
 
 
 @dataclass(frozen=True, eq=False)
@@ -27,7 +27,7 @@ class Address(SerializableAttrs):
     def best_identifier(self) -> str:
         return str(self.uuid) if self.uuid else self.number
 
-    def __eq__(self, other: 'Address') -> bool:
+    def __eq__(self, other: "Address") -> bool:
         if not isinstance(other, Address):
             return False
         if self.uuid and other.uuid:
@@ -42,7 +42,7 @@ class Address(SerializableAttrs):
         return hash(self.number)
 
     @classmethod
-    def parse(cls, value: str) -> 'Address':
+    def parse(cls, value: str) -> "Address":
         return Address(number=value) if value.startswith("+") else Address(uuid=UUID(value))
 
 
@@ -202,13 +202,15 @@ class GroupV2(GroupV2ID, SerializableAttrs):
     timer: Optional[int] = None
     master_key: Optional[str] = field(default=None, json="masterKey")
     invite_link: Optional[str] = field(default=None, json="inviteLink")
-    access_control: GroupAccessControl = field(factory=lambda: GroupAccessControl(),
-                                               json="accessControl")
+    access_control: GroupAccessControl = field(
+        factory=lambda: GroupAccessControl(), json="accessControl"
+    )
     members: List[Address]
     member_detail: List[GroupMember] = field(factory=lambda: [], json="memberDetail")
     pending_members: List[Address] = field(factory=lambda: [], json="pendingMembers")
-    pending_member_detail: List[GroupMember] = field(factory=lambda: [],
-                                                     json="pendingMemberDetail")
+    pending_member_detail: List[GroupMember] = field(
+        factory=lambda: [], json="pendingMemberDetail"
+    )
     requesting_members: List[Address] = field(factory=lambda: [], json="requestingMembers")
 
 
@@ -294,8 +296,9 @@ class MessageData(SerializableAttrs):
 class SentSyncMessage(SerializableAttrs):
     message: MessageData
     timestamp: int
-    expiration_start_timestamp: Optional[int] = field(default=None,
-                                                      json="expirationStartTimestamp")
+    expiration_start_timestamp: Optional[int] = field(
+        default=None, json="expirationStartTimestamp"
+    )
     is_recipient_update: bool = field(default=False, json="isRecipientUpdate")
     unidentified_status: Dict[str, bool] = field(factory=lambda: {})
     destination: Optional[Address] = None
@@ -347,11 +350,13 @@ class ConfigItem(SerializableAttrs):
 @dataclass
 class ClientConfiguration(SerializableAttrs):
     read_receipts: Optional[ConfigItem] = field(factory=lambda: ConfigItem(), json="readReceipts")
-    typing_indicators: Optional[ConfigItem] = field(factory=lambda: ConfigItem(),
-                                                    json="typingIndicators")
+    typing_indicators: Optional[ConfigItem] = field(
+        factory=lambda: ConfigItem(), json="typingIndicators"
+    )
     link_previews: Optional[ConfigItem] = field(factory=lambda: ConfigItem(), json="linkPreviews")
     unidentified_delivery_indicators: Optional[ConfigItem] = field(
-        factory=lambda: ConfigItem(), json="unidentifiedDeliveryIndicators")
+        factory=lambda: ConfigItem(), json="unidentifiedDeliveryIndicators"
+    )
 
 
 class StickerPackOperation(ExtensibleEnum):
@@ -376,7 +381,8 @@ class SyncMessage(SerializableAttrs):
     configuration: Optional[ClientConfiguration] = None
     # blocked_list: Optional[???] = field(default=None, json="blockedList")
     sticker_pack_operations: Optional[List[StickerPackOperations]] = field(
-        default=None, json="stickerPackOperations")
+        default=None, json="stickerPackOperations"
+    )
     contacts_complete: bool = field(default=False, json="contactsComplete")
 
 

+ 74 - 35
mautrix_signal/commands/auth.py

@@ -34,8 +34,9 @@ SECTION_AUTH = HelpSection("Authentication", 10, "")
 remove_extra_chars = str.maketrans("", "", " .,-()")
 
 
-async def make_qr(intent: IntentAPI, data: Union[str, bytes], body: str = None
-                  ) -> MediaMessageEventContent:
+async def make_qr(
+    intent: IntentAPI, data: Union[str, bytes], body: str = None
+) -> MediaMessageEventContent:
     # TODO always encrypt QR codes?
     buffer = io.BytesIO()
     image = qrcode.make(data)
@@ -43,20 +44,30 @@ async def make_qr(intent: IntentAPI, data: Union[str, bytes], body: str = None
     image.save(buffer, "PNG")
     qr = buffer.getvalue()
     mxc = await intent.upload_media(qr, "image/png", "qr.png", len(qr))
-    return MediaMessageEventContent(body=body or data, url=mxc, msgtype=MessageType.IMAGE,
-                                    info=ImageInfo(mimetype="image/png", size=len(qr),
-                                                   width=size, height=size))
-
-
-@command_handler(needs_auth=False, management_only=True, help_section=SECTION_AUTH,
-                 help_text="Link the bridge as a secondary device", help_args="[device name]")
+    return MediaMessageEventContent(
+        body=body or data,
+        url=mxc,
+        msgtype=MessageType.IMAGE,
+        info=ImageInfo(mimetype="image/png", size=len(qr), width=size, height=size),
+    )
+
+
+@command_handler(
+    needs_auth=False,
+    management_only=True,
+    help_section=SECTION_AUTH,
+    help_text="Link the bridge as a secondary device",
+    help_args="[device name]",
+)
 async def link(evt: CommandEvent) -> None:
     if qrcode is None:
         await evt.reply("Can't generate QR code: qrcode and/or PIL not installed")
         return
     if await evt.sender.is_logged_in():
-        await evt.reply("You're already logged in. "
-                        "If you want to relink, log out with `$cmdprefix+sp logout` first.")
+        await evt.reply(
+            "You're already logged in. "
+            "If you want to relink, log out with `$cmdprefix+sp logout` first."
+        )
         return
     # TODO make default device name configurable
     device_name = " ".join(evt.args) or "Mautrix-Signal bridge"
@@ -65,14 +76,16 @@ async def link(evt: CommandEvent) -> None:
     content = await make_qr(evt.az.intent, sess.uri)
     event_id = await evt.az.intent.send_message(evt.room_id, content)
     try:
-        account = await evt.bridge.signal.finish_link(session_id=sess.session_id, overwrite=True,
-                                                      device_name=device_name)
+        account = await evt.bridge.signal.finish_link(
+            session_id=sess.session_id, overwrite=True, device_name=device_name
+        )
     except TimeoutException:
         await evt.reply("Linking timed out, please try again.")
     except Exception:
         evt.log.exception("Fatal error while waiting for linking to finish")
-        await evt.reply("Fatal error while waiting for linking to finish "
-                        "(see logs for more details)")
+        await evt.reply(
+            "Fatal error while waiting for linking to finish " "(see logs for more details)"
+        )
     else:
         await evt.sender.on_signin(account)
         await evt.reply(f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}")
@@ -80,16 +93,23 @@ async def link(evt: CommandEvent) -> None:
         await evt.main_intent.redact(evt.room_id, event_id)
 
 
-@command_handler(needs_auth=False, management_only=True, help_section=SECTION_AUTH,
-                 is_enabled_for=lambda evt: evt.config["signal.registration_enabled"],
-                 help_text="Sign into Signal as the primary device", help_args="<phone>")
+@command_handler(
+    needs_auth=False,
+    management_only=True,
+    help_section=SECTION_AUTH,
+    is_enabled_for=lambda evt: evt.config["signal.registration_enabled"],
+    help_text="Sign into Signal as the primary device",
+    help_args="<phone>",
+)
 async def register(evt: CommandEvent) -> None:
     if len(evt.args) == 0:
         await evt.reply("**Usage**: $cmdprefix+sp register [--voice] [--captcha <token>] <phone>")
         return
     if await evt.sender.is_logged_in():
-        await evt.reply("You're already logged in. "
-                        "If you want to re-register, log out with `$cmdprefix+sp logout` first.")
+        await evt.reply(
+            "You're already logged in. "
+            "If you want to re-register, log out with `$cmdprefix+sp logout` first."
+        )
         return
     voice = False
     captcha = None
@@ -132,13 +152,19 @@ async def enter_register_code(evt: CommandEvent) -> None:
             raise
     else:
         await evt.sender.on_signin(account)
-        await evt.reply(f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}."
-                        f"\n\n**N.B.** You must set a Signal profile name with `$cmdprefix+sp "
-                        f"set-profile-name <name>` before you can participate in new groups.")
-
-
-@command_handler(needs_auth=True, management_only=True, help_section=SECTION_AUTH,
-                 help_text="Remove all local data about your Signal link")
+        await evt.reply(
+            f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}."
+            f"\n\n**N.B.** You must set a Signal profile name with `$cmdprefix+sp "
+            f"set-profile-name <name>` before you can participate in new groups."
+        )
+
+
+@command_handler(
+    needs_auth=True,
+    management_only=True,
+    help_section=SECTION_AUTH,
+    help_text="Remove all local data about your Signal link",
+)
 async def logout(evt: CommandEvent) -> None:
     if not evt.sender.username:
         await evt.reply("You're not logged in")
@@ -147,16 +173,29 @@ async def logout(evt: CommandEvent) -> None:
     await evt.reply("Successfully logged out")
 
 
-@command_handler(needs_auth=True, management_only=True, help_section=SECTION_AUTH,
-                 help_text="List devices linked to your Signal account")
+@command_handler(
+    needs_auth=True,
+    management_only=True,
+    help_section=SECTION_AUTH,
+    help_text="List devices linked to your Signal account",
+)
 async def list_devices(evt: CommandEvent) -> None:
     devices = await evt.bridge.signal.get_linked_devices(evt.sender.username)
-    await evt.reply("\n".join(f"* #{dev.id}: {dev.name_with_default} (created {dev.created_fmt}, "
-                              f"last seen {dev.last_seen_fmt})" for dev in devices))
-
-
-@command_handler(needs_auth=True, management_only=True, help_section=SECTION_AUTH,
-                 help_text="Remove a linked device")
+    await evt.reply(
+        "\n".join(
+            f"* #{dev.id}: {dev.name_with_default} (created {dev.created_fmt}, last seen "
+            f"{dev.last_seen_fmt})"
+            for dev in devices
+        )
+    )
+
+
+@command_handler(
+    needs_auth=True,
+    management_only=True,
+    help_section=SECTION_AUTH,
+    help_text="Remove a linked device",
+)
 async def remove_linked_device(evt: CommandEvent) -> EventID:
     if len(evt.args) == 0:
         return await evt.reply("**Usage:** `$cmdprefix+sp remove-linked-device <device ID>`")

+ 23 - 8
mautrix_signal/commands/conn.py

@@ -20,28 +20,42 @@ from .typehint import CommandEvent
 SECTION_CONNECTION = HelpSection("Connection management", 15, "")
 
 
-@command_handler(needs_auth=False, management_only=True, help_section=SECTION_CONNECTION,
-                 help_text="Mark this room as your bridge notice room.")
+@command_handler(
+    needs_auth=False,
+    management_only=True,
+    help_section=SECTION_CONNECTION,
+    help_text="Mark this room as your bridge notice room.",
+)
 async def set_notice_room(evt: CommandEvent) -> None:
     evt.sender.notice_room = evt.room_id
     await evt.sender.update()
     await evt.reply("This room has been marked as your bridge notice room")
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_CONNECTION,
-                 help_text="Relay messages in this room through your Signal account.")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_CONNECTION,
+    help_text="Relay messages in this room through your Signal account.",
+)
 async def set_relay(evt: CommandEvent) -> EventID:
     if not evt.config["bridge.relay.enabled"]:
         return await evt.reply("Relay mode is not enabled in this instance of the bridge.")
     elif not evt.is_portal:
         return await evt.reply("This is not a portal room.")
     await evt.portal.set_relay_user(evt.sender)
-    return await evt.reply("Messages from non-logged-in users in this room will now be bridged "
-                           "through your Signal account.")
+    return await evt.reply(
+        "Messages from non-logged-in users in this room will now be bridged "
+        "through your Signal account."
+    )
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_CONNECTION,
-                 help_text="Stop relaying messages in this room.")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_CONNECTION,
+    help_text="Stop relaying messages in this room.",
+)
 async def unset_relay(evt: CommandEvent) -> EventID:
     if not evt.config["bridge.relay.enabled"]:
         return await evt.reply("Relay mode is not enabled in this instance of the bridge.")
@@ -52,6 +66,7 @@ async def unset_relay(evt: CommandEvent) -> EventID:
     await evt.portal.set_relay_user(None)
     return await evt.reply("Messages from non-logged-in users will no longer be bridged.")
 
+
 # @command_handler(needs_auth=False, management_only=True, help_section=SECTION_CONNECTION,
 #                  help_text="Check if you're logged into Twitter")
 # async def ping(evt: CommandEvent) -> None:

+ 97 - 46
mautrix_signal/commands/signal.py

@@ -35,15 +35,19 @@ except ImportError:
 SECTION_SIGNAL = HelpSection("Signal actions", 20, "")
 
 
-async def _get_puppet_from_cmd(evt: CommandEvent) -> Optional['pu.Puppet']:
+async def _get_puppet_from_cmd(evt: CommandEvent) -> Optional["pu.Puppet"]:
     if len(evt.args) == 0 or not evt.args[0].startswith("+"):
-        await evt.reply(f"**Usage:** `$cmdprefix+sp {evt.command} <phone>` "
-                        "(enter phone number in international format)")
+        await evt.reply(
+            f"**Usage:** `$cmdprefix+sp {evt.command} <phone>` "
+            "(enter phone number in international format)"
+        )
         return None
     phone = "".join(evt.args).translate(remove_extra_chars)
     if not phone[1:].isdecimal():
-        await evt.reply(f"**Usage:** `$cmdprefix+sp {evt.command} <phone>` "
-                        "(enter phone number in international format)")
+        await evt.reply(
+            f"**Usage:** `$cmdprefix+sp {evt.command} <phone>` "
+            "(enter phone number in international format)"
+        )
         return None
     return await pu.Puppet.get_by_address(Address(number=phone))
 
@@ -51,40 +55,58 @@ async def _get_puppet_from_cmd(evt: CommandEvent) -> Optional['pu.Puppet']:
 def _format_safety_number(number: str) -> str:
     line_size = 20
     chunk_size = 5
-    return "\n".join(" ".join([number[chunk:chunk + chunk_size]
-                               for chunk in range(line, line + line_size, chunk_size)])
-                     for line in range(0, len(number), line_size))
-
-
-def _pill(puppet: 'pu.Puppet') -> str:
+    return "\n".join(
+        " ".join(
+            [
+                number[chunk : chunk + chunk_size]
+                for chunk in range(line, line + line_size, chunk_size)
+            ]
+        )
+        for line in range(0, len(number), line_size)
+    )
+
+
+def _pill(puppet: "pu.Puppet") -> str:
     return f"[{puppet.name}](https://matrix.to/#/{puppet.mxid})"
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_SIGNAL,
-                 help_text="Open a private chat portal with a specific phone number",
-                 help_args="<_phone_>")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_SIGNAL,
+    help_text="Open a private chat portal with a specific phone number",
+    help_args="<_phone_>",
+)
 async def pm(evt: CommandEvent) -> None:
     puppet = await _get_puppet_from_cmd(evt)
     if not puppet:
         return
-    portal = await po.Portal.get_by_chat_id(puppet.address, receiver=evt.sender.username,
-                                            create=True)
+    portal = await po.Portal.get_by_chat_id(
+        puppet.address, receiver=evt.sender.username, create=True
+    )
     if portal.mxid:
-        await evt.reply(f"You already have a private chat with {puppet.name}: "
-                        f"[{portal.mxid}](https://matrix.to/#/{portal.mxid})")
+        await evt.reply(
+            f"You already have a private chat with {puppet.name}: "
+            f"[{portal.mxid}](https://matrix.to/#/{portal.mxid})"
+        )
         await portal.main_intent.invite_user(portal.mxid, evt.sender.mxid)
         return
     await portal.create_matrix_room(evt.sender, puppet.address)
     await evt.reply(f"Created a portal room with {_pill(puppet)} and invited you to it")
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_SIGNAL,
-                 help_text="Get the invite link to the current group")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_SIGNAL,
+    help_text="Get the invite link to the current group",
+)
 async def invite_link(evt: CommandEvent) -> EventID:
     if not evt.is_portal:
         return await evt.reply("This is not a portal room.")
-    group = await evt.bridge.signal.get_group(evt.sender.username, evt.portal.chat_id,
-                                              evt.portal.revision)
+    group = await evt.bridge.signal.get_group(
+        evt.sender.username, evt.portal.chat_id, evt.portal.revision
+    )
     if not group:
         await evt.reply("Failed to get group info")
     elif not group.invite_link:
@@ -93,9 +115,13 @@ async def invite_link(evt: CommandEvent) -> EventID:
         await evt.reply(group.invite_link)
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_SIGNAL,
-                 help_text="View the safety number of a specific user",
-                 help_args="[--qr] [_phone_]")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_SIGNAL,
+    help_text="View the safety number of a specific user",
+    help_args="[--qr] [_phone_]",
+)
 async def safety_number(evt: CommandEvent) -> None:
     show_qr = evt.args and evt.args[0].lower() == "--qr"
     if show_qr:
@@ -119,53 +145,77 @@ async def safety_number(evt: CommandEvent) -> None:
         if identity.added > most_recent.added:
             most_recent = identity
     uuid = resp.address.uuid or "unknown"
-    await evt.reply(f"### {puppet.name}\n\n"
-                    f"**UUID:** {uuid}  \n"
-                    f"**Trust level:** {most_recent.trust_level}  \n"
-                    f"**Safety number:**\n"
-                    f"```\n{_format_safety_number(most_recent.safety_number)}\n```")
+    await evt.reply(
+        f"### {puppet.name}\n\n"
+        f"**UUID:** {uuid}  \n"
+        f"**Trust level:** {most_recent.trust_level}  \n"
+        f"**Safety number:**\n"
+        f"```\n{_format_safety_number(most_recent.safety_number)}\n```"
+    )
     if show_qr and most_recent.qr_code_data:
         data = base64.b64decode(most_recent.qr_code_data)
         content = await make_qr(evt.main_intent, data, "verification-qr.png")
         await evt.main_intent.send_message(evt.room_id, content)
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_SIGNAL,
-                 help_text="Set your Signal profile name", help_args="<_name_>")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_SIGNAL,
+    help_text="Set your Signal profile name",
+    help_args="<_name_>",
+)
 async def set_profile_name(evt: CommandEvent) -> None:
     await evt.bridge.signal.set_profile(evt.sender.username, name=" ".join(evt.args))
     await evt.reply("Successfully updated profile name")
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_SIGNAL,
-                 help_text="Mark another user's safety number as trusted",
-                 help_args="<_recipient phone_> <_safety number_>")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_SIGNAL,
+    help_text="Mark another user's safety number as trusted",
+    help_args="<_recipient phone_> <_safety number_>",
+)
 async def mark_trusted(evt: CommandEvent) -> EventID:
     if len(evt.args) < 2:
-        return await evt.reply("**Usage:** `$cmdprefix+sp mark-trusted <recipient phone> "
-                               "<safety number>`")
+        return await evt.reply(
+            "**Usage:** `$cmdprefix+sp mark-trusted <recipient phone> " "<safety number>`"
+        )
     number = evt.args[0].translate(remove_extra_chars)
     safety_num = "".join(evt.args[1:]).replace("\n", "")
     if len(safety_num) != 60 or not safety_num.isdecimal():
         return await evt.reply("That doesn't look like a valid safety number")
     try:
-        await evt.bridge.signal.trust(evt.sender.username, Address(number=number),
-                                      safety_number=safety_num, trust_level="TRUSTED_VERIFIED")
+        await evt.bridge.signal.trust(
+            evt.sender.username,
+            Address(number=number),
+            safety_number=safety_num,
+            trust_level="TRUSTED_VERIFIED",
+        )
     except UnknownIdentityKey as e:
         return await evt.reply(f"Failed to mark {number} as trusted: {e}")
     return await evt.reply(f"Successfully marked {number} as trusted")
 
 
-@command_handler(needs_admin=False, needs_auth=True, help_section=SECTION_SIGNAL,
-                 help_text="Sync data from Signal")
+@command_handler(
+    needs_admin=False,
+    needs_auth=True,
+    help_section=SECTION_SIGNAL,
+    help_text="Sync data from Signal",
+)
 async def sync(evt: CommandEvent) -> None:
     await evt.sender.sync()
     await evt.reply("Sync complete")
 
 
-@command_handler(needs_admin=True, needs_auth=False, help_section=SECTION_ADMIN,
-                 help_text="Send raw requests to signald",
-                 help_args="[--user] <type> <_json_>")
+@command_handler(
+    needs_admin=True,
+    needs_auth=False,
+    help_section=SECTION_ADMIN,
+    help_text="Send raw requests to signald",
+    help_args="[--user] <type> <_json_>",
+)
 async def raw(evt: CommandEvent) -> None:
     add_username = False
     while True:
@@ -200,5 +250,6 @@ async def raw(evt: CommandEvent) -> None:
         if resp_data is None:
             await evt.reply(f"Got reply `{resp_type}` with no content")
         else:
-            await evt.reply(f"Got reply `{resp_type}`:\n\n"
-                            f"```json\n{json.dumps(resp_data, indent=2)}\n```")
+            await evt.reply(
+                f"Got reply `{resp_type}`:\n\n" f"```json\n{json.dumps(resp_data, indent=2)}\n```"
+            )

+ 3 - 3
mautrix_signal/commands/typehint.py

@@ -9,6 +9,6 @@ if TYPE_CHECKING:
 
 
 class CommandEvent(BaseCommandEvent):
-    bridge: 'SignalBridge'
-    sender: 'User'
-    portal: 'Portal'
+    bridge: "SignalBridge"
+    sender: "User"
+    portal: "Portal"

+ 13 - 3
mautrix_signal/db/__init__.py

@@ -18,7 +18,17 @@ def init(db: Database) -> None:
 
 # TODO should this be in mautrix-python?
 sqlite3.register_adapter(uuid.UUID, lambda u: str(u))
-sqlite3.register_converter("UUID", lambda b: uuid.UUID(b.decode("utf-8") if isinstance(b, bytes) else b))
+sqlite3.register_converter(
+    "UUID", lambda b: uuid.UUID(b.decode("utf-8") if isinstance(b, bytes) else b)
+)
 
-__all__ = ["upgrade_table", "init", "User", "Puppet", "Portal", "Message", "Reaction",
-           "DisappearingMessage"]
+__all__ = [
+    "upgrade_table",
+    "init",
+    "User",
+    "Puppet",
+    "Portal",
+    "Message",
+    "Reaction",
+    "DisappearingMessage",
+]

+ 6 - 4
mautrix_signal/db/disappearing_message.py

@@ -38,8 +38,9 @@ class DisappearingMessage:
         INSERT INTO disappearing_message (room_id, mxid, expiration_seconds, expiration_ts)
         VALUES ($1, $2, $3, $4)
         """
-        await self.db.execute(q, self.room_id, self.mxid, self.expiration_seconds,
-                              self.expiration_ts)
+        await self.db.execute(
+            q, self.room_id, self.mxid, self.expiration_seconds, self.expiration_ts
+        )
 
     async def update(self) -> None:
         q = """
@@ -48,8 +49,9 @@ class DisappearingMessage:
         WHERE room_id=$1 AND mxid=$2
         """
         try:
-            await self.db.execute(q, self.room_id, self.mxid, self.expiration_seconds,
-                              self.expiration_ts)
+            await self.db.execute(
+                q, self.room_id, self.mxid, self.expiration_seconds, self.expiration_ts
+            )
         except Exception as e:
             print(e)
 

+ 65 - 29
mautrix_signal/db/message.py

@@ -40,23 +40,39 @@ class Message:
     signal_receiver: str
 
     async def insert(self) -> None:
-        q = ("INSERT INTO message (mxid, mx_room, sender, timestamp, signal_chat_id,"
-             "                     signal_receiver) VALUES ($1, $2, $3, $4, $5, $6)")
-        await self.db.execute(q, self.mxid, self.mx_room, self.sender.best_identifier,
-                              self.timestamp, id_to_str(self.signal_chat_id), self.signal_receiver)
+        q = """
+        INSERT INTO message (mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver)
+        VALUES ($1, $2, $3, $4, $5, $6)
+        """
+        await self.db.execute(
+            q,
+            self.mxid,
+            self.mx_room,
+            self.sender.best_identifier,
+            self.timestamp,
+            id_to_str(self.signal_chat_id),
+            self.signal_receiver,
+        )
 
     async def delete(self) -> None:
-        q = ("DELETE FROM message WHERE sender=$1 AND timestamp=$2"
-             "                          AND signal_chat_id=$3 AND signal_receiver=$4")
-        await self.db.execute(q, self.sender.best_identifier, self.timestamp,
-                              id_to_str(self.signal_chat_id), self.signal_receiver)
+        q = """
+        DELETE FROM message
+         WHERE sender=$1 AND timestamp=$2 AND signal_chat_id=$3 AND signal_receiver=$4
+        """
+        await self.db.execute(
+            q,
+            self.sender.best_identifier,
+            self.timestamp,
+            id_to_str(self.signal_chat_id),
+            self.signal_receiver,
+        )
 
     @classmethod
     async def delete_all(cls, room_id: RoomID) -> None:
         await cls.db.execute("DELETE FROM message WHERE mx_room=$1", room_id)
 
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> 'Message':
+    def _from_row(cls, row: asyncpg.Record) -> "Message":
         data = {**row}
         chat_id = data.pop("signal_chat_id")
         if data["signal_receiver"]:
@@ -65,44 +81,64 @@ class Message:
         return cls(signal_chat_id=chat_id, sender=sender, **data)
 
     @classmethod
-    async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional['Message']:
-        q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
-             "FROM message WHERE mxid=$1 AND mx_room=$2")
+    async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional["Message"]:
+        q = """
+        SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
+          FROM message WHERE mxid=$1 AND mx_room=$2
+        """
         row = await cls.db.fetchrow(q, mxid, mx_room)
         if not row:
             return None
         return cls._from_row(row)
 
     @classmethod
-    async def get_by_signal_id(cls, sender: Address, timestamp: int,
-                               signal_chat_id: Union[GroupID, Address], signal_receiver: str = ""
-                               ) -> Optional['Message']:
-        q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
-             "FROM message WHERE sender=$1 AND timestamp=$2"
-             "                   AND signal_chat_id=$3 AND signal_receiver=$4")
-        row = await cls.db.fetchrow(q, sender.best_identifier, timestamp,
-                                    id_to_str(signal_chat_id), signal_receiver)
+    async def get_by_signal_id(
+        cls,
+        sender: Address,
+        timestamp: int,
+        signal_chat_id: Union[GroupID, Address],
+        signal_receiver: str = "",
+    ) -> Optional["Message"]:
+        q = """
+        SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
+          FROM message
+         WHERE sender=$1 AND timestamp=$2 AND signal_chat_id=$3 AND signal_receiver=$4
+        """
+        row = await cls.db.fetchrow(
+            q, sender.best_identifier, timestamp, id_to_str(signal_chat_id), signal_receiver
+        )
         if not row:
             return None
         return cls._from_row(row)
 
     @classmethod
-    async def find_by_timestamps(cls, timestamps: List[int]) -> List['Message']:
+    async def find_by_timestamps(cls, timestamps: List[int]) -> List["Message"]:
         if cls.db.scheme == "postgres":
-            q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
-                 "FROM message WHERE timestamp=ANY($1)")
+            q = """
+            SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
+              FROM message
+             WHERE timestamp=ANY($1)
+            """
             rows = await cls.db.fetch(q, timestamps)
         else:
-            placeholders = ", ".join(f"?" for _ in range(len(timestamps)))
-            q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
-                 f"FROM message WHERE timestamp IN ({placeholders})")
+            placeholders = ", ".join("?" for _ in range(len(timestamps)))
+            q = f"""
+            SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
+              FROM message
+             WHERE timestamp IN ({placeholders})
+            """
             rows = await cls.db.fetch(q, *timestamps)
         return [cls._from_row(row) for row in rows]
 
     @classmethod
-    async def find_by_sender_timestamp(cls, sender: Address, timestamp: int) -> Optional['Message']:
-        q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
-             "FROM message WHERE sender=$1 AND timestamp=$2")
+    async def find_by_sender_timestamp(
+        cls, sender: Address, timestamp: int
+    ) -> Optional["Message"]:
+        q = """
+        SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
+          FROM message
+         WHERE sender=$1 AND timestamp=$2
+        """
         row = await cls.db.fetchrow(q, sender.best_identifier, timestamp)
         if not row:
             return None

+ 79 - 38
mautrix_signal/db/portal.py

@@ -49,27 +49,52 @@ class Portal:
         return id_to_str(self.chat_id)
 
     async def insert(self) -> None:
-        q = ("INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, "
-             "                    name_set, avatar_set, revision, encrypted, relay_user_id, "
-             "                    expiration_time) "
-             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)")
-        await self.db.execute(q, self.chat_id_str, self.receiver, self.mxid, self.name,
-                              self.avatar_hash, self.avatar_url, self.name_set, self.avatar_set,
-                              self.revision, self.encrypted, self.relay_user_id,
-                              self.expiration_time)
+        q = """
+        INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set,
+                            avatar_set, revision, encrypted, relay_user_id, expiration_time)
+        VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
+        """
+        await self.db.execute(
+            q,
+            self.chat_id_str,
+            self.receiver,
+            self.mxid,
+            self.name,
+            self.avatar_hash,
+            self.avatar_url,
+            self.name_set,
+            self.avatar_set,
+            self.revision,
+            self.encrypted,
+            self.relay_user_id,
+            self.expiration_time,
+        )
 
     async def update(self) -> None:
-        q = ("UPDATE portal SET mxid=$1, name=$2, avatar_hash=$3, avatar_url=$4, name_set=$5, "
-             "                  avatar_set=$6, revision=$7, encrypted=$8, relay_user_id=$9, "
-             "                  expiration_time=$10"
-             "WHERE chat_id=$11 AND receiver=$12")
-        await self.db.execute(q, self.mxid, self.name, self.avatar_hash, self.avatar_url,
-                              self.name_set, self.avatar_set, self.revision, self.encrypted,
-                              self.relay_user_id, self.expiration_time, self.chat_id_str,
-                              self.receiver)
+        q = """
+        UPDATE portal SET mxid=$1, name=$2, avatar_hash=$3, avatar_url=$4, name_set=$5,
+                          avatar_set=$6, revision=$7, encrypted=$8, relay_user_id=$9,
+                          expiration_time=$10
+        WHERE chat_id=$11 AND receiver=$12
+        """
+        await self.db.execute(
+            q,
+            self.mxid,
+            self.name,
+            self.avatar_hash,
+            self.avatar_url,
+            self.name_set,
+            self.avatar_set,
+            self.revision,
+            self.encrypted,
+            self.relay_user_id,
+            self.expiration_time,
+            self.chat_id_str,
+            self.receiver,
+        )
 
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> 'Portal':
+    def _from_row(cls, row: asyncpg.Record) -> "Portal":
         data = {**row}
         chat_id = data.pop("chat_id")
         if data["receiver"]:
@@ -77,46 +102,62 @@ class Portal:
         return cls(chat_id=chat_id, **data)
 
     @classmethod
-    async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id, expiration_time "
-             "FROM portal WHERE mxid=$1")
+    async def get_by_mxid(cls, mxid: RoomID) -> Optional["Portal"]:
+        q = """
+        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
+               revision, encrypted, relay_user_id, expiration_time
+          FROM portal
+         WHERE mxid=$1
+        """
         row = await cls.db.fetchrow(q, mxid)
         if not row:
             return None
         return cls._from_row(row)
 
     @classmethod
-    async def get_by_chat_id(cls, chat_id: Union[GroupID, Address], receiver: str = ""
-                             ) -> Optional['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id, expiration_time "
-             "FROM portal WHERE chat_id=$1 AND receiver=$2")
+    async def get_by_chat_id(
+        cls, chat_id: Union[GroupID, Address], receiver: str = ""
+    ) -> Optional["Portal"]:
+        q = """
+        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
+               revision, encrypted, relay_user_id, expiration_time
+          FROM portal
+         WHERE chat_id=$1 AND receiver=$2
+        """
         row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
         if not row:
             return None
         return cls._from_row(row)
 
     @classmethod
-    async def find_private_chats_of(cls, receiver: str) -> List['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id, expiration_time "
-             "FROM portal WHERE receiver=$1")
+    async def find_private_chats_of(cls, receiver: str) -> List["Portal"]:
+        q = """
+        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
+               revision, encrypted, relay_user_id, expiration_time
+          FROM portal
+         WHERE receiver=$1
+        """
         rows = await cls.db.fetch(q, receiver)
         return [cls._from_row(row) for row in rows]
 
     @classmethod
-    async def find_private_chats_with(cls, other_user: Address) -> List['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id, expiration_time "
-             "FROM portal WHERE chat_id=$1 AND receiver<>''")
+    async def find_private_chats_with(cls, other_user: Address) -> List["Portal"]:
+        q = """
+        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
+               revision, encrypted, relay_user_id, expiration_time
+          FROM portal
+         WHERE chat_id=$1 AND receiver<>''
+        """
         rows = await cls.db.fetch(q, other_user.best_identifier)
         return [cls._from_row(row) for row in rows]
 
     @classmethod
-    async def all_with_room(cls) -> List['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id, expiration_time "
-             "FROM portal WHERE mxid IS NOT NULL")
+    async def all_with_room(cls) -> List["Portal"]:
+        q = """
+        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
+               revision, encrypted, relay_user_id, expiration_time
+          FROM portal
+         WHERE mxid IS NOT NULL
+        """
         rows = await cls.db.fetch(q)
         return [cls._from_row(row) for row in rows]

+ 68 - 33
mautrix_signal/db/puppet.py

@@ -52,36 +52,54 @@ class Puppet:
         return str(self.base_url) if self.base_url else None
 
     async def insert(self) -> None:
-        q = ("INSERT INTO puppet (uuid, number, name, avatar_hash, avatar_url, name_set, "
-             "                    avatar_set, uuid_registered, number_registered, "
-             "                    custom_mxid, access_token, next_batch, base_url) "
-             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)")
-        await self.db.execute(q, self.uuid, self.number, self.name, self.avatar_hash,
-                              self.avatar_url, self.name_set, self.avatar_set,
-                              self.uuid_registered, self.number_registered, self.custom_mxid,
-                              self.access_token, self.next_batch, self._base_url_str)
+        q = (
+            "INSERT INTO puppet (uuid, number, name, avatar_hash, avatar_url, name_set, "
+            "                    avatar_set, uuid_registered, number_registered, "
+            "                    custom_mxid, access_token, next_batch, base_url) "
+            "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)"
+        )
+        await self.db.execute(
+            q,
+            self.uuid,
+            self.number,
+            self.name,
+            self.avatar_hash,
+            self.avatar_url,
+            self.name_set,
+            self.avatar_set,
+            self.uuid_registered,
+            self.number_registered,
+            self.custom_mxid,
+            self.access_token,
+            self.next_batch,
+            self._base_url_str,
+        )
 
     async def _set_uuid(self, uuid: UUID) -> None:
         async with self.db.acquire() as conn, conn.transaction():
-            await conn.execute("DELETE FROM puppet WHERE uuid=$1 AND number<>$2",
-                               uuid, self.number)
+            await conn.execute(
+                "DELETE FROM puppet WHERE uuid=$1 AND number<>$2", uuid, self.number
+            )
             await conn.execute("UPDATE puppet SET uuid=$1 WHERE number=$2", uuid, self.number)
             await self._update_number_to_uuid(conn, self.number, str(uuid))
 
     async def _set_number(self, number: str) -> None:
         async with self.db.acquire() as conn, conn.transaction():
-            await conn.execute("DELETE FROM puppet WHERE number=$1 AND uuid<>$2",
-                               number, self.uuid)
+            await conn.execute(
+                "DELETE FROM puppet WHERE number=$1 AND uuid<>$2", number, self.uuid
+            )
             await conn.execute("UPDATE puppet SET number=$1 WHERE uuid=$2", number, self.uuid)
             await self._update_number_to_uuid(conn, number, str(self.uuid))
 
     @staticmethod
-    async def _update_number_to_uuid(conn: asyncpg.Connection, old_number: str, new_uuid: str
-                                     ) -> None:
+    async def _update_number_to_uuid(
+        conn: asyncpg.Connection, old_number: str, new_uuid: str
+    ) -> None:
         try:
             async with conn.transaction():
-                await conn.execute("UPDATE portal SET chat_id=$1 WHERE chat_id=$2",
-                                   new_uuid, old_number)
+                await conn.execute(
+                    "UPDATE portal SET chat_id=$1 WHERE chat_id=$2", new_uuid, old_number
+                )
         except asyncpg.UniqueViolationError:
             await conn.execute("DELETE FROM portal WHERE chat_id=$1", old_number)
         await conn.execute("UPDATE message SET sender=$1 WHERE sender=$2", new_uuid, old_number)
@@ -93,32 +111,49 @@ class Puppet:
             "uuid_registered=$8, number_registered=$9, "
             "custom_mxid=$10, access_token=$11, next_batch=$12, base_url=$13"
         )
-        q = (f"UPDATE puppet SET uuid=$1, {set_columns} WHERE number=$2"
-             if self.uuid is None
-             else f"UPDATE puppet SET number=$2, {set_columns} WHERE uuid=$1")
-        await self.db.execute(q,self.uuid, self.number, self.name, self.avatar_hash,
-                              self.avatar_url, self.name_set, self.avatar_set,
-                              self.uuid_registered, self.number_registered, self.custom_mxid,
-                              self.access_token, self.next_batch, self._base_url_str)
+        q = (
+            f"UPDATE puppet SET uuid=$1, {set_columns} WHERE number=$2"
+            if self.uuid is None
+            else f"UPDATE puppet SET number=$2, {set_columns} WHERE uuid=$1"
+        )
+        await self.db.execute(
+            q,
+            self.uuid,
+            self.number,
+            self.name,
+            self.avatar_hash,
+            self.avatar_url,
+            self.name_set,
+            self.avatar_set,
+            self.uuid_registered,
+            self.number_registered,
+            self.custom_mxid,
+            self.access_token,
+            self.next_batch,
+            self._base_url_str,
+        )
 
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> 'Puppet':
+    def _from_row(cls, row: asyncpg.Record) -> "Puppet":
         data = {**row}
         base_url_str = data.pop("base_url")
         base_url = URL(base_url_str) if base_url_str is not None else None
         return cls(base_url=base_url, **data)
 
-    _select_base = ("SELECT uuid, number, name, avatar_hash, avatar_url, name_set, avatar_set, "
-                    "       uuid_registered, number_registered, custom_mxid, access_token, "
-                    "       next_batch, base_url "
-                    "FROM puppet")
+    _select_base = (
+        "SELECT uuid, number, name, avatar_hash, avatar_url, name_set, avatar_set, "
+        "       uuid_registered, number_registered, custom_mxid, access_token, "
+        "       next_batch, base_url "
+        "FROM puppet"
+    )
 
     @classmethod
-    async def get_by_address(cls, address: Address) -> Optional['Puppet']:
+    async def get_by_address(cls, address: Address) -> Optional["Puppet"]:
         if address.uuid:
             if address.number:
-                row = await cls.db.fetchrow(f"{cls._select_base} WHERE uuid=$1 OR number=$2",
-                                            address.uuid, address.number)
+                row = await cls.db.fetchrow(
+                    f"{cls._select_base} WHERE uuid=$1 OR number=$2", address.uuid, address.number
+                )
             else:
                 row = await cls.db.fetchrow(f"{cls._select_base} WHERE uuid=$1", address.uuid)
         elif address.number:
@@ -130,13 +165,13 @@ class Puppet:
         return cls._from_row(row)
 
     @classmethod
-    async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
+    async def get_by_custom_mxid(cls, mxid: UserID) -> Optional["Puppet"]:
         row = await cls.db.fetchrow(f"{cls._select_base} WHERE custom_mxid=$1", mxid)
         if not row:
             return None
         return cls._from_row(row)
 
     @classmethod
-    async def all_with_custom_mxid(cls) -> List['Puppet']:
+    async def all_with_custom_mxid(cls) -> List["Puppet"]:
         rows = await cls.db.fetch(f"{cls._select_base} WHERE custom_mxid IS NOT NULL")
         return [cls._from_row(row) for row in rows]

+ 70 - 31
mautrix_signal/db/reaction.py

@@ -42,30 +42,54 @@ class Reaction:
     emoji: str
 
     async def insert(self) -> None:
-        q = ("INSERT INTO reaction (mxid, mx_room, signal_chat_id, signal_receiver, msg_author,"
-             "                      msg_timestamp, author, emoji) "
-             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)")
-        await self.db.execute(q, self.mxid, self.mx_room, id_to_str(self.signal_chat_id),
-                              self.signal_receiver, self.msg_author.best_identifier,
-                              self.msg_timestamp, self.author.best_identifier, self.emoji)
+        q = (
+            "INSERT INTO reaction (mxid, mx_room, signal_chat_id, signal_receiver, msg_author,"
+            "                      msg_timestamp, author, emoji) "
+            "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
+        )
+        await self.db.execute(
+            q,
+            self.mxid,
+            self.mx_room,
+            id_to_str(self.signal_chat_id),
+            self.signal_receiver,
+            self.msg_author.best_identifier,
+            self.msg_timestamp,
+            self.author.best_identifier,
+            self.emoji,
+        )
 
     async def edit(self, mx_room: RoomID, mxid: EventID, emoji: str) -> None:
-        await self.db.execute("UPDATE reaction SET mxid=$1, mx_room=$2, emoji=$3 "
-                              "WHERE signal_chat_id=$4 AND signal_receiver=$5"
-                              "      AND msg_author=$6 AND msg_timestamp=$7 AND author=$8",
-                              mxid, mx_room, emoji, id_to_str(self.signal_chat_id),
-                              self.signal_receiver, self.msg_author.best_identifier,
-                              self.msg_timestamp, self.author.best_identifier)
+        await self.db.execute(
+            "UPDATE reaction SET mxid=$1, mx_room=$2, emoji=$3 "
+            "WHERE signal_chat_id=$4 AND signal_receiver=$5"
+            "      AND msg_author=$6 AND msg_timestamp=$7 AND author=$8",
+            mxid,
+            mx_room,
+            emoji,
+            id_to_str(self.signal_chat_id),
+            self.signal_receiver,
+            self.msg_author.best_identifier,
+            self.msg_timestamp,
+            self.author.best_identifier,
+        )
 
     async def delete(self) -> None:
-        q = ("DELETE FROM reaction WHERE signal_chat_id=$1 AND signal_receiver=$2"
-             "                           AND msg_author=$3 AND msg_timestamp=$4 AND author=$5")
-        await self.db.execute(q, id_to_str(self.signal_chat_id), self.signal_receiver,
-                              self.msg_author.best_identifier, self.msg_timestamp,
-                              self.author.best_identifier)
+        q = (
+            "DELETE FROM reaction WHERE signal_chat_id=$1 AND signal_receiver=$2"
+            "                           AND msg_author=$3 AND msg_timestamp=$4 AND author=$5"
+        )
+        await self.db.execute(
+            q,
+            id_to_str(self.signal_chat_id),
+            self.signal_receiver,
+            self.msg_author.best_identifier,
+            self.msg_timestamp,
+            self.author.best_identifier,
+        )
 
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> 'Reaction':
+    def _from_row(cls, row: asyncpg.Record) -> "Reaction":
         data = {**row}
         chat_id = data.pop("signal_chat_id")
         if data["signal_receiver"]:
@@ -75,25 +99,40 @@ class Reaction:
         return cls(signal_chat_id=chat_id, msg_author=msg_author, author=author, **data)
 
     @classmethod
-    async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional['Reaction']:
-        q = ("SELECT mxid, mx_room, signal_chat_id, signal_receiver,"
-             "       msg_author, msg_timestamp, author, emoji "
-             "FROM reaction WHERE mxid=$1 AND mx_room=$2")
+    async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional["Reaction"]:
+        q = (
+            "SELECT mxid, mx_room, signal_chat_id, signal_receiver,"
+            "       msg_author, msg_timestamp, author, emoji "
+            "FROM reaction WHERE mxid=$1 AND mx_room=$2"
+        )
         row = await cls.db.fetchrow(q, mxid, mx_room)
         if not row:
             return None
         return cls._from_row(row)
 
     @classmethod
-    async def get_by_signal_id(cls, chat_id: Union[GroupID, Address], receiver: str,
-                               msg_author: Address, msg_timestamp: int, author: Address
-                               ) -> Optional['Reaction']:
-        q = ("SELECT mxid, mx_room, signal_chat_id, signal_receiver,"
-             "       msg_author, msg_timestamp, author, emoji "
-             "FROM reaction WHERE signal_chat_id=$1 AND signal_receiver=$2"
-             "                    AND msg_author=$3 AND msg_timestamp=$4 AND author=$5")
-        row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver, msg_author.best_identifier,
-                                    msg_timestamp, author.best_identifier)
+    async def get_by_signal_id(
+        cls,
+        chat_id: Union[GroupID, Address],
+        receiver: str,
+        msg_author: Address,
+        msg_timestamp: int,
+        author: Address,
+    ) -> Optional["Reaction"]:
+        q = (
+            "SELECT mxid, mx_room, signal_chat_id, signal_receiver,"
+            "       msg_author, msg_timestamp, author, emoji "
+            "FROM reaction WHERE signal_chat_id=$1 AND signal_receiver=$2"
+            "                    AND msg_author=$3 AND msg_timestamp=$4 AND author=$5"
+        )
+        row = await cls.db.fetchrow(
+            q,
+            id_to_str(chat_id),
+            receiver,
+            msg_author.best_identifier,
+            msg_timestamp,
+            author.best_identifier,
+        )
         if not row:
             return None
         return cls._from_row(row)

+ 148 - 108
mautrix_signal/db/upgrade.py

@@ -22,97 +22,64 @@ upgrade_table = UpgradeTable()
 
 @upgrade_table.register(description="Initial revision")
 async def upgrade_v1(conn: Connection) -> None:
-    await conn.execute("""CREATE TABLE portal (
-        chat_id     TEXT,
-        receiver    TEXT,
-        mxid        TEXT,
-        name        TEXT,
-        encrypted   BOOLEAN NOT NULL DEFAULT false,
-
-        PRIMARY KEY (chat_id, receiver)
-    )""")
-    await conn.execute("""CREATE TABLE "user" (
-        mxid        TEXT PRIMARY KEY,
-        username    TEXT,
-        uuid        UUID,
-        notice_room TEXT
-    )""")
-    await conn.execute("""CREATE TABLE puppet (
-        uuid      UUID UNIQUE,
-        number    TEXT UNIQUE,
-        name      TEXT,
-
-        uuid_registered   BOOLEAN NOT NULL DEFAULT false,
-        number_registered BOOLEAN NOT NULL DEFAULT false,
-
-        custom_mxid  TEXT,
-        access_token TEXT,
-        next_batch   TEXT
-    )""")
-    await conn.execute("""CREATE TABLE user_portal (
-        "user"          TEXT,
-        portal          TEXT,
-        portal_receiver TEXT,
-        in_community    BOOLEAN NOT NULL DEFAULT false,
-
-        FOREIGN KEY (portal, portal_receiver) REFERENCES portal(chat_id, receiver)
-            ON UPDATE CASCADE ON DELETE CASCADE
-    )""")
-    await conn.execute("""CREATE TABLE message (
-        mxid    TEXT NOT NULL,
-        mx_room TEXT NOT NULL,
-        sender          UUID,
-        timestamp       BIGINT,
-        signal_chat_id  TEXT,
-        signal_receiver TEXT,
-
-        PRIMARY KEY (sender, timestamp, signal_chat_id, signal_receiver),
-        FOREIGN KEY (signal_chat_id, signal_receiver) REFERENCES portal(chat_id, receiver)
-            ON UPDATE CASCADE ON DELETE CASCADE,
-        UNIQUE (mxid, mx_room)
-    )""")
-    await conn.execute("""CREATE TABLE reaction (
-        mxid    TEXT NOT NULL,
-        mx_room TEXT NOT NULL,
-
-        signal_chat_id  TEXT   NOT NULL,
-        signal_receiver TEXT   NOT NULL,
-        msg_author      UUID   NOT NULL,
-        msg_timestamp   BIGINT NOT NULL,
-        author          UUID   NOT NULL,
-
-        emoji TEXT NOT NULL,
-
-        PRIMARY KEY (signal_chat_id, signal_receiver, msg_author, msg_timestamp, author),
-        FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver)
-            REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver)
-            ON DELETE CASCADE ON UPDATE CASCADE,
-        UNIQUE (mxid, mx_room)
-    )""")
-
-
-@upgrade_table.register(description="Add avatar info to portal table")
-async def upgrade_v2(conn: Connection) -> None:
-    await conn.execute("ALTER TABLE portal ADD COLUMN avatar_hash TEXT")
-    await conn.execute("ALTER TABLE portal ADD COLUMN avatar_url TEXT")
-
-
-@upgrade_table.register(description="Add double-puppeting base_url to puppe table")
-async def upgrade_v3(conn: Connection) -> None:
-    await conn.execute("ALTER TABLE puppet ADD COLUMN base_url TEXT")
-
-
-@upgrade_table.register(description="Allow phone numbers as message sender identifiers")
-async def upgrade_v4(conn: Connection, scheme: str) -> None:
-    if scheme == "sqlite":
-        # SQLite doesn't have anything in the tables yet,
-        # so just recreate them without migrating data
-        await conn.execute("DROP TABLE message")
-        await conn.execute("DROP TABLE reaction")
-        await conn.execute("""CREATE TABLE message (
+    await conn.execute(
+        """
+        CREATE TABLE portal (
+            chat_id     TEXT,
+            receiver    TEXT,
+            mxid        TEXT,
+            name        TEXT,
+            encrypted   BOOLEAN NOT NULL DEFAULT false,
+
+            PRIMARY KEY (chat_id, receiver)
+        )
+        """
+    )
+    await conn.execute(
+        """
+        CREATE TABLE "user" (
+            mxid        TEXT PRIMARY KEY,
+            username    TEXT,
+            uuid        UUID,
+            notice_room TEXT
+        )
+        """
+    )
+    await conn.execute(
+        """
+        CREATE TABLE puppet (
+            uuid      UUID UNIQUE,
+            number    TEXT UNIQUE,
+            name      TEXT,
+
+            uuid_registered   BOOLEAN NOT NULL DEFAULT false,
+            number_registered BOOLEAN NOT NULL DEFAULT false,
+
+            custom_mxid  TEXT,
+            access_token TEXT,
+            next_batch   TEXT
+        )
+        """
+    )
+    await conn.execute(
+        """
+        CREATE TABLE user_portal (
+            "user"          TEXT,
+            portal          TEXT,
+            portal_receiver TEXT,
+            in_community    BOOLEAN NOT NULL DEFAULT false,
+
+            FOREIGN KEY (portal, portal_receiver) REFERENCES portal(chat_id, receiver)
+                ON UPDATE CASCADE ON DELETE CASCADE
+        )
+        """
+    )
+    await conn.execute(
+        """
+        CREATE TABLE message (
             mxid    TEXT NOT NULL,
             mx_room TEXT NOT NULL,
-            sender          TEXT,
+            sender          UUID,
             timestamp       BIGINT,
             signal_chat_id  TEXT,
             signal_receiver TEXT,
@@ -121,16 +88,20 @@ async def upgrade_v4(conn: Connection, scheme: str) -> None:
             FOREIGN KEY (signal_chat_id, signal_receiver) REFERENCES portal(chat_id, receiver)
                 ON UPDATE CASCADE ON DELETE CASCADE,
             UNIQUE (mxid, mx_room)
-        )""")
-        await conn.execute("""CREATE TABLE reaction (
+        )
+        """
+    )
+    await conn.execute(
+        """
+        CREATE TABLE reaction (
             mxid    TEXT NOT NULL,
             mx_room TEXT NOT NULL,
 
             signal_chat_id  TEXT   NOT NULL,
             signal_receiver TEXT   NOT NULL,
-            msg_author      TEXT   NOT NULL,
+            msg_author      UUID   NOT NULL,
             msg_timestamp   BIGINT NOT NULL,
-            author          TEXT   NOT NULL,
+            author          UUID   NOT NULL,
 
             emoji TEXT NOT NULL,
 
@@ -139,19 +110,84 @@ async def upgrade_v4(conn: Connection, scheme: str) -> None:
                 REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver)
                 ON DELETE CASCADE ON UPDATE CASCADE,
             UNIQUE (mxid, mx_room)
-        )""")
+        )
+        """
+    )
+
+
+@upgrade_table.register(description="Add avatar info to portal table")
+async def upgrade_v2(conn: Connection) -> None:
+    await conn.execute("ALTER TABLE portal ADD COLUMN avatar_hash TEXT")
+    await conn.execute("ALTER TABLE portal ADD COLUMN avatar_url TEXT")
+
+
+@upgrade_table.register(description="Add double-puppeting base_url to puppe table")
+async def upgrade_v3(conn: Connection) -> None:
+    await conn.execute("ALTER TABLE puppet ADD COLUMN base_url TEXT")
+
+
+@upgrade_table.register(description="Allow phone numbers as message sender identifiers")
+async def upgrade_v4(conn: Connection, scheme: str) -> None:
+    if scheme == "sqlite":
+        # SQLite doesn't have anything in the tables yet,
+        # so just recreate them without migrating data
+        await conn.execute("DROP TABLE message")
+        await conn.execute("DROP TABLE reaction")
+        await conn.execute(
+            """
+            CREATE TABLE message (
+                mxid    TEXT NOT NULL,
+                mx_room TEXT NOT NULL,
+                sender          TEXT,
+                timestamp       BIGINT,
+                signal_chat_id  TEXT,
+                signal_receiver TEXT,
+
+                PRIMARY KEY (sender, timestamp, signal_chat_id, signal_receiver),
+                FOREIGN KEY (signal_chat_id, signal_receiver) REFERENCES portal(chat_id, receiver)
+                    ON UPDATE CASCADE ON DELETE CASCADE,
+                UNIQUE (mxid, mx_room)
+            )
+            """
+        )
+        await conn.execute(
+            """
+            CREATE TABLE reaction (
+                mxid    TEXT NOT NULL,
+                mx_room TEXT NOT NULL,
+
+                signal_chat_id  TEXT   NOT NULL,
+                signal_receiver TEXT   NOT NULL,
+                msg_author      TEXT   NOT NULL,
+                msg_timestamp   BIGINT NOT NULL,
+                author          TEXT   NOT NULL,
+
+                emoji TEXT NOT NULL,
+
+                PRIMARY KEY (signal_chat_id, signal_receiver, msg_author, msg_timestamp, author),
+                FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver)
+                    REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver)
+                    ON DELETE CASCADE ON UPDATE CASCADE,
+                UNIQUE (mxid, mx_room)
+            )
+            """
+        )
         return
 
-    cname = await conn.fetchval("SELECT constraint_name FROM information_schema.table_constraints "
-                                "WHERE table_name='reaction' AND constraint_name LIKE '%_fkey'")
+    cname = await conn.fetchval(
+        "SELECT constraint_name FROM information_schema.table_constraints "
+        "WHERE table_name='reaction' AND constraint_name LIKE '%_fkey'"
+    )
     await conn.execute(f"ALTER TABLE reaction DROP CONSTRAINT {cname}")
     await conn.execute("ALTER TABLE reaction ALTER COLUMN msg_author SET DATA TYPE TEXT")
     await conn.execute("ALTER TABLE reaction ALTER COLUMN author SET DATA TYPE TEXT")
     await conn.execute("ALTER TABLE message ALTER COLUMN sender SET DATA TYPE TEXT")
-    await conn.execute(f"ALTER TABLE reaction ADD CONSTRAINT {cname} "
-                       "FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver) "
-                       "  REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver) "
-                       "  ON DELETE CASCADE ON UPDATE CASCADE")
+    await conn.execute(
+        f"ALTER TABLE reaction ADD CONSTRAINT {cname} "
+        "FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver) "
+        "  REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver) "
+        "  ON DELETE CASCADE ON UPDATE CASCADE"
+    )
 
 
 @upgrade_table.register(description="Add avatar info to puppet table")
@@ -179,12 +215,16 @@ async def upgrade_v7(conn: Connection) -> None:
 
 @upgrade_table.register(description="Add support for disappearing messages")
 async def upgrade_v8(conn: Connection) -> None:
-    await conn.execute("""CREATE TABLE disappearing_message (
-        room_id             TEXT,
-        mxid                TEXT,
-        expiration_seconds  BIGINT,
-        expiration_ts       BIGINT,
-
-        PRIMARY KEY (room_id, mxid)
-    )""")
+    await conn.execute(
+        """
+        CREATE TABLE disappearing_message (
+            room_id             TEXT,
+            mxid                TEXT,
+            expiration_seconds  BIGINT,
+            expiration_ts       BIGINT,
+
+            PRIMARY KEY (room_id, mxid)
+        )
+        """
+    )
     await conn.execute("ALTER TABLE portal ADD COLUMN expiration_time BIGINT")

+ 5 - 6
mautrix_signal/db/user.py

@@ -34,8 +34,7 @@ class User:
     notice_room: Optional[RoomID]
 
     async def insert(self) -> None:
-        q = ('INSERT INTO "user" (mxid, username, uuid, notice_room) '
-             'VALUES ($1, $2, $3, $4)')
+        q = 'INSERT INTO "user" (mxid, username, uuid, notice_room) ' "VALUES ($1, $2, $3, $4)"
         await self.db.execute(q, self.mxid, self.username, self.uuid, self.notice_room)
 
     async def update(self) -> None:
@@ -43,7 +42,7 @@ class User:
         await self.db.execute(q, self.username, self.uuid, self.notice_room, self.mxid)
 
     @classmethod
-    async def get_by_mxid(cls, mxid: UserID) -> Optional['User']:
+    async def get_by_mxid(cls, mxid: UserID) -> Optional["User"]:
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE mxid=$1'
         row = await cls.db.fetchrow(q, mxid)
         if not row:
@@ -51,7 +50,7 @@ class User:
         return cls(**row)
 
     @classmethod
-    async def get_by_username(cls, username: str) -> Optional['User']:
+    async def get_by_username(cls, username: str) -> Optional["User"]:
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE username=$1'
         row = await cls.db.fetchrow(q, username)
         if not row:
@@ -59,7 +58,7 @@ class User:
         return cls(**row)
 
     @classmethod
-    async def get_by_uuid(cls, uuid: UUID) -> Optional['User']:
+    async def get_by_uuid(cls, uuid: UUID) -> Optional["User"]:
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE uuid=$1'
         row = await cls.db.fetchrow(q, uuid)
         if not row:
@@ -67,7 +66,7 @@ class User:
         return cls(**row)
 
     @classmethod
-    async def all_logged_in(cls) -> List['User']:
+    async def all_logged_in(cls) -> List["User"]:
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE username IS NOT NULL'
         rows = await cls.db.fetch(q)
         return [cls(**row) for row in rows]

+ 24 - 11
mautrix_signal/formatter.py

@@ -19,8 +19,13 @@ import struct
 
 from mausignald.types import MessageData, Address, Mention
 from mautrix.types import TextMessageEventContent, MessageType, Format
-from mautrix.util.formatter import (MatrixParser as BaseMatrixParser, EntityString, SimpleEntity,
-                                    EntityType, MarkdownString)
+from mautrix.util.formatter import (
+    MatrixParser as BaseMatrixParser,
+    EntityString,
+    SimpleEntity,
+    EntityType,
+    MarkdownString,
+)
 
 from . import puppet as pu, user as u
 
@@ -29,14 +34,16 @@ from . import puppet as pu, user as u
 # I don't know if this is how Signal actually calculates lengths, but it seems
 # to work better than plain len()
 def add_surrogate(text: str) -> str:
-    return ''.join(
-        ''.join(chr(y) for y in struct.unpack('<HH', x.encode('utf-16le')))
-        if (0x10000 <= ord(x) <= 0x10FFFF) else x for x in text
+    return "".join(
+        "".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16le")))
+        if (0x10000 <= ord(x) <= 0x10FFFF)
+        else x
+        for x in text
     )
 
 
 def del_surrogate(text: str) -> str:
-    return text.encode('utf-16', 'surrogatepass').decode('utf-16')
+    return text.encode("utf-16", "surrogatepass").decode("utf-16")
 
 
 async def signal_to_matrix(message: MessageData) -> TextMessageEventContent:
@@ -47,7 +54,7 @@ async def signal_to_matrix(message: MessageData) -> TextMessageEventContent:
         html_chunks = []
         last_offset = 0
         for mention in message.mentions:
-            before = surrogated_text[last_offset:mention.start]
+            before = surrogated_text[last_offset : mention.start]
             last_offset = mention.start + mention.length
 
             text_chunks.append(before)
@@ -67,11 +74,17 @@ async def signal_to_matrix(message: MessageData) -> TextMessageEventContent:
 
 # TODO this has a lot of duplication with mautrix-facebook, maybe move to mautrix-python
 class SignalFormatString(EntityString[SimpleEntity, EntityType], MarkdownString):
-    def format(self, entity_type: EntityType, **kwargs) -> 'SignalFormatString':
+    def format(self, entity_type: EntityType, **kwargs) -> "SignalFormatString":
         prefix = suffix = ""
         if entity_type == EntityType.USER_MENTION:
-            self.entities.append(SimpleEntity(type=entity_type, offset=0, length=len(self.text),
-                                              extra_info={"user_id": kwargs["user_id"]}))
+            self.entities.append(
+                SimpleEntity(
+                    type=entity_type,
+                    offset=0,
+                    length=len(self.text),
+                    extra_info={"user_id": kwargs["user_id"]},
+                )
+            )
             return self
         elif entity_type == EntityType.BOLD:
             prefix = suffix = "**"
@@ -80,7 +93,7 @@ class SignalFormatString(EntityString[SimpleEntity, EntityType], MarkdownString)
         elif entity_type == EntityType.STRIKETHROUGH:
             prefix = suffix = "~~"
         elif entity_type == EntityType.URL:
-            if kwargs['url'] != self.text:
+            if kwargs["url"] != self.text:
                 suffix = f" ({kwargs['url']})"
         elif entity_type == EntityType.PREFORMATTED:
             prefix = f"```{kwargs['language']}\n"

+ 1 - 2
mautrix_signal/get_version.py

@@ -34,8 +34,7 @@ else:
     git_revision_url = None
     git_tag = None
 
-git_tag_url = (f"https://github.com/mautrix/signal/releases/tag/{git_tag}"
-               if git_tag else None)
+git_tag_url = f"https://github.com/mautrix/signal/releases/tag/{git_tag}" if git_tag else None
 
 if git_tag and __version__ == git_tag[1:].replace("-", ""):
     version = __version__

+ 50 - 24
mautrix_signal/matrix.py

@@ -16,9 +16,22 @@
 from typing import List, Union, TYPE_CHECKING
 
 from mautrix.bridge import BaseMatrixHandler
-from mautrix.types import (Event, ReactionEvent, StateEvent, RoomID, EventID, UserID, TypingEvent,
-                           ReactionEventContent, RelationType, EventType, ReceiptEvent,
-                           PresenceEvent, RedactionEvent, SingleReceiptEventContent)
+from mautrix.types import (
+    Event,
+    ReactionEvent,
+    StateEvent,
+    RoomID,
+    EventID,
+    UserID,
+    TypingEvent,
+    ReactionEventContent,
+    RelationType,
+    EventType,
+    ReceiptEvent,
+    PresenceEvent,
+    RedactionEvent,
+    SingleReceiptEventContent,
+)
 
 from mautrix_signal.db.disappearing_message import DisappearingMessage
 
@@ -30,9 +43,9 @@ if TYPE_CHECKING:
 
 
 class MatrixHandler(BaseMatrixHandler):
-    signal: 's.SignalHandler'
+    signal: "s.SignalHandler"
 
-    def __init__(self, bridge: 'SignalBridge') -> None:
+    def __init__(self, bridge: "SignalBridge") -> None:
         prefix, suffix = bridge.config["bridge.username_template"].format(userid=":").split(":")
         homeserver = bridge.config["homeserver.domain"]
         self.user_id_prefix = f"@{prefix}"
@@ -41,13 +54,14 @@ class MatrixHandler(BaseMatrixHandler):
 
         super().__init__(bridge=bridge)
 
-    async def send_welcome_message(self, room_id: RoomID, inviter: 'u.User') -> None:
+    async def send_welcome_message(self, room_id: RoomID, inviter: "u.User") -> None:
         await super().send_welcome_message(room_id, inviter)
         if not inviter.notice_room:
             inviter.notice_room = room_id
             await inviter.update()
-            await self.az.intent.send_notice(room_id, "This room has been marked as your "
-                                                      "Signal bridge notice room.")
+            await self.az.intent.send_notice(
+                room_id, "This room has been marked as your Signal bridge notice room."
+            )
 
     async def handle_leave(self, room_id: RoomID, user_id: UserID, event_id: EventID) -> None:
         portal = await po.Portal.get_by_mxid(room_id)
@@ -72,11 +86,14 @@ class MatrixHandler(BaseMatrixHandler):
         await portal.handle_matrix_join(user)
 
     @classmethod
-    async def handle_reaction(cls, room_id: RoomID, user_id: UserID, event_id: EventID,
-                              content: ReactionEventContent) -> None:
+    async def handle_reaction(
+        cls, room_id: RoomID, user_id: UserID, event_id: EventID, content: ReactionEventContent
+    ) -> None:
         if content.relates_to.rel_type != RelationType.ANNOTATION:
-            cls.log.debug(f"Ignoring m.reaction event in {room_id} from {user_id} with unexpected "
-                          f"relation type {content.relates_to.rel_type}")
+            cls.log.debug(
+                f"Ignoring m.reaction event in {room_id} from {user_id} with unexpected "
+                f"relation type {content.relates_to.rel_type}"
+            )
             return
         user = await u.User.get_by_mxid(user_id)
         if not user:
@@ -86,12 +103,14 @@ class MatrixHandler(BaseMatrixHandler):
         if not portal:
             return
 
-        await portal.handle_matrix_reaction(user, event_id, content.relates_to.event_id,
-                                            content.relates_to.key)
+        await portal.handle_matrix_reaction(
+            user, event_id, content.relates_to.event_id, content.relates_to.key
+        )
 
     @staticmethod
-    async def handle_redaction(room_id: RoomID, user_id: UserID, event_id: EventID,
-                               redaction_event_id: EventID) -> None:
+    async def handle_redaction(
+        room_id: RoomID, user_id: UserID, event_id: EventID, redaction_event_id: EventID
+    ) -> None:
         user = await u.User.get_by_mxid(user_id)
         if not user:
             return
@@ -102,8 +121,13 @@ class MatrixHandler(BaseMatrixHandler):
 
         await portal.handle_matrix_redaction(user, event_id, redaction_event_id)
 
-    async def handle_read_receipt(self, user: 'u.User', portal: 'po.Portal', event_id: EventID,
-                                  data: SingleReceiptEventContent) -> None:
+    async def handle_read_receipt(
+        self,
+        user: "u.User",
+        portal: "po.Portal",
+        event_id: EventID,
+        data: SingleReceiptEventContent,
+    ) -> None:
         await portal.handle_read_receipt(event_id, data)
 
         message = await DBMessage.get_by_mxid(event_id, portal.mxid)
@@ -111,8 +135,9 @@ class MatrixHandler(BaseMatrixHandler):
             return
 
         user.log.trace(f"Sending read receipt for {message.timestamp} to {message.sender}")
-        await self.signal.send_receipt(user.username, message.sender,
-                                       timestamps=[message.timestamp], when=data.ts, read=True)
+        await self.signal.send_receipt(
+            user.username, message.sender, timestamps=[message.timestamp], when=data.ts, read=True
+        )
 
     async def handle_typing(self, room_id: RoomID, typing: List[UserID]) -> None:
         pass
@@ -134,8 +159,9 @@ class MatrixHandler(BaseMatrixHandler):
             evt: RedactionEvent
             await self.handle_redaction(evt.room_id, evt.sender, evt.redacts, evt.event_id)
 
-    async def handle_ephemeral_event(self, evt: Union[ReceiptEvent, PresenceEvent, TypingEvent]
-                                     ) -> None:
+    async def handle_ephemeral_event(
+        self, evt: Union[ReceiptEvent, PresenceEvent, TypingEvent]
+    ) -> None:
         if evt.type == EventType.TYPING:
             await self.handle_typing(evt.room_id, evt.content.user_ids)
         else:
@@ -157,8 +183,8 @@ class MatrixHandler(BaseMatrixHandler):
         elif evt.type == EventType.ROOM_AVATAR:
             await portal.handle_matrix_avatar(user, evt.content.url)
 
-    async def allow_message(self, user: 'u.User') -> bool:
+    async def allow_message(self, user: "u.User") -> bool:
         return user.relay_whitelisted
 
-    async def allow_bridging_message(self, user: 'u.User', portal: 'po.Portal') -> bool:
+    async def allow_bridging_message(self, user: "u.User", portal: "po.Portal") -> bool:
         return portal.has_relay or await user.is_logged_in()

Dosya farkı çok büyük olduğundan ihmal edildi
+ 374 - 199
mautrix_signal/portal.py


+ 88 - 38
mautrix_signal/puppet.py

@@ -13,8 +13,17 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import (Optional, Dict, AsyncIterable, Awaitable, AsyncGenerator, Union, Tuple,
-                    TYPE_CHECKING, cast)
+from typing import (
+    Optional,
+    Dict,
+    AsyncIterable,
+    Awaitable,
+    AsyncGenerator,
+    Union,
+    Tuple,
+    TYPE_CHECKING,
+    cast,
+)
 from uuid import UUID
 import hashlib
 import asyncio
@@ -25,8 +34,14 @@ from yarl import URL
 from mausignald.types import Address, Contact, Profile
 from mautrix.bridge import BasePuppet, async_getter_lock
 from mautrix.appservice import IntentAPI
-from mautrix.types import (UserID, SyncToken, RoomID, ContentURI, EventType,
-                           PowerLevelStateEventContent)
+from mautrix.types import (
+    UserID,
+    SyncToken,
+    RoomID,
+    ContentURI,
+    EventType,
+    PowerLevelStateEventContent,
+)
 from mautrix.errors import MForbidden
 from mautrix.util.simple_template import SimpleTemplate
 
@@ -44,9 +59,9 @@ except ImportError:
 
 
 class Puppet(DBPuppet, BasePuppet):
-    by_uuid: Dict[UUID, 'Puppet'] = {}
-    by_number: Dict[str, 'Puppet'] = {}
-    by_custom_mxid: Dict[UserID, 'Puppet'] = {}
+    by_uuid: Dict[UUID, "Puppet"] = {}
+    by_number: Dict[str, "Puppet"] = {}
+    by_custom_mxid: Dict[UserID, "Puppet"] = {}
     hs_domain: str
     mxid_template: SimpleTemplate[str]
 
@@ -58,17 +73,37 @@ class Puppet(DBPuppet, BasePuppet):
     _uuid_lock: asyncio.Lock
     _update_info_lock: asyncio.Lock
 
-    def __init__(self, uuid: Optional[UUID], number: Optional[str], name: Optional[str] = None,
-                 avatar_url: Optional[ContentURI] = None, avatar_hash: Optional[str] = None,
-                 name_set: bool = False, avatar_set: bool = False, uuid_registered: bool = False,
-                 number_registered: bool = False, custom_mxid: Optional[UserID] = None,
-                 access_token: Optional[str] = None, next_batch: Optional[SyncToken] = None,
-                 base_url: Optional[URL] = None) -> None:
-        super().__init__(uuid=uuid, number=number, name=name, avatar_url=avatar_url,
-                         avatar_hash=avatar_hash, name_set=name_set, avatar_set=avatar_set,
-                         uuid_registered=uuid_registered, number_registered=number_registered,
-                         custom_mxid=custom_mxid, access_token=access_token, next_batch=next_batch,
-                         base_url=base_url)
+    def __init__(
+        self,
+        uuid: Optional[UUID],
+        number: Optional[str],
+        name: Optional[str] = None,
+        avatar_url: Optional[ContentURI] = None,
+        avatar_hash: Optional[str] = None,
+        name_set: bool = False,
+        avatar_set: bool = False,
+        uuid_registered: bool = False,
+        number_registered: bool = False,
+        custom_mxid: Optional[UserID] = None,
+        access_token: Optional[str] = None,
+        next_batch: Optional[SyncToken] = None,
+        base_url: Optional[URL] = None,
+    ) -> None:
+        super().__init__(
+            uuid=uuid,
+            number=number,
+            name=name,
+            avatar_url=avatar_url,
+            avatar_hash=avatar_hash,
+            name_set=name_set,
+            avatar_set=avatar_set,
+            uuid_registered=uuid_registered,
+            number_registered=number_registered,
+            custom_mxid=custom_mxid,
+            access_token=access_token,
+            next_batch=next_batch,
+            base_url=base_url,
+        )
         self.log = self.log.getChild(str(uuid) if uuid else number)
 
         self.default_mxid = self.get_mxid_from_id(self.address)
@@ -79,25 +114,34 @@ class Puppet(DBPuppet, BasePuppet):
         self._update_info_lock = asyncio.Lock()
 
     @classmethod
-    def init_cls(cls, bridge: 'SignalBridge') -> AsyncIterable[Awaitable[None]]:
+    def init_cls(cls, bridge: "SignalBridge") -> AsyncIterable[Awaitable[None]]:
         cls.config = bridge.config
         cls.loop = bridge.loop
         cls.mx = bridge.matrix
         cls.az = bridge.az
         cls.hs_domain = cls.config["homeserver.domain"]
-        cls.mxid_template = SimpleTemplate(cls.config["bridge.username_template"], "userid",
-                                           prefix="@", suffix=f":{cls.hs_domain}", type=str)
+        cls.mxid_template = SimpleTemplate(
+            cls.config["bridge.username_template"],
+            "userid",
+            prefix="@",
+            suffix=f":{cls.hs_domain}",
+            type=str,
+        )
         cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
 
-        cls.homeserver_url_map = {server: URL(url) for server, url
-                                  in cls.config["bridge.double_puppet_server_map"].items()}
+        cls.homeserver_url_map = {
+            server: URL(url)
+            for server, url in cls.config["bridge.double_puppet_server_map"].items()
+        }
         cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
-        cls.login_shared_secret_map = {server: secret.encode("utf-8") for server, secret
-                                       in cls.config["bridge.login_shared_secret_map"].items()}
+        cls.login_shared_secret_map = {
+            server: secret.encode("utf-8")
+            for server, secret in cls.config["bridge.login_shared_secret_map"].items()
+        }
         cls.login_device_name = "Signal Bridge"
         return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
 
-    def intent_for(self, portal: 'p.Portal') -> IntentAPI:
+    def intent_for(self, portal: "p.Portal") -> IntentAPI:
         if portal.chat_id == self.address:
             return self.default_mxid_intent
         return self.intent
@@ -167,8 +211,10 @@ class Puppet(DBPuppet, BasePuppet):
         try:
             joined_rooms = await prev_intent.get_joined_rooms()
         except MForbidden as e:
-            self.log.debug(f"Got MForbidden ({e.message}) when getting joined rooms of old mxid, "
-                           "assuming there are no rooms to rejoin")
+            self.log.debug(
+                f"Got MForbidden ({e.message}) when getting joined rooms of old mxid, "
+                "assuming there are no rooms to rejoin"
+            )
             return
         for room_id in joined_rooms:
             await prev_intent.invite_user(room_id, self.default_mxid)
@@ -176,8 +222,9 @@ class Puppet(DBPuppet, BasePuppet):
             await prev_intent.leave_room(room_id)
             await new_intent.join_room_by_id(room_id)
 
-    async def _migrate_powers(self, prev_intent: IntentAPI, new_intent: IntentAPI, room_id: RoomID
-                              ) -> None:
+    async def _migrate_powers(
+        self, prev_intent: IntentAPI, new_intent: IntentAPI, room_id: RoomID
+    ) -> None:
         try:
             powers: PowerLevelStateEventContent
             powers = await prev_intent.get_state_event(room_id, EventType.ROOM_POWER_LEVELS)
@@ -260,8 +307,11 @@ class Puppet(DBPuppet, BasePuppet):
         return False
 
     @staticmethod
-    async def upload_avatar(self: Union['Puppet', 'p.Portal'], path: str, intent: IntentAPI,
-                            ) -> Union[bool, Tuple[str, ContentURI]]:
+    async def upload_avatar(
+        self: Union["Puppet", "p.Portal"],
+        path: str,
+        intent: IntentAPI,
+    ) -> Union[bool, Tuple[str, ContentURI]]:
         if not path:
             return False
         if not path.startswith("/"):
@@ -321,7 +371,7 @@ class Puppet(DBPuppet, BasePuppet):
         await self.update()
 
     @classmethod
-    async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']:
+    async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional["Puppet"]:
         address = cls.get_id_from_mxid(mxid)
         if not address:
             return None
@@ -329,7 +379,7 @@ class Puppet(DBPuppet, BasePuppet):
 
     @classmethod
     @async_getter_lock
-    async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
+    async def get_by_custom_mxid(cls, mxid: UserID) -> Optional["Puppet"]:
         try:
             return cls.by_custom_mxid[mxid]
         except KeyError:
@@ -348,7 +398,7 @@ class Puppet(DBPuppet, BasePuppet):
         if not identifier:
             return None
         if identifier.startswith("phone_"):
-            return Address(number="+" + identifier[len("phone_"):])
+            return Address(number="+" + identifier[len("phone_") :])
         else:
             try:
                 return Address(uuid=UUID(identifier.upper()))
@@ -367,7 +417,7 @@ class Puppet(DBPuppet, BasePuppet):
 
     @classmethod
     @async_getter_lock
-    async def get_by_address(cls, address: Address, create: bool = True) -> Optional['Puppet']:
+    async def get_by_address(cls, address: Address, create: bool = True) -> Optional["Puppet"]:
         puppet = await cls._get_by_address(address, create)
         if puppet and address.uuid and not puppet.uuid:
             # We found a UUID for this user, store it ASAP
@@ -375,7 +425,7 @@ class Puppet(DBPuppet, BasePuppet):
         return puppet
 
     @classmethod
-    async def _get_by_address(cls, address: Address, create: bool = True) -> Optional['Puppet']:
+    async def _get_by_address(cls, address: Address, create: bool = True) -> Optional["Puppet"]:
         if not address.is_valid:
             raise ValueError("Empty address")
         if address.uuid:
@@ -403,7 +453,7 @@ class Puppet(DBPuppet, BasePuppet):
         return None
 
     @classmethod
-    async def all_with_custom_mxid(cls) -> AsyncGenerator['Puppet', None]:
+    async def all_with_custom_mxid(cls) -> AsyncGenerator["Puppet", None]:
         puppets = await super().all_with_custom_mxid()
         puppet: cls
         for index, puppet in enumerate(puppets):

+ 52 - 25
mautrix_signal/signal.py

@@ -18,9 +18,17 @@ import asyncio
 import logging
 
 from mausignald import SignaldClient
-from mausignald.types import (Message, MessageData, Address, TypingNotification, TypingAction,
-                              OwnReadReceipt, Receipt, ReceiptType,
-                              WebsocketConnectionStateChangeEvent)
+from mausignald.types import (
+    Message,
+    MessageData,
+    Address,
+    TypingNotification,
+    TypingAction,
+    OwnReadReceipt,
+    Receipt,
+    ReceiptType,
+    WebsocketConnectionStateChangeEvent,
+)
 from mautrix.util.logging import TraceLogger
 
 from .db import Message as DBMessage
@@ -39,13 +47,14 @@ class SignalHandler(SignaldClient):
     data_dir: str
     delete_unknown_accounts: bool
 
-    def __init__(self, bridge: 'SignalBridge') -> None:
+    def __init__(self, bridge: "SignalBridge") -> None:
         super().__init__(bridge.config["signal.socket_path"], loop=bridge.loop)
         self.data_dir = bridge.config["signal.data_dir"]
         self.delete_unknown_accounts = bridge.config["signal.delete_unknown_accounts_on_start"]
         self.add_event_handler(Message, self.on_message)
-        self.add_event_handler(WebsocketConnectionStateChangeEvent,
-                               self.on_websocket_connection_state_change)
+        self.add_event_handler(
+            WebsocketConnectionStateChangeEvent, self.on_websocket_connection_state_change
+        )
 
     async def on_message(self, evt: Message) -> None:
         sender = await pu.Puppet.get_by_address(evt.source)
@@ -62,8 +71,12 @@ class SignalHandler(SignaldClient):
             if evt.sync_message.read_messages:
                 await self.handle_own_receipts(sender, evt.sync_message.read_messages)
             if evt.sync_message.sent:
-                await self.handle_message(user, sender, evt.sync_message.sent.message,
-                                          addr_override=evt.sync_message.sent.destination)
+                await self.handle_message(
+                    user,
+                    sender,
+                    evt.sync_message.sent.message,
+                    addr_override=evt.sync_message.sent.destination,
+                )
             if evt.sync_message.typing:
                 # Typing notification from own device
                 pass
@@ -75,12 +88,19 @@ class SignalHandler(SignaldClient):
                 await user.sync_groups()
 
     @staticmethod
-    async def on_websocket_connection_state_change(evt: WebsocketConnectionStateChangeEvent) -> None:
+    async def on_websocket_connection_state_change(
+        evt: WebsocketConnectionStateChangeEvent,
+    ) -> None:
         user = await u.User.get_by_username(evt.account)
         user.on_websocket_connection_state_change(evt)
 
-    async def handle_message(self, user: 'u.User', sender: 'pu.Puppet', msg: MessageData,
-                             addr_override: Optional[Address] = None) -> None:
+    async def handle_message(
+        self,
+        user: "u.User",
+        sender: "pu.Puppet",
+        msg: MessageData,
+        addr_override: Optional[Address] = None,
+    ) -> None:
         if msg.profile_key_update:
             self.log.debug("Ignoring profile key update")
             return
@@ -89,18 +109,23 @@ class SignalHandler(SignaldClient):
         elif msg.group:
             portal = await po.Portal.get_by_chat_id(msg.group.group_id, create=True)
         else:
-            portal = await po.Portal.get_by_chat_id(addr_override or sender.address,
-                                                    receiver=user.username, create=True)
+            portal = await po.Portal.get_by_chat_id(
+                addr_override or sender.address, receiver=user.username, create=True
+            )
             if addr_override and not sender.is_real_user:
-                portal.log.debug(f"Ignoring own message {msg.timestamp} as user doesn't have"
-                                 " double puppeting enabled")
+                portal.log.debug(
+                    f"Ignoring own message {msg.timestamp} as user doesn't have double puppeting "
+                    "enabled"
+                )
                 return
         if not portal.mxid:
-            await portal.create_matrix_room(user, (msg.group_v2 or msg.group
-                                                   or addr_override or sender.address))
+            await portal.create_matrix_room(
+                user, (msg.group_v2 or msg.group or addr_override or sender.address)
+            )
             if not portal.mxid:
-                user.log.debug(f"Failed to create room for incoming message {msg.timestamp},"
-                               " dropping message")
+                user.log.debug(
+                    f"Failed to create room for incoming message {msg.timestamp}, dropping message"
+                )
                 return
         elif msg.group_v2 and msg.group_v2.revision > portal.revision:
             self.log.debug(f"Got new revision of {msg.group_v2.id}, updating info")
@@ -117,7 +142,7 @@ class SignalHandler(SignaldClient):
             await portal.update_expires_in_seconds(sender, msg.expires_in_seconds)
 
     @staticmethod
-    async def handle_own_receipts(sender: 'pu.Puppet', receipts: List[OwnReadReceipt]) -> None:
+    async def handle_own_receipts(sender: "pu.Puppet", receipts: List[OwnReadReceipt]) -> None:
         for receipt in receipts:
             puppet = await pu.Puppet.get_by_address(receipt.sender, create=False)
             if not puppet:
@@ -131,8 +156,9 @@ class SignalHandler(SignaldClient):
             await sender.intent_for(portal).mark_read(portal.mxid, message.mxid)
 
     @staticmethod
-    async def handle_typing(user: 'u.User', sender: 'pu.Puppet',
-                            typing: TypingNotification) -> None:
+    async def handle_typing(
+        user: "u.User", sender: "pu.Puppet", typing: TypingNotification
+    ) -> None:
         if typing.group_id:
             portal = await po.Portal.get_by_chat_id(typing.group_id)
         else:
@@ -140,11 +166,12 @@ class SignalHandler(SignaldClient):
         if not portal or not portal.mxid:
             return
         is_typing = typing.action == TypingAction.STARTED
-        await sender.intent_for(portal).set_typing(portal.mxid, is_typing, ignore_cache=True,
-                                                   timeout=SIGNAL_TYPING_TIMEOUT)
+        await sender.intent_for(portal).set_typing(
+            portal.mxid, is_typing, ignore_cache=True, timeout=SIGNAL_TYPING_TIMEOUT
+        )
 
     @staticmethod
-    async def handle_receipt(sender: 'pu.Puppet', receipt: Receipt) -> None:
+    async def handle_receipt(sender: "pu.Puppet", receipt: Receipt) -> None:
         if receipt.type != ReceiptType.READ:
             return
         messages = await DBMessage.find_by_timestamps(receipt.timestamps)

+ 48 - 29
mautrix_signal/user.py

@@ -19,8 +19,15 @@ from typing import Union, Dict, Optional, AsyncGenerator, List, TYPE_CHECKING, c
 from uuid import UUID
 import asyncio
 
-from mausignald.types import (Account, Address, Profile, Group, GroupV2, WebsocketConnectionState,
-                              WebsocketConnectionStateChangeEvent)
+from mausignald.types import (
+    Account,
+    Address,
+    Profile,
+    Group,
+    GroupV2,
+    WebsocketConnectionState,
+    WebsocketConnectionStateChangeEvent,
+)
 from mautrix.bridge import BaseUser, AutologinError, async_getter_lock
 from mautrix.types import UserID, RoomID
 from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
@@ -34,23 +41,25 @@ from . import puppet as pu, portal as po
 if TYPE_CHECKING:
     from .__main__ import SignalBridge
 
-METRIC_CONNECTED = Gauge('bridge_connected', 'Bridge users connected to Signal')
-METRIC_LOGGED_IN = Gauge('bridge_logged_in', 'Bridge users logged into Signal')
+METRIC_CONNECTED = Gauge("bridge_connected", "Bridge users connected to Signal")
+METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Bridge users logged into Signal")
 
-BridgeState.human_readable_errors.update({
-    "logged-out": "You're not logged into Signal",
-    "signal-not-connected": None,
-})
+BridgeState.human_readable_errors.update(
+    {
+        "logged-out": "You're not logged into Signal",
+        "signal-not-connected": None,
+    }
+)
 
 
 class User(DBUser, BaseUser):
-    by_mxid: Dict[UserID, 'User'] = {}
-    by_username: Dict[str, 'User'] = {}
-    by_uuid: Dict[UUID, 'User'] = {}
+    by_mxid: Dict[UserID, "User"] = {}
+    by_username: Dict[str, "User"] = {}
+    by_uuid: Dict[UUID, "User"] = {}
     config: Config
     az: AppService
     loop: asyncio.AbstractEventLoop
-    bridge: 'SignalBridge'
+    bridge: "SignalBridge"
 
     relay_whitelisted: bool
     is_admin: bool
@@ -62,8 +71,13 @@ class User(DBUser, BaseUser):
     _websocket_connection_state: Optional[WebsocketConnectionState]
     _latest_non_transient_disconnect_state: Optional[datetime]
 
-    def __init__(self, mxid: UserID, username: Optional[str] = None, uuid: Optional[UUID] = None,
-                 notice_room: Optional[RoomID] = None) -> None:
+    def __init__(
+        self,
+        mxid: UserID,
+        username: Optional[str] = None,
+        uuid: Optional[UUID] = None,
+        notice_room: Optional[RoomID] = None,
+    ) -> None:
         super().__init__(mxid=mxid, username=username, uuid=uuid, notice_room=notice_room)
         BaseUser.__init__(self)
         self._notice_room_lock = asyncio.Lock()
@@ -74,7 +88,7 @@ class User(DBUser, BaseUser):
         self.relay_whitelisted, self.is_whitelisted, self.is_admin, self.permission_level = perms
 
     @classmethod
-    def init_cls(cls, bridge: 'SignalBridge') -> None:
+    def init_cls(cls, bridge: "SignalBridge") -> None:
         cls.bridge = bridge
         cls.config = bridge.config
         cls.az = bridge.az
@@ -125,7 +139,7 @@ class User(DBUser, BaseUser):
             state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
         return [state]
 
-    async def get_puppet(self) -> Optional['pu.Puppet']:
+    async def get_puppet(self) -> Optional["pu.Puppet"]:
         if not self.address:
             return None
         return await pu.Puppet.get_by_address(self.address)
@@ -139,7 +153,9 @@ class User(DBUser, BaseUser):
         asyncio.create_task(self.sync())
         self._track_metric(METRIC_LOGGED_IN, True)
 
-    def on_websocket_connection_state_change(self, evt: WebsocketConnectionStateChangeEvent) -> None:
+    def on_websocket_connection_state_change(
+        self, evt: WebsocketConnectionStateChangeEvent
+    ) -> None:
         if evt.state == WebsocketConnectionState.CONNECTED:
             self.log.info("Connected to Signal")
             self._track_metric(METRIC_CONNECTED, True)
@@ -147,14 +163,14 @@ class User(DBUser, BaseUser):
             self._connected = True
         else:
             self.log.warning(
-                f"New websocket state from signald: {evt.state}. Error: {evt.exception}")
+                f"New websocket state from signald: {evt.state}. Error: {evt.exception}"
+            )
             self._track_metric(METRIC_CONNECTED, False)
             self._connected = False
 
         bridge_state = {
             # Signald disconnected
             WebsocketConnectionState.SOCKET_DISCONNECTED: BridgeStateEvent.TRANSIENT_DISCONNECT,
-
             # Websocket state reported by signald
             WebsocketConnectionState.DISCONNECTED: (
                 None
@@ -174,6 +190,7 @@ class User(DBUser, BaseUser):
 
         now = datetime.now()
         if bridge_state == BridgeStateEvent.TRANSIENT_DISCONNECT:
+
             async def wait_report_transient_disconnect():
                 # Wait for 10 seconds (that should be enough for the bridge to get connected)
                 # before sending a TRANSIENT_DISCONNECT.
@@ -212,7 +229,7 @@ class User(DBUser, BaseUser):
             self.uuid = puppet.uuid
             self.by_uuid[self.uuid] = self
         if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
-            self.log.info(f"Automatically enabling custom puppet")
+            self.log.info("Automatically enabling custom puppet")
             try:
                 await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
             except AutologinError as e:
@@ -244,8 +261,9 @@ class User(DBUser, BaseUser):
         except Exception:
             self.log.exception("Error while syncing groups")
 
-    async def sync_contact(self, contact: Union[Profile, Address], create_portals: bool = False
-                           ) -> None:
+    async def sync_contact(
+        self, contact: Union[Profile, Address], create_portals: bool = False
+    ) -> None:
         self.log.trace("Syncing contact %s", contact)
         if isinstance(contact, Address):
             address = contact
@@ -258,8 +276,9 @@ class User(DBUser, BaseUser):
         puppet = await pu.Puppet.get_by_address(address)
         await puppet.update_info(profile)
         if create_portals:
-            portal = await po.Portal.get_by_chat_id(puppet.address, receiver=self.username,
-                                                    create=True)
+            portal = await po.Portal.get_by_chat_id(
+                puppet.address, receiver=self.username, create=True
+            )
             await portal.create_matrix_room(self, profile)
 
     async def _sync_group(self, group: Group, create_portals: bool) -> None:
@@ -311,7 +330,7 @@ class User(DBUser, BaseUser):
 
     @classmethod
     @async_getter_lock
-    async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['User']:
+    async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional["User"]:
         # Never allow ghosts to be users
         if pu.Puppet.get_id_from_mxid(mxid):
             return None
@@ -335,7 +354,7 @@ class User(DBUser, BaseUser):
 
     @classmethod
     @async_getter_lock
-    async def get_by_username(cls, username: str) -> Optional['User']:
+    async def get_by_username(cls, username: str) -> Optional["User"]:
         try:
             return cls.by_username[username]
         except KeyError:
@@ -350,7 +369,7 @@ class User(DBUser, BaseUser):
 
     @classmethod
     @async_getter_lock
-    async def get_by_uuid(cls, uuid: UUID) -> Optional['User']:
+    async def get_by_uuid(cls, uuid: UUID) -> Optional["User"]:
         try:
             return cls.by_uuid[uuid]
         except KeyError:
@@ -364,7 +383,7 @@ class User(DBUser, BaseUser):
         return None
 
     @classmethod
-    async def get_by_address(cls, address: Address) -> Optional['User']:
+    async def get_by_address(cls, address: Address) -> Optional["User"]:
         if address.uuid:
             return await cls.get_by_uuid(address.uuid)
         elif address.number:
@@ -373,7 +392,7 @@ class User(DBUser, BaseUser):
             raise ValueError("Given address is blank")
 
     @classmethod
-    async def all_logged_in(cls) -> AsyncGenerator['User', None]:
+    async def all_logged_in(cls) -> AsyncGenerator["User", None]:
         users = await super().all_logged_in()
         user: cls
         for user in users:

+ 43 - 30
mautrix_signal/web/provisioning_api.py

@@ -34,9 +34,9 @@ if TYPE_CHECKING:
 class ProvisioningAPI:
     log: TraceLogger = logging.getLogger("mau.web.provisioning")
     app: web.Application
-    bridge: 'SignalBridge'
+    bridge: "SignalBridge"
 
-    def __init__(self, bridge: 'SignalBridge', shared_secret: str) -> None:
+    def __init__(self, bridge: "SignalBridge", shared_secret: str) -> None:
         self.bridge = bridge
         self.app = web.Application()
         self.shared_secret = shared_secret
@@ -70,23 +70,26 @@ class ProvisioningAPI:
     async def login_options(self, _: web.Request) -> web.Response:
         return web.Response(status=200, headers=self._headers)
 
-    async def check_token(self, request: web.Request) -> 'u.User':
+    async def check_token(self, request: web.Request) -> "u.User":
         try:
             token = request.headers["Authorization"]
-            token = token[len("Bearer "):]
+            token = token[len("Bearer ") :]
         except KeyError:
-            raise web.HTTPBadRequest(text='{"error": "Missing Authorization header"}',
-                                     headers=self._headers)
+            raise web.HTTPBadRequest(
+                text='{"error": "Missing Authorization header"}', headers=self._headers
+            )
         except IndexError:
-            raise web.HTTPBadRequest(text='{"error": "Malformed Authorization header"}',
-                                     headers=self._headers)
+            raise web.HTTPBadRequest(
+                text='{"error": "Malformed Authorization header"}', headers=self._headers
+            )
         if token != self.shared_secret:
             raise web.HTTPForbidden(text='{"error": "Invalid token"}', headers=self._headers)
         try:
             user_id = request.query["user_id"]
         except KeyError:
-            raise web.HTTPBadRequest(text='{"error": "Missing user_id query param"}',
-                                     headers=self._headers)
+            raise web.HTTPBadRequest(
+                text='{"error": "Missing user_id query param"}', headers=self._headers
+            )
 
         if not self.bridge.signal.is_connected:
             await self.bridge.signal.wait_for_connected()
@@ -103,7 +106,8 @@ class ProvisioningAPI:
         if await user.is_logged_in():
             try:
                 profile = await self.bridge.signal.get_profile(
-                    username=user.username, address=Address(number=user.username))
+                    username=user.username, address=Address(number=user.username)
+                )
             except Exception as e:
                 self.log.exception(f"Failed to get {user.username}'s profile for whoami")
                 data["signal"] = {
@@ -127,8 +131,9 @@ class ProvisioningAPI:
         user = await self.check_token(request)
 
         if await user.is_logged_in():
-            raise web.HTTPConflict(text='''{"error": "You're already logged in"}''',
-                                   headers=self._headers)
+            raise web.HTTPConflict(
+                text="""{"error": "You're already logged in"}""", headers=self._headers
+            )
 
         try:
             data = await request.json()
@@ -147,17 +152,19 @@ class ProvisioningAPI:
         self.log.debug(f"Returning linking URI for {user.mxid} / {sess.session_id}")
         return web.json_response({"uri": sess.uri}, headers=self._acao_headers)
 
-    async def _shielded_link(self, user: 'u.User', session_id: str, device_name: str) -> Account:
+    async def _shielded_link(self, user: "u.User", session_id: str, device_name: str) -> Account:
         try:
             self.log.debug(f"Starting finish link request for {user.mxid} / {session_id}")
-            account = await self.bridge.signal.finish_link(session_id=session_id, overwrite=True,
-                                                           device_name=device_name)
+            account = await self.bridge.signal.finish_link(
+                session_id=session_id, overwrite=True, device_name=device_name
+            )
         except TimeoutException:
             self.log.warning(f"Timed out waiting for linking to finish (session {session_id})")
             raise
         except Exception:
-            self.log.exception("Fatal error while waiting for linking to finish "
-                               f"(session {session_id})")
+            self.log.exception(
+                "Fatal error while waiting for linking to finish " f"(session {session_id})"
+            )
             raise
         else:
             await user.on_signin(account)
@@ -166,36 +173,42 @@ class ProvisioningAPI:
     async def link_wait(self, request: web.Request) -> web.Response:
         user = await self.check_token(request)
         if not user.command_status or user.command_status["action"] != "Link":
-            raise web.HTTPBadRequest(text='{"error": "No Signal linking started"}',
-                                     headers=self._headers)
+            raise web.HTTPBadRequest(
+                text='{"error": "No Signal linking started"}', headers=self._headers
+            )
         session_id = user.command_status["session_id"]
         device_name = user.command_status["device_name"]
         try:
             account = await asyncio.shield(self._shielded_link(user, session_id, device_name))
         except asyncio.CancelledError:
-            self.log.warning(f"Client cancelled link wait request ({session_id})"
-                             " before it finished")
+            self.log.warning(
+                f"Client cancelled link wait request ({session_id})" " before it finished"
+            )
         except TimeoutException:
-            raise web.HTTPBadRequest(text='{"error": "Signal linking timed out"}',
-                                     headers=self._headers)
+            raise web.HTTPBadRequest(
+                text='{"error": "Signal linking timed out"}', headers=self._headers
+            )
         except InternalError as ie:
             if "java.io.IOException" in ie.exceptions:
                 raise web.HTTPBadRequest(
                     text='{"error": "Signald websocket disconnected before linking finished"}',
                     headers=self._headers,
                 )
-            raise web.HTTPInternalServerError(text='{"error": "Fatal error in Signal linking"}',
-                                              headers=self._headers)
+            raise web.HTTPInternalServerError(
+                text='{"error": "Fatal error in Signal linking"}', headers=self._headers
+            )
         except Exception:
-            raise web.HTTPInternalServerError(text='{"error": "Fatal error in Signal linking"}',
-                                              headers=self._headers)
+            raise web.HTTPInternalServerError(
+                text='{"error": "Fatal error in Signal linking"}', headers=self._headers
+            )
         else:
             return web.json_response(account.address.serialize())
 
     async def logout(self, request: web.Request) -> web.Response:
         user = await self.check_token(request)
         if not await user.is_logged_in():
-            raise web.HTTPNotFound(text='''{"error": "You're not logged in"}''',
-                                   headers=self._headers)
+            raise web.HTTPNotFound(
+                text="""{"error": "You're not logged in"}""", headers=self._headers
+            )
         await user.logout()
         return web.json_response({}, headers=self._acao_headers)

+ 10 - 0
pyproject.toml

@@ -0,0 +1,10 @@
+[tool.isort]
+profile = "black"
+force_to_top = "typing"
+from_first = true
+line_length = 99
+
+[tool.black]
+line-length = 99
+target-version = ["py38"]
+required-version = "21.12b0"

+ 5 - 0
setup.cfg

@@ -0,0 +1,5 @@
+[flake8]
+max-line-length = 99
+extend-ignore =
+    # See https://github.com/PyCQA/pycodestyle/issues/373
+    E203,

Bu fark içinde çok fazla dosya değişikliği olduğu için bazı dosyalar gösterilmiyor