Explorar el Código

Use new wrapper for creating background tasks

Tulir Asokan hace 2 años
padre
commit
fb10822a7f

+ 3 - 2
mausignald/rpc.py

@@ -11,6 +11,7 @@ import asyncio
 import json
 import logging
 
+from mautrix.util import background_task
 from mautrix.util.logging import TraceLogger
 
 from .errors import NotConnected, UnexpectedError, UnexpectedResponse, make_response_error
@@ -90,7 +91,7 @@ class SignaldRPCClient:
 
         read_loop = asyncio.create_task(self._try_read_loop())
         self.is_connected = True
-        asyncio.create_task(self._run_rpc_handler(CONNECT_EVENT, {}))
+        background_task.create(self._run_rpc_handler(CONNECT_EVENT, {}))
         self._connect_future.set_result(True)
 
         await read_loop
@@ -164,7 +165,7 @@ class SignaldRPCClient:
 
         req_id = req.get("id")
         if req_id is None:
-            asyncio.create_task(self._run_rpc_handler(req_type, req))
+            background_task.create(self._run_rpc_handler(req_type, req))
         else:
             self._run_response_handlers(UUID(req_id), req_type, req)
 

+ 2 - 2
mautrix_signal/commands/signal.py

@@ -16,7 +16,6 @@
 from __future__ import annotations
 
 from typing import Awaitable
-import asyncio
 import base64
 import json
 
@@ -33,6 +32,7 @@ from mautrix.types import (
     PowerLevelStateEventContent,
     RoomID,
 )
+from mautrix.util import background_task
 
 from .. import portal as po, puppet as pu
 from ..util import normalize_number, user_has_power_level
@@ -533,7 +533,7 @@ async def _locked_confirm_bridge(
     await portal.save()
     await portal.update_bridge_info()
 
-    asyncio.create_task(portal.update_matrix_room(evt.sender, group))
+    background_task.create(portal.update_matrix_room(evt.sender, group))
 
     await warn_missing_power(levels, evt)
 

+ 14 - 14
mautrix_signal/portal.py

@@ -83,7 +83,7 @@ from mautrix.types import (
     UserID,
     VideoInfo,
 )
-from mautrix.util import ffmpeg, variation_selector
+from mautrix.util import background_task, ffmpeg, variation_selector
 from mautrix.util.format_duration import format_duration
 from mautrix.util.message_send_checkpoint import MessageSendCheckpointStatus
 
@@ -320,7 +320,7 @@ class Portal(DBPortal, BasePortal):
             )
             await sender.handle_auth_failure(e)
             await self._send_error_notice("message", e)
-            asyncio.create_task(self._send_message_status(event_id, e))
+            background_task.create(self._send_message_status(event_id, e))
 
     async def _send_error_notice(self, type_name: str, err: Exception) -> None:
         if not self.config["bridge.delivery_error_reports"]:
@@ -513,7 +513,7 @@ class Portal(DBPortal, BasePortal):
             dm = DisappearingMessage(self.mxid, event_id, self.expiration_time)
             dm.start_timer()
             await dm.insert()
-            asyncio.create_task(self._disappear_event(dm))
+            background_task.create(self._disappear_event(dm))
 
         sender.send_remote_checkpoint(
             MessageSendCheckpointStatus.SUCCESS,
@@ -524,7 +524,7 @@ class Portal(DBPortal, BasePortal):
             retry_num=retry_count,
         )
         await self._send_delivery_receipt(event_id)
-        asyncio.create_task(self._send_message_status(event_id, err=None))
+        background_task.create(self._send_message_status(event_id, err=None))
 
     async def _signal_send_with_retries(
         self,
@@ -607,7 +607,7 @@ class Portal(DBPortal, BasePortal):
             )
             await self._send_error_notice("reaction", e)
             await sender.handle_auth_failure(e)
-            asyncio.create_task(self._send_message_status(event_id, e))
+            background_task.create(self._send_message_status(event_id, e))
         else:
             sender.send_remote_checkpoint(
                 MessageSendCheckpointStatus.SUCCESS,
@@ -617,7 +617,7 @@ class Portal(DBPortal, BasePortal):
                 retry_num=retry_count,
             )
             await self._send_delivery_receipt(event_id)
