浏览代码

Fix portals not being stored in by_chat_id cache correctly

Tulir Asokan 3 年之前
父节点
当前提交
d7c81ea81a
共有 1 个文件被更改,包括 4 次插入3 次删除
  1. 4 3
      mautrix_signal/portal.py

+ 4 - 3
mautrix_signal/portal.py

@@ -30,7 +30,7 @@ import os
 from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
 from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
                               Attachment, GroupID, GroupV2ID, GroupV2, Mention, Sticker,
                               Attachment, GroupID, GroupV2ID, GroupV2, Mention, Sticker,
                               GroupAccessControl, AccessControlMode, GroupMemberRole)
                               GroupAccessControl, AccessControlMode, GroupMemberRole)
-from mausignald.errors import AuthorizationFailedException, RPCError, ResponseError
+from mausignald.errors import RPCError, ResponseError
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.appservice import AppService, IntentAPI
 from mautrix.bridge import BasePortal, async_getter_lock
 from mautrix.bridge import BasePortal, async_getter_lock
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType, Format,
 from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType, Format,
@@ -44,6 +44,7 @@ from mautrix.errors import MatrixError, MForbidden, IntentError
 from .db import Portal as DBPortal, Message as DBMessage, Reaction as DBReaction
 from .db import Portal as DBPortal, Message as DBMessage, Reaction as DBReaction
 from .config import Config
 from .config import Config
 from .formatter import matrix_to_signal, signal_to_matrix
 from .formatter import matrix_to_signal, signal_to_matrix
+from .util import id_to_str
 from . import user as u, puppet as p, matrix as m, signal as s
 from . import user as u, puppet as p, matrix as m, signal as s
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
@@ -1185,7 +1186,7 @@ class Portal(DBPortal, BasePortal):
     # region Database getters
     # region Database getters
 
 
     async def _postinit(self) -> None:
     async def _postinit(self) -> None:
-        self.by_chat_id[(self.chat_id, self.receiver)] = self
+        self.by_chat_id[(self.chat_id_str, self.receiver)] = self
         if self.mxid:
         if self.mxid:
             self.by_mxid[self.mxid] = self
             self.by_mxid[self.mxid] = self
         if self.is_direct:
         if self.is_direct:
@@ -1247,7 +1248,7 @@ class Portal(DBPortal, BasePortal):
             raise ValueError(f"Invalid chat ID type {type(chat_id)}")
             raise ValueError(f"Invalid chat ID type {type(chat_id)}")
         elif not receiver:
         elif not receiver:
             raise ValueError("Direct chats must have a receiver")
             raise ValueError("Direct chats must have a receiver")
-        best_id = chat_id.best_identifier if isinstance(chat_id, Address) else chat_id
+        best_id = id_to_str(chat_id)
         portal = await cls._get_by_chat_id(best_id, receiver, create=create, chat_id=chat_id)
         portal = await cls._get_by_chat_id(best_id, receiver, create=create, chat_id=chat_id)
         if portal:
         if portal:
             portal.log.debug(f"get_by_chat_id({chat_id}, {receiver}) -> {hex(id(portal))}")
             portal.log.debug(f"get_by_chat_id({chat_id}, {receiver}) -> {hex(id(portal))}")