浏览代码

Add locks in cached async get methods

Tulir Asokan 4 年之前
父节点
当前提交
fe4657a269
共有 5 个文件被更改,包括 17 次插入11 次删除
  1. 1 1
      mautrix_instagram/db/portal.py
  2. 7 5
      mautrix_instagram/portal.py
  3. 4 2
      mautrix_instagram/puppet.py
  4. 4 2
      mautrix_instagram/user.py
  5. 1 1
      requirements.txt

+ 1 - 1
mautrix_instagram/db/portal.py

@@ -83,7 +83,7 @@ class Portal:
     @classmethod
     @classmethod
     async def find_private_chats_with(cls, other_user: int) -> List['Portal']:
     async def find_private_chats_with(cls, other_user: int) -> List['Portal']:
         q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, encrypted FROM portal "
         q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, encrypted FROM portal "
-             "WHERE other_user=$1")
+             "WHERE other_user_pk=$1")
         rows = await cls.db.fetch(q, other_user)
         rows = await cls.db.fetch(q, other_user)
         return [cls._from_row(row) for row in rows]
         return [cls._from_row(row) for row in rows]
 
 

+ 7 - 5
mautrix_instagram/portal.py

@@ -28,7 +28,7 @@ from mauigpapi.types import (Thread, ThreadUser, ThreadItem, RegularMediaItem, M
                              VoiceMediaItem, ExpiredMediaItem, MessageSyncMessage, ReelShareType,
                              VoiceMediaItem, ExpiredMediaItem, MessageSyncMessage, ReelShareType,
                              TypingStatus)
                              TypingStatus)
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.appservice import AppService, IntentAPI
-from mautrix.bridge import BasePortal, NotificationDisabler
+from mautrix.bridge import BasePortal, NotificationDisabler, async_getter_lock
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType, ImageInfo,
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType, ImageInfo,
                            VideoInfo, MediaMessageEventContent, TextMessageEventContent, AudioInfo,
                            VideoInfo, MediaMessageEventContent, TextMessageEventContent, AudioInfo,
                            ContentURI, EncryptedFile, LocationMessageEventContent, Format, UserID)
                            ContentURI, EncryptedFile, LocationMessageEventContent, Format, UserID)
@@ -899,6 +899,7 @@ class Portal(DBPortal, BasePortal):
                 yield portal
                 yield portal
 
 
     @classmethod
     @classmethod
+    @async_getter_lock
     async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
     async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
         try:
         try:
             return cls.by_mxid[mxid]
             return cls.by_mxid[mxid]
@@ -913,7 +914,8 @@ class Portal(DBPortal, BasePortal):
         return None
         return None
 
 
     @classmethod
     @classmethod
-    async def get_by_thread_id(cls, thread_id: str, receiver: int,
+    @async_getter_lock
+    async def get_by_thread_id(cls, thread_id: str, *, receiver: int,
                                is_group: Optional[bool] = None,
                                is_group: Optional[bool] = None,
                                other_user_pk: Optional[int] = None) -> Optional['Portal']:
                                other_user_pk: Optional[int] = None) -> Optional['Portal']:
         if is_group and receiver != 0:
         if is_group and receiver != 0:
@@ -928,7 +930,7 @@ class Portal(DBPortal, BasePortal):
             except KeyError:
             except KeyError:
                 pass
                 pass
 
 
-        portal = cast(cls, await super().get_by_thread_id(thread_id, receiver,
+        portal = cast(cls, await super().get_by_thread_id(thread_id, receiver=receiver,
                                                           rec_must_match=is_group is not None))
                                                           rec_must_match=is_group is not None))
         if portal is not None:
         if portal is not None:
             await portal.postinit()
             await portal.postinit()
@@ -952,6 +954,6 @@ class Portal(DBPortal, BasePortal):
                 other_user_pk = receiver
                 other_user_pk = receiver
             else:
             else:
                 other_user_pk = thread.users[0].pk
                 other_user_pk = thread.users[0].pk
-        return await cls.get_by_thread_id(thread.thread_id, receiver, is_group=thread.is_group,
-                                          other_user_pk=other_user_pk)
+        return await cls.get_by_thread_id(thread.thread_id, receiver=receiver,
+                                          is_group=thread.is_group, other_user_pk=other_user_pk)
     # endregion
     # endregion

+ 4 - 2
mautrix_instagram/puppet.py

@@ -19,7 +19,7 @@ import os.path
 from yarl import URL
 from yarl import URL
 
 
 from mauigpapi.types import BaseResponseUser
 from mauigpapi.types import BaseResponseUser
