puppet.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2021 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
  18. from uuid import UUID
  19. import asyncio
  20. import hashlib
  21. import os.path
  22. from yarl import URL
  23. from mausignald.errors import UnregisteredUserError
  24. from mausignald.types import Address, Profile
  25. from mautrix.appservice import IntentAPI
  26. from mautrix.bridge import BasePuppet, async_getter_lock
  27. from mautrix.errors import MForbidden
  28. from mautrix.types import (
  29. ContentURI,
  30. EventType,
  31. PowerLevelStateEventContent,
  32. RoomID,
  33. SyncToken,
  34. UserID,
  35. )
  36. from mautrix.util import background_task
  37. from mautrix.util.simple_template import SimpleTemplate
  38. from . import portal as p, signal, user as u
  39. from .config import Config
  40. from .db import Puppet as DBPuppet
  41. if TYPE_CHECKING:
  42. from .__main__ import SignalBridge
  43. try:
  44. import phonenumbers
  45. except ImportError:
  46. phonenumbers = None
  47. class Puppet(DBPuppet, BasePuppet):
  48. by_uuid: dict[UUID, Puppet] = {}
  49. by_number: dict[str, Puppet] = {}
  50. by_custom_mxid: dict[UserID, Puppet] = {}
  51. hs_domain: str
  52. mxid_template: SimpleTemplate[str]
  53. bridge: SignalBridge
  54. config: Config
  55. signal: signal.SignalHandler
  56. default_mxid_intent: IntentAPI
  57. default_mxid: UserID
  58. _uuid_lock: asyncio.Lock
  59. _update_info_lock: asyncio.Lock
  60. def __init__(
  61. self,
  62. uuid: UUID,
  63. number: str | None,
  64. name: str | None = None,
  65. name_quality: int = 0,
  66. avatar_url: ContentURI | None = None,
  67. avatar_hash: str | None = None,
  68. name_set: bool = False,
  69. avatar_set: bool = False,
  70. contact_info_set: bool = False,
  71. is_registered: bool = False,
  72. custom_mxid: UserID | None = None,
  73. access_token: str | None = None,
  74. next_batch: SyncToken | None = None,
  75. base_url: URL | None = None,
  76. ) -> None:
  77. assert uuid, "UUID must be set for ghosts"
  78. assert isinstance(uuid, UUID)
  79. super().__init__(
  80. uuid=uuid,
  81. number=number,
  82. name=name,
  83. name_quality=name_quality,
  84. avatar_url=avatar_url,
  85. avatar_hash=avatar_hash,
  86. name_set=name_set,
  87. avatar_set=avatar_set,
  88. contact_info_set=contact_info_set,
  89. is_registered=is_registered,
  90. custom_mxid=custom_mxid,
  91. access_token=access_token,
  92. next_batch=next_batch,
  93. base_url=base_url,
  94. )
  95. self.log = self.log.getChild(str(uuid) if uuid else number)
  96. self.default_mxid = self.get_mxid_from_id(self.uuid)
  97. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  98. self.intent = self._fresh_intent()
  99. self._uuid_lock = asyncio.Lock()
  100. self._update_info_lock = asyncio.Lock()
  101. @classmethod
  102. def init_cls(cls, bridge: "SignalBridge") -> AsyncIterable[Awaitable[None]]:
  103. cls.bridge = bridge
  104. cls.config = bridge.config
  105. cls.loop = bridge.loop
  106. cls.signal = bridge.signal
  107. cls.mx = bridge.matrix
  108. cls.az = bridge.az
  109. cls.hs_domain = cls.config["homeserver.domain"]
  110. cls.mxid_template = SimpleTemplate(
  111. cls.config["bridge.username_template"],
  112. "userid",
  113. prefix="@",
  114. suffix=f":{cls.hs_domain}",
  115. type=str,
  116. )
  117. cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
  118. cls.homeserver_url_map = {
  119. server: URL(url)
  120. for server, url in cls.config["bridge.double_puppet_server_map"].items()
  121. }
  122. cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
  123. cls.login_shared_secret_map = {
  124. server: secret.encode("utf-8")
  125. for server, secret in cls.config["bridge.login_shared_secret_map"].items()
  126. }
  127. cls.login_device_name = "Signal Bridge"
  128. return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
  129. def intent_for(self, portal: p.Portal) -> IntentAPI:
  130. if portal.chat_id == self.uuid:
  131. return self.default_mxid_intent
  132. return self.intent
  133. @property
  134. def address(self) -> Address:
  135. return Address(uuid=self.uuid, number=self.number)
  136. async def handle_number_receive(self, number: str) -> None:
  137. async with self._uuid_lock:
  138. if self.number == number:
  139. return
  140. if self.number:
  141. self.by_number.pop(self.number, None)
  142. self.number = number
  143. self._add_number_to_cache()
  144. await self._update_number()
  145. async def _migrate_memberships(self, prev_intent: IntentAPI, new_intent: IntentAPI) -> None:
  146. self.log.debug(f"Migrating memberships {prev_intent.mxid} -> {new_intent.mxid}")
  147. try:
  148. joined_rooms = await prev_intent.get_joined_rooms()
  149. except MForbidden as e:
  150. self.log.debug(
  151. f"Got MForbidden ({e.message}) when getting joined rooms of old mxid, "
  152. "assuming there are no rooms to rejoin"
  153. )
  154. return
  155. for room_id in joined_rooms:
  156. await prev_intent.invite_user(room_id, self.default_mxid)
  157. await self._migrate_powers(prev_intent, new_intent, room_id)
  158. await prev_intent.leave_room(room_id)
  159. await new_intent.join_room_by_id(room_id)
  160. async def _migrate_powers(
  161. self, prev_intent: IntentAPI, new_intent: IntentAPI, room_id: RoomID
  162. ) -> None:
  163. try:
  164. powers: PowerLevelStateEventContent
  165. powers = await prev_intent.get_state_event(room_id, EventType.ROOM_POWER_LEVELS)
  166. user_level = powers.get_user_level(prev_intent.mxid)
  167. pl_state_level = powers.get_event_level(EventType.ROOM_POWER_LEVELS)
  168. if user_level >= pl_state_level > powers.users_default:
  169. powers.ensure_user_level(new_intent.mxid, user_level)
  170. await prev_intent.send_state_event(room_id, EventType.ROOM_POWER_LEVELS, powers)
  171. except Exception:
  172. self.log.warning("Failed to migrate power levels", exc_info=True)
  173. async def update_info(self, info: Profile | Address, source: u.User) -> None:
  174. update = False
  175. address = info.address if isinstance(info, Profile) else info
  176. number_updated = False
  177. if address.number and address.number != self.number:
  178. await self.handle_number_receive(address.number)
  179. update = True
  180. number_updated = True
  181. self.log.debug("Updating info with %s (source: %s)", info, source.mxid)
  182. async with self._update_info_lock:
  183. update = await self._update_contact_info(force=number_updated) or update
  184. if isinstance(info, Profile) or self.name is None:
  185. update = await self._update_name(info) or update
  186. if isinstance(info, Profile):
  187. update = await self._update_avatar(info.avatar) or update
  188. elif self.config["bridge.contact_list_names"] != "disallow" and self.number:
  189. # Try to use a contact list avatar
  190. update = await self._update_avatar(f"contact-{self.number}") or update
  191. if update:
  192. await self.update()
  193. background_task.create(self._try_update_portal_meta())
  194. async def _update_contact_info(self, force: bool = False) -> bool:
  195. if not self.bridge.homeserver_software.is_hungry:
  196. return False
  197. if self.contact_info_set and not force:
  198. return False
  199. try:
  200. identifiers = [f"signal:{self.uuid}"]
  201. if self.number:
  202. identifiers.append(f"tel:{self.number}")
  203. await self.default_mxid_intent.beeper_update_profile(
  204. {
  205. "com.beeper.bridge.identifiers": identifiers,
  206. "com.beeper.bridge.remote_id": str(self.uuid),
  207. "com.beeper.bridge.service": "signal",
  208. "com.beeper.bridge.network": "signal",
  209. }
  210. )
  211. self.contact_info_set = True
  212. except Exception:
  213. self.log.exception("Error updating contact info")
  214. self.contact_info_set = False
  215. return True
  216. @staticmethod
  217. def fmt_phone(number: str) -> str:
  218. if phonenumbers is None:
  219. return number
  220. parsed = phonenumbers.parse(number)
  221. fmt = phonenumbers.PhoneNumberFormat.INTERNATIONAL
  222. return phonenumbers.format_number(parsed, fmt)
  223. @classmethod
  224. def _get_displayname(cls, info: Profile | Address) -> tuple[str, int]:
  225. quality = 10
  226. if isinstance(info, Profile):
  227. address = info.address
  228. name = None
  229. contact_names = cls.config["bridge.contact_list_names"]
  230. if info.profile_name:
  231. name = info.profile_name
  232. quality = 90 if contact_names == "prefer" else 100
  233. if info.contact_name:
  234. if contact_names == "prefer":
  235. quality = 100
  236. name = info.contact_name
  237. elif contact_names == "allow" and not name:
  238. quality = 50
  239. name = info.contact_name
  240. names = name.split("\x00") if name else []
  241. else:
  242. address = info
  243. names = []
  244. data = {
  245. "first_name": names[0] if len(names) > 0 else "",
  246. "last_name": names[-1] if len(names) > 1 else "",
  247. "full_name": " ".join(names),
  248. "phone": cls.fmt_phone(address.number) if address.number else None,
  249. "uuid": str(address.uuid) if address.uuid else None,
  250. "displayname": "Unknown user",
  251. }
  252. for pref in cls.config["bridge.displayname_preference"]:
  253. value = data.get(pref.replace(" ", "_"))
  254. if value:
  255. data["displayname"] = value
  256. break
  257. return cls.config["bridge.displayname_template"].format(**data), quality
  258. async def _update_name(self, info: Profile | Address) -> bool:
  259. name, quality = self._get_displayname(info)
  260. if quality >= self.name_quality and (name != self.name or not self.name_set):
  261. self.log.debug(
  262. "Updating name from '%s' to '%s' (quality: %d)", self.name, name, quality
  263. )
  264. self.name = name
  265. self.name_quality = quality
  266. try:
  267. await self.default_mxid_intent.set_displayname(self.name)
  268. self.name_set = True
  269. except Exception:
  270. self.log.exception("Error setting displayname")
  271. self.name_set = False
  272. return True
  273. elif name != self.name or not self.name_set:
  274. self.log.debug(
  275. "Not updating name from '%s' to '%s', new quality (%d) is lower than old (%d)",
  276. self.name,
  277. name,
  278. quality,
  279. self.name_quality,
  280. )
  281. elif self.name_quality == 0:
  282. # Name matches, but quality is not stored in database - store it now
  283. self.name_quality = quality
  284. return True
  285. return False
  286. @staticmethod
  287. async def upload_avatar(
  288. self: Puppet | p.Portal, path: str, intent: IntentAPI
  289. ) -> bool | tuple[str, ContentURI]:
  290. if not path:
  291. return False
  292. if not path.startswith("/"):
  293. path = os.path.join(self.config["signal.avatar_dir"], path)
  294. try:
  295. with open(path, "rb") as file:
  296. data = file.read()
  297. except FileNotFoundError:
  298. return False
  299. if not data:
  300. return False
  301. new_hash = hashlib.sha256(data).hexdigest()
  302. if self.avatar_set and new_hash == self.avatar_hash:
  303. return False
  304. mxc = await intent.upload_media(data, async_upload=self.config["homeserver.async_media"])
  305. return new_hash, mxc
  306. async def _update_avatar(self, path: str) -> bool:
  307. res = await Puppet.upload_avatar(self, path, self.default_mxid_intent)
  308. if res is False:
  309. return False
  310. self.avatar_hash, self.avatar_url = res
  311. try:
  312. await self.default_mxid_intent.set_avatar_url(self.avatar_url)
  313. self.avatar_set = True
  314. except Exception:
  315. self.log.exception("Error setting avatar")
  316. self.avatar_set = False
  317. return True
  318. async def _try_update_portal_meta(self) -> None:
  319. try:
  320. await self._update_portal_meta()
  321. except Exception:
  322. self.log.exception("Error updating portal meta")
  323. async def _update_portal_meta(self) -> None:
  324. async for portal in p.Portal.find_private_chats_with(self.uuid):
  325. if portal.receiver == self.number:
  326. # This is a note to self chat, don't change the name
  327. continue
  328. try:
  329. await portal.update_puppet_name(self.name)
  330. await portal.update_puppet_avatar(self.avatar_hash, self.avatar_url)
  331. if self.number:
  332. await portal.update_puppet_number(self.fmt_phone(self.number))
  333. except Exception:
  334. self.log.exception(f"Error updating portal meta for {portal.receiver}")
  335. async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
  336. portal: p.Portal = await p.Portal.get_by_mxid(room_id)
  337. # Leave all portals except the notes to self room
  338. return not (portal and portal.is_direct and portal.chat_id == self.uuid)
  339. # region Database getters
  340. def _add_number_to_cache(self) -> None:
  341. if self.number:
  342. try:
  343. existing = self.by_number[self.number]
  344. if existing and existing.uuid != self.uuid and existing != self:
  345. existing.number = None
  346. except KeyError:
  347. pass
  348. self.by_number[self.number] = self
  349. def _add_to_cache(self) -> None:
  350. self.by_uuid[self.uuid] = self
  351. self._add_number_to_cache()
  352. if self.custom_mxid:
  353. self.by_custom_mxid[self.custom_mxid] = self
  354. async def save(self) -> None:
  355. await self.update()
  356. @classmethod
  357. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Puppet | None:
  358. uuid = cls.get_id_from_mxid(mxid)
  359. if not uuid:
  360. return None
  361. return await cls.get_by_uuid(uuid, create=create)
  362. @classmethod
  363. @async_getter_lock
  364. async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
  365. try:
  366. return cls.by_custom_mxid[mxid]
  367. except KeyError:
  368. pass
  369. puppet = cast(cls, await super().get_by_custom_mxid(mxid))
  370. if puppet:
  371. puppet._add_to_cache()
  372. return puppet
  373. return None
  374. @classmethod
  375. def get_id_from_mxid(cls, mxid: UserID) -> UUID | None:
  376. identifier = cls.mxid_template.parse(mxid)
  377. if not identifier:
  378. return None
  379. try:
  380. return UUID(identifier.upper())
  381. except ValueError:
  382. return None
  383. @classmethod
  384. def get_mxid_from_id(cls, uuid: UUID) -> UserID:
  385. return UserID(cls.mxid_template.format_full(str(uuid).lower()))
  386. @classmethod
  387. @async_getter_lock
  388. async def get_by_number(
  389. cls, number: str, /, *, resolve_via: str | None = None, raise_resolve: bool = False
  390. ) -> Puppet | None:
  391. try:
  392. return cls.by_number[number]
  393. except KeyError:
  394. pass
  395. puppet = cast(cls, await super().get_by_number(number))
  396. if puppet is not None:
  397. puppet._add_to_cache()
  398. return puppet
  399. if resolve_via:
  400. cls.log.debug(
  401. f"Couldn't find puppet with number {number}, resolving UUID via {resolve_via}"
  402. )
  403. try:
  404. uuid = await cls.signal.find_uuid(resolve_via, number)
  405. except UnregisteredUserError:
  406. if raise_resolve:
  407. raise
  408. cls.log.debug(f"Resolving {number} via {resolve_via} threw UnregisteredUserError")
  409. return None
  410. except Exception:
  411. if raise_resolve:
  412. raise
  413. cls.log.exception(f"Failed to resolve {number} via {resolve_via}")
  414. return None
  415. if uuid:
  416. cls.log.debug(f"Found {uuid} for {number} after resolving via {resolve_via}")
  417. return await cls.get_by_uuid(uuid, number=number)
  418. else:
  419. cls.log.debug(f"Didn't find UUID for {number} via {resolve_via}")
  420. return None
  421. @classmethod
  422. async def get_by_address(
  423. cls,
  424. address: Address,
  425. create: bool = True,
  426. resolve_via: str | None = None,
  427. raise_resolve: bool = False,
  428. ) -> Puppet | None:
  429. if not address.uuid:
  430. return await cls.get_by_number(
  431. address.number, resolve_via=resolve_via, raise_resolve=raise_resolve
  432. )
  433. else:
  434. return await cls.get_by_uuid(address.uuid, create=create, number=address.number)
  435. @classmethod
  436. @async_getter_lock
  437. async def get_by_uuid(
  438. cls, uuid: UUID, /, *, create: bool = True, number: str | None = None
  439. ) -> Puppet | None:
  440. try:
  441. return cls.by_uuid[uuid]
  442. except KeyError:
  443. pass
  444. puppet = cast(cls, await super().get_by_uuid(uuid))
  445. if puppet is not None:
  446. puppet._add_to_cache()
  447. return puppet
  448. if create:
  449. puppet = cls(uuid, number)
  450. await puppet.insert()
  451. puppet._add_to_cache()
  452. return puppet
  453. return None
  454. @classmethod
  455. async def all_with_custom_mxid(cls) -> AsyncGenerator[Puppet, None]:
  456. puppets = await super().all_with_custom_mxid()
  457. puppet: cls
  458. for index, puppet in enumerate(puppets):
  459. try:
  460. yield cls.by_uuid[puppet.uuid]
  461. except KeyError:
  462. puppet._add_to_cache()
  463. yield puppet
  464. # endregion