signal.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import Optional, List, TYPE_CHECKING
  17. import asyncio
  18. import logging
  19. from mausignald import SignaldClient
  20. from mausignald.types import (
  21. Message,
  22. MessageData,
  23. Address,
  24. TypingNotification,
  25. TypingAction,
  26. OwnReadReceipt,
  27. Receipt,
  28. ReceiptType,
  29. WebsocketConnectionStateChangeEvent,
  30. )
  31. from mautrix.util.logging import TraceLogger
  32. from .db import Message as DBMessage
  33. from . import user as u, portal as po, puppet as pu
  34. if TYPE_CHECKING:
  35. from .__main__ import SignalBridge
  36. # Typing notifications seem to get resent every 10 seconds and the timeout is around 15 seconds
  37. SIGNAL_TYPING_TIMEOUT = 15000
  38. class SignalHandler(SignaldClient):
  39. log: TraceLogger = logging.getLogger("mau.signal")
  40. loop: asyncio.AbstractEventLoop
  41. data_dir: str
  42. delete_unknown_accounts: bool
  43. def __init__(self, bridge: "SignalBridge") -> None:
  44. super().__init__(bridge.config["signal.socket_path"], loop=bridge.loop)
  45. self.data_dir = bridge.config["signal.data_dir"]
  46. self.delete_unknown_accounts = bridge.config["signal.delete_unknown_accounts_on_start"]
  47. self.add_event_handler(Message, self.on_message)
  48. self.add_event_handler(
  49. WebsocketConnectionStateChangeEvent, self.on_websocket_connection_state_change
  50. )
  51. async def on_message(self, evt: Message) -> None:
  52. sender = await pu.Puppet.get_by_address(evt.source)
  53. user = await u.User.get_by_username(evt.username)
  54. # TODO add lots of logging
  55. if evt.data_message:
  56. await self.handle_message(user, sender, evt.data_message)
  57. if evt.typing:
  58. await self.handle_typing(user, sender, evt.typing)
  59. if evt.receipt:
  60. await self.handle_receipt(sender, evt.receipt)
  61. if evt.sync_message:
  62. if evt.sync_message.read_messages:
  63. await self.handle_own_receipts(sender, evt.sync_message.read_messages)
  64. if evt.sync_message.sent:
  65. await self.handle_message(
  66. user,
  67. sender,
  68. evt.sync_message.sent.message,
  69. addr_override=evt.sync_message.sent.destination,
  70. )
  71. if evt.sync_message.typing:
  72. # Typing notification from own device
  73. pass
  74. if evt.sync_message.contacts or evt.sync_message.contacts_complete:
  75. self.log.debug("Sync message includes contacts meta, syncing contacts...")
  76. await user.sync_contacts()
  77. if evt.sync_message.groups:
  78. self.log.debug("Sync message includes groups meta, syncing groups...")
  79. await user.sync_groups()
  80. @staticmethod
  81. async def on_websocket_connection_state_change(
  82. evt: WebsocketConnectionStateChangeEvent,
  83. ) -> None:
  84. user = await u.User.get_by_username(evt.account)
  85. user.on_websocket_connection_state_change(evt)
  86. async def handle_message(
  87. self,
  88. user: "u.User",
  89. sender: "pu.Puppet",
  90. msg: MessageData,
  91. addr_override: Optional[Address] = None,
  92. ) -> None:
  93. if msg.profile_key_update:
  94. self.log.debug("Ignoring profile key update")
  95. return
  96. if msg.group_v2:
  97. portal = await po.Portal.get_by_chat_id(msg.group_v2.id, create=True)
  98. elif msg.group:
  99. portal = await po.Portal.get_by_chat_id(msg.group.group_id, create=True)
  100. else:
  101. portal = await po.Portal.get_by_chat_id(
  102. addr_override or sender.address, receiver=user.username, create=True
  103. )
  104. if addr_override and not sender.is_real_user:
  105. portal.log.debug(
  106. f"Ignoring own message {msg.timestamp} as user doesn't have double puppeting "
  107. "enabled"
  108. )
  109. return
  110. if not portal.mxid:
  111. await portal.create_matrix_room(
  112. user, (msg.group_v2 or msg.group or addr_override or sender.address)
  113. )
  114. if not portal.mxid:
  115. user.log.debug(
  116. f"Failed to create room for incoming message {msg.timestamp}, dropping message"
  117. )
  118. return
  119. elif msg.group_v2 and msg.group_v2.revision > portal.revision:
  120. self.log.debug(f"Got new revision of {msg.group_v2.id}, updating info")
  121. await portal.update_info(user, msg.group_v2, sender)
  122. if msg.reaction:
  123. await portal.handle_signal_reaction(sender, msg.reaction, msg.timestamp)
  124. if msg.body or msg.attachments or msg.sticker:
  125. await portal.handle_signal_message(user, sender, msg)
  126. if msg.group and msg.group.type == "UPDATE":
  127. await portal.update_info(user, msg.group)
  128. if msg.remote_delete:
  129. await portal.handle_signal_delete(sender, msg.remote_delete.target_sent_timestamp)
  130. if msg.expires_in_seconds is not None:
  131. await portal.update_expires_in_seconds(sender, msg.expires_in_seconds)
  132. @staticmethod
  133. async def handle_own_receipts(sender: "pu.Puppet", receipts: List[OwnReadReceipt]) -> None:
  134. for receipt in receipts:
  135. puppet = await pu.Puppet.get_by_address(receipt.sender, create=False)
  136. if not puppet:
  137. continue
  138. message = await DBMessage.find_by_sender_timestamp(puppet.address, receipt.timestamp)
  139. if not message:
  140. continue
  141. portal = await po.Portal.get_by_mxid(message.mx_room)
  142. if not portal or (portal.is_direct and not sender.is_real_user):
  143. continue
  144. await sender.intent_for(portal).mark_read(portal.mxid, message.mxid)
  145. @staticmethod
  146. async def handle_typing(
  147. user: "u.User", sender: "pu.Puppet", typing: TypingNotification
  148. ) -> None:
  149. if typing.group_id:
  150. portal = await po.Portal.get_by_chat_id(typing.group_id)
  151. else:
  152. portal = await po.Portal.get_by_chat_id(sender.address, receiver=user.username)
  153. if not portal or not portal.mxid:
  154. return
  155. is_typing = typing.action == TypingAction.STARTED
  156. await sender.intent_for(portal).set_typing(
  157. portal.mxid, is_typing, ignore_cache=True, timeout=SIGNAL_TYPING_TIMEOUT
  158. )
  159. @staticmethod
  160. async def handle_receipt(sender: "pu.Puppet", receipt: Receipt) -> None:
  161. if receipt.type != ReceiptType.READ:
  162. return
  163. messages = await DBMessage.find_by_timestamps(receipt.timestamps)
  164. for message in messages:
  165. portal = await po.Portal.get_by_mxid(message.mx_room)
  166. await sender.intent_for(portal).mark_read(portal.mxid, message.mxid)
  167. async def start(self) -> None:
  168. await self.connect()
  169. known_usernames = set()
  170. async for user in u.User.all_logged_in():
  171. # TODO report errors to user?
  172. known_usernames.add(user.username)
  173. if await self.subscribe(user.username):
  174. asyncio.create_task(user.sync())
  175. if self.delete_unknown_accounts:
  176. self.log.debug("Checking for unknown accounts to delete")
  177. for account in await self.list_accounts():
  178. if account.account_id not in known_usernames:
  179. self.log.warning(f"Unknown account ID {account.account_id}, deleting...")
  180. await self.delete_account(account.account_id)
  181. async def stop(self) -> None:
  182. await self.disconnect()