Jelajahi Sumber

Move to future annotations and clean up things

Tulir Asokan 3 tahun lalu
induk
melakukan
1884f00170

+ 1 - 2
mautrix_signal/__main__.py

@@ -23,8 +23,7 @@ from mautrix.types import RoomID, UserID
 
 from . import commands
 from .config import Config
-from .db import init as init_db
-from .db import upgrade_table
+from .db import init as init_db, upgrade_table
 from .matrix import MatrixHandler
 from .portal import Portal
 from .puppet import Puppet

+ 1 - 1
mautrix_signal/commands/auth.py

@@ -85,7 +85,7 @@ async def link(evt: CommandEvent) -> None:
     except Exception:
         evt.log.exception("Fatal error while waiting for linking to finish")
         await evt.reply(
-            "Fatal error while waiting for linking to finish " "(see logs for more details)"
+            "Fatal error while waiting for linking to finish (see logs for more details)"
         )
     else:
         await evt.sender.on_signin(account)

+ 3 - 4
mautrix_signal/commands/signal.py

@@ -23,8 +23,7 @@ from mautrix.types import EventID
 from mausignald.errors import UnknownIdentityKey
 from mausignald.types import Address
 
-from .. import portal as po
-from .. import puppet as pu
+from .. import portal as po, puppet as pu
 from .auth import make_qr, remove_extra_chars
 from .typehint import CommandEvent
 
@@ -182,7 +181,7 @@ async def set_profile_name(evt: CommandEvent) -> None:
 async def mark_trusted(evt: CommandEvent) -> EventID:
     if len(evt.args) < 2:
         return await evt.reply(
-            "**Usage:** `$cmdprefix+sp mark-trusted <recipient phone> " "<safety number>`"
+            "**Usage:** `$cmdprefix+sp mark-trusted <recipient phone> <safety number>`"
         )
     number = evt.args[0].translate(remove_extra_chars)
     safety_num = "".join(evt.args[1:]).replace("\n", "")
@@ -253,5 +252,5 @@ async def raw(evt: CommandEvent) -> None:
             await evt.reply(f"Got reply `{resp_type}` with no content")
         else:
             await evt.reply(
-                f"Got reply `{resp_type}`:\n\n" f"```json\n{json.dumps(resp_data, indent=2)}\n```"
+                f"Got reply `{resp_type}`:\n\n```json\n{json.dumps(resp_data, indent=2)}\n```"
             )

+ 6 - 4
mautrix_signal/db/disappearing_message.py

@@ -13,6 +13,8 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
+from __future__ import annotations
+
 from typing import TYPE_CHECKING, ClassVar, List, Optional
 
 from attr import dataclass
@@ -60,11 +62,11 @@ class DisappearingMessage:
         await cls.db.execute(q, room_id, event_id)
 
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> "DisappearingMessage":
+    def _from_row(cls, row: asyncpg.Record) -> DisappearingMessage:
         return cls(**row)
 
     @classmethod
-    async def get(cls, room_id: RoomID, event_id: EventID) -> Optional["DisappearingMessage"]:
+    async def get(cls, room_id: RoomID, event_id: EventID) -> Optional[DisappearingMessage]:
         q = """
         SELECT room_id, mxid, expiration_seconds, expiration_ts
           FROM disappearing_message
@@ -77,12 +79,12 @@ class DisappearingMessage:
             return None
 
     @classmethod
-    async def get_all(cls) -> List["DisappearingMessage"]:
+    async def get_all(cls) -> List[DisappearingMessage]:
         q = "SELECT room_id, mxid, expiration_seconds, expiration_ts FROM disappearing_message"
         return [cls._from_row(r) for r in await cls.db.fetch(q)]
 
     @classmethod
-    async def get_all_for_room(cls, room_id: RoomID) -> List["DisappearingMessage"]:
+    async def get_all_for_room(cls, room_id: RoomID) -> List[DisappearingMessage]:
         q = """
         SELECT room_id, mxid, expiration_seconds, expiration_ts
           FROM disappearing_message

+ 21 - 26
mautrix_signal/db/message.py

@@ -13,8 +13,9 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
-from uuid import UUID
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, ClassVar
 
 from attr import dataclass
 from mautrix.types import EventID, RoomID
@@ -36,7 +37,7 @@ class Message:
     mx_room: RoomID
     sender: Address
     timestamp: int
-    signal_chat_id: Union[GroupID, Address]
+    signal_chat_id: GroupID | Address
     signal_receiver: str
 
     async def insert(self) -> None:
@@ -57,7 +58,7 @@ class Message:
     async def delete(self) -> None:
         q = """
         DELETE FROM message
-         WHERE sender=$1 AND timestamp=$2 AND signal_chat_id=$3 AND signal_receiver=$4
+        WHERE sender=$1 AND timestamp=$2 AND signal_chat_id=$3 AND signal_receiver=$4
         """
         await self.db.execute(
             q,
@@ -72,7 +73,7 @@ class Message:
         await cls.db.execute("DELETE FROM message WHERE mx_room=$1", room_id)
 
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> "Message":
+    def _from_row(cls, row: asyncpg.Record) -> Message:
         data = {**row}
         chat_id = data.pop("signal_chat_id")
         if data["signal_receiver"]:
@@ -81,10 +82,10 @@ class Message:
         return cls(signal_chat_id=chat_id, sender=sender, **data)
 
     @classmethod
-    async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional["Message"]:
+    async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Message | None:
         q = """
-        SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
-          FROM message WHERE mxid=$1 AND mx_room=$2
+        SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message
+        WHERE mxid=$1 AND mx_room=$2
         """
         row = await cls.db.fetchrow(q, mxid, mx_room)
         if not row:
@@ -96,13 +97,12 @@ class Message:
         cls,
         sender: Address,
         timestamp: int,
-        signal_chat_id: Union[GroupID, Address],
+        signal_chat_id: GroupID | Address,
         signal_receiver: str = "",
-    ) -> Optional["Message"]:
+    ) -> Message | None:
         q = """
-        SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
-          FROM message
-         WHERE sender=$1 AND timestamp=$2 AND signal_chat_id=$3 AND signal_receiver=$4
+        SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message
+        WHERE sender=$1 AND timestamp=$2 AND signal_chat_id=$3 AND signal_receiver=$4
         """
         row = await cls.db.fetchrow(
             q, sender.best_identifier, timestamp, id_to_str(signal_chat_id), signal_receiver
@@ -112,32 +112,27 @@ class Message:
         return cls._from_row(row)
 
     @classmethod
-    async def find_by_timestamps(cls, timestamps: List[int]) -> List["Message"]:
+    async def find_by_timestamps(cls, timestamps: list[int]) -> list[Message]:
         if cls.db.scheme == "postgres":
             q = """
-            SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
-              FROM message
-             WHERE timestamp=ANY($1)
+            SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message
+            WHERE timestamp=ANY($1)
             """
             rows = await cls.db.fetch(q, timestamps)
         else:
             placeholders = ", ".join("?" for _ in range(len(timestamps)))
             q = f"""
-            SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
-              FROM message
-             WHERE timestamp IN ({placeholders})
+            SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message
+            WHERE timestamp IN ({placeholders})
             """
             rows = await cls.db.fetch(q, *timestamps)
         return [cls._from_row(row) for row in rows]
 
     @classmethod
-    async def find_by_sender_timestamp(
-        cls, sender: Address, timestamp: int
-    ) -> Optional["Message"]:
+    async def find_by_sender_timestamp(cls, sender: Address, timestamp: int) -> Message | None:
         q = """
-        SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver
-          FROM message
-         WHERE sender=$1 AND timestamp=$2
+        SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message
+        WHERE sender=$1 AND timestamp=$2
         """
         row = await cls.db.fetchrow(q, sender.best_identifier, timestamp)
         if not row:

+ 26 - 31
mautrix_signal/db/portal.py

@@ -13,7 +13,9 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, ClassVar
 
 from attr import dataclass
 from mautrix.types import ContentURI, RoomID, UserID
@@ -31,18 +33,18 @@ fake_db = Database.create("") if TYPE_CHECKING else None
 class Portal:
     db: ClassVar[Database] = fake_db
 
-    chat_id: Union[GroupID, Address]
+    chat_id: GroupID | Address
     receiver: str
-    mxid: Optional[RoomID]
-    name: Optional[str]
-    avatar_hash: Optional[str]
-    avatar_url: Optional[ContentURI]
+    mxid: RoomID | None
+    name: str | None
+    avatar_hash: str | None
+    avatar_url: ContentURI | None
     name_set: bool
     avatar_set: bool
     revision: int
     encrypted: bool
-    relay_user_id: Optional[UserID]
-    expiration_time: Optional[int]
+    relay_user_id: UserID | None
+    expiration_time: int | None
 
     @property
     def chat_id_str(self) -> str:
@@ -94,7 +96,7 @@ class Portal:
         )
 
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> "Portal":
+    def _from_row(cls, row: asyncpg.Record) -> Portal:
         data = {**row}
         chat_id = data.pop("chat_id")
         if data["receiver"]:
