Просмотр исходного кода

Merge remote-tracking branch 'origin/blacken-everything'

Tulir Asokan 3 лет назад
Родитель
Сommit
414904fd30

+ 1 - 1
.editorconfig

@@ -17,5 +17,5 @@ trim_trailing_whitespace = false
 [*.{yaml,yml,py,md}]
 [*.{yaml,yml,py,md}]
 indent_style = space
 indent_style = space
 
 
-[{.gitlab-ci.yml,*.md}]
+[{.gitlab-ci.yml,*.md,.github/workflows/*.yml}]
 indent_size = 2
 indent_size = 2

+ 18 - 0
.github/workflows/python-lint.yml

@@ -0,0 +1,18 @@
+name: Python lint
+
+on: [push, pull_request]
+
+jobs:
+  lint:
+    runs-on: ubuntu-latest
+    steps:
+    - uses: actions/checkout@v2
+    - uses: actions/setup-python@v2
+      with:
+        python-version: "3.10"
+    - uses: isort/isort-action@master
+      with:
+        sortPaths: "./mausignald ./mautrix_signal"
+    - uses: psf/black@21.12b0
+      with:
+        src: "./mausignald ./mautrix_signal"

+ 6 - 2
mausignald/errors.py

@@ -26,8 +26,12 @@ class NotConnected(RPCError):
 
 
 
 
 class ResponseError(RPCError):
 class ResponseError(RPCError):
-    def __init__(self, data: Dict[str, Any], error_type: Optional[str] = None,
-                 message_override: Optional[str] = None) -> None:
+    def __init__(
+        self,
+        data: Dict[str, Any],
+        error_type: Optional[str] = None,
+        message_override: Optional[str] = None,
+    ) -> None:
         self.data = data
         self.data = data
         msg = message_override or data["message"]
         msg = message_override or data["message"]
         if error_type:
         if error_type:

+ 18 - 10
mausignald/rpc.py

@@ -3,11 +3,11 @@
 # This Source Code Form is subject to the terms of the Mozilla Public
 # This Source Code Form is subject to the terms of the Mozilla Public
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
-from typing import Optional, Dict, List, Callable, Awaitable, Any, Tuple
+from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
 from uuid import UUID, uuid4
 from uuid import UUID, uuid4
 import asyncio
 import asyncio
-import logging
 import json
 import json
+import logging
 
 
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
@@ -36,8 +36,12 @@ class SignaldRPCClient:
     _response_waiters: Dict[UUID, asyncio.Future]
     _response_waiters: Dict[UUID, asyncio.Future]
     _rpc_event_handlers: Dict[str, List[EventHandler]]
     _rpc_event_handlers: Dict[str, List[EventHandler]]
 
 
-    def __init__(self, socket_path: str, log: Optional[TraceLogger] = None,
-                 loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
+    def __init__(
+        self,
+        socket_path: str,
+        log: Optional[TraceLogger] = None,
+        loop: Optional[asyncio.AbstractEventLoop] = None,
+    ) -> None:
         self.socket_path = socket_path
         self.socket_path = socket_path
         self.log = log or logging.getLogger("mausignald")
         self.log = log or logging.getLogger("mausignald")
         self.loop = loop or asyncio.get_event_loop()
         self.loop = loop or asyncio.get_event_loop()
@@ -67,7 +71,8 @@ class SignaldRPCClient:
         while True:
         while True:
             try:
             try:
                 self._reader, self._writer = await asyncio.open_unix_connection(
                 self._reader, self._writer = await asyncio.open_unix_connection(
-                    self.socket_path, limit=_SOCKET_LIMIT)
+                    self.socket_path, limit=_SOCKET_LIMIT
+                )
             except OSError as e:
             except OSError as e:
                 self.log.error(f"Connection to {self.socket_path} failed: {e}")
                 self.log.error(f"Connection to {self.socket_path} failed: {e}")
                 await asyncio.sleep(5)
                 await asyncio.sleep(5)
@@ -177,8 +182,9 @@ class SignaldRPCClient:
         self._reader = None
         self._reader = None
         self._writer = None
         self._writer = None
 
 
-    def _create_request(self, command: str, req_id: Optional[UUID] = None, **data: Any
-                        ) -> Tuple[asyncio.Future, Dict[str, Any]]:
+    def _create_request(
+        self, command: str, req_id: Optional[UUID] = None, **data: Any
+    ) -> Tuple[asyncio.Future, Dict[str, Any]]:
         req_id = req_id or uuid4()
         req_id = req_id or uuid4()
         req = {"id": str(req_id), "type": command, **data}
         req = {"id": str(req_id), "type": command, **data}
         self.log.trace("Request %s: %s %s", req_id, command, data)
         self.log.trace("Request %s: %s %s", req_id, command, data)
@@ -196,7 +202,8 @@ class SignaldRPCClient:
             if not waiter.done():
             if not waiter.done():
                 self.log.trace(f"Abandoning response for {req_id}")
                 self.log.trace(f"Abandoning response for {req_id}")
                 waiter.set_exception(
                 waiter.set_exception(
-                    NotConnected("Disconnected from signald before RPC completed"))
+                    NotConnected("Disconnected from signald before RPC completed")
+                )
 
 
     async def _send_request(self, data: Dict[str, Any]) -> None:
     async def _send_request(self, data: Dict[str, Any]) -> None:
         if self._writer is None:
         if self._writer is None:
@@ -207,8 +214,9 @@ class SignaldRPCClient:
         await self._writer.drain()
         await self._writer.drain()
         self.log.trace("Sent data to server server: %s", data)
         self.log.trace("Sent data to server server: %s", data)
 
 
-    async def _raw_request(self, command: str, req_id: Optional[UUID] = None, **data: Any
-                           ) -> Tuple[str, Dict[str, Any]]:
+    async def _raw_request(
+        self, command: str, req_id: Optional[UUID] = None, **data: Any
+    ) -> Tuple[str, Dict[str, Any]]:
         future, data = self._create_request(command, req_id, **data)
         future, data = self._create_request(command, req_id, **data)
         await self._send_request(data)
         await self._send_request(data)
         return await asyncio.shield(future)
         return await asyncio.shield(future)

+ 158 - 79
mausignald/signald.py

@@ -3,18 +3,33 @@
 # This Source Code Form is subject to the terms of the Mozilla Public
 # This Source Code Form is subject to the terms of the Mozilla Public
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
-from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, Set, TypeVar, Type
+from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Type, TypeVar, Union
 import asyncio
 import asyncio
 
 
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
-from .rpc import CONNECT_EVENT, DISCONNECT_EVENT, SignaldRPCClient
 from .errors import UnexpectedError, UnexpectedResponse
 from .errors import UnexpectedError, UnexpectedResponse
-from .types import (Address, Quote, Attachment, Reaction, Account, Message, DeviceInfo, Group,
-                    Profile, GroupID, GetIdentitiesResponse, GroupV2, Mention, LinkSession,
-                    WebsocketConnectionState, WebsocketConnectionStateChangeEvent)
-
-T = TypeVar('T')
+from .rpc import CONNECT_EVENT, DISCONNECT_EVENT, SignaldRPCClient
+from .types import (
+    Account,
+    Address,
+    Attachment,
+    DeviceInfo,
+    GetIdentitiesResponse,
+    Group,
+    GroupID,
+    GroupV2,
+    LinkSession,
+    Mention,
+    Message,
+    Profile,
+    Quote,
+    Reaction,
+    WebsocketConnectionState,
+    WebsocketConnectionStateChangeEvent,
+)
+
+T = TypeVar("T")
 EventHandler = Callable[[T], Awaitable[None]]
 EventHandler = Callable[[T], Awaitable[None]]
 
 
 
 
@@ -22,15 +37,19 @@ class SignaldClient(SignaldRPCClient):
     _event_handlers: Dict[Type[T], List[EventHandler]]
     _event_handlers: Dict[Type[T], List[EventHandler]]
     _subscriptions: Set[str]
     _subscriptions: Set[str]
 
 
-    def __init__(self, socket_path: str = "/var/run/signald/signald.sock",
-                 log: Optional[TraceLogger] = None,
-                 loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
+    def __init__(
+        self,
+        socket_path: str = "/var/run/signald/signald.sock",
+        log: Optional[TraceLogger] = None,
+        loop: Optional[asyncio.AbstractEventLoop] = None,
+    ) -> None:
         super().__init__(socket_path, log, loop)
         super().__init__(socket_path, log, loop)
         self._event_handlers = {}
         self._event_handlers = {}
         self._subscriptions = set()
         self._subscriptions = set()
         self.add_rpc_handler("message", self._parse_message)
         self.add_rpc_handler("message", self._parse_message)
-        self.add_rpc_handler("websocket_connection_state_change",
-                             self._websocket_connection_state_change)
+        self.add_rpc_handler(
+            "websocket_connection_state_change", self._websocket_connection_state_change
+        )
         self.add_rpc_handler("version", self._log_version)
         self.add_rpc_handler("version", self._log_version)
         self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
         self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
         self.add_rpc_handler(DISCONNECT_EVENT, self._on_disconnect)
         self.add_rpc_handler(DISCONNECT_EVENT, self._on_disconnect)
@@ -111,12 +130,13 @@ class SignaldClient(SignaldRPCClient):
                 evt = WebsocketConnectionStateChangeEvent(
                 evt = WebsocketConnectionStateChangeEvent(
                     state=WebsocketConnectionState.SOCKET_DISCONNECTED,
                     state=WebsocketConnectionState.SOCKET_DISCONNECTED,
                     account=username,
                     account=username,
-                    exception="Disconnected from signald"
+                    exception="Disconnected from signald",
                 )
                 )
                 await self._run_event_handler(evt)
                 await self._run_event_handler(evt)
 
 
-    async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
-                       ) -> str:
+    async def register(
+        self, phone: str, voice: bool = False, captcha: Optional[str] = None
+    ) -> str:
         resp = await self.request_v1("register", account=phone, voice=voice, captcha=captcha)
         resp = await self.request_v1("register", account=phone, voice=voice, captcha=captcha)
         return resp["account_id"]
         return resp["account_id"]
 
 
@@ -127,15 +147,18 @@ class SignaldClient(SignaldRPCClient):
     async def start_link(self) -> LinkSession:
     async def start_link(self) -> LinkSession:
         return LinkSession.deserialize(await self.request_v1("generate_linking_uri"))
         return LinkSession.deserialize(await self.request_v1("generate_linking_uri"))
 
 
-    async def finish_link(self, session_id: str, device_name: str = "mausignald",
-                          overwrite: bool = False) -> Account:
-        resp = await self.request_v1("finish_link", device_name=device_name, session_id=session_id,
-                                     overwrite=overwrite)
+    async def finish_link(
+        self, session_id: str, device_name: str = "mausignald", overwrite: bool = False
+    ) -> Account:
+        resp = await self.request_v1(
+            "finish_link", device_name=device_name, session_id=session_id, overwrite=overwrite
+        )
         return Account.deserialize(resp)
         return Account.deserialize(resp)
 
 
     @staticmethod
     @staticmethod
-    def _recipient_to_args(recipient: Union[Address, GroupID], simple_name: bool = False
-                           ) -> Dict[str, Any]:
+    def _recipient_to_args(
+        recipient: Union[Address, GroupID], simple_name: bool = False
+    ) -> Dict[str, Any]:
         if isinstance(recipient, Address):
         if isinstance(recipient, Address):
             recipient = recipient.serialize()
             recipient = recipient.serialize()
             field_name = "address" if simple_name else "recipientAddress"
             field_name = "address" if simple_name else "recipientAddress"
@@ -143,27 +166,49 @@ class SignaldClient(SignaldRPCClient):
             field_name = "group" if simple_name else "recipientGroupId"
             field_name = "group" if simple_name else "recipientGroupId"
         return {field_name: recipient}
         return {field_name: recipient}
 
 
-    async def react(self, username: str, recipient: Union[Address, GroupID],
-                    reaction: Reaction) -> None:
-        await self.request_v1("react", username=username, reaction=reaction.serialize(),
-                              **self._recipient_to_args(recipient))
-
-    async def remote_delete(self, username: str, recipient: Union[Address, GroupID], timestamp: int
-                            ) -> None:
-        await self.request_v1("remote_delete", account=username, timestamp=timestamp,
-                              **self._recipient_to_args(recipient, simple_name=True))
-
-    async def send(self, username: str, recipient: Union[Address, GroupID], body: str,
-                   quote: Optional[Quote] = None, attachments: Optional[List[Attachment]] = None,
-                   mentions: Optional[List[Mention]] = None, timestamp: Optional[int] = None
-                   ) -> None:
+    async def react(
+        self, username: str, recipient: Union[Address, GroupID], reaction: Reaction
+    ) -> None:
+        await self.request_v1(
+            "react",
+            username=username,
+            reaction=reaction.serialize(),
+            **self._recipient_to_args(recipient),
+        )
+
+    async def remote_delete(
+        self, username: str, recipient: Union[Address, GroupID], timestamp: int
+    ) -> None:
+        await self.request_v1(
+            "remote_delete",
+            account=username,
+            timestamp=timestamp,
+            **self._recipient_to_args(recipient, simple_name=True),
+        )
+
+    async def send(
+        self,
+        username: str,
+        recipient: Union[Address, GroupID],
+        body: str,
+        quote: Optional[Quote] = None,
+        attachments: Optional[List[Attachment]] = None,
+        mentions: Optional[List[Mention]] = None,
+        timestamp: Optional[int] = None,
+    ) -> None:
         serialized_quote = quote.serialize() if quote else None
         serialized_quote = quote.serialize() if quote else None
         serialized_attachments = [attachment.serialize() for attachment in (attachments or [])]
         serialized_attachments = [attachment.serialize() for attachment in (attachments or [])]
         serialized_mentions = [mention.serialize() for mention in (mentions or [])]
         serialized_mentions = [mention.serialize() for mention in (mentions or [])]
-        resp = await self.request_v1("send", username=username, messageBody=body,
-                                     attachments=serialized_attachments, quote=serialized_quote,
-                                     mentions=serialized_mentions, timestamp=timestamp,
-                                     **self._recipient_to_args(recipient))
+        resp = await self.request_v1(
+            "send",
+            username=username,
+            messageBody=body,
+            attachments=serialized_attachments,
+            quote=serialized_quote,
+            mentions=serialized_mentions,
+            timestamp=timestamp,
+            **self._recipient_to_args(recipient),
+        )
         errors = []
         errors = []
 
 
         # We handle unregisteredFailure a little differently than other errors. If there are no
         # We handle unregisteredFailure a little differently than other errors. If there are no
@@ -173,9 +218,8 @@ class SignaldClient(SignaldRPCClient):
         successful_send_count = 0
         successful_send_count = 0
         results = resp.get("results", [])
         results = resp.get("results", [])
         for result in results:
         for result in results:
-            number = (
-                result.get("address", {}).get("number") or result.get("address", {}).get("uuid")
-            )
+            address = result.get("addres", {})
+            number = address.get("number") or address.get("uuid")
             proof_required_failure = result.get("proof_required_failure")
             proof_required_failure = result.get("proof_required_failure")
             if result.get("networkFailure", False):
             if result.get("networkFailure", False):
                 errors.append(f"Network failure occurred while sending message to {number}.")
                 errors.append(f"Network failure occurred while sending message to {number}.")
@@ -186,9 +230,10 @@ class SignaldClient(SignaldRPCClient):
             elif result.get("identityFailure", ""):
             elif result.get("identityFailure", ""):
                 errors.append(
                 errors.append(
                     f"Identity failure occurred while sending message to {number}. New identity: "
                     f"Identity failure occurred while sending message to {number}. New identity: "
-                    f"{result['identityFailure']}")
+                    f"{result['identityFailure']}"
+                )
             elif proof_required_failure:
             elif proof_required_failure:
-                options = proof_required_failure.get('options')
+                options = proof_required_failure.get("options")
                 self.log.warning(
                 self.log.warning(
                     f"Proof Required Failure {options}. "
                     f"Proof Required Failure {options}. "
                     f"Retry after: {proof_required_failure.get('retry_after')}. "
                     f"Retry after: {proof_required_failure.get('retry_after')}. "
@@ -206,17 +251,26 @@ class SignaldClient(SignaldRPCClient):
                     await self.request_v1("submit_challenge")
                     await self.request_v1("submit_challenge")
             else:
             else:
                 successful_send_count += 1
                 successful_send_count += 1
-        self.log.info(f"Successfully sent message to {successful_send_count}/{len(results)} users in {recipient}")
+        self.log.info(
+            f"Successfully sent message to {successful_send_count}/{len(results)} users in {recipient}"
+        )
         if errors or successful_send_count == 0:
         if errors or successful_send_count == 0:
             raise Exception("\n".join(errors + unregistered_failures))
             raise Exception("\n".join(errors + unregistered_failures))
 
 
-    async def send_receipt(self, username: str, sender: Address, timestamps: List[int],
-                           when: Optional[int] = None, read: bool = False) -> None:
+    async def send_receipt(
+        self,
+        username: str,
+        sender: Address,
+        timestamps: List[int],
+        when: Optional[int] = None,
+        read: bool = False,
+    ) -> None:
         if not read:
         if not read:
             # TODO implement
             # TODO implement
             return
             return
-        await self.request_v1("mark_read", account=username, timestamps=timestamps, when=when,
-                              to=sender.serialize())
+        await self.request_v1(
+            "mark_read", account=username, timestamps=timestamps, when=when, to=sender.serialize()
+        )
 
 
     async def list_accounts(self) -> List[Account]:
     async def list_accounts(self) -> List[Account]:
         resp = await self.request_v1("list_accounts")
         resp = await self.request_v1("list_accounts")
@@ -242,19 +296,28 @@ class SignaldClient(SignaldRPCClient):
         v2 = [GroupV2.deserialize(group) for group in resp.get("groups", [])]
         v2 = [GroupV2.deserialize(group) for group in resp.get("groups", [])]
         return legacy + v2
         return legacy + v2
 
 
-    async def update_group(self, username: str, group_id: GroupID, title: Optional[str] = None,
-                           avatar_path: Optional[str] = None,
-                           add_members: Optional[List[Address]] = None,
-                           remove_members: Optional[List[Address]] = None
-                           ) -> Union[Group, GroupV2, None]:
-        update_params = {key: value for key, value in {
-            "groupID": group_id,
-            "avatar": avatar_path,
-            "title": title,
-            "addMembers": [addr.serialize() for addr in add_members] if add_members else None,
-            "removeMembers": ([addr.serialize() for addr in remove_members]
-                              if remove_members else None),
-        }.items() if value is not None}
+    async def update_group(
+        self,
+        username: str,
+        group_id: GroupID,
+        title: Optional[str] = None,
+        avatar_path: Optional[str] = None,
+        add_members: Optional[List[Address]] = None,
+        remove_members: Optional[List[Address]] = None,
+    ) -> Union[Group, GroupV2, None]:
+        update_params = {
+            key: value
+            for key, value in {
+                "groupID": group_id,
+                "avatar": avatar_path,
+                "title": title,
+                "addMembers": [addr.serialize() for addr in add_members] if add_members else None,
+                "removeMembers": (
+                    [addr.serialize() for addr in remove_members] if remove_members else None
+                ),
+            }.items()
+            if value is not None
+        }
         resp = await self.request_v1("update_group", account=username, **update_params)
         resp = await self.request_v1("update_group", account=username, **update_params)
         if "v1" in resp:
         if "v1" in resp:
             return Group.deserialize(resp["v1"])
             return Group.deserialize(resp["v1"])
@@ -267,21 +330,25 @@ class SignaldClient(SignaldRPCClient):
         resp = await self.request_v1("accept_invitation", account=username, groupID=group_id)
         resp = await self.request_v1("accept_invitation", account=username, groupID=group_id)
         return GroupV2.deserialize(resp)
         return GroupV2.deserialize(resp)
 
 
-    async def get_group(self, username: str, group_id: GroupID, revision: int = -1
-                        ) -> Optional[GroupV2]:
-        resp = await self.request_v1("get_group", account=username, groupID=group_id,
-                                     revision=revision)
+    async def get_group(
+        self, username: str, group_id: GroupID, revision: int = -1
+    ) -> Optional[GroupV2]:
+        resp = await self.request_v1(
+            "get_group", account=username, groupID=group_id, revision=revision
+        )
         if "id" not in resp:
         if "id" not in resp:
             return None
             return None
         return GroupV2.deserialize(resp)
         return GroupV2.deserialize(resp)
 
 
-    async def get_profile(self, username: str, address: Address, use_cache: bool = False
-                          ) -> Optional[Profile]:
+    async def get_profile(
+        self, username: str, address: Address, use_cache: bool = False
+    ) -> Optional[Profile]:
         try:
         try:
             # async is a reserved keyword, so can't pass it as a normal parameter
             # async is a reserved keyword, so can't pass it as a normal parameter
             kwargs = {"async": use_cache}
             kwargs = {"async": use_cache}
-            resp = await self.request_v1("get_profile", account=username,
-                                         address=address.serialize(), **kwargs)
+            resp = await self.request_v1(
+                "get_profile", account=username, address=address.serialize(), **kwargs
+            )
         except UnexpectedResponse as e:
         except UnexpectedResponse as e:
             if e.resp_type == "profile_not_available":
             if e.resp_type == "profile_not_available":
                 return None
                 return None
@@ -289,12 +356,14 @@ class SignaldClient(SignaldRPCClient):
         return Profile.deserialize(resp)
         return Profile.deserialize(resp)
 
 
     async def get_identities(self, username: str, address: Address) -> GetIdentitiesResponse:
     async def get_identities(self, username: str, address: Address) -> GetIdentitiesResponse:
-        resp = await self.request_v1("get_identities", account=username,
-                                     address=address.serialize())
+        resp = await self.request_v1(
+            "get_identities", account=username, address=address.serialize()
+        )
         return GetIdentitiesResponse.deserialize(resp)
         return GetIdentitiesResponse.deserialize(resp)
 
 
-    async def set_profile(self, username: str, name: Optional[str] = None,
-                          avatar_path: Optional[str] = None) -> None:
+    async def set_profile(
+        self, username: str, name: Optional[str] = None, avatar_path: Optional[str] = None
+    ) -> None:
         args = {}
         args = {}
         if name is not None:
         if name is not None:
             args["name"] = name
             args["name"] = name
@@ -302,9 +371,14 @@ class SignaldClient(SignaldRPCClient):
             args["avatarFile"] = avatar_path
             args["avatarFile"] = avatar_path
         await self.request_v1("set_profile", account=username, **args)
         await self.request_v1("set_profile", account=username, **args)
 
 
-    async def trust(self, username: str, recipient: Address, trust_level: str,
-                    safety_number: Optional[str] = None, qr_code_data: Optional[str] = None
-                    ) -> None:
+    async def trust(
+        self,
+        username: str,
+        recipient: Address,
+        trust_level: str,
+        safety_number: Optional[str] = None,
+        qr_code_data: Optional[str] = None,
+    ) -> None:
         args = {}
         args = {}
         if safety_number:
         if safety_number:
             if qr_code_data:
             if qr_code_data:
@@ -314,5 +388,10 @@ class SignaldClient(SignaldRPCClient):
             args["qr_code_data"] = qr_code_data
             args["qr_code_data"] = qr_code_data
         else:
         else:
             raise ValueError("safety_number or qr_code_data is required")
             raise ValueError("safety_number or qr_code_data is required")
-        await self.request_v1("trust", account=username, **args, trust_level=trust_level,
-                              address=recipient.serialize())
+        await self.request_v1(
+            "trust",
+            account=username,
+            **args,
+            trust_level=trust_level,
+            address=recipient.serialize(),
+        )

+ 21 - 16
mausignald/types.py

@@ -3,15 +3,14 @@
 # This Source Code Form is subject to the terms of the Mozilla Public
 # This Source Code Form is subject to the terms of the Mozilla Public
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
-from typing import Optional, Dict, List, NewType
+from typing import Dict, List, NewType, Optional
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from uuid import UUID
 from uuid import UUID
 
 
 from attr import dataclass
 from attr import dataclass
+from mautrix.types import ExtensibleEnum, SerializableAttrs, SerializableEnum, field
 
 
-from mautrix.types import SerializableAttrs, SerializableEnum, ExtensibleEnum, field
-
-GroupID = NewType('GroupID', str)
+GroupID = NewType("GroupID", str)
 
 
 
 
 @dataclass(frozen=True, eq=False)
 @dataclass(frozen=True, eq=False)
@@ -27,7 +26,7 @@ class Address(SerializableAttrs):
     def best_identifier(self) -> str:
     def best_identifier(self) -> str:
         return str(self.uuid) if self.uuid else self.number
         return str(self.uuid) if self.uuid else self.number
 
 
-    def __eq__(self, other: 'Address') -> bool:
+    def __eq__(self, other: "Address") -> bool:
         if not isinstance(other, Address):
         if not isinstance(other, Address):
             return False
             return False
         if self.uuid and other.uuid:
         if self.uuid and other.uuid:
@@ -42,7 +41,7 @@ class Address(SerializableAttrs):
         return hash(self.number)
         return hash(self.number)
 
 
     @classmethod
     @classmethod
-    def parse(cls, value: str) -> 'Address':
+    def parse(cls, value: str) -> "Address":
         return Address(number=value) if value.startswith("+") else Address(uuid=UUID(value))
         return Address(number=value) if value.startswith("+") else Address(uuid=UUID(value))
 
 
 
 
@@ -202,13 +201,15 @@ class GroupV2(GroupV2ID, SerializableAttrs):
     timer: Optional[int] = None
     timer: Optional[int] = None
     master_key: Optional[str] = field(default=None, json="masterKey")
     master_key: Optional[str] = field(default=None, json="masterKey")
     invite_link: Optional[str] = field(default=None, json="inviteLink")
     invite_link: Optional[str] = field(default=None, json="inviteLink")
-    access_control: GroupAccessControl = field(factory=lambda: GroupAccessControl(),
-                                               json="accessControl")
+    access_control: GroupAccessControl = field(
+        factory=lambda: GroupAccessControl(), json="accessControl"
+    )
     members: List[Address]
     members: List[Address]
     member_detail: List[GroupMember] = field(factory=lambda: [], json="memberDetail")
     member_detail: List[GroupMember] = field(factory=lambda: [], json="memberDetail")
     pending_members: List[Address] = field(factory=lambda: [], json="pendingMembers")
     pending_members: List[Address] = field(factory=lambda: [], json="pendingMembers")
-    pending_member_detail: List[GroupMember] = field(factory=lambda: [],
-                                                     json="pendingMemberDetail")
+    pending_member_detail: List[GroupMember] = field(
+        factory=lambda: [], json="pendingMemberDetail"
+    )
     requesting_members: List[Address] = field(factory=lambda: [], json="requestingMembers")
     requesting_members: List[Address] = field(factory=lambda: [], json="requestingMembers")
 
 
 
 
@@ -294,8 +295,9 @@ class MessageData(SerializableAttrs):
 class SentSyncMessage(SerializableAttrs):
 class SentSyncMessage(SerializableAttrs):
     message: MessageData
     message: MessageData
     timestamp: int
     timestamp: int
-    expiration_start_timestamp: Optional[int] = field(default=None,
-                                                      json="expirationStartTimestamp")
+    expiration_start_timestamp: Optional[int] = field(
+        default=None, json="expirationStartTimestamp"
+    )
     is_recipient_update: bool = field(default=False, json="isRecipientUpdate")
     is_recipient_update: bool = field(default=False, json="isRecipientUpdate")
     unidentified_status: Dict[str, bool] = field(factory=lambda: {})
     unidentified_status: Dict[str, bool] = field(factory=lambda: {})
     destination: Optional[Address] = None
     destination: Optional[Address] = None
@@ -347,11 +349,13 @@ class ConfigItem(SerializableAttrs):
 @dataclass
 @dataclass
 class ClientConfiguration(SerializableAttrs):
 class ClientConfiguration(SerializableAttrs):
     read_receipts: Optional[ConfigItem] = field(factory=lambda: ConfigItem(), json="readReceipts")
     read_receipts: Optional[ConfigItem] = field(factory=lambda: ConfigItem(), json="readReceipts")
-    typing_indicators: Optional[ConfigItem] = field(factory=lambda: ConfigItem(),
-                                                    json="typingIndicators")
+    typing_indicators: Optional[ConfigItem] = field(
+        factory=lambda: ConfigItem(), json="typingIndicators"
+    )
     link_previews: Optional[ConfigItem] = field(factory=lambda: ConfigItem(), json="linkPreviews")
     link_previews: Optional[ConfigItem] = field(factory=lambda: ConfigItem(), json="linkPreviews")
     unidentified_delivery_indicators: Optional[ConfigItem] = field(
     unidentified_delivery_indicators: Optional[ConfigItem] = field(
-        factory=lambda: ConfigItem(), json="unidentifiedDeliveryIndicators")
+        factory=lambda: ConfigItem(), json="unidentifiedDeliveryIndicators"
+    )
 
 
 
 
 class StickerPackOperation(ExtensibleEnum):
 class StickerPackOperation(ExtensibleEnum):
@@ -376,7 +380,8 @@ class SyncMessage(SerializableAttrs):
     configuration: Optional[ClientConfiguration] = None
     configuration: Optional[ClientConfiguration] = None
     # blocked_list: Optional[???] = field(default=None, json="blockedList")
     # blocked_list: Optional[???] = field(default=None, json="blockedList")
     sticker_pack_operations: Optional[List[StickerPackOperations]] = field(
     sticker_pack_operations: Optional[List[StickerPackOperations]] = field(
-        default=None, json="stickerPackOperations")
+        default=None, json="stickerPackOperations"
+    )
     contacts_complete: bool = field(default=False, json="contactsComplete")
     contacts_complete: bool = field(default=False, json="contactsComplete")
 
 
 
 

+ 7 - 6
mautrix_signal/__main__.py

@@ -13,7 +13,7 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import Dict, Any
+from typing import Any, Dict
 from random import uniform
 from random import uniform
 import asyncio
 import asyncio
 import logging
 import logging
@@ -21,16 +21,17 @@ import logging
 from mautrix.bridge import Bridge
 from mautrix.bridge import Bridge
 from mautrix.types import RoomID, UserID
 from mautrix.types import RoomID, UserID
 
 
-from .version import version, linkified_version
+from . import commands
 from .config import Config
 from .config import Config
-from .db import upgrade_table, init as init_db
+from .db import init as init_db
+from .db import upgrade_table
 from .matrix import MatrixHandler
 from .matrix import MatrixHandler
-from .signal import SignalHandler
-from .user import User
 from .portal import Portal
 from .portal import Portal
 from .puppet import Puppet
 from .puppet import Puppet
+from .signal import SignalHandler
+from .user import User
+from .version import linkified_version, version
 from .web import ProvisioningAPI
 from .web import ProvisioningAPI
-from . import commands
 
 
 SYNC_JITTER = 10
 SYNC_JITTER = 10
 
 

+ 78 - 38
mautrix_signal/commands/auth.py

@@ -16,17 +16,18 @@
 from typing import Union
 from typing import Union
 import io
 import io
 
 
-from mausignald.errors import UnexpectedResponse, TimeoutException, AuthorizationFailedException
 from mautrix.appservice import IntentAPI
 from mautrix.appservice import IntentAPI
-from mautrix.types import MediaMessageEventContent, MessageType, ImageInfo, EventID
 from mautrix.bridge.commands import HelpSection, command_handler
 from mautrix.bridge.commands import HelpSection, command_handler
+from mautrix.types import EventID, ImageInfo, MediaMessageEventContent, MessageType
+
+from mausignald.errors import AuthorizationFailedException, TimeoutException, UnexpectedResponse
 
 
 from .. import puppet as pu
 from .. import puppet as pu
 from .typehint import CommandEvent
 from .typehint import CommandEvent
 
 
 try:
 try:
-    import qrcode
     import PIL as _
     import PIL as _
+    import qrcode
 except ImportError:
 except ImportError:
     qrcode = None
     qrcode = None
 
 
@@ -34,8 +35,9 @@ SECTION_AUTH = HelpSection("Authentication", 10, "")
 remove_extra_chars = str.maketrans("", "", " .,-()")
 remove_extra_chars = str.maketrans("", "", " .,-()")
 
 
 
 
-async def make_qr(intent: IntentAPI, data: Union[str, bytes], body: str = None
-                  ) -> MediaMessageEventContent:
+async def make_qr(
+    intent: IntentAPI, data: Union[str, bytes], body: str = None
+) -> MediaMessageEventContent:
     # TODO always encrypt QR codes?
     # TODO always encrypt QR codes?
     buffer = io.BytesIO()
     buffer = io.BytesIO()
     image = qrcode.make(data)
     image = qrcode.make(data)
@@ -43,20 +45,30 @@ async def make_qr(intent: IntentAPI, data: Union[str, bytes], body: str = None
     image.save(buffer, "PNG")
     image.save(buffer, "PNG")
     qr = buffer.getvalue()
     qr = buffer.getvalue()
     mxc = await intent.upload_media(qr, "image/png", "qr.png", len(qr))
     mxc = await intent.upload_media(qr, "image/png", "qr.png", len(qr))
-    return MediaMessageEventContent(body=body or data, url=mxc, msgtype=MessageType.IMAGE,
-                                    info=ImageInfo(mimetype="image/png", size=len(qr),
-                                                   width=size, height=size))
-
-
-@command_handler(needs_auth=False, management_only=True, help_section=SECTION_AUTH,
-                 help_text="Link the bridge as a secondary device", help_args="[device name]")
+    return MediaMessageEventContent(
+        body=body or data,
+        url=mxc,
+        msgtype=MessageType.IMAGE,
+        info=ImageInfo(mimetype="image/png", size=len(qr), width=size, height=size),
+    )
+
+
+@command_handler(
+    needs_auth=False,
+    management_only=True,
+    help_section=SECTION_AUTH,
+    help_text="Link the bridge as a secondary device",
+    help_args="[device name]",
+)
 async def link(evt: CommandEvent) -> None:
 async def link(evt: CommandEvent) -> None:
     if qrcode is None:
     if qrcode is None:
         await evt.reply("Can't generate QR code: qrcode and/or PIL not installed")
         await evt.reply("Can't generate QR code: qrcode and/or PIL not installed")
         return
         return
     if await evt.sender.is_logged_in():
     if await evt.sender.is_logged_in():
-        await evt.reply("You're already logged in. "
-                        "If you want to relink, log out with `$cmdprefix+sp logout` first.")
+        await evt.reply(
+            "You're already logged in. "
+            "If you want to relink, log out with `$cmdprefix+sp logout` first."
+        )
         return
         return
     # TODO make default device name configurable
     # TODO make default device name configurable
     device_name = " ".join(evt.args) or "Mautrix-Signal bridge"
     device_name = " ".join(evt.args) or "Mautrix-Signal bridge"
@@ -65,14 +77,16 @@ async def link(evt: CommandEvent) -> None:
     content = await make_qr(evt.az.intent, sess.uri)
     content = await make_qr(evt.az.intent, sess.uri)
     event_id = await evt.az.intent.send_message(evt.room_id, content)
     event_id = await evt.az.intent.send_message(evt.room_id, content)
     try:
     try:
-        account = await evt.bridge.signal.finish_link(session_id=sess.session_id, overwrite=True,
-                                                      device_name=device_name)
+        account = await evt.bridge.signal.finish_link(
+            session_id=sess.session_id, overwrite=True, device_name=device_name
+        )
     except TimeoutException:
     except TimeoutException:
         await evt.reply("Linking timed out, please try again.")
         await evt.reply("Linking timed out, please try again.")
     except Exception:
     except Exception:
         evt.log.exception("Fatal error while waiting for linking to finish")
         evt.log.exception("Fatal error while waiting for linking to finish")
-        await evt.reply("Fatal error while waiting for linking to finish "
-                        "(see logs for more details)")
+        await evt.reply(
+            "Fatal error while waiting for linking to finish " "(see logs for more details)"
+        )
     else:
     else:
         await evt.sender.on_signin(account)
         await evt.sender.on_signin(account)
         await evt.reply(f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}")
         await evt.reply(f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}")
@@ -80,16 +94,23 @@ async def link(evt: CommandEvent) -> None:
         await evt.main_intent.redact(evt.room_id, event_id)
         await evt.main_intent.redact(evt.room_id, event_id)
 
 
 
 
-@command_handler(needs_auth=False, management_only=True, help_section=SECTION_AUTH,
-                 is_enabled_for=lambda evt: evt.config["signal.registration_enabled"],
-                 help_text="Sign into Signal as the primary device", help_args="<phone>")
+@command_handler(
+    needs_auth=False,
+    management_only=True,
+    help_section=SECTION_AUTH,
+    is_enabled_for=lambda evt: evt.config["signal.registration_enabled"],
+    help_text="Sign into Signal as the primary device",
+    help_args="<phone>",
+)
 async def register(evt: CommandEvent) -> None:
 async def register(evt: CommandEvent) -> None:
     if len(evt.args) == 0:
     if len(evt.args) == 0:
         await evt.reply("**Usage**: $cmdprefix+sp register [--voice] [--captcha <token>] <phone>")
         await evt.reply("**Usage**: $cmdprefix+sp register [--voice] [--captcha <token>] <phone>")
         return
         return
     if await evt.sender.is_logged_in():
     if await evt.sender.is_logged_in():
-        await evt.reply("You're already logged in. "
-                        "If you want to re-register, log out with `$cmdprefix+sp logout` first.")
+        await evt.reply(
+            "You're already logged in. "
+            "If you want to re-register, log out with `$cmdprefix+sp logout` first."
+        )
         return
         return
     voice = False
     voice = False
     captcha = None
     captcha = None
@@ -132,13 +153,19 @@ async def enter_register_code(evt: CommandEvent) -> None:
             raise
             raise
     else:
     else:
         await evt.sender.on_signin(account)
         await evt.sender.on_signin(account)
-        await evt.reply(f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}."
-                        f"\n\n**N.B.** You must set a Signal profile name with `$cmdprefix+sp "
-                        f"set-profile-name <name>` before you can participate in new groups.")
-
-
-@command_handler(needs_auth=True, management_only=True, help_section=SECTION_AUTH,
-                 help_text="Remove all local data about your Signal link")
+        await evt.reply(
+            f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}."
+            f"\n\n**N.B.** You must set a Signal profile name with `$cmdprefix+sp "
+            f"set-profile-name <name>` before you can participate in new groups."
+        )
+
+
+@command_handler(
+    needs_auth=True,
+    management_only=True,
+    help_section=SECTION_AUTH,
+    help_text="Remove all local data about your Signal link",
+)
 async def logout(evt: CommandEvent) -> None:
 async def logout(evt: CommandEvent) -> None:
     if not evt.sender.username:
     if not evt.sender.username:
         await evt.reply("You're not logged in")
         await evt.reply("You're not logged in")
@@ -147,16 +174,29 @@ async def logout(evt: CommandEvent) -> None:
     await evt.reply("Successfully logged out")
     await evt.reply("Successfully logged out")
 
 
 
 
-@command_handler(needs_auth=True, management_only=True, help_section=SECTION_AUTH,
-                 help_text="List devices linked to your Signal account")
+@command_handler(
+    needs_auth=True,
+    management_only=True,
+    help_section=SECTION_AUTH,
+    help_text="List devices linked to your Signal account",
+)
 async def list_devices(evt: CommandEvent) -> None:
 async def list_devices(evt: CommandEvent) -> None:
     devices = await evt.bridge.signal.get_linked_devices(evt.sender.username)
     devices = await evt.bridge.signal.get_linked_devices(evt.sender.username)
-    await evt.reply("\n".join(f"* #{dev.id}: {dev.name_with_default} (created {dev.created_fmt}, "
-                              f"last seen {dev.last_seen_fmt})" for dev in devices))
-
-
-@command_handler(needs_auth=True, management_only=True, help_section=SECTION_AUTH,
-                 help_text="Remove a linked device")
+    await evt.reply(
+        "\n".join(
+            f"* #{dev.id}: {dev.name_with_default} (created {dev.created_fmt}, last seen "
+            f"{dev.last_seen_fmt})"
+            for dev in devices
+        )
+    )
+
+
+@command_handler(
+    needs_auth=True,
+    management_only=True,
+    help_section=SECTION_AUTH,
+    help_text="Remove a linked device",
+)
 async def remove_linked_device(evt: CommandEvent) -> EventID:
 async def remove_linked_device(evt: CommandEvent) -> EventID:
     if len(evt.args) == 0:
     if len(evt.args) == 0:
         return await evt.reply("**Usage:** `$cmdprefix+sp remove-linked-device <device ID>`")
         return await evt.reply("**Usage:** `$cmdprefix+sp remove-linked-device <device ID>`")

+ 25 - 9
mautrix_signal/commands/conn.py

@@ -13,35 +13,50 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from mautrix.types import EventID
 from mautrix.bridge.commands import HelpSection, command_handler
 from mautrix.bridge.commands import HelpSection, command_handler
+from mautrix.types import EventID
+
 from .typehint import CommandEvent
 from .typehint import CommandEvent
 
 
 SECTION_CONNECTION = HelpSection("Connection management", 15, "")
 SECTION_CONNECTION = HelpSection("Connection management", 15, "")
 
 
 
 
-@command_handler(needs_auth=False, management_only=True, help_section=SECTION_CONNECTION,
-                 help_text="Mark this room as your bridge notice room.")
+@command_handler(
+    needs_auth=False,
+    management_only=True,
+    help_section=SECTION_CONNECTION,
+    help_text="Mark this room as your bridge notice room.",
+)
 async def set_notice_room(evt: CommandEvent) -> None:
 async def set_notice_room(evt: CommandEvent) -> None:
     evt.sender.notice_room = evt.room_id
     evt.sender.notice_room = evt.room_id
     await evt.sender.update()
     await evt.sender.update()
     await evt.reply("This room has been marked as your bridge notice room")
     await evt.reply("This room has been marked as your bridge notice room")
 
 
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_CONNECTION,
-                 help_text="Relay messages in this room through your Signal account.")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_CONNECTION,
+    help_text="Relay messages in this room through your Signal account.",
+)
 async def set_relay(evt: CommandEvent) -> EventID:
 async def set_relay(evt: CommandEvent) -> EventID:
     if not evt.config["bridge.relay.enabled"]:
     if not evt.config["bridge.relay.enabled"]:
         return await evt.reply("Relay mode is not enabled in this instance of the bridge.")
         return await evt.reply("Relay mode is not enabled in this instance of the bridge.")
     elif not evt.is_portal:
     elif not evt.is_portal:
         return await evt.reply("This is not a portal room.")
         return await evt.reply("This is not a portal room.")
     await evt.portal.set_relay_user(evt.sender)
     await evt.portal.set_relay_user(evt.sender)
-    return await evt.reply("Messages from non-logged-in users in this room will now be bridged "
-                           "through your Signal account.")
+    return await evt.reply(
+        "Messages from non-logged-in users in this room will now be bridged "
+        "through your Signal account."
+    )
 
 
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_CONNECTION,
-                 help_text="Stop relaying messages in this room.")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_CONNECTION,
+    help_text="Stop relaying messages in this room.",
+)
 async def unset_relay(evt: CommandEvent) -> EventID:
 async def unset_relay(evt: CommandEvent) -> EventID:
     if not evt.config["bridge.relay.enabled"]:
     if not evt.config["bridge.relay.enabled"]:
         return await evt.reply("Relay mode is not enabled in this instance of the bridge.")
         return await evt.reply("Relay mode is not enabled in this instance of the bridge.")
@@ -52,6 +67,7 @@ async def unset_relay(evt: CommandEvent) -> EventID:
     await evt.portal.set_relay_user(None)
     await evt.portal.set_relay_user(None)
     return await evt.reply("Messages from non-logged-in users will no longer be bridged.")
     return await evt.reply("Messages from non-logged-in users will no longer be bridged.")
 
 
+
 # @command_handler(needs_auth=False, management_only=True, help_section=SECTION_CONNECTION,
 # @command_handler(needs_auth=False, management_only=True, help_section=SECTION_CONNECTION,
 #                  help_text="Check if you're logged into Twitter")
 #                  help_text="Check if you're logged into Twitter")
 # async def ping(evt: CommandEvent) -> None:
 # async def ping(evt: CommandEvent) -> None:

+ 101 - 48
mautrix_signal/commands/signal.py

@@ -17,33 +17,39 @@ from typing import Optional
 import base64
 import base64
 import json
 import json
 
 
-from mautrix.bridge.commands import HelpSection, command_handler, SECTION_ADMIN
+from mautrix.bridge.commands import SECTION_ADMIN, HelpSection, command_handler
 from mautrix.types import EventID
 from mautrix.types import EventID
-from mausignald.types import Address
+
 from mausignald.errors import UnknownIdentityKey
 from mausignald.errors import UnknownIdentityKey
+from mausignald.types import Address
 
 
-from .. import puppet as pu, portal as po
+from .. import portal as po
+from .. import puppet as pu
 from .auth import make_qr, remove_extra_chars
 from .auth import make_qr, remove_extra_chars
 from .typehint import CommandEvent
 from .typehint import CommandEvent
 
 
 try:
 try:
-    import qrcode
     import PIL as _
     import PIL as _
+    import qrcode
 except ImportError:
 except ImportError:
     qrcode = None
     qrcode = None
 
 
 SECTION_SIGNAL = HelpSection("Signal actions", 20, "")
 SECTION_SIGNAL = HelpSection("Signal actions", 20, "")
 
 
 
 
-async def _get_puppet_from_cmd(evt: CommandEvent) -> Optional['pu.Puppet']:
+async def _get_puppet_from_cmd(evt: CommandEvent) -> Optional["pu.Puppet"]:
     if len(evt.args) == 0 or not evt.args[0].startswith("+"):
     if len(evt.args) == 0 or not evt.args[0].startswith("+"):
-        await evt.reply(f"**Usage:** `$cmdprefix+sp {evt.command} <phone>` "
-                        "(enter phone number in international format)")
+        await evt.reply(
+            f"**Usage:** `$cmdprefix+sp {evt.command} <phone>` "
+            "(enter phone number in international format)"
+        )
         return None
         return None
     phone = "".join(evt.args).translate(remove_extra_chars)
     phone = "".join(evt.args).translate(remove_extra_chars)
     if not phone[1:].isdecimal():
     if not phone[1:].isdecimal():
-        await evt.reply(f"**Usage:** `$cmdprefix+sp {evt.command} <phone>` "
-                        "(enter phone number in international format)")
+        await evt.reply(
+            f"**Usage:** `$cmdprefix+sp {evt.command} <phone>` "
+            "(enter phone number in international format)"
+        )
         return None
         return None
     return await pu.Puppet.get_by_address(Address(number=phone))
     return await pu.Puppet.get_by_address(Address(number=phone))
 
 
@@ -51,40 +57,58 @@ async def _get_puppet_from_cmd(evt: CommandEvent) -> Optional['pu.Puppet']:
 def _format_safety_number(number: str) -> str:
 def _format_safety_number(number: str) -> str:
     line_size = 20
     line_size = 20
     chunk_size = 5
     chunk_size = 5
-    return "\n".join(" ".join([number[chunk:chunk + chunk_size]
-                               for chunk in range(line, line + line_size, chunk_size)])
-                     for line in range(0, len(number), line_size))
+    return "\n".join(
+        " ".join(
+            [
+                number[chunk : chunk + chunk_size]
+                for chunk in range(line, line + line_size, chunk_size)
+            ]
+        )
+        for line in range(0, len(number), line_size)
+    )
 
 
 
 
-def _pill(puppet: 'pu.Puppet') -> str:
+def _pill(puppet: "pu.Puppet") -> str:
     return f"[{puppet.name}](https://matrix.to/#/{puppet.mxid})"
     return f"[{puppet.name}](https://matrix.to/#/{puppet.mxid})"
 
 
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_SIGNAL,
-                 help_text="Open a private chat portal with a specific phone number",
-                 help_args="<_phone_>")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_SIGNAL,
+    help_text="Open a private chat portal with a specific phone number",
+    help_args="<_phone_>",
+)
 async def pm(evt: CommandEvent) -> None:
 async def pm(evt: CommandEvent) -> None:
     puppet = await _get_puppet_from_cmd(evt)
     puppet = await _get_puppet_from_cmd(evt)
     if not puppet:
     if not puppet:
         return
         return
-    portal = await po.Portal.get_by_chat_id(puppet.address, receiver=evt.sender.username,
-                                            create=True)
+    portal = await po.Portal.get_by_chat_id(
+        puppet.address, receiver=evt.sender.username, create=True
+    )
     if portal.mxid:
     if portal.mxid:
-        await evt.reply(f"You already have a private chat with {puppet.name}: "
-                        f"[{portal.mxid}](https://matrix.to/#/{portal.mxid})")
+        await evt.reply(
+            f"You already have a private chat with {puppet.name}: "
+            f"[{portal.mxid}](https://matrix.to/#/{portal.mxid})"
+        )
         await portal.main_intent.invite_user(portal.mxid, evt.sender.mxid)
         await portal.main_intent.invite_user(portal.mxid, evt.sender.mxid)
         return
         return
     await portal.create_matrix_room(evt.sender, puppet.address)
     await portal.create_matrix_room(evt.sender, puppet.address)
     await evt.reply(f"Created a portal room with {_pill(puppet)} and invited you to it")
     await evt.reply(f"Created a portal room with {_pill(puppet)} and invited you to it")
 
 
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_SIGNAL,
-                 help_text="Get the invite link to the current group")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_SIGNAL,
+    help_text="Get the invite link to the current group",
+)
 async def invite_link(evt: CommandEvent) -> EventID:
 async def invite_link(evt: CommandEvent) -> EventID:
     if not evt.is_portal:
     if not evt.is_portal:
         return await evt.reply("This is not a portal room.")
         return await evt.reply("This is not a portal room.")
-    group = await evt.bridge.signal.get_group(evt.sender.username, evt.portal.chat_id,
-                                              evt.portal.revision)
+    group = await evt.bridge.signal.get_group(
+        evt.sender.username, evt.portal.chat_id, evt.portal.revision
+    )
     if not group:
     if not group:
         await evt.reply("Failed to get group info")
         await evt.reply("Failed to get group info")
     elif not group.invite_link:
     elif not group.invite_link:
@@ -93,9 +117,13 @@ async def invite_link(evt: CommandEvent) -> EventID:
         await evt.reply(group.invite_link)
         await evt.reply(group.invite_link)
 
 
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_SIGNAL,
-                 help_text="View the safety number of a specific user",
-                 help_args="[--qr] [_phone_]")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_SIGNAL,
+    help_text="View the safety number of a specific user",
+    help_args="[--qr] [_phone_]",
+)
 async def safety_number(evt: CommandEvent) -> None:
 async def safety_number(evt: CommandEvent) -> None:
     show_qr = evt.args and evt.args[0].lower() == "--qr"
     show_qr = evt.args and evt.args[0].lower() == "--qr"
     if show_qr:
     if show_qr:
@@ -119,53 +147,77 @@ async def safety_number(evt: CommandEvent) -> None:
         if identity.added > most_recent.added:
         if identity.added > most_recent.added:
             most_recent = identity
             most_recent = identity
     uuid = resp.address.uuid or "unknown"
     uuid = resp.address.uuid or "unknown"
-    await evt.reply(f"### {puppet.name}\n\n"
-                    f"**UUID:** {uuid}  \n"
-                    f"**Trust level:** {most_recent.trust_level}  \n"
-                    f"**Safety number:**\n"
-                    f"```\n{_format_safety_number(most_recent.safety_number)}\n```")
+    await evt.reply(
+        f"### {puppet.name}\n\n"
+        f"**UUID:** {uuid}  \n"
+        f"**Trust level:** {most_recent.trust_level}  \n"
+        f"**Safety number:**\n"
+        f"```\n{_format_safety_number(most_recent.safety_number)}\n```"
+    )
     if show_qr and most_recent.qr_code_data:
     if show_qr and most_recent.qr_code_data:
         data = base64.b64decode(most_recent.qr_code_data)
         data = base64.b64decode(most_recent.qr_code_data)
         content = await make_qr(evt.main_intent, data, "verification-qr.png")
         content = await make_qr(evt.main_intent, data, "verification-qr.png")
         await evt.main_intent.send_message(evt.room_id, content)
         await evt.main_intent.send_message(evt.room_id, content)
 
 
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_SIGNAL,
-                 help_text="Set your Signal profile name", help_args="<_name_>")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_SIGNAL,
+    help_text="Set your Signal profile name",
+    help_args="<_name_>",
+)
 async def set_profile_name(evt: CommandEvent) -> None:
 async def set_profile_name(evt: CommandEvent) -> None:
     await evt.bridge.signal.set_profile(evt.sender.username, name=" ".join(evt.args))
     await evt.bridge.signal.set_profile(evt.sender.username, name=" ".join(evt.args))
     await evt.reply("Successfully updated profile name")
     await evt.reply("Successfully updated profile name")
 
 
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_SIGNAL,
-                 help_text="Mark another user's safety number as trusted",
-                 help_args="<_recipient phone_> <_safety number_>")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_SIGNAL,
+    help_text="Mark another user's safety number as trusted",
+    help_args="<_recipient phone_> <_safety number_>",
+)
 async def mark_trusted(evt: CommandEvent) -> EventID:
 async def mark_trusted(evt: CommandEvent) -> EventID:
     if len(evt.args) < 2:
     if len(evt.args) < 2:
-        return await evt.reply("**Usage:** `$cmdprefix+sp mark-trusted <recipient phone> "
-                               "<safety number>`")
+        return await evt.reply(
+            "**Usage:** `$cmdprefix+sp mark-trusted <recipient phone> " "<safety number>`"
+        )
     number = evt.args[0].translate(remove_extra_chars)
     number = evt.args[0].translate(remove_extra_chars)
     safety_num = "".join(evt.args[1:]).replace("\n", "")
     safety_num = "".join(evt.args[1:]).replace("\n", "")
     if len(safety_num) != 60 or not safety_num.isdecimal():
     if len(safety_num) != 60 or not safety_num.isdecimal():
         return await evt.reply("That doesn't look like a valid safety number")
         return await evt.reply("That doesn't look like a valid safety number")
     try:
     try:
-        await evt.bridge.signal.trust(evt.sender.username, Address(number=number),
-                                      safety_number=safety_num, trust_level="TRUSTED_VERIFIED")
+        await evt.bridge.signal.trust(
+            evt.sender.username,
+            Address(number=number),
+            safety_number=safety_num,
+            trust_level="TRUSTED_VERIFIED",
+        )
     except UnknownIdentityKey as e:
     except UnknownIdentityKey as e:
         return await evt.reply(f"Failed to mark {number} as trusted: {e}")
         return await evt.reply(f"Failed to mark {number} as trusted: {e}")
     return await evt.reply(f"Successfully marked {number} as trusted")
     return await evt.reply(f"Successfully marked {number} as trusted")
 
 
 
 
-@command_handler(needs_admin=False, needs_auth=True, help_section=SECTION_SIGNAL,
-                 help_text="Sync data from Signal")
+@command_handler(
+    needs_admin=False,
+    needs_auth=True,
+    help_section=SECTION_SIGNAL,
+    help_text="Sync data from Signal",
+)
 async def sync(evt: CommandEvent) -> None:
 async def sync(evt: CommandEvent) -> None:
     await evt.sender.sync()
     await evt.sender.sync()
     await evt.reply("Sync complete")
     await evt.reply("Sync complete")
 
 
 
 
-@command_handler(needs_admin=True, needs_auth=False, help_section=SECTION_ADMIN,
-                 help_text="Send raw requests to signald",
-                 help_args="[--user] <type> <_json_>")
+@command_handler(
+    needs_admin=True,
+    needs_auth=False,
+    help_section=SECTION_ADMIN,
+    help_text="Send raw requests to signald",
+    help_args="[--user] <type> <_json_>",
+)
 async def raw(evt: CommandEvent) -> None:
 async def raw(evt: CommandEvent) -> None:
     add_username = False
     add_username = False
     while True:
     while True:
@@ -200,5 +252,6 @@ async def raw(evt: CommandEvent) -> None:
         if resp_data is None:
         if resp_data is None:
             await evt.reply(f"Got reply `{resp_type}` with no content")
             await evt.reply(f"Got reply `{resp_type}` with no content")
         else:
         else:
-            await evt.reply(f"Got reply `{resp_type}`:\n\n"
-                            f"```json\n{json.dumps(resp_data, indent=2)}\n```")
+            await evt.reply(
+                f"Got reply `{resp_type}`:\n\n" f"```json\n{json.dumps(resp_data, indent=2)}\n```"
+            )

+ 4 - 4
mautrix_signal/commands/typehint.py

@@ -4,11 +4,11 @@ from mautrix.bridge.commands import CommandEvent as BaseCommandEvent
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from ..__main__ import SignalBridge
     from ..__main__ import SignalBridge
-    from ..user import User
     from ..portal import Portal
     from ..portal import Portal
+    from ..user import User
 
 
 
 
 class CommandEvent(BaseCommandEvent):
 class CommandEvent(BaseCommandEvent):
-    bridge: 'SignalBridge'
-    sender: 'User'
-    portal: 'Portal'
+    bridge: "SignalBridge"
+    sender: "User"
+    portal: "Portal"

+ 3 - 3
mautrix_signal/config.py

@@ -16,10 +16,10 @@
 from typing import Any, List, NamedTuple
 from typing import Any, List, NamedTuple
 import os
 import os
 
 
-from mautrix.types import UserID
-from mautrix.client import Client
-from mautrix.util.config import ConfigUpdateHelper, ForbiddenKey, ForbiddenDefault
 from mautrix.bridge.config import BaseBridgeConfig
 from mautrix.bridge.config import BaseBridgeConfig
+from mautrix.client import Client
+from mautrix.types import UserID
+from mautrix.util.config import ConfigUpdateHelper, ForbiddenDefault, ForbiddenKey
 
 
 Permissions = NamedTuple("Permissions", relay=bool, user=bool, admin=bool, level=str)
 Permissions = NamedTuple("Permissions", relay=bool, user=bool, admin=bool, level=str)
 
 

+ 19 - 8
mautrix_signal/db/__init__.py

@@ -1,14 +1,15 @@
-from mautrix.util.async_db import Database
 import sqlite3
 import sqlite3
 import uuid
 import uuid
 
 
-from .upgrade import upgrade_table
+from mautrix.util.async_db import Database
+
 from .disappearing_message import DisappearingMessage
 from .disappearing_message import DisappearingMessage
-from .user import User
-from .puppet import Puppet
-from .portal import Portal
 from .message import Message
 from .message import Message
+from .portal import Portal
+from .puppet import Puppet
 from .reaction import Reaction
 from .reaction import Reaction
+from .upgrade import upgrade_table
+from .user import User
 
 
 
 
 def init(db: Database) -> None:
 def init(db: Database) -> None:
@@ -18,7 +19,17 @@ def init(db: Database) -> None:
 
 
 # TODO should this be in mautrix-python?
 # TODO should this be in mautrix-python?
 sqlite3.register_adapter(uuid.UUID, lambda u: str(u))
 sqlite3.register_adapter(uuid.UUID, lambda u: str(u))
-sqlite3.register_converter("UUID", lambda b: uuid.UUID(b.decode("utf-8") if isinstance(b, bytes) else b))
+sqlite3.register_converter(
+    "UUID", lambda b: uuid.UUID(b.decode("utf-8") if isinstance(b, bytes) else b)
+)
 
 
-__all__ = ["upgrade_table", "init", "User", "Puppet", "Portal", "Message", "Reaction",
-           "DisappearingMessage"]
+__all__ = [
+    "upgrade_table",
+    "init",
+    "User",
+    "Puppet",
+    "Portal",
+    "Message",
+    "Reaction",
+    "DisappearingMessage",
+]

+ 9 - 8
mautrix_signal/db/disappearing_message.py

@@ -13,13 +13,12 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import ClassVar, List, Optional, TYPE_CHECKING
+from typing import TYPE_CHECKING, ClassVar, List, Optional
 
 
 from attr import dataclass
 from attr import dataclass
-import asyncpg
-
-from mautrix.types import RoomID, EventID
+from mautrix.types import EventID, RoomID
 from mautrix.util.async_db import Database
 from mautrix.util.async_db import Database
+import asyncpg
 
 
 fake_db = Database.create("") if TYPE_CHECKING else None
 fake_db = Database.create("") if TYPE_CHECKING else None
 
 
@@ -38,8 +37,9 @@ class DisappearingMessage:
         INSERT INTO disappearing_message (room_id, mxid, expiration_seconds, expiration_ts)
         INSERT INTO disappearing_message (room_id, mxid, expiration_seconds, expiration_ts)
         VALUES ($1, $2, $3, $4)
         VALUES ($1, $2, $3, $4)
         """
         """
-        await self.db.execute(q, self.room_id, self.mxid, self.expiration_seconds,
-                              self.expiration_ts)
+        await self.db.execute(
+            q, self.room_id, self.mxid, self.expiration_seconds, self.expiration_ts
+        )
 
 
     async def update(self) -> None:
     async def update(self) -> None:
         q = """
         q = """
@@ -48,8 +48,9 @@ class DisappearingMessage:
         WHERE room_id=$1 AND mxid=$2
         WHERE room_id=$1 AND mxid=$2
         """
         """
         try:
         try:
-            await self.db.execute(q, self.room_id, self.mxid, self.expiration_seconds,
-                              self.expiration_ts)
+            await self.db.execute(
+                q, self.room_id, self.mxid, self.expiration_seconds, self.expiration_ts
+            )
         except Exception as e:
         except Exception as e:
             print(e)
             print(e)
 
 

+ 68 - 32
mautrix_signal/db/message.py

@@ -13,15 +13,15 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import Optional, ClassVar, Union, List, TYPE_CHECKING
+from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
 from uuid import UUID
 from uuid import UUID
 
 
 from attr import dataclass
 from attr import dataclass
+from mautrix.types import EventID, RoomID
+from mautrix.util.async_db import Database
 import asyncpg
 import asyncpg
 
 
 from mausignald.types import Address, GroupID
 from mausignald.types import Address, GroupID
-from mautrix.types import RoomID, EventID
-from mautrix.util.async_db import Database
 
 
 from ..util import id_to_str
 from ..util import id_to_str
 
 
@@ -40,23 +40,39 @@ class Message:
     signal_receiver: str
     signal_receiver: str
 
 
     async def insert(self) -> None:
     async def insert(self) -> None:
-        q = ("INSERT INTO message (mxid, mx_room, sender, timestamp, signal_chat_id,"
-             "                     signal_receiver) VALUES ($1, $2, $3, $4, $5, $6)")
-        await self.db.execute(q, self.mxid, self.mx_room, self.sender.best_identifier,
-                              self.timestamp, id_to_str(self.signal_chat_id), self.signal_receiver)
+        q = """
+        INSERT INTO message (mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver)
+        VALUES ($1, $2, $3, $4, $5, $6)
+        """
+        await self.db.execute(
+            q,
+            self.mxid,
+            self.mx_room,
+            self.sender.best_identifier,
+            self.timestamp,
+            id_to_str(self.signal_chat_id),
+            self.signal_receiver,
+        )
 
 
     async def delete(self) -> None:
     async def delete(self) -> None:
-        q = ("DELETE FROM message WHERE sender=$1 AND timestamp=$2"
-             "                          AND signal_chat_id=$3 AND signal_receiver=$4")
-        await self.db.execute(q, self.sender.best_identifier, self.timestamp,
-                              id_to_str(self.signal_chat_id), self.signal_receiver)
+        q = """
+        DELETE FROM message
+         WHERE sender=$1 AND timestamp=$2 AND signal_chat_id=$3 AND signal_receiver=$4
+        """
+        await self.db.execute(
+            q,
+            self.sender.best_identifier,
+            self.timestamp,
+            id_to_str(self.signal_chat_id),
+            self.signal_receiver,
+        )
 
 
     @classmethod
     @classmethod
     async def delete_all(cls, room_id: RoomID) -> None:
     async def delete_all(cls, room_id: RoomID) -> None:
         await cls.db.execute("DELETE FROM message WHERE mx_room=$1", room_id)
         await cls.db.execute("DELETE FROM message WHERE mx_room=$1", room_id)
 
 
     @classmethod
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> 'Message':
+    def _from_row(cls, row: asyncpg.Record) -> "Message":
         data = {**row}
         data = {**row}
         chat_id = data.pop("signal_chat_id")
         chat_id = data.pop("signal_chat_id")
         if data["signal_receiver"]:
         if data["signal_receiver"]:
@@ -65,44 +81,64 @@ class Message:
         return cls(signal_chat_id=chat_id, sender=sender, **data)
         return cls(signal_chat_id=chat_id, sender=sender, **data)
 
 
     @classmethod
     @classmethod
-    async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional['Message']:
-        q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
-             "FROM message WHERE mxid=$1 AND mx_room=$2")
+    async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional["Message"]:
+        q = """
+        SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
+          FROM message WHERE mxid=$1 AND mx_room=$2
+        """
         row = await cls.db.fetchrow(q, mxid, mx_room)
         row = await cls.db.fetchrow(q, mxid, mx_room)
         if not row:
         if not row:
             return None
             return None
         return cls._from_row(row)
         return cls._from_row(row)
 
 
     @classmethod
     @classmethod
-    async def get_by_signal_id(cls, sender: Address, timestamp: int,
-                               signal_chat_id: Union[GroupID, Address], signal_receiver: str = ""
-                               ) -> Optional['Message']:
-        q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
-             "FROM message WHERE sender=$1 AND timestamp=$2"
-             "                   AND signal_chat_id=$3 AND signal_receiver=$4")
-        row = await cls.db.fetchrow(q, sender.best_identifier, timestamp,
-                                    id_to_str(signal_chat_id), signal_receiver)
+    async def get_by_signal_id(
+        cls,
+        sender: Address,
+        timestamp: int,
+        signal_chat_id: Union[GroupID, Address],
+        signal_receiver: str = "",
+    ) -> Optional["Message"]:
+        q = """
+        SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
+          FROM message
+         WHERE sender=$1 AND timestamp=$2 AND signal_chat_id=$3 AND signal_receiver=$4
+        """
+        row = await cls.db.fetchrow(
+            q, sender.best_identifier, timestamp, id_to_str(signal_chat_id), signal_receiver
+        )
         if not row:
         if not row:
             return None
             return None
         return cls._from_row(row)
         return cls._from_row(row)
 
 
     @classmethod
     @classmethod
-    async def find_by_timestamps(cls, timestamps: List[int]) -> List['Message']:
+    async def find_by_timestamps(cls, timestamps: List[int]) -> List["Message"]:
         if cls.db.scheme == "postgres":
         if cls.db.scheme == "postgres":
-            q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
-                 "FROM message WHERE timestamp=ANY($1)")
+            q = """
+            SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
+              FROM message
+             WHERE timestamp=ANY($1)
+            """
             rows = await cls.db.fetch(q, timestamps)
             rows = await cls.db.fetch(q, timestamps)
         else:
         else:
-            placeholders = ", ".join(f"?" for _ in range(len(timestamps)))
-            q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
-                 f"FROM message WHERE timestamp IN ({placeholders})")
+            placeholders = ", ".join("?" for _ in range(len(timestamps)))
+            q = f"""
+            SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
+              FROM message
+             WHERE timestamp IN ({placeholders})
+            """
             rows = await cls.db.fetch(q, *timestamps)
             rows = await cls.db.fetch(q, *timestamps)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]
 
 
     @classmethod
     @classmethod
-    async def find_by_sender_timestamp(cls, sender: Address, timestamp: int) -> Optional['Message']:
-        q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
-             "FROM message WHERE sender=$1 AND timestamp=$2")
+    async def find_by_sender_timestamp(
+        cls, sender: Address, timestamp: int
+    ) -> Optional["Message"]:
+        q = """
+        SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
+          FROM message
+         WHERE sender=$1 AND timestamp=$2
+        """
         row = await cls.db.fetchrow(q, sender.best_identifier, timestamp)
         row = await cls.db.fetchrow(q, sender.best_identifier, timestamp)
         if not row:
         if not row:
             return None
             return None

+ 82 - 41
mautrix_signal/db/portal.py

@@ -13,14 +13,14 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import Optional, ClassVar, List, Union, TYPE_CHECKING
+from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
 
 
 from attr import dataclass
 from attr import dataclass
+from mautrix.types import ContentURI, RoomID, UserID
+from mautrix.util.async_db import Database
 import asyncpg
 import asyncpg
 
 
 from mausignald.types import Address, GroupID
 from mausignald.types import Address, GroupID
-from mautrix.types import RoomID, ContentURI, UserID
-from mautrix.util.async_db import Database
 
 
 from ..util import id_to_str
 from ..util import id_to_str
 
 
@@ -49,27 +49,52 @@ class Portal:
         return id_to_str(self.chat_id)
         return id_to_str(self.chat_id)
 
 
     async def insert(self) -> None:
     async def insert(self) -> None:
-        q = ("INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, "
-             "                    name_set, avatar_set, revision, encrypted, relay_user_id, "
-             "                    expiration_time) "
-             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)")
-        await self.db.execute(q, self.chat_id_str, self.receiver, self.mxid, self.name,
-                              self.avatar_hash, self.avatar_url, self.name_set, self.avatar_set,
-                              self.revision, self.encrypted, self.relay_user_id,
-                              self.expiration_time)
+        q = """
+        INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set,
+                            avatar_set, revision, encrypted, relay_user_id, expiration_time)
+        VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
+        """
+        await self.db.execute(
+            q,
+            self.chat_id_str,
+            self.receiver,
+            self.mxid,
+            self.name,
+            self.avatar_hash,
+            self.avatar_url,
+            self.name_set,
+            self.avatar_set,
+            self.revision,
+            self.encrypted,
+            self.relay_user_id,
+            self.expiration_time,
+        )
 
 
     async def update(self) -> None:
     async def update(self) -> None:
-        q = ("UPDATE portal SET mxid=$1, name=$2, avatar_hash=$3, avatar_url=$4, name_set=$5, "
-             "                  avatar_set=$6, revision=$7, encrypted=$8, relay_user_id=$9, "
-             "                  expiration_time=$10"
-             "WHERE chat_id=$11 AND receiver=$12")
-        await self.db.execute(q, self.mxid, self.name, self.avatar_hash, self.avatar_url,
-                              self.name_set, self.avatar_set, self.revision, self.encrypted,
-                              self.relay_user_id, self.expiration_time, self.chat_id_str,
-                              self.receiver)
+        q = """
+        UPDATE portal SET mxid=$1, name=$2, avatar_hash=$3, avatar_url=$4, name_set=$5,
+                          avatar_set=$6, revision=$7, encrypted=$8, relay_user_id=$9,
+                          expiration_time=$10
+        WHERE chat_id=$11 AND receiver=$12
+        """
+        await self.db.execute(
+            q,
+            self.mxid,
+            self.name,
+            self.avatar_hash,
+            self.avatar_url,
+            self.name_set,
+            self.avatar_set,
+            self.revision,
+            self.encrypted,
+            self.relay_user_id,
+            self.expiration_time,
+            self.chat_id_str,
+            self.receiver,
+        )
 
 
     @classmethod
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> 'Portal':
+    def _from_row(cls, row: asyncpg.Record) -> "Portal":
         data = {**row}
         data = {**row}
         chat_id = data.pop("chat_id")
         chat_id = data.pop("chat_id")
         if data["receiver"]:
         if data["receiver"]:
@@ -77,46 +102,62 @@ class Portal:
         return cls(chat_id=chat_id, **data)
         return cls(chat_id=chat_id, **data)
 
 
     @classmethod
     @classmethod
-    async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id, expiration_time "
-             "FROM portal WHERE mxid=$1")
+    async def get_by_mxid(cls, mxid: RoomID) -> Optional["Portal"]:
+        q = """
+        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
+               revision, encrypted, relay_user_id, expiration_time
+          FROM portal
+         WHERE mxid=$1
+        """
         row = await cls.db.fetchrow(q, mxid)
         row = await cls.db.fetchrow(q, mxid)
         if not row:
         if not row:
             return None
             return None
         return cls._from_row(row)
         return cls._from_row(row)
 
 
     @classmethod
     @classmethod
-    async def get_by_chat_id(cls, chat_id: Union[GroupID, Address], receiver: str = ""
-                             ) -> Optional['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id, expiration_time "
-             "FROM portal WHERE chat_id=$1 AND receiver=$2")
+    async def get_by_chat_id(
+        cls, chat_id: Union[GroupID, Address], receiver: str = ""
+    ) -> Optional["Portal"]:
+        q = """
+        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
+               revision, encrypted, relay_user_id, expiration_time
+          FROM portal
+         WHERE chat_id=$1 AND receiver=$2
+        """
         row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
         row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
         if not row:
         if not row:
             return None
             return None
         return cls._from_row(row)
         return cls._from_row(row)
 
 
     @classmethod
     @classmethod
-    async def find_private_chats_of(cls, receiver: str) -> List['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id, expiration_time "
-             "FROM portal WHERE receiver=$1")
+    async def find_private_chats_of(cls, receiver: str) -> List["Portal"]:
+        q = """
+        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
+               revision, encrypted, relay_user_id, expiration_time
+          FROM portal
+         WHERE receiver=$1
+        """
         rows = await cls.db.fetch(q, receiver)
         rows = await cls.db.fetch(q, receiver)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]
 
 
     @classmethod
     @classmethod
-    async def find_private_chats_with(cls, other_user: Address) -> List['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id, expiration_time "
-             "FROM portal WHERE chat_id=$1 AND receiver<>''")
+    async def find_private_chats_with(cls, other_user: Address) -> List["Portal"]:
+        q = """
+        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
+               revision, encrypted, relay_user_id, expiration_time
+          FROM portal
+         WHERE chat_id=$1 AND receiver<>''
+        """
         rows = await cls.db.fetch(q, other_user.best_identifier)
         rows = await cls.db.fetch(q, other_user.best_identifier)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]
 
 
     @classmethod
     @classmethod
-    async def all_with_room(cls) -> List['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id, expiration_time "
-             "FROM portal WHERE mxid IS NOT NULL")
+    async def all_with_room(cls) -> List["Portal"]:
+        q = """
+        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
+               revision, encrypted, relay_user_id, expiration_time
+          FROM portal
+         WHERE mxid IS NOT NULL
+        """
         rows = await cls.db.fetch(q)
         rows = await cls.db.fetch(q)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]

+ 71 - 36
mautrix_signal/db/puppet.py

@@ -13,16 +13,16 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import Optional, ClassVar, List, TYPE_CHECKING
+from typing import TYPE_CHECKING, ClassVar, List, Optional
 from uuid import UUID
 from uuid import UUID
 
 
 from attr import dataclass
 from attr import dataclass
+from mautrix.types import ContentURI, SyncToken, UserID
+from mautrix.util.async_db import Database
 from yarl import URL
 from yarl import URL
 import asyncpg
 import asyncpg
 
 
 from mausignald.types import Address
 from mausignald.types import Address
-from mautrix.types import UserID, SyncToken, ContentURI
-from mautrix.util.async_db import Database
 
 
 fake_db = Database.create("") if TYPE_CHECKING else None
 fake_db = Database.create("") if TYPE_CHECKING else None
 
 
@@ -52,36 +52,54 @@ class Puppet:
         return str(self.base_url) if self.base_url else None
         return str(self.base_url) if self.base_url else None
 
 
     async def insert(self) -> None:
     async def insert(self) -> None:
-        q = ("INSERT INTO puppet (uuid, number, name, avatar_hash, avatar_url, name_set, "
-             "                    avatar_set, uuid_registered, number_registered, "
-             "                    custom_mxid, access_token, next_batch, base_url) "
-             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)")
-        await self.db.execute(q, self.uuid, self.number, self.name, self.avatar_hash,
-                              self.avatar_url, self.name_set, self.avatar_set,
-                              self.uuid_registered, self.number_registered, self.custom_mxid,
-                              self.access_token, self.next_batch, self._base_url_str)
+        q = (
+            "INSERT INTO puppet (uuid, number, name, avatar_hash, avatar_url, name_set, "
+            "                    avatar_set, uuid_registered, number_registered, "
+            "                    custom_mxid, access_token, next_batch, base_url) "
+            "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)"
+        )
+        await self.db.execute(
+            q,
+            self.uuid,
+            self.number,
+            self.name,
+            self.avatar_hash,
+            self.avatar_url,
+            self.name_set,
+            self.avatar_set,
+            self.uuid_registered,
+            self.number_registered,
+            self.custom_mxid,
+            self.access_token,
+            self.next_batch,
+            self._base_url_str,
+        )
 
 
     async def _set_uuid(self, uuid: UUID) -> None:
     async def _set_uuid(self, uuid: UUID) -> None:
         async with self.db.acquire() as conn, conn.transaction():
         async with self.db.acquire() as conn, conn.transaction():
-            await conn.execute("DELETE FROM puppet WHERE uuid=$1 AND number<>$2",
-                               uuid, self.number)
+            await conn.execute(
+                "DELETE FROM puppet WHERE uuid=$1 AND number<>$2", uuid, self.number
+            )
             await conn.execute("UPDATE puppet SET uuid=$1 WHERE number=$2", uuid, self.number)
             await conn.execute("UPDATE puppet SET uuid=$1 WHERE number=$2", uuid, self.number)
             await self._update_number_to_uuid(conn, self.number, str(uuid))
             await self._update_number_to_uuid(conn, self.number, str(uuid))
 
 
     async def _set_number(self, number: str) -> None:
     async def _set_number(self, number: str) -> None:
         async with self.db.acquire() as conn, conn.transaction():
         async with self.db.acquire() as conn, conn.transaction():
-            await conn.execute("DELETE FROM puppet WHERE number=$1 AND uuid<>$2",
-                               number, self.uuid)
+            await conn.execute(
+                "DELETE FROM puppet WHERE number=$1 AND uuid<>$2", number, self.uuid
+            )
             await conn.execute("UPDATE puppet SET number=$1 WHERE uuid=$2", number, self.uuid)
             await conn.execute("UPDATE puppet SET number=$1 WHERE uuid=$2", number, self.uuid)
             await self._update_number_to_uuid(conn, number, str(self.uuid))
             await self._update_number_to_uuid(conn, number, str(self.uuid))
 
 
     @staticmethod
     @staticmethod
-    async def _update_number_to_uuid(conn: asyncpg.Connection, old_number: str, new_uuid: str
-                                     ) -> None:
+    async def _update_number_to_uuid(
+        conn: asyncpg.Connection, old_number: str, new_uuid: str
+    ) -> None:
         try:
         try:
             async with conn.transaction():
             async with conn.transaction():
-                await conn.execute("UPDATE portal SET chat_id=$1 WHERE chat_id=$2",
-                                   new_uuid, old_number)
+                await conn.execute(
+                    "UPDATE portal SET chat_id=$1 WHERE chat_id=$2", new_uuid, old_number
+                )
         except asyncpg.UniqueViolationError:
         except asyncpg.UniqueViolationError:
             await conn.execute("DELETE FROM portal WHERE chat_id=$1", old_number)
             await conn.execute("DELETE FROM portal WHERE chat_id=$1", old_number)
         await conn.execute("UPDATE message SET sender=$1 WHERE sender=$2", new_uuid, old_number)
         await conn.execute("UPDATE message SET sender=$1 WHERE sender=$2", new_uuid, old_number)
@@ -93,32 +111,49 @@ class Puppet:
             "uuid_registered=$8, number_registered=$9, "
             "uuid_registered=$8, number_registered=$9, "
             "custom_mxid=$10, access_token=$11, next_batch=$12, base_url=$13"
             "custom_mxid=$10, access_token=$11, next_batch=$12, base_url=$13"
         )
         )
-        q = (f"UPDATE puppet SET uuid=$1, {set_columns} WHERE number=$2"
-             if self.uuid is None
-             else f"UPDATE puppet SET number=$2, {set_columns} WHERE uuid=$1")
-        await self.db.execute(q,self.uuid, self.number, self.name, self.avatar_hash,
-                              self.avatar_url, self.name_set, self.avatar_set,
-                              self.uuid_registered, self.number_registered, self.custom_mxid,
-                              self.access_token, self.next_batch, self._base_url_str)
+        q = (
+            f"UPDATE puppet SET uuid=$1, {set_columns} WHERE number=$2"
+            if self.uuid is None
+            else f"UPDATE puppet SET number=$2, {set_columns} WHERE uuid=$1"
+        )
+        await self.db.execute(
+            q,
+            self.uuid,
+            self.number,
+            self.name,
+            self.avatar_hash,
+            self.avatar_url,
+            self.name_set,
+            self.avatar_set,
+            self.uuid_registered,
+            self.number_registered,
+            self.custom_mxid,
+            self.access_token,
+            self.next_batch,
+            self._base_url_str,
+        )
 
 
     @classmethod
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> 'Puppet':
+    def _from_row(cls, row: asyncpg.Record) -> "Puppet":
         data = {**row}
         data = {**row}
         base_url_str = data.pop("base_url")
         base_url_str = data.pop("base_url")
         base_url = URL(base_url_str) if base_url_str is not None else None
         base_url = URL(base_url_str) if base_url_str is not None else None
         return cls(base_url=base_url, **data)
         return cls(base_url=base_url, **data)
 
 
-    _select_base = ("SELECT uuid, number, name, avatar_hash, avatar_url, name_set, avatar_set, "
-                    "       uuid_registered, number_registered, custom_mxid, access_token, "
-                    "       next_batch, base_url "
-                    "FROM puppet")
+    _select_base = (
+        "SELECT uuid, number, name, avatar_hash, avatar_url, name_set, avatar_set, "
+        "       uuid_registered, number_registered, custom_mxid, access_token, "
+        "       next_batch, base_url "
+        "FROM puppet"
+    )
 
 
     @classmethod
     @classmethod
-    async def get_by_address(cls, address: Address) -> Optional['Puppet']:
+    async def get_by_address(cls, address: Address) -> Optional["Puppet"]:
         if address.uuid:
         if address.uuid:
             if address.number:
             if address.number:
-                row = await cls.db.fetchrow(f"{cls._select_base} WHERE uuid=$1 OR number=$2",
-                                            address.uuid, address.number)
+                row = await cls.db.fetchrow(
+                    f"{cls._select_base} WHERE uuid=$1 OR number=$2", address.uuid, address.number
+                )
             else:
             else:
                 row = await cls.db.fetchrow(f"{cls._select_base} WHERE uuid=$1", address.uuid)
                 row = await cls.db.fetchrow(f"{cls._select_base} WHERE uuid=$1", address.uuid)
         elif address.number:
         elif address.number:
@@ -130,13 +165,13 @@ class Puppet:
         return cls._from_row(row)
         return cls._from_row(row)
 
 
     @classmethod
     @classmethod
-    async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
+    async def get_by_custom_mxid(cls, mxid: UserID) -> Optional["Puppet"]:
         row = await cls.db.fetchrow(f"{cls._select_base} WHERE custom_mxid=$1", mxid)
         row = await cls.db.fetchrow(f"{cls._select_base} WHERE custom_mxid=$1", mxid)
         if not row:
         if not row:
             return None
             return None
         return cls._from_row(row)
         return cls._from_row(row)
 
 
     @classmethod
     @classmethod
-    async def all_with_custom_mxid(cls) -> List['Puppet']:
+    async def all_with_custom_mxid(cls) -> List["Puppet"]:
         rows = await cls.db.fetch(f"{cls._select_base} WHERE custom_mxid IS NOT NULL")
         rows = await cls.db.fetch(f"{cls._select_base} WHERE custom_mxid IS NOT NULL")
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]

+ 73 - 34
mautrix_signal/db/reaction.py

@@ -13,15 +13,15 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import Optional, ClassVar, Union, TYPE_CHECKING
+from typing import TYPE_CHECKING, ClassVar, Optional, Union
 from uuid import UUID
 from uuid import UUID
 
 
 from attr import dataclass
 from attr import dataclass
+from mautrix.types import EventID, RoomID
+from mautrix.util.async_db import Database
 import asyncpg
 import asyncpg
 
 
 from mausignald.types import Address, GroupID
 from mausignald.types import Address, GroupID
-from mautrix.types import RoomID, EventID
-from mautrix.util.async_db import Database
 
 
 from ..util import id_to_str
 from ..util import id_to_str
 
 
@@ -42,30 +42,54 @@ class Reaction:
     emoji: str
     emoji: str
 
 
     async def insert(self) -> None:
     async def insert(self) -> None:
-        q = ("INSERT INTO reaction (mxid, mx_room, signal_chat_id, signal_receiver, msg_author,"
-             "                      msg_timestamp, author, emoji) "
-             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)")
-        await self.db.execute(q, self.mxid, self.mx_room, id_to_str(self.signal_chat_id),
-                              self.signal_receiver, self.msg_author.best_identifier,
-                              self.msg_timestamp, self.author.best_identifier, self.emoji)
+        q = (
+            "INSERT INTO reaction (mxid, mx_room, signal_chat_id, signal_receiver, msg_author,"
+            "                      msg_timestamp, author, emoji) "
+            "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
+        )
+        await self.db.execute(
+            q,
+            self.mxid,
+            self.mx_room,
+            id_to_str(self.signal_chat_id),
+            self.signal_receiver,
+            self.msg_author.best_identifier,
+            self.msg_timestamp,
+            self.author.best_identifier,
+            self.emoji,
+        )
 
 
     async def edit(self, mx_room: RoomID, mxid: EventID, emoji: str) -> None:
     async def edit(self, mx_room: RoomID, mxid: EventID, emoji: str) -> None:
-        await self.db.execute("UPDATE reaction SET mxid=$1, mx_room=$2, emoji=$3 "
-                              "WHERE signal_chat_id=$4 AND signal_receiver=$5"
-                              "      AND msg_author=$6 AND msg_timestamp=$7 AND author=$8",
-                              mxid, mx_room, emoji, id_to_str(self.signal_chat_id),
-                              self.signal_receiver, self.msg_author.best_identifier,
-                              self.msg_timestamp, self.author.best_identifier)
+        await self.db.execute(
+            "UPDATE reaction SET mxid=$1, mx_room=$2, emoji=$3 "
+            "WHERE signal_chat_id=$4 AND signal_receiver=$5"
+            "      AND msg_author=$6 AND msg_timestamp=$7 AND author=$8",
+            mxid,
+            mx_room,
+            emoji,
+            id_to_str(self.signal_chat_id),
+            self.signal_receiver,
+            self.msg_author.best_identifier,
+            self.msg_timestamp,
+            self.author.best_identifier,
+        )
 
 
     async def delete(self) -> None:
     async def delete(self) -> None:
-        q = ("DELETE FROM reaction WHERE signal_chat_id=$1 AND signal_receiver=$2"
-             "                           AND msg_author=$3 AND msg_timestamp=$4 AND author=$5")
-        await self.db.execute(q, id_to_str(self.signal_chat_id), self.signal_receiver,
-                              self.msg_author.best_identifier, self.msg_timestamp,
-                              self.author.best_identifier)
+        q = (
+            "DELETE FROM reaction WHERE signal_chat_id=$1 AND signal_receiver=$2"
+            "                           AND msg_author=$3 AND msg_timestamp=$4 AND author=$5"
+        )
+        await self.db.execute(
+            q,
+            id_to_str(self.signal_chat_id),
+            self.signal_receiver,
+            self.msg_author.best_identifier,
+            self.msg_timestamp,
+            self.author.best_identifier,
+        )
 
 
     @classmethod
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> 'Reaction':
+    def _from_row(cls, row: asyncpg.Record) -> "Reaction":
         data = {**row}
         data = {**row}
         chat_id = data.pop("signal_chat_id")
         chat_id = data.pop("signal_chat_id")
         if data["signal_receiver"]:
         if data["signal_receiver"]:
@@ -75,25 +99,40 @@ class Reaction:
         return cls(signal_chat_id=chat_id, msg_author=msg_author, author=author, **data)
         return cls(signal_chat_id=chat_id, msg_author=msg_author, author=author, **data)
 
 
     @classmethod
     @classmethod
-    async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional['Reaction']:
-        q = ("SELECT mxid, mx_room, signal_chat_id, signal_receiver,"
-             "       msg_author, msg_timestamp, author, emoji "
-             "FROM reaction WHERE mxid=$1 AND mx_room=$2")
+    async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional["Reaction"]:
+        q = (
+            "SELECT mxid, mx_room, signal_chat_id, signal_receiver,"
+            "       msg_author, msg_timestamp, author, emoji "
+            "FROM reaction WHERE mxid=$1 AND mx_room=$2"
+        )
         row = await cls.db.fetchrow(q, mxid, mx_room)
         row = await cls.db.fetchrow(q, mxid, mx_room)
         if not row:
         if not row:
             return None
             return None
         return cls._from_row(row)
         return cls._from_row(row)
 
 
     @classmethod
     @classmethod
-    async def get_by_signal_id(cls, chat_id: Union[GroupID, Address], receiver: str,
-                               msg_author: Address, msg_timestamp: int, author: Address
-                               ) -> Optional['Reaction']:
-        q = ("SELECT mxid, mx_room, signal_chat_id, signal_receiver,"
-             "       msg_author, msg_timestamp, author, emoji "
-             "FROM reaction WHERE signal_chat_id=$1 AND signal_receiver=$2"
-             "                    AND msg_author=$3 AND msg_timestamp=$4 AND author=$5")
-        row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver, msg_author.best_identifier,
-                                    msg_timestamp, author.best_identifier)
+    async def get_by_signal_id(
+        cls,
+        chat_id: Union[GroupID, Address],
+        receiver: str,
+        msg_author: Address,
+        msg_timestamp: int,
+        author: Address,
+    ) -> Optional["Reaction"]:
+        q = (
+            "SELECT mxid, mx_room, signal_chat_id, signal_receiver,"
+            "       msg_author, msg_timestamp, author, emoji "
+            "FROM reaction WHERE signal_chat_id=$1 AND signal_receiver=$2"
+            "                    AND msg_author=$3 AND msg_timestamp=$4 AND author=$5"
+        )
+        row = await cls.db.fetchrow(
+            q,
+            id_to_str(chat_id),
+            receiver,
+            msg_author.best_identifier,
+            msg_timestamp,
+            author.best_identifier,
+        )
         if not row:
         if not row:
             return None
             return None
         return cls._from_row(row)
         return cls._from_row(row)

+ 148 - 109
mautrix_signal/db/upgrade.py

@@ -14,7 +14,6 @@
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 from asyncpg import Connection
 from asyncpg import Connection
-
 from mautrix.util.async_db import UpgradeTable
 from mautrix.util.async_db import UpgradeTable
 
 
 upgrade_table = UpgradeTable()
 upgrade_table = UpgradeTable()
@@ -22,97 +21,64 @@ upgrade_table = UpgradeTable()
 
 
 @upgrade_table.register(description="Initial revision")
 @upgrade_table.register(description="Initial revision")
 async def upgrade_v1(conn: Connection) -> None:
 async def upgrade_v1(conn: Connection) -> None:
-    await conn.execute("""CREATE TABLE portal (
-        chat_id     TEXT,
-        receiver    TEXT,
-        mxid        TEXT,
-        name        TEXT,
-        encrypted   BOOLEAN NOT NULL DEFAULT false,
-
-        PRIMARY KEY (chat_id, receiver)
-    )""")
-    await conn.execute("""CREATE TABLE "user" (
-        mxid        TEXT PRIMARY KEY,
-        username    TEXT,
-        uuid        UUID,
-        notice_room TEXT
-    )""")
-    await conn.execute("""CREATE TABLE puppet (
-        uuid      UUID UNIQUE,
-        number    TEXT UNIQUE,
-        name      TEXT,
-
-        uuid_registered   BOOLEAN NOT NULL DEFAULT false,
-        number_registered BOOLEAN NOT NULL DEFAULT false,
-
-        custom_mxid  TEXT,
-        access_token TEXT,
-        next_batch   TEXT
-    )""")
-    await conn.execute("""CREATE TABLE user_portal (
-        "user"          TEXT,
-        portal          TEXT,
-        portal_receiver TEXT,
-        in_community    BOOLEAN NOT NULL DEFAULT false,
-
-        FOREIGN KEY (portal, portal_receiver) REFERENCES portal(chat_id, receiver)
-            ON UPDATE CASCADE ON DELETE CASCADE
-    )""")
-    await conn.execute("""CREATE TABLE message (
-        mxid    TEXT NOT NULL,
-        mx_room TEXT NOT NULL,
-        sender          UUID,
-        timestamp       BIGINT,
-        signal_chat_id  TEXT,
-        signal_receiver TEXT,
-
-        PRIMARY KEY (sender, timestamp, signal_chat_id, signal_receiver),
-        FOREIGN KEY (signal_chat_id, signal_receiver) REFERENCES portal(chat_id, receiver)
-            ON UPDATE CASCADE ON DELETE CASCADE,
-        UNIQUE (mxid, mx_room)
-    )""")
-    await conn.execute("""CREATE TABLE reaction (
-        mxid    TEXT NOT NULL,
-        mx_room TEXT NOT NULL,
-
-        signal_chat_id  TEXT   NOT NULL,
-        signal_receiver TEXT   NOT NULL,
-        msg_author      UUID   NOT NULL,
-        msg_timestamp   BIGINT NOT NULL,
-        author          UUID   NOT NULL,
-
-        emoji TEXT NOT NULL,
-
-        PRIMARY KEY (signal_chat_id, signal_receiver, msg_author, msg_timestamp, author),
-        FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver)
-            REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver)
-            ON DELETE CASCADE ON UPDATE CASCADE,
-        UNIQUE (mxid, mx_room)
-    )""")
-
-
-@upgrade_table.register(description="Add avatar info to portal table")
-async def upgrade_v2(conn: Connection) -> None:
-    await conn.execute("ALTER TABLE portal ADD COLUMN avatar_hash TEXT")
-    await conn.execute("ALTER TABLE portal ADD COLUMN avatar_url TEXT")
-
-
-@upgrade_table.register(description="Add double-puppeting base_url to puppe table")
-async def upgrade_v3(conn: Connection) -> None:
-    await conn.execute("ALTER TABLE puppet ADD COLUMN base_url TEXT")
-
-
-@upgrade_table.register(description="Allow phone numbers as message sender identifiers")
-async def upgrade_v4(conn: Connection, scheme: str) -> None:
-    if scheme == "sqlite":
-        # SQLite doesn't have anything in the tables yet,
-        # so just recreate them without migrating data
-        await conn.execute("DROP TABLE message")
-        await conn.execute("DROP TABLE reaction")
-        await conn.execute("""CREATE TABLE message (
+    await conn.execute(
+        """
+        CREATE TABLE portal (
+            chat_id     TEXT,
+            receiver    TEXT,
+            mxid        TEXT,
+            name        TEXT,
+            encrypted   BOOLEAN NOT NULL DEFAULT false,
+
+            PRIMARY KEY (chat_id, receiver)
+        )
+        """
+    )
+    await conn.execute(
+        """
+        CREATE TABLE "user" (
+            mxid        TEXT PRIMARY KEY,
+            username    TEXT,
+            uuid        UUID,
+            notice_room TEXT
+        )
+        """
+    )
+    await conn.execute(
+        """
+        CREATE TABLE puppet (
+            uuid      UUID UNIQUE,
+            number    TEXT UNIQUE,
+            name      TEXT,
+
+            uuid_registered   BOOLEAN NOT NULL DEFAULT false,
+            number_registered BOOLEAN NOT NULL DEFAULT false,
+
+            custom_mxid  TEXT,
+            access_token TEXT,
+            next_batch   TEXT
+        )
+        """
+    )
+    await conn.execute(
+        """
+        CREATE TABLE user_portal (
+            "user"          TEXT,
+            portal          TEXT,
+            portal_receiver TEXT,
+            in_community    BOOLEAN NOT NULL DEFAULT false,
+
+            FOREIGN KEY (portal, portal_receiver) REFERENCES portal(chat_id, receiver)
+                ON UPDATE CASCADE ON DELETE CASCADE
+        )
+        """
+    )
+    await conn.execute(
+        """
+        CREATE TABLE message (
             mxid    TEXT NOT NULL,
             mxid    TEXT NOT NULL,
             mx_room TEXT NOT NULL,
             mx_room TEXT NOT NULL,
-            sender          TEXT,
+            sender          UUID,
             timestamp       BIGINT,
             timestamp       BIGINT,
             signal_chat_id  TEXT,
             signal_chat_id  TEXT,
             signal_receiver TEXT,
             signal_receiver TEXT,
@@ -121,16 +87,20 @@ async def upgrade_v4(conn: Connection, scheme: str) -> None:
             FOREIGN KEY (signal_chat_id, signal_receiver) REFERENCES portal(chat_id, receiver)
             FOREIGN KEY (signal_chat_id, signal_receiver) REFERENCES portal(chat_id, receiver)
                 ON UPDATE CASCADE ON DELETE CASCADE,
                 ON UPDATE CASCADE ON DELETE CASCADE,
             UNIQUE (mxid, mx_room)
             UNIQUE (mxid, mx_room)
-        )""")
-        await conn.execute("""CREATE TABLE reaction (
+        )
+        """
+    )
+    await conn.execute(
+        """
+        CREATE TABLE reaction (
             mxid    TEXT NOT NULL,
             mxid    TEXT NOT NULL,
             mx_room TEXT NOT NULL,
             mx_room TEXT NOT NULL,
 
 
             signal_chat_id  TEXT   NOT NULL,
             signal_chat_id  TEXT   NOT NULL,
             signal_receiver TEXT   NOT NULL,
             signal_receiver TEXT   NOT NULL,
-            msg_author      TEXT   NOT NULL,
+            msg_author      UUID   NOT NULL,
             msg_timestamp   BIGINT NOT NULL,
             msg_timestamp   BIGINT NOT NULL,
-            author          TEXT   NOT NULL,
+            author          UUID   NOT NULL,
 
 
             emoji TEXT NOT NULL,
             emoji TEXT NOT NULL,
 
 
@@ -139,19 +109,84 @@ async def upgrade_v4(conn: Connection, scheme: str) -> None:
                 REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver)
                 REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver)
                 ON DELETE CASCADE ON UPDATE CASCADE,
                 ON DELETE CASCADE ON UPDATE CASCADE,
             UNIQUE (mxid, mx_room)
             UNIQUE (mxid, mx_room)
-        )""")
+        )
+        """
+    )
+
+
+@upgrade_table.register(description="Add avatar info to portal table")
+async def upgrade_v2(conn: Connection) -> None:
+    await conn.execute("ALTER TABLE portal ADD COLUMN avatar_hash TEXT")
+    await conn.execute("ALTER TABLE portal ADD COLUMN avatar_url TEXT")
+
+
+@upgrade_table.register(description="Add double-puppeting base_url to puppe table")
+async def upgrade_v3(conn: Connection) -> None:
+    await conn.execute("ALTER TABLE puppet ADD COLUMN base_url TEXT")
+
+
+@upgrade_table.register(description="Allow phone numbers as message sender identifiers")
+async def upgrade_v4(conn: Connection, scheme: str) -> None:
+    if scheme == "sqlite":
+        # SQLite doesn't have anything in the tables yet,
+        # so just recreate them without migrating data
+        await conn.execute("DROP TABLE message")
+        await conn.execute("DROP TABLE reaction")
+        await conn.execute(
+            """
+            CREATE TABLE message (
+                mxid    TEXT NOT NULL,
+                mx_room TEXT NOT NULL,
+                sender          TEXT,
+                timestamp       BIGINT,
+                signal_chat_id  TEXT,
+                signal_receiver TEXT,
+
+                PRIMARY KEY (sender, timestamp, signal_chat_id, signal_receiver),
+                FOREIGN KEY (signal_chat_id, signal_receiver) REFERENCES portal(chat_id, receiver)
+                    ON UPDATE CASCADE ON DELETE CASCADE,
+                UNIQUE (mxid, mx_room)
+            )
+            """
+        )
+        await conn.execute(
+            """
+            CREATE TABLE reaction (
+                mxid    TEXT NOT NULL,
+                mx_room TEXT NOT NULL,
+
+                signal_chat_id  TEXT   NOT NULL,
+                signal_receiver TEXT   NOT NULL,
+                msg_author      TEXT   NOT NULL,
+                msg_timestamp   BIGINT NOT NULL,
+                author          TEXT   NOT NULL,
+
+                emoji TEXT NOT NULL,
+
+                PRIMARY KEY (signal_chat_id, signal_receiver, msg_author, msg_timestamp, author),
+                FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver)
+                    REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver)
+                    ON DELETE CASCADE ON UPDATE CASCADE,
+                UNIQUE (mxid, mx_room)
+            )
+            """
+        )
         return
         return
 
 
-    cname = await conn.fetchval("SELECT constraint_name FROM information_schema.table_constraints "
-                                "WHERE table_name='reaction' AND constraint_name LIKE '%_fkey'")
+    cname = await conn.fetchval(
+        "SELECT constraint_name FROM information_schema.table_constraints "
+        "WHERE table_name='reaction' AND constraint_name LIKE '%_fkey'"
+    )
     await conn.execute(f"ALTER TABLE reaction DROP CONSTRAINT {cname}")
     await conn.execute(f"ALTER TABLE reaction DROP CONSTRAINT {cname}")
     await conn.execute("ALTER TABLE reaction ALTER COLUMN msg_author SET DATA TYPE TEXT")
     await conn.execute("ALTER TABLE reaction ALTER COLUMN msg_author SET DATA TYPE TEXT")
     await conn.execute("ALTER TABLE reaction ALTER COLUMN author SET DATA TYPE TEXT")
     await conn.execute("ALTER TABLE reaction ALTER COLUMN author SET DATA TYPE TEXT")
     await conn.execute("ALTER TABLE message ALTER COLUMN sender SET DATA TYPE TEXT")
     await conn.execute("ALTER TABLE message ALTER COLUMN sender SET DATA TYPE TEXT")
-    await conn.execute(f"ALTER TABLE reaction ADD CONSTRAINT {cname} "
-                       "FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver) "
-                       "  REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver) "
-                       "  ON DELETE CASCADE ON UPDATE CASCADE")
+    await conn.execute(
+        f"ALTER TABLE reaction ADD CONSTRAINT {cname} "
+        "FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver) "
+        "  REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver) "
+        "  ON DELETE CASCADE ON UPDATE CASCADE"
+    )
 
 
 
 
 @upgrade_table.register(description="Add avatar info to puppet table")
 @upgrade_table.register(description="Add avatar info to puppet table")
@@ -179,12 +214,16 @@ async def upgrade_v7(conn: Connection) -> None:
 
 
 @upgrade_table.register(description="Add support for disappearing messages")
 @upgrade_table.register(description="Add support for disappearing messages")
 async def upgrade_v8(conn: Connection) -> None:
 async def upgrade_v8(conn: Connection) -> None:
-    await conn.execute("""CREATE TABLE disappearing_message (
-        room_id             TEXT,
-        mxid                TEXT,
-        expiration_seconds  BIGINT,
-        expiration_ts       BIGINT,
-
-        PRIMARY KEY (room_id, mxid)
-    )""")
+    await conn.execute(
+        """
+        CREATE TABLE disappearing_message (
+            room_id             TEXT,
+            mxid                TEXT,
+            expiration_seconds  BIGINT,
+            expiration_ts       BIGINT,
+
+            PRIMARY KEY (room_id, mxid)
+        )
+        """
+    )
     await conn.execute("ALTER TABLE portal ADD COLUMN expiration_time BIGINT")
     await conn.execute("ALTER TABLE portal ADD COLUMN expiration_time BIGINT")

+ 7 - 9
mautrix_signal/db/user.py

@@ -13,12 +13,11 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import Optional, ClassVar, List, TYPE_CHECKING
+from typing import TYPE_CHECKING, ClassVar, List, Optional
 from uuid import UUID
 from uuid import UUID
 
 
 from attr import dataclass
 from attr import dataclass
-
-from mautrix.types import UserID, RoomID
+from mautrix.types import RoomID, UserID
 from mautrix.util.async_db import Database
 from mautrix.util.async_db import Database
 
 
 fake_db = Database.create("") if TYPE_CHECKING else None
 fake_db = Database.create("") if TYPE_CHECKING else None
@@ -34,8 +33,7 @@ class User:
     notice_room: Optional[RoomID]
     notice_room: Optional[RoomID]
 
 
     async def insert(self) -> None:
     async def insert(self) -> None:
-        q = ('INSERT INTO "user" (mxid, username, uuid, notice_room) '
-             'VALUES ($1, $2, $3, $4)')
+        q = 'INSERT INTO "user" (mxid, username, uuid, notice_room) ' "VALUES ($1, $2, $3, $4)"
         await self.db.execute(q, self.mxid, self.username, self.uuid, self.notice_room)
         await self.db.execute(q, self.mxid, self.username, self.uuid, self.notice_room)
 
 
     async def update(self) -> None:
     async def update(self) -> None:
@@ -43,7 +41,7 @@ class User:
         await self.db.execute(q, self.username, self.uuid, self.notice_room, self.mxid)
         await self.db.execute(q, self.username, self.uuid, self.notice_room, self.mxid)
 
 
     @classmethod
     @classmethod
-    async def get_by_mxid(cls, mxid: UserID) -> Optional['User']:
+    async def get_by_mxid(cls, mxid: UserID) -> Optional["User"]:
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE mxid=$1'
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE mxid=$1'
         row = await cls.db.fetchrow(q, mxid)
         row = await cls.db.fetchrow(q, mxid)
         if not row:
         if not row:
@@ -51,7 +49,7 @@ class User:
         return cls(**row)
         return cls(**row)
 
 
     @classmethod
     @classmethod
-    async def get_by_username(cls, username: str) -> Optional['User']:
+    async def get_by_username(cls, username: str) -> Optional["User"]:
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE username=$1'
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE username=$1'
         row = await cls.db.fetchrow(q, username)
         row = await cls.db.fetchrow(q, username)
         if not row:
         if not row:
@@ -59,7 +57,7 @@ class User:
         return cls(**row)
         return cls(**row)
 
 
     @classmethod
     @classmethod
-    async def get_by_uuid(cls, uuid: UUID) -> Optional['User']:
+    async def get_by_uuid(cls, uuid: UUID) -> Optional["User"]:
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE uuid=$1'
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE uuid=$1'
         row = await cls.db.fetchrow(q, uuid)
         row = await cls.db.fetchrow(q, uuid)
         if not row:
         if not row:
@@ -67,7 +65,7 @@ class User:
         return cls(**row)
         return cls(**row)
 
 
     @classmethod
     @classmethod
-    async def all_logged_in(cls) -> List['User']:
+    async def all_logged_in(cls) -> List["User"]:
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE username IS NOT NULL'
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE username IS NOT NULL'
         rows = await cls.db.fetch(q)
         rows = await cls.db.fetch(q)
         return [cls(**row) for row in rows]
         return [cls(**row) for row in rows]

+ 26 - 15
mautrix_signal/formatter.py

@@ -13,30 +13,35 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import Tuple, List, cast
+from typing import List, Tuple, cast
 from html import escape
 from html import escape
 import struct
 import struct
 
 
-from mausignald.types import MessageData, Address, Mention
-from mautrix.types import TextMessageEventContent, MessageType, Format
-from mautrix.util.formatter import (MatrixParser as BaseMatrixParser, EntityString, SimpleEntity,
-                                    EntityType, MarkdownString)
+from mautrix.types import Format, MessageType, TextMessageEventContent
+from mautrix.util.formatter import EntityString, EntityType, MarkdownString
+from mautrix.util.formatter import MatrixParser as BaseMatrixParser
+from mautrix.util.formatter import SimpleEntity
 
 
-from . import puppet as pu, user as u
+from mausignald.types import Address, Mention, MessageData
+
+from . import puppet as pu
+from . import user as u
 
 
 
 
 # Helper methods from rom https://github.com/LonamiWebs/Telethon/blob/master/telethon/helpers.py
 # Helper methods from rom https://github.com/LonamiWebs/Telethon/blob/master/telethon/helpers.py
 # I don't know if this is how Signal actually calculates lengths, but it seems
 # I don't know if this is how Signal actually calculates lengths, but it seems
 # to work better than plain len()
 # to work better than plain len()
 def add_surrogate(text: str) -> str:
 def add_surrogate(text: str) -> str:
-    return ''.join(
-        ''.join(chr(y) for y in struct.unpack('<HH', x.encode('utf-16le')))
-        if (0x10000 <= ord(x) <= 0x10FFFF) else x for x in text
+    return "".join(
+        "".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16le")))
+        if (0x10000 <= ord(x) <= 0x10FFFF)
+        else x
+        for x in text
     )
     )
 
 
 
 
 def del_surrogate(text: str) -> str:
 def del_surrogate(text: str) -> str:
-    return text.encode('utf-16', 'surrogatepass').decode('utf-16')
+    return text.encode("utf-16", "surrogatepass").decode("utf-16")
 
 
 
 
 async def signal_to_matrix(message: MessageData) -> TextMessageEventContent:
 async def signal_to_matrix(message: MessageData) -> TextMessageEventContent:
@@ -47,7 +52,7 @@ async def signal_to_matrix(message: MessageData) -> TextMessageEventContent:
         html_chunks = []
         html_chunks = []
         last_offset = 0
         last_offset = 0
         for mention in message.mentions:
         for mention in message.mentions:
-            before = surrogated_text[last_offset:mention.start]
+            before = surrogated_text[last_offset : mention.start]
             last_offset = mention.start + mention.length
             last_offset = mention.start + mention.length
 
 
             text_chunks.append(before)
             text_chunks.append(before)
@@ -67,11 +72,17 @@ async def signal_to_matrix(message: MessageData) -> TextMessageEventContent:
 
 
 # TODO this has a lot of duplication with mautrix-facebook, maybe move to mautrix-python
 # TODO this has a lot of duplication with mautrix-facebook, maybe move to mautrix-python
 class SignalFormatString(EntityString[SimpleEntity, EntityType], MarkdownString):
 class SignalFormatString(EntityString[SimpleEntity, EntityType], MarkdownString):
-    def format(self, entity_type: EntityType, **kwargs) -> 'SignalFormatString':
+    def format(self, entity_type: EntityType, **kwargs) -> "SignalFormatString":
         prefix = suffix = ""
         prefix = suffix = ""
         if entity_type == EntityType.USER_MENTION:
         if entity_type == EntityType.USER_MENTION:
-            self.entities.append(SimpleEntity(type=entity_type, offset=0, length=len(self.text),
-                                              extra_info={"user_id": kwargs["user_id"]}))
+            self.entities.append(
+                SimpleEntity(
+                    type=entity_type,
+                    offset=0,
+                    length=len(self.text),
+                    extra_info={"user_id": kwargs["user_id"]},
+                )
+            )
             return self
             return self
         elif entity_type == EntityType.BOLD:
         elif entity_type == EntityType.BOLD:
             prefix = suffix = "**"
             prefix = suffix = "**"
@@ -80,7 +91,7 @@ class SignalFormatString(EntityString[SimpleEntity, EntityType], MarkdownString)
         elif entity_type == EntityType.STRIKETHROUGH:
         elif entity_type == EntityType.STRIKETHROUGH:
             prefix = suffix = "~~"
             prefix = suffix = "~~"
         elif entity_type == EntityType.URL:
         elif entity_type == EntityType.URL:
-            if kwargs['url'] != self.text:
+            if kwargs["url"] != self.text:
                 suffix = f" ({kwargs['url']})"
                 suffix = f" ({kwargs['url']})"
         elif entity_type == EntityType.PREFORMATTED:
         elif entity_type == EntityType.PREFORMATTED:
             prefix = f"```{kwargs['language']}\n"
             prefix = f"```{kwargs['language']}\n"

+ 3 - 4
mautrix_signal/get_version.py

@@ -1,6 +1,6 @@
-import subprocess
-import shutil
 import os
 import os
+import shutil
+import subprocess
 
 
 from . import __version__
 from . import __version__
 
 
@@ -34,8 +34,7 @@ else:
     git_revision_url = None
     git_revision_url = None
     git_tag = None
     git_tag = None
 
 
-git_tag_url = (f"https://github.com/mautrix/signal/releases/tag/{git_tag}"
-               if git_tag else None)
+git_tag_url = f"https://github.com/mautrix/signal/releases/tag/{git_tag}" if git_tag else None
 
 
 if git_tag and __version__ == git_tag[1:].replace("-", ""):
 if git_tag and __version__ == git_tag[1:].replace("-", ""):
     version = __version__
     version = __version__

+ 54 - 26
mautrix_signal/matrix.py

@@ -13,26 +13,41 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import List, Union, TYPE_CHECKING
+from typing import TYPE_CHECKING, List, Union
 
 
 from mautrix.bridge import BaseMatrixHandler
 from mautrix.bridge import BaseMatrixHandler
-from mautrix.types import (Event, ReactionEvent, StateEvent, RoomID, EventID, UserID, TypingEvent,
-                           ReactionEventContent, RelationType, EventType, ReceiptEvent,
-                           PresenceEvent, RedactionEvent, SingleReceiptEventContent)
+from mautrix.types import (
+    Event,
+    EventID,
+    EventType,
+    PresenceEvent,
+    ReactionEvent,
+    ReactionEventContent,
+    ReceiptEvent,
+    RedactionEvent,
+    RelationType,
+    RoomID,
+    SingleReceiptEventContent,
+    StateEvent,
+    TypingEvent,
+    UserID,
+)
 
 
 from mautrix_signal.db.disappearing_message import DisappearingMessage
 from mautrix_signal.db.disappearing_message import DisappearingMessage
 
 
+from . import portal as po
+from . import signal as s
+from . import user as u
 from .db import Message as DBMessage
 from .db import Message as DBMessage
-from . import portal as po, user as u, signal as s
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .__main__ import SignalBridge
     from .__main__ import SignalBridge
 
 
 
 
 class MatrixHandler(BaseMatrixHandler):
 class MatrixHandler(BaseMatrixHandler):
-    signal: 's.SignalHandler'
+    signal: "s.SignalHandler"
 
 
-    def __init__(self, bridge: 'SignalBridge') -> None:
+    def __init__(self, bridge: "SignalBridge") -> None:
         prefix, suffix = bridge.config["bridge.username_template"].format(userid=":").split(":")
         prefix, suffix = bridge.config["bridge.username_template"].format(userid=":").split(":")
         homeserver = bridge.config["homeserver.domain"]
         homeserver = bridge.config["homeserver.domain"]
         self.user_id_prefix = f"@{prefix}"
         self.user_id_prefix = f"@{prefix}"
@@ -41,13 +56,14 @@ class MatrixHandler(BaseMatrixHandler):
 
 
         super().__init__(bridge=bridge)
         super().__init__(bridge=bridge)
 
 
-    async def send_welcome_message(self, room_id: RoomID, inviter: 'u.User') -> None:
+    async def send_welcome_message(self, room_id: RoomID, inviter: "u.User") -> None:
         await super().send_welcome_message(room_id, inviter)
         await super().send_welcome_message(room_id, inviter)
         if not inviter.notice_room:
         if not inviter.notice_room:
             inviter.notice_room = room_id
             inviter.notice_room = room_id
             await inviter.update()
             await inviter.update()
-            await self.az.intent.send_notice(room_id, "This room has been marked as your "
-                                                      "Signal bridge notice room.")
+            await self.az.intent.send_notice(
+                room_id, "This room has been marked as your Signal bridge notice room."
+            )
 
 
     async def handle_leave(self, room_id: RoomID, user_id: UserID, event_id: EventID) -> None:
     async def handle_leave(self, room_id: RoomID, user_id: UserID, event_id: EventID) -> None:
         portal = await po.Portal.get_by_mxid(room_id)
         portal = await po.Portal.get_by_mxid(room_id)
@@ -72,11 +88,14 @@ class MatrixHandler(BaseMatrixHandler):
         await portal.handle_matrix_join(user)
         await portal.handle_matrix_join(user)
 
 
     @classmethod
     @classmethod
-    async def handle_reaction(cls, room_id: RoomID, user_id: UserID, event_id: EventID,
-                              content: ReactionEventContent) -> None:
+    async def handle_reaction(
+        cls, room_id: RoomID, user_id: UserID, event_id: EventID, content: ReactionEventContent
+    ) -> None:
         if content.relates_to.rel_type != RelationType.ANNOTATION:
         if content.relates_to.rel_type != RelationType.ANNOTATION:
-            cls.log.debug(f"Ignoring m.reaction event in {room_id} from {user_id} with unexpected "
-                          f"relation type {content.relates_to.rel_type}")
+            cls.log.debug(
+                f"Ignoring m.reaction event in {room_id} from {user_id} with unexpected "
+                f"relation type {content.relates_to.rel_type}"
+            )
             return
             return
         user = await u.User.get_by_mxid(user_id)
         user = await u.User.get_by_mxid(user_id)
         if not user:
         if not user:
@@ -86,12 +105,14 @@ class MatrixHandler(BaseMatrixHandler):
         if not portal:
         if not portal:
             return
             return
 
 
-        await portal.handle_matrix_reaction(user, event_id, content.relates_to.event_id,
-                                            content.relates_to.key)
+        await portal.handle_matrix_reaction(
+            user, event_id, content.relates_to.event_id, content.relates_to.key
+        )
 
 
     @staticmethod
     @staticmethod
-    async def handle_redaction(room_id: RoomID, user_id: UserID, event_id: EventID,
-                               redaction_event_id: EventID) -> None:
+    async def handle_redaction(
+        room_id: RoomID, user_id: UserID, event_id: EventID, redaction_event_id: EventID
+    ) -> None:
         user = await u.User.get_by_mxid(user_id)
         user = await u.User.get_by_mxid(user_id)
         if not user:
         if not user:
             return
             return
@@ -102,8 +123,13 @@ class MatrixHandler(BaseMatrixHandler):
 
 
         await portal.handle_matrix_redaction(user, event_id, redaction_event_id)
         await portal.handle_matrix_redaction(user, event_id, redaction_event_id)
 
 
-    async def handle_read_receipt(self, user: 'u.User', portal: 'po.Portal', event_id: EventID,
-                                  data: SingleReceiptEventContent) -> None:
+    async def handle_read_receipt(
+        self,
+        user: "u.User",
+        portal: "po.Portal",
+        event_id: EventID,
+        data: SingleReceiptEventContent,
+    ) -> None:
         await portal.handle_read_receipt(event_id, data)
         await portal.handle_read_receipt(event_id, data)
 
 
         message = await DBMessage.get_by_mxid(event_id, portal.mxid)
         message = await DBMessage.get_by_mxid(event_id, portal.mxid)
@@ -111,8 +137,9 @@ class MatrixHandler(BaseMatrixHandler):
             return
             return
 
 
         user.log.trace(f"Sending read receipt for {message.timestamp} to {message.sender}")
         user.log.trace(f"Sending read receipt for {message.timestamp} to {message.sender}")
-        await self.signal.send_receipt(user.username, message.sender,
-                                       timestamps=[message.timestamp], when=data.ts, read=True)
+        await self.signal.send_receipt(
+            user.username, message.sender, timestamps=[message.timestamp], when=data.ts, read=True
+        )
 
 
     async def handle_typing(self, room_id: RoomID, typing: List[UserID]) -> None:
     async def handle_typing(self, room_id: RoomID, typing: List[UserID]) -> None:
         pass
         pass
@@ -134,8 +161,9 @@ class MatrixHandler(BaseMatrixHandler):
             evt: RedactionEvent
             evt: RedactionEvent
             await self.handle_redaction(evt.room_id, evt.sender, evt.redacts, evt.event_id)
             await self.handle_redaction(evt.room_id, evt.sender, evt.redacts, evt.event_id)
 
 
-    async def handle_ephemeral_event(self, evt: Union[ReceiptEvent, PresenceEvent, TypingEvent]
-                                     ) -> None:
+    async def handle_ephemeral_event(
+        self, evt: Union[ReceiptEvent, PresenceEvent, TypingEvent]
+    ) -> None:
         if evt.type == EventType.TYPING:
         if evt.type == EventType.TYPING:
             await self.handle_typing(evt.room_id, evt.content.user_ids)
             await self.handle_typing(evt.room_id, evt.content.user_ids)
         else:
         else:
@@ -157,8 +185,8 @@ class MatrixHandler(BaseMatrixHandler):
         elif evt.type == EventType.ROOM_AVATAR:
         elif evt.type == EventType.ROOM_AVATAR:
             await portal.handle_matrix_avatar(user, evt.content.url)
             await portal.handle_matrix_avatar(user, evt.content.url)
 
 
-    async def allow_message(self, user: 'u.User') -> bool:
+    async def allow_message(self, user: "u.User") -> bool:
         return user.relay_whitelisted
         return user.relay_whitelisted
 
 
-    async def allow_bridging_message(self, user: 'u.User', portal: 'po.Portal') -> bool:
+    async def allow_bridging_message(self, user: "u.User", portal: "po.Portal") -> bool:
         return portal.has_relay or await user.is_logged_in()
         return portal.has_relay or await user.is_logged_in()

Разница между файлами не показана из-за своего большого размера
+ 371 - 200
mautrix_signal/portal.py


+ 96 - 45
mautrix_signal/puppet.py

@@ -13,26 +13,42 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import (Optional, Dict, AsyncIterable, Awaitable, AsyncGenerator, Union, Tuple,
-                    TYPE_CHECKING, cast)
+from typing import (
+    TYPE_CHECKING,
+    AsyncGenerator,
+    AsyncIterable,
+    Awaitable,
+    Dict,
+    Optional,
+    Tuple,
+    Union,
+    cast,
+)
 from uuid import UUID
 from uuid import UUID
-import hashlib
 import asyncio
 import asyncio
+import hashlib
 import os.path
 import os.path
 
 
-from yarl import URL
-
-from mausignald.types import Address, Contact, Profile
-from mautrix.bridge import BasePuppet, async_getter_lock
 from mautrix.appservice import IntentAPI
 from mautrix.appservice import IntentAPI
-from mautrix.types import (UserID, SyncToken, RoomID, ContentURI, EventType,
-                           PowerLevelStateEventContent)
+from mautrix.bridge import BasePuppet, async_getter_lock
 from mautrix.errors import MForbidden
 from mautrix.errors import MForbidden
+from mautrix.types import (
+    ContentURI,
+    EventType,
+    PowerLevelStateEventContent,
+    RoomID,
+    SyncToken,
+    UserID,
+)
 from mautrix.util.simple_template import SimpleTemplate
 from mautrix.util.simple_template import SimpleTemplate
+from yarl import URL
 
 
-from .db import Puppet as DBPuppet
+from mausignald.types import Address, Contact, Profile
+
+from . import portal as p
+from . import user as u
 from .config import Config
 from .config import Config
-from . import portal as p, user as u
+from .db import Puppet as DBPuppet
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .__main__ import SignalBridge
     from .__main__ import SignalBridge
@@ -44,9 +60,9 @@ except ImportError:
 
 
 
 
 class Puppet(DBPuppet, BasePuppet):
 class Puppet(DBPuppet, BasePuppet):
-    by_uuid: Dict[UUID, 'Puppet'] = {}
-    by_number: Dict[str, 'Puppet'] = {}
-    by_custom_mxid: Dict[UserID, 'Puppet'] = {}
+    by_uuid: Dict[UUID, "Puppet"] = {}
+    by_number: Dict[str, "Puppet"] = {}
+    by_custom_mxid: Dict[UserID, "Puppet"] = {}
     hs_domain: str
     hs_domain: str
     mxid_template: SimpleTemplate[str]
     mxid_template: SimpleTemplate[str]
 
 
@@ -58,17 +74,37 @@ class Puppet(DBPuppet, BasePuppet):
     _uuid_lock: asyncio.Lock
     _uuid_lock: asyncio.Lock
     _update_info_lock: asyncio.Lock
     _update_info_lock: asyncio.Lock
 
 
-    def __init__(self, uuid: Optional[UUID], number: Optional[str], name: Optional[str] = None,
-                 avatar_url: Optional[ContentURI] = None, avatar_hash: Optional[str] = None,
-                 name_set: bool = False, avatar_set: bool = False, uuid_registered: bool = False,
-                 number_registered: bool = False, custom_mxid: Optional[UserID] = None,
-                 access_token: Optional[str] = None, next_batch: Optional[SyncToken] = None,
-                 base_url: Optional[URL] = None) -> None:
-        super().__init__(uuid=uuid, number=number, name=name, avatar_url=avatar_url,
-                         avatar_hash=avatar_hash, name_set=name_set, avatar_set=avatar_set,
-                         uuid_registered=uuid_registered, number_registered=number_registered,
-                         custom_mxid=custom_mxid, access_token=access_token, next_batch=next_batch,
-                         base_url=base_url)
+    def __init__(
+        self,
+        uuid: Optional[UUID],
+        number: Optional[str],
+        name: Optional[str] = None,
+        avatar_url: Optional[ContentURI] = None,
+        avatar_hash: Optional[str] = None,
+        name_set: bool = False,
+        avatar_set: bool = False,
+        uuid_registered: bool = False,
+        number_registered: bool = False,
+        custom_mxid: Optional[UserID] = None,
+        access_token: Optional[str] = None,
+        next_batch: Optional[SyncToken] = None,
+        base_url: Optional[URL] = None,
+    ) -> None:
+        super().__init__(
+            uuid=uuid,
+            number=number,
+            name=name,
+            avatar_url=avatar_url,
+            avatar_hash=avatar_hash,
+            name_set=name_set,
+            avatar_set=avatar_set,
+            uuid_registered=uuid_registered,
+            number_registered=number_registered,
+            custom_mxid=custom_mxid,
+            access_token=access_token,
+            next_batch=next_batch,
+            base_url=base_url,
+        )
         self.log = self.log.getChild(str(uuid) if uuid else number)
         self.log = self.log.getChild(str(uuid) if uuid else number)
 
 
         self.default_mxid = self.get_mxid_from_id(self.address)
         self.default_mxid = self.get_mxid_from_id(self.address)
@@ -79,25 +115,34 @@ class Puppet(DBPuppet, BasePuppet):
         self._update_info_lock = asyncio.Lock()
         self._update_info_lock = asyncio.Lock()
 
 
     @classmethod
     @classmethod
-    def init_cls(cls, bridge: 'SignalBridge') -> AsyncIterable[Awaitable[None]]:
+    def init_cls(cls, bridge: "SignalBridge") -> AsyncIterable[Awaitable[None]]:
         cls.config = bridge.config
         cls.config = bridge.config
         cls.loop = bridge.loop
         cls.loop = bridge.loop
         cls.mx = bridge.matrix
         cls.mx = bridge.matrix
         cls.az = bridge.az
         cls.az = bridge.az
         cls.hs_domain = cls.config["homeserver.domain"]
         cls.hs_domain = cls.config["homeserver.domain"]
-        cls.mxid_template = SimpleTemplate(cls.config["bridge.username_template"], "userid",
-                                           prefix="@", suffix=f":{cls.hs_domain}", type=str)
+        cls.mxid_template = SimpleTemplate(
+            cls.config["bridge.username_template"],
+            "userid",
+            prefix="@",
+            suffix=f":{cls.hs_domain}",
+            type=str,
+        )
         cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
         cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
 
 
-        cls.homeserver_url_map = {server: URL(url) for server, url
-                                  in cls.config["bridge.double_puppet_server_map"].items()}
+        cls.homeserver_url_map = {
+            server: URL(url)
+            for server, url in cls.config["bridge.double_puppet_server_map"].items()
+        }
         cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
         cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
-        cls.login_shared_secret_map = {server: secret.encode("utf-8") for server, secret
-                                       in cls.config["bridge.login_shared_secret_map"].items()}
+        cls.login_shared_secret_map = {
+            server: secret.encode("utf-8")
+            for server, secret in cls.config["bridge.login_shared_secret_map"].items()
+        }
         cls.login_device_name = "Signal Bridge"
         cls.login_device_name = "Signal Bridge"
         return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
         return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
 
 
-    def intent_for(self, portal: 'p.Portal') -> IntentAPI:
+    def intent_for(self, portal: "p.Portal") -> IntentAPI:
         if portal.chat_id == self.address:
         if portal.chat_id == self.address:
             return self.default_mxid_intent
             return self.default_mxid_intent
         return self.intent
         return self.intent
@@ -167,8 +212,10 @@ class Puppet(DBPuppet, BasePuppet):
         try:
         try:
             joined_rooms = await prev_intent.get_joined_rooms()
             joined_rooms = await prev_intent.get_joined_rooms()
         except MForbidden as e:
         except MForbidden as e:
-            self.log.debug(f"Got MForbidden ({e.message}) when getting joined rooms of old mxid, "
-                           "assuming there are no rooms to rejoin")
+            self.log.debug(
+                f"Got MForbidden ({e.message}) when getting joined rooms of old mxid, "
+                "assuming there are no rooms to rejoin"
+            )
             return
             return
         for room_id in joined_rooms:
         for room_id in joined_rooms:
             await prev_intent.invite_user(room_id, self.default_mxid)
             await prev_intent.invite_user(room_id, self.default_mxid)
@@ -176,8 +223,9 @@ class Puppet(DBPuppet, BasePuppet):
             await prev_intent.leave_room(room_id)
             await prev_intent.leave_room(room_id)
             await new_intent.join_room_by_id(room_id)
             await new_intent.join_room_by_id(room_id)
 
 
-    async def _migrate_powers(self, prev_intent: IntentAPI, new_intent: IntentAPI, room_id: RoomID
-                              ) -> None:
+    async def _migrate_powers(
+        self, prev_intent: IntentAPI, new_intent: IntentAPI, room_id: RoomID
+    ) -> None:
         try:
         try:
             powers: PowerLevelStateEventContent
             powers: PowerLevelStateEventContent
             powers = await prev_intent.get_state_event(room_id, EventType.ROOM_POWER_LEVELS)
             powers = await prev_intent.get_state_event(room_id, EventType.ROOM_POWER_LEVELS)
@@ -260,8 +308,11 @@ class Puppet(DBPuppet, BasePuppet):
         return False
         return False
 
 
     @staticmethod
     @staticmethod
-    async def upload_avatar(self: Union['Puppet', 'p.Portal'], path: str, intent: IntentAPI,
-                            ) -> Union[bool, Tuple[str, ContentURI]]:
+    async def upload_avatar(
+        self: Union["Puppet", "p.Portal"],
+        path: str,
+        intent: IntentAPI,
+    ) -> Union[bool, Tuple[str, ContentURI]]:
         if not path:
         if not path:
             return False
             return False
         if not path.startswith("/"):
         if not path.startswith("/"):
@@ -321,7 +372,7 @@ class Puppet(DBPuppet, BasePuppet):
         await self.update()
         await self.update()
 
 
     @classmethod
     @classmethod
-    async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']:
+    async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional["Puppet"]:
         address = cls.get_id_from_mxid(mxid)
         address = cls.get_id_from_mxid(mxid)
         if not address:
         if not address:
             return None
             return None
@@ -329,7 +380,7 @@ class Puppet(DBPuppet, BasePuppet):
 
 
     @classmethod
     @classmethod
     @async_getter_lock
     @async_getter_lock
-    async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
+    async def get_by_custom_mxid(cls, mxid: UserID) -> Optional["Puppet"]:
         try:
         try:
             return cls.by_custom_mxid[mxid]
             return cls.by_custom_mxid[mxid]
         except KeyError:
         except KeyError:
@@ -348,7 +399,7 @@ class Puppet(DBPuppet, BasePuppet):
         if not identifier:
         if not identifier:
             return None
             return None
         if identifier.startswith("phone_"):
         if identifier.startswith("phone_"):
-            return Address(number="+" + identifier[len("phone_"):])
+            return Address(number="+" + identifier[len("phone_") :])
         else:
         else:
             try:
             try:
                 return Address(uuid=UUID(identifier.upper()))
                 return Address(uuid=UUID(identifier.upper()))
@@ -367,7 +418,7 @@ class Puppet(DBPuppet, BasePuppet):
 
 
     @classmethod
     @classmethod
     @async_getter_lock
     @async_getter_lock
-    async def get_by_address(cls, address: Address, create: bool = True) -> Optional['Puppet']:
+    async def get_by_address(cls, address: Address, create: bool = True) -> Optional["Puppet"]:
         puppet = await cls._get_by_address(address, create)
         puppet = await cls._get_by_address(address, create)
         if puppet and address.uuid and not puppet.uuid:
         if puppet and address.uuid and not puppet.uuid:
             # We found a UUID for this user, store it ASAP
             # We found a UUID for this user, store it ASAP
@@ -375,7 +426,7 @@ class Puppet(DBPuppet, BasePuppet):
         return puppet
         return puppet
 
 
     @classmethod
     @classmethod
-    async def _get_by_address(cls, address: Address, create: bool = True) -> Optional['Puppet']:
+    async def _get_by_address(cls, address: Address, create: bool = True) -> Optional["Puppet"]:
         if not address.is_valid:
         if not address.is_valid:
             raise ValueError("Empty address")
             raise ValueError("Empty address")
         if address.uuid:
         if address.uuid:
@@ -403,7 +454,7 @@ class Puppet(DBPuppet, BasePuppet):
         return None
         return None
 
 
     @classmethod
     @classmethod
-    async def all_with_custom_mxid(cls) -> AsyncGenerator['Puppet', None]:
+    async def all_with_custom_mxid(cls) -> AsyncGenerator["Puppet", None]:
         puppets = await super().all_with_custom_mxid()
         puppets = await super().all_with_custom_mxid()
         puppet: cls
         puppet: cls
         for index, puppet in enumerate(puppets):
         for index, puppet in enumerate(puppets):

+ 58 - 28
mautrix_signal/signal.py

@@ -13,18 +13,29 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import Optional, List, TYPE_CHECKING
+from typing import TYPE_CHECKING, List, Optional
 import asyncio
 import asyncio
 import logging
 import logging
 
 
-from mausignald import SignaldClient
-from mausignald.types import (Message, MessageData, Address, TypingNotification, TypingAction,
-                              OwnReadReceipt, Receipt, ReceiptType,
-                              WebsocketConnectionStateChangeEvent)
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
+from mausignald import SignaldClient
+from mausignald.types import (
+    Address,
+    Message,
+    MessageData,
+    OwnReadReceipt,
+    Receipt,
+    ReceiptType,
+    TypingAction,
+    TypingNotification,
+    WebsocketConnectionStateChangeEvent,
+)
+
+from . import portal as po
+from . import puppet as pu
+from . import user as u
 from .db import Message as DBMessage
 from .db import Message as DBMessage
-from . import user as u, portal as po, puppet as pu
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .__main__ import SignalBridge
     from .__main__ import SignalBridge
@@ -39,13 +50,14 @@ class SignalHandler(SignaldClient):
     data_dir: str
     data_dir: str
     delete_unknown_accounts: bool
     delete_unknown_accounts: bool
 
 
-    def __init__(self, bridge: 'SignalBridge') -> None:
+    def __init__(self, bridge: "SignalBridge") -> None:
         super().__init__(bridge.config["signal.socket_path"], loop=bridge.loop)
         super().__init__(bridge.config["signal.socket_path"], loop=bridge.loop)
         self.data_dir = bridge.config["signal.data_dir"]
         self.data_dir = bridge.config["signal.data_dir"]
         self.delete_unknown_accounts = bridge.config["signal.delete_unknown_accounts_on_start"]
         self.delete_unknown_accounts = bridge.config["signal.delete_unknown_accounts_on_start"]
         self.add_event_handler(Message, self.on_message)
         self.add_event_handler(Message, self.on_message)
-        self.add_event_handler(WebsocketConnectionStateChangeEvent,
-                               self.on_websocket_connection_state_change)
+        self.add_event_handler(
+            WebsocketConnectionStateChangeEvent, self.on_websocket_connection_state_change
+        )
 
 
     async def on_message(self, evt: Message) -> None:
     async def on_message(self, evt: Message) -> None:
         sender = await pu.Puppet.get_by_address(evt.source)
         sender = await pu.Puppet.get_by_address(evt.source)
@@ -62,8 +74,12 @@ class SignalHandler(SignaldClient):
             if evt.sync_message.read_messages:
             if evt.sync_message.read_messages:
                 await self.handle_own_receipts(sender, evt.sync_message.read_messages)
                 await self.handle_own_receipts(sender, evt.sync_message.read_messages)
             if evt.sync_message.sent:
             if evt.sync_message.sent:
-                await self.handle_message(user, sender, evt.sync_message.sent.message,
-                                          addr_override=evt.sync_message.sent.destination)
+                await self.handle_message(
+                    user,
+                    sender,
+                    evt.sync_message.sent.message,
+                    addr_override=evt.sync_message.sent.destination,
+                )
             if evt.sync_message.typing:
             if evt.sync_message.typing:
                 # Typing notification from own device
                 # Typing notification from own device
                 pass
                 pass
@@ -75,12 +91,19 @@ class SignalHandler(SignaldClient):
                 await user.sync_groups()
                 await user.sync_groups()
 
 
     @staticmethod
     @staticmethod
-    async def on_websocket_connection_state_change(evt: WebsocketConnectionStateChangeEvent) -> None:
+    async def on_websocket_connection_state_change(
+        evt: WebsocketConnectionStateChangeEvent,
+    ) -> None:
         user = await u.User.get_by_username(evt.account)
         user = await u.User.get_by_username(evt.account)
         user.on_websocket_connection_state_change(evt)
         user.on_websocket_connection_state_change(evt)
 
 
-    async def handle_message(self, user: 'u.User', sender: 'pu.Puppet', msg: MessageData,
-                             addr_override: Optional[Address] = None) -> None:
+    async def handle_message(
+        self,
+        user: "u.User",
+        sender: "pu.Puppet",
+        msg: MessageData,
+        addr_override: Optional[Address] = None,
+    ) -> None:
         if msg.profile_key_update:
         if msg.profile_key_update:
             self.log.debug("Ignoring profile key update")
             self.log.debug("Ignoring profile key update")
             return
             return
@@ -89,18 +112,23 @@ class SignalHandler(SignaldClient):
         elif msg.group:
         elif msg.group:
             portal = await po.Portal.get_by_chat_id(msg.group.group_id, create=True)
             portal = await po.Portal.get_by_chat_id(msg.group.group_id, create=True)
         else:
         else:
-            portal = await po.Portal.get_by_chat_id(addr_override or sender.address,
-                                                    receiver=user.username, create=True)
+            portal = await po.Portal.get_by_chat_id(
+                addr_override or sender.address, receiver=user.username, create=True
+            )
             if addr_override and not sender.is_real_user:
             if addr_override and not sender.is_real_user:
-                portal.log.debug(f"Ignoring own message {msg.timestamp} as user doesn't have"
-                                 " double puppeting enabled")
+                portal.log.debug(
+                    f"Ignoring own message {msg.timestamp} as user doesn't have double puppeting "
+                    "enabled"
+                )
                 return
                 return
         if not portal.mxid:
         if not portal.mxid:
-            await portal.create_matrix_room(user, (msg.group_v2 or msg.group
-                                                   or addr_override or sender.address))
+            await portal.create_matrix_room(
+                user, (msg.group_v2 or msg.group or addr_override or sender.address)
+            )
             if not portal.mxid:
             if not portal.mxid:
-                user.log.debug(f"Failed to create room for incoming message {msg.timestamp},"
-                               " dropping message")
+                user.log.debug(
+                    f"Failed to create room for incoming message {msg.timestamp}, dropping message"
+                )
                 return
                 return
         elif msg.group_v2 and msg.group_v2.revision > portal.revision:
         elif msg.group_v2 and msg.group_v2.revision > portal.revision:
             self.log.debug(f"Got new revision of {msg.group_v2.id}, updating info")
             self.log.debug(f"Got new revision of {msg.group_v2.id}, updating info")
@@ -117,7 +145,7 @@ class SignalHandler(SignaldClient):
             await portal.update_expires_in_seconds(sender, msg.expires_in_seconds)
             await portal.update_expires_in_seconds(sender, msg.expires_in_seconds)
 
 
     @staticmethod
     @staticmethod
-    async def handle_own_receipts(sender: 'pu.Puppet', receipts: List[OwnReadReceipt]) -> None:
+    async def handle_own_receipts(sender: "pu.Puppet", receipts: List[OwnReadReceipt]) -> None:
         for receipt in receipts:
         for receipt in receipts:
             puppet = await pu.Puppet.get_by_address(receipt.sender, create=False)
             puppet = await pu.Puppet.get_by_address(receipt.sender, create=False)
             if not puppet:
             if not puppet:
@@ -131,8 +159,9 @@ class SignalHandler(SignaldClient):
             await sender.intent_for(portal).mark_read(portal.mxid, message.mxid)
             await sender.intent_for(portal).mark_read(portal.mxid, message.mxid)
 
 
     @staticmethod
     @staticmethod
-    async def handle_typing(user: 'u.User', sender: 'pu.Puppet',
-                            typing: TypingNotification) -> None:
+    async def handle_typing(
+        user: "u.User", sender: "pu.Puppet", typing: TypingNotification
+    ) -> None:
         if typing.group_id:
         if typing.group_id:
             portal = await po.Portal.get_by_chat_id(typing.group_id)
             portal = await po.Portal.get_by_chat_id(typing.group_id)
         else:
         else:
@@ -140,11 +169,12 @@ class SignalHandler(SignaldClient):
         if not portal or not portal.mxid:
         if not portal or not portal.mxid:
             return
             return
         is_typing = typing.action == TypingAction.STARTED
         is_typing = typing.action == TypingAction.STARTED
-        await sender.intent_for(portal).set_typing(portal.mxid, is_typing, ignore_cache=True,
-                                                   timeout=SIGNAL_TYPING_TIMEOUT)
+        await sender.intent_for(portal).set_typing(
+            portal.mxid, is_typing, ignore_cache=True, timeout=SIGNAL_TYPING_TIMEOUT
+        )
 
 
     @staticmethod
     @staticmethod
-    async def handle_receipt(sender: 'pu.Puppet', receipt: Receipt) -> None:
+    async def handle_receipt(sender: "pu.Puppet", receipt: Receipt) -> None:
         if receipt.type != ReceiptType.READ:
         if receipt.type != ReceiptType.READ:
             return
             return
         messages = await DBMessage.find_by_timestamps(receipt.timestamps)
         messages = await DBMessage.find_by_timestamps(receipt.timestamps)

+ 60 - 38
mautrix_signal/user.py

@@ -13,44 +13,55 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
+from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Union, cast
 from asyncio.tasks import sleep
 from asyncio.tasks import sleep
 from datetime import datetime
 from datetime import datetime
-from typing import Union, Dict, Optional, AsyncGenerator, List, TYPE_CHECKING, cast
 from uuid import UUID
 from uuid import UUID
 import asyncio
 import asyncio
 
 
-from mausignald.types import (Account, Address, Profile, Group, GroupV2, WebsocketConnectionState,
-                              WebsocketConnectionStateChangeEvent)
-from mautrix.bridge import BaseUser, AutologinError, async_getter_lock
-from mautrix.types import UserID, RoomID
-from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
 from mautrix.appservice import AppService
 from mautrix.appservice import AppService
+from mautrix.bridge import AutologinError, BaseUser, async_getter_lock
+from mautrix.types import RoomID, UserID
+from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
 from mautrix.util.opt_prometheus import Gauge
 from mautrix.util.opt_prometheus import Gauge
 
 
-from .db import User as DBUser
+from mausignald.types import (
+    Account,
+    Address,
+    Group,
+    GroupV2,
+    Profile,
+    WebsocketConnectionState,
+    WebsocketConnectionStateChangeEvent,
+)
+
+from . import portal as po
+from . import puppet as pu
 from .config import Config
 from .config import Config
-from . import puppet as pu, portal as po
+from .db import User as DBUser
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .__main__ import SignalBridge
     from .__main__ import SignalBridge
 
 
-METRIC_CONNECTED = Gauge('bridge_connected', 'Bridge users connected to Signal')
-METRIC_LOGGED_IN = Gauge('bridge_logged_in', 'Bridge users logged into Signal')
+METRIC_CONNECTED = Gauge("bridge_connected", "Bridge users connected to Signal")
+METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Bridge users logged into Signal")
 
 
-BridgeState.human_readable_errors.update({
-    "logged-out": "You're not logged into Signal",
-    "signal-not-connected": None,
-})
+BridgeState.human_readable_errors.update(
+    {
+        "logged-out": "You're not logged into Signal",
+        "signal-not-connected": None,
+    }
+)
 
 
 
 
 class User(DBUser, BaseUser):
 class User(DBUser, BaseUser):
-    by_mxid: Dict[UserID, 'User'] = {}
-    by_username: Dict[str, 'User'] = {}
-    by_uuid: Dict[UUID, 'User'] = {}
+    by_mxid: Dict[UserID, "User"] = {}
+    by_username: Dict[str, "User"] = {}
+    by_uuid: Dict[UUID, "User"] = {}
     config: Config
     config: Config
     az: AppService
     az: AppService
     loop: asyncio.AbstractEventLoop
     loop: asyncio.AbstractEventLoop
-    bridge: 'SignalBridge'
+    bridge: "SignalBridge"
 
 
     relay_whitelisted: bool
     relay_whitelisted: bool
     is_admin: bool
     is_admin: bool
@@ -62,8 +73,13 @@ class User(DBUser, BaseUser):
     _websocket_connection_state: Optional[BridgeStateEvent]
     _websocket_connection_state: Optional[BridgeStateEvent]
     _latest_non_transient_disconnect_state: Optional[datetime]
     _latest_non_transient_disconnect_state: Optional[datetime]
 
 
-    def __init__(self, mxid: UserID, username: Optional[str] = None, uuid: Optional[UUID] = None,
-                 notice_room: Optional[RoomID] = None) -> None:
+    def __init__(
+        self,
+        mxid: UserID,
+        username: Optional[str] = None,
+        uuid: Optional[UUID] = None,
+        notice_room: Optional[RoomID] = None,
+    ) -> None:
         super().__init__(mxid=mxid, username=username, uuid=uuid, notice_room=notice_room)
         super().__init__(mxid=mxid, username=username, uuid=uuid, notice_room=notice_room)
         BaseUser.__init__(self)
         BaseUser.__init__(self)
         self._notice_room_lock = asyncio.Lock()
         self._notice_room_lock = asyncio.Lock()
@@ -74,7 +90,7 @@ class User(DBUser, BaseUser):
         self.relay_whitelisted, self.is_whitelisted, self.is_admin, self.permission_level = perms
         self.relay_whitelisted, self.is_whitelisted, self.is_admin, self.permission_level = perms
 
 
     @classmethod
     @classmethod
-    def init_cls(cls, bridge: 'SignalBridge') -> None:
+    def init_cls(cls, bridge: "SignalBridge") -> None:
         cls.bridge = bridge
         cls.bridge = bridge
         cls.config = bridge.config
         cls.config = bridge.config
         cls.az = bridge.az
         cls.az = bridge.az
@@ -89,9 +105,10 @@ class User(DBUser, BaseUser):
     async def is_logged_in(self) -> bool:
     async def is_logged_in(self) -> bool:
         return bool(self.username)
         return bool(self.username)
 
 
-    async def needs_relay(self, portal: 'po.Portal') -> bool:
-        return not await self.is_logged_in() or (portal.is_direct
-                                                 and portal.receiver != self.username)
+    async def needs_relay(self, portal: "po.Portal") -> bool:
+        return not await self.is_logged_in() or (
+            portal.is_direct and portal.receiver != self.username
+        )
 
 
     async def logout(self) -> None:
     async def logout(self) -> None:
         if not self.username:
         if not self.username:
@@ -129,7 +146,7 @@ class User(DBUser, BaseUser):
             state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
             state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
         return [state]
         return [state]
 
 
-    async def get_puppet(self) -> Optional['pu.Puppet']:
+    async def get_puppet(self) -> Optional["pu.Puppet"]:
         if not self.address:
         if not self.address:
             return None
             return None
         return await pu.Puppet.get_by_address(self.address)
         return await pu.Puppet.get_by_address(self.address)
@@ -143,7 +160,9 @@ class User(DBUser, BaseUser):
         asyncio.create_task(self.sync())
         asyncio.create_task(self.sync())
         self._track_metric(METRIC_LOGGED_IN, True)
         self._track_metric(METRIC_LOGGED_IN, True)
 
 
-    def on_websocket_connection_state_change(self, evt: WebsocketConnectionStateChangeEvent) -> None:
+    def on_websocket_connection_state_change(
+        self, evt: WebsocketConnectionStateChangeEvent
+    ) -> None:
         if evt.state == WebsocketConnectionState.CONNECTED:
         if evt.state == WebsocketConnectionState.CONNECTED:
             self.log.info("Connected to Signal")
             self.log.info("Connected to Signal")
             self._track_metric(METRIC_CONNECTED, True)
             self._track_metric(METRIC_CONNECTED, True)
@@ -151,14 +170,14 @@ class User(DBUser, BaseUser):
             self._connected = True
             self._connected = True
         else:
         else:
             self.log.warning(
             self.log.warning(
-                f"New websocket state from signald: {evt.state}. Error: {evt.exception}")
+                f"New websocket state from signald: {evt.state}. Error: {evt.exception}"
+            )
             self._track_metric(METRIC_CONNECTED, False)
             self._track_metric(METRIC_CONNECTED, False)
             self._connected = False
             self._connected = False
 
 
         bridge_state = {
         bridge_state = {
             # Signald disconnected
             # Signald disconnected
             WebsocketConnectionState.SOCKET_DISCONNECTED: BridgeStateEvent.TRANSIENT_DISCONNECT,
             WebsocketConnectionState.SOCKET_DISCONNECTED: BridgeStateEvent.TRANSIENT_DISCONNECT,
-
             # Websocket state reported by signald
             # Websocket state reported by signald
             WebsocketConnectionState.DISCONNECTED: (
             WebsocketConnectionState.DISCONNECTED: (
                 None
                 None
@@ -178,6 +197,7 @@ class User(DBUser, BaseUser):
 
 
         now = datetime.now()
         now = datetime.now()
         if bridge_state == BridgeStateEvent.TRANSIENT_DISCONNECT:
         if bridge_state == BridgeStateEvent.TRANSIENT_DISCONNECT:
+
             async def wait_report_transient_disconnect():
             async def wait_report_transient_disconnect():
                 # Wait for 10 seconds (that should be enough for the bridge to get connected)
                 # Wait for 10 seconds (that should be enough for the bridge to get connected)
                 # before sending a TRANSIENT_DISCONNECT.
                 # before sending a TRANSIENT_DISCONNECT.
@@ -216,7 +236,7 @@ class User(DBUser, BaseUser):
             self.uuid = puppet.uuid
             self.uuid = puppet.uuid
             self.by_uuid[self.uuid] = self
             self.by_uuid[self.uuid] = self
         if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
         if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
-            self.log.info(f"Automatically enabling custom puppet")
+            self.log.info("Automatically enabling custom puppet")
             try:
             try:
                 await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
                 await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
             except AutologinError as e:
             except AutologinError as e:
@@ -248,8 +268,9 @@ class User(DBUser, BaseUser):
         except Exception:
         except Exception:
             self.log.exception("Error while syncing groups")
             self.log.exception("Error while syncing groups")
 
 
-    async def sync_contact(self, contact: Union[Profile, Address], create_portals: bool = False
-                           ) -> None:
+    async def sync_contact(
+        self, contact: Union[Profile, Address], create_portals: bool = False
+    ) -> None:
         self.log.trace("Syncing contact %s", contact)
         self.log.trace("Syncing contact %s", contact)
         if isinstance(contact, Address):
         if isinstance(contact, Address):
             address = contact
             address = contact
@@ -262,8 +283,9 @@ class User(DBUser, BaseUser):
         puppet = await pu.Puppet.get_by_address(address)
         puppet = await pu.Puppet.get_by_address(address)
         await puppet.update_info(profile)
         await puppet.update_info(profile)
         if create_portals:
         if create_portals:
-            portal = await po.Portal.get_by_chat_id(puppet.address, receiver=self.username,
-                                                    create=True)
+            portal = await po.Portal.get_by_chat_id(
+                puppet.address, receiver=self.username, create=True
+            )
             await portal.create_matrix_room(self, profile)
             await portal.create_matrix_room(self, profile)
 
 
     async def _sync_group(self, group: Group, create_portals: bool) -> None:
     async def _sync_group(self, group: Group, create_portals: bool) -> None:
@@ -315,7 +337,7 @@ class User(DBUser, BaseUser):
 
 
     @classmethod
     @classmethod
     @async_getter_lock
     @async_getter_lock
-    async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['User']:
+    async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional["User"]:
         # Never allow ghosts to be users
         # Never allow ghosts to be users
         if pu.Puppet.get_id_from_mxid(mxid):
         if pu.Puppet.get_id_from_mxid(mxid):
             return None
             return None
@@ -339,7 +361,7 @@ class User(DBUser, BaseUser):
 
 
     @classmethod
     @classmethod
     @async_getter_lock
     @async_getter_lock
-    async def get_by_username(cls, username: str) -> Optional['User']:
+    async def get_by_username(cls, username: str) -> Optional["User"]:
         try:
         try:
             return cls.by_username[username]
             return cls.by_username[username]
         except KeyError:
         except KeyError:
@@ -354,7 +376,7 @@ class User(DBUser, BaseUser):
 
 
     @classmethod
     @classmethod
     @async_getter_lock
     @async_getter_lock
-    async def get_by_uuid(cls, uuid: UUID) -> Optional['User']:
+    async def get_by_uuid(cls, uuid: UUID) -> Optional["User"]:
         try:
         try:
             return cls.by_uuid[uuid]
             return cls.by_uuid[uuid]
         except KeyError:
         except KeyError:
@@ -368,7 +390,7 @@ class User(DBUser, BaseUser):
         return None
         return None
 
 
     @classmethod
     @classmethod
-    async def get_by_address(cls, address: Address) -> Optional['User']:
+    async def get_by_address(cls, address: Address) -> Optional["User"]:
         if address.uuid:
         if address.uuid:
             return await cls.get_by_uuid(address.uuid)
             return await cls.get_by_uuid(address.uuid)
         elif address.number:
         elif address.number:
@@ -377,7 +399,7 @@ class User(DBUser, BaseUser):
             raise ValueError("Given address is blank")
             raise ValueError("Given address is blank")
 
 
     @classmethod
     @classmethod
-    async def all_logged_in(cls) -> AsyncGenerator['User', None]:
+    async def all_logged_in(cls) -> AsyncGenerator["User", None]:
         users = await super().all_logged_in()
         users = await super().all_logged_in()
         user: cls
         user: cls
         for user in users:
         for user in users:

+ 2 - 1
mautrix_signal/util/color_log.py

@@ -13,7 +13,8 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from mautrix.util.logging.color import ColorFormatter as BaseColorFormatter, PREFIX, RESET
+from mautrix.util.logging.color import PREFIX, RESET
+from mautrix.util.logging.color import ColorFormatter as BaseColorFormatter
 
 
 MAUSIGNALD_COLOR = PREFIX + "35;1m"  # magenta
 MAUSIGNALD_COLOR = PREFIX + "35;1m"  # magenta
 
 

+ 1 - 1
mautrix_signal/version.py

@@ -1 +1 @@
-from .get_version import git_tag, git_revision, version, linkified_version
+from .get_version import git_revision, git_tag, linkified_version, version

+ 48 - 35
mautrix_signal/web/provisioning_api.py

@@ -13,18 +13,18 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import Awaitable, Dict, TYPE_CHECKING
-import logging
+from typing import TYPE_CHECKING, Awaitable, Dict
 import asyncio
 import asyncio
 import json
 import json
+import logging
 
 
 from aiohttp import web
 from aiohttp import web
-
-from mausignald.types import Address, Account
-from mausignald.errors import InternalError, TimeoutException
 from mautrix.types import UserID
 from mautrix.types import UserID
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
+from mausignald.errors import InternalError, TimeoutException
+from mausignald.types import Account, Address
+
 from .. import user as u
 from .. import user as u
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
@@ -34,9 +34,9 @@ if TYPE_CHECKING:
 class ProvisioningAPI:
 class ProvisioningAPI:
     log: TraceLogger = logging.getLogger("mau.web.provisioning")
     log: TraceLogger = logging.getLogger("mau.web.provisioning")
     app: web.Application
     app: web.Application
-    bridge: 'SignalBridge'
+    bridge: "SignalBridge"
 
 
-    def __init__(self, bridge: 'SignalBridge', shared_secret: str) -> None:
+    def __init__(self, bridge: "SignalBridge", shared_secret: str) -> None:
         self.bridge = bridge
         self.bridge = bridge
         self.app = web.Application()
         self.app = web.Application()
         self.shared_secret = shared_secret
         self.shared_secret = shared_secret
@@ -70,23 +70,26 @@ class ProvisioningAPI:
     async def login_options(self, _: web.Request) -> web.Response:
     async def login_options(self, _: web.Request) -> web.Response:
         return web.Response(status=200, headers=self._headers)
         return web.Response(status=200, headers=self._headers)
 
 
-    async def check_token(self, request: web.Request) -> 'u.User':
+    async def check_token(self, request: web.Request) -> "u.User":
         try:
         try:
             token = request.headers["Authorization"]
             token = request.headers["Authorization"]
-            token = token[len("Bearer "):]
+            token = token[len("Bearer ") :]
         except KeyError:
         except KeyError:
-            raise web.HTTPBadRequest(text='{"error": "Missing Authorization header"}',
-                                     headers=self._headers)
+            raise web.HTTPBadRequest(
+                text='{"error": "Missing Authorization header"}', headers=self._headers
+            )
         except IndexError:
         except IndexError:
-            raise web.HTTPBadRequest(text='{"error": "Malformed Authorization header"}',
-                                     headers=self._headers)
+            raise web.HTTPBadRequest(
+                text='{"error": "Malformed Authorization header"}', headers=self._headers
+            )
         if token != self.shared_secret:
         if token != self.shared_secret:
             raise web.HTTPForbidden(text='{"error": "Invalid token"}', headers=self._headers)
             raise web.HTTPForbidden(text='{"error": "Invalid token"}', headers=self._headers)
         try:
         try:
             user_id = request.query["user_id"]
             user_id = request.query["user_id"]
         except KeyError:
         except KeyError:
-            raise web.HTTPBadRequest(text='{"error": "Missing user_id query param"}',
-                                     headers=self._headers)
+            raise web.HTTPBadRequest(
+                text='{"error": "Missing user_id query param"}', headers=self._headers
+            )
 
 
         if not self.bridge.signal.is_connected:
         if not self.bridge.signal.is_connected:
             await self.bridge.signal.wait_for_connected()
             await self.bridge.signal.wait_for_connected()
@@ -103,7 +106,8 @@ class ProvisioningAPI:
         if await user.is_logged_in():
         if await user.is_logged_in():
             try:
             try:
                 profile = await self.bridge.signal.get_profile(
                 profile = await self.bridge.signal.get_profile(
-                    username=user.username, address=Address(number=user.username))
+                    username=user.username, address=Address(number=user.username)
+                )
             except Exception as e:
             except Exception as e:
                 self.log.exception(f"Failed to get {user.username}'s profile for whoami")
                 self.log.exception(f"Failed to get {user.username}'s profile for whoami")
                 data["signal"] = {
                 data["signal"] = {
@@ -127,8 +131,9 @@ class ProvisioningAPI:
         user = await self.check_token(request)
         user = await self.check_token(request)
 
 
         if await user.is_logged_in():
         if await user.is_logged_in():
-            raise web.HTTPConflict(text='''{"error": "You're already logged in"}''',
-                                   headers=self._headers)
+            raise web.HTTPConflict(
+                text="""{"error": "You're already logged in"}""", headers=self._headers
+            )
 
 
         try:
         try:
             data = await request.json()
             data = await request.json()
@@ -147,17 +152,19 @@ class ProvisioningAPI:
         self.log.debug(f"Returning linking URI for {user.mxid} / {sess.session_id}")
         self.log.debug(f"Returning linking URI for {user.mxid} / {sess.session_id}")
         return web.json_response({"uri": sess.uri}, headers=self._acao_headers)
         return web.json_response({"uri": sess.uri}, headers=self._acao_headers)
 
 
-    async def _shielded_link(self, user: 'u.User', session_id: str, device_name: str) -> Account:
+    async def _shielded_link(self, user: "u.User", session_id: str, device_name: str) -> Account:
         try:
         try:
             self.log.debug(f"Starting finish link request for {user.mxid} / {session_id}")
             self.log.debug(f"Starting finish link request for {user.mxid} / {session_id}")
-            account = await self.bridge.signal.finish_link(session_id=session_id, overwrite=True,
-                                                           device_name=device_name)
+            account = await self.bridge.signal.finish_link(
+                session_id=session_id, overwrite=True, device_name=device_name
+            )
         except TimeoutException:
         except TimeoutException:
             self.log.warning(f"Timed out waiting for linking to finish (session {session_id})")
             self.log.warning(f"Timed out waiting for linking to finish (session {session_id})")
             raise
             raise
         except Exception:
         except Exception:
-            self.log.exception("Fatal error while waiting for linking to finish "
-                               f"(session {session_id})")
+            self.log.exception(
+                "Fatal error while waiting for linking to finish " f"(session {session_id})"
+            )
             raise
             raise
         else:
         else:
             await user.on_signin(account)
             await user.on_signin(account)
@@ -166,36 +173,42 @@ class ProvisioningAPI:
     async def link_wait(self, request: web.Request) -> web.Response:
     async def link_wait(self, request: web.Request) -> web.Response:
         user = await self.check_token(request)
         user = await self.check_token(request)
         if not user.command_status or user.command_status["action"] != "Link":
         if not user.command_status or user.command_status["action"] != "Link":
-            raise web.HTTPBadRequest(text='{"error": "No Signal linking started"}',
-                                     headers=self._headers)
+            raise web.HTTPBadRequest(
+                text='{"error": "No Signal linking started"}', headers=self._headers
+            )
         session_id = user.command_status["session_id"]
         session_id = user.command_status["session_id"]
         device_name = user.command_status["device_name"]
         device_name = user.command_status["device_name"]
         try:
         try:
             account = await asyncio.shield(self._shielded_link(user, session_id, device_name))
             account = await asyncio.shield(self._shielded_link(user, session_id, device_name))
         except asyncio.CancelledError:
         except asyncio.CancelledError:
-            self.log.warning(f"Client cancelled link wait request ({session_id})"
-                             " before it finished")
+            self.log.warning(
+                f"Client cancelled link wait request ({session_id})" " before it finished"
+            )
         except TimeoutException:
         except TimeoutException:
-            raise web.HTTPBadRequest(text='{"error": "Signal linking timed out"}',
-                                     headers=self._headers)
+            raise web.HTTPBadRequest(
+                text='{"error": "Signal linking timed out"}', headers=self._headers
+            )
         except InternalError as ie:
         except InternalError as ie:
             if "java.io.IOException" in ie.exceptions:
             if "java.io.IOException" in ie.exceptions:
                 raise web.HTTPBadRequest(
                 raise web.HTTPBadRequest(
                     text='{"error": "Signald websocket disconnected before linking finished"}',
                     text='{"error": "Signald websocket disconnected before linking finished"}',
                     headers=self._headers,
                     headers=self._headers,
                 )
                 )
-            raise web.HTTPInternalServerError(text='{"error": "Fatal error in Signal linking"}',
-                                              headers=self._headers)
+            raise web.HTTPInternalServerError(
+                text='{"error": "Fatal error in Signal linking"}', headers=self._headers
+            )
         except Exception:
         except Exception:
-            raise web.HTTPInternalServerError(text='{"error": "Fatal error in Signal linking"}',
-                                              headers=self._headers)
+            raise web.HTTPInternalServerError(
+                text='{"error": "Fatal error in Signal linking"}', headers=self._headers
+            )
         else:
         else:
             return web.json_response(account.address.serialize())
             return web.json_response(account.address.serialize())
 
 
     async def logout(self, request: web.Request) -> web.Response:
     async def logout(self, request: web.Request) -> web.Response:
         user = await self.check_token(request)
         user = await self.check_token(request)
         if not await user.is_logged_in():
         if not await user.is_logged_in():
-            raise web.HTTPNotFound(text='''{"error": "You're not logged in"}''',
-                                   headers=self._headers)
+            raise web.HTTPNotFound(
+                text="""{"error": "You're not logged in"}""", headers=self._headers
+            )
         await user.logout()
         await user.logout()
         return web.json_response({}, headers=self._acao_headers)
         return web.json_response({}, headers=self._acao_headers)

+ 10 - 0
pyproject.toml

@@ -0,0 +1,10 @@
+[tool.isort]
+profile = "black"
+force_to_top = "typing"
+from_first = true
+line_length = 99
+
+[tool.black]
+line-length = 99
+target-version = ["py38"]
+required-version = "21.12b0"

+ 5 - 0
setup.cfg

@@ -0,0 +1,5 @@
+[flake8]
+max-line-length = 99
+extend-ignore =
+    # See https://github.com/PyCQA/pycodestyle/issues/373
+    E203,

Некоторые файлы не были показаны из-за большого количества измененных файлов