-from mautrix.bridge import BasePuppet
+from mautrix.bridge import BasePuppet, async_getter_lock
 from mautrix.appservice import IntentAPI
 from mautrix.appservice import IntentAPI
 from mautrix.types import ContentURI, UserID, SyncToken, RoomID
 from mautrix.types import ContentURI, UserID, SyncToken, RoomID
 from mautrix.util.simple_template import SimpleTemplate
 from mautrix.util.simple_template import SimpleTemplate
@@ -160,6 +160,7 @@ class Puppet(DBPuppet, BasePuppet):
         return None
         return None
 
 
     @classmethod
     @classmethod
+    @async_getter_lock
     async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
     async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
         try:
         try:
             return cls.by_custom_mxid[mxid]
             return cls.by_custom_mxid[mxid]
@@ -182,7 +183,8 @@ class Puppet(DBPuppet, BasePuppet):
         return UserID(cls.mxid_template.format_full(twid))
         return UserID(cls.mxid_template.format_full(twid))
 
 
     @classmethod
     @classmethod
-    async def get_by_pk(cls, pk: int, create: bool = True) -> Optional['Puppet']:
+    @async_getter_lock
+    async def get_by_pk(cls, pk: int, *, create: bool = True) -> Optional['Puppet']:
         try:
         try:
             return cls.by_pk[pk]
             return cls.by_pk[pk]
         except KeyError:
         except KeyError:

+ 4 - 2
mautrix_instagram/user.py

@@ -25,7 +25,7 @@ from mauigpapi.mqtt import Connect, Disconnect, GraphQLSubscription, SkywalkerSu
 from mauigpapi.types import (CurrentUser, MessageSyncEvent, Operation, RealtimeDirectEvent,
 from mauigpapi.types import (CurrentUser, MessageSyncEvent, Operation, RealtimeDirectEvent,
                              ActivityIndicatorData, TypingStatus, ThreadSyncEvent)
                              ActivityIndicatorData, TypingStatus, ThreadSyncEvent)
 from mauigpapi.errors import IGNotLoggedInError, MQTTNotLoggedIn, MQTTNotConnected
 from mauigpapi.errors import IGNotLoggedInError, MQTTNotLoggedIn, MQTTNotConnected
-from mautrix.bridge import BaseUser
+from mautrix.bridge import BaseUser, async_getter_lock
 from mautrix.types import UserID, RoomID, EventID, TextMessageEventContent, MessageType
 from mautrix.types import UserID, RoomID, EventID, TextMessageEventContent, MessageType
 from mautrix.appservice import AppService
 from mautrix.appservice import AppService
 from mautrix.util.opt_prometheus import Summary, Gauge, async_time
 from mautrix.util.opt_prometheus import Summary, Gauge, async_time
@@ -388,7 +388,8 @@ class User(DBUser, BaseUser):
             self.by_igpk[self.igpk] = self
             self.by_igpk[self.igpk] = self
 
 
     @classmethod
     @classmethod
-    async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['User']:
+    @async_getter_lock
+    async def get_by_mxid(cls, mxid: UserID, *, create: bool = True) -> Optional['User']:
         # Never allow ghosts to be users
         # Never allow ghosts to be users
         if pu.Puppet.get_id_from_mxid(mxid):
         if pu.Puppet.get_id_from_mxid(mxid):
             return None
             return None
@@ -411,6 +412,7 @@ class User(DBUser, BaseUser):
         return None
         return None
 
 
     @classmethod
     @classmethod
+    @async_getter_lock
     async def get_by_igpk(cls, igpk: int) -> Optional['User']:
     async def get_by_igpk(cls, igpk: int) -> Optional['User']:
         try:
         try:
             return cls.by_igpk[igpk]
             return cls.by_igpk[igpk]

+ 1 - 1
requirements.txt

@@ -4,7 +4,7 @@ commonmark>=0.8,<0.10
 aiohttp>=3,<4
 aiohttp>=3,<4
 yarl>=1,<2
 yarl>=1,<2
 attrs>=19.1
 attrs>=19.1
-mautrix>=0.8.11,<0.9
+mautrix>=0.8.13,<0.9
 asyncpg>=0.20,<0.22
 asyncpg>=0.20,<0.22
 pycryptodome>=3,<4
 pycryptodome>=3,<4
 paho-mqtt>=1.5,<2
 paho-mqtt>=1.5,<2