-            asyncio.create_task(self._send_message_status(event_id, err=None))
+            background_task.create(self._send_message_status(event_id, err=None))
 
     async def _handle_matrix_reaction(
         self,
@@ -683,8 +683,8 @@ class Portal(DBPortal, BasePortal):
                     error=e,
                 )
                 await sender.handle_auth_failure(e)
-                asyncio.create_task(self._send_error_notice("message deletion", e))
-                asyncio.create_task(self._send_message_status(event_id, e))
+                background_task.create(self._send_error_notice("message deletion", e))
+                background_task.create(self._send_message_status(event_id, e))
             else:
                 self.log.trace(f"Removed {message} after Matrix redaction")
                 sender.send_remote_checkpoint(
@@ -694,7 +694,7 @@ class Portal(DBPortal, BasePortal):
                     EventType.ROOM_REDACTION,
                 )
                 await self._send_delivery_receipt(redaction_event_id)
-                asyncio.create_task(self._send_message_status(redaction_event_id, err=None))
+                background_task.create(self._send_message_status(redaction_event_id, err=None))
             return
 
         reaction = await DBReaction.get_by_mxid(event_id, self.mxid)
@@ -723,8 +723,8 @@ class Portal(DBPortal, BasePortal):
                     error=e,
                 )
                 await sender.handle_auth_failure(e)
-                asyncio.create_task(self._send_error_notice("reaction deletion", e))
-                asyncio.create_task(self._send_message_status(event_id, e))
+                background_task.create(self._send_error_notice("reaction deletion", e))
+                background_task.create(self._send_message_status(event_id, e))
             else:
                 self.log.trace(f"Removed {reaction} after Matrix redaction")
                 sender.send_remote_checkpoint(
@@ -734,7 +734,7 @@ class Portal(DBPortal, BasePortal):
                     EventType.ROOM_REDACTION,
                 )
                 await self._send_delivery_receipt(redaction_event_id)
-                asyncio.create_task(self._send_message_status(redaction_event_id, err=None))
+                background_task.create(self._send_message_status(redaction_event_id, err=None))
             return
 
         sender.send_remote_checkpoint(
@@ -745,7 +745,7 @@ class Portal(DBPortal, BasePortal):
             error="No message or reaction found for redaction",
         )
         status_err = UnknownReactionTarget("No message or reaction found for redaction")
-        asyncio.create_task(self._send_message_status(redaction_event_id, err=status_err))
+        background_task.create(self._send_message_status(redaction_event_id, err=status_err))
 
     async def handle_matrix_join(self, user: u.User) -> None:
         if self.is_direct or not await user.is_logged_in():
@@ -1314,7 +1314,7 @@ class Portal(DBPortal, BasePortal):
                 if sender.uuid == source.uuid:
                     dm.start_timer()
                     await dm.insert()
-                    asyncio.create_task(self._disappear_event(dm))
+                    background_task.create(self._disappear_event(dm))
                     self.log.debug(
                         f"{event_id} set to be redacted in {message.expires_in_seconds} seconds"
                     )

+ 2 - 1
mautrix_signal/puppet.py

@@ -36,6 +36,7 @@ from mautrix.types import (
     SyncToken,
     UserID,
 )
+from mautrix.util import background_task
 from mautrix.util.simple_template import SimpleTemplate
 
 from . import portal as p, signal, user as u
@@ -204,7 +205,7 @@ class Puppet(DBPuppet, BasePuppet):
                 update = await self._update_avatar(f"contact-{self.number}") or update
             if update:
                 await self.update()
-                asyncio.create_task(self._try_update_portal_meta())
+                background_task.create(self._try_update_portal_meta())
 
     @staticmethod
     def fmt_phone(number: str) -> str:

+ 3 - 2
mautrix_signal/signal.py

@@ -37,6 +37,7 @@ from mausignald.types import (
     WebsocketConnectionStateChangeEvent,
 )
 from mautrix.types import EventID, EventType, Format, MessageType, TextMessageEventContent
+from mautrix.util import background_task
 from mautrix.util.logging import TraceLogger
 from mautrix.util.message_send_checkpoint import MessageSendCheckpointStatus
 
