Sfoglia il codice sorgente

formatting: use black and add CI job to enforce styling

Sumner Evans 3 anni fa
parent
commit
4f31ce41e0

+ 1 - 1
.editorconfig

@@ -17,5 +17,5 @@ trim_trailing_whitespace = false
 [*.{yaml,yml,py,md}]
 [*.{yaml,yml,py,md}]
 indent_style = space
 indent_style = space
 
 
-[{.gitlab-ci.yml,*.md}]
+[{.gitlab-ci.yml,*.md,.github/workflows/*.yml}]
 indent_size = 2
 indent_size = 2

+ 18 - 0
.github/workflows/python-lint.yml

@@ -0,0 +1,18 @@
+name: Python lint
+
+on: [push, pull_request]
+
+jobs:
+  lint:
+    runs-on: ubuntu-latest
+    steps:
+    - uses: actions/checkout@v2
+    - uses: actions/setup-python@v2
+      with:
+        python-version: "3.10"
+    - uses: isort/isort-action@master
+      with:
+        sortPaths: "./mausignald ./mautrix_signal"
+    - uses: psf/black@21.12b0
+      with:
+        src: "./mausignald ./mautrix_signal"

+ 6 - 2
mausignald/errors.py

@@ -26,8 +26,12 @@ class NotConnected(RPCError):
 
 
 
 
 class ResponseError(RPCError):
 class ResponseError(RPCError):
-    def __init__(self, data: Dict[str, Any], error_type: Optional[str] = None,
-                 message_override: Optional[str] = None) -> None:
+    def __init__(
+        self,
+        data: Dict[str, Any],
+        error_type: Optional[str] = None,
+        message_override: Optional[str] = None,
+    ) -> None:
         self.data = data
         self.data = data
         msg = message_override or data["message"]
         msg = message_override or data["message"]
         if error_type:
         if error_type:

+ 16 - 8
mausignald/rpc.py

@@ -36,8 +36,12 @@ class SignaldRPCClient:
     _response_waiters: Dict[UUID, asyncio.Future]
     _response_waiters: Dict[UUID, asyncio.Future]
     _rpc_event_handlers: Dict[str, List[EventHandler]]
     _rpc_event_handlers: Dict[str, List[EventHandler]]
 
 
-    def __init__(self, socket_path: str, log: Optional[TraceLogger] = None,
-                 loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
+    def __init__(
+        self,
+        socket_path: str,
+        log: Optional[TraceLogger] = None,
+        loop: Optional[asyncio.AbstractEventLoop] = None,
+    ) -> None:
         self.socket_path = socket_path
         self.socket_path = socket_path
         self.log = log or logging.getLogger("mausignald")
         self.log = log or logging.getLogger("mausignald")
         self.loop = loop or asyncio.get_event_loop()
         self.loop = loop or asyncio.get_event_loop()
@@ -67,7 +71,8 @@ class SignaldRPCClient:
         while True:
         while True:
             try:
             try:
                 self._reader, self._writer = await asyncio.open_unix_connection(
                 self._reader, self._writer = await asyncio.open_unix_connection(
-                    self.socket_path, limit=_SOCKET_LIMIT)
+                    self.socket_path, limit=_SOCKET_LIMIT
+                )
             except OSError as e:
             except OSError as e:
                 self.log.error(f"Connection to {self.socket_path} failed: {e}")
                 self.log.error(f"Connection to {self.socket_path} failed: {e}")
                 await asyncio.sleep(5)
                 await asyncio.sleep(5)
@@ -177,8 +182,9 @@ class SignaldRPCClient:
         self._reader = None
         self._reader = None
         self._writer = None
         self._writer = None
 
 
-    def _create_request(self, command: str, req_id: Optional[UUID] = None, **data: Any
-                        ) -> Tuple[asyncio.Future, Dict[str, Any]]:
+    def _create_request(
+        self, command: str, req_id: Optional[UUID] = None, **data: Any
+    ) -> Tuple[asyncio.Future, Dict[str, Any]]:
         req_id = req_id or uuid4()
         req_id = req_id or uuid4()
         req = {"id": str(req_id), "type": command, **data}
         req = {"id": str(req_id), "type": command, **data}
         self.log.trace("Request %s: %s %s", req_id, command, data)
         self.log.trace("Request %s: %s %s", req_id, command, data)
@@ -196,7 +202,8 @@ class SignaldRPCClient:
             if not waiter.done():
             if not waiter.done():
                 self.log.trace(f"Abandoning response for {req_id}")
                 self.log.trace(f"Abandoning response for {req_id}")
                 waiter.set_exception(
                 waiter.set_exception(
-                    NotConnected("Disconnected from signald before RPC completed"))
+                    NotConnected("Disconnected from signald before RPC completed")
+                )
 
 
     async def _send_request(self, data: Dict[str, Any]) -> None:
     async def _send_request(self, data: Dict[str, Any]) -> None:
         if self._writer is None:
         if self._writer is None:
@@ -207,8 +214,9 @@ class SignaldRPCClient:
         await self._writer.drain()
         await self._writer.drain()
         self.log.trace("Sent data to server server: %s", data)
         self.log.trace("Sent data to server server: %s", data)
 
 
-    async def _raw_request(self, command: str, req_id: Optional[UUID] = None, **data: Any
-                           ) -> Tuple[str, Dict[str, Any]]:
+    async def _raw_request(
+        self, command: str, req_id: Optional[UUID] = None, **data: Any
+    ) -> Tuple[str, Dict[str, Any]]:
         future, data = self._create_request(command, req_id, **data)
         future, data = self._create_request(command, req_id, **data)
         await self._send_request(data)
         await self._send_request(data)
         return await asyncio.shield(future)
         return await asyncio.shield(future)

+ 156 - 77
mausignald/signald.py

@@ -10,11 +10,26 @@ from mautrix.util.logging import TraceLogger
 
 
 from .rpc import CONNECT_EVENT, DISCONNECT_EVENT, SignaldRPCClient
 from .rpc import CONNECT_EVENT, DISCONNECT_EVENT, SignaldRPCClient
 from .errors import UnexpectedError, UnexpectedResponse
 from .errors import UnexpectedError, UnexpectedResponse
-from .types import (Address, Quote, Attachment, Reaction, Account, Message, DeviceInfo, Group,
-                    Profile, GroupID, GetIdentitiesResponse, GroupV2, Mention, LinkSession,
-                    WebsocketConnectionState, WebsocketConnectionStateChangeEvent)
-
-T = TypeVar('T')
+from .types import (
+    Address,
+    Quote,
+    Attachment,
+    Reaction,
+    Account,
+    Message,
+    DeviceInfo,
+    Group,
+    Profile,
+    GroupID,
+    GetIdentitiesResponse,
+    GroupV2,
+    Mention,
+    LinkSession,
+    WebsocketConnectionState,
+    WebsocketConnectionStateChangeEvent,
+)
+
+T = TypeVar("T")
 EventHandler = Callable[[T], Awaitable[None]]
 EventHandler = Callable[[T], Awaitable[None]]
 
 
 
 
@@ -22,15 +37,19 @@ class SignaldClient(SignaldRPCClient):
     _event_handlers: Dict[Type[T], List[EventHandler]]
     _event_handlers: Dict[Type[T], List[EventHandler]]
     _subscriptions: Set[str]
     _subscriptions: Set[str]
 
 
-    def __init__(self, socket_path: str = "/var/run/signald/signald.sock",
-                 log: Optional[TraceLogger] = None,
-                 loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
+    def __init__(
+        self,
+        socket_path: str = "/var/run/signald/signald.sock",
+        log: Optional[TraceLogger] = None,
+        loop: Optional[asyncio.AbstractEventLoop] = None,
+    ) -> None:
         super().__init__(socket_path, log, loop)
         super().__init__(socket_path, log, loop)
         self._event_handlers = {}
         self._event_handlers = {}
         self._subscriptions = set()
         self._subscriptions = set()
         self.add_rpc_handler("message", self._parse_message)
         self.add_rpc_handler("message", self._parse_message)
-        self.add_rpc_handler("websocket_connection_state_change",
-                             self._websocket_connection_state_change)
+        self.add_rpc_handler(
+            "websocket_connection_state_change", self._websocket_connection_state_change
+        )
         self.add_rpc_handler("version", self._log_version)
         self.add_rpc_handler("version", self._log_version)
         self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
         self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
         self.add_rpc_handler(DISCONNECT_EVENT, self._on_disconnect)
         self.add_rpc_handler(DISCONNECT_EVENT, self._on_disconnect)
@@ -111,12 +130,13 @@ class SignaldClient(SignaldRPCClient):
                 evt = WebsocketConnectionStateChangeEvent(
                 evt = WebsocketConnectionStateChangeEvent(
                     state=WebsocketConnectionState.SOCKET_DISCONNECTED,
                     state=WebsocketConnectionState.SOCKET_DISCONNECTED,
                     account=username,
                     account=username,
-                    exception="Disconnected from signald"
+                    exception="Disconnected from signald",
                 )
                 )
                 await self._run_event_handler(evt)
                 await self._run_event_handler(evt)
 
 
-    async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
-                       ) -> str:
+    async def register(
+        self, phone: str, voice: bool = False, captcha: Optional[str] = None
+    ) -> str:
         resp = await self.request_v1("register", account=phone, voice=voice, captcha=captcha)
         resp = await self.request_v1("register", account=phone, voice=voice, captcha=captcha)
         return resp["account_id"]
         return resp["account_id"]
 
 
@@ -127,15 +147,18 @@ class SignaldClient(SignaldRPCClient):
     async def start_link(self) -> LinkSession:
     async def start_link(self) -> LinkSession:
         return LinkSession.deserialize(await self.request_v1("generate_linking_uri"))
         return LinkSession.deserialize(await self.request_v1("generate_linking_uri"))
 
 
-    async def finish_link(self, session_id: str, device_name: str = "mausignald",
-                          overwrite: bool = False) -> Account:
-        resp = await self.request_v1("finish_link", device_name=device_name, session_id=session_id,
-                                     overwrite=overwrite)
+    async def finish_link(
+        self, session_id: str, device_name: str = "mausignald", overwrite: bool = False
+    ) -> Account:
+        resp = await self.request_v1(
+            "finish_link", device_name=device_name, session_id=session_id, overwrite=overwrite
+        )
         return Account.deserialize(resp)
         return Account.deserialize(resp)
 
 
     @staticmethod
     @staticmethod
-    def _recipient_to_args(recipient: Union[Address, GroupID], simple_name: bool = False
-                           ) -> Dict[str, Any]:
+    def _recipient_to_args(
+        recipient: Union[Address, GroupID], simple_name: bool = False
+    ) -> Dict[str, Any]:
         if isinstance(recipient, Address):
         if isinstance(recipient, Address):
             recipient = recipient.serialize()
             recipient = recipient.serialize()
             field_name = "address" if simple_name else "recipientAddress"
             field_name = "address" if simple_name else "recipientAddress"
@@ -143,27 +166,49 @@ class SignaldClient(SignaldRPCClient):
             field_name = "group" if simple_name else "recipientGroupId"
             field_name = "group" if simple_name else "recipientGroupId"
         return {field_name: recipient}
         return {field_name: recipient}
 
 
-    async def react(self, username: str, recipient: Union[Address, GroupID],
-                    reaction: Reaction) -> None:
-        await self.request_v1("react", username=username, reaction=reaction.serialize(),
-                              **self._recipient_to_args(recipient))
-
-    async def remote_delete(self, username: str, recipient: Union[Address, GroupID], timestamp: int
-                            ) -> None:
-        await self.request_v1("remote_delete", account=username, timestamp=timestamp,
-                              **self._recipient_to_args(recipient, simple_name=True))
-
-    async def send(self, username: str, recipient: Union[Address, GroupID], body: str,
-                   quote: Optional[Quote] = None, attachments: Optional[List[Attachment]] = None,
-                   mentions: Optional[List[Mention]] = None, timestamp: Optional[int] = None
-                   ) -> None:
+    async def react(
+        self, username: str, recipient: Union[Address, GroupID], reaction: Reaction
+    ) -> None:
+        await self.request_v1(
+            "react",
+            username=username,
+            reaction=reaction.serialize(),
+            **self._recipient_to_args(recipient),
+        )
+
+    async def remote_delete(
+        self, username: str, recipient: Union[Address, GroupID], timestamp: int
+    ) -> None:
+        await self.request_v1(
+            "remote_delete",
+            account=username,
+            timestamp=timestamp,
+            **self._recipient_to_args(recipient, simple_name=True),
+        )
+
+    async def send(
+        self,
+        username: str,
+        recipient: Union[Address, GroupID],
+        body: str,
+        quote: Optional[Quote] = None,
+        attachments: Optional[List[Attachment]] = None,
+        mentions: Optional[List[Mention]] = None,
+        timestamp: Optional[int] = None,
+    ) -> None:
         serialized_quote = quote.serialize() if quote else None
         serialized_quote = quote.serialize() if quote else None
         serialized_attachments = [attachment.serialize() for attachment in (attachments or [])]
         serialized_attachments = [attachment.serialize() for attachment in (attachments or [])]
         serialized_mentions = [mention.serialize() for mention in (mentions or [])]
         serialized_mentions = [mention.serialize() for mention in (mentions or [])]
-        resp = await self.request_v1("send", username=username, messageBody=body,
-                                     attachments=serialized_attachments, quote=serialized_quote,
-                                     mentions=serialized_mentions, timestamp=timestamp,
-                                     **self._recipient_to_args(recipient))
+        resp = await self.request_v1(
+            "send",
+            username=username,
+            messageBody=body,
+            attachments=serialized_attachments,
+            quote=serialized_quote,
+            mentions=serialized_mentions,
+            timestamp=timestamp,
+            **self._recipient_to_args(recipient),
+        )
         errors = []
         errors = []
 
 
         # We handle unregisteredFailure a little differently than other errors. If there are no
         # We handle unregisteredFailure a little differently than other errors. If there are no
@@ -173,9 +218,8 @@ class SignaldClient(SignaldRPCClient):
         successful_send_count = 0
         successful_send_count = 0
         results = resp.get("results", [])
         results = resp.get("results", [])
         for result in results:
         for result in results:
-            number = (
-                result.get("address", {}).get("number") or result.get("address", {}).get("uuid")
-            )
+            address = result.get("addres", {})
+            number = address.get("number") or address.get("uuid")
             proof_required_failure = result.get("proof_required_failure")
             proof_required_failure = result.get("proof_required_failure")
             if result.get("networkFailure", False):
             if result.get("networkFailure", False):
                 errors.append(f"Network failure occurred while sending message to {number}.")
                 errors.append(f"Network failure occurred while sending message to {number}.")
@@ -186,9 +230,10 @@ class SignaldClient(SignaldRPCClient):
             elif result.get("identityFailure", ""):
             elif result.get("identityFailure", ""):
                 errors.append(
                 errors.append(
                     f"Identity failure occurred while sending message to {number}. New identity: "
                     f"Identity failure occurred while sending message to {number}. New identity: "
-                    f"{result['identityFailure']}")
+                    f"{result['identityFailure']}"
+                )
             elif proof_required_failure:
             elif proof_required_failure:
-                options = proof_required_failure.get('options')
+                options = proof_required_failure.get("options")
                 self.log.warning(
                 self.log.warning(
                     f"Proof Required Failure {options}. "
                     f"Proof Required Failure {options}. "
                     f"Retry after: {proof_required_failure.get('retry_after')}. "
                     f"Retry after: {proof_required_failure.get('retry_after')}. "
@@ -206,17 +251,26 @@ class SignaldClient(SignaldRPCClient):
                     await self.request_v1("submit_challenge")
                     await self.request_v1("submit_challenge")
             else:
             else:
                 successful_send_count += 1
                 successful_send_count += 1
-        self.log.info(f"Successfully sent message to {successful_send_count}/{len(results)} users in {recipient}")
+        self.log.info(
+            f"Successfully sent message to {successful_send_count}/{len(results)} users in {recipient}"
+        )
         if errors or successful_send_count == 0:
         if errors or successful_send_count == 0:
             raise Exception("\n".join(errors + unregistered_failures))
             raise Exception("\n".join(errors + unregistered_failures))
 
 
-    async def send_receipt(self, username: str, sender: Address, timestamps: List[int],
-                           when: Optional[int] = None, read: bool = False) -> None:
+    async def send_receipt(
+        self,
+        username: str,
+        sender: Address,
+        timestamps: List[int],
+        when: Optional[int] = None,
+        read: bool = False,
+    ) -> None:
         if not read:
         if not read:
             # TODO implement
             # TODO implement
             return
             return
-        await self.request_v1("mark_read", account=username, timestamps=timestamps, when=when,
-                              to=sender.serialize())
+        await self.request_v1(
+            "mark_read", account=username, timestamps=timestamps, when=when, to=sender.serialize()
+        )
 
 
     async def list_accounts(self) -> List[Account]:
     async def list_accounts(self) -> List[Account]:
         resp = await self.request_v1("list_accounts")
         resp = await self.request_v1("list_accounts")
@@ -242,19 +296,28 @@ class SignaldClient(SignaldRPCClient):
         v2 = [GroupV2.deserialize(group) for group in resp.get("groups", [])]
         v2 = [GroupV2.deserialize(group) for group in resp.get("groups", [])]
         return legacy + v2
         return legacy + v2
 
 
-    async def update_group(self, username: str, group_id: GroupID, title: Optional[str] = None,
-                           avatar_path: Optional[str] = None,
-                           add_members: Optional[List[Address]] = None,
-                           remove_members: Optional[List[Address]] = None
-                           ) -> Union[Group, GroupV2, None]:
-        update_params = {key: value for key, value in {
-            "groupID": group_id,
-            "avatar": avatar_path,
-            "title": title,
-            "addMembers": [addr.serialize() for addr in add_members] if add_members else None,
-            "removeMembers": ([addr.serialize() for addr in remove_members]
-                              if remove_members else None),
-        }.items() if value is not None}
+    async def update_group(
+        self,
+        username: str,
+        group_id: GroupID,
+        title: Optional[str] = None,
+        avatar_path: Optional[str] = None,
+        add_members: Optional[List[Address]] = None,
+        remove_members: Optional[List[Address]] = None,
+    ) -> Union[Group, GroupV2, None]:
+        update_params = {
+            key: value
+            for key, value in {
+                "groupID": group_id,
+                "avatar": avatar_path,
+                "title": title,
+                "addMembers": [addr.serialize() for addr in add_members] if add_members else None,
+                "removeMembers": (
+                    [addr.serialize() for addr in remove_members] if remove_members else None
+                ),
+            }.items()
+            if value is not None
+        }
         resp = await self.request_v1("update_group", account=username, **update_params)
         resp = await self.request_v1("update_group", account=username, **update_params)
         if "v1" in resp:
         if "v1" in resp:
             return Group.deserialize(resp["v1"])
             return Group.deserialize(resp["v1"])
@@ -267,21 +330,25 @@ class SignaldClient(SignaldRPCClient):
         resp = await self.request_v1("accept_invitation", account=username, groupID=group_id)
         resp = await self.request_v1("accept_invitation", account=username, groupID=group_id)
         return GroupV2.deserialize(resp)
         return GroupV2.deserialize(resp)
 
 
-    async def get_group(self, username: str, group_id: GroupID, revision: int = -1
-                        ) -> Optional[GroupV2]:
-        resp = await self.request_v1("get_group", account=username, groupID=group_id,
-                                     revision=revision)
+    async def get_group(
+        self, username: str, group_id: GroupID, revision: int = -1
+    ) -> Optional[GroupV2]:
+        resp = await self.request_v1(
+            "get_group", account=username, groupID=group_id, revision=revision
+        )
         if "id" not in resp:
         if "id" not in resp:
             return None
             return None
         return GroupV2.deserialize(resp)
         return GroupV2.deserialize(resp)
 
 
-    async def get_profile(self, username: str, address: Address, use_cache: bool = False
-                          ) -> Optional[Profile]:
+    async def get_profile(
+        self, username: str, address: Address, use_cache: bool = False
+    ) -> Optional[Profile]:
         try:
         try:
             # async is a reserved keyword, so can't pass it as a normal parameter
             # async is a reserved keyword, so can't pass it as a normal parameter
             kwargs = {"async": use_cache}
             kwargs = {"async": use_cache}
-            resp = await self.request_v1("get_profile", account=username,
-                                         address=address.serialize(), **kwargs)
+            resp = await self.request_v1(
+                "get_profile", account=username, address=address.serialize(), **kwargs
+            )
         except UnexpectedResponse as e:
         except UnexpectedResponse as e:
             if e.resp_type == "profile_not_available":
             if e.resp_type == "profile_not_available":
                 return None
                 return None
@@ -289,12 +356,14 @@ class SignaldClient(SignaldRPCClient):
         return Profile.deserialize(resp)
         return Profile.deserialize(resp)
 
 
     async def get_identities(self, username: str, address: Address) -> GetIdentitiesResponse:
     async def get_identities(self, username: str, address: Address) -> GetIdentitiesResponse:
-        resp = await self.request_v1("get_identities", account=username,
-                                     address=address.serialize())
+        resp = await self.request_v1(
+            "get_identities", account=username, address=address.serialize()
+        )
         return GetIdentitiesResponse.deserialize(resp)
         return GetIdentitiesResponse.deserialize(resp)
 
 
-    async def set_profile(self, username: str, name: Optional[str] = None,
-                          avatar_path: Optional[str] = None) -> None:
+    async def set_profile(
+        self, username: str, name: Optional[str] = None, avatar_path: Optional[str] = None
+    ) -> None:
         args = {}
         args = {}
         if name is not None:
         if name is not None:
             args["name"] = name
             args["name"] = name
@@ -302,9 +371,14 @@ class SignaldClient(SignaldRPCClient):
             args["avatarFile"] = avatar_path
             args["avatarFile"] = avatar_path
         await self.request_v1("set_profile", account=username, **args)
         await self.request_v1("set_profile", account=username, **args)
 
 
-    async def trust(self, username: str, recipient: Address, trust_level: str,
-                    safety_number: Optional[str] = None, qr_code_data: Optional[str] = None
-                    ) -> None:
+    async def trust(
+        self,
+        username: str,
+        recipient: Address,
+        trust_level: str,
+        safety_number: Optional[str] = None,
+        qr_code_data: Optional[str] = None,
+    ) -> None:
         args = {}
         args = {}
         if safety_number:
         if safety_number:
             if qr_code_data:
             if qr_code_data:
@@ -314,5 +388,10 @@ class SignaldClient(SignaldRPCClient):
             args["qr_code_data"] = qr_code_data
             args["qr_code_data"] = qr_code_data
         else:
         else:
             raise ValueError("safety_number or qr_code_data is required")
             raise ValueError("safety_number or qr_code_data is required")
-        await self.request_v1("trust", account=username, **args, trust_level=trust_level,
-                              address=recipient.serialize())
+        await self.request_v1(
+            "trust",
+            account=username,
+            **args,
+            trust_level=trust_level,
+            address=recipient.serialize(),
+        )

+ 19 - 13
mausignald/types.py

@@ -11,7 +11,7 @@ from attr import dataclass
 
 
 from mautrix.types import SerializableAttrs, SerializableEnum, ExtensibleEnum, field
 from mautrix.types import SerializableAttrs, SerializableEnum, ExtensibleEnum, field
 
 
-GroupID = NewType('GroupID', str)
+GroupID = NewType("GroupID", str)
 
 
 
 
 @dataclass(frozen=True, eq=False)
 @dataclass(frozen=True, eq=False)
@@ -27,7 +27,7 @@ class Address(SerializableAttrs):
     def best_identifier(self) -> str:
     def best_identifier(self) -> str:
         return str(self.uuid) if self.uuid else self.number
         return str(self.uuid) if self.uuid else self.number
 
 
-    def __eq__(self, other: 'Address') -> bool:
+    def __eq__(self, other: "Address") -> bool:
         if not isinstance(other, Address):
         if not isinstance(other, Address):
             return False
             return False
         if self.uuid and other.uuid:
         if self.uuid and other.uuid:
@@ -42,7 +42,7 @@ class Address(SerializableAttrs):
         return hash(self.number)
         return hash(self.number)
 
 
     @classmethod
     @classmethod
-    def parse(cls, value: str) -> 'Address':
+    def parse(cls, value: str) -> "Address":
         return Address(number=value) if value.startswith("+") else Address(uuid=UUID(value))
         return Address(number=value) if value.startswith("+") else Address(uuid=UUID(value))
 
 
 
 
@@ -202,13 +202,15 @@ class GroupV2(GroupV2ID, SerializableAttrs):
     timer: Optional[int] = None
     timer: Optional[int] = None
     master_key: Optional[str] = field(default=None, json="masterKey")
     master_key: Optional[str] = field(default=None, json="masterKey")
     invite_link: Optional[str] = field(default=None, json="inviteLink")
     invite_link: Optional[str] = field(default=None, json="inviteLink")
-    access_control: GroupAccessControl = field(factory=lambda: GroupAccessControl(),
-                                               json="accessControl")
+    access_control: GroupAccessControl = field(
+        factory=lambda: GroupAccessControl(), json="accessControl"
+    )
     members: List[Address]
     members: List[Address]
     member_detail: List[GroupMember] = field(factory=lambda: [], json="memberDetail")
     member_detail: List[GroupMember] = field(factory=lambda: [], json="memberDetail")
     pending_members: List[Address] = field(factory=lambda: [], json="pendingMembers")
     pending_members: List[Address] = field(factory=lambda: [], json="pendingMembers")
-    pending_member_detail: List[GroupMember] = field(factory=lambda: [],
-                                                     json="pendingMemberDetail")
+    pending_member_detail: List[GroupMember] = field(
+        factory=lambda: [], json="pendingMemberDetail"
+    )
     requesting_members: List[Address] = field(factory=lambda: [], json="requestingMembers")
     requesting_members: List[Address] = field(factory=lambda: [], json="requestingMembers")
 
 
 
 
@@ -294,8 +296,9 @@ class MessageData(SerializableAttrs):
 class SentSyncMessage(SerializableAttrs):
 class SentSyncMessage(SerializableAttrs):
     message: MessageData
     message: MessageData
     timestamp: int
     timestamp: int
-    expiration_start_timestamp: Optional[int] = field(default=None,
-                                                      json="expirationStartTimestamp")
+    expiration_start_timestamp: Optional[int] = field(
+        default=None, json="expirationStartTimestamp"
+    )
     is_recipient_update: bool = field(default=False, json="isRecipientUpdate")
     is_recipient_update: bool = field(default=False, json="isRecipientUpdate")
     unidentified_status: Dict[str, bool] = field(factory=lambda: {})
     unidentified_status: Dict[str, bool] = field(factory=lambda: {})
     destination: Optional[Address] = None
     destination: Optional[Address] = None
@@ -347,11 +350,13 @@ class ConfigItem(SerializableAttrs):
 @dataclass
 @dataclass
 class ClientConfiguration(SerializableAttrs):
 class ClientConfiguration(SerializableAttrs):
     read_receipts: Optional[ConfigItem] = field(factory=lambda: ConfigItem(), json="readReceipts")
     read_receipts: Optional[ConfigItem] = field(factory=lambda: ConfigItem(), json="readReceipts")
-    typing_indicators: Optional[ConfigItem] = field(factory=lambda: ConfigItem(),
-                                                    json="typingIndicators")
+    typing_indicators: Optional[ConfigItem] = field(
+        factory=lambda: ConfigItem(), json="typingIndicators"
+    )
     link_previews: Optional[ConfigItem] = field(factory=lambda: ConfigItem(), json="linkPreviews")
     link_previews: Optional[ConfigItem] = field(factory=lambda: ConfigItem(), json="linkPreviews")
     unidentified_delivery_indicators: Optional[ConfigItem] = field(
     unidentified_delivery_indicators: Optional[ConfigItem] = field(
-        factory=lambda: ConfigItem(), json="unidentifiedDeliveryIndicators")
+        factory=lambda: ConfigItem(), json="unidentifiedDeliveryIndicators"
+    )
 
 
 
 
 class StickerPackOperation(ExtensibleEnum):
 class StickerPackOperation(ExtensibleEnum):
@@ -376,7 +381,8 @@ class SyncMessage(SerializableAttrs):
     configuration: Optional[ClientConfiguration] = None
     configuration: Optional[ClientConfiguration] = None
     # blocked_list: Optional[???] = field(default=None, json="blockedList")
     # blocked_list: Optional[???] = field(default=None, json="blockedList")
     sticker_pack_operations: Optional[List[StickerPackOperations]] = field(
     sticker_pack_operations: Optional[List[StickerPackOperations]] = field(
-        default=None, json="stickerPackOperations")
+        default=None, json="stickerPackOperations"
+    )
     contacts_complete: bool = field(default=False, json="contactsComplete")
     contacts_complete: bool = field(default=False, json="contactsComplete")
 
 
 
 

+ 74 - 35
mautrix_signal/commands/auth.py

@@ -34,8 +34,9 @@ SECTION_AUTH = HelpSection("Authentication", 10, "")
 remove_extra_chars = str.maketrans("", "", " .,-()")
 remove_extra_chars = str.maketrans("", "", " .,-()")
 
 
 
 
-async def make_qr(intent: IntentAPI, data: Union[str, bytes], body: str = None
-                  ) -> MediaMessageEventContent:
+async def make_qr(
+    intent: IntentAPI, data: Union[str, bytes], body: str = None
+) -> MediaMessageEventContent:
     # TODO always encrypt QR codes?
     # TODO always encrypt QR codes?
     buffer = io.BytesIO()
     buffer = io.BytesIO()
     image = qrcode.make(data)
     image = qrcode.make(data)
@@ -43,20 +44,30 @@ async def make_qr(intent: IntentAPI, data: Union[str, bytes], body: str = None
     image.save(buffer, "PNG")
     image.save(buffer, "PNG")
     qr = buffer.getvalue()
     qr = buffer.getvalue()
     mxc = await intent.upload_media(qr, "image/png", "qr.png", len(qr))
     mxc = await intent.upload_media(qr, "image/png", "qr.png", len(qr))
-    return MediaMessageEventContent(body=body or data, url=mxc, msgtype=MessageType.IMAGE,
-                                    info=ImageInfo(mimetype="image/png", size=len(qr),
-                                                   width=size, height=size))
-
-
-@command_handler(needs_auth=False, management_only=True, help_section=SECTION_AUTH,
-                 help_text="Link the bridge as a secondary device", help_args="[device name]")
+    return MediaMessageEventContent(
+        body=body or data,
+        url=mxc,
+        msgtype=MessageType.IMAGE,
+        info=ImageInfo(mimetype="image/png", size=len(qr), width=size, height=size),
+    )
+
+
+@command_handler(
+    needs_auth=False,
+    management_only=True,
+    help_section=SECTION_AUTH,
+    help_text="Link the bridge as a secondary device",
+    help_args="[device name]",
+)
 async def link(evt: CommandEvent) -> None:
 async def link(evt: CommandEvent) -> None:
     if qrcode is None:
     if qrcode is None:
         await evt.reply("Can't generate QR code: qrcode and/or PIL not installed")
         await evt.reply("Can't generate QR code: qrcode and/or PIL not installed")
         return
         return
     if await evt.sender.is_logged_in():
     if await evt.sender.is_logged_in():
-        await evt.reply("You're already logged in. "
-                        "If you want to relink, log out with `$cmdprefix+sp logout` first.")
+        await evt.reply(
+            "You're already logged in. "
+            "If you want to relink, log out with `$cmdprefix+sp logout` first."
+        )
         return
         return
     # TODO make default device name configurable
     # TODO make default device name configurable
     device_name = " ".join(evt.args) or "Mautrix-Signal bridge"
     device_name = " ".join(evt.args) or "Mautrix-Signal bridge"
@@ -65,14 +76,16 @@ async def link(evt: CommandEvent) -> None:
     content = await make_qr(evt.az.intent, sess.uri)
     content = await make_qr(evt.az.intent, sess.uri)
     event_id = await evt.az.intent.send_message(evt.room_id, content)
     event_id = await evt.az.intent.send_message(evt.room_id, content)
     try:
     try:
-        account = await evt.bridge.signal.finish_link(session_id=sess.session_id, overwrite=True,
-                                                      device_name=device_name)
+        account = await evt.bridge.signal.finish_link(
+            session_id=sess.session_id, overwrite=True, device_name=device_name
+        )
     except TimeoutException:
     except TimeoutException:
         await evt.reply("Linking timed out, please try again.")
         await evt.reply("Linking timed out, please try again.")
     except Exception:
     except Exception:
         evt.log.exception("Fatal error while waiting for linking to finish")
         evt.log.exception("Fatal error while waiting for linking to finish")
-        await evt.reply("Fatal error while waiting for linking to finish "
-                        "(see logs for more details)")
+        await evt.reply(
+            "Fatal error while waiting for linking to finish " "(see logs for more details)"
+        )
     else:
     else:
         await evt.sender.on_signin(account)
         await evt.sender.on_signin(account)
         await evt.reply(f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}")
         await evt.reply(f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}")
@@ -80,16 +93,23 @@ async def link(evt: CommandEvent) -> None:
         await evt.main_intent.redact(evt.room_id, event_id)
         await evt.main_intent.redact(evt.room_id, event_id)
 
 
 
 
-@command_handler(needs_auth=False, management_only=True, help_section=SECTION_AUTH,
-                 is_enabled_for=lambda evt: evt.config["signal.registration_enabled"],
-                 help_text="Sign into Signal as the primary device", help_args="<phone>")
+@command_handler(
+    needs_auth=False,
+    management_only=True,
+    help_section=SECTION_AUTH,
+    is_enabled_for=lambda evt: evt.config["signal.registration_enabled"],
+    help_text="Sign into Signal as the primary device",
+    help_args="<phone>",
+)
 async def register(evt: CommandEvent) -> None:
 async def register(evt: CommandEvent) -> None:
     if len(evt.args) == 0:
     if len(evt.args) == 0:
         await evt.reply("**Usage**: $cmdprefix+sp register [--voice] [--captcha <token>] <phone>")
         await evt.reply("**Usage**: $cmdprefix+sp register [--voice] [--captcha <token>] <phone>")
         return
         return
     if await evt.sender.is_logged_in():
     if await evt.sender.is_logged_in():
-        await evt.reply("You're already logged in. "
-                        "If you want to re-register, log out with `$cmdprefix+sp logout` first.")
+        await evt.reply(
+            "You're already logged in. "
+            "If you want to re-register, log out with `$cmdprefix+sp logout` first."
+        )
         return
         return
     voice = False
     voice = False
     captcha = None
     captcha = None
@@ -132,13 +152,19 @@ async def enter_register_code(evt: CommandEvent) -> None:
             raise
             raise
     else:
     else:
         await evt.sender.on_signin(account)
         await evt.sender.on_signin(account)
-        await evt.reply(f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}."
-                        f"\n\n**N.B.** You must set a Signal profile name with `$cmdprefix+sp "
-                        f"set-profile-name <name>` before you can participate in new groups.")
-
-
-@command_handler(needs_auth=True, management_only=True, help_section=SECTION_AUTH,
-                 help_text="Remove all local data about your Signal link")
+        await evt.reply(
+            f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}."
+            f"\n\n**N.B.** You must set a Signal profile name with `$cmdprefix+sp "
+            f"set-profile-name <name>` before you can participate in new groups."
+        )
+
+
+@command_handler(
+    needs_auth=True,
+    management_only=True,
+    help_section=SECTION_AUTH,
+    help_text="Remove all local data about your Signal link",
+)
 async def logout(evt: CommandEvent) -> None:
 async def logout(evt: CommandEvent) -> None:
     if not evt.sender.username:
     if not evt.sender.username:
         await evt.reply("You're not logged in")
         await evt.reply("You're not logged in")
@@ -147,16 +173,29 @@ async def logout(evt: CommandEvent) -> None:
     await evt.reply("Successfully logged out")
     await evt.reply("Successfully logged out")
 
 
 
 
-@command_handler(needs_auth=True, management_only=True, help_section=SECTION_AUTH,
-                 help_text="List devices linked to your Signal account")
+@command_handler(
+    needs_auth=True,
+    management_only=True,
+    help_section=SECTION_AUTH,
+    help_text="List devices linked to your Signal account",
+)
 async def list_devices(evt: CommandEvent) -> None:
 async def list_devices(evt: CommandEvent) -> None:
     devices = await evt.bridge.signal.get_linked_devices(evt.sender.username)
     devices = await evt.bridge.signal.get_linked_devices(evt.sender.username)
-    await evt.reply("\n".join(f"* #{dev.id}: {dev.name_with_default} (created {dev.created_fmt}, "
-                              f"last seen {dev.last_seen_fmt})" for dev in devices))
-
-
-@command_handler(needs_auth=True, management_only=True, help_section=SECTION_AUTH,
-                 help_text="Remove a linked device")
+    await evt.reply(
+        "\n".join(
+            f"* #{dev.id}: {dev.name_with_default} (created {dev.created_fmt}, last seen "
+            f"{dev.last_seen_fmt})"
+            for dev in devices
+        )
+    )
+
+
+@command_handler(
+    needs_auth=True,
+    management_only=True,
+    help_section=SECTION_AUTH,
+    help_text="Remove a linked device",
+)
 async def remove_linked_device(evt: CommandEvent) -> EventID:
 async def remove_linked_device(evt: CommandEvent) -> EventID:
     if len(evt.args) == 0:
     if len(evt.args) == 0:
         return await evt.reply("**Usage:** `$cmdprefix+sp remove-linked-device <device ID>`")
         return await evt.reply("**Usage:** `$cmdprefix+sp remove-linked-device <device ID>`")

+ 23 - 8
mautrix_signal/commands/conn.py

@@ -20,28 +20,42 @@ from .typehint import CommandEvent
 SECTION_CONNECTION = HelpSection("Connection management", 15, "")
 SECTION_CONNECTION = HelpSection("Connection management", 15, "")
 
 
 
 
-@command_handler(needs_auth=False, management_only=True, help_section=SECTION_CONNECTION,
-                 help_text="Mark this room as your bridge notice room.")
+@command_handler(
+    needs_auth=False,
+    management_only=True,
+    help_section=SECTION_CONNECTION,
+    help_text="Mark this room as your bridge notice room.",
+)
 async def set_notice_room(evt: CommandEvent) -> None:
 async def set_notice_room(evt: CommandEvent) -> None:
     evt.sender.notice_room = evt.room_id
     evt.sender.notice_room = evt.room_id
     await evt.sender.update()
     await evt.sender.update()
     await evt.reply("This room has been marked as your bridge notice room")
     await evt.reply("This room has been marked as your bridge notice room")
 
 
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_CONNECTION,
-                 help_text="Relay messages in this room through your Signal account.")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_CONNECTION,
+    help_text="Relay messages in this room through your Signal account.",
+)
 async def set_relay(evt: CommandEvent) -> EventID:
 async def set_relay(evt: CommandEvent) -> EventID:
     if not evt.config["bridge.relay.enabled"]:
     if not evt.config["bridge.relay.enabled"]:
         return await evt.reply("Relay mode is not enabled in this instance of the bridge.")
         return await evt.reply("Relay mode is not enabled in this instance of the bridge.")
     elif not evt.is_portal:
     elif not evt.is_portal:
         return await evt.reply("This is not a portal room.")
         return await evt.reply("This is not a portal room.")
     await evt.portal.set_relay_user(evt.sender)
     await evt.portal.set_relay_user(evt.sender)
-    return await evt.reply("Messages from non-logged-in users in this room will now be bridged "
-                           "through your Signal account.")
+    return await evt.reply(
+        "Messages from non-logged-in users in this room will now be bridged "
+        "through your Signal account."
+    )
 
 
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_CONNECTION,
-                 help_text="Stop relaying messages in this room.")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_CONNECTION,
+    help_text="Stop relaying messages in this room.",
+)
 async def unset_relay(evt: CommandEvent) -> EventID:
 async def unset_relay(evt: CommandEvent) -> EventID:
     if not evt.config["bridge.relay.enabled"]:
     if not evt.config["bridge.relay.enabled"]:
         return await evt.reply("Relay mode is not enabled in this instance of the bridge.")
         return await evt.reply("Relay mode is not enabled in this instance of the bridge.")
@@ -52,6 +66,7 @@ async def unset_relay(evt: CommandEvent) -> EventID:
     await evt.portal.set_relay_user(None)
     await evt.portal.set_relay_user(None)
     return await evt.reply("Messages from non-logged-in users will no longer be bridged.")
     return await evt.reply("Messages from non-logged-in users will no longer be bridged.")
 
 
+
 # @command_handler(needs_auth=False, management_only=True, help_section=SECTION_CONNECTION,
 # @command_handler(needs_auth=False, management_only=True, help_section=SECTION_CONNECTION,
 #                  help_text="Check if you're logged into Twitter")
 #                  help_text="Check if you're logged into Twitter")
 # async def ping(evt: CommandEvent) -> None:
 # async def ping(evt: CommandEvent) -> None:

+ 97 - 46
mautrix_signal/commands/signal.py

@@ -35,15 +35,19 @@ except ImportError:
 SECTION_SIGNAL = HelpSection("Signal actions", 20, "")
 SECTION_SIGNAL = HelpSection("Signal actions", 20, "")
 
 
 
 
-async def _get_puppet_from_cmd(evt: CommandEvent) -> Optional['pu.Puppet']:
+async def _get_puppet_from_cmd(evt: CommandEvent) -> Optional["pu.Puppet"]:
     if len(evt.args) == 0 or not evt.args[0].startswith("+"):
     if len(evt.args) == 0 or not evt.args[0].startswith("+"):
-        await evt.reply(f"**Usage:** `$cmdprefix+sp {evt.command} <phone>` "
-                        "(enter phone number in international format)")
+        await evt.reply(
+            f"**Usage:** `$cmdprefix+sp {evt.command} <phone>` "
+            "(enter phone number in international format)"
+        )
         return None
         return None
     phone = "".join(evt.args).translate(remove_extra_chars)
     phone = "".join(evt.args).translate(remove_extra_chars)
     if not phone[1:].isdecimal():
     if not phone[1:].isdecimal():
-        await evt.reply(f"**Usage:** `$cmdprefix+sp {evt.command} <phone>` "
-                        "(enter phone number in international format)")
+        await evt.reply(
+            f"**Usage:** `$cmdprefix+sp {evt.command} <phone>` "
+            "(enter phone number in international format)"
+        )
         return None
         return None
     return await pu.Puppet.get_by_address(Address(number=phone))
     return await pu.Puppet.get_by_address(Address(number=phone))
 
 
@@ -51,40 +55,58 @@ async def _get_puppet_from_cmd(evt: CommandEvent) -> Optional['pu.Puppet']:
 def _format_safety_number(number: str) -> str:
 def _format_safety_number(number: str) -> str:
     line_size = 20
     line_size = 20
     chunk_size = 5
     chunk_size = 5
-    return "\n".join(" ".join([number[chunk:chunk + chunk_size]
-                               for chunk in range(line, line + line_size, chunk_size)])
-                     for line in range(0, len(number), line_size))
-
-
-def _pill(puppet: 'pu.Puppet') -> str:
+    return "\n".join(
+        " ".join(
+            [
+                number[chunk : chunk + chunk_size]
+                for chunk in range(line, line + line_size, chunk_size)
+            ]
+        )
+        for line in range(0, len(number), line_size)
+    )
+
+
+def _pill(puppet: "pu.Puppet") -> str:
     return f"[{puppet.name}](https://matrix.to/#/{puppet.mxid})"
     return f"[{puppet.name}](https://matrix.to/#/{puppet.mxid})"
 
 
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_SIGNAL,
-                 help_text="Open a private chat portal with a specific phone number",
-                 help_args="<_phone_>")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_SIGNAL,
+    help_text="Open a private chat portal with a specific phone number",
+    help_args="<_phone_>",
+)
 async def pm(evt: CommandEvent) -> None:
 async def pm(evt: CommandEvent) -> None:
     puppet = await _get_puppet_from_cmd(evt)
     puppet = await _get_puppet_from_cmd(evt)
     if not puppet:
     if not puppet:
         return
         return
-    portal = await po.Portal.get_by_chat_id(puppet.address, receiver=evt.sender.username,
-                                            create=True)
+    portal = await po.Portal.get_by_chat_id(
+        puppet.address, receiver=evt.sender.username, create=True
+    )
     if portal.mxid:
     if portal.mxid:
-        await evt.reply(f"You already have a private chat with {puppet.name}: "
-                        f"[{portal.mxid}](https://matrix.to/#/{portal.mxid})")
+        await evt.reply(
+            f"You already have a private chat with {puppet.name}: "
+            f"[{portal.mxid}](https://matrix.to/#/{portal.mxid})"
+        )
         await portal.main_intent.invite_user(portal.mxid, evt.sender.mxid)
         await portal.main_intent.invite_user(portal.mxid, evt.sender.mxid)
         return
         return
     await portal.create_matrix_room(evt.sender, puppet.address)
     await portal.create_matrix_room(evt.sender, puppet.address)
     await evt.reply(f"Created a portal room with {_pill(puppet)} and invited you to it")
     await evt.reply(f"Created a portal room with {_pill(puppet)} and invited you to it")
 
 
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_SIGNAL,
-                 help_text="Get the invite link to the current group")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_SIGNAL,
+    help_text="Get the invite link to the current group",
+)
 async def invite_link(evt: CommandEvent) -> EventID:
 async def invite_link(evt: CommandEvent) -> EventID:
     if not evt.is_portal:
     if not evt.is_portal:
         return await evt.reply("This is not a portal room.")
         return await evt.reply("This is not a portal room.")
-    group = await evt.bridge.signal.get_group(evt.sender.username, evt.portal.chat_id,
-                                              evt.portal.revision)
+    group = await evt.bridge.signal.get_group(
+        evt.sender.username, evt.portal.chat_id, evt.portal.revision
+    )
     if not group:
     if not group:
         await evt.reply("Failed to get group info")
         await evt.reply("Failed to get group info")
     elif not group.invite_link:
     elif not group.invite_link:
@@ -93,9 +115,13 @@ async def invite_link(evt: CommandEvent) -> EventID:
         await evt.reply(group.invite_link)
         await evt.reply(group.invite_link)
 
 
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_SIGNAL,
-                 help_text="View the safety number of a specific user",
-                 help_args="[--qr] [_phone_]")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_SIGNAL,
+    help_text="View the safety number of a specific user",
+    help_args="[--qr] [_phone_]",
+)
 async def safety_number(evt: CommandEvent) -> None:
 async def safety_number(evt: CommandEvent) -> None:
     show_qr = evt.args and evt.args[0].lower() == "--qr"
     show_qr = evt.args and evt.args[0].lower() == "--qr"
     if show_qr:
     if show_qr:
@@ -119,53 +145,77 @@ async def safety_number(evt: CommandEvent) -> None:
         if identity.added > most_recent.added:
         if identity.added > most_recent.added:
             most_recent = identity
             most_recent = identity
     uuid = resp.address.uuid or "unknown"
     uuid = resp.address.uuid or "unknown"
-    await evt.reply(f"### {puppet.name}\n\n"
-                    f"**UUID:** {uuid}  \n"
-                    f"**Trust level:** {most_recent.trust_level}  \n"
-                    f"**Safety number:**\n"
-                    f"```\n{_format_safety_number(most_recent.safety_number)}\n```")
+    await evt.reply(
+        f"### {puppet.name}\n\n"
+        f"**UUID:** {uuid}  \n"
+        f"**Trust level:** {most_recent.trust_level}  \n"
+        f"**Safety number:**\n"
+        f"```\n{_format_safety_number(most_recent.safety_number)}\n```"
+    )
     if show_qr and most_recent.qr_code_data:
     if show_qr and most_recent.qr_code_data:
         data = base64.b64decode(most_recent.qr_code_data)
         data = base64.b64decode(most_recent.qr_code_data)
         content = await make_qr(evt.main_intent, data, "verification-qr.png")
         content = await make_qr(evt.main_intent, data, "verification-qr.png")
         await evt.main_intent.send_message(evt.room_id, content)
         await evt.main_intent.send_message(evt.room_id, content)
 
 
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_SIGNAL,
-                 help_text="Set your Signal profile name", help_args="<_name_>")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_SIGNAL,
+    help_text="Set your Signal profile name",
+    help_args="<_name_>",
+)
 async def set_profile_name(evt: CommandEvent) -> None:
 async def set_profile_name(evt: CommandEvent) -> None:
     await evt.bridge.signal.set_profile(evt.sender.username, name=" ".join(evt.args))
     await evt.bridge.signal.set_profile(evt.sender.username, name=" ".join(evt.args))
     await evt.reply("Successfully updated profile name")
     await evt.reply("Successfully updated profile name")
 
 
 
 
-@command_handler(needs_auth=True, management_only=False, help_section=SECTION_SIGNAL,
-                 help_text="Mark another user's safety number as trusted",
-                 help_args="<_recipient phone_> <_safety number_>")
+@command_handler(
+    needs_auth=True,
+    management_only=False,
+    help_section=SECTION_SIGNAL,
+    help_text="Mark another user's safety number as trusted",
+    help_args="<_recipient phone_> <_safety number_>",
+)
 async def mark_trusted(evt: CommandEvent) -> EventID:
 async def mark_trusted(evt: CommandEvent) -> EventID:
     if len(evt.args) < 2:
     if len(evt.args) < 2:
-        return await evt.reply("**Usage:** `$cmdprefix+sp mark-trusted <recipient phone> "
-                               "<safety number>`")
+        return await evt.reply(
+            "**Usage:** `$cmdprefix+sp mark-trusted <recipient phone> " "<safety number>`"
+        )
     number = evt.args[0].translate(remove_extra_chars)
     number = evt.args[0].translate(remove_extra_chars)
     safety_num = "".join(evt.args[1:]).replace("\n", "")
     safety_num = "".join(evt.args[1:]).replace("\n", "")
     if len(safety_num) != 60 or not safety_num.isdecimal():
     if len(safety_num) != 60 or not safety_num.isdecimal():
         return await evt.reply("That doesn't look like a valid safety number")
         return await evt.reply("That doesn't look like a valid safety number")
     try:
     try:
-        await evt.bridge.signal.trust(evt.sender.username, Address(number=number),
-                                      safety_number=safety_num, trust_level="TRUSTED_VERIFIED")
+        await evt.bridge.signal.trust(
+            evt.sender.username,
+            Address(number=number),
+            safety_number=safety_num,
+            trust_level="TRUSTED_VERIFIED",
+        )
     except UnknownIdentityKey as e:
     except UnknownIdentityKey as e:
         return await evt.reply(f"Failed to mark {number} as trusted: {e}")
         return await evt.reply(f"Failed to mark {number} as trusted: {e}")
     return await evt.reply(f"Successfully marked {number} as trusted")
     return await evt.reply(f"Successfully marked {number} as trusted")
 
 
 
 
-@command_handler(needs_admin=False, needs_auth=True, help_section=SECTION_SIGNAL,
-                 help_text="Sync data from Signal")
+@command_handler(
+    needs_admin=False,
+    needs_auth=True,
+    help_section=SECTION_SIGNAL,
+    help_text="Sync data from Signal",
+)
 async def sync(evt: CommandEvent) -> None:
 async def sync(evt: CommandEvent) -> None:
     await evt.sender.sync()
     await evt.sender.sync()
     await evt.reply("Sync complete")
     await evt.reply("Sync complete")
 
 
 
 
-@command_handler(needs_admin=True, needs_auth=False, help_section=SECTION_ADMIN,
-                 help_text="Send raw requests to signald",
-                 help_args="[--user] <type> <_json_>")
+@command_handler(
+    needs_admin=True,
+    needs_auth=False,
+    help_section=SECTION_ADMIN,
+    help_text="Send raw requests to signald",
+    help_args="[--user] <type> <_json_>",
+)
 async def raw(evt: CommandEvent) -> None:
 async def raw(evt: CommandEvent) -> None:
     add_username = False
     add_username = False
     while True:
     while True:
@@ -200,5 +250,6 @@ async def raw(evt: CommandEvent) -> None:
         if resp_data is None:
         if resp_data is None:
             await evt.reply(f"Got reply `{resp_type}` with no content")
             await evt.reply(f"Got reply `{resp_type}` with no content")
         else:
         else:
-            await evt.reply(f"Got reply `{resp_type}`:\n\n"
-                            f"```json\n{json.dumps(resp_data, indent=2)}\n```")
+            await evt.reply(
+                f"Got reply `{resp_type}`:\n\n" f"```json\n{json.dumps(resp_data, indent=2)}\n```"
+            )

+ 3 - 3
mautrix_signal/commands/typehint.py

@@ -9,6 +9,6 @@ if TYPE_CHECKING:
 
 
 
 
 class CommandEvent(BaseCommandEvent):
 class CommandEvent(BaseCommandEvent):
-    bridge: 'SignalBridge'
-    sender: 'User'
-    portal: 'Portal'
+    bridge: "SignalBridge"
+    sender: "User"
+    portal: "Portal"

+ 13 - 3
mautrix_signal/db/__init__.py

@@ -18,7 +18,17 @@ def init(db: Database) -> None:
 
 
 # TODO should this be in mautrix-python?
 # TODO should this be in mautrix-python?
 sqlite3.register_adapter(uuid.UUID, lambda u: str(u))
 sqlite3.register_adapter(uuid.UUID, lambda u: str(u))
-sqlite3.register_converter("UUID", lambda b: uuid.UUID(b.decode("utf-8") if isinstance(b, bytes) else b))
+sqlite3.register_converter(
+    "UUID", lambda b: uuid.UUID(b.decode("utf-8") if isinstance(b, bytes) else b)
+)
 
 
-__all__ = ["upgrade_table", "init", "User", "Puppet", "Portal", "Message", "Reaction",
-           "DisappearingMessage"]
+__all__ = [
+    "upgrade_table",
+    "init",
+    "User",
+    "Puppet",
+    "Portal",
+    "Message",
+    "Reaction",
+    "DisappearingMessage",
+]

+ 6 - 4
mautrix_signal/db/disappearing_message.py

@@ -38,8 +38,9 @@ class DisappearingMessage:
         INSERT INTO disappearing_message (room_id, mxid, expiration_seconds, expiration_ts)
         INSERT INTO disappearing_message (room_id, mxid, expiration_seconds, expiration_ts)
         VALUES ($1, $2, $3, $4)
         VALUES ($1, $2, $3, $4)
         """
         """
-        await self.db.execute(q, self.room_id, self.mxid, self.expiration_seconds,
-                              self.expiration_ts)
+        await self.db.execute(
+            q, self.room_id, self.mxid, self.expiration_seconds, self.expiration_ts
+        )
 
 
     async def update(self) -> None:
     async def update(self) -> None:
         q = """
         q = """
@@ -48,8 +49,9 @@ class DisappearingMessage:
         WHERE room_id=$1 AND mxid=$2
         WHERE room_id=$1 AND mxid=$2
         """
         """
         try:
         try:
-            await self.db.execute(q, self.room_id, self.mxid, self.expiration_seconds,
-                              self.expiration_ts)
+            await self.db.execute(
+                q, self.room_id, self.mxid, self.expiration_seconds, self.expiration_ts
+            )
         except Exception as e:
         except Exception as e:
             print(e)
             print(e)
 
 

+ 65 - 29
mautrix_signal/db/message.py

@@ -40,23 +40,39 @@ class Message:
     signal_receiver: str
     signal_receiver: str
 
 
     async def insert(self) -> None:
     async def insert(self) -> None:
-        q = ("INSERT INTO message (mxid, mx_room, sender, timestamp, signal_chat_id,"
-             "                     signal_receiver) VALUES ($1, $2, $3, $4, $5, $6)")
-        await self.db.execute(q, self.mxid, self.mx_room, self.sender.best_identifier,
-                              self.timestamp, id_to_str(self.signal_chat_id), self.signal_receiver)
+        q = """
+        INSERT INTO message (mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver)
+        VALUES ($1, $2, $3, $4, $5, $6)
+        """
+        await self.db.execute(
+            q,
+            self.mxid,
+            self.mx_room,
+            self.sender.best_identifier,
+            self.timestamp,
+            id_to_str(self.signal_chat_id),
+            self.signal_receiver,
+        )
 
 
     async def delete(self) -> None:
     async def delete(self) -> None:
-        q = ("DELETE FROM message WHERE sender=$1 AND timestamp=$2"
-             "                          AND signal_chat_id=$3 AND signal_receiver=$4")
-        await self.db.execute(q, self.sender.best_identifier, self.timestamp,
-                              id_to_str(self.signal_chat_id), self.signal_receiver)
+        q = """
+        DELETE FROM message
+         WHERE sender=$1 AND timestamp=$2 AND signal_chat_id=$3 AND signal_receiver=$4
+        """
+        await self.db.execute(
+            q,
+            self.sender.best_identifier,
+            self.timestamp,
+            id_to_str(self.signal_chat_id),
+            self.signal_receiver,
+        )
 
 
     @classmethod
     @classmethod
     async def delete_all(cls, room_id: RoomID) -> None:
     async def delete_all(cls, room_id: RoomID) -> None:
         await cls.db.execute("DELETE FROM message WHERE mx_room=$1", room_id)
         await cls.db.execute("DELETE FROM message WHERE mx_room=$1", room_id)
 
 
     @classmethod
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> 'Message':
+    def _from_row(cls, row: asyncpg.Record) -> "Message":
         data = {**row}
         data = {**row}
         chat_id = data.pop("signal_chat_id")
         chat_id = data.pop("signal_chat_id")
         if data["signal_receiver"]:
         if data["signal_receiver"]:
@@ -65,44 +81,64 @@ class Message:
         return cls(signal_chat_id=chat_id, sender=sender, **data)
         return cls(signal_chat_id=chat_id, sender=sender, **data)
 
 
     @classmethod
     @classmethod
-    async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional['Message']:
-        q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
-             "FROM message WHERE mxid=$1 AND mx_room=$2")
+    async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional["Message"]:
+        q = """
+        SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
+          FROM message WHERE mxid=$1 AND mx_room=$2
+        """
         row = await cls.db.fetchrow(q, mxid, mx_room)
         row = await cls.db.fetchrow(q, mxid, mx_room)
         if not row:
         if not row:
             return None
             return None
         return cls._from_row(row)
         return cls._from_row(row)
 
 
     @classmethod
     @classmethod
-    async def get_by_signal_id(cls, sender: Address, timestamp: int,
-                               signal_chat_id: Union[GroupID, Address], signal_receiver: str = ""
-                               ) -> Optional['Message']:
-        q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
-             "FROM message WHERE sender=$1 AND timestamp=$2"
-             "                   AND signal_chat_id=$3 AND signal_receiver=$4")
-        row = await cls.db.fetchrow(q, sender.best_identifier, timestamp,
-                                    id_to_str(signal_chat_id), signal_receiver)
+    async def get_by_signal_id(
+        cls,
+        sender: Address,
+        timestamp: int,
+        signal_chat_id: Union[GroupID, Address],
+        signal_receiver: str = "",
+    ) -> Optional["Message"]:
+        q = """
+        SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
+          FROM message
+         WHERE sender=$1 AND timestamp=$2 AND signal_chat_id=$3 AND signal_receiver=$4
+        """
+        row = await cls.db.fetchrow(
+            q, sender.best_identifier, timestamp, id_to_str(signal_chat_id), signal_receiver
+        )
         if not row:
         if not row:
             return None
             return None
         return cls._from_row(row)
         return cls._from_row(row)
 
 
     @classmethod
     @classmethod
-    async def find_by_timestamps(cls, timestamps: List[int]) -> List['Message']:
+    async def find_by_timestamps(cls, timestamps: List[int]) -> List["Message"]:
         if cls.db.scheme == "postgres":
         if cls.db.scheme == "postgres":
-            q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
-                 "FROM message WHERE timestamp=ANY($1)")
+            q = """
+            SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
+              FROM message
+             WHERE timestamp=ANY($1)
+            """
             rows = await cls.db.fetch(q, timestamps)
             rows = await cls.db.fetch(q, timestamps)
         else:
         else:
-            placeholders = ", ".join(f"?" for _ in range(len(timestamps)))
-            q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
-                 f"FROM message WHERE timestamp IN ({placeholders})")
+            placeholders = ", ".join("?" for _ in range(len(timestamps)))
+            q = f"""
+            SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
+              FROM message
+             WHERE timestamp IN ({placeholders})
+            """
             rows = await cls.db.fetch(q, *timestamps)
             rows = await cls.db.fetch(q, *timestamps)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]
 
 
     @classmethod
     @classmethod
-    async def find_by_sender_timestamp(cls, sender: Address, timestamp: int) -> Optional['Message']:
-        q = ("SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver "
-             "FROM message WHERE sender=$1 AND timestamp=$2")
+    async def find_by_sender_timestamp(
+        cls, sender: Address, timestamp: int
+    ) -> Optional["Message"]:
+        q = """
+        SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
+          FROM message
+         WHERE sender=$1 AND timestamp=$2
+        """
         row = await cls.db.fetchrow(q, sender.best_identifier, timestamp)
         row = await cls.db.fetchrow(q, sender.best_identifier, timestamp)
         if not row:
         if not row:
             return None
             return None

+ 79 - 38
mautrix_signal/db/portal.py

@@ -49,27 +49,52 @@ class Portal:
         return id_to_str(self.chat_id)
         return id_to_str(self.chat_id)
 
 
     async def insert(self) -> None:
     async def insert(self) -> None:
-        q = ("INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, "
-             "                    name_set, avatar_set, revision, encrypted, relay_user_id, "
-             "                    expiration_time) "
-             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)")
-        await self.db.execute(q, self.chat_id_str, self.receiver, self.mxid, self.name,
-                              self.avatar_hash, self.avatar_url, self.name_set, self.avatar_set,
-                              self.revision, self.encrypted, self.relay_user_id,
-                              self.expiration_time)
+        q = """
+        INSERT INTO portal (chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set,
+                            avatar_set, revision, encrypted, relay_user_id, expiration_time)
+        VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
+        """
+        await self.db.execute(
+            q,
+            self.chat_id_str,
+            self.receiver,
+            self.mxid,
+            self.name,
+            self.avatar_hash,
+            self.avatar_url,
+            self.name_set,
+            self.avatar_set,
+            self.revision,
+            self.encrypted,
+            self.relay_user_id,
+            self.expiration_time,
+        )
 
 
     async def update(self) -> None:
     async def update(self) -> None:
-        q = ("UPDATE portal SET mxid=$1, name=$2, avatar_hash=$3, avatar_url=$4, name_set=$5, "
-             "                  avatar_set=$6, revision=$7, encrypted=$8, relay_user_id=$9, "
-             "                  expiration_time=$10"
-             "WHERE chat_id=$11 AND receiver=$12")
-        await self.db.execute(q, self.mxid, self.name, self.avatar_hash, self.avatar_url,
-                              self.name_set, self.avatar_set, self.revision, self.encrypted,
-                              self.relay_user_id, self.expiration_time, self.chat_id_str,
-                              self.receiver)
+        q = """
+        UPDATE portal SET mxid=$1, name=$2, avatar_hash=$3, avatar_url=$4, name_set=$5,
+                          avatar_set=$6, revision=$7, encrypted=$8, relay_user_id=$9,
+                          expiration_time=$10
+        WHERE chat_id=$11 AND receiver=$12
+        """
+        await self.db.execute(
+            q,
+            self.mxid,
+            self.name,
+            self.avatar_hash,
+            self.avatar_url,
+            self.name_set,
+            self.avatar_set,
+            self.revision,
+            self.encrypted,
+            self.relay_user_id,
+            self.expiration_time,
+            self.chat_id_str,
+            self.receiver,
+        )
 
 
     @classmethod
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> 'Portal':
+    def _from_row(cls, row: asyncpg.Record) -> "Portal":
         data = {**row}
         data = {**row}
         chat_id = data.pop("chat_id")
         chat_id = data.pop("chat_id")
         if data["receiver"]:
         if data["receiver"]:
@@ -77,46 +102,62 @@ class Portal:
         return cls(chat_id=chat_id, **data)
         return cls(chat_id=chat_id, **data)
 
 
     @classmethod
     @classmethod
-    async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id, expiration_time "
-             "FROM portal WHERE mxid=$1")
+    async def get_by_mxid(cls, mxid: RoomID) -> Optional["Portal"]:
+        q = """
+        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
+               revision, encrypted, relay_user_id, expiration_time
+          FROM portal
+         WHERE mxid=$1
+        """
         row = await cls.db.fetchrow(q, mxid)
         row = await cls.db.fetchrow(q, mxid)
         if not row:
         if not row:
             return None
             return None
         return cls._from_row(row)
         return cls._from_row(row)
 
 
     @classmethod
     @classmethod
-    async def get_by_chat_id(cls, chat_id: Union[GroupID, Address], receiver: str = ""
-                             ) -> Optional['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id, expiration_time "
-             "FROM portal WHERE chat_id=$1 AND receiver=$2")
+    async def get_by_chat_id(
+        cls, chat_id: Union[GroupID, Address], receiver: str = ""
+    ) -> Optional["Portal"]:
+        q = """
+        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
+               revision, encrypted, relay_user_id, expiration_time
+          FROM portal
+         WHERE chat_id=$1 AND receiver=$2
+        """
         row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
         row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
         if not row:
         if not row:
             return None
             return None
         return cls._from_row(row)
         return cls._from_row(row)
 
 
     @classmethod
     @classmethod
-    async def find_private_chats_of(cls, receiver: str) -> List['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id, expiration_time "
-             "FROM portal WHERE receiver=$1")
+    async def find_private_chats_of(cls, receiver: str) -> List["Portal"]:
+        q = """
+        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
+               revision, encrypted, relay_user_id, expiration_time
+          FROM portal
+         WHERE receiver=$1
+        """
         rows = await cls.db.fetch(q, receiver)
         rows = await cls.db.fetch(q, receiver)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]
 
 
     @classmethod
     @classmethod
-    async def find_private_chats_with(cls, other_user: Address) -> List['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id, expiration_time "
-             "FROM portal WHERE chat_id=$1 AND receiver<>''")
+    async def find_private_chats_with(cls, other_user: Address) -> List["Portal"]:
+        q = """
+        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
+               revision, encrypted, relay_user_id, expiration_time
+          FROM portal
+         WHERE chat_id=$1 AND receiver<>''
+        """
         rows = await cls.db.fetch(q, other_user.best_identifier)
         rows = await cls.db.fetch(q, other_user.best_identifier)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]
 
 
     @classmethod
     @classmethod
-    async def all_with_room(cls) -> List['Portal']:
-        q = ("SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,"
-             "       revision, encrypted, relay_user_id, expiration_time "
-             "FROM portal WHERE mxid IS NOT NULL")
+    async def all_with_room(cls) -> List["Portal"]:
+        q = """
+        SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
+               revision, encrypted, relay_user_id, expiration_time
+          FROM portal
+         WHERE mxid IS NOT NULL
+        """
         rows = await cls.db.fetch(q)
         rows = await cls.db.fetch(q)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]

+ 68 - 33
mautrix_signal/db/puppet.py

@@ -52,36 +52,54 @@ class Puppet:
         return str(self.base_url) if self.base_url else None
         return str(self.base_url) if self.base_url else None
 
 
     async def insert(self) -> None:
     async def insert(self) -> None:
-        q = ("INSERT INTO puppet (uuid, number, name, avatar_hash, avatar_url, name_set, "
-             "                    avatar_set, uuid_registered, number_registered, "
-             "                    custom_mxid, access_token, next_batch, base_url) "
-             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)")
-        await self.db.execute(q, self.uuid, self.number, self.name, self.avatar_hash,
-                              self.avatar_url, self.name_set, self.avatar_set,
-                              self.uuid_registered, self.number_registered, self.custom_mxid,
-                              self.access_token, self.next_batch, self._base_url_str)
+        q = (
+            "INSERT INTO puppet (uuid, number, name, avatar_hash, avatar_url, name_set, "
+            "                    avatar_set, uuid_registered, number_registered, "
+            "                    custom_mxid, access_token, next_batch, base_url) "
+            "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)"
+        )
+        await self.db.execute(
+            q,
+            self.uuid,
+            self.number,
+            self.name,
+            self.avatar_hash,
+            self.avatar_url,
+            self.name_set,
+            self.avatar_set,
+            self.uuid_registered,
+            self.number_registered,
+            self.custom_mxid,
+            self.access_token,
+            self.next_batch,
+            self._base_url_str,
+        )
 
 
     async def _set_uuid(self, uuid: UUID) -> None:
     async def _set_uuid(self, uuid: UUID) -> None:
         async with self.db.acquire() as conn, conn.transaction():
         async with self.db.acquire() as conn, conn.transaction():
-            await conn.execute("DELETE FROM puppet WHERE uuid=$1 AND number<>$2",
-                               uuid, self.number)
+            await conn.execute(
+                "DELETE FROM puppet WHERE uuid=$1 AND number<>$2", uuid, self.number
+            )
             await conn.execute("UPDATE puppet SET uuid=$1 WHERE number=$2", uuid, self.number)
             await conn.execute("UPDATE puppet SET uuid=$1 WHERE number=$2", uuid, self.number)
             await self._update_number_to_uuid(conn, self.number, str(uuid))
             await self._update_number_to_uuid(conn, self.number, str(uuid))
 
 
     async def _set_number(self, number: str) -> None:
     async def _set_number(self, number: str) -> None:
         async with self.db.acquire() as conn, conn.transaction():
         async with self.db.acquire() as conn, conn.transaction():
-            await conn.execute("DELETE FROM puppet WHERE number=$1 AND uuid<>$2",
-                               number, self.uuid)
+            await conn.execute(
+                "DELETE FROM puppet WHERE number=$1 AND uuid<>$2", number, self.uuid
+            )
             await conn.execute("UPDATE puppet SET number=$1 WHERE uuid=$2", number, self.uuid)
             await conn.execute("UPDATE puppet SET number=$1 WHERE uuid=$2", number, self.uuid)
             await self._update_number_to_uuid(conn, number, str(self.uuid))
             await self._update_number_to_uuid(conn, number, str(self.uuid))
 
 
     @staticmethod
     @staticmethod
-    async def _update_number_to_uuid(conn: asyncpg.Connection, old_number: str, new_uuid: str
-                                     ) -> None:
+    async def _update_number_to_uuid(
+        conn: asyncpg.Connection, old_number: str, new_uuid: str
+    ) -> None:
         try:
         try:
             async with conn.transaction():
             async with conn.transaction():
-                await conn.execute("UPDATE portal SET chat_id=$1 WHERE chat_id=$2",
-                                   new_uuid, old_number)
+                await conn.execute(
+                    "UPDATE portal SET chat_id=$1 WHERE chat_id=$2", new_uuid, old_number
+                )
         except asyncpg.UniqueViolationError:
         except asyncpg.UniqueViolationError:
             await conn.execute("DELETE FROM portal WHERE chat_id=$1", old_number)
             await conn.execute("DELETE FROM portal WHERE chat_id=$1", old_number)
         await conn.execute("UPDATE message SET sender=$1 WHERE sender=$2", new_uuid, old_number)
         await conn.execute("UPDATE message SET sender=$1 WHERE sender=$2", new_uuid, old_number)
@@ -93,32 +111,49 @@ class Puppet:
             "uuid_registered=$8, number_registered=$9, "
             "uuid_registered=$8, number_registered=$9, "
             "custom_mxid=$10, access_token=$11, next_batch=$12, base_url=$13"
             "custom_mxid=$10, access_token=$11, next_batch=$12, base_url=$13"
         )
         )
-        q = (f"UPDATE puppet SET uuid=$1, {set_columns} WHERE number=$2"
-             if self.uuid is None
-             else f"UPDATE puppet SET number=$2, {set_columns} WHERE uuid=$1")
-        await self.db.execute(q,self.uuid, self.number, self.name, self.avatar_hash,
-                              self.avatar_url, self.name_set, self.avatar_set,
-                              self.uuid_registered, self.number_registered, self.custom_mxid,
-                              self.access_token, self.next_batch, self._base_url_str)
+        q = (
+            f"UPDATE puppet SET uuid=$1, {set_columns} WHERE number=$2"
+            if self.uuid is None
+            else f"UPDATE puppet SET number=$2, {set_columns} WHERE uuid=$1"
+        )
+        await self.db.execute(
+            q,
+            self.uuid,
+            self.number,
+            self.name,
+            self.avatar_hash,
+            self.avatar_url,
+            self.name_set,
+            self.avatar_set,
+            self.uuid_registered,
+            self.number_registered,
+            self.custom_mxid,
+            self.access_token,
+            self.next_batch,
+            self._base_url_str,
+        )
 
 
     @classmethod
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> 'Puppet':
+    def _from_row(cls, row: asyncpg.Record) -> "Puppet":
         data = {**row}
         data = {**row}
         base_url_str = data.pop("base_url")
         base_url_str = data.pop("base_url")
         base_url = URL(base_url_str) if base_url_str is not None else None
         base_url = URL(base_url_str) if base_url_str is not None else None
         return cls(base_url=base_url, **data)
         return cls(base_url=base_url, **data)
 
 
-    _select_base = ("SELECT uuid, number, name, avatar_hash, avatar_url, name_set, avatar_set, "
-                    "       uuid_registered, number_registered, custom_mxid, access_token, "
-                    "       next_batch, base_url "
-                    "FROM puppet")
+    _select_base = (
+        "SELECT uuid, number, name, avatar_hash, avatar_url, name_set, avatar_set, "
+        "       uuid_registered, number_registered, custom_mxid, access_token, "
+        "       next_batch, base_url "
+        "FROM puppet"
+    )
 
 
     @classmethod
     @classmethod
-    async def get_by_address(cls, address: Address) -> Optional['Puppet']:
+    async def get_by_address(cls, address: Address) -> Optional["Puppet"]:
         if address.uuid:
         if address.uuid:
             if address.number:
             if address.number:
-                row = await cls.db.fetchrow(f"{cls._select_base} WHERE uuid=$1 OR number=$2",
-                                            address.uuid, address.number)
+                row = await cls.db.fetchrow(
+                    f"{cls._select_base} WHERE uuid=$1 OR number=$2", address.uuid, address.number
+                )
             else:
             else:
                 row = await cls.db.fetchrow(f"{cls._select_base} WHERE uuid=$1", address.uuid)
                 row = await cls.db.fetchrow(f"{cls._select_base} WHERE uuid=$1", address.uuid)
         elif address.number:
         elif address.number:
@@ -130,13 +165,13 @@ class Puppet:
         return cls._from_row(row)
         return cls._from_row(row)
 
 
     @classmethod
     @classmethod
-    async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
+    async def get_by_custom_mxid(cls, mxid: UserID) -> Optional["Puppet"]:
         row = await cls.db.fetchrow(f"{cls._select_base} WHERE custom_mxid=$1", mxid)
         row = await cls.db.fetchrow(f"{cls._select_base} WHERE custom_mxid=$1", mxid)
         if not row:
         if not row:
             return None
             return None
         return cls._from_row(row)
         return cls._from_row(row)
 
 
     @classmethod
     @classmethod
-    async def all_with_custom_mxid(cls) -> List['Puppet']:
+    async def all_with_custom_mxid(cls) -> List["Puppet"]:
         rows = await cls.db.fetch(f"{cls._select_base} WHERE custom_mxid IS NOT NULL")
         rows = await cls.db.fetch(f"{cls._select_base} WHERE custom_mxid IS NOT NULL")
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]

+ 70 - 31
mautrix_signal/db/reaction.py

@@ -42,30 +42,54 @@ class Reaction:
     emoji: str
     emoji: str
 
 
     async def insert(self) -> None:
     async def insert(self) -> None:
-        q = ("INSERT INTO reaction (mxid, mx_room, signal_chat_id, signal_receiver, msg_author,"
-             "                      msg_timestamp, author, emoji) "
-             "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)")
-        await self.db.execute(q, self.mxid, self.mx_room, id_to_str(self.signal_chat_id),
-                              self.signal_receiver, self.msg_author.best_identifier,
-                              self.msg_timestamp, self.author.best_identifier, self.emoji)
+        q = (
+            "INSERT INTO reaction (mxid, mx_room, signal_chat_id, signal_receiver, msg_author,"
+            "                      msg_timestamp, author, emoji) "
+            "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
+        )
+        await self.db.execute(
+            q,
+            self.mxid,
+            self.mx_room,
+            id_to_str(self.signal_chat_id),
+            self.signal_receiver,
+            self.msg_author.best_identifier,
+            self.msg_timestamp,
+            self.author.best_identifier,
+            self.emoji,
+        )
 
 
     async def edit(self, mx_room: RoomID, mxid: EventID, emoji: str) -> None:
     async def edit(self, mx_room: RoomID, mxid: EventID, emoji: str) -> None:
-        await self.db.execute("UPDATE reaction SET mxid=$1, mx_room=$2, emoji=$3 "
-                              "WHERE signal_chat_id=$4 AND signal_receiver=$5"
-                              "      AND msg_author=$6 AND msg_timestamp=$7 AND author=$8",
-                              mxid, mx_room, emoji, id_to_str(self.signal_chat_id),
-                              self.signal_receiver, self.msg_author.best_identifier,
-                              self.msg_timestamp, self.author.best_identifier)
+        await self.db.execute(
+            "UPDATE reaction SET mxid=$1, mx_room=$2, emoji=$3 "
+            "WHERE signal_chat_id=$4 AND signal_receiver=$5"
+            "      AND msg_author=$6 AND msg_timestamp=$7 AND author=$8",
+            mxid,
+            mx_room,
+            emoji,
+            id_to_str(self.signal_chat_id),
+            self.signal_receiver,
+            self.msg_author.best_identifier,
+            self.msg_timestamp,
+            self.author.best_identifier,
+        )
 
 
     async def delete(self) -> None:
     async def delete(self) -> None:
-        q = ("DELETE FROM reaction WHERE signal_chat_id=$1 AND signal_receiver=$2"
-             "                           AND msg_author=$3 AND msg_timestamp=$4 AND author=$5")
-        await self.db.execute(q, id_to_str(self.signal_chat_id), self.signal_receiver,
-                              self.msg_author.best_identifier, self.msg_timestamp,
-                              self.author.best_identifier)
+        q = (
+            "DELETE FROM reaction WHERE signal_chat_id=$1 AND signal_receiver=$2"
+            "                           AND msg_author=$3 AND msg_timestamp=$4 AND author=$5"
+        )
+        await self.db.execute(
+            q,
+            id_to_str(self.signal_chat_id),
+            self.signal_receiver,
+            self.msg_author.best_identifier,
+            self.msg_timestamp,
+            self.author.best_identifier,
+        )
 
 
     @classmethod
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> 'Reaction':
+    def _from_row(cls, row: asyncpg.Record) -> "Reaction":
         data = {**row}
         data = {**row}
         chat_id = data.pop("signal_chat_id")
         chat_id = data.pop("signal_chat_id")
         if data["signal_receiver"]:
         if data["signal_receiver"]:
@@ -75,25 +99,40 @@ class Reaction:
         return cls(signal_chat_id=chat_id, msg_author=msg_author, author=author, **data)
         return cls(signal_chat_id=chat_id, msg_author=msg_author, author=author, **data)
 
 
     @classmethod
     @classmethod
-    async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional['Reaction']:
-        q = ("SELECT mxid, mx_room, signal_chat_id, signal_receiver,"
-             "       msg_author, msg_timestamp, author, emoji "
-             "FROM reaction WHERE mxid=$1 AND mx_room=$2")
+    async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional["Reaction"]:
+        q = (
+            "SELECT mxid, mx_room, signal_chat_id, signal_receiver,"
+            "       msg_author, msg_timestamp, author, emoji "
+            "FROM reaction WHERE mxid=$1 AND mx_room=$2"
+        )
         row = await cls.db.fetchrow(q, mxid, mx_room)
         row = await cls.db.fetchrow(q, mxid, mx_room)
         if not row:
         if not row:
             return None
             return None
         return cls._from_row(row)
         return cls._from_row(row)
 
 
     @classmethod
     @classmethod
-    async def get_by_signal_id(cls, chat_id: Union[GroupID, Address], receiver: str,
-                               msg_author: Address, msg_timestamp: int, author: Address
-                               ) -> Optional['Reaction']:
-        q = ("SELECT mxid, mx_room, signal_chat_id, signal_receiver,"
-             "       msg_author, msg_timestamp, author, emoji "
-             "FROM reaction WHERE signal_chat_id=$1 AND signal_receiver=$2"
-             "                    AND msg_author=$3 AND msg_timestamp=$4 AND author=$5")
-        row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver, msg_author.best_identifier,
-                                    msg_timestamp, author.best_identifier)
+    async def get_by_signal_id(
+        cls,
+        chat_id: Union[GroupID, Address],
+        receiver: str,
+        msg_author: Address,
+        msg_timestamp: int,
+        author: Address,
+    ) -> Optional["Reaction"]:
+        q = (
+            "SELECT mxid, mx_room, signal_chat_id, signal_receiver,"
+            "       msg_author, msg_timestamp, author, emoji "
+            "FROM reaction WHERE signal_chat_id=$1 AND signal_receiver=$2"
+            "                    AND msg_author=$3 AND msg_timestamp=$4 AND author=$5"
+        )
+        row = await cls.db.fetchrow(
+            q,
+            id_to_str(chat_id),
+            receiver,
+            msg_author.best_identifier,
+            msg_timestamp,
+            author.best_identifier,
+        )
         if not row:
         if not row:
             return None
             return None
         return cls._from_row(row)
         return cls._from_row(row)

+ 148 - 108
mautrix_signal/db/upgrade.py

@@ -22,97 +22,64 @@ upgrade_table = UpgradeTable()
 
 
 @upgrade_table.register(description="Initial revision")
 @upgrade_table.register(description="Initial revision")
 async def upgrade_v1(conn: Connection) -> None:
 async def upgrade_v1(conn: Connection) -> None:
-    await conn.execute("""CREATE TABLE portal (
-        chat_id     TEXT,
-        receiver    TEXT,
-        mxid        TEXT,
-        name        TEXT,
-        encrypted   BOOLEAN NOT NULL DEFAULT false,
-
-        PRIMARY KEY (chat_id, receiver)
-    )""")
-    await conn.execute("""CREATE TABLE "user" (
-        mxid        TEXT PRIMARY KEY,
-        username    TEXT,
-        uuid        UUID,
-        notice_room TEXT
-    )""")
-    await conn.execute("""CREATE TABLE puppet (
-        uuid      UUID UNIQUE,
-        number    TEXT UNIQUE,
-        name      TEXT,
-
-        uuid_registered   BOOLEAN NOT NULL DEFAULT false,
-        number_registered BOOLEAN NOT NULL DEFAULT false,
-
-        custom_mxid  TEXT,
-        access_token TEXT,
-        next_batch   TEXT
-    )""")
-    await conn.execute("""CREATE TABLE user_portal (
-        "user"          TEXT,
-        portal          TEXT,
-        portal_receiver TEXT,
-        in_community    BOOLEAN NOT NULL DEFAULT false,
-
-        FOREIGN KEY (portal, portal_receiver) REFERENCES portal(chat_id, receiver)
-            ON UPDATE CASCADE ON DELETE CASCADE
-    )""")
-    await conn.execute("""CREATE TABLE message (
-        mxid    TEXT NOT NULL,
-        mx_room TEXT NOT NULL,
-        sender          UUID,
-        timestamp       BIGINT,
-        signal_chat_id  TEXT,
-        signal_receiver TEXT,
-
-        PRIMARY KEY (sender, timestamp, signal_chat_id, signal_receiver),
-        FOREIGN KEY (signal_chat_id, signal_receiver) REFERENCES portal(chat_id, receiver)
-            ON UPDATE CASCADE ON DELETE CASCADE,
-        UNIQUE (mxid, mx_room)
-    )""")
-    await conn.execute("""CREATE TABLE reaction (
-        mxid    TEXT NOT NULL,
-        mx_room TEXT NOT NULL,
-
-        signal_chat_id  TEXT   NOT NULL,
-        signal_receiver TEXT   NOT NULL,
-        msg_author      UUID   NOT NULL,
-        msg_timestamp   BIGINT NOT NULL,
-        author          UUID   NOT NULL,
-
-        emoji TEXT NOT NULL,
-
-        PRIMARY KEY (signal_chat_id, signal_receiver, msg_author, msg_timestamp, author),
-        FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver)
-            REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver)
-            ON DELETE CASCADE ON UPDATE CASCADE,
-        UNIQUE (mxid, mx_room)
-    )""")
-
-
-@upgrade_table.register(description="Add avatar info to portal table")
-async def upgrade_v2(conn: Connection) -> None:
-    await conn.execute("ALTER TABLE portal ADD COLUMN avatar_hash TEXT")
-    await conn.execute("ALTER TABLE portal ADD COLUMN avatar_url TEXT")
-
-
-@upgrade_table.register(description="Add double-puppeting base_url to puppe table")
-async def upgrade_v3(conn: Connection) -> None:
-    await conn.execute("ALTER TABLE puppet ADD COLUMN base_url TEXT")
-
-
-@upgrade_table.register(description="Allow phone numbers as message sender identifiers")
-async def upgrade_v4(conn: Connection, scheme: str) -> None:
-    if scheme == "sqlite":
-        # SQLite doesn't have anything in the tables yet,
-        # so just recreate them without migrating data
-        await conn.execute("DROP TABLE message")
-        await conn.execute("DROP TABLE reaction")
-        await conn.execute("""CREATE TABLE message (
+    await conn.execute(
+        """
+        CREATE TABLE portal (
+            chat_id     TEXT,
+            receiver    TEXT,
+            mxid        TEXT,
+            name        TEXT,
+            encrypted   BOOLEAN NOT NULL DEFAULT false,
+
+            PRIMARY KEY (chat_id, receiver)
+        )
+        """
+    )
+    await conn.execute(
+        """
+        CREATE TABLE "user" (
+            mxid        TEXT PRIMARY KEY,
+            username    TEXT,
+            uuid        UUID,
+            notice_room TEXT
+        )
+        """
+    )
+    await conn.execute(
+        """
+        CREATE TABLE puppet (
+            uuid      UUID UNIQUE,
+            number    TEXT UNIQUE,
+            name      TEXT,
+
+            uuid_registered   BOOLEAN NOT NULL DEFAULT false,
+            number_registered BOOLEAN NOT NULL DEFAULT false,
+
+            custom_mxid  TEXT,
+            access_token TEXT,
+            next_batch   TEXT
+        )
+        """
+    )
+    await conn.execute(
+        """
+        CREATE TABLE user_portal (
+            "user"          TEXT,
+            portal          TEXT,
+            portal_receiver TEXT,
+            in_community    BOOLEAN NOT NULL DEFAULT false,
+
+            FOREIGN KEY (portal, portal_receiver) REFERENCES portal(chat_id, receiver)
+                ON UPDATE CASCADE ON DELETE CASCADE
+        )
+        """
+    )
+    await conn.execute(
+        """
+        CREATE TABLE message (
             mxid    TEXT NOT NULL,
             mxid    TEXT NOT NULL,
             mx_room TEXT NOT NULL,
             mx_room TEXT NOT NULL,
-            sender          TEXT,
+            sender          UUID,
             timestamp       BIGINT,
             timestamp       BIGINT,
             signal_chat_id  TEXT,
             signal_chat_id  TEXT,
             signal_receiver TEXT,
             signal_receiver TEXT,
@@ -121,16 +88,20 @@ async def upgrade_v4(conn: Connection, scheme: str) -> None:
             FOREIGN KEY (signal_chat_id, signal_receiver) REFERENCES portal(chat_id, receiver)
             FOREIGN KEY (signal_chat_id, signal_receiver) REFERENCES portal(chat_id, receiver)
                 ON UPDATE CASCADE ON DELETE CASCADE,
                 ON UPDATE CASCADE ON DELETE CASCADE,
             UNIQUE (mxid, mx_room)
             UNIQUE (mxid, mx_room)
-        )""")
-        await conn.execute("""CREATE TABLE reaction (
+        )
+        """
+    )
+    await conn.execute(
+        """
+        CREATE TABLE reaction (
             mxid    TEXT NOT NULL,
             mxid    TEXT NOT NULL,
             mx_room TEXT NOT NULL,
             mx_room TEXT NOT NULL,
 
 
             signal_chat_id  TEXT   NOT NULL,
             signal_chat_id  TEXT   NOT NULL,
             signal_receiver TEXT   NOT NULL,
             signal_receiver TEXT   NOT NULL,
-            msg_author      TEXT   NOT NULL,
+            msg_author      UUID   NOT NULL,
             msg_timestamp   BIGINT NOT NULL,
             msg_timestamp   BIGINT NOT NULL,
-            author          TEXT   NOT NULL,
+            author          UUID   NOT NULL,
 
 
             emoji TEXT NOT NULL,
             emoji TEXT NOT NULL,
 
 
@@ -139,19 +110,84 @@ async def upgrade_v4(conn: Connection, scheme: str) -> None:
                 REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver)
                 REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver)
                 ON DELETE CASCADE ON UPDATE CASCADE,
                 ON DELETE CASCADE ON UPDATE CASCADE,
             UNIQUE (mxid, mx_room)
             UNIQUE (mxid, mx_room)
-        )""")
+        )
+        """
+    )
+
+
+@upgrade_table.register(description="Add avatar info to portal table")
+async def upgrade_v2(conn: Connection) -> None:
+    await conn.execute("ALTER TABLE portal ADD COLUMN avatar_hash TEXT")
+    await conn.execute("ALTER TABLE portal ADD COLUMN avatar_url TEXT")
+
+
+@upgrade_table.register(description="Add double-puppeting base_url to puppe table")
+async def upgrade_v3(conn: Connection) -> None:
+    await conn.execute("ALTER TABLE puppet ADD COLUMN base_url TEXT")
+
+
+@upgrade_table.register(description="Allow phone numbers as message sender identifiers")
+async def upgrade_v4(conn: Connection, scheme: str) -> None:
+    if scheme == "sqlite":
+        # SQLite doesn't have anything in the tables yet,
+        # so just recreate them without migrating data
+        await conn.execute("DROP TABLE message")
+        await conn.execute("DROP TABLE reaction")
+        await conn.execute(
+            """
+            CREATE TABLE message (
+                mxid    TEXT NOT NULL,
+                mx_room TEXT NOT NULL,
+                sender          TEXT,
+                timestamp       BIGINT,
+                signal_chat_id  TEXT,
+                signal_receiver TEXT,
+
+                PRIMARY KEY (sender, timestamp, signal_chat_id, signal_receiver),
+                FOREIGN KEY (signal_chat_id, signal_receiver) REFERENCES portal(chat_id, receiver)
+                    ON UPDATE CASCADE ON DELETE CASCADE,
+                UNIQUE (mxid, mx_room)
+            )
+            """
+        )
+        await conn.execute(
+            """
+            CREATE TABLE reaction (
+                mxid    TEXT NOT NULL,
+                mx_room TEXT NOT NULL,
+
+                signal_chat_id  TEXT   NOT NULL,
+                signal_receiver TEXT   NOT NULL,
+                msg_author      TEXT   NOT NULL,
+                msg_timestamp   BIGINT NOT NULL,
+                author          TEXT   NOT NULL,
+
+                emoji TEXT NOT NULL,
+
+                PRIMARY KEY (signal_chat_id, signal_receiver, msg_author, msg_timestamp, author),
+                FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver)
+                    REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver)
+                    ON DELETE CASCADE ON UPDATE CASCADE,
+                UNIQUE (mxid, mx_room)
+            )
+            """
+        )
         return
         return
 
 
-    cname = await conn.fetchval("SELECT constraint_name FROM information_schema.table_constraints "
-                                "WHERE table_name='reaction' AND constraint_name LIKE '%_fkey'")
+    cname = await conn.fetchval(
+        "SELECT constraint_name FROM information_schema.table_constraints "
+        "WHERE table_name='reaction' AND constraint_name LIKE '%_fkey'"
+    )
     await conn.execute(f"ALTER TABLE reaction DROP CONSTRAINT {cname}")
     await conn.execute(f"ALTER TABLE reaction DROP CONSTRAINT {cname}")
     await conn.execute("ALTER TABLE reaction ALTER COLUMN msg_author SET DATA TYPE TEXT")
     await conn.execute("ALTER TABLE reaction ALTER COLUMN msg_author SET DATA TYPE TEXT")
     await conn.execute("ALTER TABLE reaction ALTER COLUMN author SET DATA TYPE TEXT")
     await conn.execute("ALTER TABLE reaction ALTER COLUMN author SET DATA TYPE TEXT")
     await conn.execute("ALTER TABLE message ALTER COLUMN sender SET DATA TYPE TEXT")
     await conn.execute("ALTER TABLE message ALTER COLUMN sender SET DATA TYPE TEXT")
-    await conn.execute(f"ALTER TABLE reaction ADD CONSTRAINT {cname} "
-                       "FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver) "
-                       "  REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver) "
-                       "  ON DELETE CASCADE ON UPDATE CASCADE")
+    await conn.execute(
+        f"ALTER TABLE reaction ADD CONSTRAINT {cname} "
+        "FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver) "
+        "  REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver) "
+        "  ON DELETE CASCADE ON UPDATE CASCADE"
+    )
 
 
 
 
 @upgrade_table.register(description="Add avatar info to puppet table")
 @upgrade_table.register(description="Add avatar info to puppet table")
@@ -179,12 +215,16 @@ async def upgrade_v7(conn: Connection) -> None:
 
 
 @upgrade_table.register(description="Add support for disappearing messages")
 @upgrade_table.register(description="Add support for disappearing messages")
 async def upgrade_v8(conn: Connection) -> None:
 async def upgrade_v8(conn: Connection) -> None:
-    await conn.execute("""CREATE TABLE disappearing_message (
-        room_id             TEXT,
-        mxid                TEXT,
-        expiration_seconds  BIGINT,
-        expiration_ts       BIGINT,
-
-        PRIMARY KEY (room_id, mxid)
-    )""")
+    await conn.execute(
+        """
+        CREATE TABLE disappearing_message (
+            room_id             TEXT,
+            mxid                TEXT,
+            expiration_seconds  BIGINT,
+            expiration_ts       BIGINT,
+
+            PRIMARY KEY (room_id, mxid)
+        )
+        """
+    )
     await conn.execute("ALTER TABLE portal ADD COLUMN expiration_time BIGINT")
     await conn.execute("ALTER TABLE portal ADD COLUMN expiration_time BIGINT")

+ 5 - 6
mautrix_signal/db/user.py

@@ -34,8 +34,7 @@ class User:
     notice_room: Optional[RoomID]
     notice_room: Optional[RoomID]
 
 
     async def insert(self) -> None:
     async def insert(self) -> None:
-        q = ('INSERT INTO "user" (mxid, username, uuid, notice_room) '
-             'VALUES ($1, $2, $3, $4)')
+        q = 'INSERT INTO "user" (mxid, username, uuid, notice_room) ' "VALUES ($1, $2, $3, $4)"
         await self.db.execute(q, self.mxid, self.username, self.uuid, self.notice_room)
         await self.db.execute(q, self.mxid, self.username, self.uuid, self.notice_room)
 
 
     async def update(self) -> None:
     async def update(self) -> None:
@@ -43,7 +42,7 @@ class User:
         await self.db.execute(q, self.username, self.uuid, self.notice_room, self.mxid)
         await self.db.execute(q, self.username, self.uuid, self.notice_room, self.mxid)
 
 
     @classmethod
     @classmethod
-    async def get_by_mxid(cls, mxid: UserID) -> Optional['User']:
+    async def get_by_mxid(cls, mxid: UserID) -> Optional["User"]:
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE mxid=$1'
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE mxid=$1'
         row = await cls.db.fetchrow(q, mxid)
         row = await cls.db.fetchrow(q, mxid)
         if not row:
         if not row:
@@ -51,7 +50,7 @@ class User:
         return cls(**row)
         return cls(**row)
 
 
     @classmethod
     @classmethod
-    async def get_by_username(cls, username: str) -> Optional['User']:
+    async def get_by_username(cls, username: str) -> Optional["User"]:
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE username=$1'
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE username=$1'
         row = await cls.db.fetchrow(q, username)
         row = await cls.db.fetchrow(q, username)
         if not row:
         if not row:
@@ -59,7 +58,7 @@ class User:
         return cls(**row)
         return cls(**row)
 
 
     @classmethod
     @classmethod
-    async def get_by_uuid(cls, uuid: UUID) -> Optional['User']:
+    async def get_by_uuid(cls, uuid: UUID) -> Optional["User"]:
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE uuid=$1'
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE uuid=$1'
         row = await cls.db.fetchrow(q, uuid)
         row = await cls.db.fetchrow(q, uuid)
         if not row:
         if not row:
@@ -67,7 +66,7 @@ class User:
         return cls(**row)
         return cls(**row)
 
 
     @classmethod
     @classmethod
-    async def all_logged_in(cls) -> List['User']:
+    async def all_logged_in(cls) -> List["User"]:
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE username IS NOT NULL'
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE username IS NOT NULL'
         rows = await cls.db.fetch(q)
         rows = await cls.db.fetch(q)
         return [cls(**row) for row in rows]
         return [cls(**row) for row in rows]

+ 24 - 11
mautrix_signal/formatter.py

@@ -19,8 +19,13 @@ import struct
 
 
 from mausignald.types import MessageData, Address, Mention
 from mausignald.types import MessageData, Address, Mention
 from mautrix.types import TextMessageEventContent, MessageType, Format
 from mautrix.types import TextMessageEventContent, MessageType, Format
-from mautrix.util.formatter import (MatrixParser as BaseMatrixParser, EntityString, SimpleEntity,
-                                    EntityType, MarkdownString)
+from mautrix.util.formatter import (
+    MatrixParser as BaseMatrixParser,
+    EntityString,
+    SimpleEntity,
+    EntityType,
+    MarkdownString,
+)
 
 
 from . import puppet as pu, user as u
 from . import puppet as pu, user as u
 
 
@@ -29,14 +34,16 @@ from . import puppet as pu, user as u
 # I don't know if this is how Signal actually calculates lengths, but it seems
 # I don't know if this is how Signal actually calculates lengths, but it seems
 # to work better than plain len()
 # to work better than plain len()
 def add_surrogate(text: str) -> str:
 def add_surrogate(text: str) -> str:
-    return ''.join(
-        ''.join(chr(y) for y in struct.unpack('<HH', x.encode('utf-16le')))
-        if (0x10000 <= ord(x) <= 0x10FFFF) else x for x in text
+    return "".join(
+        "".join(chr(y) for y in struct.unpack("<HH", x.encode("utf-16le")))
+        if (0x10000 <= ord(x) <= 0x10FFFF)
+        else x
+        for x in text
     )
     )
 
 
 
 
 def del_surrogate(text: str) -> str:
 def del_surrogate(text: str) -> str:
-    return text.encode('utf-16', 'surrogatepass').decode('utf-16')
+    return text.encode("utf-16", "surrogatepass").decode("utf-16")
 
 
 
 
 async def signal_to_matrix(message: MessageData) -> TextMessageEventContent:
 async def signal_to_matrix(message: MessageData) -> TextMessageEventContent:
@@ -47,7 +54,7 @@ async def signal_to_matrix(message: MessageData) -> TextMessageEventContent:
         html_chunks = []
         html_chunks = []
         last_offset = 0
         last_offset = 0
         for mention in message.mentions:
         for mention in message.mentions:
-            before = surrogated_text[last_offset:mention.start]
+            before = surrogated_text[last_offset : mention.start]
             last_offset = mention.start + mention.length
             last_offset = mention.start + mention.length
 
 
             text_chunks.append(before)
             text_chunks.append(before)
@@ -67,11 +74,17 @@ async def signal_to_matrix(message: MessageData) -> TextMessageEventContent:
 
 
 # TODO this has a lot of duplication with mautrix-facebook, maybe move to mautrix-python
 # TODO this has a lot of duplication with mautrix-facebook, maybe move to mautrix-python
 class SignalFormatString(EntityString[SimpleEntity, EntityType], MarkdownString):
 class SignalFormatString(EntityString[SimpleEntity, EntityType], MarkdownString):
-    def format(self, entity_type: EntityType, **kwargs) -> 'SignalFormatString':
+    def format(self, entity_type: EntityType, **kwargs) -> "SignalFormatString":
         prefix = suffix = ""
         prefix = suffix = ""
         if entity_type == EntityType.USER_MENTION:
         if entity_type == EntityType.USER_MENTION:
-            self.entities.append(SimpleEntity(type=entity_type, offset=0, length=len(self.text),
-                                              extra_info={"user_id": kwargs["user_id"]}))
+            self.entities.append(
+                SimpleEntity(
+                    type=entity_type,
+                    offset=0,
+                    length=len(self.text),
+                    extra_info={"user_id": kwargs["user_id"]},
+                )
+            )
             return self
             return self
         elif entity_type == EntityType.BOLD:
         elif entity_type == EntityType.BOLD:
             prefix = suffix = "**"
             prefix = suffix = "**"
@@ -80,7 +93,7 @@ class SignalFormatString(EntityString[SimpleEntity, EntityType], MarkdownString)
         elif entity_type == EntityType.STRIKETHROUGH:
         elif entity_type == EntityType.STRIKETHROUGH:
             prefix = suffix = "~~"
             prefix = suffix = "~~"
         elif entity_type == EntityType.URL:
         elif entity_type == EntityType.URL:
-            if kwargs['url'] != self.text:
+            if kwargs["url"] != self.text:
                 suffix = f" ({kwargs['url']})"
                 suffix = f" ({kwargs['url']})"
         elif entity_type == EntityType.PREFORMATTED:
         elif entity_type == EntityType.PREFORMATTED:
             prefix = f"```{kwargs['language']}\n"
             prefix = f"```{kwargs['language']}\n"

+ 1 - 2
mautrix_signal/get_version.py

@@ -34,8 +34,7 @@ else:
     git_revision_url = None
     git_revision_url = None
     git_tag = None
     git_tag = None
 
 
-git_tag_url = (f"https://github.com/mautrix/signal/releases/tag/{git_tag}"
-               if git_tag else None)
+git_tag_url = f"https://github.com/mautrix/signal/releases/tag/{git_tag}" if git_tag else None
 
 
 if git_tag and __version__ == git_tag[1:].replace("-", ""):
 if git_tag and __version__ == git_tag[1:].replace("-", ""):
     version = __version__
     version = __version__

+ 50 - 24
mautrix_signal/matrix.py

@@ -16,9 +16,22 @@
 from typing import List, Union, TYPE_CHECKING
 from typing import List, Union, TYPE_CHECKING
 
 
 from mautrix.bridge import BaseMatrixHandler
 from mautrix.bridge import BaseMatrixHandler
-from mautrix.types import (Event, ReactionEvent, StateEvent, RoomID, EventID, UserID, TypingEvent,
-                           ReactionEventContent, RelationType, EventType, ReceiptEvent,
-                           PresenceEvent, RedactionEvent, SingleReceiptEventContent)
+from mautrix.types import (
+    Event,
+    ReactionEvent,
+    StateEvent,
+    RoomID,
+    EventID,
+    UserID,
+    TypingEvent,
+    ReactionEventContent,
+    RelationType,
+    EventType,
+    ReceiptEvent,
+    PresenceEvent,
+    RedactionEvent,
+    SingleReceiptEventContent,
+)
 
 
 from mautrix_signal.db.disappearing_message import DisappearingMessage
 from mautrix_signal.db.disappearing_message import DisappearingMessage
 
 
@@ -30,9 +43,9 @@ if TYPE_CHECKING:
 
 
 
 
 class MatrixHandler(BaseMatrixHandler):
 class MatrixHandler(BaseMatrixHandler):
-    signal: 's.SignalHandler'
+    signal: "s.SignalHandler"
 
 
-    def __init__(self, bridge: 'SignalBridge') -> None:
+    def __init__(self, bridge: "SignalBridge") -> None:
         prefix, suffix = bridge.config["bridge.username_template"].format(userid=":").split(":")
         prefix, suffix = bridge.config["bridge.username_template"].format(userid=":").split(":")
         homeserver = bridge.config["homeserver.domain"]
         homeserver = bridge.config["homeserver.domain"]
         self.user_id_prefix = f"@{prefix}"
         self.user_id_prefix = f"@{prefix}"
@@ -41,13 +54,14 @@ class MatrixHandler(BaseMatrixHandler):
 
 
         super().__init__(bridge=bridge)
         super().__init__(bridge=bridge)
 
 
-    async def send_welcome_message(self, room_id: RoomID, inviter: 'u.User') -> None:
+    async def send_welcome_message(self, room_id: RoomID, inviter: "u.User") -> None:
         await super().send_welcome_message(room_id, inviter)
         await super().send_welcome_message(room_id, inviter)
         if not inviter.notice_room:
         if not inviter.notice_room:
             inviter.notice_room = room_id
             inviter.notice_room = room_id
             await inviter.update()
             await inviter.update()
-            await self.az.intent.send_notice(room_id, "This room has been marked as your "
-                                                      "Signal bridge notice room.")
+            await self.az.intent.send_notice(
+                room_id, "This room has been marked as your Signal bridge notice room."
+            )
 
 
     async def handle_leave(self, room_id: RoomID, user_id: UserID, event_id: EventID) -> None:
     async def handle_leave(self, room_id: RoomID, user_id: UserID, event_id: EventID) -> None:
         portal = await po.Portal.get_by_mxid(room_id)
         portal = await po.Portal.get_by_mxid(room_id)
@@ -72,11 +86,14 @@ class MatrixHandler(BaseMatrixHandler):
         await portal.handle_matrix_join(user)
         await portal.handle_matrix_join(user)
 
 
     @classmethod
     @classmethod
-    async def handle_reaction(cls, room_id: RoomID, user_id: UserID, event_id: EventID,
-                              content: ReactionEventContent) -> None:
+    async def handle_reaction(
+        cls, room_id: RoomID, user_id: UserID, event_id: EventID, content: ReactionEventContent
+    ) -> None:
         if content.relates_to.rel_type != RelationType.ANNOTATION:
         if content.relates_to.rel_type != RelationType.ANNOTATION:
-            cls.log.debug(f"Ignoring m.reaction event in {room_id} from {user_id} with unexpected "
-                          f"relation type {content.relates_to.rel_type}")
+            cls.log.debug(
+                f"Ignoring m.reaction event in {room_id} from {user_id} with unexpected "
+                f"relation type {content.relates_to.rel_type}"
+            )
             return
             return
         user = await u.User.get_by_mxid(user_id)
         user = await u.User.get_by_mxid(user_id)
         if not user:
         if not user:
@@ -86,12 +103,14 @@ class MatrixHandler(BaseMatrixHandler):
         if not portal:
         if not portal:
             return
             return
 
 
-        await portal.handle_matrix_reaction(user, event_id, content.relates_to.event_id,
-                                            content.relates_to.key)
+        await portal.handle_matrix_reaction(
+            user, event_id, content.relates_to.event_id, content.relates_to.key
+        )
 
 
     @staticmethod
     @staticmethod
-    async def handle_redaction(room_id: RoomID, user_id: UserID, event_id: EventID,
-                               redaction_event_id: EventID) -> None:
+    async def handle_redaction(
+        room_id: RoomID, user_id: UserID, event_id: EventID, redaction_event_id: EventID
+    ) -> None:
         user = await u.User.get_by_mxid(user_id)
         user = await u.User.get_by_mxid(user_id)
         if not user:
         if not user:
             return
             return
@@ -102,8 +121,13 @@ class MatrixHandler(BaseMatrixHandler):
 
 
         await portal.handle_matrix_redaction(user, event_id, redaction_event_id)
         await portal.handle_matrix_redaction(user, event_id, redaction_event_id)
 
 
-    async def handle_read_receipt(self, user: 'u.User', portal: 'po.Portal', event_id: EventID,
-                                  data: SingleReceiptEventContent) -> None:
+    async def handle_read_receipt(
+        self,
+        user: "u.User",
+        portal: "po.Portal",
+        event_id: EventID,
+        data: SingleReceiptEventContent,
+    ) -> None:
         await portal.handle_read_receipt(event_id, data)
         await portal.handle_read_receipt(event_id, data)
 
 
         message = await DBMessage.get_by_mxid(event_id, portal.mxid)
         message = await DBMessage.get_by_mxid(event_id, portal.mxid)
@@ -111,8 +135,9 @@ class MatrixHandler(BaseMatrixHandler):
             return
             return
 
 
         user.log.trace(f"Sending read receipt for {message.timestamp} to {message.sender}")
         user.log.trace(f"Sending read receipt for {message.timestamp} to {message.sender}")
-        await self.signal.send_receipt(user.username, message.sender,
-                                       timestamps=[message.timestamp], when=data.ts, read=True)
+        await self.signal.send_receipt(
+            user.username, message.sender, timestamps=[message.timestamp], when=data.ts, read=True
+        )
 
 
     async def handle_typing(self, room_id: RoomID, typing: List[UserID]) -> None:
     async def handle_typing(self, room_id: RoomID, typing: List[UserID]) -> None:
         pass
         pass
@@ -134,8 +159,9 @@ class MatrixHandler(BaseMatrixHandler):
             evt: RedactionEvent
             evt: RedactionEvent
             await self.handle_redaction(evt.room_id, evt.sender, evt.redacts, evt.event_id)
             await self.handle_redaction(evt.room_id, evt.sender, evt.redacts, evt.event_id)
 
 
-    async def handle_ephemeral_event(self, evt: Union[ReceiptEvent, PresenceEvent, TypingEvent]
-                                     ) -> None:
+    async def handle_ephemeral_event(
+        self, evt: Union[ReceiptEvent, PresenceEvent, TypingEvent]
+    ) -> None:
         if evt.type == EventType.TYPING:
         if evt.type == EventType.TYPING:
             await self.handle_typing(evt.room_id, evt.content.user_ids)
             await self.handle_typing(evt.room_id, evt.content.user_ids)
         else:
         else:
@@ -157,8 +183,8 @@ class MatrixHandler(BaseMatrixHandler):
         elif evt.type == EventType.ROOM_AVATAR:
         elif evt.type == EventType.ROOM_AVATAR:
             await portal.handle_matrix_avatar(user, evt.content.url)
             await portal.handle_matrix_avatar(user, evt.content.url)
 
 
-    async def allow_message(self, user: 'u.User') -> bool:
+    async def allow_message(self, user: "u.User") -> bool:
         return user.relay_whitelisted
         return user.relay_whitelisted
 
 
-    async def allow_bridging_message(self, user: 'u.User', portal: 'po.Portal') -> bool:
+    async def allow_bridging_message(self, user: "u.User", portal: "po.Portal") -> bool:
         return portal.has_relay or await user.is_logged_in()
         return portal.has_relay or await user.is_logged_in()

File diff suppressed because it is too large
+ 374 - 199
mautrix_signal/portal.py


+ 88 - 38
mautrix_signal/puppet.py

@@ -13,8 +13,17 @@
 #
 #
 # You should have received a copy of the GNU Affero General Public License
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import (Optional, Dict, AsyncIterable, Awaitable, AsyncGenerator, Union, Tuple,
-                    TYPE_CHECKING, cast)
+from typing import (
+    Optional,
+    Dict,
+    AsyncIterable,
+    Awaitable,
+    AsyncGenerator,
+    Union,
+    Tuple,
+    TYPE_CHECKING,
+    cast,
+)
 from uuid import UUID
 from uuid import UUID
 import hashlib
 import hashlib
 import asyncio
 import asyncio
@@ -25,8 +34,14 @@ from yarl import URL
 from mausignald.types import Address, Contact, Profile
 from mausignald.types import Address, Contact, Profile
 from mautrix.bridge import BasePuppet, async_getter_lock
 from mautrix.bridge import BasePuppet, async_getter_lock
 from mautrix.appservice import IntentAPI
 from mautrix.appservice import IntentAPI
-from mautrix.types import (UserID, SyncToken, RoomID, ContentURI, EventType,
-                           PowerLevelStateEventContent)
+from mautrix.types import (
+    UserID,
+    SyncToken,
+    RoomID,
+    ContentURI,
+    EventType,
+    PowerLevelStateEventContent,
+)
 from mautrix.errors import MForbidden
 from mautrix.errors import MForbidden
 from mautrix.util.simple_template import SimpleTemplate
 from mautrix.util.simple_template import SimpleTemplate
 
 
@@ -44,9 +59,9 @@ except ImportError:
 
 
 
 
 class Puppet(DBPuppet, BasePuppet):
 class Puppet(DBPuppet, BasePuppet):
-    by_uuid: Dict[UUID, 'Puppet'] = {}
-    by_number: Dict[str, 'Puppet'] = {}
-    by_custom_mxid: Dict[UserID, 'Puppet'] = {}
+    by_uuid: Dict[UUID, "Puppet"] = {}
+    by_number: Dict[str, "Puppet"] = {}
+    by_custom_mxid: Dict[UserID, "Puppet"] = {}
     hs_domain: str
     hs_domain: str
     mxid_template: SimpleTemplate[str]
     mxid_template: SimpleTemplate[str]
 
 
@@ -58,17 +73,37 @@ class Puppet(DBPuppet, BasePuppet):
     _uuid_lock: asyncio.Lock
     _uuid_lock: asyncio.Lock
     _update_info_lock: asyncio.Lock
     _update_info_lock: asyncio.Lock
 
 
-    def __init__(self, uuid: Optional[UUID], number: Optional[str], name: Optional[str] = None,
-                 avatar_url: Optional[ContentURI] = None, avatar_hash: Optional[str] = None,
-                 name_set: bool = False, avatar_set: bool = False, uuid_registered: bool = False,
-                 number_registered: bool = False, custom_mxid: Optional[UserID] = None,
-                 access_token: Optional[str] = None, next_batch: Optional[SyncToken] = None,
-                 base_url: Optional[URL] = None) -> None:
-        super().__init__(uuid=uuid, number=number, name=name, avatar_url=avatar_url,
-                         avatar_hash=avatar_hash, name_set=name_set, avatar_set=avatar_set,
-                         uuid_registered=uuid_registered, number_registered=number_registered,
-                         custom_mxid=custom_mxid, access_token=access_token, next_batch=next_batch,
-                         base_url=base_url)
+    def __init__(
+        self,
+        uuid: Optional[UUID],
+        number: Optional[str],
+        name: Optional[str] = None,
+        avatar_url: Optional[ContentURI] = None,
+        avatar_hash: Optional[str] = None,
+        name_set: bool = False,
+        avatar_set: bool = False,
+        uuid_registered: bool = False,
+        number_registered: bool = False,
+        custom_mxid: Optional[UserID] = None,
+        access_token: Optional[str] = None,
+        next_batch: Optional[SyncToken] = None,
+        base_url: Optional[URL] = None,
+    ) -> None:
+        super().__init__(
+            uuid=uuid,
+            number=number,
+            name=name,
+            avatar_url=avatar_url,
+            avatar_hash=avatar_hash,
+            name_set=name_set,
+            avatar_set=avatar_set,
+            uuid_registered=uuid_registered,
+            number_registered=number_registered,
+            custom_mxid=custom_mxid,
+            access_token=access_token,
+            next_batch=next_batch,
+            base_url=base_url,
+        )
         self.log = self.log.getChild(str(uuid) if uuid else number)
         self.log = self.log.getChild(str(uuid) if uuid else number)
 
 
         self.default_mxid = self.get_mxid_from_id(self.address)
         self.default_mxid = self.get_mxid_from_id(self.address)
@@ -79,25 +114,34 @@ class Puppet(DBPuppet, BasePuppet):
         self._update_info_lock = asyncio.Lock()
         self._update_info_lock = asyncio.Lock()
 
 
     @classmethod
     @classmethod
-    def init_cls(cls, bridge: 'SignalBridge') -> AsyncIterable[Awaitable[None]]:
+    def init_cls(cls, bridge: "SignalBridge") -> AsyncIterable[Awaitable[None]]:
         cls.config = bridge.config
         cls.config = bridge.config
         cls.loop = bridge.loop
         cls.loop = bridge.loop
         cls.mx = bridge.matrix
         cls.mx = bridge.matrix
         cls.az = bridge.az
         cls.az = bridge.az
         cls.hs_domain = cls.config["homeserver.domain"]
         cls.hs_domain = cls.config["homeserver.domain"]
-        cls.mxid_template = SimpleTemplate(cls.config["bridge.username_template"], "userid",
-                                           prefix="@", suffix=f":{cls.hs_domain}", type=str)
+        cls.mxid_template = SimpleTemplate(
+            cls.config["bridge.username_template"],
+            "userid",
+            prefix="@",
+            suffix=f":{cls.hs_domain}",
+            type=str,
+        )
         cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
         cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
 
 
-        cls.homeserver_url_map = {server: URL(url) for server, url
-                                  in cls.config["bridge.double_puppet_server_map"].items()}
+        cls.homeserver_url_map = {
+            server: URL(url)
+            for server, url in cls.config["bridge.double_puppet_server_map"].items()
+        }
         cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
         cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
-        cls.login_shared_secret_map = {server: secret.encode("utf-8") for server, secret
-                                       in cls.config["bridge.login_shared_secret_map"].items()}
+        cls.login_shared_secret_map = {
+            server: secret.encode("utf-8")
+            for server, secret in cls.config["bridge.login_shared_secret_map"].items()
+        }
         cls.login_device_name = "Signal Bridge"
         cls.login_device_name = "Signal Bridge"
         return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
         return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
 
 
-    def intent_for(self, portal: 'p.Portal') -> IntentAPI:
+    def intent_for(self, portal: "p.Portal") -> IntentAPI:
         if portal.chat_id == self.address:
         if portal.chat_id == self.address:
             return self.default_mxid_intent
             return self.default_mxid_intent
         return self.intent
         return self.intent
@@ -167,8 +211,10 @@ class Puppet(DBPuppet, BasePuppet):
         try:
         try:
             joined_rooms = await prev_intent.get_joined_rooms()
             joined_rooms = await prev_intent.get_joined_rooms()
         except MForbidden as e:
         except MForbidden as e:
-            self.log.debug(f"Got MForbidden ({e.message}) when getting joined rooms of old mxid, "
-                           "assuming there are no rooms to rejoin")
+            self.log.debug(
+                f"Got MForbidden ({e.message}) when getting joined rooms of old mxid, "
+                "assuming there are no rooms to rejoin"
+            )
             return
             return
         for room_id in joined_rooms:
         for room_id in joined_rooms:
             await prev_intent.invite_user(room_id, self.default_mxid)
             await prev_intent.invite_user(room_id, self.default_mxid)
@@ -176,8 +222,9 @@ class Puppet(DBPuppet, BasePuppet):
             await prev_intent.leave_room(room_id)
             await prev_intent.leave_room(room_id)
             await new_intent.join_room_by_id(room_id)
             await new_intent.join_room_by_id(room_id)
 
 
-    async def _migrate_powers(self, prev_intent: IntentAPI, new_intent: IntentAPI, room_id: RoomID
-                              ) -> None:
+    async def _migrate_powers(
+        self, prev_intent: IntentAPI, new_intent: IntentAPI, room_id: RoomID
+    ) -> None:
         try:
         try:
             powers: PowerLevelStateEventContent
             powers: PowerLevelStateEventContent
             powers = await prev_intent.get_state_event(room_id, EventType.ROOM_POWER_LEVELS)
             powers = await prev_intent.get_state_event(room_id, EventType.ROOM_POWER_LEVELS)
@@ -260,8 +307,11 @@ class Puppet(DBPuppet, BasePuppet):
         return False
         return False
 
 
     @staticmethod
     @staticmethod
-    async def upload_avatar(self: Union['Puppet', 'p.Portal'], path: str, intent: IntentAPI,
-                            ) -> Union[bool, Tuple[str, ContentURI]]:
+    async def upload_avatar(
+        self: Union["Puppet", "p.Portal"],
+        path: str,
+        intent: IntentAPI,
+    ) -> Union[bool, Tuple[str, ContentURI]]:
         if not path:
         if not path:
             return False
             return False
         if not path.startswith("/"):
         if not path.startswith("/"):
@@ -321,7 +371,7 @@ class Puppet(DBPuppet, BasePuppet):
         await self.update()
         await self.update()
 
 
     @classmethod
     @classmethod
-    async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']:
+    async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional["Puppet"]:
         address = cls.get_id_from_mxid(mxid)
         address = cls.get_id_from_mxid(mxid)
         if not address:
         if not address:
             return None
             return None
@@ -329,7 +379,7 @@ class Puppet(DBPuppet, BasePuppet):
 
 
     @classmethod
     @classmethod
     @async_getter_lock
     @async_getter_lock
-    async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
+    async def get_by_custom_mxid(cls, mxid: UserID) -> Optional["Puppet"]:
         try:
         try:
             return cls.by_custom_mxid[mxid]
             return cls.by_custom_mxid[mxid]
         except KeyError:
         except KeyError:
@@ -348,7 +398,7 @@ class Puppet(DBPuppet, BasePuppet):
         if not identifier:
         if not identifier:
             return None
             return None
         if identifier.startswith("phone_"):
         if identifier.startswith("phone_"):
-            return Address(number="+" + identifier[len("phone_"):])
+            return Address(number="+" + identifier[len("phone_") :])
         else:
         else:
             try:
             try:
                 return Address(uuid=UUID(identifier.upper()))
                 return Address(uuid=UUID(identifier.upper()))
@@ -367,7 +417,7 @@ class Puppet(DBPuppet, BasePuppet):
 
 
     @classmethod
     @classmethod
     @async_getter_lock
     @async_getter_lock
-    async def get_by_address(cls, address: Address, create: bool = True) -> Optional['Puppet']:
+    async def get_by_address(cls, address: Address, create: bool = True) -> Optional["Puppet"]:
         puppet = await cls._get_by_address(address, create)
         puppet = await cls._get_by_address(address, create)
         if puppet and address.uuid and not puppet.uuid:
         if puppet and address.uuid and not puppet.uuid:
             # We found a UUID for this user, store it ASAP
             # We found a UUID for this user, store it ASAP
@@ -375,7 +425,7 @@ class Puppet(DBPuppet, BasePuppet):
         return puppet
         return puppet
 
 
     @classmethod
     @classmethod
-    async def _get_by_address(cls, address: Address, create: bool = True) -> Optional['Puppet']:
+    async def _get_by_address(cls, address: Address, create: bool = True) -> Optional["Puppet"]:
         if not address.is_valid:
         if not address.is_valid:
             raise ValueError("Empty address")
             raise ValueError("Empty address")
         if address.uuid:
         if address.uuid:
@@ -403,7 +453,7 @@ class Puppet(DBPuppet, BasePuppet):
         return None
         return None
 
 
     @classmethod
     @classmethod
-    async def all_with_custom_mxid(cls) -> AsyncGenerator['Puppet', None]:
+    async def all_with_custom_mxid(cls) -> AsyncGenerator["Puppet", None]:
         puppets = await super().all_with_custom_mxid()
         puppets = await super().all_with_custom_mxid()
         puppet: cls
         puppet: cls
         for index, puppet in enumerate(puppets):
         for index, puppet in enumerate(puppets):

+ 52 - 25
mautrix_signal/signal.py

@@ -18,9 +18,17 @@ import asyncio
 import logging
 import logging
 
 
 from mausignald import SignaldClient
 from mausignald import SignaldClient
-from mausignald.types import (Message, MessageData, Address, TypingNotification, TypingAction,
-                              OwnReadReceipt, Receipt, ReceiptType,
-                              WebsocketConnectionStateChangeEvent)
+from mausignald.types import (
+    Message,
+    MessageData,
+    Address,
+    TypingNotification,
+    TypingAction,
+    OwnReadReceipt,
+    Receipt,
+    ReceiptType,
+    WebsocketConnectionStateChangeEvent,
+)
 from mautrix.util.logging import TraceLogger
 from mautrix.util.logging import TraceLogger
 
 
 from .db import Message as DBMessage
 from .db import Message as DBMessage
@@ -39,13 +47,14 @@ class SignalHandler(SignaldClient):
     data_dir: str
     data_dir: str
     delete_unknown_accounts: bool
     delete_unknown_accounts: bool
 
 
-    def __init__(self, bridge: 'SignalBridge') -> None:
+    def __init__(self, bridge: "SignalBridge") -> None:
         super().__init__(bridge.config["signal.socket_path"], loop=bridge.loop)
         super().__init__(bridge.config["signal.socket_path"], loop=bridge.loop)
         self.data_dir = bridge.config["signal.data_dir"]
         self.data_dir = bridge.config["signal.data_dir"]
         self.delete_unknown_accounts = bridge.config["signal.delete_unknown_accounts_on_start"]
         self.delete_unknown_accounts = bridge.config["signal.delete_unknown_accounts_on_start"]
         self.add_event_handler(Message, self.on_message)
         self.add_event_handler(Message, self.on_message)
-        self.add_event_handler(WebsocketConnectionStateChangeEvent,
-                               self.on_websocket_connection_state_change)
+        self.add_event_handler(
+            WebsocketConnectionStateChangeEvent, self.on_websocket_connection_state_change
+        )
 
 
     async def on_message(self, evt: Message) -> None:
     async def on_message(self, evt: Message) -> None:
         sender = await pu.Puppet.get_by_address(evt.source)
         sender = await pu.Puppet.get_by_address(evt.source)
@@ -62,8 +71,12 @@ class SignalHandler(SignaldClient):
             if evt.sync_message.read_messages:
             if evt.sync_message.read_messages:
                 await self.handle_own_receipts(sender, evt.sync_message.read_messages)
                 await self.handle_own_receipts(sender, evt.sync_message.read_messages)
             if evt.sync_message.sent:
             if evt.sync_message.sent:
-                await self.handle_message(user, sender, evt.sync_message.sent.message,
-                                          addr_override=evt.sync_message.sent.destination)
+                await self.handle_message(
+                    user,
+                    sender,
+                    evt.sync_message.sent.message,
+                    addr_override=evt.sync_message.sent.destination,
+                )
             if evt.sync_message.typing:
             if evt.sync_message.typing:
                 # Typing notification from own device
                 # Typing notification from own device
                 pass
                 pass
@@ -75,12 +88,19 @@ class SignalHandler(SignaldClient):
                 await user.sync_groups()
                 await user.sync_groups()
 
 
     @staticmethod
     @staticmethod
-    async def on_websocket_connection_state_change(evt: WebsocketConnectionStateChangeEvent) -> None:
+    async def on_websocket_connection_state_change(
+        evt: WebsocketConnectionStateChangeEvent,
+    ) -> None:
         user = await u.User.get_by_username(evt.account)
         user = await u.User.get_by_username(evt.account)
         user.on_websocket_connection_state_change(evt)
         user.on_websocket_connection_state_change(evt)
 
 
-    async def handle_message(self, user: 'u.User', sender: 'pu.Puppet', msg: MessageData,
-                             addr_override: Optional[Address] = None) -> None:
+    async def handle_message(
+        self,
+        user: "u.User",
+        sender: "pu.Puppet",
+        msg: MessageData,
+        addr_override: Optional[Address] = None,
+    ) -> None:
         if msg.profile_key_update:
         if msg.profile_key_update:
             self.log.debug("Ignoring profile key update")
             self.log.debug("Ignoring profile key update")
             return
             return
@@ -89,18 +109,23 @@ class SignalHandler(SignaldClient):
         elif msg.group:
         elif msg.group:
             portal = await po.Portal.get_by_chat_id(msg.group.group_id, create=True)
             portal = await po.Portal.get_by_chat_id(msg.group.group_id, create=True)
         else:
         else:
-            portal = await po.Portal.get_by_chat_id(addr_override or sender.address,
-                                                    receiver=user.username, create=True)
+            portal = await po.Portal.get_by_chat_id(
+                addr_override or sender.address, receiver=user.username, create=True
+            )
             if addr_override and not sender.is_real_user:
             if addr_override and not sender.is_real_user:
-                portal.log.debug(f"Ignoring own message {msg.timestamp} as user doesn't have"
-                                 " double puppeting enabled")
+                portal.log.debug(
+                    f"Ignoring own message {msg.timestamp} as user doesn't have double puppeting "
+                    "enabled"
+                )
                 return
                 return
         if not portal.mxid:
         if not portal.mxid:
-            await portal.create_matrix_room(user, (msg.group_v2 or msg.group
-                                                   or addr_override or sender.address))
+            await portal.create_matrix_room(
+                user, (msg.group_v2 or msg.group or addr_override or sender.address)
+            )
             if not portal.mxid:
             if not portal.mxid:
-                user.log.debug(f"Failed to create room for incoming message {msg.timestamp},"
-                               " dropping message")
+                user.log.debug(
+                    f"Failed to create room for incoming message {msg.timestamp}, dropping message"
+                )
                 return
                 return
         elif msg.group_v2 and msg.group_v2.revision > portal.revision:
         elif msg.group_v2 and msg.group_v2.revision > portal.revision:
             self.log.debug(f"Got new revision of {msg.group_v2.id}, updating info")
             self.log.debug(f"Got new revision of {msg.group_v2.id}, updating info")
@@ -117,7 +142,7 @@ class SignalHandler(SignaldClient):
             await portal.update_expires_in_seconds(sender, msg.expires_in_seconds)
             await portal.update_expires_in_seconds(sender, msg.expires_in_seconds)
 
 
     @staticmethod
     @staticmethod
-    async def handle_own_receipts(sender: 'pu.Puppet', receipts: List[OwnReadReceipt]) -> None:
+    async def handle_own_receipts(sender: "pu.Puppet", receipts: List[OwnReadReceipt]) -> None:
         for receipt in receipts:
         for receipt in receipts:
             puppet = await pu.Puppet.get_by_address(receipt.sender, create=False)
             puppet = await pu.Puppet.get_by_address(receipt.sender, create=False)
             if not puppet:
             if not puppet:
@@ -131,8 +156,9 @@ class SignalHandler(SignaldClient):
             await sender.intent_for(portal).mark_read(portal.mxid, message.mxid)
             await sender.intent_for(portal).mark_read(portal.mxid, message.mxid)
 
 
     @staticmethod
     @staticmethod
-    async def handle_typing(user: 'u.User', sender: 'pu.Puppet',
-                            typing: TypingNotification) -> None:
+    async def handle_typing(
+        user: "u.User", sender: "pu.Puppet", typing: TypingNotification
+    ) -> None:
         if typing.group_id:
         if typing.group_id:
             portal = await po.Portal.get_by_chat_id(typing.group_id)
             portal = await po.Portal.get_by_chat_id(typing.group_id)
         else:
         else:
@@ -140,11 +166,12 @@ class SignalHandler(SignaldClient):
         if not portal or not portal.mxid:
         if not portal or not portal.mxid:
             return
             return
         is_typing = typing.action == TypingAction.STARTED
         is_typing = typing.action == TypingAction.STARTED
-        await sender.intent_for(portal).set_typing(portal.mxid, is_typing, ignore_cache=True,
-                                                   timeout=SIGNAL_TYPING_TIMEOUT)
+        await sender.intent_for(portal).set_typing(
+            portal.mxid, is_typing, ignore_cache=True, timeout=SIGNAL_TYPING_TIMEOUT
+        )
 
 
     @staticmethod
     @staticmethod
-    async def handle_receipt(sender: 'pu.Puppet', receipt: Receipt) -> None:
+    async def handle_receipt(sender: "pu.Puppet", receipt: Receipt) -> None:
         if receipt.type != ReceiptType.READ:
         if receipt.type != ReceiptType.READ:
             return
             return
         messages = await DBMessage.find_by_timestamps(receipt.timestamps)
         messages = await DBMessage.find_by_timestamps(receipt.timestamps)

+ 48 - 29
mautrix_signal/user.py

@@ -19,8 +19,15 @@ from typing import Union, Dict, Optional, AsyncGenerator, List, TYPE_CHECKING, c
 from uuid import UUID
 from uuid import UUID
 import asyncio
 import asyncio
 
 
-from mausignald.types import (Account, Address, Profile, Group, GroupV2, WebsocketConnectionState,
-                              WebsocketConnectionStateChangeEvent)
+from mausignald.types import (
+    Account,
+    Address,
+    Profile,
+    Group,
+    GroupV2,
+    WebsocketConnectionState,
+    WebsocketConnectionStateChangeEvent,
+)
 from mautrix.bridge import BaseUser, AutologinError, async_getter_lock
 from mautrix.bridge import BaseUser, AutologinError, async_getter_lock
 from mautrix.types import UserID, RoomID
 from mautrix.types import UserID, RoomID
 from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
 from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
@@ -34,23 +41,25 @@ from . import puppet as pu, portal as po
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .__main__ import SignalBridge
     from .__main__ import SignalBridge
 
 
-METRIC_CONNECTED = Gauge('bridge_connected', 'Bridge users connected to Signal')
-METRIC_LOGGED_IN = Gauge('bridge_logged_in', 'Bridge users logged into Signal')
+METRIC_CONNECTED = Gauge("bridge_connected", "Bridge users connected to Signal")
+METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Bridge users logged into Signal")
 
 
-BridgeState.human_readable_errors.update({
-    "logged-out": "You're not logged into Signal",
-    "signal-not-connected": None,
-})
+BridgeState.human_readable_errors.update(
+    {
+        "logged-out": "You're not logged into Signal",
+        "signal-not-connected": None,
+    }
+)
 
 
 
 
 class User(DBUser, BaseUser):
 class User(DBUser, BaseUser):
-    by_mxid: Dict[UserID, 'User'] = {}
-    by_username: Dict[str, 'User'] = {}
-    by_uuid: Dict[UUID, 'User'] = {}
+    by_mxid: Dict[UserID, "User"] = {}
+    by_username: Dict[str, "User"] = {}
+    by_uuid: Dict[UUID, "User"] = {}
     config: Config
     config: Config
     az: AppService
     az: AppService
     loop: asyncio.AbstractEventLoop
     loop: asyncio.AbstractEventLoop
-    bridge: 'SignalBridge'
+    bridge: "SignalBridge"
 
 
     relay_whitelisted: bool
     relay_whitelisted: bool
     is_admin: bool
     is_admin: bool
@@ -62,8 +71,13 @@ class User(DBUser, BaseUser):
     _websocket_connection_state: Optional[WebsocketConnectionState]
     _websocket_connection_state: Optional[WebsocketConnectionState]
     _latest_non_transient_disconnect_state: Optional[datetime]
     _latest_non_transient_disconnect_state: Optional[datetime]
 
 
-    def __init__(self, mxid: UserID, username: Optional[str] = None, uuid: Optional[UUID] = None,
-                 notice_room: Optional[RoomID] = None) -> None:
+    def __init__(
+        self,
+        mxid: UserID,
+        username: Optional[str] = None,
+        uuid: Optional[UUID] = None,
+        notice_room: Optional[RoomID] = None,
+    ) -> None:
         super().__init__(mxid=mxid, username=username, uuid=uuid, notice_room=notice_room)
         super().__init__(mxid=mxid, username=username, uuid=uuid, notice_room=notice_room)
         BaseUser.__init__(self)
         BaseUser.__init__(self)
         self._notice_room_lock = asyncio.Lock()
         self._notice_room_lock = asyncio.Lock()
@@ -74,7 +88,7 @@ class User(DBUser, BaseUser):
         self.relay_whitelisted, self.is_whitelisted, self.is_admin, self.permission_level = perms
         self.relay_whitelisted, self.is_whitelisted, self.is_admin, self.permission_level = perms
 
 
     @classmethod
     @classmethod
-    def init_cls(cls, bridge: 'SignalBridge') -> None:
+    def init_cls(cls, bridge: "SignalBridge") -> None:
         cls.bridge = bridge
         cls.bridge = bridge
         cls.config = bridge.config
         cls.config = bridge.config
         cls.az = bridge.az
         cls.az = bridge.az
@@ -125,7 +139,7 @@ class User(DBUser, BaseUser):
             state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
             state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
         return [state]
         return [state]
 
 
-    async def get_puppet(self) -> Optional['pu.Puppet']:
+    async def get_puppet(self) -> Optional["pu.Puppet"]:
         if not self.address:
         if not self.address:
             return None
             return None
         return await pu.Puppet.get_by_address(self.address)
         return await pu.Puppet.get_by_address(self.address)
@@ -139,7 +153,9 @@ class User(DBUser, BaseUser):
         asyncio.create_task(self.sync())
         asyncio.create_task(self.sync())
         self._track_metric(METRIC_LOGGED_IN, True)
         self._track_metric(METRIC_LOGGED_IN, True)
 
 
-    def on_websocket_connection_state_change(self, evt: WebsocketConnectionStateChangeEvent) -> None:
+    def on_websocket_connection_state_change(
+        self, evt: WebsocketConnectionStateChangeEvent
+    ) -> None:
         if evt.state == WebsocketConnectionState.CONNECTED:
         if evt.state == WebsocketConnectionState.CONNECTED:
             self.log.info("Connected to Signal")
             self.log.info("Connected to Signal")
             self._track_metric(METRIC_CONNECTED, True)
             self._track_metric(METRIC_CONNECTED, True)
@@ -147,14 +163,14 @@ class User(DBUser, BaseUser):
             self._connected = True
             self._connected = True
         else:
         else:
             self.log.warning(
             self.log.warning(
-                f"New websocket state from signald: {evt.state}. Error: {evt.exception}")
+                f"New websocket state from signald: {evt.state}. Error: {evt.exception}"
+            )
             self._track_metric(METRIC_CONNECTED, False)
             self._track_metric(METRIC_CONNECTED, False)
             self._connected = False
             self._connected = False
 
 
         bridge_state = {
         bridge_state = {
             # Signald disconnected
             # Signald disconnected
             WebsocketConnectionState.SOCKET_DISCONNECTED: BridgeStateEvent.TRANSIENT_DISCONNECT,
             WebsocketConnectionState.SOCKET_DISCONNECTED: BridgeStateEvent.TRANSIENT_DISCONNECT,
-
             # Websocket state reported by signald
             # Websocket state reported by signald
             WebsocketConnectionState.DISCONNECTED: (
             WebsocketConnectionState.DISCONNECTED: (
                 None
                 None
@@ -174,6 +190,7 @@ class User(DBUser, BaseUser):
 
 
         now = datetime.now()
         now = datetime.now()
         if bridge_state == BridgeStateEvent.TRANSIENT_DISCONNECT:
         if bridge_state == BridgeStateEvent.TRANSIENT_DISCONNECT:
+
             async def wait_report_transient_disconnect():
             async def wait_report_transient_disconnect():
                 # Wait for 10 seconds (that should be enough for the bridge to get connected)
                 # Wait for 10 seconds (that should be enough for the bridge to get connected)
                 # before sending a TRANSIENT_DISCONNECT.
                 # before sending a TRANSIENT_DISCONNECT.
@@ -212,7 +229,7 @@ class User(DBUser, BaseUser):
             self.uuid = puppet.uuid
             self.uuid = puppet.uuid
             self.by_uuid[self.uuid] = self
             self.by_uuid[self.uuid] = self
         if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
         if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
-            self.log.info(f"Automatically enabling custom puppet")
+            self.log.info("Automatically enabling custom puppet")
             try:
             try:
                 await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
                 await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
             except AutologinError as e:
             except AutologinError as e:
@@ -244,8 +261,9 @@ class User(DBUser, BaseUser):
         except Exception:
         except Exception:
             self.log.exception("Error while syncing groups")
             self.log.exception("Error while syncing groups")
 
 
-    async def sync_contact(self, contact: Union[Profile, Address], create_portals: bool = False
-                           ) -> None:
+    async def sync_contact(
+        self, contact: Union[Profile, Address], create_portals: bool = False
+    ) -> None:
         self.log.trace("Syncing contact %s", contact)
         self.log.trace("Syncing contact %s", contact)
         if isinstance(contact, Address):
         if isinstance(contact, Address):
             address = contact
             address = contact
@@ -258,8 +276,9 @@ class User(DBUser, BaseUser):
         puppet = await pu.Puppet.get_by_address(address)
         puppet = await pu.Puppet.get_by_address(address)
         await puppet.update_info(profile)
         await puppet.update_info(profile)
         if create_portals:
         if create_portals:
-            portal = await po.Portal.get_by_chat_id(puppet.address, receiver=self.username,
-                                                    create=True)
+            portal = await po.Portal.get_by_chat_id(
+                puppet.address, receiver=self.username, create=True
+            )
             await portal.create_matrix_room(self, profile)
             await portal.create_matrix_room(self, profile)
 
 
     async def _sync_group(self, group: Group, create_portals: bool) -> None:
     async def _sync_group(self, group: Group, create_portals: bool) -> None:
@@ -311,7 +330,7 @@ class User(DBUser, BaseUser):
 
 
     @classmethod
     @classmethod
     @async_getter_lock
     @async_getter_lock
-    async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['User']:
+    async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional["User"]:
         # Never allow ghosts to be users
         # Never allow ghosts to be users
         if pu.Puppet.get_id_from_mxid(mxid):
         if pu.Puppet.get_id_from_mxid(mxid):
             return None
             return None
@@ -335,7 +354,7 @@ class User(DBUser, BaseUser):
 
 
     @classmethod
     @classmethod
     @async_getter_lock
     @async_getter_lock
-    async def get_by_username(cls, username: str) -> Optional['User']:
+    async def get_by_username(cls, username: str) -> Optional["User"]:
         try:
         try:
             return cls.by_username[username]
             return cls.by_username[username]
         except KeyError:
         except KeyError:
@@ -350,7 +369,7 @@ class User(DBUser, BaseUser):
 
 
     @classmethod
     @classmethod
     @async_getter_lock
     @async_getter_lock
-    async def get_by_uuid(cls, uuid: UUID) -> Optional['User']:
+    async def get_by_uuid(cls, uuid: UUID) -> Optional["User"]:
         try:
         try:
             return cls.by_uuid[uuid]
             return cls.by_uuid[uuid]
         except KeyError:
         except KeyError:
@@ -364,7 +383,7 @@ class User(DBUser, BaseUser):
         return None
         return None
 
 
     @classmethod
     @classmethod
-    async def get_by_address(cls, address: Address) -> Optional['User']:
+    async def get_by_address(cls, address: Address) -> Optional["User"]:
         if address.uuid:
         if address.uuid:
             return await cls.get_by_uuid(address.uuid)
             return await cls.get_by_uuid(address.uuid)
         elif address.number:
         elif address.number:
@@ -373,7 +392,7 @@ class User(DBUser, BaseUser):
             raise ValueError("Given address is blank")
             raise ValueError("Given address is blank")
 
 
     @classmethod
     @classmethod
-    async def all_logged_in(cls) -> AsyncGenerator['User', None]:
+    async def all_logged_in(cls) -> AsyncGenerator["User", None]:
         users = await super().all_logged_in()
         users = await super().all_logged_in()
         user: cls
         user: cls
         for user in users:
         for user in users:

+ 43 - 30
mautrix_signal/web/provisioning_api.py

@@ -34,9 +34,9 @@ if TYPE_CHECKING:
 class ProvisioningAPI:
 class ProvisioningAPI:
     log: TraceLogger = logging.getLogger("mau.web.provisioning")
     log: TraceLogger = logging.getLogger("mau.web.provisioning")
     app: web.Application
     app: web.Application
-    bridge: 'SignalBridge'
+    bridge: "SignalBridge"
 
 
-    def __init__(self, bridge: 'SignalBridge', shared_secret: str) -> None:
+    def __init__(self, bridge: "SignalBridge", shared_secret: str) -> None:
         self.bridge = bridge
         self.bridge = bridge
         self.app = web.Application()
         self.app = web.Application()
         self.shared_secret = shared_secret
         self.shared_secret = shared_secret
@@ -70,23 +70,26 @@ class ProvisioningAPI:
     async def login_options(self, _: web.Request) -> web.Response:
     async def login_options(self, _: web.Request) -> web.Response:
         return web.Response(status=200, headers=self._headers)
         return web.Response(status=200, headers=self._headers)
 
 
-    async def check_token(self, request: web.Request) -> 'u.User':
+    async def check_token(self, request: web.Request) -> "u.User":
         try:
         try:
             token = request.headers["Authorization"]
             token = request.headers["Authorization"]
-            token = token[len("Bearer "):]
+            token = token[len("Bearer ") :]
         except KeyError:
         except KeyError:
-            raise web.HTTPBadRequest(text='{"error": "Missing Authorization header"}',
-                                     headers=self._headers)
+            raise web.HTTPBadRequest(
+                text='{"error": "Missing Authorization header"}', headers=self._headers
+            )
         except IndexError:
         except IndexError:
-            raise web.HTTPBadRequest(text='{"error": "Malformed Authorization header"}',
-                                     headers=self._headers)
+            raise web.HTTPBadRequest(
+                text='{"error": "Malformed Authorization header"}', headers=self._headers
+            )
         if token != self.shared_secret:
         if token != self.shared_secret:
             raise web.HTTPForbidden(text='{"error": "Invalid token"}', headers=self._headers)
             raise web.HTTPForbidden(text='{"error": "Invalid token"}', headers=self._headers)
         try:
         try:
             user_id = request.query["user_id"]
             user_id = request.query["user_id"]
         except KeyError:
         except KeyError:
-            raise web.HTTPBadRequest(text='{"error": "Missing user_id query param"}',
-                                     headers=self._headers)
+            raise web.HTTPBadRequest(
+                text='{"error": "Missing user_id query param"}', headers=self._headers
+            )
 
 
         if not self.bridge.signal.is_connected:
         if not self.bridge.signal.is_connected:
             await self.bridge.signal.wait_for_connected()
             await self.bridge.signal.wait_for_connected()
@@ -103,7 +106,8 @@ class ProvisioningAPI:
         if await user.is_logged_in():
         if await user.is_logged_in():
             try:
             try:
                 profile = await self.bridge.signal.get_profile(
                 profile = await self.bridge.signal.get_profile(
-                    username=user.username, address=Address(number=user.username))
+                    username=user.username, address=Address(number=user.username)
+                )
             except Exception as e:
             except Exception as e:
                 self.log.exception(f"Failed to get {user.username}'s profile for whoami")
                 self.log.exception(f"Failed to get {user.username}'s profile for whoami")
                 data["signal"] = {
                 data["signal"] = {
@@ -127,8 +131,9 @@ class ProvisioningAPI:
         user = await self.check_token(request)
         user = await self.check_token(request)
 
 
         if await user.is_logged_in():
         if await user.is_logged_in():
-            raise web.HTTPConflict(text='''{"error": "You're already logged in"}''',
-                                   headers=self._headers)
+            raise web.HTTPConflict(
+                text="""{"error": "You're already logged in"}""", headers=self._headers
+            )
 
 
         try:
         try:
             data = await request.json()
             data = await request.json()
@@ -147,17 +152,19 @@ class ProvisioningAPI:
         self.log.debug(f"Returning linking URI for {user.mxid} / {sess.session_id}")
         self.log.debug(f"Returning linking URI for {user.mxid} / {sess.session_id}")
         return web.json_response({"uri": sess.uri}, headers=self._acao_headers)
         return web.json_response({"uri": sess.uri}, headers=self._acao_headers)
 
 
-    async def _shielded_link(self, user: 'u.User', session_id: str, device_name: str) -> Account:
+    async def _shielded_link(self, user: "u.User", session_id: str, device_name: str) -> Account:
         try:
         try:
             self.log.debug(f"Starting finish link request for {user.mxid} / {session_id}")
             self.log.debug(f"Starting finish link request for {user.mxid} / {session_id}")
-            account = await self.bridge.signal.finish_link(session_id=session_id, overwrite=True,
-                                                           device_name=device_name)
+            account = await self.bridge.signal.finish_link(
+                session_id=session_id, overwrite=True, device_name=device_name
+            )
         except TimeoutException:
         except TimeoutException:
             self.log.warning(f"Timed out waiting for linking to finish (session {session_id})")
             self.log.warning(f"Timed out waiting for linking to finish (session {session_id})")
             raise
             raise
         except Exception:
         except Exception:
-            self.log.exception("Fatal error while waiting for linking to finish "
-                               f"(session {session_id})")
+            self.log.exception(
+                "Fatal error while waiting for linking to finish " f"(session {session_id})"
+            )
             raise
             raise
         else:
         else:
             await user.on_signin(account)
             await user.on_signin(account)
@@ -166,36 +173,42 @@ class ProvisioningAPI:
     async def link_wait(self, request: web.Request) -> web.Response:
     async def link_wait(self, request: web.Request) -> web.Response:
         user = await self.check_token(request)
         user = await self.check_token(request)
         if not user.command_status or user.command_status["action"] != "Link":
         if not user.command_status or user.command_status["action"] != "Link":
-            raise web.HTTPBadRequest(text='{"error": "No Signal linking started"}',
-                                     headers=self._headers)
+            raise web.HTTPBadRequest(
+                text='{"error": "No Signal linking started"}', headers=self._headers
+            )
         session_id = user.command_status["session_id"]
         session_id = user.command_status["session_id"]
         device_name = user.command_status["device_name"]
         device_name = user.command_status["device_name"]
         try:
         try:
             account = await asyncio.shield(self._shielded_link(user, session_id, device_name))
             account = await asyncio.shield(self._shielded_link(user, session_id, device_name))
         except asyncio.CancelledError:
         except asyncio.CancelledError:
-            self.log.warning(f"Client cancelled link wait request ({session_id})"
-                             " before it finished")
+            self.log.warning(
+                f"Client cancelled link wait request ({session_id})" " before it finished"
+            )
         except TimeoutException:
         except TimeoutException:
-            raise web.HTTPBadRequest(text='{"error": "Signal linking timed out"}',
-                                     headers=self._headers)
+            raise web.HTTPBadRequest(
+                text='{"error": "Signal linking timed out"}', headers=self._headers
+            )
         except InternalError as ie:
         except InternalError as ie:
             if "java.io.IOException" in ie.exceptions:
             if "java.io.IOException" in ie.exceptions:
                 raise web.HTTPBadRequest(
                 raise web.HTTPBadRequest(
                     text='{"error": "Signald websocket disconnected before linking finished"}',
                     text='{"error": "Signald websocket disconnected before linking finished"}',
                     headers=self._headers,
                     headers=self._headers,
                 )
                 )
-            raise web.HTTPInternalServerError(text='{"error": "Fatal error in Signal linking"}',
-                                              headers=self._headers)
+            raise web.HTTPInternalServerError(
+                text='{"error": "Fatal error in Signal linking"}', headers=self._headers
+            )
         except Exception:
         except Exception:
-            raise web.HTTPInternalServerError(text='{"error": "Fatal error in Signal linking"}',
-                                              headers=self._headers)
+            raise web.HTTPInternalServerError(
+                text='{"error": "Fatal error in Signal linking"}', headers=self._headers
+            )
         else:
         else:
             return web.json_response(account.address.serialize())
             return web.json_response(account.address.serialize())
 
 
     async def logout(self, request: web.Request) -> web.Response:
     async def logout(self, request: web.Request) -> web.Response:
         user = await self.check_token(request)
         user = await self.check_token(request)
         if not await user.is_logged_in():
         if not await user.is_logged_in():
-            raise web.HTTPNotFound(text='''{"error": "You're not logged in"}''',
-                                   headers=self._headers)
+            raise web.HTTPNotFound(
+                text="""{"error": "You're not logged in"}""", headers=self._headers
+            )
         await user.logout()
         await user.logout()
         return web.json_response({}, headers=self._acao_headers)
         return web.json_response({}, headers=self._acao_headers)

+ 10 - 0
pyproject.toml

@@ -0,0 +1,10 @@
+[tool.isort]
+profile = "black"
+force_to_top = "typing"
+from_first = true
+line_length = 99
+
+[tool.black]
+line-length = 99
+target-version = ["py38"]
+required-version = "21.12b0"

+ 5 - 0
setup.cfg

@@ -0,0 +1,5 @@
+[flake8]
+max-line-length = 99
+extend-ignore =
+    # See https://github.com/PyCQA/pycodestyle/issues/373
+    E203,

Some files were not shown because too many files changed in this diff