puppet.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2021 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
  18. from uuid import UUID
  19. import asyncio
  20. import hashlib
  21. import os.path
  22. from yarl import URL
  23. from mausignald.errors import UnregisteredUserError
  24. from mausignald.types import Address, Profile
  25. from mautrix.appservice import IntentAPI
  26. from mautrix.bridge import BasePuppet, async_getter_lock
  27. from mautrix.errors import MForbidden
  28. from mautrix.types import (
  29. ContentURI,
  30. EventType,
  31. PowerLevelStateEventContent,
  32. RoomID,
  33. SyncToken,
  34. UserID,
  35. )
  36. from mautrix.util import background_task
  37. from mautrix.util.simple_template import SimpleTemplate
  38. from . import portal as p, signal, user as u
  39. from .config import Config
  40. from .db import Puppet as DBPuppet
  41. if TYPE_CHECKING:
  42. from .__main__ import SignalBridge
  43. try:
  44. import phonenumbers
  45. except ImportError:
  46. phonenumbers = None
  47. class Puppet(DBPuppet, BasePuppet):
  48. by_uuid: dict[UUID, Puppet] = {}
  49. by_number: dict[str, Puppet] = {}
  50. by_custom_mxid: dict[UserID, Puppet] = {}
  51. hs_domain: str
  52. mxid_template: SimpleTemplate[str]
  53. config: Config
  54. signal: signal.SignalHandler
  55. default_mxid_intent: IntentAPI
  56. default_mxid: UserID
  57. _uuid_lock: asyncio.Lock
  58. _update_info_lock: asyncio.Lock
  59. def __init__(
  60. self,
  61. uuid: UUID,
  62. number: str | None,
  63. name: str | None = None,
  64. name_quality: int = 0,
  65. avatar_url: ContentURI | None = None,
  66. avatar_hash: str | None = None,
  67. name_set: bool = False,
  68. avatar_set: bool = False,
  69. contact_info_set: bool = False,
  70. is_registered: bool = False,
  71. custom_mxid: UserID | None = None,
  72. access_token: str | None = None,
  73. next_batch: SyncToken | None = None,
  74. base_url: URL | None = None,
  75. ) -> None:
  76. assert uuid, "UUID must be set for ghosts"
  77. assert isinstance(uuid, UUID)
  78. super().__init__(
  79. uuid=uuid,
  80. number=number,
  81. name=name,
  82. name_quality=name_quality,
  83. avatar_url=avatar_url,
  84. avatar_hash=avatar_hash,
  85. name_set=name_set,
  86. avatar_set=avatar_set,
  87. contact_info_set=contact_info_set,
  88. is_registered=is_registered,
  89. custom_mxid=custom_mxid,
  90. access_token=access_token,
  91. next_batch=next_batch,
  92. base_url=base_url,
  93. )
  94. self.log = self.log.getChild(str(uuid) if uuid else number)
  95. self.default_mxid = self.get_mxid_from_id(self.uuid)
  96. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  97. self.intent = self._fresh_intent()
  98. self._uuid_lock = asyncio.Lock()
  99. self._update_info_lock = asyncio.Lock()
  100. @classmethod
  101. def init_cls(cls, bridge: "SignalBridge") -> AsyncIterable[Awaitable[None]]:
  102. cls.config = bridge.config
  103. cls.loop = bridge.loop
  104. cls.signal = bridge.signal
  105. cls.mx = bridge.matrix
  106. cls.az = bridge.az
  107. cls.hs_domain = cls.config["homeserver.domain"]
  108. cls.mxid_template = SimpleTemplate(
  109. cls.config["bridge.username_template"],
  110. "userid",
  111. prefix="@",
  112. suffix=f":{cls.hs_domain}",
  113. type=str,
  114. )
  115. cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
  116. cls.homeserver_url_map = {
  117. server: URL(url)
  118. for server, url in cls.config["bridge.double_puppet_server_map"].items()
  119. }
  120. cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
  121. cls.login_shared_secret_map = {
  122. server: secret.encode("utf-8")
  123. for server, secret in cls.config["bridge.login_shared_secret_map"].items()
  124. }
  125. cls.login_device_name = "Signal Bridge"
  126. return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
  127. def intent_for(self, portal: p.Portal) -> IntentAPI:
  128. if portal.chat_id == self.uuid:
  129. return self.default_mxid_intent
  130. return self.intent
  131. @property
  132. def address(self) -> Address:
  133. return Address(uuid=self.uuid, number=self.number)
  134. async def handle_number_receive(self, number: str) -> None:
  135. async with self._uuid_lock:
  136. if self.number == number:
  137. return
  138. if self.number:
  139. self.by_number.pop(self.number, None)
  140. self.number = number
  141. self._add_number_to_cache()
  142. await self._update_number()
  143. async def _migrate_memberships(self, prev_intent: IntentAPI, new_intent: IntentAPI) -> None:
  144. self.log.debug(f"Migrating memberships {prev_intent.mxid} -> {new_intent.mxid}")
  145. try:
  146. joined_rooms = await prev_intent.get_joined_rooms()
  147. except MForbidden as e:
  148. self.log.debug(
  149. f"Got MForbidden ({e.message}) when getting joined rooms of old mxid, "
  150. "assuming there are no rooms to rejoin"
  151. )
  152. return
  153. for room_id in joined_rooms:
  154. await prev_intent.invite_user(room_id, self.default_mxid)
  155. await self._migrate_powers(prev_intent, new_intent, room_id)
  156. await prev_intent.leave_room(room_id)
  157. await new_intent.join_room_by_id(room_id)
  158. async def _migrate_powers(
  159. self, prev_intent: IntentAPI, new_intent: IntentAPI, room_id: RoomID
  160. ) -> None:
  161. try:
  162. powers: PowerLevelStateEventContent
  163. powers = await prev_intent.get_state_event(room_id, EventType.ROOM_POWER_LEVELS)
  164. user_level = powers.get_user_level(prev_intent.mxid)
  165. pl_state_level = powers.get_event_level(EventType.ROOM_POWER_LEVELS)
  166. if user_level >= pl_state_level > powers.users_default:
  167. powers.ensure_user_level(new_intent.mxid, user_level)
  168. await prev_intent.send_state_event(room_id, EventType.ROOM_POWER_LEVELS, powers)
  169. except Exception:
  170. self.log.warning("Failed to migrate power levels", exc_info=True)
  171. async def update_info(self, info: Profile | Address, source: u.User) -> None:
  172. update = False
  173. address = info.address if isinstance(info, Profile) else info
  174. if address.number and address.number != self.number:
  175. await self.handle_number_receive(address.number)
  176. update = True
  177. self.log.debug("Updating info with %s (source: %s)", info, source.mxid)
  178. async with self._update_info_lock:
  179. if isinstance(info, Profile) or self.name is None:
  180. update = await self._update_name(info) or update
  181. if isinstance(info, Profile):
  182. update = await self._update_avatar(info.avatar) or update
  183. elif self.config["bridge.contact_list_names"] != "disallow" and self.number:
  184. # Try to use a contact list avatar
  185. update = await self._update_avatar(f"contact-{self.number}") or update
  186. if update:
  187. await self.update()
  188. background_task.create(self._try_update_portal_meta())
  189. @staticmethod
  190. def fmt_phone(number: str) -> str:
  191. if phonenumbers is None:
  192. return number
  193. parsed = phonenumbers.parse(number)
  194. fmt = phonenumbers.PhoneNumberFormat.INTERNATIONAL
  195. return phonenumbers.format_number(parsed, fmt)
  196. @classmethod
  197. def _get_displayname(cls, info: Profile | Address) -> tuple[str, int]:
  198. quality = 10
  199. if isinstance(info, Profile):
  200. address = info.address
  201. name = None
  202. contact_names = cls.config["bridge.contact_list_names"]
  203. if info.profile_name:
  204. name = info.profile_name
  205. quality = 90 if contact_names == "prefer" else 100
  206. if info.contact_name:
  207. if contact_names == "prefer":
  208. quality = 100
  209. name = info.contact_name
  210. elif contact_names == "allow" and not name:
  211. quality = 50
  212. name = info.contact_name
  213. names = name.split("\x00") if name else []
  214. else:
  215. address = info
  216. names = []
  217. data = {
  218. "first_name": names[0] if len(names) > 0 else "",
  219. "last_name": names[-1] if len(names) > 1 else "",
  220. "full_name": " ".join(names),
  221. "phone": cls.fmt_phone(address.number) if address.number else None,
  222. "uuid": str(address.uuid) if address.uuid else None,
  223. "displayname": "Unknown user",
  224. }
  225. for pref in cls.config["bridge.displayname_preference"]:
  226. value = data.get(pref.replace(" ", "_"))
  227. if value:
  228. data["displayname"] = value
  229. break
  230. return cls.config["bridge.displayname_template"].format(**data), quality
  231. async def _update_name(self, info: Profile | Address) -> bool:
  232. name, quality = self._get_displayname(info)
  233. if quality >= self.name_quality and (name != self.name or not self.name_set):
  234. self.log.debug(
  235. "Updating name from '%s' to '%s' (quality: %d)", self.name, name, quality
  236. )
  237. self.name = name
  238. self.name_quality = quality
  239. try:
  240. await self.default_mxid_intent.set_displayname(self.name)
  241. self.name_set = True
  242. except Exception:
  243. self.log.exception("Error setting displayname")
  244. self.name_set = False
  245. return True
  246. elif name != self.name or not self.name_set:
  247. self.log.debug(
  248. "Not updating name from '%s' to '%s', new quality (%d) is lower than old (%d)",
  249. self.name,
  250. name,
  251. quality,
  252. self.name_quality,
  253. )
  254. elif self.name_quality == 0:
  255. # Name matches, but quality is not stored in database - store it now
  256. self.name_quality = quality
  257. return True
  258. return False
  259. @staticmethod
  260. async def upload_avatar(
  261. self: Puppet | p.Portal, path: str, intent: IntentAPI
  262. ) -> bool | tuple[str, ContentURI]:
  263. if not path:
  264. return False
  265. if not path.startswith("/"):
  266. path = os.path.join(self.config["signal.avatar_dir"], path)
  267. try:
  268. with open(path, "rb") as file:
  269. data = file.read()
  270. except FileNotFoundError:
  271. return False
  272. if not data:
  273. return False
  274. new_hash = hashlib.sha256(data).hexdigest()
  275. if self.avatar_set and new_hash == self.avatar_hash:
  276. return False
  277. mxc = await intent.upload_media(data, async_upload=self.config["homeserver.async_media"])
  278. return new_hash, mxc
  279. async def _update_avatar(self, path: str) -> bool:
  280. res = await Puppet.upload_avatar(self, path, self.default_mxid_intent)
  281. if res is False:
  282. return False
  283. self.avatar_hash, self.avatar_url = res
  284. try:
  285. await self.default_mxid_intent.set_avatar_url(self.avatar_url)
  286. self.avatar_set = True
  287. except Exception:
  288. self.log.exception("Error setting avatar")
  289. self.avatar_set = False
  290. return True
  291. async def _try_update_portal_meta(self) -> None:
  292. try:
  293. await self._update_portal_meta()
  294. except Exception:
  295. self.log.exception("Error updating portal meta")
  296. async def _update_portal_meta(self) -> None:
  297. async for portal in p.Portal.find_private_chats_with(self.uuid):
  298. if portal.receiver == self.number:
  299. # This is a note to self chat, don't change the name
  300. continue
  301. try:
  302. await portal.update_puppet_name(self.name)
  303. await portal.update_puppet_avatar(self.avatar_hash, self.avatar_url)
  304. if self.number:
  305. await portal.update_puppet_number(self.fmt_phone(self.number))
  306. except Exception:
  307. self.log.exception(f"Error updating portal meta for {portal.receiver}")
  308. async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
  309. portal: p.Portal = await p.Portal.get_by_mxid(room_id)
  310. # Leave all portals except the notes to self room
  311. return not (portal and portal.is_direct and portal.chat_id == self.uuid)
  312. # region Database getters
  313. def _add_number_to_cache(self) -> None:
  314. if self.number:
  315. try:
  316. existing = self.by_number[self.number]
  317. if existing and existing.uuid != self.uuid and existing != self:
  318. existing.number = None
  319. except KeyError:
  320. pass
  321. self.by_number[self.number] = self
  322. def _add_to_cache(self) -> None:
  323. self.by_uuid[self.uuid] = self
  324. self._add_number_to_cache()
  325. if self.custom_mxid:
  326. self.by_custom_mxid[self.custom_mxid] = self
  327. async def save(self) -> None:
  328. await self.update()
  329. @classmethod
  330. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Puppet | None:
  331. uuid = cls.get_id_from_mxid(mxid)
  332. if not uuid:
  333. return None
  334. return await cls.get_by_uuid(uuid, create=create)
  335. @classmethod
  336. @async_getter_lock
  337. async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
  338. try:
  339. return cls.by_custom_mxid[mxid]
  340. except KeyError:
  341. pass
  342. puppet = cast(cls, await super().get_by_custom_mxid(mxid))
  343. if puppet:
  344. puppet._add_to_cache()
  345. return puppet
  346. return None
  347. @classmethod
  348. def get_id_from_mxid(cls, mxid: UserID) -> UUID | None:
  349. identifier = cls.mxid_template.parse(mxid)
  350. if not identifier:
  351. return None
  352. try:
  353. return UUID(identifier.upper())
  354. except ValueError:
  355. return None
  356. @classmethod
  357. def get_mxid_from_id(cls, uuid: UUID) -> UserID:
  358. return UserID(cls.mxid_template.format_full(str(uuid).lower()))
  359. @classmethod
  360. @async_getter_lock
  361. async def get_by_number(
  362. cls, number: str, /, *, resolve_via: str | None = None, raise_resolve: bool = False
  363. ) -> Puppet | None:
  364. try:
  365. return cls.by_number[number]
  366. except KeyError:
  367. pass
  368. puppet = cast(cls, await super().get_by_number(number))
  369. if puppet is not None:
  370. puppet._add_to_cache()
  371. return puppet
  372. if resolve_via:
  373. cls.log.debug(
  374. f"Couldn't find puppet with number {number}, resolving UUID via {resolve_via}"
  375. )
  376. try:
  377. uuid = await cls.signal.find_uuid(resolve_via, number)
  378. except UnregisteredUserError:
  379. if raise_resolve:
  380. raise
  381. cls.log.debug(f"Resolving {number} via {resolve_via} threw UnregisteredUserError")
  382. return None
  383. except Exception:
  384. if raise_resolve:
  385. raise
  386. cls.log.exception(f"Failed to resolve {number} via {resolve_via}")
  387. return None
  388. if uuid:
  389. cls.log.debug(f"Found {uuid} for {number} after resolving via {resolve_via}")
  390. return await cls.get_by_uuid(uuid, number=number)
  391. else:
  392. cls.log.debug(f"Didn't find UUID for {number} via {resolve_via}")
  393. return None
  394. @classmethod
  395. async def get_by_address(
  396. cls,
  397. address: Address,
  398. create: bool = True,
  399. resolve_via: str | None = None,
  400. raise_resolve: bool = False,
  401. ) -> Puppet | None:
  402. if not address.uuid:
  403. return await cls.get_by_number(
  404. address.number, resolve_via=resolve_via, raise_resolve=raise_resolve
  405. )
  406. else:
  407. return await cls.get_by_uuid(address.uuid, create=create, number=address.number)
  408. @classmethod
  409. @async_getter_lock
  410. async def get_by_uuid(
  411. cls, uuid: UUID, /, *, create: bool = True, number: str | None = None
  412. ) -> Puppet | None:
  413. try:
  414. return cls.by_uuid[uuid]
  415. except KeyError:
  416. pass
  417. puppet = cast(cls, await super().get_by_uuid(uuid))
  418. if puppet is not None:
  419. puppet._add_to_cache()
  420. return puppet
  421. if create:
  422. puppet = cls(uuid, number)
  423. await puppet.insert()
  424. puppet._add_to_cache()
  425. return puppet
  426. return None
  427. @classmethod
  428. async def all_with_custom_mxid(cls) -> AsyncGenerator[Puppet, None]:
  429. puppets = await super().all_with_custom_mxid()
  430. puppet: cls
  431. for index, puppet in enumerate(puppets):
  432. try:
  433. yield cls.by_uuid[puppet.uuid]
  434. except KeyError:
  435. puppet._add_to_cache()
  436. yield puppet
  437. # endregion