puppet.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2021 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
  18. from uuid import UUID
  19. import asyncio
  20. import hashlib
  21. import os.path
  22. from yarl import URL
  23. from mausignald.errors import UnregisteredUserError
  24. from mausignald.types import Address, Profile
  25. from mautrix.appservice import IntentAPI
  26. from mautrix.bridge import BasePuppet, async_getter_lock
  27. from mautrix.errors import MForbidden
  28. from mautrix.types import (
  29. ContentURI,
  30. EventType,
  31. PowerLevelStateEventContent,
  32. RoomID,
  33. SyncToken,
  34. UserID,
  35. )
  36. from mautrix.util import background_task
  37. from mautrix.util.simple_template import SimpleTemplate
  38. from . import portal as p, signal, user as u
  39. from .config import Config
  40. from .db import Puppet as DBPuppet
  41. if TYPE_CHECKING:
  42. from .__main__ import SignalBridge
  43. try:
  44. import phonenumbers
  45. except ImportError:
  46. phonenumbers = None
  47. class Puppet(DBPuppet, BasePuppet):
  48. by_uuid: dict[UUID, Puppet] = {}
  49. by_number: dict[str, Puppet] = {}
  50. by_custom_mxid: dict[UserID, Puppet] = {}
  51. hs_domain: str
  52. mxid_template: SimpleTemplate[str]
  53. bridge: SignalBridge
  54. config: Config
  55. signal: signal.SignalHandler
  56. default_mxid_intent: IntentAPI
  57. default_mxid: UserID
  58. _uuid_lock: asyncio.Lock
  59. _update_info_lock: asyncio.Lock
  60. def __init__(
  61. self,
  62. uuid: UUID,
  63. number: str | None,
  64. name: str | None = None,
  65. name_quality: int = 0,
  66. avatar_url: ContentURI | None = None,
  67. avatar_hash: str | None = None,
  68. name_set: bool = False,
  69. avatar_set: bool = False,
  70. contact_info_set: bool = False,
  71. is_registered: bool = False,
  72. custom_mxid: UserID | None = None,
  73. access_token: str | None = None,
  74. next_batch: SyncToken | None = None,
  75. base_url: URL | None = None,
  76. ) -> None:
  77. assert uuid, "UUID must be set for ghosts"
  78. assert isinstance(uuid, UUID)
  79. super().__init__(
  80. uuid=uuid,
  81. number=number,
  82. name=name,
  83. name_quality=name_quality,
  84. avatar_url=avatar_url,
  85. avatar_hash=avatar_hash,
  86. name_set=name_set,
  87. avatar_set=avatar_set,
  88. contact_info_set=contact_info_set,
  89. is_registered=is_registered,
  90. custom_mxid=custom_mxid,
  91. access_token=access_token,
  92. next_batch=next_batch,
  93. base_url=base_url,
  94. )
  95. self.log = self.log.getChild(str(uuid) if uuid else number)
  96. self.default_mxid = self.get_mxid_from_id(self.uuid)
  97. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  98. self.intent = self._fresh_intent()
  99. self._uuid_lock = asyncio.Lock()
  100. self._update_info_lock = asyncio.Lock()
  101. @classmethod
  102. def init_cls(cls, bridge: "SignalBridge") -> AsyncIterable[Awaitable[None]]:
  103. cls.bridge = bridge
  104. cls.config = bridge.config
  105. cls.loop = bridge.loop
  106. cls.signal = bridge.signal
  107. cls.mx = bridge.matrix
  108. cls.az = bridge.az
  109. cls.hs_domain = cls.config["homeserver.domain"]
  110. cls.mxid_template = SimpleTemplate(
  111. cls.config["bridge.username_template"],
  112. "userid",
  113. prefix="@",
  114. suffix=f":{cls.hs_domain}",
  115. type=str,
  116. )
  117. cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
  118. cls.homeserver_url_map = {
  119. server: URL(url)
  120. for server, url in cls.config["bridge.double_puppet_server_map"].items()
  121. }
  122. cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
  123. cls.login_shared_secret_map = {
  124. server: secret.encode("utf-8")
  125. for server, secret in cls.config["bridge.login_shared_secret_map"].items()
  126. }
  127. cls.login_device_name = "Signal Bridge"
  128. return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
  129. def intent_for(self, portal: p.Portal) -> IntentAPI:
  130. if portal.chat_id == self.uuid:
  131. return self.default_mxid_intent
  132. return self.intent
  133. @property
  134. def address(self) -> Address:
  135. return Address(uuid=self.uuid, number=self.number)
  136. async def handle_number_receive(self, number: str) -> None:
  137. async with self._uuid_lock:
  138. if self.number == number:
  139. return
  140. if self.number:
  141. self.by_number.pop(self.number, None)
  142. self.number = number
  143. self._add_number_to_cache()
  144. await self._update_number()
  145. async def _migrate_memberships(self, prev_intent: IntentAPI, new_intent: IntentAPI) -> None:
  146. self.log.debug(f"Migrating memberships {prev_intent.mxid} -> {new_intent.mxid}")
  147. try:
  148. joined_rooms = await prev_intent.get_joined_rooms()
  149. except MForbidden as e:
  150. self.log.debug(
  151. f"Got MForbidden ({e.message}) when getting joined rooms of old mxid, "
  152. "assuming there are no rooms to rejoin"
  153. )
  154. return
  155. for room_id in joined_rooms:
  156. await prev_intent.invite_user(room_id, self.default_mxid)
  157. await self._migrate_powers(prev_intent, new_intent, room_id)
  158. await prev_intent.leave_room(room_id)
  159. await new_intent.join_room_by_id(room_id)
  160. async def _migrate_powers(
  161. self, prev_intent: IntentAPI, new_intent: IntentAPI, room_id: RoomID
  162. ) -> None:
  163. try:
  164. powers: PowerLevelStateEventContent
  165. powers = await prev_intent.get_state_event(room_id, EventType.ROOM_POWER_LEVELS)
  166. user_level = powers.get_user_level(prev_intent.mxid)
  167. pl_state_level = powers.get_event_level(EventType.ROOM_POWER_LEVELS)
  168. if user_level >= pl_state_level > powers.users_default:
  169. powers.ensure_user_level(new_intent.mxid, user_level)
  170. await prev_intent.send_state_event(room_id, EventType.ROOM_POWER_LEVELS, powers)
  171. except Exception:
  172. self.log.warning("Failed to migrate power levels", exc_info=True)
  173. async def update_info(self, info: Profile | Address, source: u.User) -> None:
  174. update = False
  175. address = info.address if isinstance(info, Profile) else info
  176. number_updated = False
  177. if address.number and address.number != self.number:
  178. await self.handle_number_receive(address.number)
  179. update = True
  180. number_updated = True
  181. self.log.debug("Updating info with %s (source: %s)", info, source.mxid)
  182. async with self._update_info_lock:
  183. update = await self._update_contact_info(force=number_updated) or update
  184. if isinstance(info, Profile) or self.name is None:
  185. update = await self._update_name(info) or update
  186. if isinstance(info, Profile):
  187. update = await self._update_avatar(info.avatar) or update
  188. elif self.config["bridge.contact_list_names"] != "disallow" and self.number:
  189. # Try to use a contact list avatar
  190. update = await self._update_avatar(f"contact-{self.number}") or update
  191. if update:
  192. await self.update()
  193. background_task.create(self._try_update_portal_meta())
  194. async def _update_contact_info(self, force: bool = False) -> bool:
  195. if not self.bridge.homeserver_software.is_hungry:
  196. return False
  197. if self.contact_info_set and not force:
  198. return False
  199. try:
  200. identifiers = [f"signal:{self.uuid}"]
  201. if self.number:
  202. identifiers.append(f"tel:{self.number}")
  203. await self.default_mxid_intent.beeper_update_profile(
  204. {
  205. "com.beeper.bridge.identifiers": identifiers,
  206. "com.beeper.bridge.remote_id": str(self.uuid),
  207. "com.beeper.bridge.service": "signal",
  208. "com.beeper.bridge.network": "signal",
  209. "com.beeper.bridge.is_bridge_bot": False,
  210. "com.beeper.bridge.is_bot": False,
  211. }
  212. )
  213. self.contact_info_set = True
  214. except Exception:
  215. self.log.exception("Error updating contact info")
  216. self.contact_info_set = False
  217. return True
  218. @staticmethod
  219. def fmt_phone(number: str) -> str:
  220. if phonenumbers is None:
  221. return number
  222. parsed = phonenumbers.parse(number)
  223. fmt = phonenumbers.PhoneNumberFormat.INTERNATIONAL
  224. return phonenumbers.format_number(parsed, fmt)
  225. @classmethod
  226. def _get_displayname(cls, info: Profile | Address) -> tuple[str, int]:
  227. quality = 10
  228. if isinstance(info, Profile):
  229. address = info.address
  230. name = None
  231. contact_names = cls.config["bridge.contact_list_names"]
  232. if info.profile_name:
  233. name = info.profile_name
  234. quality = 90 if contact_names == "prefer" else 100
  235. if info.contact_name:
  236. if contact_names == "prefer":
  237. quality = 100
  238. name = info.contact_name
  239. elif contact_names == "allow" and not name:
  240. quality = 50
  241. name = info.contact_name
  242. names = name.split("\x00") if name else []
  243. else:
  244. address = info
  245. names = []
  246. data = {
  247. "first_name": names[0] if len(names) > 0 else "",
  248. "last_name": names[-1] if len(names) > 1 else "",
  249. "full_name": " ".join(names),
  250. "phone": cls.fmt_phone(address.number) if address.number else None,
  251. "uuid": str(address.uuid) if address.uuid else None,
  252. "displayname": "Unknown user",
  253. }
  254. for pref in cls.config["bridge.displayname_preference"]:
  255. value = data.get(pref.replace(" ", "_"))
  256. if value:
  257. data["displayname"] = value
  258. break
  259. return cls.config["bridge.displayname_template"].format(**data), quality
  260. async def _update_name(self, info: Profile | Address) -> bool:
  261. name, quality = self._get_displayname(info)
  262. if quality >= self.name_quality and (name != self.name or not self.name_set):
  263. self.log.debug(
  264. "Updating name from '%s' to '%s' (quality: %d)", self.name, name, quality
  265. )
  266. self.name = name
  267. self.name_quality = quality
  268. try:
  269. await self.default_mxid_intent.set_displayname(self.name)
  270. self.name_set = True
  271. except Exception:
  272. self.log.exception("Error setting displayname")
  273. self.name_set = False
  274. return True
  275. elif name != self.name or not self.name_set:
  276. self.log.debug(
  277. "Not updating name from '%s' to '%s', new quality (%d) is lower than old (%d)",
  278. self.name,
  279. name,
  280. quality,
  281. self.name_quality,
  282. )
  283. elif self.name_quality == 0:
  284. # Name matches, but quality is not stored in database - store it now
  285. self.name_quality = quality
  286. return True
  287. return False
  288. @staticmethod
  289. async def upload_avatar(
  290. self: Puppet | p.Portal, path: str, intent: IntentAPI
  291. ) -> bool | tuple[str, ContentURI]:
  292. if not path:
  293. return False
  294. if not path.startswith("/"):
  295. path = os.path.join(self.config["signal.avatar_dir"], path)
  296. try:
  297. with open(path, "rb") as file:
  298. data = file.read()
  299. except FileNotFoundError:
  300. return False
  301. if not data:
  302. return False
  303. new_hash = hashlib.sha256(data).hexdigest()
  304. if self.avatar_set and new_hash == self.avatar_hash:
  305. return False
  306. mxc = await intent.upload_media(data, async_upload=self.config["homeserver.async_media"])
  307. return new_hash, mxc
  308. async def _update_avatar(self, path: str) -> bool:
  309. res = await Puppet.upload_avatar(self, path, self.default_mxid_intent)
  310. if res is False:
  311. return False
  312. self.avatar_hash, self.avatar_url = res
  313. try:
  314. await self.default_mxid_intent.set_avatar_url(self.avatar_url)
  315. self.avatar_set = True
  316. except Exception:
  317. self.log.exception("Error setting avatar")
  318. self.avatar_set = False
  319. return True
  320. async def _try_update_portal_meta(self) -> None:
  321. try:
  322. await self._update_portal_meta()
  323. except Exception:
  324. self.log.exception("Error updating portal meta")
  325. async def _update_portal_meta(self) -> None:
  326. async for portal in p.Portal.find_private_chats_with(self.uuid):
  327. if portal.receiver == self.number:
  328. # This is a note to self chat, don't change the name
  329. continue
  330. try:
  331. await portal.update_puppet_name(self.name)
  332. await portal.update_puppet_avatar(self.avatar_hash, self.avatar_url)
  333. if self.number:
  334. await portal.update_puppet_number(self.fmt_phone(self.number))
  335. except Exception:
  336. self.log.exception(f"Error updating portal meta for {portal.receiver}")
  337. async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
  338. portal: p.Portal = await p.Portal.get_by_mxid(room_id)
  339. # Leave all portals except the notes to self room
  340. return not (portal and portal.is_direct and portal.chat_id == self.uuid)
  341. # region Database getters
  342. def _add_number_to_cache(self) -> None:
  343. if self.number:
  344. try:
  345. existing = self.by_number[self.number]
  346. if existing and existing.uuid != self.uuid and existing != self:
  347. existing.number = None
  348. except KeyError:
  349. pass
  350. self.by_number[self.number] = self
  351. def _add_to_cache(self) -> None:
  352. self.by_uuid[self.uuid] = self
  353. self._add_number_to_cache()
  354. if self.custom_mxid:
  355. self.by_custom_mxid[self.custom_mxid] = self
  356. async def save(self) -> None:
  357. await self.update()
  358. @classmethod
  359. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Puppet | None:
  360. uuid = cls.get_id_from_mxid(mxid)
  361. if not uuid:
  362. return None
  363. return await cls.get_by_uuid(uuid, create=create)
  364. @classmethod
  365. @async_getter_lock
  366. async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
  367. try:
  368. return cls.by_custom_mxid[mxid]
  369. except KeyError:
  370. pass
  371. puppet = cast(cls, await super().get_by_custom_mxid(mxid))
  372. if puppet:
  373. puppet._add_to_cache()
  374. return puppet
  375. return None
  376. @classmethod
  377. def get_id_from_mxid(cls, mxid: UserID) -> UUID | None:
  378. identifier = cls.mxid_template.parse(mxid)
  379. if not identifier:
  380. return None
  381. try:
  382. return UUID(identifier.upper())
  383. except ValueError:
  384. return None
  385. @classmethod
  386. def get_mxid_from_id(cls, uuid: UUID) -> UserID:
  387. return UserID(cls.mxid_template.format_full(str(uuid).lower()))
  388. @classmethod
  389. @async_getter_lock
  390. async def get_by_number(
  391. cls, number: str, /, *, resolve_via: str | None = None, raise_resolve: bool = False
  392. ) -> Puppet | None:
  393. try:
  394. return cls.by_number[number]
  395. except KeyError:
  396. pass
  397. puppet = cast(cls, await super().get_by_number(number))
  398. if puppet is not None:
  399. puppet._add_to_cache()
  400. return puppet
  401. if resolve_via:
  402. cls.log.debug(
  403. f"Couldn't find puppet with number {number}, resolving UUID via {resolve_via}"
  404. )
  405. try:
  406. uuid = await cls.signal.find_uuid(resolve_via, number)
  407. except UnregisteredUserError:
  408. if raise_resolve:
  409. raise
  410. cls.log.debug(f"Resolving {number} via {resolve_via} threw UnregisteredUserError")
  411. return None
  412. except Exception:
  413. if raise_resolve:
  414. raise
  415. cls.log.exception(f"Failed to resolve {number} via {resolve_via}")
  416. return None
  417. if uuid:
  418. cls.log.debug(f"Found {uuid} for {number} after resolving via {resolve_via}")
  419. return await cls.get_by_uuid(uuid, number=number)
  420. else:
  421. cls.log.debug(f"Didn't find UUID for {number} via {resolve_via}")
  422. return None
  423. @classmethod
  424. async def get_by_address(
  425. cls,
  426. address: Address,
  427. create: bool = True,
  428. resolve_via: str | None = None,
  429. raise_resolve: bool = False,
  430. ) -> Puppet | None:
  431. if not address.uuid:
  432. return await cls.get_by_number(
  433. address.number, resolve_via=resolve_via, raise_resolve=raise_resolve
  434. )
  435. else:
  436. return await cls.get_by_uuid(address.uuid, create=create, number=address.number)
  437. @classmethod
  438. @async_getter_lock
  439. async def get_by_uuid(
  440. cls, uuid: UUID, /, *, create: bool = True, number: str | None = None
  441. ) -> Puppet | None:
  442. try:
  443. return cls.by_uuid[uuid]
  444. except KeyError:
  445. pass
  446. puppet = cast(cls, await super().get_by_uuid(uuid))
  447. if puppet is not None:
  448. puppet._add_to_cache()
  449. return puppet
  450. if create:
  451. puppet = cls(uuid, number)
  452. await puppet.insert()
  453. puppet._add_to_cache()
  454. return puppet
  455. return None
  456. @classmethod
  457. async def all_with_custom_mxid(cls) -> AsyncGenerator[Puppet, None]:
  458. puppets = await super().all_with_custom_mxid()
  459. puppet: cls
  460. for index, puppet in enumerate(puppets):
  461. try:
  462. yield cls.by_uuid[puppet.uuid]
  463. except KeyError:
  464. puppet._add_to_cache()
  465. yield puppet
  466. # endregion