Ver código fonte

Add locks in cached async get methods

Tulir Asokan 4 anos atrás
pai
commit
fe4657a269

+ 1 - 1
mautrix_instagram/db/portal.py

@@ -83,7 +83,7 @@ class Portal:
     @classmethod
     async def find_private_chats_with(cls, other_user: int) -> List['Portal']:
         q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, encrypted FROM portal "
-             "WHERE other_user=$1")
+             "WHERE other_user_pk=$1")
         rows = await cls.db.fetch(q, other_user)
         return [cls._from_row(row) for row in rows]
 

+ 7 - 5
mautrix_instagram/portal.py

@@ -28,7 +28,7 @@ from mauigpapi.types import (Thread, ThreadUser, ThreadItem, RegularMediaItem, M
                              VoiceMediaItem, ExpiredMediaItem, MessageSyncMessage, ReelShareType,
                              TypingStatus)
 from mautrix.appservice import AppService, IntentAPI
-from mautrix.bridge import BasePortal, NotificationDisabler
+from mautrix.bridge import BasePortal, NotificationDisabler, async_getter_lock
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType, ImageInfo,
                            VideoInfo, MediaMessageEventContent, TextMessageEventContent, AudioInfo,
                            ContentURI, EncryptedFile, LocationMessageEventContent, Format, UserID)
@@ -899,6 +899,7 @@ class Portal(DBPortal, BasePortal):
                 yield portal
 
     @classmethod
+    @async_getter_lock
     async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
         try:
             return cls.by_mxid[mxid]
@@ -913,7 +914,8 @@ class Portal(DBPortal, BasePortal):
         return None
 
     @classmethod
-    async def get_by_thread_id(cls, thread_id: str, receiver: int,
+    @async_getter_lock
+    async def get_by_thread_id(cls, thread_id: str, *, receiver: int,
                                is_group: Optional[bool] = None,
                                other_user_pk: Optional[int] = None) -> Optional['Portal']:
         if is_group and receiver != 0:
@@ -928,7 +930,7 @@ class Portal(DBPortal, BasePortal):
             except KeyError:
                 pass
 
-        portal = cast(cls, await super().get_by_thread_id(thread_id, receiver,
+        portal = cast(cls, await super().get_by_thread_id(thread_id, receiver=receiver,
                                                           rec_must_match=is_group is not None))
         if portal is not None:
             await portal.postinit()
@@ -952,6 +954,6 @@ class Portal(DBPortal, BasePortal):
                 other_user_pk = receiver
             else:
                 other_user_pk = thread.users[0].pk
-        return await cls.get_by_thread_id(thread.thread_id, receiver, is_group=thread.is_group,
-                                          other_user_pk=other_user_pk)
+        return await cls.get_by_thread_id(thread.thread_id, receiver=receiver,
+                                          is_group=thread.is_group, other_user_pk=other_user_pk)
     # endregion

+ 4 - 2
mautrix_instagram/puppet.py

@@ -19,7 +19,7 @@ import os.path
 from yarl import URL
 
 from mauigpapi.types import BaseResponseUser
-from mautrix.bridge import BasePuppet
+from mautrix.bridge import BasePuppet, async_getter_lock
 from mautrix.appservice import IntentAPI
 from mautrix.types import ContentURI, UserID, SyncToken, RoomID
 from mautrix.util.simple_template import SimpleTemplate
@@ -160,6 +160,7 @@ class Puppet(DBPuppet, BasePuppet):
         return None
 
     @classmethod
+    @async_getter_lock
     async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
         try:
             return cls.by_custom_mxid[mxid]
@@ -182,7 +183,8 @@ class Puppet(DBPuppet, BasePuppet):
         return UserID(cls.mxid_template.format_full(twid))
 
     @classmethod
-    async def get_by_pk(cls, pk: int, create: bool = True) -> Optional['Puppet']:
+    @async_getter_lock
+    async def get_by_pk(cls, pk: int, *, create: bool = True) -> Optional['Puppet']:
         try:
             return cls.by_pk[pk]
         except KeyError:

+ 4 - 2
mautrix_instagram/user.py

@@ -25,7 +25,7 @@ from mauigpapi.mqtt import Connect, Disconnect, GraphQLSubscription, SkywalkerSu
 from mauigpapi.types import (CurrentUser, MessageSyncEvent, Operation, RealtimeDirectEvent,
                              ActivityIndicatorData, TypingStatus, ThreadSyncEvent)
 from mauigpapi.errors import IGNotLoggedInError, MQTTNotLoggedIn, MQTTNotConnected
-from mautrix.bridge import BaseUser
+from mautrix.bridge import BaseUser, async_getter_lock
 from mautrix.types import UserID, RoomID, EventID, TextMessageEventContent, MessageType
 from mautrix.appservice import AppService
 from mautrix.util.opt_prometheus import Summary, Gauge, async_time
@@ -388,7 +388,8 @@ class User(DBUser, BaseUser):
             self.by_igpk[self.igpk] = self
 
     @classmethod
-    async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['User']:
+    @async_getter_lock
+    async def get_by_mxid(cls, mxid: UserID, *, create: bool = True) -> Optional['User']:
         # Never allow ghosts to be users
         if pu.Puppet.get_id_from_mxid(mxid):
             return None
@@ -411,6 +412,7 @@ class User(DBUser, BaseUser):
         return None
 
     @classmethod
+    @async_getter_lock
     async def get_by_igpk(cls, igpk: int) -> Optional['User']:
         try:
             return cls.by_igpk[igpk]

+ 1 - 1
requirements.txt

@@ -4,7 +4,7 @@ commonmark>=0.8,<0.10
 aiohttp>=3,<4
 yarl>=1,<2
 attrs>=19.1
-mautrix>=0.8.11,<0.9
+mautrix>=0.8.13,<0.9
 asyncpg>=0.20,<0.22
 pycryptodome>=3,<4
 paho-mqtt>=1.5,<2