puppet.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2021 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
  18. from uuid import UUID
  19. import asyncio
  20. import hashlib
  21. import os.path
  22. from mautrix.appservice import IntentAPI
  23. from mautrix.bridge import BasePuppet, async_getter_lock
  24. from mautrix.errors import MForbidden
  25. from mautrix.types import (
  26. ContentURI,
  27. EventType,
  28. PowerLevelStateEventContent,
  29. RoomID,
  30. SyncToken,
  31. UserID,
  32. )
  33. from mautrix.util.simple_template import SimpleTemplate
  34. from yarl import URL
  35. from mausignald.types import Address, Contact, Profile
  36. from . import portal as p, user as u
  37. from .config import Config
  38. from .db import Puppet as DBPuppet
  39. if TYPE_CHECKING:
  40. from .__main__ import SignalBridge
  41. try:
  42. import phonenumbers
  43. except ImportError:
  44. phonenumbers = None
  45. class Puppet(DBPuppet, BasePuppet):
  46. by_uuid: dict[UUID, Puppet] = {}
  47. by_number: dict[str, Puppet] = {}
  48. by_custom_mxid: dict[UserID, Puppet] = {}
  49. hs_domain: str
  50. mxid_template: SimpleTemplate[str]
  51. config: Config
  52. default_mxid_intent: IntentAPI
  53. default_mxid: UserID
  54. _uuid_lock: asyncio.Lock
  55. _update_info_lock: asyncio.Lock
  56. def __init__(
  57. self,
  58. uuid: UUID | None,
  59. number: str | None,
  60. name: str | None = None,
  61. avatar_url: ContentURI | None = None,
  62. avatar_hash: str | None = None,
  63. name_set: bool = False,
  64. avatar_set: bool = False,
  65. uuid_registered: bool = False,
  66. number_registered: bool = False,
  67. custom_mxid: UserID | None = None,
  68. access_token: str | None = None,
  69. next_batch: SyncToken | None = None,
  70. base_url: URL | None = None,
  71. ) -> None:
  72. super().__init__(
  73. uuid=uuid,
  74. number=number,
  75. name=name,
  76. avatar_url=avatar_url,
  77. avatar_hash=avatar_hash,
  78. name_set=name_set,
  79. avatar_set=avatar_set,
  80. uuid_registered=uuid_registered,
  81. number_registered=number_registered,
  82. custom_mxid=custom_mxid,
  83. access_token=access_token,
  84. next_batch=next_batch,
  85. base_url=base_url,
  86. )
  87. self.log = self.log.getChild(str(uuid) if uuid else number)
  88. self.default_mxid = self.get_mxid_from_id(self.address)
  89. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  90. self.intent = self._fresh_intent()
  91. self._uuid_lock = asyncio.Lock()
  92. self._update_info_lock = asyncio.Lock()
  93. @classmethod
  94. def init_cls(cls, bridge: "SignalBridge") -> AsyncIterable[Awaitable[None]]:
  95. cls.config = bridge.config
  96. cls.loop = bridge.loop
  97. cls.mx = bridge.matrix
  98. cls.az = bridge.az
  99. cls.hs_domain = cls.config["homeserver.domain"]
  100. cls.mxid_template = SimpleTemplate(
  101. cls.config["bridge.username_template"],
  102. "userid",
  103. prefix="@",
  104. suffix=f":{cls.hs_domain}",
  105. type=str,
  106. )
  107. cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
  108. cls.homeserver_url_map = {
  109. server: URL(url)
  110. for server, url in cls.config["bridge.double_puppet_server_map"].items()
  111. }
  112. cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
  113. cls.login_shared_secret_map = {
  114. server: secret.encode("utf-8")
  115. for server, secret in cls.config["bridge.login_shared_secret_map"].items()
  116. }
  117. cls.login_device_name = "Signal Bridge"
  118. return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
  119. def intent_for(self, portal: p.Portal) -> IntentAPI:
  120. if portal.chat_id == self.address:
  121. return self.default_mxid_intent
  122. return self.intent
  123. @property
  124. def is_registered(self) -> bool:
  125. return self.uuid_registered if self.uuid is not None else self.number_registered
  126. @is_registered.setter
  127. def is_registered(self, value: bool) -> None:
  128. if self.uuid is not None:
  129. self.uuid_registered = value
  130. else:
  131. self.number_registered = value
  132. @property
  133. def address(self) -> Address:
  134. return Address(uuid=self.uuid, number=self.number)
  135. async def handle_uuid_receive(self, uuid: UUID) -> None:
  136. async with self._uuid_lock:
  137. if self.uuid:
  138. # Received UUID was handled while this call was waiting
  139. return
  140. await self._handle_uuid_receive(uuid)
  141. async def handle_number_receive(self, number: str) -> None:
  142. async with self._uuid_lock:
  143. if self.number:
  144. return
  145. self.number = number
  146. self.by_number[self.number] = self
  147. await self._set_number(number)
  148. async for portal in p.Portal.find_private_chats_with(Address(number=number)):
  149. self.log.trace(f"Updating chat_id of private chat portal {portal.receiver}")
  150. portal.handle_uuid_receive(self.uuid)
  151. prev_mxid = self.get_mxid_from_id(Address(number=number))
  152. if await self.az.state_store.is_registered(prev_mxid):
  153. prev_intent = self.az.intent.user(prev_mxid)
  154. await self._migrate_memberships(prev_intent, self.default_mxid_intent)
  155. async def _handle_uuid_receive(self, uuid: UUID) -> None:
  156. self.log.debug(f"Found UUID for user: {uuid}")
  157. user = await u.User.get_by_username(self.number)
  158. if user and not user.uuid:
  159. user.uuid = self.uuid
  160. user.by_uuid[user.uuid] = user
  161. await user.update()
  162. self.uuid = uuid
  163. self.by_uuid[self.uuid] = self
  164. await self._set_uuid(uuid)
  165. async for portal in p.Portal.find_private_chats_with(Address(number=self.number)):
  166. self.log.trace(f"Updating chat_id of private chat portal {portal.receiver}")
  167. portal.handle_uuid_receive(self.uuid)
  168. prev_intent = self.default_mxid_intent
  169. self.default_mxid = self.get_mxid_from_id(self.address)
  170. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  171. self.intent = self._fresh_intent()
  172. await self.default_mxid_intent.ensure_registered()
  173. if self.name:
  174. await self.default_mxid_intent.set_displayname(self.name)
  175. self.log = Puppet.log.getChild(str(uuid))
  176. await self._migrate_memberships(prev_intent, self.default_mxid_intent)
  177. async def _migrate_memberships(self, prev_intent: IntentAPI, new_intent: IntentAPI) -> None:
  178. self.log.debug(f"Migrating memberships {prev_intent.mxid} -> {new_intent.mxid}")
  179. try:
  180. joined_rooms = await prev_intent.get_joined_rooms()
  181. except MForbidden as e:
  182. self.log.debug(
  183. f"Got MForbidden ({e.message}) when getting joined rooms of old mxid, "
  184. "assuming there are no rooms to rejoin"
  185. )
  186. return
  187. for room_id in joined_rooms:
  188. await prev_intent.invite_user(room_id, self.default_mxid)
  189. await self._migrate_powers(prev_intent, new_intent, room_id)
  190. await prev_intent.leave_room(room_id)
  191. await new_intent.join_room_by_id(room_id)
  192. async def _migrate_powers(
  193. self, prev_intent: IntentAPI, new_intent: IntentAPI, room_id: RoomID
  194. ) -> None:
  195. try:
  196. powers: PowerLevelStateEventContent
  197. powers = await prev_intent.get_state_event(room_id, EventType.ROOM_POWER_LEVELS)
  198. user_level = powers.get_user_level(prev_intent.mxid)
  199. pl_state_level = powers.get_event_level(EventType.ROOM_POWER_LEVELS)
  200. if user_level >= pl_state_level > powers.users_default:
  201. powers.ensure_user_level(new_intent.mxid, user_level)
  202. await prev_intent.send_state_event(room_id, EventType.ROOM_POWER_LEVELS, powers)
  203. except Exception:
  204. self.log.warning("Failed to migrate power levels", exc_info=True)
  205. async def update_info(self, info: Profile | Contact | Address) -> None:
  206. address = info.address if isinstance(info, (Contact, Profile)) else info
  207. if address.uuid and not self.uuid:
  208. await self.handle_uuid_receive(address.uuid)
  209. if address.number and not self.number:
  210. await self.handle_number_receive(address.number)
  211. contact_names = self.config["bridge.contact_list_names"]
  212. if isinstance(info, Profile) and contact_names != "prefer" and info.profile_name:
  213. name = info.profile_name
  214. elif isinstance(info, (Contact, Profile)) and contact_names != "disallow":
  215. name = info.name
  216. if not name and isinstance(info, Profile) and info.profile_name:
  217. # Contact list name is preferred, but was not found, fall back to profile
  218. name = info.profile_name
  219. else:
  220. name = None
  221. async with self._update_info_lock:
  222. update = False
  223. if name is not None or self.name is None:
  224. update = await self._update_name(name) or update
  225. if isinstance(info, Profile):
  226. update = await self._update_avatar(info.avatar) or update
  227. elif contact_names != "disallow" and self.number:
  228. update = await self._update_avatar(f"contact-{self.number}") or update
  229. if update:
  230. await self.update()
  231. asyncio.create_task(self._update_portal_meta())
  232. @staticmethod
  233. def fmt_phone(number: str) -> str:
  234. if phonenumbers is None:
  235. return number
  236. parsed = phonenumbers.parse(number)
  237. fmt = phonenumbers.PhoneNumberFormat.INTERNATIONAL
  238. return phonenumbers.format_number(parsed, fmt)
  239. @classmethod
  240. def _get_displayname(cls, address: Address, name: str | None) -> str:
  241. names = name.split("\x00") if name else []
  242. data = {
  243. "first_name": names[0] if len(names) > 0 else "",
  244. "last_name": names[-1] if len(names) > 1 else "",
  245. "full_name": " ".join(names),
  246. "phone": cls.fmt_phone(address.number) if address.number else None,
  247. "uuid": str(address.uuid) if address.uuid else None,
  248. "displayname": "Unknown user",
  249. }
  250. for pref in cls.config["bridge.displayname_preference"]:
  251. value = data.get(pref.replace(" ", "_"))
  252. if value:
  253. data["displayname"] = value
  254. break
  255. return cls.config["bridge.displayname_template"].format(**data)
  256. async def _update_name(self, name: str | None) -> bool:
  257. name = self._get_displayname(self.address, name)
  258. if name != self.name or not self.name_set:
  259. self.name = name
  260. try:
  261. await self.default_mxid_intent.set_displayname(self.name)
  262. self.name_set = True
  263. except Exception:
  264. self.log.exception("Error setting displayname")
  265. self.name_set = False
  266. return True
  267. return False
  268. @staticmethod
  269. async def upload_avatar(
  270. self: Puppet | p.Portal, path: str, intent: IntentAPI
  271. ) -> bool | tuple[str, ContentURI]:
  272. if not path:
  273. return False
  274. if not path.startswith("/"):
  275. path = os.path.join(self.config["signal.avatar_dir"], path)
  276. try:
  277. with open(path, "rb") as file:
  278. data = file.read()
  279. except FileNotFoundError:
  280. return False
  281. if not data:
  282. return False
  283. new_hash = hashlib.sha256(data).hexdigest()
  284. if self.avatar_set and new_hash == self.avatar_hash:
  285. return False
  286. mxc = await intent.upload_media(data)
  287. return new_hash, mxc
  288. async def _update_avatar(self, path: str) -> bool:
  289. res = await Puppet.upload_avatar(self, path, self.default_mxid_intent)
  290. if res is False:
  291. return False
  292. self.avatar_hash, self.avatar_url = res
  293. try:
  294. await self.default_mxid_intent.set_avatar_url(self.avatar_url)
  295. self.avatar_set = True
  296. except Exception:
  297. self.log.exception("Error setting avatar")
  298. self.avatar_set = False
  299. return True
  300. async def _update_portal_meta(self) -> None:
  301. async for portal in p.Portal.find_private_chats_with(self.address):
  302. if portal.receiver == self.number:
  303. # This is a note to self chat, don't change the name
  304. continue
  305. try:
  306. await portal.update_puppet_name(self.name)
  307. await portal.update_puppet_avatar(self.avatar_hash, self.avatar_url)
  308. except Exception:
  309. self.log.exception(f"Error updating portal meta for {portal.receiver}")
  310. async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
  311. portal = await p.Portal.get_by_mxid(room_id)
  312. return portal and portal.chat_id != self.uuid
  313. # region Database getters
  314. def _add_to_cache(self) -> None:
  315. if self.uuid:
  316. self.by_uuid[self.uuid] = self
  317. if self.number:
  318. self.by_number[self.number] = self
  319. if self.custom_mxid:
  320. self.by_custom_mxid[self.custom_mxid] = self
  321. async def save(self) -> None:
  322. await self.update()
  323. @classmethod
  324. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Puppet | None:
  325. address = cls.get_id_from_mxid(mxid)
  326. if not address:
  327. return None
  328. return await cls.get_by_address(address, create)
  329. @classmethod
  330. @async_getter_lock
  331. async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
  332. try:
  333. return cls.by_custom_mxid[mxid]
  334. except KeyError:
  335. pass
  336. puppet = cast(cls, await super().get_by_custom_mxid(mxid))
  337. if puppet:
  338. puppet._add_to_cache()
  339. return puppet
  340. return None
  341. @classmethod
  342. def get_id_from_mxid(cls, mxid: UserID) -> Address | None:
  343. identifier = cls.mxid_template.parse(mxid)
  344. if not identifier:
  345. return None
  346. if identifier.startswith("phone_"):
  347. return Address(number="+" + identifier[len("phone_") :])
  348. else:
  349. try:
  350. return Address(uuid=UUID(identifier.upper()))
  351. except ValueError:
  352. return None
  353. @classmethod
  354. def get_mxid_from_id(cls, address: Address) -> UserID:
  355. if address.uuid:
  356. identifier = str(address.uuid).lower()
  357. elif address.number:
  358. identifier = f"phone_{address.number.lstrip('+')}"
  359. else:
  360. raise ValueError("Empty address")
  361. return UserID(cls.mxid_template.format_full(identifier))
  362. @classmethod
  363. @async_getter_lock
  364. async def get_by_address(cls, address: Address, create: bool = True) -> Puppet | None:
  365. puppet = await cls._get_by_address(address, create)
  366. if puppet and address.uuid and not puppet.uuid:
  367. # We found a UUID for this user, store it ASAP
  368. await puppet.handle_uuid_receive(address.uuid)
  369. return puppet
  370. @classmethod
  371. async def _get_by_address(cls, address: Address, create: bool = True) -> Puppet | None:
  372. if not address.is_valid:
  373. raise ValueError("Empty address")
  374. if address.uuid:
  375. try:
  376. return cls.by_uuid[address.uuid]
  377. except KeyError:
  378. pass
  379. if address.number:
  380. try:
  381. return cls.by_number[address.number]
  382. except KeyError:
  383. pass
  384. puppet = cast(cls, await super().get_by_address(address))
  385. if puppet is not None:
  386. puppet._add_to_cache()
  387. return puppet
  388. if create:
  389. puppet = cls(address.uuid, address.number)
  390. await puppet.insert()
  391. puppet._add_to_cache()
  392. return puppet
  393. return None
  394. @classmethod
  395. async def all_with_custom_mxid(cls) -> AsyncGenerator[Puppet, None]:
  396. puppets = await super().all_with_custom_mxid()
  397. puppet: cls
  398. for index, puppet in enumerate(puppets):
  399. try:
  400. yield cls.by_uuid[puppet.uuid]
  401. except KeyError:
  402. try:
  403. yield cls.by_number[puppet.number]
  404. except KeyError:
  405. puppet._add_to_cache()
  406. yield puppet
  407. # endregion