Browse Source

Update mausignald to use new type hints

Tulir Asokan 3 years ago
parent
commit
e19b56889f
3 changed files with 67 additions and 64 deletions
  1. 11 9
      mausignald/errors.py
  2. 18 16
      mausignald/rpc.py
  3. 38 39
      mausignald/signald.py

+ 11 - 9
mausignald/errors.py

@@ -1,9 +1,11 @@
-# Copyright (c) 2020 Tulir Asokan
+# Copyright (c) 2022 Tulir Asokan
 #
 # This Source Code Form is subject to the terms of the Mozilla Public
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
-from typing import Any, Dict, Optional
+from __future__ import annotations
+
+from typing import Any
 
 
 class RPCError(Exception):
@@ -28,9 +30,9 @@ class NotConnected(RPCError):
 class ResponseError(RPCError):
     def __init__(
         self,
-        data: Dict[str, Any],
-        error_type: Optional[str] = None,
-        message_override: Optional[str] = None,
+        data: dict[str, Any],
+        error_type: str | None = None,
+        message_override: str | None = None,
     ) -> None:
         self.data = data
         msg = message_override or data["message"]
@@ -56,19 +58,19 @@ class AuthorizationFailedException(ResponseError):
 
 
 class UserAlreadyExistsError(ResponseError):
-    def __init__(self, data: Dict[str, Any]) -> None:
+    def __init__(self, data: dict[str, Any]) -> None:
         super().__init__(data, message_override="You're already logged in")
 
 
 class RequestValidationFailure(ResponseError):
-    def __init__(self, data: Dict[str, Any]) -> None:
+    def __init__(self, data: dict[str, Any]) -> None:
         results = data["validationResults"]
         result_str = ", ".join(results) if isinstance(results, list) else str(results)
         super().__init__(data, message_override=result_str)
 
 
 class InternalError(ResponseError):
-    def __init__(self, data: Dict[str, Any]) -> None:
+    def __init__(self, data: dict[str, Any]) -> None:
         exceptions = data.get("exceptions", [])
         self.exceptions = exceptions
         message = data.get("message")
@@ -88,7 +90,7 @@ response_error_types = {
 }
 
 
-def make_response_error(data: Dict[str, Any]) -> ResponseError:
+def make_response_error(data: dict[str, Any]) -> ResponseError:
     error_data = data["error"]
     if isinstance(error_data, str):
         error_data = {"message": error_data}

+ 18 - 16
mausignald/rpc.py

@@ -3,7 +3,9 @@
 # This Source Code Form is subject to the terms of the Mozilla Public
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
-from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
+from __future__ import annotations
+
+from typing import Any, Awaitable, Callable, Dict
 from uuid import UUID, uuid4
 import asyncio
 import json
@@ -27,20 +29,20 @@ class SignaldRPCClient:
     log: TraceLogger
 
     socket_path: str
-    _reader: Optional[asyncio.StreamReader]
-    _writer: Optional[asyncio.StreamWriter]
+    _reader: asyncio.StreamReader | None
+    _writer: asyncio.StreamWriter | None
     is_connected: bool
     _connect_future: asyncio.Future
-    _communicate_task: Optional[asyncio.Task]
+    _communicate_task: asyncio.Task | None
 
-    _response_waiters: Dict[UUID, asyncio.Future]
-    _rpc_event_handlers: Dict[str, List[EventHandler]]
+    _response_waiters: dict[UUID, asyncio.Future]
+    _rpc_event_handlers: dict[str, list[EventHandler]]
 
     def __init__(
         self,
         socket_path: str,
-        log: Optional[TraceLogger] = None,
-        loop: Optional[asyncio.AbstractEventLoop] = None,
+        log: TraceLogger | None = None,
+        loop: asyncio.AbstractEventLoop | None = None,
     ) -> None:
         self.socket_path = socket_path
         self.log = log or logging.getLogger("mausignald")
@@ -54,7 +56,7 @@ class SignaldRPCClient:
         self._rpc_event_handlers = {CONNECT_EVENT: [], DISCONNECT_EVENT: []}
         self.add_rpc_handler(DISCONNECT_EVENT, self._abandon_responses)
 