@@ -102,12 +104,11 @@ class Portal:
         return cls(chat_id=chat_id, **data)
 
     @classmethod
-    async def get_by_mxid(cls, mxid: RoomID) -> Optional["Portal"]:
+    async def get_by_mxid(cls, mxid: RoomID) -> Portal | None:
         q = """
         SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
-               revision, encrypted, relay_user_id, expiration_time
-          FROM portal
-         WHERE mxid=$1
+               revision, encrypted, relay_user_id, expiration_time FROM portal
+        WHERE mxid=$1
         """
         row = await cls.db.fetchrow(q, mxid)
         if not row:
@@ -115,14 +116,11 @@ class Portal:
         return cls._from_row(row)
 
     @classmethod
-    async def get_by_chat_id(
-        cls, chat_id: Union[GroupID, Address], receiver: str = ""
-    ) -> Optional["Portal"]:
+    async def get_by_chat_id(cls, chat_id: GroupID | Address, receiver: str = "") -> Portal | None:
         q = """
         SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
-               revision, encrypted, relay_user_id, expiration_time
-          FROM portal
-         WHERE chat_id=$1 AND receiver=$2
+               revision, encrypted, relay_user_id, expiration_time FROM portal
+        WHERE chat_id=$1 AND receiver=$2
         """
         row = await cls.db.fetchrow(q, id_to_str(chat_id), receiver)
         if not row:
@@ -130,34 +128,31 @@ class Portal:
         return cls._from_row(row)
 
     @classmethod
-    async def find_private_chats_of(cls, receiver: str) -> List["Portal"]:
+    async def find_private_chats_of(cls, receiver: str) -> list[Portal]:
         q = """
         SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
-               revision, encrypted, relay_user_id, expiration_time
-          FROM portal
-         WHERE receiver=$1
+               revision, encrypted, relay_user_id, expiration_time FROM portal
+        WHERE receiver=$1
         """
         rows = await cls.db.fetch(q, receiver)
         return [cls._from_row(row) for row in rows]
 
     @classmethod
-    async def find_private_chats_with(cls, other_user: Address) -> List["Portal"]:
+    async def find_private_chats_with(cls, other_user: Address) -> list[Portal]:
         q = """
         SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
-               revision, encrypted, relay_user_id, expiration_time
-          FROM portal
-         WHERE chat_id=$1 AND receiver<>''
+               revision, encrypted, relay_user_id, expiration_time FROM portal
+        WHERE chat_id=$1 AND receiver<>''
         """
         rows = await cls.db.fetch(q, other_user.best_identifier)
         return [cls._from_row(row) for row in rows]
 
     @classmethod
-    async def all_with_room(cls) -> List["Portal"]:
+    async def all_with_room(cls) -> list[Portal]:
         q = """
         SELECT chat_id, receiver, mxid, name, avatar_hash, avatar_url, name_set, avatar_set,
-               revision, encrypted, relay_user_id, expiration_time
-          FROM portal
-         WHERE mxid IS NOT NULL
+               revision, encrypted, relay_user_id, expiration_time FROM portal
+        WHERE mxid IS NOT NULL
         """
         rows = await cls.db.fetch(q)
         return [cls._from_row(row) for row in rows]

+ 17 - 15
mautrix_signal/db/puppet.py

@@ -13,7 +13,9 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import TYPE_CHECKING, ClassVar, List, Optional
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, ClassVar
 from uuid import UUID
 
 from attr import dataclass
@@ -31,24 +33,24 @@ fake_db = Database.create("") if TYPE_CHECKING else None
 class Puppet:
     db: ClassVar[Database] = fake_db
 
-    uuid: Optional[UUID]
-    number: Optional[str]
-    name: Optional[str]
-    avatar_hash: Optional[str]
-    avatar_url: Optional[ContentURI]
+    uuid: UUID | None
+    number: str | None
+    name: str | None
+    avatar_hash: str | None
+    avatar_url: ContentURI | None
     name_set: bool
     avatar_set: bool
 
     uuid_registered: bool
     number_registered: bool
 
-    custom_mxid: Optional[UserID]
-    access_token: Optional[str]
-    next_batch: Optional[SyncToken]
-    base_url: Optional[URL]
+    custom_mxid: UserID | None
+    access_token: str | None
+    next_batch: SyncToken | None
+    base_url: URL | None
 
     @property
-    def _base_url_str(self) -> Optional[str]:
+    def _base_url_str(self) -> str | None:
         return str(self.base_url) if self.base_url else None
 
     async def insert(self) -> None:
@@ -134,7 +136,7 @@ class Puppet:
         )
 
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> "Puppet":
+    def _from_row(cls, row: asyncpg.Record) -> Puppet:
         data = {**row}
         base_url_str = data.pop("base_url")
         base_url = URL(base_url_str) if base_url_str is not None else None
@@ -148,7 +150,7 @@ class Puppet:
     )
 
     @classmethod
-    async def get_by_address(cls, address: Address) -> Optional["Puppet"]:
+    async def get_by_address(cls, address: Address) -> Puppet | None:
         if address.uuid:
             if address.number:
                 row = await cls.db.fetchrow(
@@ -165,13 +167,13 @@ class Puppet:
         return cls._from_row(row)
 
     @classmethod
-    async def get_by_custom_mxid(cls, mxid: UserID) -> Optional["Puppet"]:
+    async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
         row = await cls.db.fetchrow(f"{cls._select_base} WHERE custom_mxid=$1", mxid)
         if not row:
             return None
         return cls._from_row(row)
 
     @classmethod
-    async def all_with_custom_mxid(cls) -> List["Puppet"]:
+    async def all_with_custom_mxid(cls) -> list[Puppet]:
         rows = await cls.db.fetch(f"{cls._select_base} WHERE custom_mxid IS NOT NULL")
         return [cls._from_row(row) for row in rows]

+ 8 - 7
mautrix_signal/db/reaction.py

@@ -13,8 +13,9 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import TYPE_CHECKING, ClassVar, Optional, Union
-from uuid import UUID
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, ClassVar
 
 from attr import dataclass
 from mautrix.types import EventID, RoomID
@@ -34,7 +35,7 @@ class Reaction:
 
     mxid: EventID
     mx_room: RoomID
-    signal_chat_id: Union[GroupID, Address]
+    signal_chat_id: GroupID | Address
     signal_receiver: str
     msg_author: Address
     msg_timestamp: int
@@ -89,7 +90,7 @@ class Reaction:
         )
 
     @classmethod
