puppet.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2021 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
  18. from uuid import UUID
  19. import asyncio
  20. import hashlib
  21. import os.path
  22. from yarl import URL
  23. from mausignald.errors import UnregisteredUserError
  24. from mausignald.types import Address, Profile
  25. from mautrix.appservice import IntentAPI
  26. from mautrix.bridge import BasePuppet, async_getter_lock
  27. from mautrix.errors import MForbidden
  28. from mautrix.types import (
  29. ContentURI,
  30. EventType,
  31. PowerLevelStateEventContent,
  32. RoomID,
  33. SyncToken,
  34. UserID,
  35. )
  36. from mautrix.util.simple_template import SimpleTemplate
  37. from . import portal as p, signal, user as u
  38. from .config import Config
  39. from .db import Puppet as DBPuppet
  40. if TYPE_CHECKING:
  41. from .__main__ import SignalBridge
  42. try:
  43. import phonenumbers
  44. except ImportError:
  45. phonenumbers = None
  46. class Puppet(DBPuppet, BasePuppet):
  47. by_uuid: dict[UUID, Puppet] = {}
  48. by_number: dict[str, Puppet] = {}
  49. by_custom_mxid: dict[UserID, Puppet] = {}
  50. hs_domain: str
  51. mxid_template: SimpleTemplate[str]
  52. config: Config
  53. signal: signal.SignalHandler
  54. default_mxid_intent: IntentAPI
  55. default_mxid: UserID
  56. _uuid_lock: asyncio.Lock
  57. _update_info_lock: asyncio.Lock
  58. def __init__(
  59. self,
  60. uuid: UUID,
  61. number: str | None,
  62. name: str | None = None,
  63. name_quality: int = 0,
  64. avatar_url: ContentURI | None = None,
  65. avatar_hash: str | None = None,
  66. name_set: bool = False,
  67. avatar_set: bool = False,
  68. is_registered: bool = False,
  69. custom_mxid: UserID | None = None,
  70. access_token: str | None = None,
  71. next_batch: SyncToken | None = None,
  72. base_url: URL | None = None,
  73. ) -> None:
  74. assert uuid, "UUID must be set for ghosts"
  75. assert isinstance(uuid, UUID)
  76. super().__init__(
  77. uuid=uuid,
  78. number=number,
  79. name=name,
  80. name_quality=name_quality,
  81. avatar_url=avatar_url,
  82. avatar_hash=avatar_hash,
  83. name_set=name_set,
  84. avatar_set=avatar_set,
  85. is_registered=is_registered,
  86. custom_mxid=custom_mxid,
  87. access_token=access_token,
  88. next_batch=next_batch,
  89. base_url=base_url,
  90. )
  91. self.log = self.log.getChild(str(uuid) if uuid else number)
  92. self.default_mxid = self.get_mxid_from_id(self.uuid)
  93. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  94. self.intent = self._fresh_intent()
  95. self._uuid_lock = asyncio.Lock()
  96. self._update_info_lock = asyncio.Lock()
  97. @classmethod
  98. def init_cls(cls, bridge: "SignalBridge") -> AsyncIterable[Awaitable[None]]:
  99. cls.config = bridge.config
  100. cls.loop = bridge.loop
  101. cls.signal = bridge.signal
  102. cls.mx = bridge.matrix
  103. cls.az = bridge.az
  104. cls.hs_domain = cls.config["homeserver.domain"]
  105. cls.mxid_template = SimpleTemplate(
  106. cls.config["bridge.username_template"],
  107. "userid",
  108. prefix="@",
  109. suffix=f":{cls.hs_domain}",
  110. type=str,
  111. )
  112. cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
  113. cls.homeserver_url_map = {
  114. server: URL(url)
  115. for server, url in cls.config["bridge.double_puppet_server_map"].items()
  116. }
  117. cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
  118. cls.login_shared_secret_map = {
  119. server: secret.encode("utf-8")
  120. for server, secret in cls.config["bridge.login_shared_secret_map"].items()
  121. }
  122. cls.login_device_name = "Signal Bridge"
  123. return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
  124. def intent_for(self, portal: p.Portal) -> IntentAPI:
  125. if portal.chat_id == self.uuid:
  126. return self.default_mxid_intent
  127. return self.intent
  128. @property
  129. def address(self) -> Address:
  130. return Address(uuid=self.uuid, number=self.number)
  131. async def handle_number_receive(self, number: str) -> None:
  132. async with self._uuid_lock:
  133. if self.number == number:
  134. return
  135. if self.number:
  136. self.by_number.pop(self.number, None)
  137. self.number = number
  138. self._add_number_to_cache()
  139. await self._update_number()
  140. async def _migrate_memberships(self, prev_intent: IntentAPI, new_intent: IntentAPI) -> None:
  141. self.log.debug(f"Migrating memberships {prev_intent.mxid} -> {new_intent.mxid}")
  142. try:
  143. joined_rooms = await prev_intent.get_joined_rooms()
  144. except MForbidden as e:
  145. self.log.debug(
  146. f"Got MForbidden ({e.message}) when getting joined rooms of old mxid, "
  147. "assuming there are no rooms to rejoin"
  148. )
  149. return
  150. for room_id in joined_rooms:
  151. await prev_intent.invite_user(room_id, self.default_mxid)
  152. await self._migrate_powers(prev_intent, new_intent, room_id)
  153. await prev_intent.leave_room(room_id)
  154. await new_intent.join_room_by_id(room_id)
  155. async def _migrate_powers(
  156. self, prev_intent: IntentAPI, new_intent: IntentAPI, room_id: RoomID
  157. ) -> None:
  158. try:
  159. powers: PowerLevelStateEventContent
  160. powers = await prev_intent.get_state_event(room_id, EventType.ROOM_POWER_LEVELS)
  161. user_level = powers.get_user_level(prev_intent.mxid)
  162. pl_state_level = powers.get_event_level(EventType.ROOM_POWER_LEVELS)
  163. if user_level >= pl_state_level > powers.users_default:
  164. powers.ensure_user_level(new_intent.mxid, user_level)
  165. await prev_intent.send_state_event(room_id, EventType.ROOM_POWER_LEVELS, powers)
  166. except Exception:
  167. self.log.warning("Failed to migrate power levels", exc_info=True)
  168. async def update_info(self, info: Profile | Address, source: u.User) -> None:
  169. update = False
  170. address = info.address if isinstance(info, Profile) else info
  171. if address.number and address.number != self.number:
  172. await self.handle_number_receive(address.number)
  173. update = True
  174. self.log.debug("Updating info with %s (source: %s)", info, source.mxid)
  175. async with self._update_info_lock:
  176. if isinstance(info, Profile) or self.name is None:
  177. update = await self._update_name(info) or update
  178. if isinstance(info, Profile):
  179. update = await self._update_avatar(info.avatar) or update
  180. elif self.config["bridge.contact_list_names"] != "disallow" and self.number:
  181. # Try to use a contact list avatar
  182. update = await self._update_avatar(f"contact-{self.number}") or update
  183. if update:
  184. await self.update()
  185. asyncio.create_task(self._try_update_portal_meta())
  186. @staticmethod
  187. def fmt_phone(number: str) -> str:
  188. if phonenumbers is None:
  189. return number
  190. parsed = phonenumbers.parse(number)
  191. fmt = phonenumbers.PhoneNumberFormat.INTERNATIONAL
  192. return phonenumbers.format_number(parsed, fmt)
  193. @classmethod
  194. def _get_displayname(cls, info: Profile | Address) -> tuple[str, int]:
  195. quality = 10
  196. if isinstance(info, Profile):
  197. address = info.address
  198. name = None
  199. contact_names = cls.config["bridge.contact_list_names"]
  200. if info.profile_name:
  201. name = info.profile_name
  202. quality = 90 if contact_names == "prefer" else 100
  203. if info.contact_name:
  204. if contact_names == "prefer":
  205. quality = 100
  206. name = info.contact_name
  207. elif contact_names == "allow" and not name:
  208. quality = 50
  209. name = info.contact_name
  210. names = name.split("\x00") if name else []
  211. else:
  212. address = info
  213. names = []
  214. data = {
  215. "first_name": names[0] if len(names) > 0 else "",
  216. "last_name": names[-1] if len(names) > 1 else "",
  217. "full_name": " ".join(names),
  218. "phone": cls.fmt_phone(address.number) if address.number else None,
  219. "uuid": str(address.uuid) if address.uuid else None,
  220. "displayname": "Unknown user",
  221. }
  222. for pref in cls.config["bridge.displayname_preference"]:
  223. value = data.get(pref.replace(" ", "_"))
  224. if value:
  225. data["displayname"] = value
  226. break
  227. return cls.config["bridge.displayname_template"].format(**data), quality
  228. async def _update_name(self, info: Profile | Address) -> bool:
  229. name, quality = self._get_displayname(info)
  230. if quality >= self.name_quality and (name != self.name or not self.name_set):
  231. self.log.debug(
  232. "Updating name from '%s' to '%s' (quality: %d)", self.name, name, quality
  233. )
  234. self.name = name
  235. self.name_quality = quality
  236. try:
  237. await self.default_mxid_intent.set_displayname(self.name)
  238. self.name_set = True
  239. except Exception:
  240. self.log.exception("Error setting displayname")
  241. self.name_set = False
  242. return True
  243. elif name != self.name or not self.name_set:
  244. self.log.debug(
  245. "Not updating name from '%s' to '%s', new quality (%d) is lower than old (%d)",
  246. self.name,
  247. name,
  248. quality,
  249. self.name_quality,
  250. )
  251. elif self.name_quality == 0:
  252. # Name matches, but quality is not stored in database - store it now
  253. self.name_quality = quality
  254. return True
  255. return False
  256. @staticmethod
  257. async def upload_avatar(
  258. self: Puppet | p.Portal, path: str, intent: IntentAPI
  259. ) -> bool | tuple[str, ContentURI]:
  260. if not path:
  261. return False
  262. if not path.startswith("/"):
  263. path = os.path.join(self.config["signal.avatar_dir"], path)
  264. try:
  265. with open(path, "rb") as file:
  266. data = file.read()
  267. except FileNotFoundError:
  268. return False
  269. if not data:
  270. return False
  271. new_hash = hashlib.sha256(data).hexdigest()
  272. if self.avatar_set and new_hash == self.avatar_hash:
  273. return False
  274. mxc = await intent.upload_media(data, async_upload=self.config["homeserver.async_media"])
  275. return new_hash, mxc
  276. async def _update_avatar(self, path: str) -> bool:
  277. res = await Puppet.upload_avatar(self, path, self.default_mxid_intent)
  278. if res is False:
  279. return False
  280. self.avatar_hash, self.avatar_url = res
  281. try:
  282. await self.default_mxid_intent.set_avatar_url(self.avatar_url)
  283. self.avatar_set = True
  284. except Exception:
  285. self.log.exception("Error setting avatar")
  286. self.avatar_set = False
  287. return True
  288. async def _try_update_portal_meta(self) -> None:
  289. try:
  290. await self._update_portal_meta()
  291. except Exception:
  292. self.log.exception("Error updating portal meta")
  293. async def _update_portal_meta(self) -> None:
  294. async for portal in p.Portal.find_private_chats_with(self.uuid):
  295. if portal.receiver == self.number:
  296. # This is a note to self chat, don't change the name
  297. continue
  298. try:
  299. await portal.update_puppet_name(self.name)
  300. await portal.update_puppet_avatar(self.avatar_hash, self.avatar_url)
  301. if self.number:
  302. await portal.update_puppet_number(self.fmt_phone(self.number))
  303. except Exception:
  304. self.log.exception(f"Error updating portal meta for {portal.receiver}")
  305. async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
  306. portal: p.Portal = await p.Portal.get_by_mxid(room_id)
  307. # Leave all portals except the notes to self room
  308. return not (portal and portal.is_direct and portal.chat_id == self.uuid)
  309. # region Database getters
  310. def _add_number_to_cache(self) -> None:
  311. if self.number:
  312. try:
  313. existing = self.by_number[self.number]
  314. if existing and existing.uuid != self.uuid and existing != self:
  315. existing.number = None
  316. except KeyError:
  317. pass
  318. self.by_number[self.number] = self
  319. def _add_to_cache(self) -> None:
  320. self.by_uuid[self.uuid] = self
  321. self._add_number_to_cache()
  322. if self.custom_mxid:
  323. self.by_custom_mxid[self.custom_mxid] = self
  324. async def save(self) -> None:
  325. await self.update()
  326. @classmethod
  327. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Puppet | None:
  328. uuid = cls.get_id_from_mxid(mxid)
  329. if not uuid:
  330. return None
  331. return await cls.get_by_uuid(uuid, create=create)
  332. @classmethod
  333. @async_getter_lock
  334. async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
  335. try:
  336. return cls.by_custom_mxid[mxid]
  337. except KeyError:
  338. pass
  339. puppet = cast(cls, await super().get_by_custom_mxid(mxid))
  340. if puppet:
  341. puppet._add_to_cache()
  342. return puppet
  343. return None
  344. @classmethod
  345. def get_id_from_mxid(cls, mxid: UserID) -> UUID | None:
  346. identifier = cls.mxid_template.parse(mxid)
  347. if not identifier:
  348. return None
  349. try:
  350. return UUID(identifier.upper())
  351. except ValueError:
  352. return None
  353. @classmethod
  354. def get_mxid_from_id(cls, uuid: UUID) -> UserID:
  355. return UserID(cls.mxid_template.format_full(str(uuid).lower()))
  356. @classmethod
  357. @async_getter_lock
  358. async def get_by_number(
  359. cls, number: str, /, *, resolve_via: str | None = None, raise_resolve: bool = False
  360. ) -> Puppet | None:
  361. try:
  362. return cls.by_number[number]
  363. except KeyError:
  364. pass
  365. puppet = cast(cls, await super().get_by_number(number))
  366. if puppet is not None:
  367. puppet._add_to_cache()
  368. return puppet
  369. if resolve_via:
  370. cls.log.debug(
  371. f"Couldn't find puppet with number {number}, resolving UUID via {resolve_via}"
  372. )
  373. try:
  374. uuid = await cls.signal.find_uuid(resolve_via, number)
  375. except UnregisteredUserError:
  376. if raise_resolve:
  377. raise
  378. cls.log.debug(f"Resolving {number} via {resolve_via} threw UnregisteredUserError")
  379. return None
  380. except Exception:
  381. if raise_resolve:
  382. raise
  383. cls.log.exception(f"Failed to resolve {number} via {resolve_via}")
  384. return None
  385. if uuid:
  386. cls.log.debug(f"Found {uuid} for {number} after resolving via {resolve_via}")
  387. return await cls.get_by_uuid(uuid, number=number)
  388. else:
  389. cls.log.debug(f"Didn't find UUID for {number} via {resolve_via}")
  390. return None
  391. @classmethod
  392. async def get_by_address(
  393. cls,
  394. address: Address,
  395. create: bool = True,
  396. resolve_via: str | None = None,
  397. raise_resolve: bool = False,
  398. ) -> Puppet | None:
  399. if not address.uuid:
  400. return await cls.get_by_number(
  401. address.number, resolve_via=resolve_via, raise_resolve=raise_resolve
  402. )
  403. else:
  404. return await cls.get_by_uuid(address.uuid, create=create, number=address.number)
  405. @classmethod
  406. @async_getter_lock
  407. async def get_by_uuid(
  408. cls, uuid: UUID, /, *, create: bool = True, number: str | None = None
  409. ) -> Puppet | None:
  410. try:
  411. return cls.by_uuid[uuid]
  412. except KeyError:
  413. pass
  414. puppet = cast(cls, await super().get_by_uuid(uuid))
  415. if puppet is not None:
  416. puppet._add_to_cache()
  417. return puppet
  418. if create:
  419. puppet = cls(uuid, number)
  420. await puppet.insert()
  421. puppet._add_to_cache()
  422. return puppet
  423. return None
  424. @classmethod
  425. async def all_with_custom_mxid(cls) -> AsyncGenerator[Puppet, None]:
  426. puppets = await super().all_with_custom_mxid()
  427. puppet: cls
  428. for index, puppet in enumerate(puppets):
  429. try:
  430. yield cls.by_uuid[puppet.uuid]
  431. except KeyError:
  432. puppet._add_to_cache()
  433. yield puppet
  434. # endregion