puppet.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2021 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
  18. from uuid import UUID
  19. import asyncio
  20. import hashlib
  21. import os.path
  22. from yarl import URL
  23. from mausignald.types import Address, Profile
  24. from mautrix.appservice import IntentAPI
  25. from mautrix.bridge import BasePuppet, async_getter_lock
  26. from mautrix.errors import MForbidden
  27. from mautrix.types import (
  28. ContentURI,
  29. EventType,
  30. PowerLevelStateEventContent,
  31. RoomID,
  32. SyncToken,
  33. UserID,
  34. )
  35. from mautrix.util.simple_template import SimpleTemplate
  36. from . import portal as p, user as u
  37. from .config import Config
  38. from .db import Puppet as DBPuppet
  39. if TYPE_CHECKING:
  40. from .__main__ import SignalBridge
  41. try:
  42. import phonenumbers
  43. except ImportError:
  44. phonenumbers = None
  45. class Puppet(DBPuppet, BasePuppet):
  46. by_uuid: dict[UUID, Puppet] = {}
  47. by_number: dict[str, Puppet] = {}
  48. by_custom_mxid: dict[UserID, Puppet] = {}
  49. hs_domain: str
  50. mxid_template: SimpleTemplate[str]
  51. config: Config
  52. default_mxid_intent: IntentAPI
  53. default_mxid: UserID
  54. _uuid_lock: asyncio.Lock
  55. _update_info_lock: asyncio.Lock
  56. def __init__(
  57. self,
  58. uuid: UUID | None,
  59. number: str | None,
  60. name: str | None = None,
  61. name_quality: int = 0,
  62. avatar_url: ContentURI | None = None,
  63. avatar_hash: str | None = None,
  64. name_set: bool = False,
  65. avatar_set: bool = False,
  66. uuid_registered: bool = False,
  67. number_registered: bool = False,
  68. custom_mxid: UserID | None = None,
  69. access_token: str | None = None,
  70. next_batch: SyncToken | None = None,
  71. base_url: URL | None = None,
  72. ) -> None:
  73. super().__init__(
  74. uuid=uuid,
  75. number=number,
  76. name=name,
  77. name_quality=name_quality,
  78. avatar_url=avatar_url,
  79. avatar_hash=avatar_hash,
  80. name_set=name_set,
  81. avatar_set=avatar_set,
  82. uuid_registered=uuid_registered,
  83. number_registered=number_registered,
  84. custom_mxid=custom_mxid,
  85. access_token=access_token,
  86. next_batch=next_batch,
  87. base_url=base_url,
  88. )
  89. self.log = self.log.getChild(str(uuid) if uuid else number)
  90. self.default_mxid = self.get_mxid_from_id(self.address)
  91. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  92. self.intent = self._fresh_intent()
  93. self._uuid_lock = asyncio.Lock()
  94. self._update_info_lock = asyncio.Lock()
  95. @classmethod
  96. def init_cls(cls, bridge: "SignalBridge") -> AsyncIterable[Awaitable[None]]:
  97. cls.config = bridge.config
  98. cls.loop = bridge.loop
  99. cls.mx = bridge.matrix
  100. cls.az = bridge.az
  101. cls.hs_domain = cls.config["homeserver.domain"]
  102. cls.mxid_template = SimpleTemplate(
  103. cls.config["bridge.username_template"],
  104. "userid",
  105. prefix="@",
  106. suffix=f":{cls.hs_domain}",
  107. type=str,
  108. )
  109. cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
  110. cls.homeserver_url_map = {
  111. server: URL(url)
  112. for server, url in cls.config["bridge.double_puppet_server_map"].items()
  113. }
  114. cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
  115. cls.login_shared_secret_map = {
  116. server: secret.encode("utf-8")
  117. for server, secret in cls.config["bridge.login_shared_secret_map"].items()
  118. }
  119. cls.login_device_name = "Signal Bridge"
  120. return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
  121. def intent_for(self, portal: p.Portal) -> IntentAPI:
  122. if portal.chat_id == self.address:
  123. return self.default_mxid_intent
  124. return self.intent
  125. @property
  126. def is_registered(self) -> bool:
  127. return self.uuid_registered if self.uuid is not None else self.number_registered
  128. @is_registered.setter
  129. def is_registered(self, value: bool) -> None:
  130. if self.uuid is not None:
  131. self.uuid_registered = value
  132. else:
  133. self.number_registered = value
  134. @property
  135. def address(self) -> Address:
  136. return Address(uuid=self.uuid, number=self.number)
  137. async def handle_uuid_receive(self, uuid: UUID) -> None:
  138. async with self._uuid_lock:
  139. if self.uuid:
  140. # Received UUID was handled while this call was waiting
  141. return
  142. await self._handle_uuid_receive(uuid)
  143. async def handle_number_receive(self, number: str) -> None:
  144. async with self._uuid_lock:
  145. if self.number == number:
  146. return
  147. self.number = number
  148. self.by_number[self.number] = self
  149. await self._set_number(number)
  150. async for portal in p.Portal.find_private_chats_with(Address(number=number)):
  151. self.log.trace(f"Updating chat_id of private chat portal {portal.receiver}")
  152. portal.handle_uuid_receive(self.uuid)
  153. prev_mxid = self.get_mxid_from_id(Address(number=number))
  154. if await self.az.state_store.is_registered(prev_mxid):
  155. prev_intent = self.az.intent.user(prev_mxid)
  156. await self._migrate_memberships(prev_intent, self.default_mxid_intent)
  157. async def _handle_uuid_receive(self, uuid: UUID) -> None:
  158. self.log.debug(f"Found UUID for user: {uuid}")
  159. user = await u.User.get_by_username(self.number)
  160. if user and not user.uuid:
  161. user.uuid = self.uuid
  162. user.by_uuid[user.uuid] = user
  163. await user.update()
  164. self.uuid = uuid
  165. self.by_uuid[self.uuid] = self
  166. await self._set_uuid(uuid)
  167. async for portal in p.Portal.find_private_chats_with(Address(number=self.number)):
  168. self.log.trace(f"Updating chat_id of private chat portal {portal.receiver}")
  169. portal.handle_uuid_receive(self.uuid)
  170. prev_intent = self.default_mxid_intent
  171. self.default_mxid = self.get_mxid_from_id(self.address)
  172. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  173. self.intent = self._fresh_intent()
  174. await self.default_mxid_intent.ensure_registered()
  175. if self.name:
  176. await self.default_mxid_intent.set_displayname(self.name)
  177. self.log = Puppet.log.getChild(str(uuid))
  178. await self._migrate_memberships(prev_intent, self.default_mxid_intent)
  179. async def _migrate_memberships(self, prev_intent: IntentAPI, new_intent: IntentAPI) -> None:
  180. self.log.debug(f"Migrating memberships {prev_intent.mxid} -> {new_intent.mxid}")
  181. try:
  182. joined_rooms = await prev_intent.get_joined_rooms()
  183. except MForbidden as e:
  184. self.log.debug(
  185. f"Got MForbidden ({e.message}) when getting joined rooms of old mxid, "
  186. "assuming there are no rooms to rejoin"
  187. )
  188. return
  189. for room_id in joined_rooms:
  190. await prev_intent.invite_user(room_id, self.default_mxid)
  191. await self._migrate_powers(prev_intent, new_intent, room_id)
  192. await prev_intent.leave_room(room_id)
  193. await new_intent.join_room_by_id(room_id)
  194. async def _migrate_powers(
  195. self, prev_intent: IntentAPI, new_intent: IntentAPI, room_id: RoomID
  196. ) -> None:
  197. try:
  198. powers: PowerLevelStateEventContent
  199. powers = await prev_intent.get_state_event(room_id, EventType.ROOM_POWER_LEVELS)
  200. user_level = powers.get_user_level(prev_intent.mxid)
  201. pl_state_level = powers.get_event_level(EventType.ROOM_POWER_LEVELS)
  202. if user_level >= pl_state_level > powers.users_default:
  203. powers.ensure_user_level(new_intent.mxid, user_level)
  204. await prev_intent.send_state_event(room_id, EventType.ROOM_POWER_LEVELS, powers)
  205. except Exception:
  206. self.log.warning("Failed to migrate power levels", exc_info=True)
  207. async def update_info(self, info: Profile | Address, source: u.User) -> None:
  208. update = False
  209. address = info.address if isinstance(info, Profile) else info
  210. if address.uuid and not self.uuid:
  211. await self.handle_uuid_receive(address.uuid)
  212. if address.number and address.number != self.number:
  213. await self.handle_number_receive(address.number)
  214. update = True
  215. self.log.debug("Updating info with %s (source: %s)", info, source.mxid)
  216. async with self._update_info_lock:
  217. if isinstance(info, Profile) or self.name is None:
  218. update = await self._update_name(info) or update
  219. if isinstance(info, Profile):
  220. update = await self._update_avatar(info.avatar) or update
  221. elif self.config["bridge.contact_list_names"] != "disallow" and self.number:
  222. # Try to use a contact list avatar
  223. update = await self._update_avatar(f"contact-{self.number}") or update
  224. if update:
  225. await self.update()
  226. asyncio.create_task(self._update_portal_meta())
  227. @staticmethod
  228. def fmt_phone(number: str) -> str:
  229. if phonenumbers is None:
  230. return number
  231. parsed = phonenumbers.parse(number)
  232. fmt = phonenumbers.PhoneNumberFormat.INTERNATIONAL
  233. return phonenumbers.format_number(parsed, fmt)
  234. @classmethod
  235. def _get_displayname(cls, info: Profile | Address) -> tuple[str, int]:
  236. quality = 10
  237. if isinstance(info, Profile):
  238. address = info.address
  239. name = None
  240. contact_names = cls.config["bridge.contact_list_names"]
  241. if info.profile_name:
  242. name = info.profile_name
  243. quality = 90 if contact_names == "prefer" else 100
  244. if info.contact_name:
  245. if contact_names == "prefer":
  246. quality = 100
  247. name = info.contact_name
  248. elif contact_names == "allow" and not name:
  249. quality = 50
  250. name = info.contact_name
  251. names = name.split("\x00") if name else []
  252. else:
  253. address = info
  254. names = []
  255. data = {
  256. "first_name": names[0] if len(names) > 0 else "",
  257. "last_name": names[-1] if len(names) > 1 else "",
  258. "full_name": " ".join(names),
  259. "phone": cls.fmt_phone(address.number) if address.number else None,
  260. "uuid": str(address.uuid) if address.uuid else None,
  261. "displayname": "Unknown user",
  262. }
  263. for pref in cls.config["bridge.displayname_preference"]:
  264. value = data.get(pref.replace(" ", "_"))
  265. if value:
  266. data["displayname"] = value
  267. break
  268. return cls.config["bridge.displayname_template"].format(**data), quality
  269. async def _update_name(self, info: Profile | Address) -> bool:
  270. name, quality = self._get_displayname(info)
  271. if quality >= self.name_quality and (name != self.name or not self.name_set):
  272. self.log.debug(
  273. "Updating name from '%s' to '%s' (quality: %d)", self.name, name, quality
  274. )
  275. self.name = name
  276. self.name_quality = quality
  277. try:
  278. await self.default_mxid_intent.set_displayname(self.name)
  279. self.name_set = True
  280. except Exception:
  281. self.log.exception("Error setting displayname")
  282. self.name_set = False
  283. return True
  284. elif name != self.name or not self.name_set:
  285. self.log.debug(
  286. "Not updating name from '%s' to '%s', new quality (%d) is lower than old (%d)",
  287. self.name,
  288. name,
  289. quality,
  290. self.name_quality,
  291. )
  292. elif self.name_quality == 0:
  293. # Name matches, but quality is not stored in database - store it now
  294. self.name_quality = quality
  295. return True
  296. return False
  297. @staticmethod
  298. async def upload_avatar(
  299. self: Puppet | p.Portal, path: str, intent: IntentAPI
  300. ) -> bool | tuple[str, ContentURI]:
  301. if not path:
  302. return False
  303. if not path.startswith("/"):
  304. path = os.path.join(self.config["signal.avatar_dir"], path)
  305. try:
  306. with open(path, "rb") as file:
  307. data = file.read()
  308. except FileNotFoundError:
  309. return False
  310. if not data:
  311. return False
  312. new_hash = hashlib.sha256(data).hexdigest()
  313. if self.avatar_set and new_hash == self.avatar_hash:
  314. return False
  315. mxc = await intent.upload_media(data, async_upload=self.config["homeserver.async_media"])
  316. return new_hash, mxc
  317. async def _update_avatar(self, path: str) -> bool:
  318. res = await Puppet.upload_avatar(self, path, self.default_mxid_intent)
  319. if res is False:
  320. return False
  321. self.avatar_hash, self.avatar_url = res
  322. try:
  323. await self.default_mxid_intent.set_avatar_url(self.avatar_url)
  324. self.avatar_set = True
  325. except Exception:
  326. self.log.exception("Error setting avatar")
  327. self.avatar_set = False
  328. return True
  329. async def _update_portal_meta(self) -> None:
  330. async for portal in p.Portal.find_private_chats_with(self.address):
  331. if portal.receiver == self.number:
  332. # This is a note to self chat, don't change the name
  333. continue
  334. try:
  335. await portal.update_puppet_name(self.name)
  336. await portal.update_puppet_avatar(self.avatar_hash, self.avatar_url)
  337. if self.number:
  338. await portal.update_puppet_number(self.fmt_phone(self.number))
  339. except Exception:
  340. self.log.exception(f"Error updating portal meta for {portal.receiver}")
  341. async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
  342. portal: p.Portal = await p.Portal.get_by_mxid(room_id)
  343. if not portal or not portal.is_direct:
  344. return True
  345. elif portal.chat_id.uuid and self.uuid:
  346. return portal.chat_id.uuid != self.uuid
  347. elif portal.chat_id.number and self.number:
  348. return portal.chat_id.number != self.number
  349. else:
  350. return True
  351. # region Database getters
  352. def _add_to_cache(self) -> None:
  353. if self.uuid:
  354. self.by_uuid[self.uuid] = self
  355. if self.number:
  356. self.by_number[self.number] = self
  357. if self.custom_mxid:
  358. self.by_custom_mxid[self.custom_mxid] = self
  359. async def save(self) -> None:
  360. await self.update()
  361. @classmethod
  362. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Puppet | None:
  363. address = cls.get_id_from_mxid(mxid)
  364. if not address:
  365. return None
  366. return await cls.get_by_address(address, create)
  367. @classmethod
  368. @async_getter_lock
  369. async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
  370. try:
  371. return cls.by_custom_mxid[mxid]
  372. except KeyError:
  373. pass
  374. puppet = cast(cls, await super().get_by_custom_mxid(mxid))
  375. if puppet:
  376. puppet._add_to_cache()
  377. return puppet
  378. return None
  379. @classmethod
  380. def get_id_from_mxid(cls, mxid: UserID) -> Address | None:
  381. identifier = cls.mxid_template.parse(mxid)
  382. if not identifier:
  383. return None
  384. if identifier.startswith("phone_"):
  385. return Address(number="+" + identifier[len("phone_") :])
  386. else:
  387. try:
  388. return Address(uuid=UUID(identifier.upper()))
  389. except ValueError:
  390. return None
  391. @classmethod
  392. def get_mxid_from_id(cls, address: Address) -> UserID:
  393. if address.uuid:
  394. identifier = str(address.uuid).lower()
  395. elif address.number:
  396. identifier = f"phone_{address.number.lstrip('+')}"
  397. else:
  398. raise ValueError("Empty address")
  399. return UserID(cls.mxid_template.format_full(identifier))
  400. @classmethod
  401. @async_getter_lock
  402. async def get_by_address(cls, address: Address, create: bool = True) -> Puppet | None:
  403. puppet = await cls._get_by_address(address, create)
  404. if puppet and address.uuid and not puppet.uuid:
  405. # We found a UUID for this user, store it ASAP
  406. await puppet.handle_uuid_receive(address.uuid)
  407. return puppet
  408. @classmethod
  409. async def _get_by_address(cls, address: Address, create: bool = True) -> Puppet | None:
  410. if not address.is_valid:
  411. raise ValueError("Empty address")
  412. if address.uuid:
  413. try:
  414. return cls.by_uuid[address.uuid]
  415. except KeyError:
  416. pass
  417. if address.number:
  418. try:
  419. return cls.by_number[address.number]
  420. except KeyError:
  421. pass
  422. puppet = cast(cls, await super().get_by_address(address))
  423. if puppet is not None:
  424. puppet._add_to_cache()
  425. return puppet
  426. if create:
  427. puppet = cls(address.uuid, address.number)
  428. await puppet.insert()
  429. puppet._add_to_cache()
  430. return puppet
  431. return None
  432. @classmethod
  433. async def all_with_custom_mxid(cls) -> AsyncGenerator[Puppet, None]:
  434. puppets = await super().all_with_custom_mxid()
  435. puppet: cls
  436. for index, puppet in enumerate(puppets):
  437. try:
  438. yield cls.by_uuid[puppet.uuid]
  439. except KeyError:
  440. try:
  441. yield cls.by_number[puppet.number]
  442. except KeyError:
  443. puppet._add_to_cache()
  444. yield puppet
  445. # endregion