|
@@ -28,7 +28,7 @@ from mauigpapi.types import (Thread, ThreadUser, ThreadItem, RegularMediaItem, M
|
|
VoiceMediaItem, ExpiredMediaItem, MessageSyncMessage, ReelShareType,
|
|
VoiceMediaItem, ExpiredMediaItem, MessageSyncMessage, ReelShareType,
|
|
TypingStatus)
|
|
TypingStatus)
|
|
from mautrix.appservice import AppService, IntentAPI
|
|
from mautrix.appservice import AppService, IntentAPI
|
|
-from mautrix.bridge import BasePortal, NotificationDisabler
|
|
|
|
|
|
+from mautrix.bridge import BasePortal, NotificationDisabler, async_getter_lock
|
|
from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType, ImageInfo,
|
|
from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType, ImageInfo,
|
|
VideoInfo, MediaMessageEventContent, TextMessageEventContent, AudioInfo,
|
|
VideoInfo, MediaMessageEventContent, TextMessageEventContent, AudioInfo,
|
|
ContentURI, EncryptedFile, LocationMessageEventContent, Format, UserID)
|
|
ContentURI, EncryptedFile, LocationMessageEventContent, Format, UserID)
|
|
@@ -899,6 +899,7 @@ class Portal(DBPortal, BasePortal):
|
|
yield portal
|
|
yield portal
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
|
|
+ @async_getter_lock
|
|
async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
|
|
async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
|
|
try:
|
|
try:
|
|
return cls.by_mxid[mxid]
|
|
return cls.by_mxid[mxid]
|
|
@@ -913,7 +914,8 @@ class Portal(DBPortal, BasePortal):
|
|
return None
|
|
return None
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
- async def get_by_thread_id(cls, thread_id: str, receiver: int,
|
|
|
|
|
|
+ @async_getter_lock
|
|
|
|
+ async def get_by_thread_id(cls, thread_id: str, *, receiver: int,
|
|
is_group: Optional[bool] = None,
|
|
is_group: Optional[bool] = None,
|
|
other_user_pk: Optional[int] = None) -> Optional['Portal']:
|
|
other_user_pk: Optional[int] = None) -> Optional['Portal']:
|
|
if is_group and receiver != 0:
|
|
if is_group and receiver != 0:
|
|
@@ -928,7 +930,7 @@ class Portal(DBPortal, BasePortal):
|
|
except KeyError:
|
|
except KeyError:
|
|
pass
|
|
pass
|
|
|
|
|
|
- portal = cast(cls, await super().get_by_thread_id(thread_id, receiver,
|
|
|
|
|
|
+ portal = cast(cls, await super().get_by_thread_id(thread_id, receiver=receiver,
|
|
rec_must_match=is_group is not None))
|
|
rec_must_match=is_group is not None))
|
|
if portal is not None:
|
|
if portal is not None:
|
|
await portal.postinit()
|
|
await portal.postinit()
|
|
@@ -952,6 +954,6 @@ class Portal(DBPortal, BasePortal):
|
|
other_user_pk = receiver
|
|
other_user_pk = receiver
|
|
else:
|
|
else:
|
|
other_user_pk = thread.users[0].pk
|
|
other_user_pk = thread.users[0].pk
|
|
- return await cls.get_by_thread_id(thread.thread_id, receiver, is_group=thread.is_group,
|
|
|
|
- other_user_pk=other_user_pk)
|
|
|
|
|
|
+ return await cls.get_by_thread_id(thread.thread_id, receiver=receiver,
|
|
|
|
+ is_group=thread.is_group, other_user_pk=other_user_pk)
|
|
# endregion
|
|
# endregion
|