-    async def wait_for_connected(self, timeout: Optional[int] = None) -> bool:
+    async def wait_for_connected(self, timeout: int | None = None) -> bool:
         if self.is_connected:
             return True
         await asyncio.wait_for(asyncio.shield(self._connect_future), timeout)
@@ -106,7 +108,7 @@ class SignaldRPCClient:
     def remove_rpc_handler(self, method: str, handler: EventHandler) -> None:
         self._rpc_event_handlers.setdefault(method, []).remove(handler)
 
-    async def _run_rpc_handler(self, command: str, req: Dict[str, Any]) -> None:
+    async def _run_rpc_handler(self, command: str, req: dict[str, Any]) -> None:
         try:
             handlers = self._rpc_event_handlers[command]
         except KeyError:
@@ -183,8 +185,8 @@ class SignaldRPCClient:
         self._writer = None
 
     def _create_request(
-        self, command: str, req_id: Optional[UUID] = None, **data: Any
-    ) -> Tuple[asyncio.Future, Dict[str, Any]]:
+        self, command: str, req_id: UUID | None = None, **data: Any
+    ) -> tuple[asyncio.Future, dict[str, Any]]:
         req_id = req_id or uuid4()
         req = {"id": str(req_id), "type": command, **data}
         self.log.trace("Request %s: %s %s", req_id, command, data)
@@ -197,7 +199,7 @@ class SignaldRPCClient:
             future = self._response_waiters[req_id] = self.loop.create_future()
         return future
 
-    async def _abandon_responses(self, unused_data: Dict[str, Any]) -> None:
+    async def _abandon_responses(self, unused_data: dict[str, Any]) -> None:
         for req_id, waiter in self._response_waiters.items():
             if not waiter.done():
                 self.log.trace(f"Abandoning response for {req_id}")
@@ -205,7 +207,7 @@ class SignaldRPCClient:
                     NotConnected("Disconnected from signald before RPC completed")
                 )
 
-    async def _send_request(self, data: Dict[str, Any]) -> None:
+    async def _send_request(self, data: dict[str, Any]) -> None:
         if self._writer is None:
             raise NotConnected("Not connected to signald")
 
@@ -215,8 +217,8 @@ class SignaldRPCClient:
         self.log.trace("Sent data to server server: %s", data)
 
     async def _raw_request(
-        self, command: str, req_id: Optional[UUID] = None, **data: Any
-    ) -> Tuple[str, Dict[str, Any]]:
+        self, command: str, req_id: UUID | None = None, **data: Any
+    ) -> tuple[str, dict[str, Any]]:
         future, data = self._create_request(command, req_id, **data)
         await self._send_request(data)
         return await asyncio.shield(future)

+ 38 - 39
mausignald/signald.py

@@ -3,7 +3,10 @@
 # This Source Code Form is subject to the terms of the Mozilla Public
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
-from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Type, TypeVar, Union
+from __future__ import annotations
+
+from typing import Any, Awaitable, Callable, Type, TypeVar
+from uuid import UUID
 import asyncio
 
 from mautrix.util.logging import TraceLogger
@@ -34,14 +37,14 @@ EventHandler = Callable[[T], Awaitable[None]]
 
 
 class SignaldClient(SignaldRPCClient):
-    _event_handlers: Dict[Type[T], List[EventHandler]]
-    _subscriptions: Set[str]
+    _event_handlers: dict[Type[T], list[EventHandler]]
+    _subscriptions: set[str]
 
     def __init__(
         self,
         socket_path: str = "/var/run/signald/signald.sock",
-        log: Optional[TraceLogger] = None,
-        loop: Optional[asyncio.AbstractEventLoop] = None,
+        log: TraceLogger | None = None,
+        loop: asyncio.AbstractEventLoop | None = None,
     ) -> None:
         super().__init__(socket_path, log, loop)
         self._event_handlers = {}
