puppet.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2021 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
  18. from uuid import UUID
  19. import asyncio
  20. import hashlib
  21. import os.path
  22. from yarl import URL
  23. from mausignald.errors import UnregisteredUserError
  24. from mausignald.types import Address, Profile
  25. from mautrix.appservice import IntentAPI
  26. from mautrix.bridge import BasePuppet, async_getter_lock
  27. from mautrix.errors import MForbidden
  28. from mautrix.types import (
  29. ContentURI,
  30. EventType,
  31. PowerLevelStateEventContent,
  32. RoomID,
  33. SyncToken,
  34. UserID,
  35. )
  36. from mautrix.util.simple_template import SimpleTemplate
  37. from . import portal as p, signal, user as u
  38. from .config import Config
  39. from .db import Puppet as DBPuppet
  40. if TYPE_CHECKING:
  41. from .__main__ import SignalBridge
  42. try:
  43. import phonenumbers
  44. except ImportError:
  45. phonenumbers = None
  46. class Puppet(DBPuppet, BasePuppet):
  47. by_uuid: dict[UUID, Puppet] = {}
  48. by_number: dict[str, Puppet] = {}
  49. by_custom_mxid: dict[UserID, Puppet] = {}
  50. hs_domain: str
  51. mxid_template: SimpleTemplate[str]
  52. config: Config
  53. signal: signal.SignalHandler
  54. default_mxid_intent: IntentAPI
  55. default_mxid: UserID
  56. _uuid_lock: asyncio.Lock
  57. _update_info_lock: asyncio.Lock
  58. def __init__(
  59. self,
  60. uuid: UUID,
  61. number: str | None,
  62. name: str | None = None,
  63. name_quality: int = 0,
  64. avatar_url: ContentURI | None = None,
  65. avatar_hash: str | None = None,
  66. name_set: bool = False,
  67. avatar_set: bool = False,
  68. is_registered: bool = False,
  69. custom_mxid: UserID | None = None,
  70. access_token: str | None = None,
  71. next_batch: SyncToken | None = None,
  72. base_url: URL | None = None,
  73. ) -> None:
  74. assert uuid, "UUID must be set for ghosts"
  75. assert isinstance(uuid, UUID)
  76. super().__init__(
  77. uuid=uuid,
  78. number=number,
  79. name=name,
  80. name_quality=name_quality,
  81. avatar_url=avatar_url,
  82. avatar_hash=avatar_hash,
  83. name_set=name_set,
  84. avatar_set=avatar_set,
  85. is_registered=is_registered,
  86. custom_mxid=custom_mxid,
  87. access_token=access_token,
  88. next_batch=next_batch,
  89. base_url=base_url,
  90. )
  91. self.log = self.log.getChild(str(uuid) if uuid else number)
  92. self.default_mxid = self.get_mxid_from_id(self.uuid)
  93. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  94. self.intent = self._fresh_intent()
  95. self._uuid_lock = asyncio.Lock()
  96. self._update_info_lock = asyncio.Lock()
  97. @classmethod
  98. def init_cls(cls, bridge: "SignalBridge") -> AsyncIterable[Awaitable[None]]:
  99. cls.config = bridge.config
  100. cls.loop = bridge.loop
  101. cls.signal = bridge.signal
  102. cls.mx = bridge.matrix
  103. cls.az = bridge.az
  104. cls.hs_domain = cls.config["homeserver.domain"]
  105. cls.mxid_template = SimpleTemplate(
  106. cls.config["bridge.username_template"],
  107. "userid",
  108. prefix="@",
  109. suffix=f":{cls.hs_domain}",
  110. type=str,
  111. )
  112. cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
  113. cls.homeserver_url_map = {
  114. server: URL(url)
  115. for server, url in cls.config["bridge.double_puppet_server_map"].items()
  116. }
  117. cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
  118. cls.login_shared_secret_map = {
  119. server: secret.encode("utf-8")
  120. for server, secret in cls.config["bridge.login_shared_secret_map"].items()
  121. }
  122. cls.login_device_name = "Signal Bridge"
  123. return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
  124. def intent_for(self, portal: p.Portal) -> IntentAPI:
  125. if portal.chat_id == self.uuid:
  126. return self.default_mxid_intent
  127. return self.intent
  128. @property
  129. def address(self) -> Address:
  130. return Address(uuid=self.uuid, number=self.number)
  131. async def handle_number_receive(self, number: str) -> None:
  132. async with self._uuid_lock:
  133. if self.number == number:
  134. return
  135. if self.number:
  136. self.by_number.pop(self.number, None)
  137. self.number = number
  138. self._add_number_to_cache()
  139. await self._update_number()
  140. async def _migrate_memberships(self, prev_intent: IntentAPI, new_intent: IntentAPI) -> None:
  141. self.log.debug(f"Migrating memberships {prev_intent.mxid} -> {new_intent.mxid}")
  142. try:
  143. joined_rooms = await prev_intent.get_joined_rooms()
  144. except MForbidden as e:
  145. self.log.debug(
  146. f"Got MForbidden ({e.message}) when getting joined rooms of old mxid, "
  147. "assuming there are no rooms to rejoin"
  148. )
  149. return
  150. for room_id in joined_rooms:
  151. await prev_intent.invite_user(room_id, self.default_mxid)
  152. await self._migrate_powers(prev_intent, new_intent, room_id)
  153. await prev_intent.leave_room(room_id)
  154. await new_intent.join_room_by_id(room_id)
  155. async def _migrate_powers(
  156. self, prev_intent: IntentAPI, new_intent: IntentAPI, room_id: RoomID
  157. ) -> None:
  158. try:
  159. powers: PowerLevelStateEventContent
  160. powers = await prev_intent.get_state_event(room_id, EventType.ROOM_POWER_LEVELS)
  161. user_level = powers.get_user_level(prev_intent.mxid)
  162. pl_state_level = powers.get_event_level(EventType.ROOM_POWER_LEVELS)
  163. if user_level >= pl_state_level > powers.users_default:
  164. powers.ensure_user_level(new_intent.mxid, user_level)
  165. await prev_intent.send_state_event(room_id, EventType.ROOM_POWER_LEVELS, powers)
  166. except Exception:
  167. self.log.warning("Failed to migrate power levels", exc_info=True)
  168. async def update_info(self, info: Profile | Address, source: u.User) -> None:
  169. update = False
  170. address = info.address if isinstance(info, Profile) else info
  171. if address.number and address.number != self.number:
  172. await self.handle_number_receive(address.number)
  173. update = True
  174. self.log.debug("Updating info with %s (source: %s)", info, source.mxid)
  175. async with self._update_info_lock:
  176. if isinstance(info, Profile) or self.name is None:
  177. update = await self._update_name(info) or update
  178. if isinstance(info, Profile):
  179. update = await self._update_avatar(info.avatar) or update
  180. elif self.config["bridge.contact_list_names"] != "disallow" and self.number:
  181. # Try to use a contact list avatar
  182. update = await self._update_avatar(f"contact-{self.number}") or update
  183. if update:
  184. await self.update()
  185. asyncio.create_task(self._update_portal_meta())
  186. @staticmethod
  187. def fmt_phone(number: str) -> str:
  188. if phonenumbers is None:
  189. return number
  190. parsed = phonenumbers.parse(number)
  191. fmt = phonenumbers.PhoneNumberFormat.INTERNATIONAL
  192. return phonenumbers.format_number(parsed, fmt)
  193. @classmethod
  194. def _get_displayname(cls, info: Profile | Address) -> tuple[str, int]:
  195. quality = 10
  196. if isinstance(info, Profile):
  197. address = info.address
  198. name = None
  199. contact_names = cls.config["bridge.contact_list_names"]
  200. if info.profile_name:
  201. name = info.profile_name
  202. quality = 90 if contact_names == "prefer" else 100
  203. if info.contact_name:
  204. if contact_names == "prefer":
  205. quality = 100
  206. name = info.contact_name
  207. elif contact_names == "allow" and not name:
  208. quality = 50
  209. name = info.contact_name
  210. names = name.split("\x00") if name else []
  211. else:
  212. address = info
  213. names = []
  214. data = {
  215. "first_name": names[0] if len(names) > 0 else "",
  216. "last_name": names[-1] if len(names) > 1 else "",
  217. "full_name": " ".join(names),
  218. "phone": cls.fmt_phone(address.number) if address.number else None,
  219. "uuid": str(address.uuid) if address.uuid else None,
  220. "displayname": "Unknown user",
  221. }
  222. for pref in cls.config["bridge.displayname_preference"]:
  223. value = data.get(pref.replace(" ", "_"))
  224. if value:
  225. data["displayname"] = value
  226. break
  227. return cls.config["bridge.displayname_template"].format(**data), quality
  228. async def _update_name(self, info: Profile | Address) -> bool:
  229. name, quality = self._get_displayname(info)
  230. if quality >= self.name_quality and (name != self.name or not self.name_set):
  231. self.log.debug(
  232. "Updating name from '%s' to '%s' (quality: %d)", self.name, name, quality
  233. )
  234. self.name = name
  235. self.name_quality = quality
  236. try:
  237. await self.default_mxid_intent.set_displayname(self.name)
  238. self.name_set = True
  239. except Exception:
  240. self.log.exception("Error setting displayname")
  241. self.name_set = False
  242. return True
  243. elif name != self.name or not self.name_set:
  244. self.log.debug(
  245. "Not updating name from '%s' to '%s', new quality (%d) is lower than old (%d)",
  246. self.name,
  247. name,
  248. quality,
  249. self.name_quality,
  250. )
  251. elif self.name_quality == 0:
  252. # Name matches, but quality is not stored in database - store it now
  253. self.name_quality = quality
  254. return True
  255. return False
  256. @staticmethod
  257. async def upload_avatar(
  258. self: Puppet | p.Portal, path: str, intent: IntentAPI
  259. ) -> bool | tuple[str, ContentURI]:
  260. if not path:
  261. return False
  262. if not path.startswith("/"):
  263. path = os.path.join(self.config["signal.avatar_dir"], path)
  264. try:
  265. with open(path, "rb") as file:
  266. data = file.read()
  267. except FileNotFoundError:
  268. return False
  269. if not data:
  270. return False
  271. new_hash = hashlib.sha256(data).hexdigest()
  272. if self.avatar_set and new_hash == self.avatar_hash:
  273. return False
  274. mxc = await intent.upload_media(data, async_upload=self.config["homeserver.async_media"])
  275. return new_hash, mxc
  276. async def _update_avatar(self, path: str) -> bool:
  277. res = await Puppet.upload_avatar(self, path, self.default_mxid_intent)
  278. if res is False:
  279. return False
  280. self.avatar_hash, self.avatar_url = res
  281. try:
  282. await self.default_mxid_intent.set_avatar_url(self.avatar_url)
  283. self.avatar_set = True
  284. except Exception:
  285. self.log.exception("Error setting avatar")
  286. self.avatar_set = False
  287. return True
  288. async def _update_portal_meta(self) -> None:
  289. async for portal in p.Portal.find_private_chats_with(self.uuid):
  290. if portal.receiver == self.number:
  291. # This is a note to self chat, don't change the name
  292. continue
  293. try:
  294. await portal.update_puppet_name(self.name)
  295. await portal.update_puppet_avatar(self.avatar_hash, self.avatar_url)
  296. if self.number:
  297. await portal.update_puppet_number(self.fmt_phone(self.number))
  298. except Exception:
  299. self.log.exception(f"Error updating portal meta for {portal.receiver}")
  300. async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
  301. portal: p.Portal = await p.Portal.get_by_mxid(room_id)
  302. # Leave all portals except the notes to self room
  303. return not (portal and portal.is_direct and portal.chat_id == self.uuid)
  304. # region Database getters
  305. def _add_number_to_cache(self) -> None:
  306. if self.number:
  307. try:
  308. existing = self.by_number[self.number]
  309. if existing and existing.uuid != self.uuid and existing != self:
  310. existing.number = None
  311. except KeyError:
  312. pass
  313. self.by_number[self.number] = self
  314. def _add_to_cache(self) -> None:
  315. self.by_uuid[self.uuid] = self
  316. self._add_number_to_cache()
  317. if self.custom_mxid:
  318. self.by_custom_mxid[self.custom_mxid] = self
  319. async def save(self) -> None:
  320. await self.update()
  321. @classmethod
  322. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Puppet | None:
  323. uuid = cls.get_id_from_mxid(mxid)
  324. if not uuid:
  325. return None
  326. return await cls.get_by_uuid(uuid, create=create)
  327. @classmethod
  328. @async_getter_lock
  329. async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
  330. try:
  331. return cls.by_custom_mxid[mxid]
  332. except KeyError:
  333. pass
  334. puppet = cast(cls, await super().get_by_custom_mxid(mxid))
  335. if puppet:
  336. puppet._add_to_cache()
  337. return puppet
  338. return None
  339. @classmethod
  340. def get_id_from_mxid(cls, mxid: UserID) -> UUID | None:
  341. identifier = cls.mxid_template.parse(mxid)
  342. if not identifier:
  343. return None
  344. try:
  345. return UUID(identifier.upper())
  346. except ValueError:
  347. return None
  348. @classmethod
  349. def get_mxid_from_id(cls, uuid: UUID) -> UserID:
  350. return UserID(cls.mxid_template.format_full(str(uuid).lower()))
  351. @classmethod
  352. @async_getter_lock
  353. async def get_by_number(
  354. cls, number: str, /, *, resolve_via: str | None = None, raise_resolve: bool = False
  355. ) -> Puppet | None:
  356. try:
  357. return cls.by_number[number]
  358. except KeyError:
  359. pass
  360. puppet = cast(cls, await super().get_by_number(number))
  361. if puppet is not None:
  362. puppet._add_to_cache()
  363. return puppet
  364. if resolve_via:
  365. cls.log.debug(
  366. f"Couldn't find puppet with number {number}, resolving UUID via {resolve_via}"
  367. )
  368. try:
  369. uuid = await cls.signal.find_uuid(resolve_via, number)
  370. except UnregisteredUserError:
  371. if raise_resolve:
  372. raise
  373. cls.log.debug(f"Resolving {number} via {resolve_via} threw UnregisteredUserError")
  374. return None
  375. except Exception:
  376. if raise_resolve:
  377. raise
  378. cls.log.exception(f"Failed to resolve {number} via {resolve_via}")
  379. return None
  380. if uuid:
  381. cls.log.debug(f"Found {uuid} for {number} after resolving via {resolve_via}")
  382. return await cls.get_by_uuid(uuid, number=number)
  383. else:
  384. cls.log.debug(f"Didn't find UUID for {number} via {resolve_via}")
  385. return None
  386. @classmethod
  387. async def get_by_address(
  388. cls,
  389. address: Address,
  390. create: bool = True,
  391. resolve_via: str | None = None,
  392. raise_resolve: bool = False,
  393. ) -> Puppet | None:
  394. if not address.uuid:
  395. return await cls.get_by_number(
  396. address.number, resolve_via=resolve_via, raise_resolve=raise_resolve
  397. )
  398. else:
  399. return await cls.get_by_uuid(address.uuid, create=create, number=address.number)
  400. @classmethod
  401. @async_getter_lock
  402. async def get_by_uuid(
  403. cls, uuid: UUID, /, *, create: bool = True, number: str | None = None
  404. ) -> Puppet | None:
  405. try:
  406. return cls.by_uuid[uuid]
  407. except KeyError:
  408. pass
  409. puppet = cast(cls, await super().get_by_uuid(uuid))
  410. if puppet is not None:
  411. puppet._add_to_cache()
  412. return puppet
  413. if create:
  414. puppet = cls(uuid, number)
  415. await puppet.insert()
  416. puppet._add_to_cache()
  417. return puppet
  418. return None
  419. @classmethod
  420. async def all_with_custom_mxid(cls) -> AsyncGenerator[Puppet, None]:
  421. puppets = await super().all_with_custom_mxid()
  422. puppet: cls
  423. for index, puppet in enumerate(puppets):
  424. try:
  425. yield cls.by_uuid[puppet.uuid]
  426. except KeyError:
  427. puppet._add_to_cache()
  428. yield puppet
  429. # endregion