-    def _from_row(cls, row: asyncpg.Record) -> "Reaction":
+    def _from_row(cls, row: asyncpg.Record) -> Reaction:
         data = {**row}
         chat_id = data.pop("signal_chat_id")
         if data["signal_receiver"]:
@@ -99,7 +100,7 @@ class Reaction:
         return cls(signal_chat_id=chat_id, msg_author=msg_author, author=author, **data)
 
     @classmethod
-    async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Optional["Reaction"]:
+    async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Reaction | None:
         q = (
             "SELECT mxid, mx_room, signal_chat_id, signal_receiver,"
             "       msg_author, msg_timestamp, author, emoji "
@@ -113,12 +114,12 @@ class Reaction:
     @classmethod
     async def get_by_signal_id(
         cls,
-        chat_id: Union[GroupID, Address],
+        chat_id: GroupID | Address,
         receiver: str,
         msg_author: Address,
         msg_timestamp: int,
         author: Address,
-    ) -> Optional["Reaction"]:
+    ) -> Reaction | None:
         q = (
             "SELECT mxid, mx_room, signal_chat_id, signal_receiver,"
             "       msg_author, msg_timestamp, author, emoji "

+ 10 - 8
mautrix_signal/db/user.py

@@ -13,7 +13,9 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import TYPE_CHECKING, ClassVar, List, Optional
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, ClassVar
 from uuid import UUID
 
 from attr import dataclass
@@ -28,9 +30,9 @@ class User:
     db: ClassVar[Database] = fake_db
 
     mxid: UserID
-    username: Optional[str]
-    uuid: Optional[UUID]
-    notice_room: Optional[RoomID]
+    username: str | None
+    uuid: UUID | None
+    notice_room: RoomID | None
 
     async def insert(self) -> None:
         q = 'INSERT INTO "user" (mxid, username, uuid, notice_room) ' "VALUES ($1, $2, $3, $4)"
@@ -41,7 +43,7 @@ class User:
         await self.db.execute(q, self.username, self.uuid, self.notice_room, self.mxid)
 
     @classmethod
-    async def get_by_mxid(cls, mxid: UserID) -> Optional["User"]:
+    async def get_by_mxid(cls, mxid: UserID) -> User | None:
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE mxid=$1'
         row = await cls.db.fetchrow(q, mxid)
         if not row:
@@ -49,7 +51,7 @@ class User:
         return cls(**row)
 
     @classmethod
-    async def get_by_username(cls, username: str) -> Optional["User"]:
+    async def get_by_username(cls, username: str) -> User | None:
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE username=$1'
         row = await cls.db.fetchrow(q, username)
         if not row:
@@ -57,7 +59,7 @@ class User:
         return cls(**row)
 
     @classmethod
-    async def get_by_uuid(cls, uuid: UUID) -> Optional["User"]:
+    async def get_by_uuid(cls, uuid: UUID) -> User | None:
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE uuid=$1'
         row = await cls.db.fetchrow(q, uuid)
         if not row:
@@ -65,7 +67,7 @@ class User:
         return cls(**row)
 
     @classmethod
-    async def all_logged_in(cls) -> List["User"]:
+    async def all_logged_in(cls) -> list[User]:
         q = 'SELECT mxid, username, uuid, notice_room FROM "user" WHERE username IS NOT NULL'
         rows = await cls.db.fetch(q)
         return [cls(**row) for row in rows]

+ 8 - 5
mautrix_signal/formatter.py

@@ -18,14 +18,17 @@ from html import escape
 import struct
 
 from mautrix.types import Format, MessageType, TextMessageEventContent
-from mautrix.util.formatter import EntityString, EntityType, MarkdownString
-from mautrix.util.formatter import MatrixParser as BaseMatrixParser
-from mautrix.util.formatter import SimpleEntity
+from mautrix.util.formatter import (
+    EntityString,
+    EntityType,
+    MarkdownString,
+    MatrixParser as BaseMatrixParser,
+    SimpleEntity,
+)
 
 from mausignald.types import Address, Mention, MessageData
 
-from . import puppet as pu
-from . import user as u
+from . import puppet as pu, user as u
 
 
 # Helper methods from rom https://github.com/LonamiWebs/Telethon/blob/master/telethon/helpers.py

+ 1 - 3
mautrix_signal/matrix.py

@@ -35,9 +35,7 @@ from mautrix.types import (
 
 from mautrix_signal.db.disappearing_message import DisappearingMessage
 
-from . import portal as po
-from . import signal as s
-from . import user as u
+from . import portal as po, signal as s, user as u
 from .db import Message as DBMessage
 
 if TYPE_CHECKING:

+ 74 - 95
mautrix_signal/portal.py

@@ -13,24 +13,10 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    AsyncGenerator,
-    Awaitable,
-    Callable,
-    Deque,
-    Dict,
-    List,
-    Optional,
-    Set,
-    Tuple,
-    Union,
-    cast,
-)
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Union, cast
 from collections import deque
-from html import escape as escape_html
-from string import Template
 from uuid import UUID, uuid4
 import asyncio
 import hashlib
@@ -50,7 +36,6 @@ from mautrix.types import (
     EventID,
     EventType,
     FileInfo,
-    Format,
     ImageInfo,
     MediaMessageEventContent,
     MessageEvent,
@@ -87,15 +72,14 @@ from mausignald.types import (
     Sticker,
 )
 
-from . import matrix as m
-from . import puppet as p
-from . import signal as s
-from . import user as u
+from . import matrix as m, puppet as p, signal as s, user as u
 from .config import Config
-from .db import DisappearingMessage
-from .db import Message as DBMessage
-from .db import Portal as DBPortal
-from .db import Reaction as DBReaction
+from .db import (
+    DisappearingMessage,
+    Message as DBMessage,
+    Portal as DBPortal,
+    Reaction as DBReaction,
+)
 from .formatter import matrix_to_signal, signal_to_matrix
 from .util import id_to_str
 
@@ -124,38 +108,38 @@ ChatInfo = Union[Group, GroupV2, GroupV2ID, Contact, Profile, Address]
 
 
 class Portal(DBPortal, BasePortal):
-    by_mxid: Dict[RoomID, "Portal"] = {}
-    by_chat_id: Dict[Tuple[str, str], "Portal"] = {}
-    _sticker_meta_cache: Dict[str, StickerPack] = {}
+    by_mxid: dict[RoomID, Portal] = {}
+    by_chat_id: dict[tuple[str, str], Portal] = {}
+    _sticker_meta_cache: dict[str, StickerPack] = {}
     config: Config
-    matrix: "m.MatrixHandler"
-    signal: "s.SignalHandler"
+    matrix: m.MatrixHandler
+    signal: s.SignalHandler
     az: AppService
     private_chat_portal_meta: bool
-    expiration_time: Optional[int]
+    expiration_time: int | None
 
-    _main_intent: Optional[IntentAPI]
+    _main_intent: IntentAPI | None
     _create_room_lock: asyncio.Lock
-    _msgts_dedup: Deque[Tuple[Address, int]]
-    _reaction_dedup: Deque[Tuple[Address, int, str]]
+    _msgts_dedup: deque[tuple[Address, int]]
+    _reaction_dedup: deque[tuple[Address, int, str]]
     _reaction_lock: asyncio.Lock
-    _pending_members: Optional[Set[UUID]]
+    _pending_members: set[UUID] | None
     _expiration_lock: asyncio.Lock
 
     def __init__(
         self,
-        chat_id: Union[GroupID, Address],
+        chat_id: GroupID | Address,
         receiver: str,
-        mxid: Optional[RoomID] = None,
-        name: Optional[str] = None,
-        avatar_hash: Optional[str] = None,
-        avatar_url: Optional[ContentURI] = None,
+        mxid: RoomID | None = None,
+        name: str | None = None,
+        avatar_hash: str | None = None,
+        avatar_url: ContentURI | None = None,
         name_set: bool = False,
         avatar_set: bool = False,
         revision: int = 0,
         encrypted: bool = False,
-        relay_user_id: Optional[UserID] = None,
-        expiration_time: Optional[int] = None,
+        relay_user_id: UserID | None = None,
+        expiration_time: int | None = None,
     ) -> None:
         super().__init__(
             chat_id,
@@ -236,7 +220,7 @@ class Portal(DBPortal, BasePortal):
         existing: DBReaction,
         intent: IntentAPI,
         mxid: EventID,
-        sender: Union["p.Puppet", "u.User"],
+        sender: p.Puppet | u.User,
         message: DBMessage,
         emoji: str,
     ) -> None:
@@ -302,7 +286,7 @@ class Portal(DBPortal, BasePortal):
         return self._write_outgoing_file(data)
 
     async def handle_matrix_message(
-        self, sender: "u.User", message: MessageEventContent, event_id: EventID
+        self, sender: u.User, message: MessageEventContent, event_id: EventID
     ) -> None:
         orig_sender = sender
         sender, is_relay = await self.get_relay_sender(sender, f"message {event_id}")
@@ -329,9 +313,9 @@ class Portal(DBPortal, BasePortal):
             if reply is not None:
                 quote = Quote(id=reply.timestamp, author=reply.sender, text="")
 
-        attachments: Optional[List[Attachment]] = None
-        attachment_path: Optional[str] = None
-        mentions: Optional[List[Mention]] = None
+        attachments: list[Attachment] | None = None
+        attachment_path: str | None = None
+        mentions: list[Mention] | None = None
         if message.msgtype.is_text:
             text, mentions = await matrix_to_signal(message)
         elif message.msgtype.is_media:
@@ -409,7 +393,7 @@ class Portal(DBPortal, BasePortal):
                 await Portal._expire_event(dm.room_id, dm.mxid)
 
     async def handle_matrix_reaction(
-        self, sender: "u.User", event_id: EventID, reacting_to: EventID, emoji: str
+        self, sender: u.User, event_id: EventID, reacting_to: EventID, emoji: str
     ) -> None:
         if not await sender.is_logged_in():
             self.log.trace(f"Ignoring reaction by non-logged-in user {sender.mxid}")
@@ -464,7 +448,7 @@ class Portal(DBPortal, BasePortal):
                 await self._send_delivery_receipt(event_id)
 
     async def handle_matrix_redaction(
-        self, sender: "u.User", event_id: EventID, redaction_event_id: EventID
+        self, sender: u.User, event_id: EventID, redaction_event_id: EventID
     ) -> None:
         if not await sender.is_logged_in():
             return
@@ -537,7 +521,7 @@ class Portal(DBPortal, BasePortal):
             error=f"No message or reaction found for redaction",
         )
 
-    async def handle_matrix_join(self, user: "u.User") -> None:
+    async def handle_matrix_join(self, user: u.User) -> None:
         if self.is_direct or not await user.is_logged_in():
             return
         if self._pending_members is None:
@@ -558,12 +542,12 @@ class Portal(DBPortal, BasePortal):
                 self._pending_members.remove(user.uuid)
             except RPCError as e:
                 await self.main_intent.send_notice(
-                    self.mxid, "\u26a0 Failed to accept invite " f"on Signal: {e}"
+                    self.mxid, f"\u26a0 Failed to accept invite on Signal: {e}"
                 )
             else:
                 await self.update_info(user, resp)
 
-    async def handle_matrix_leave(self, user: "u.User") -> None:
+    async def handle_matrix_leave(self, user: u.User) -> None:
         if not await user.is_logged_in():
             return
         if self.is_direct:
@@ -577,15 +561,15 @@ class Portal(DBPortal, BasePortal):
             self.log.debug(f"{user.mxid} left portal to {self.chat_id}")
             # TODO cleanup if empty
 
-    async def handle_matrix_name(self, user: "u.User", name: str) -> None:
+    async def handle_matrix_name(self, user: u.User, name: str) -> None:
         if self.name == name or self.is_direct or not name:
             return
-        sender, is_relay = await self._get_relay_sender(user, "name change")
+        sender, is_relay = await self.get_relay_sender(user, "name change")
         if not sender:
             return
         self.name = name
         self.log.debug(
-            f"{user.mxid} changed the group name, " f"sending to Signal through {sender.username}"
+            f"{user.mxid} changed the group name, sending to Signal through {sender.username}"
         )
         try:
             await self.signal.update_group(sender.username, self.chat_id, title=name)
@@ -593,10 +577,10 @@ class Portal(DBPortal, BasePortal):
             self.log.exception("Failed to update Signal group name")
             self.name = None
 
-    async def handle_matrix_avatar(self, user: "u.User", url: ContentURI) -> None:
+    async def handle_matrix_avatar(self, user: u.User, url: ContentURI) -> None:
         if self.is_direct or not url:
             return
-        sender, is_relay = await self._get_relay_sender(user, "avatar change")
+        sender, is_relay = await self.get_relay_sender(user, "avatar change")
         if not sender:
             return
 
@@ -609,8 +593,7 @@ class Portal(DBPortal, BasePortal):
         self.avatar_hash = new_hash
         path = self._write_outgoing_file(data)
         self.log.debug(
-            f"{user.mxid} changed the group avatar, "
-            f"sending to Signal through {sender.username}"
+            f"{user.mxid} changed the group avatar, sending to Signal through {sender.username}"
         )
         try:
             await self.signal.update_group(sender.username, self.chat_id, avatar_path=path)
@@ -678,7 +661,7 @@ class Portal(DBPortal, BasePortal):
                 await disappearing_message.update()
 
     @classmethod
-    async def _expire_event_task(cls, portal: "Portal", event_id: EventID, wait: float):
+    async def _expire_event_task(cls, portal: Portal, event_id: EventID, wait: float):
         portal.log.debug(f"Redacting {event_id} in {wait} seconds")
         await asyncio.sleep(wait)
 
@@ -716,9 +699,7 @@ class Portal(DBPortal, BasePortal):
         puppet = await p.Puppet.get_by_address(address, create=False)
         return puppet.address
 
-    async def _find_quote_event_id(
-        self, quote: Optional[Quote]
-    ) -> Optional[Union[MessageEvent, EventID]]:
+    async def _find_quote_event_id(self, quote: Quote | None) -> MessageEvent | EventID | None:
         if not quote:
             return None
 
@@ -737,7 +718,7 @@ class Portal(DBPortal, BasePortal):
             return reply_msg.mxid
 
     async def handle_signal_message(
-        self, source: "u.User", sender: "p.Puppet", message: MessageData
+        self, source: u.User, sender: p.Puppet, message: MessageData
     ) -> None:
         if (sender.address, message.timestamp) in self._msgts_dedup:
             self.log.debug(
@@ -753,7 +734,7 @@ class Portal(DBPortal, BasePortal):
         )
         if old_message is not None:
             self.log.debug(
-                f"Ignoring message {message.timestamp} by {sender.uuid} as it was already handled"
+                f"Ignoring message {message.timestamp} by {sender.uuid} as it was already handled "
                 "(message.id found in database)"
             )
             await self.signal.send_receipt(
@@ -944,7 +925,7 @@ class Portal(DBPortal, BasePortal):
 
     async def _handle_signal_sticker(
         self, intent: IntentAPI, sticker: Sticker
-    ) -> Optional[MediaMessageEventContent]:
+    ) -> MediaMessageEventContent | None:
         try:
             self.log.debug(f"Fetching sticker {sticker.pack_id}#{sticker.sticker_id}")
             async with StickersClient() as client:
@@ -992,7 +973,7 @@ class Portal(DBPortal, BasePortal):
                 content.info.thumbnail_url = content.url
 
     async def handle_signal_reaction(
-        self, sender: "p.Puppet", reaction: Reaction, timestamp: int
+        self, sender: p.Puppet, reaction: Reaction, timestamp: int
     ) -> None:
         author_address = await self._resolve_address(reaction.target_author)
         target_id = reaction.target_sent_timestamp
@@ -1033,7 +1014,7 @@ class Portal(DBPortal, BasePortal):
         self.log.debug(f"{sender.address} reacted to {message.mxid} -> {mxid}")
         await self._upsert_reaction(existing, intent, mxid, sender, message, reaction.emoji)
 
-    async def handle_signal_delete(self, sender: "p.Puppet", message_ts: int) -> None:
+    async def handle_signal_delete(self, sender: p.Puppet, message_ts: int) -> None:
         message = await DBMessage.get_by_signal_id(
             sender.address, message_ts, self.chat_id, self.receiver
         )
@@ -1049,7 +1030,7 @@ class Portal(DBPortal, BasePortal):
     # region Updating portal info
 
     async def update_info(
-        self, source: "u.User", info: ChatInfo, sender: Optional["p.Puppet"] = None
+        self, source: u.User, info: ChatInfo, sender: p.Puppet | None = None
     ) -> None:
         if self.is_direct:
             if not isinstance(info, (Contact, Profile, Address)):
@@ -1098,7 +1079,7 @@ class Portal(DBPortal, BasePortal):
             await self.update_bridge_info()
             await self.update()
 
-    async def update_expires_in_seconds(self, sender: "p.Puppet", expires_in_seconds: int) -> None:
+    async def update_expires_in_seconds(self, sender: p.Puppet, expires_in_seconds: int) -> None:
         if expires_in_seconds == 0:
             expires_in_seconds = None
         if self.expiration_time == expires_in_seconds:
@@ -1142,7 +1123,7 @@ class Portal(DBPortal, BasePortal):
             await self.update_bridge_info()
             await self.update()
 
-    async def _update_name(self, name: str, sender: Optional["p.Puppet"] = None) -> bool:
+    async def _update_name(self, name: str, sender: p.Puppet | None = None) -> bool:
         if self.name != name or not self.name_set:
             self.name = name
             if self.mxid:
@@ -1158,7 +1139,7 @@ class Portal(DBPortal, BasePortal):
         return False
 
     async def _try_with_puppet(
-        self, action: Callable[[IntentAPI], Awaitable[Any]], puppet: Optional["p.Puppet"] = None
+        self, action: Callable[[IntentAPI], Awaitable[Any]], puppet: p.Puppet | None = None
     ) -> None:
         if puppet:
             try:
@@ -1168,7 +1149,7 @@ class Portal(DBPortal, BasePortal):
         else:
             await action(self.main_intent)
 
-    async def _update_avatar(self, info: ChatInfo, sender: Optional["p.Puppet"] = None) -> bool:
+    async def _update_avatar(self, info: ChatInfo, sender: p.Puppet | None = None) -> bool:
         path = None
         if isinstance(info, GroupV2):
             path = info.avatar
@@ -1191,7 +1172,7 @@ class Portal(DBPortal, BasePortal):
             self.avatar_set = False
         return True
 
-    async def _update_participants(self, source: "u.User", info: ChatInfo) -> None:
+    async def _update_participants(self, source: u.User, info: ChatInfo) -> None:
         if not self.mxid or not isinstance(info, (Group, GroupV2)):
             return
 
@@ -1232,7 +1213,7 @@ class Portal(DBPortal, BasePortal):
         return f"net.maunium.signal://signal/{self.chat_id}"
 
     @property
-    def bridge_info(self) -> Dict[str, Any]:
+    def bridge_info(self) -> dict[str, Any]:
         return {
             "bridgebot": self.az.bot_mxid,
             "creator": self.main_intent.mxid,
@@ -1267,7 +1248,7 @@ class Portal(DBPortal, BasePortal):
     # endregion
     # region Creating Matrix rooms
 
-    async def update_matrix_room(self, source: "u.User", info: ChatInfo) -> None:
+    async def update_matrix_room(self, source: u.User, info: ChatInfo) -> None:
         if not self.is_direct and not isinstance(info, (Group, GroupV2, GroupV2ID)):
             raise ValueError(f"Unexpected type for updating group portal: {type(info)}")
         elif self.is_direct and not isinstance(info, (Contact, Profile, Address)):
@@ -1277,7 +1258,7 @@ class Portal(DBPortal, BasePortal):
         except Exception:
             self.log.exception("Failed to update portal")
 
-    async def create_matrix_room(self, source: "u.User", info: ChatInfo) -> Optional[RoomID]:
+    async def create_matrix_room(self, source: u.User, info: ChatInfo) -> RoomID | None:
         if not self.is_direct and not isinstance(info, (Group, GroupV2, GroupV2ID)):
             raise ValueError(f"Unexpected type for creating group portal: {type(info)}")
         elif self.is_direct and not isinstance(info, (Contact, Profile, Address)):
@@ -1303,7 +1284,7 @@ class Portal(DBPortal, BasePortal):
         async with self._create_room_lock:
             return await self._create_matrix_room(source, info)
 
-    def _get_invite_content(self, double_puppet: Optional["p.Puppet"]) -> Dict[str, Any]:
+    def _get_invite_content(self, double_puppet: p.Puppet | None) -> dict[str, Any]:
         invite_content = {}
         if double_puppet:
             invite_content["fi.mau.will_auto_accept"] = True
@@ -1311,7 +1292,7 @@ class Portal(DBPortal, BasePortal):
             invite_content["is_direct"] = True
         return invite_content
 
-    async def _update_matrix_room(self, source: "u.User", info: ChatInfo) -> None:
+    async def _update_matrix_room(self, source: u.User, info: ChatInfo) -> None:
         puppet = await p.Puppet.get_by_custom_mxid(source.mxid)
         await self.main_intent.invite_user(
             self.mxid,
@@ -1328,8 +1309,8 @@ class Portal(DBPortal, BasePortal):
 
     async def _get_power_levels(
         self,
-        levels: Optional[PowerLevelStateEventContent] = None,
-        info: Optional[ChatInfo] = None,
+        levels: PowerLevelStateEventContent | None = None,
+        info: ChatInfo | None = None,
         is_initial: bool = False,
     ) -> PowerLevelStateEventContent:
         levels = levels or PowerLevelStateEventContent()
@@ -1366,13 +1347,13 @@ class Portal(DBPortal, BasePortal):
             levels.users[self.main_intent.mxid] = 9001 if is_initial else 100
         return levels
 
-    async def _create_matrix_room(self, source: "u.User", info: ChatInfo) -> Optional[RoomID]:
+    async def _create_matrix_room(self, source: u.User, info: ChatInfo) -> RoomID | None:
         if self.mxid:
             await self._update_matrix_room(source, info)
             return self.mxid
         await self.update_info(source, info)
         self.log.debug("Creating Matrix room")
-        name: Optional[str] = None
+        name: str | None = None
         power_levels = await self._get_power_levels(info=info, is_initial=True)
         initial_state = [
             {
@@ -1440,7 +1421,7 @@ class Portal(DBPortal, BasePortal):
             try:
                 await self.az.intent.ensure_joined(self.mxid)
             except Exception:
-                self.log.warning("Failed to add bridge bot " f"to new private chat {self.mxid}")
+                self.log.warning("Failed to add bridge bot to new private chat {self.mxid}")
 
         puppet = await p.Puppet.get_by_custom_mxid(source.mxid)
         await self.main_intent.invite_user(
@@ -1492,17 +1473,15 @@ class Portal(DBPortal, BasePortal):
         await self.update()
 
     @classmethod
-    def all_with_room(cls) -> AsyncGenerator["Portal", None]:
+    def all_with_room(cls) -> AsyncGenerator[Portal, None]:
         return cls._db_to_portals(super().all_with_room())
 
     @classmethod
-    def find_private_chats_with(cls, other_user: Address) -> AsyncGenerator["Portal", None]:
+    def find_private_chats_with(cls, other_user: Address) -> AsyncGenerator[Portal, None]:
         return cls._db_to_portals(super().find_private_chats_with(other_user))
 
     @classmethod
-    async def _db_to_portals(
-        cls, query: Awaitable[List["Portal"]]
-    ) -> AsyncGenerator["Portal", None]:
+    async def _db_to_portals(cls, query: Awaitable[list[Portal]]) -> AsyncGenerator[Portal, None]:
         portals = await query
         for index, portal in enumerate(portals):
             try:
@@ -1513,7 +1492,7 @@ class Portal(DBPortal, BasePortal):
 
     @classmethod
     @async_getter_lock
-    async def get_by_mxid(cls, mxid: RoomID) -> Optional["Portal"]:
+    async def get_by_mxid(cls, mxid: RoomID) -> Portal | None:
         try:
             return cls.by_mxid[mxid]
         except KeyError:
@@ -1528,8 +1507,8 @@ class Portal(DBPortal, BasePortal):
 
     @classmethod
     async def get_by_chat_id(
-        cls, chat_id: Union[GroupID, Address], *, receiver: str = "", create: bool = False
-    ) -> Optional["Portal"]:
+        cls, chat_id: GroupID | Address, *, receiver: str = "", create: bool = False
+    ) -> Portal | None:
         if isinstance(chat_id, str):
             receiver = ""
         elif not isinstance(chat_id, Address):
@@ -1545,8 +1524,8 @@ class Portal(DBPortal, BasePortal):
     @classmethod
     @async_getter_lock
     async def _get_by_chat_id(
-        cls, best_id: str, receiver: str, *, create: bool, chat_id: Union[GroupID, Address]
-    ) -> Optional["Portal"]:
+        cls, best_id: str, receiver: str, *, create: bool, chat_id: GroupID | Address
+    ) -> Portal | None:
         try:
             return cls.by_chat_id[(best_id, receiver)]
         except KeyError:

+ 29 - 40
mautrix_signal/puppet.py

@@ -1,5 +1,5 @@
 # mautrix-signal - A Matrix-Signal puppeting bridge
-# Copyright (C) 2020 Tulir Asokan
+# Copyright (C) 2021 Tulir Asokan
 #
 # This program is free software: you can redistribute it and/or modify
 # it under the terms of the GNU Affero General Public License as published by
@@ -13,17 +13,9 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import (
-    TYPE_CHECKING,
-    AsyncGenerator,
-    AsyncIterable,
-    Awaitable,
-    Dict,
-    Optional,
-    Tuple,
-    Union,
-    cast,
-)
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
 from uuid import UUID
 import asyncio
 import hashlib
@@ -45,8 +37,7 @@ from yarl import URL
 
 from mausignald.types import Address, Contact, Profile
 
-from . import portal as p
-from . import user as u
+from . import portal as p, user as u
 from .config import Config
 from .db import Puppet as DBPuppet
 
@@ -60,9 +51,9 @@ except ImportError:
 
 
 class Puppet(DBPuppet, BasePuppet):
-    by_uuid: Dict[UUID, "Puppet"] = {}
-    by_number: Dict[str, "Puppet"] = {}
-    by_custom_mxid: Dict[UserID, "Puppet"] = {}
+    by_uuid: dict[UUID, Puppet] = {}
+    by_number: dict[str, Puppet] = {}
+    by_custom_mxid: dict[UserID, Puppet] = {}
     hs_domain: str
     mxid_template: SimpleTemplate[str]
 
@@ -76,19 +67,19 @@ class Puppet(DBPuppet, BasePuppet):
 
     def __init__(
         self,
-        uuid: Optional[UUID],
-        number: Optional[str],
-        name: Optional[str] = None,
-        avatar_url: Optional[ContentURI] = None,
-        avatar_hash: Optional[str] = None,
+        uuid: UUID | None,
+        number: str | None,
+        name: str | None = None,
+        avatar_url: ContentURI | None = None,
+        avatar_hash: str | None = None,
         name_set: bool = False,
         avatar_set: bool = False,
         uuid_registered: bool = False,
         number_registered: bool = False,
-        custom_mxid: Optional[UserID] = None,
-        access_token: Optional[str] = None,
-        next_batch: Optional[SyncToken] = None,
-        base_url: Optional[URL] = None,
+        custom_mxid: UserID | None = None,
+        access_token: str | None = None,
+        next_batch: SyncToken | None = None,
+        base_url: URL | None = None,
     ) -> None:
         super().__init__(
             uuid=uuid,
@@ -142,7 +133,7 @@ class Puppet(DBPuppet, BasePuppet):
         cls.login_device_name = "Signal Bridge"
         return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
 
-    def intent_for(self, portal: "p.Portal") -> IntentAPI:
+    def intent_for(self, portal: p.Portal) -> IntentAPI:
         if portal.chat_id == self.address:
             return self.default_mxid_intent
         return self.intent
@@ -237,7 +228,7 @@ class Puppet(DBPuppet, BasePuppet):
         except Exception:
             self.log.warning("Failed to migrate power levels", exc_info=True)
 
-    async def update_info(self, info: Union[Profile, Contact, Address]) -> None:
+    async def update_info(self, info: Profile | Contact | Address) -> None:
         address = info.address if isinstance(info, (Contact, Profile)) else info
         if address.uuid and not self.uuid:
             await self.handle_uuid_receive(address.uuid)
@@ -276,7 +267,7 @@ class Puppet(DBPuppet, BasePuppet):
         return phonenumbers.format_number(parsed, fmt)
 
     @classmethod
-    def _get_displayname(cls, address: Address, name: Optional[str]) -> str:
+    def _get_displayname(cls, address: Address, name: str | None) -> str:
         names = name.split("\x00") if name else []
         data = {
             "first_name": names[0] if len(names) > 0 else "",
@@ -294,7 +285,7 @@ class Puppet(DBPuppet, BasePuppet):
 
         return cls.config["bridge.displayname_template"].format(**data)
 
-    async def _update_name(self, name: Optional[str]) -> bool:
+    async def _update_name(self, name: str | None) -> bool:
         name = self._get_displayname(self.address, name)
         if name != self.name or not self.name_set:
             self.name = name
@@ -309,10 +300,8 @@ class Puppet(DBPuppet, BasePuppet):
 
     @staticmethod
     async def upload_avatar(
-        self: Union["Puppet", "p.Portal"],
-        path: str,
-        intent: IntentAPI,
-    ) -> Union[bool, Tuple[str, ContentURI]]:
+        self: Puppet | p.Portal, path: str, intent: IntentAPI
+    ) -> bool | tuple[str, ContentURI]:
         if not path:
             return False
         if not path.startswith("/"):
@@ -372,7 +361,7 @@ class Puppet(DBPuppet, BasePuppet):
         await self.update()
 
     @classmethod
-    async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional["Puppet"]:
+    async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Puppet | None:
         address = cls.get_id_from_mxid(mxid)
         if not address:
             return None
@@ -380,7 +369,7 @@ class Puppet(DBPuppet, BasePuppet):
 
     @classmethod
     @async_getter_lock
-    async def get_by_custom_mxid(cls, mxid: UserID) -> Optional["Puppet"]:
+    async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
         try:
             return cls.by_custom_mxid[mxid]
         except KeyError:
@@ -394,7 +383,7 @@ class Puppet(DBPuppet, BasePuppet):
         return None
 
     @classmethod
-    def get_id_from_mxid(cls, mxid: UserID) -> Optional[Address]:
+    def get_id_from_mxid(cls, mxid: UserID) -> Address | None:
         identifier = cls.mxid_template.parse(mxid)
         if not identifier:
             return None
@@ -418,7 +407,7 @@ class Puppet(DBPuppet, BasePuppet):
 
     @classmethod
     @async_getter_lock
-    async def get_by_address(cls, address: Address, create: bool = True) -> Optional["Puppet"]:
+    async def get_by_address(cls, address: Address, create: bool = True) -> Puppet | None:
         puppet = await cls._get_by_address(address, create)
         if puppet and address.uuid and not puppet.uuid:
             # We found a UUID for this user, store it ASAP
@@ -426,7 +415,7 @@ class Puppet(DBPuppet, BasePuppet):
         return puppet
 
     @classmethod
-    async def _get_by_address(cls, address: Address, create: bool = True) -> Optional["Puppet"]:
+    async def _get_by_address(cls, address: Address, create: bool = True) -> Puppet | None:
         if not address.is_valid:
             raise ValueError("Empty address")
         if address.uuid:
@@ -454,7 +443,7 @@ class Puppet(DBPuppet, BasePuppet):
         return None
 
     @classmethod
-    async def all_with_custom_mxid(cls) -> AsyncGenerator["Puppet", None]:
+    async def all_with_custom_mxid(cls) -> AsyncGenerator[Puppet, None]:
         puppets = await super().all_with_custom_mxid()
         puppet: cls
         for index, puppet in enumerate(puppets):

+ 12 - 14
mautrix_signal/signal.py

@@ -1,5 +1,5 @@
 # mautrix-signal - A Matrix-Signal puppeting bridge
-# Copyright (C) 2020 Tulir Asokan
+# Copyright (C) 2021 Tulir Asokan
 #
 # This program is free software: you can redistribute it and/or modify
 # it under the terms of the GNU Affero General Public License as published by
@@ -13,7 +13,9 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import TYPE_CHECKING, List, Optional
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
 import asyncio
 import logging
 
@@ -32,9 +34,7 @@ from mausignald.types import (
     WebsocketConnectionStateChangeEvent,
 )
 
-from . import portal as po
-from . import puppet as pu
-from . import user as u
+from . import portal as po, puppet as pu, user as u
 from .db import Message as DBMessage
 
 if TYPE_CHECKING:
@@ -99,10 +99,10 @@ class SignalHandler(SignaldClient):
 
     async def handle_message(
         self,
-        user: "u.User",
-        sender: "pu.Puppet",
+        user: u.User,
+        sender: pu.Puppet,
         msg: MessageData,
-        addr_override: Optional[Address] = None,
+        addr_override: Address | None = None,
     ) -> None:
         if msg.profile_key_update:
             self.log.debug("Ignoring profile key update")
@@ -123,7 +123,7 @@ class SignalHandler(SignaldClient):
                 return
         if not portal.mxid:
             await portal.create_matrix_room(
-                user, (msg.group_v2 or msg.group or addr_override or sender.address)
+                user, msg.group_v2 or msg.group or addr_override or sender.address
             )
             if not portal.mxid:
                 user.log.debug(
@@ -145,7 +145,7 @@ class SignalHandler(SignaldClient):
             await portal.update_expires_in_seconds(sender, msg.expires_in_seconds)
 
     @staticmethod
-    async def handle_own_receipts(sender: "pu.Puppet", receipts: List[OwnReadReceipt]) -> None:
+    async def handle_own_receipts(sender: pu.Puppet, receipts: list[OwnReadReceipt]) -> None:
         for receipt in receipts:
             puppet = await pu.Puppet.get_by_address(receipt.sender, create=False)
             if not puppet:
@@ -159,9 +159,7 @@ class SignalHandler(SignaldClient):
             await sender.intent_for(portal).mark_read(portal.mxid, message.mxid)
 
     @staticmethod
-    async def handle_typing(
-        user: "u.User", sender: "pu.Puppet", typing: TypingNotification
-    ) -> None:
+    async def handle_typing(user: u.User, sender: pu.Puppet, typing: TypingNotification) -> None:
         if typing.group_id:
             portal = await po.Portal.get_by_chat_id(typing.group_id)
         else:
@@ -174,7 +172,7 @@ class SignalHandler(SignaldClient):
         )
 
     @staticmethod
-    async def handle_receipt(sender: "pu.Puppet", receipt: Receipt) -> None:
+    async def handle_receipt(sender: pu.Puppet, receipt: Receipt) -> None:
         if receipt.type != ReceiptType.READ:
             return
         messages = await DBMessage.find_by_timestamps(receipt.timestamps)

+ 22 - 23
mautrix_signal/user.py

@@ -13,7 +13,9 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Union, cast
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, AsyncGenerator, cast
 from asyncio.tasks import sleep
 from datetime import datetime
 from uuid import UUID
@@ -35,8 +37,7 @@ from mausignald.types import (
     WebsocketConnectionStateChangeEvent,
 )
 
-from . import portal as po
-from . import puppet as pu
+from . import portal as po, puppet as pu
 from .config import Config
 from .db import User as DBUser
 
@@ -55,9 +56,9 @@ BridgeState.human_readable_errors.update(
 
 
 class User(DBUser, BaseUser):
-    by_mxid: Dict[UserID, "User"] = {}
-    by_username: Dict[str, "User"] = {}
-    by_uuid: Dict[UUID, "User"] = {}
+    by_mxid: dict[UserID, User] = {}
+    by_username: dict[str, User] = {}
+    by_uuid: dict[UUID, User] = {}
     config: Config
     az: AppService
     loop: asyncio.AbstractEventLoop
@@ -70,15 +71,15 @@ class User(DBUser, BaseUser):
     _sync_lock: asyncio.Lock
     _notice_room_lock: asyncio.Lock
     _connected: bool
-    _websocket_connection_state: Optional[BridgeStateEvent]
-    _latest_non_transient_disconnect_state: Optional[datetime]
+    _websocket_connection_state: BridgeStateEvent | None
+    _latest_non_transient_disconnect_state: datetime | None
 
     def __init__(
         self,
         mxid: UserID,
-        username: Optional[str] = None,
-        uuid: Optional[UUID] = None,
-        notice_room: Optional[RoomID] = None,
+        username: str | None = None,
+        uuid: UUID | None = None,
+        notice_room: RoomID | None = None,
     ) -> None:
         super().__init__(mxid=mxid, username=username, uuid=uuid, notice_room=notice_room)
         BaseUser.__init__(self)
@@ -97,7 +98,7 @@ class User(DBUser, BaseUser):
         cls.loop = bridge.loop
 
     @property
-    def address(self) -> Optional[Address]:
+    def address(self) -> Address | None:
         if not self.username:
             return None
         return Address(uuid=self.uuid, number=self.username)
@@ -105,7 +106,7 @@ class User(DBUser, BaseUser):
     async def is_logged_in(self) -> bool:
         return bool(self.username)
 
-    async def needs_relay(self, portal: "po.Portal") -> bool:
+    async def needs_relay(self, portal: po.Portal) -> bool:
         return not await self.is_logged_in() or (
             portal.is_direct and portal.receiver != self.username
         )
@@ -136,7 +137,7 @@ class User(DBUser, BaseUser):
             puppet = await self.get_puppet()
             state.remote_name = puppet.name or self.username
 
-    async def get_bridge_states(self) -> List[BridgeState]:
+    async def get_bridge_states(self) -> list[BridgeState]:
         if not self.username:
             return []
         state = BridgeState(state_event=BridgeStateEvent.UNKNOWN_ERROR)
@@ -146,7 +147,7 @@ class User(DBUser, BaseUser):
             state.state_event = BridgeStateEvent.TRANSIENT_DISCONNECT
         return [state]
 
-    async def get_puppet(self) -> Optional["pu.Puppet"]:
+    async def get_puppet(self) -> pu.Puppet | None:
         if not self.address:
             return None
         return await pu.Puppet.get_by_address(self.address)
@@ -268,9 +269,7 @@ class User(DBUser, BaseUser):
         except Exception:
             self.log.exception("Error while syncing groups")
 
-    async def sync_contact(
-        self, contact: Union[Profile, Address], create_portals: bool = False
-    ) -> None:
+    async def sync_contact(self, contact: Profile | Address, create_portals: bool = False) -> None:
         self.log.trace("Syncing contact %s", contact)
         if isinstance(contact, Address):
             address = contact
@@ -337,7 +336,7 @@ class User(DBUser, BaseUser):
 
     @classmethod
     @async_getter_lock
-    async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional["User"]:
+    async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> User | None:
         # Never allow ghosts to be users
         if pu.Puppet.get_id_from_mxid(mxid):
             return None
@@ -361,7 +360,7 @@ class User(DBUser, BaseUser):
 
     @classmethod
     @async_getter_lock
-    async def get_by_username(cls, username: str) -> Optional["User"]:
+    async def get_by_username(cls, username: str) -> User | None:
         try:
             return cls.by_username[username]
         except KeyError:
@@ -376,7 +375,7 @@ class User(DBUser, BaseUser):
 
     @classmethod
     @async_getter_lock
-    async def get_by_uuid(cls, uuid: UUID) -> Optional["User"]:
+    async def get_by_uuid(cls, uuid: UUID) -> User | None:
         try:
             return cls.by_uuid[uuid]
         except KeyError:
@@ -390,7 +389,7 @@ class User(DBUser, BaseUser):
         return None
 
     @classmethod
-    async def get_by_address(cls, address: Address) -> Optional["User"]:
+    async def get_by_address(cls, address: Address) -> User | None:
         if address.uuid:
             return await cls.get_by_uuid(address.uuid)
         elif address.number:
@@ -399,7 +398,7 @@ class User(DBUser, BaseUser):
             raise ValueError("Given address is blank")
 
     @classmethod
-    async def all_logged_in(cls) -> AsyncGenerator["User", None]:
+    async def all_logged_in(cls) -> AsyncGenerator[User, None]:
         users = await super().all_logged_in()
         user: cls
         for user in users:

+ 1 - 2
mautrix_signal/util/color_log.py

@@ -13,8 +13,7 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from mautrix.util.logging.color import PREFIX, RESET
-from mautrix.util.logging.color import ColorFormatter as BaseColorFormatter
+from mautrix.util.logging.color import PREFIX, RESET, ColorFormatter as BaseColorFormatter
 
 MAUSIGNALD_COLOR = PREFIX + "35;1m"  # magenta
 

+ 7 - 5
mautrix_signal/web/provisioning_api.py

@@ -13,7 +13,9 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import TYPE_CHECKING, Awaitable, Dict
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
 import asyncio
 import json
 import logging
@@ -53,7 +55,7 @@ class ProvisioningAPI:
         self.app.router.add_post("/api/logout", self.logout)
 
     @property
-    def _acao_headers(self) -> Dict[str, str]:
+    def _acao_headers(self) -> dict[str, str]:
         return {
             "Access-Control-Allow-Origin": "*",
             "Access-Control-Allow-Headers": "Authorization, Content-Type",
@@ -61,7 +63,7 @@ class ProvisioningAPI:
         }
 
     @property
-    def _headers(self) -> Dict[str, str]:
+    def _headers(self) -> dict[str, str]:
         return {
             **self._acao_headers,
             "Content-Type": "application/json",
@@ -163,7 +165,7 @@ class ProvisioningAPI:
             raise
         except Exception:
             self.log.exception(
-                "Fatal error while waiting for linking to finish " f"(session {session_id})"
+                "Fatal error while waiting for linking to finish (session {session_id})"
             )
             raise
         else:
@@ -182,7 +184,7 @@ class ProvisioningAPI:
             account = await asyncio.shield(self._shielded_link(user, session_id, device_name))
         except asyncio.CancelledError:
             self.log.warning(
-                f"Client cancelled link wait request ({session_id})" " before it finished"
+                f"Client cancelled link wait request ({session_id}) before it finished"
             )
         except TimeoutException:
             raise web.HTTPBadRequest(

+ 1 - 0
pyproject.toml

@@ -2,6 +2,7 @@
 profile = "black"
 force_to_top = "typing"
 from_first = true
+combine_as_imports = true
 line_length = 99
 
 [tool.black]