@@ -70,7 +73,7 @@ class SignaldClient(SignaldRPCClient):
                 except Exception:
                     self.log.exception("Exception in event handler")
 
-    async def _parse_message(self, data: Dict[str, Any]) -> None:
+    async def _parse_message(self, data: dict[str, Any]) -> None:
         event_type = data["type"]
         event_data = data["data"]
         event_class = {
@@ -79,12 +82,12 @@ class SignaldClient(SignaldRPCClient):
         event = event_class.deserialize(event_data)
         await self._run_event_handler(event)
 
-    async def _log_version(self, data: Dict[str, Any]) -> None:
+    async def _log_version(self, data: dict[str, Any]) -> None:
         name = data["data"]["name"]
         version = data["data"]["version"]
         self.log.info(f"Connected to {name} v{version}")
 
-    async def _websocket_connection_state_change(self, change_event: Dict[str, Any]) -> None:
+    async def _websocket_connection_state_change(self, change_event: dict[str, Any]) -> None:
         evt = WebsocketConnectionStateChangeEvent.deserialize(
             {
                 "account": change_event["account"],
@@ -120,7 +123,7 @@ class SignaldClient(SignaldRPCClient):
             self.log.debug("Failed to unsubscribe from %s: %s", username, e)
             return False
 
-    async def _resubscribe(self, unused_data: Dict[str, Any]) -> None:
+    async def _resubscribe(self, unused_data: dict[str, Any]) -> None:
         if self._subscriptions:
             self.log.debug("Resubscribing to users")
             for username in list(self._subscriptions):
@@ -137,9 +140,7 @@ class SignaldClient(SignaldRPCClient):
                 )
                 await self._run_event_handler(evt)
 
-    async def register(
-        self, phone: str, voice: bool = False, captcha: Optional[str] = None
-    ) -> str:
+    async def register(self, phone: str, voice: bool = False, captcha: str | None = None) -> str:
         resp = await self.request_v1("register", account=phone, voice=voice, captcha=captcha)
         return resp["account_id"]
 
@@ -163,8 +164,8 @@ class SignaldClient(SignaldRPCClient):
 
     @staticmethod
     def _recipient_to_args(
-        recipient: Union[Address, GroupID], simple_name: bool = False
-    ) -> Dict[str, Any]:
+        recipient: Address | GroupID, simple_name: bool = False
+    ) -> dict[str, Any]:
         if isinstance(recipient, Address):
             recipient = recipient.serialize()
             field_name = "address" if simple_name else "recipientAddress"
@@ -172,9 +173,7 @@ class SignaldClient(SignaldRPCClient):
             field_name = "group" if simple_name else "recipientGroupId"
         return {field_name: recipient}
 
-    async def react(
-        self, username: str, recipient: Union[Address, GroupID], reaction: Reaction
-    ) -> None:
+    async def react(self, username: str, recipient: Address | GroupID, reaction: Reaction) -> None:
         await self.request_v1(
             "react",
             username=username,
@@ -183,7 +182,7 @@ class SignaldClient(SignaldRPCClient):
         )
 
     async def remote_delete(
-        self, username: str, recipient: Union[Address, GroupID], timestamp: int
+        self, username: str, recipient: Address | GroupID, timestamp: int
     ) -> None:
         await self.request_v1(
             "remote_delete",
@@ -195,12 +194,12 @@ class SignaldClient(SignaldRPCClient):
     async def send(
         self,
         username: str,
-        recipient: Union[Address, GroupID],
+        recipient: Address | GroupID,
         body: str,
-        quote: Optional[Quote] = None,
-        attachments: Optional[List[Attachment]] = None,
-        mentions: Optional[List[Mention]] = None,
-        timestamp: Optional[int] = None,
+        quote: Quote | None = None,
+        attachments: list[Attachment] | None = None,
+        mentions: list[Mention] | None = None,
+        timestamp: int | None = None,
     ) -> None:
         serialized_quote = quote.serialize() if quote else None
         serialized_attachments = [attachment.serialize() for attachment in (attachments or [])]
@@ -270,8 +269,8 @@ class SignaldClient(SignaldRPCClient):
         self,
         username: str,
         sender: Address,
-        timestamps: List[int],
-        when: Optional[int] = None,
+        timestamps: list[int],
+        when: int | None = None,
         read: bool = False,
     ) -> None:
         if not read:
@@ -281,25 +280,25 @@ class SignaldClient(SignaldRPCClient):
             "mark_read", account=username, timestamps=timestamps, when=when, to=sender.serialize()
         )
 
