Browse Source

Add locks in cached async get methods. Might fix #48

Tulir Asokan 4 years ago
parent
commit
c4e8463212
4 changed files with 13 additions and 6 deletions
  1. 4 2
      mautrix_signal/portal.py
  2. 3 1
      mautrix_signal/puppet.py
  3. 5 2
      mautrix_signal/user.py
  4. 1 1
      requirements.txt

+ 4 - 2
mautrix_signal/portal.py

@@ -27,7 +27,7 @@ import os
 from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
 from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
                               Attachment, GroupID, GroupV2ID, GroupV2, Mention, Sticker)
                               Attachment, GroupID, GroupV2ID, GroupV2, Mention, Sticker)
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.appservice import AppService, IntentAPI
-from mautrix.bridge import BasePortal
+from mautrix.bridge import BasePortal, async_getter_lock
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
                            MessageEvent, EncryptedEvent, ContentURI, MediaMessageEventContent,
                            MessageEvent, EncryptedEvent, ContentURI, MediaMessageEventContent,
                            ImageInfo, VideoInfo, FileInfo, AudioInfo)
                            ImageInfo, VideoInfo, FileInfo, AudioInfo)
@@ -929,6 +929,7 @@ class Portal(DBPortal, BasePortal):
                 yield portal
                 yield portal
 
 
     @classmethod
     @classmethod
+    @async_getter_lock
     async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
     async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
         try:
         try:
             return cls.by_mxid[mxid]
             return cls.by_mxid[mxid]
@@ -943,7 +944,8 @@ class Portal(DBPortal, BasePortal):
         return None
         return None
 
 
     @classmethod
     @classmethod
-    async def get_by_chat_id(cls, chat_id: Union[GroupID, Address], receiver: str = "",
+    @async_getter_lock
+    async def get_by_chat_id(cls, chat_id: Union[GroupID, Address], *, receiver: str = "",
                              create: bool = False) -> Optional['Portal']:
                              create: bool = False) -> Optional['Portal']:
         if isinstance(chat_id, str):
         if isinstance(chat_id, str):
             receiver = ""
             receiver = ""

+ 3 - 1
mautrix_signal/puppet.py

@@ -23,7 +23,7 @@ import os.path
 from yarl import URL
 from yarl import URL
 
 
 from mausignald.types import Address, Contact, Profile
 from mausignald.types import Address, Contact, Profile
-from mautrix.bridge import BasePuppet
+from mautrix.bridge import BasePuppet, async_getter_lock
 from mautrix.appservice import IntentAPI
 from mautrix.appservice import IntentAPI
 from mautrix.types import UserID, SyncToken, RoomID, ContentURI
 from mautrix.types import UserID, SyncToken, RoomID, ContentURI
 from mautrix.errors import MForbidden
 from mautrix.errors import MForbidden
@@ -286,6 +286,7 @@ class Puppet(DBPuppet, BasePuppet):
         return await cls.get_by_address(address, create)
         return await cls.get_by_address(address, create)
 
 
     @classmethod
     @classmethod
+    @async_getter_lock
     async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
     async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
         try:
         try:
             return cls.by_custom_mxid[mxid]
             return cls.by_custom_mxid[mxid]
@@ -323,6 +324,7 @@ class Puppet(DBPuppet, BasePuppet):
         return UserID(cls.mxid_template.format_full(identifier))
         return UserID(cls.mxid_template.format_full(identifier))
 
 
     @classmethod
     @classmethod
+    @async_getter_lock
     async def get_by_address(cls, address: Address, create: bool = True) -> Optional['Puppet']:
     async def get_by_address(cls, address: Address, create: bool = True) -> Optional['Puppet']:
         puppet = await cls._get_by_address(address, create)
         puppet = await cls._get_by_address(address, create)
         if puppet and address.uuid and not puppet.uuid:
         if puppet and address.uuid and not puppet.uuid:

+ 5 - 2
mautrix_signal/user.py

@@ -21,7 +21,7 @@ import os.path
 import shutil
 import shutil
 
 
 from mausignald.types import Account, Address, Contact, Group, GroupV2, ListenEvent, ListenAction
 from mausignald.types import Account, Address, Contact, Group, GroupV2, ListenEvent, ListenAction
-from mautrix.bridge import BaseUser
+from mautrix.bridge import BaseUser, async_getter_lock
 from mautrix.types import UserID, RoomID
 from mautrix.types import UserID, RoomID
 from mautrix.appservice import AppService
 from mautrix.appservice import AppService
 from mautrix.util.opt_prometheus import Gauge
 from mautrix.util.opt_prometheus import Gauge
@@ -148,7 +148,8 @@ class User(DBUser, BaseUser):
             profile = None
             profile = None
         await puppet.update_info(profile or contact)
         await puppet.update_info(profile or contact)
         if create_portals:
         if create_portals:
-            portal = await po.Portal.get_by_chat_id(puppet.address, self.username, create=True)
+            portal = await po.Portal.get_by_chat_id(puppet.address, receiver=self.username,
+                                                    create=True)
             await portal.create_matrix_room(self, profile or contact)
             await portal.create_matrix_room(self, profile or contact)
 
 
     async def _sync_group(self, group: Group, create_portals: bool) -> None:
     async def _sync_group(self, group: Group, create_portals: bool) -> None:
@@ -197,6 +198,7 @@ class User(DBUser, BaseUser):
             self.by_username[self.username] = self
             self.by_username[self.username] = self
 
 
     @classmethod
     @classmethod
+    @async_getter_lock
     async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['User']:
     async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['User']:
         # Never allow ghosts to be users
         # Never allow ghosts to be users
         if pu.Puppet.get_id_from_mxid(mxid):
         if pu.Puppet.get_id_from_mxid(mxid):
@@ -220,6 +222,7 @@ class User(DBUser, BaseUser):
         return None
         return None
 
 
     @classmethod
     @classmethod
+    @async_getter_lock
     async def get_by_username(cls, username: str) -> Optional['User']:
     async def get_by_username(cls, username: str) -> Optional['User']:
         try:
         try:
             return cls.by_username[username]
             return cls.by_username[username]

+ 1 - 1
requirements.txt

@@ -4,5 +4,5 @@ commonmark>=0.8,<0.10
 aiohttp>=3,<4
 aiohttp>=3,<4
 yarl>=1,<2
 yarl>=1,<2
 attrs>=19.1
 attrs>=19.1
-mautrix>=0.8.11,<0.9
+mautrix>=0.8.13,<0.9
 asyncpg>=0.20,<0.22
 asyncpg>=0.20,<0.22