puppet.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2021 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
  18. from uuid import UUID
  19. import asyncio
  20. import hashlib
  21. import os.path
  22. from yarl import URL
  23. from mausignald.errors import UnregisteredUserError
  24. from mausignald.types import Address, Profile
  25. from mautrix.appservice import IntentAPI
  26. from mautrix.bridge import BasePuppet, async_getter_lock
  27. from mautrix.errors import MForbidden
  28. from mautrix.types import (
  29. ContentURI,
  30. EventType,
  31. PowerLevelStateEventContent,
  32. RoomID,
  33. SyncToken,
  34. UserID,
  35. )
  36. from mautrix.util import background_task
  37. from mautrix.util.simple_template import SimpleTemplate
  38. from . import portal as p, signal, user as u
  39. from .config import Config
  40. from .db import Puppet as DBPuppet
  41. if TYPE_CHECKING:
  42. from .__main__ import SignalBridge
  43. try:
  44. import phonenumbers
  45. except ImportError:
  46. phonenumbers = None
  47. class Puppet(DBPuppet, BasePuppet):
  48. by_uuid: dict[UUID, Puppet] = {}
  49. by_number: dict[str, Puppet] = {}
  50. by_custom_mxid: dict[UserID, Puppet] = {}
  51. hs_domain: str
  52. mxid_template: SimpleTemplate[str]
  53. config: Config
  54. signal: signal.SignalHandler
  55. default_mxid_intent: IntentAPI
  56. default_mxid: UserID
  57. _uuid_lock: asyncio.Lock
  58. _update_info_lock: asyncio.Lock
  59. def __init__(
  60. self,
  61. uuid: UUID,
  62. number: str | None,
  63. name: str | None = None,
  64. name_quality: int = 0,
  65. avatar_url: ContentURI | None = None,
  66. avatar_hash: str | None = None,
  67. name_set: bool = False,
  68. avatar_set: bool = False,
  69. is_registered: bool = False,
  70. custom_mxid: UserID | None = None,
  71. access_token: str | None = None,
  72. next_batch: SyncToken | None = None,
  73. base_url: URL | None = None,
  74. ) -> None:
  75. assert uuid, "UUID must be set for ghosts"
  76. assert isinstance(uuid, UUID)
  77. super().__init__(
  78. uuid=uuid,
  79. number=number,
  80. name=name,
  81. name_quality=name_quality,
  82. avatar_url=avatar_url,
  83. avatar_hash=avatar_hash,
  84. name_set=name_set,
  85. avatar_set=avatar_set,
  86. is_registered=is_registered,
  87. custom_mxid=custom_mxid,
  88. access_token=access_token,
  89. next_batch=next_batch,
  90. base_url=base_url,
  91. )
  92. self.log = self.log.getChild(str(uuid) if uuid else number)
  93. self.default_mxid = self.get_mxid_from_id(self.uuid)
  94. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  95. self.intent = self._fresh_intent()
  96. self._uuid_lock = asyncio.Lock()
  97. self._update_info_lock = asyncio.Lock()
  98. @classmethod
  99. def init_cls(cls, bridge: "SignalBridge") -> AsyncIterable[Awaitable[None]]:
  100. cls.config = bridge.config
  101. cls.loop = bridge.loop
  102. cls.signal = bridge.signal
  103. cls.mx = bridge.matrix
  104. cls.az = bridge.az
  105. cls.hs_domain = cls.config["homeserver.domain"]
  106. cls.mxid_template = SimpleTemplate(
  107. cls.config["bridge.username_template"],
  108. "userid",
  109. prefix="@",
  110. suffix=f":{cls.hs_domain}",
  111. type=str,
  112. )
  113. cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
  114. cls.homeserver_url_map = {
  115. server: URL(url)
  116. for server, url in cls.config["bridge.double_puppet_server_map"].items()
  117. }
  118. cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
  119. cls.login_shared_secret_map = {
  120. server: secret.encode("utf-8")
  121. for server, secret in cls.config["bridge.login_shared_secret_map"].items()
  122. }
  123. cls.login_device_name = "Signal Bridge"
  124. return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
  125. def intent_for(self, portal: p.Portal) -> IntentAPI:
  126. if portal.chat_id == self.uuid:
  127. return self.default_mxid_intent
  128. return self.intent
  129. @property
  130. def address(self) -> Address:
  131. return Address(uuid=self.uuid, number=self.number)
  132. async def handle_number_receive(self, number: str) -> None:
  133. async with self._uuid_lock:
  134. if self.number == number:
  135. return
  136. if self.number:
  137. self.by_number.pop(self.number, None)
  138. self.number = number
  139. self._add_number_to_cache()
  140. await self._update_number()
  141. async def _migrate_memberships(self, prev_intent: IntentAPI, new_intent: IntentAPI) -> None:
  142. self.log.debug(f"Migrating memberships {prev_intent.mxid} -> {new_intent.mxid}")
  143. try:
  144. joined_rooms = await prev_intent.get_joined_rooms()
  145. except MForbidden as e:
  146. self.log.debug(
  147. f"Got MForbidden ({e.message}) when getting joined rooms of old mxid, "
  148. "assuming there are no rooms to rejoin"
  149. )
  150. return
  151. for room_id in joined_rooms:
  152. await prev_intent.invite_user(room_id, self.default_mxid)
  153. await self._migrate_powers(prev_intent, new_intent, room_id)
  154. await prev_intent.leave_room(room_id)
  155. await new_intent.join_room_by_id(room_id)
  156. async def _migrate_powers(
  157. self, prev_intent: IntentAPI, new_intent: IntentAPI, room_id: RoomID
  158. ) -> None:
  159. try:
  160. powers: PowerLevelStateEventContent
  161. powers = await prev_intent.get_state_event(room_id, EventType.ROOM_POWER_LEVELS)
  162. user_level = powers.get_user_level(prev_intent.mxid)
  163. pl_state_level = powers.get_event_level(EventType.ROOM_POWER_LEVELS)
  164. if user_level >= pl_state_level > powers.users_default:
  165. powers.ensure_user_level(new_intent.mxid, user_level)
  166. await prev_intent.send_state_event(room_id, EventType.ROOM_POWER_LEVELS, powers)
  167. except Exception:
  168. self.log.warning("Failed to migrate power levels", exc_info=True)
  169. async def update_info(self, info: Profile | Address, source: u.User) -> None:
  170. update = False
  171. address = info.address if isinstance(info, Profile) else info
  172. if address.number and address.number != self.number:
  173. await self.handle_number_receive(address.number)
  174. update = True
  175. self.log.debug("Updating info with %s (source: %s)", info, source.mxid)
  176. async with self._update_info_lock:
  177. if isinstance(info, Profile) or self.name is None:
  178. update = await self._update_name(info) or update
  179. if isinstance(info, Profile):
  180. update = await self._update_avatar(info.avatar) or update
  181. elif self.config["bridge.contact_list_names"] != "disallow" and self.number:
  182. # Try to use a contact list avatar
  183. update = await self._update_avatar(f"contact-{self.number}") or update
  184. if update:
  185. await self.update()
  186. background_task.create(self._try_update_portal_meta())
  187. @staticmethod
  188. def fmt_phone(number: str) -> str:
  189. if phonenumbers is None:
  190. return number
  191. parsed = phonenumbers.parse(number)
  192. fmt = phonenumbers.PhoneNumberFormat.INTERNATIONAL
  193. return phonenumbers.format_number(parsed, fmt)
  194. @classmethod
  195. def _get_displayname(cls, info: Profile | Address) -> tuple[str, int]:
  196. quality = 10
  197. if isinstance(info, Profile):
  198. address = info.address
  199. name = None
  200. contact_names = cls.config["bridge.contact_list_names"]
  201. if info.profile_name:
  202. name = info.profile_name
  203. quality = 90 if contact_names == "prefer" else 100
  204. if info.contact_name:
  205. if contact_names == "prefer":
  206. quality = 100
  207. name = info.contact_name
  208. elif contact_names == "allow" and not name:
  209. quality = 50
  210. name = info.contact_name
  211. names = name.split("\x00") if name else []
  212. else:
  213. address = info
  214. names = []
  215. data = {
  216. "first_name": names[0] if len(names) > 0 else "",
  217. "last_name": names[-1] if len(names) > 1 else "",
  218. "full_name": " ".join(names),
  219. "phone": cls.fmt_phone(address.number) if address.number else None,
  220. "uuid": str(address.uuid) if address.uuid else None,
  221. "displayname": "Unknown user",
  222. }
  223. for pref in cls.config["bridge.displayname_preference"]:
  224. value = data.get(pref.replace(" ", "_"))
  225. if value:
  226. data["displayname"] = value
  227. break
  228. return cls.config["bridge.displayname_template"].format(**data), quality
  229. async def _update_name(self, info: Profile | Address) -> bool:
  230. name, quality = self._get_displayname(info)
  231. if quality >= self.name_quality and (name != self.name or not self.name_set):
  232. self.log.debug(
  233. "Updating name from '%s' to '%s' (quality: %d)", self.name, name, quality
  234. )
  235. self.name = name
  236. self.name_quality = quality
  237. try:
  238. await self.default_mxid_intent.set_displayname(self.name)
  239. self.name_set = True
  240. except Exception:
  241. self.log.exception("Error setting displayname")
  242. self.name_set = False
  243. return True
  244. elif name != self.name or not self.name_set:
  245. self.log.debug(
  246. "Not updating name from '%s' to '%s', new quality (%d) is lower than old (%d)",
  247. self.name,
  248. name,
  249. quality,
  250. self.name_quality,
  251. )
  252. elif self.name_quality == 0:
  253. # Name matches, but quality is not stored in database - store it now
  254. self.name_quality = quality
  255. return True
  256. return False
  257. @staticmethod
  258. async def upload_avatar(
  259. self: Puppet | p.Portal, path: str, intent: IntentAPI
  260. ) -> bool | tuple[str, ContentURI]:
  261. if not path:
  262. return False
  263. if not path.startswith("/"):
  264. path = os.path.join(self.config["signal.avatar_dir"], path)
  265. try:
  266. with open(path, "rb") as file:
  267. data = file.read()
  268. except FileNotFoundError:
  269. return False
  270. if not data:
  271. return False
  272. new_hash = hashlib.sha256(data).hexdigest()
  273. if self.avatar_set and new_hash == self.avatar_hash:
  274. return False
  275. mxc = await intent.upload_media(data, async_upload=self.config["homeserver.async_media"])
  276. return new_hash, mxc
  277. async def _update_avatar(self, path: str) -> bool:
  278. res = await Puppet.upload_avatar(self, path, self.default_mxid_intent)
  279. if res is False:
  280. return False
  281. self.avatar_hash, self.avatar_url = res
  282. try:
  283. await self.default_mxid_intent.set_avatar_url(self.avatar_url)
  284. self.avatar_set = True
  285. except Exception:
  286. self.log.exception("Error setting avatar")
  287. self.avatar_set = False
  288. return True
  289. async def _try_update_portal_meta(self) -> None:
  290. try:
  291. await self._update_portal_meta()
  292. except Exception:
  293. self.log.exception("Error updating portal meta")
  294. async def _update_portal_meta(self) -> None:
  295. async for portal in p.Portal.find_private_chats_with(self.uuid):
  296. if portal.receiver == self.number:
  297. # This is a note to self chat, don't change the name
  298. continue
  299. try:
  300. await portal.update_puppet_name(self.name)
  301. await portal.update_puppet_avatar(self.avatar_hash, self.avatar_url)
  302. if self.number:
  303. await portal.update_puppet_number(self.fmt_phone(self.number))
  304. except Exception:
  305. self.log.exception(f"Error updating portal meta for {portal.receiver}")
  306. async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
  307. portal: p.Portal = await p.Portal.get_by_mxid(room_id)
  308. # Leave all portals except the notes to self room
  309. return not (portal and portal.is_direct and portal.chat_id == self.uuid)
  310. # region Database getters
  311. def _add_number_to_cache(self) -> None:
  312. if self.number:
  313. try:
  314. existing = self.by_number[self.number]
  315. if existing and existing.uuid != self.uuid and existing != self:
  316. existing.number = None
  317. except KeyError:
  318. pass
  319. self.by_number[self.number] = self
  320. def _add_to_cache(self) -> None:
  321. self.by_uuid[self.uuid] = self
  322. self._add_number_to_cache()
  323. if self.custom_mxid:
  324. self.by_custom_mxid[self.custom_mxid] = self
  325. async def save(self) -> None:
  326. await self.update()
  327. @classmethod
  328. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Puppet | None:
  329. uuid = cls.get_id_from_mxid(mxid)
  330. if not uuid:
  331. return None
  332. return await cls.get_by_uuid(uuid, create=create)
  333. @classmethod
  334. @async_getter_lock
  335. async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
  336. try:
  337. return cls.by_custom_mxid[mxid]
  338. except KeyError:
  339. pass
  340. puppet = cast(cls, await super().get_by_custom_mxid(mxid))
  341. if puppet:
  342. puppet._add_to_cache()
  343. return puppet
  344. return None
  345. @classmethod
  346. def get_id_from_mxid(cls, mxid: UserID) -> UUID | None:
  347. identifier = cls.mxid_template.parse(mxid)
  348. if not identifier:
  349. return None
  350. try:
  351. return UUID(identifier.upper())
  352. except ValueError:
  353. return None
  354. @classmethod
  355. def get_mxid_from_id(cls, uuid: UUID) -> UserID:
  356. return UserID(cls.mxid_template.format_full(str(uuid).lower()))
  357. @classmethod
  358. @async_getter_lock
  359. async def get_by_number(
  360. cls, number: str, /, *, resolve_via: str | None = None, raise_resolve: bool = False
  361. ) -> Puppet | None:
  362. try:
  363. return cls.by_number[number]
  364. except KeyError:
  365. pass
  366. puppet = cast(cls, await super().get_by_number(number))
  367. if puppet is not None:
  368. puppet._add_to_cache()
  369. return puppet
  370. if resolve_via:
  371. cls.log.debug(
  372. f"Couldn't find puppet with number {number}, resolving UUID via {resolve_via}"
  373. )
  374. try:
  375. uuid = await cls.signal.find_uuid(resolve_via, number)
  376. except UnregisteredUserError:
  377. if raise_resolve:
  378. raise
  379. cls.log.debug(f"Resolving {number} via {resolve_via} threw UnregisteredUserError")
  380. return None
  381. except Exception:
  382. if raise_resolve:
  383. raise
  384. cls.log.exception(f"Failed to resolve {number} via {resolve_via}")
  385. return None
  386. if uuid:
  387. cls.log.debug(f"Found {uuid} for {number} after resolving via {resolve_via}")
  388. return await cls.get_by_uuid(uuid, number=number)
  389. else:
  390. cls.log.debug(f"Didn't find UUID for {number} via {resolve_via}")
  391. return None
  392. @classmethod
  393. async def get_by_address(
  394. cls,
  395. address: Address,
  396. create: bool = True,
  397. resolve_via: str | None = None,
  398. raise_resolve: bool = False,
  399. ) -> Puppet | None:
  400. if not address.uuid:
  401. return await cls.get_by_number(
  402. address.number, resolve_via=resolve_via, raise_resolve=raise_resolve
  403. )
  404. else:
  405. return await cls.get_by_uuid(address.uuid, create=create, number=address.number)
  406. @classmethod
  407. @async_getter_lock
  408. async def get_by_uuid(
  409. cls, uuid: UUID, /, *, create: bool = True, number: str | None = None
  410. ) -> Puppet | None:
  411. try:
  412. return cls.by_uuid[uuid]
  413. except KeyError:
  414. pass
  415. puppet = cast(cls, await super().get_by_uuid(uuid))
  416. if puppet is not None:
  417. puppet._add_to_cache()
  418. return puppet
  419. if create:
  420. puppet = cls(uuid, number)
  421. await puppet.insert()
  422. puppet._add_to_cache()
  423. return puppet
  424. return None
  425. @classmethod
  426. async def all_with_custom_mxid(cls) -> AsyncGenerator[Puppet, None]:
  427. puppets = await super().all_with_custom_mxid()
  428. puppet: cls
  429. for index, puppet in enumerate(puppets):
  430. try:
  431. yield cls.by_uuid[puppet.uuid]
  432. except KeyError:
  433. puppet._add_to_cache()
  434. yield puppet
  435. # endregion