-    async def list_accounts(self) -> List[Account]:
+    async def list_accounts(self) -> list[Account]:
         resp = await self.request_v1("list_accounts")
         return [Account.deserialize(acc) for acc in resp.get("accounts", [])]
 
     async def delete_account(self, username: str, server: bool = False) -> None:
         await self.request_v1("delete_account", account=username, server=server)
 
-    async def get_linked_devices(self, username: str) -> List[DeviceInfo]:
+    async def get_linked_devices(self, username: str) -> list[DeviceInfo]:
         resp = await self.request_v1("get_linked_devices", account=username)
         return [DeviceInfo.deserialize(dev) for dev in resp.get("devices", [])]
 
     async def remove_linked_device(self, username: str, device_id: int) -> None:
         await self.request_v1("remove_linked_device", account=username, deviceId=device_id)
 
-    async def list_contacts(self, username: str) -> List[Profile]:
+    async def list_contacts(self, username: str) -> list[Profile]:
         resp = await self.request_v1("list_contacts", account=username)
         return [Profile.deserialize(contact) for contact in resp["profiles"]]
 
-    async def list_groups(self, username: str) -> List[Union[Group, GroupV2]]:
+    async def list_groups(self, username: str) -> list[Group | GroupV2]:
         resp = await self.request_v1("list_groups", account=username)
         legacy = [Group.deserialize(group) for group in resp.get("legacyGroups", [])]
         v2 = [GroupV2.deserialize(group) for group in resp.get("groups", [])]
@@ -309,11 +308,11 @@ class SignaldClient(SignaldRPCClient):
         self,
         username: str,
         group_id: GroupID,
-        title: Optional[str] = None,
-        avatar_path: Optional[str] = None,
-        add_members: Optional[List[Address]] = None,
-        remove_members: Optional[List[Address]] = None,
-    ) -> Union[Group, GroupV2, None]:
+        title: str | None = None,
+        avatar_path: str | None = None,
+        add_members: list[Address] | None = None,
+        remove_members: list[Address] | None = None,
+    ) -> Group | GroupV2 | None:
         update_params = {
             key: value
             for key, value in {
@@ -341,7 +340,7 @@ class SignaldClient(SignaldRPCClient):
 
     async def get_group(
         self, username: str, group_id: GroupID, revision: int = -1
-    ) -> Optional[GroupV2]:
+    ) -> GroupV2 | None:
         resp = await self.request_v1(
             "get_group", account=username, groupID=group_id, revision=revision
         )
@@ -351,7 +350,7 @@ class SignaldClient(SignaldRPCClient):
 
     async def get_profile(
         self, username: str, address: Address, use_cache: bool = False
-    ) -> Optional[Profile]:
+    ) -> Profile | None:
         try:
             # async is a reserved keyword, so can't pass it as a normal parameter
             kwargs = {"async": use_cache}
@@ -371,7 +370,7 @@ class SignaldClient(SignaldRPCClient):
         return GetIdentitiesResponse.deserialize(resp)
 
     async def set_profile(
-        self, username: str, name: Optional[str] = None, avatar_path: Optional[str] = None
+        self, username: str, name: str | None = None, avatar_path: str | None = None
     ) -> None:
         args = {}
         if name is not None:
@@ -385,8 +384,8 @@ class SignaldClient(SignaldRPCClient):
         username: str,
         recipient: Address,
         trust_level: str,
-        safety_number: Optional[str] = None,
-        qr_code_data: Optional[str] = None,
+        safety_number: str | None = None,
+        qr_code_data: str | None = None,
     ) -> None:
         args = {}
         if safety_number: