|
@@ -13,8 +13,8 @@
|
|
|
#
|
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
-from typing import (Dict, Tuple, Optional, List, Deque, Set, Any, Union, AsyncGenerator,
|
|
|
- Awaitable, TYPE_CHECKING, cast)
|
|
|
+from typing import (Dict, Tuple, Optional, List, Deque, Any, Union, AsyncGenerator, Awaitable,
|
|
|
+ TYPE_CHECKING, cast)
|
|
|
from collections import deque
|
|
|
from uuid import UUID, uuid4
|
|
|
import mimetypes
|
|
@@ -25,7 +25,7 @@ import time
|
|
|
import os
|
|
|
|
|
|
from mausignald.types import (Address, MessageData, Reaction, Quote, Group, Contact, Profile,
|
|
|
- Attachment)
|
|
|
+ Attachment, GroupID)
|
|
|
from mautrix.appservice import AppService, IntentAPI
|
|
|
from mautrix.bridge import BasePortal
|
|
|
from mautrix.types import (EventID, MessageEventContent, RoomID, EventType, MessageType,
|
|
@@ -57,7 +57,7 @@ ChatInfo = Union[Group, Contact, Profile, Address]
|
|
|
|
|
|
class Portal(DBPortal, BasePortal):
|
|
|
by_mxid: Dict[RoomID, 'Portal'] = {}
|
|
|
- by_chat_id: Dict[Tuple[Union[str, UUID], str], 'Portal'] = {}
|
|
|
+ by_chat_id: Dict[Tuple[str, str], 'Portal'] = {}
|
|
|
config: Config
|
|
|
matrix: 'm.MatrixHandler'
|
|
|
signal: 's.SignalHandler'
|
|
@@ -66,16 +66,16 @@ class Portal(DBPortal, BasePortal):
|
|
|
|
|
|
_main_intent: Optional[IntentAPI]
|
|
|
_create_room_lock: asyncio.Lock
|
|
|
- _msgts_dedup: Deque[Tuple[UUID, int]]
|
|
|
- _reaction_dedup: Deque[Tuple[UUID, int, str]]
|
|
|
+ _msgts_dedup: Deque[Tuple[Address, int]]
|
|
|
+ _reaction_dedup: Deque[Tuple[Address, int, str]]
|
|
|
_reaction_lock: asyncio.Lock
|
|
|
|
|
|
- def __init__(self, chat_id: Union[str, UUID], receiver: str, mxid: Optional[RoomID] = None,
|
|
|
+ def __init__(self, chat_id: Union[GroupID, Address], receiver: str, mxid: Optional[RoomID] = None,
|
|
|
name: Optional[str] = None, avatar_hash: Optional[str] = None,
|
|
|
avatar_url: Optional[ContentURI] = None, encrypted: bool = False) -> None:
|
|
|
super().__init__(chat_id, receiver, mxid, name, avatar_hash, avatar_url, encrypted)
|
|
|
self._create_room_lock = asyncio.Lock()
|
|
|
- self.log = self.log.getChild(str(chat_id))
|
|
|
+ self.log = self.log.getChild(self.chat_id_str)
|
|
|
self._main_intent = None
|
|
|
self._msgts_dedup = deque(maxlen=100)
|
|
|
self._reaction_dedup = deque(maxlen=100)
|
|
@@ -90,14 +90,15 @@ class Portal(DBPortal, BasePortal):
|
|
|
|
|
|
@property
|
|
|
def is_direct(self) -> bool:
|
|
|
- return isinstance(self.chat_id, UUID)
|
|
|
+ return isinstance(self.chat_id, Address)
|
|
|
|
|
|
- @property
|
|
|
- def recipient(self) -> Union[str, Address]:
|
|
|
- if self.is_direct:
|
|
|
- return Address(uuid=self.chat_id)
|
|
|
- else:
|
|
|
- return self.chat_id
|
|
|
+ def handle_uuid_receive(self, uuid: UUID) -> None:
|
|
|
+ if not self.is_direct or self.chat_id.uuid:
|
|
|
+ raise ValueError("handle_uuid_receive can only be used for private chat portals with "
|
|
|
+ "a phone number chat_id")
|
|
|
+ del self.by_chat_id[(self.chat_id_str, self.receiver)]
|
|
|
+ self.chat_id = Address(uuid=uuid)
|
|
|
+ self.by_chat_id[(self.chat_id_str, self.receiver)] = self
|
|
|
|
|
|
@classmethod
|
|
|
def init_cls(cls, bridge: 'SignalBridge') -> None:
|
|
@@ -131,9 +132,10 @@ class Portal(DBPortal, BasePortal):
|
|
|
await existing.edit(emoji=emoji, mxid=mxid, mx_room=message.mx_room)
|
|
|
else:
|
|
|
self.log.debug(f"_upsert_reaction inserting {mxid} (message: {message.mxid})")
|
|
|
- await DBReaction(mxid=mxid, mx_room=message.mx_room, emoji=emoji, author=sender.uuid,
|
|
|
+ await DBReaction(mxid=mxid, mx_room=message.mx_room, emoji=emoji,
|
|
|
signal_chat_id=self.chat_id, signal_receiver=self.receiver,
|
|
|
- msg_author=message.sender, msg_timestamp=message.timestamp).insert()
|
|
|
+ msg_author=message.sender, msg_timestamp=message.timestamp,
|
|
|
+ author=sender.address).insert()
|
|
|
|
|
|
# endregion
|
|
|
# region Matrix event handling
|
|
@@ -168,14 +170,14 @@ class Portal(DBPortal, BasePortal):
|
|
|
self.log.debug(f"Ignoring puppet-sent message by confirmed puppet user {sender.mxid}")
|
|
|
return
|
|
|
request_id = int(time.time() * 1000)
|
|
|
- self._msgts_dedup.appendleft((sender.uuid, request_id))
|
|
|
+ self._msgts_dedup.appendleft((sender.address, request_id))
|
|
|
|
|
|
quote = None
|
|
|
if message.get_reply_to():
|
|
|
reply = await DBMessage.get_by_mxid(message.get_reply_to(), self.mxid)
|
|
|
# TODO include actual text? either store in db or fetch event from homeserver
|
|
|
if reply is not None:
|
|
|
- quote = Quote(id=reply.timestamp, author=Address(uuid=reply.sender), text="")
|
|
|
+ quote = Quote(id=reply.timestamp, author=reply.sender, text="")
|
|
|
|
|
|
text = message.body
|
|
|
attachments: Optional[List[Attachment]] = None
|
|
@@ -188,9 +190,9 @@ class Portal(DBPortal, BasePortal):
|
|
|
attachments = [attachment]
|
|
|
text = None
|
|
|
self.log.trace("Formed outgoing attachment %s", attachment)
|
|
|
- await self.signal.send(username=sender.username, recipient=self.recipient, body=text,
|
|
|
+ await self.signal.send(username=sender.username, recipient=self.chat_id, body=text,
|
|
|
quote=quote, attachments=attachments, timestamp=request_id)
|
|
|
- msg = DBMessage(mxid=event_id, mx_room=self.mxid, sender=sender.uuid, timestamp=request_id,
|
|
|
+ msg = DBMessage(mxid=event_id, mx_room=self.mxid, sender=sender.address, timestamp=request_id,
|
|
|
signal_chat_id=self.chat_id, signal_receiver=self.receiver)
|
|
|
await msg.insert()
|
|
|
await self._send_delivery_receipt(event_id)
|
|
@@ -212,7 +214,7 @@ class Portal(DBPortal, BasePortal):
|
|
|
return
|
|
|
|
|
|
existing = await DBReaction.get_by_signal_id(self.chat_id, self.receiver, message.sender,
|
|
|
- message.timestamp, sender.uuid)
|
|
|
+ message.timestamp, sender.address)
|
|
|
if existing and existing.emoji == emoji:
|
|
|
return
|
|
|
|
|
@@ -220,9 +222,9 @@ class Portal(DBPortal, BasePortal):
|
|
|
self._reaction_dedup.appendleft(dedup_id)
|
|
|
async with self._reaction_lock:
|
|
|
reaction = Reaction(emoji=emoji, remove=False,
|
|
|
- target_author=Address(uuid=message.sender),
|
|
|
+ target_author=message.sender,
|
|
|
target_sent_timestamp=message.timestamp)
|
|
|
- await self.signal.react(username=sender.username, recipient=self.recipient,
|
|
|
+ await self.signal.react(username=sender.username, recipient=self.chat_id,
|
|
|
reaction=reaction)
|
|
|
await self._upsert_reaction(existing, self.main_intent, event_id, sender, message,
|
|
|
emoji)
|
|
@@ -239,9 +241,9 @@ class Portal(DBPortal, BasePortal):
|
|
|
try:
|
|
|
await reaction.delete()
|
|
|
remove_reaction = Reaction(emoji=reaction.emoji, remove=True,
|
|
|
- target_author=Address(uuid=reaction.msg_author),
|
|
|
+ target_author=reaction.msg_author,
|
|
|
target_sent_timestamp=reaction.msg_timestamp)
|
|
|
- await self.signal.react(username=sender.username, recipient=self.recipient,
|
|
|
+ await self.signal.react(username=sender.username, recipient=self.chat_id,
|
|
|
reaction=remove_reaction)
|
|
|
await self._send_delivery_receipt(redaction_event_id)
|
|
|
self.log.trace(f"Removed {reaction} after Matrix redaction")
|
|
@@ -263,21 +265,17 @@ class Portal(DBPortal, BasePortal):
|
|
|
# region Signal event handling
|
|
|
|
|
|
@staticmethod
|
|
|
- async def _find_address_uuid(address: Address) -> Optional[UUID]:
|
|
|
- if address.uuid:
|
|
|
- return address.uuid
|
|
|
+ async def _resolve_address(address: Address) -> Address:
|
|
|
puppet = await p.Puppet.get_by_address(address, create=False)
|
|
|
- if puppet and puppet.uuid:
|
|
|
- return puppet.uuid
|
|
|
- return None
|
|
|
+ return puppet.address
|
|
|
|
|
|
async def _find_quote_event_id(self, quote: Optional[Quote]
|
|
|
) -> Optional[Union[MessageEvent, EventID]]:
|
|
|
if not quote:
|
|
|
return None
|
|
|
|
|
|
- author_uuid = await self._find_address_uuid(quote.author)
|
|
|
- reply_msg = await DBMessage.get_by_signal_id(author_uuid, quote.id,
|
|
|
+ author_address = await self._resolve_address(quote.author)
|
|
|
+ reply_msg = await DBMessage.get_by_signal_id(author_address, quote.id,
|
|
|
self.chat_id, self.receiver)
|
|
|
if not reply_msg:
|
|
|
return None
|
|
@@ -291,13 +289,13 @@ class Portal(DBPortal, BasePortal):
|
|
|
|
|
|
async def handle_signal_message(self, source: 'u.User', sender: 'p.Puppet',
|
|
|
message: MessageData) -> None:
|
|
|
- if (sender.uuid, message.timestamp) in self._msgts_dedup:
|
|
|
+ if (sender.address, message.timestamp) in self._msgts_dedup:
|
|
|
self.log.debug(f"Ignoring message {message.timestamp} by {sender.uuid}"
|
|
|
" as it was already handled (message.timestamp in dedup queue)")
|
|
|
await self.signal.send_receipt(source.username, sender.address,
|
|
|
timestamps=[message.timestamp])
|
|
|
return
|
|
|
- old_message = await DBMessage.get_by_signal_id(sender.uuid, message.timestamp,
|
|
|
+ old_message = await DBMessage.get_by_signal_id(sender.address, message.timestamp,
|
|
|
self.chat_id, self.receiver)
|
|
|
if old_message is not None:
|
|
|
self.log.debug(f"Ignoring message {message.timestamp} by {sender.uuid}"
|
|
@@ -307,7 +305,7 @@ class Portal(DBPortal, BasePortal):
|
|
|
return
|
|
|
self.log.debug(f"Started handling message {message.timestamp} by {sender.uuid}")
|
|
|
self.log.trace(f"Message content: {message}")
|
|
|
- self._msgts_dedup.appendleft((sender.uuid, message.timestamp))
|
|
|
+ self._msgts_dedup.appendleft((sender.address, message.timestamp))
|
|
|
intent = sender.intent_for(self)
|
|
|
await intent.set_typing(self.mxid, False)
|
|
|
event_id = None
|
|
@@ -345,7 +343,7 @@ class Portal(DBPortal, BasePortal):
|
|
|
|
|
|
if event_id:
|
|
|
msg = DBMessage(mxid=event_id, mx_room=self.mxid,
|
|
|
- sender=sender.uuid, timestamp=message.timestamp,
|
|
|
+ sender=sender.address, timestamp=message.timestamp,
|
|
|
signal_chat_id=self.chat_id, signal_receiver=self.receiver)
|
|
|
await msg.insert()
|
|
|
await self.signal.send_receipt(source.username, sender.address,
|
|
@@ -404,20 +402,16 @@ class Portal(DBPortal, BasePortal):
|
|
|
return content
|
|
|
|
|
|
async def handle_signal_reaction(self, sender: 'p.Puppet', reaction: Reaction) -> None:
|
|
|
- author_uuid = await self._find_address_uuid(reaction.target_author)
|
|
|
+ author_address = await self._resolve_address(reaction.target_author)
|
|
|
target_id = reaction.target_sent_timestamp
|
|
|
- if author_uuid is None:
|
|
|
- self.log.warning(f"Failed to handle reaction from {sender.uuid}: "
|
|
|
- f"couldn't find UUID of {reaction.target_author}")
|
|
|
- return
|
|
|
async with self._reaction_lock:
|
|
|
- dedup_id = (author_uuid, target_id, reaction.emoji)
|
|
|
+ dedup_id = (author_address, target_id, reaction.emoji)
|
|
|
if dedup_id in self._reaction_dedup:
|
|
|
return
|
|
|
self._reaction_dedup.appendleft(dedup_id)
|
|
|
|
|
|
existing = await DBReaction.get_by_signal_id(self.chat_id, self.receiver,
|
|
|
- author_uuid, target_id, sender.uuid)
|
|
|
+ author_address, target_id, sender.address)
|
|
|
|
|
|
if reaction.remove:
|
|
|
if existing:
|
|
@@ -431,7 +425,7 @@ class Portal(DBPortal, BasePortal):
|
|
|
elif existing and existing.emoji == reaction.emoji:
|
|
|
return
|
|
|
|
|
|
- message = await DBMessage.get_by_signal_id(author_uuid, target_id,
|
|
|
+ message = await DBMessage.get_by_signal_id(author_address, target_id,
|
|
|
self.chat_id, self.receiver)
|
|
|
if not message:
|
|
|
self.log.debug(f"Ignoring reaction to unknown message {target_id}")
|
|
@@ -440,7 +434,7 @@ class Portal(DBPortal, BasePortal):
|
|
|
intent = sender.intent_for(self)
|
|
|
# TODO add variation selectors to emoji before sending to Matrix
|
|
|
mxid = await intent.react(message.mx_room, message.mxid, reaction.emoji)
|
|
|
- self.log.debug(f"{sender.uuid} reacted to {message.mxid} -> {mxid}")
|
|
|
+ self.log.debug(f"{sender.address} reacted to {message.mxid} -> {mxid}")
|
|
|
await self._upsert_reaction(existing, intent, mxid, sender, message, reaction.emoji)
|
|
|
|
|
|
# endregion
|
|
@@ -451,7 +445,7 @@ class Portal(DBPortal, BasePortal):
|
|
|
if not isinstance(info, (Contact, Profile, Address)):
|
|
|
raise ValueError(f"Unexpected type for direct chat update_info: {type(info)}")
|
|
|
if not self.name:
|
|
|
- puppet = await p.Puppet.get_by_address(Address(uuid=self.chat_id))
|
|
|
+ puppet = await p.Puppet.get_by_address(self.chat_id)
|
|
|
if not puppet.name:
|
|
|
await puppet.update_info(info)
|
|
|
self.name = puppet.name
|
|
@@ -619,18 +613,18 @@ class Portal(DBPortal, BasePortal):
|
|
|
if self.config["bridge.encryption.default"] and self.matrix.e2ee:
|
|
|
self.encrypted = True
|
|
|
initial_state.append({
|
|
|
- "type": "m.room.encryption",
|
|
|
+ "type": str(EventType.ROOM_ENCRYPTION),
|
|
|
"content": {"algorithm": "m.megolm.v1.aes-sha2"},
|
|
|
})
|
|
|
if self.is_direct:
|
|
|
invites.append(self.az.bot_mxid)
|
|
|
- if source.uuid == self.chat_id:
|
|
|
+ if self.is_direct and source.address == self.chat_id:
|
|
|
name = self.name = "Signal Note to Self"
|
|
|
elif self.encrypted or self.private_chat_portal_meta or not self.is_direct:
|
|
|
name = self.name
|
|
|
if self.avatar_url:
|
|
|
initial_state.append({
|
|
|
- "type": "m.room.avatar",
|
|
|
+ "type": str(EventType.ROOM_AVATAR),
|
|
|
"content": {"url": self.avatar_url},
|
|
|
})
|
|
|
if self.config["appservice.community_id"]:
|
|
@@ -638,10 +632,9 @@ class Portal(DBPortal, BasePortal):
|
|
|
"type": "m.room.related_groups",
|
|
|
"content": {"groups": [self.config["appservice.community_id"]]},
|
|
|
})
|
|
|
- #Allow chaning of room avatar and name in direct chats
|
|
|
if self.is_direct:
|
|
|
initial_state.append({
|
|
|
- "type": "m.room.power_levels",
|
|
|
+ "type": str(EventType.ROOM_POWER_LEVELS),
|
|
|
"content": {"users": {self.main_intent.mxid: 100},
|
|
|
"events": {"m.room.avatar": 0, "m.room.name": 0}}
|
|
|
})
|
|
@@ -689,7 +682,7 @@ class Portal(DBPortal, BasePortal):
|
|
|
if self.mxid:
|
|
|
self.by_mxid[self.mxid] = self
|
|
|
if self.is_direct:
|
|
|
- puppet = await p.Puppet.get_by_address(Address(uuid=self.chat_id))
|
|
|
+ puppet = await p.Puppet.get_by_address(self.chat_id)
|
|
|
self._main_intent = puppet.default_mxid_intent
|
|
|
elif not self.is_direct:
|
|
|
self._main_intent = self.az.intent
|
|
@@ -709,7 +702,7 @@ class Portal(DBPortal, BasePortal):
|
|
|
return cls._db_to_portals(super().all_with_room())
|
|
|
|
|
|
@classmethod
|
|
|
- def find_private_chats_with(cls, other_user: UUID) -> AsyncGenerator['Portal', None]:
|
|
|
+ def find_private_chats_with(cls, other_user: Address) -> AsyncGenerator['Portal', None]:
|
|
|
return cls._db_to_portals(super().find_private_chats_with(other_user))
|
|
|
|
|
|
@classmethod
|
|
@@ -718,7 +711,7 @@ class Portal(DBPortal, BasePortal):
|
|
|
portals = await query
|
|
|
for index, portal in enumerate(portals):
|
|
|
try:
|
|
|
- yield cls.by_chat_id[(portal.chat_id, portal.receiver)]
|
|
|
+ yield cls.by_chat_id[(portal.chat_id_str, portal.receiver)]
|
|
|
except KeyError:
|
|
|
await portal._postinit()
|
|
|
yield portal
|
|
@@ -738,14 +731,17 @@ class Portal(DBPortal, BasePortal):
|
|
|
return None
|
|
|
|
|
|
@classmethod
|
|
|
- async def get_by_chat_id(cls, chat_id: Union[UUID, str], receiver: str = "",
|
|
|
+ async def get_by_chat_id(cls, chat_id: Union[GroupID, Address], receiver: str = "",
|
|
|
create: bool = False) -> Optional['Portal']:
|
|
|
if isinstance(chat_id, str):
|
|
|
receiver = ""
|
|
|
+ elif not isinstance(chat_id, Address):
|
|
|
+ raise ValueError(f"Invalid chat ID type {type(chat_id)}")
|
|
|
elif not receiver:
|
|
|
raise ValueError("Direct chats must have a receiver")
|
|
|
try:
|
|
|
- return cls.by_chat_id[(chat_id, receiver)]
|
|
|
+ best_id = chat_id.best_identifier if isinstance(chat_id, Address) else chat_id
|
|
|
+ return cls.by_chat_id[(best_id, receiver)]
|
|
|
except KeyError:
|
|
|
pass
|
|
|
|