@@ -216,7 +217,7 @@ class SignalHandler(SignaldClient):
         addr_override: Address | None = None,
     ) -> None:
         if msg.profile_key_update:
-            asyncio.create_task(user.sync_contact(sender.address, use_cache=False))
+            background_task.create(user.sync_contact(sender.address, use_cache=False))
             return
         if msg.group_v2:
             portal = await po.Portal.get_by_chat_id(msg.group_v2.id, create=True)
@@ -413,7 +414,7 @@ class SignalHandler(SignaldClient):
                 self.log.info(
                     f"Successfully subscribed {user.username}, running sync in background"
                 )
-                asyncio.create_task(user.sync())
+                background_task.create(user.sync())
         if self.delete_unknown_accounts:
             self.log.debug("Checking for unknown accounts to delete")
             for account in await self.list_accounts():

+ 6 - 5
mautrix_signal/user.py

@@ -34,6 +34,7 @@ from mausignald.types import (
 from mautrix.appservice import AppService
 from mautrix.bridge import AutologinError, BaseUser, async_getter_lock
 from mautrix.types import EventType, RoomID, UserID
+from mautrix.util import background_task
 from mautrix.util.bridge_state import BridgeState, BridgeStateEvent
 from mautrix.util.message_send_checkpoint import MessageSendCheckpointStatus
 from mautrix.util.opt_prometheus import Gauge
@@ -168,7 +169,7 @@ class User(DBUser, BaseUser):
         self._add_to_cache()
         await self.update()
         await self.bridge.signal.subscribe(self.username)
-        asyncio.create_task(self.sync())
+        background_task.create(self.sync())
         self._track_metric(METRIC_LOGGED_IN, True)
 
     def on_websocket_connection_state_change(
@@ -224,7 +225,7 @@ class User(DBUser, BaseUser):
                     self._latest_non_transient_bridge_state
                     and now > self._latest_non_transient_bridge_state
                 ):
-                    asyncio.create_task(self.push_bridge_state(bridge_state))
+                    background_task.create(self.push_bridge_state(bridge_state))
 
                 self._websocket_connection_state = bridge_state
 
@@ -236,7 +237,7 @@ class User(DBUser, BaseUser):
                     self._latest_non_transient_bridge_state
                     and now > self._latest_non_transient_bridge_state
                 ):
-                    asyncio.create_task(
+                    background_task.create(
                         self.push_bridge_state(
                             BridgeStateEvent.UNKNOWN_ERROR,
                             message="Failed to restore connection to Signal",
@@ -249,12 +250,12 @@ class User(DBUser, BaseUser):
                         "not transitioning to UNKNOWN_ERROR."
                     )
 
-            asyncio.create_task(wait_report_bridge_state())
+            background_task.create(wait_report_bridge_state())
         elif self._websocket_connection_state == bridge_state:
             self.log.info("Websocket state unchanged, not reporting new bridge state")
             self._latest_non_transient_bridge_state = now
         else:
-            asyncio.create_task(self.push_bridge_state(bridge_state))
+            background_task.create(self.push_bridge_state(bridge_state))
             self._latest_non_transient_bridge_state = now
             self._websocket_connection_state = bridge_state
 

+ 3 - 2
mautrix_signal/web/segment_analytics.py

@@ -1,11 +1,12 @@
 from __future__ import annotations
 
-import asyncio
 import logging
 
 from yarl import URL
 import aiohttp
 
+from mautrix.util import background_task
+
 from .. import user as u
 
 log = logging.getLogger("mau.web.public.analytics")
@@ -30,7 +31,7 @@ async def _track(user: u.User, event: str, properties: dict) -> None:
 
 def track(user: u.User, event: str, properties: dict | None = None):
     if segment_key:
-        asyncio.create_task(_track(user, event, properties or {}))
+        background_task.create(_track(user, event, properties or {}))
 
 
 def init(key, user_id: str | None = None):

+ 1 - 1
requirements.txt

@@ -4,5 +4,5 @@ commonmark>=0.8,<0.10
 aiohttp>=3,<4
 yarl>=1,<2
 attrs>=19.1
-mautrix>=0.19.3,<0.20
+mautrix>=0.19.4,<0.20
 asyncpg>=0.20,<0.28