puppet.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2021 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
  18. from uuid import UUID
  19. import asyncio
  20. import hashlib
  21. import os.path
  22. from yarl import URL
  23. from mausignald.errors import UnregisteredUserError
  24. from mausignald.types import Address, Profile
  25. from mautrix.appservice import IntentAPI
  26. from mautrix.bridge import BasePuppet, async_getter_lock
  27. from mautrix.errors import MForbidden
  28. from mautrix.types import (
  29. ContentURI,
  30. EventType,
  31. PowerLevelStateEventContent,
  32. RoomID,
  33. SyncToken,
  34. UserID,
  35. )
  36. from mautrix.util.simple_template import SimpleTemplate
  37. from . import portal as p, signal, user as u
  38. from .config import Config
  39. from .db import Puppet as DBPuppet
  40. if TYPE_CHECKING:
  41. from .__main__ import SignalBridge
  42. try:
  43. import phonenumbers
  44. except ImportError:
  45. phonenumbers = None
  46. class Puppet(DBPuppet, BasePuppet):
  47. by_uuid: dict[UUID, Puppet] = {}
  48. by_number: dict[str, Puppet] = {}
  49. by_custom_mxid: dict[UserID, Puppet] = {}
  50. hs_domain: str
  51. mxid_template: SimpleTemplate[str]
  52. config: Config
  53. signal: signal.SignalHandler
  54. default_mxid_intent: IntentAPI
  55. default_mxid: UserID
  56. _uuid_lock: asyncio.Lock
  57. _update_info_lock: asyncio.Lock
  58. def __init__(
  59. self,
  60. uuid: UUID,
  61. number: str | None,
  62. name: str | None = None,
  63. name_quality: int = 0,
  64. avatar_url: ContentURI | None = None,
  65. avatar_hash: str | None = None,
  66. name_set: bool = False,
  67. avatar_set: bool = False,
  68. is_registered: bool = False,
  69. custom_mxid: UserID | None = None,
  70. access_token: str | None = None,
  71. next_batch: SyncToken | None = None,
  72. base_url: URL | None = None,
  73. ) -> None:
  74. assert uuid, "UUID must be set for ghosts"
  75. assert isinstance(uuid, UUID)
  76. super().__init__(
  77. uuid=uuid,
  78. number=number,
  79. name=name,
  80. name_quality=name_quality,
  81. avatar_url=avatar_url,
  82. avatar_hash=avatar_hash,
  83. name_set=name_set,
  84. avatar_set=avatar_set,
  85. is_registered=is_registered,
  86. custom_mxid=custom_mxid,
  87. access_token=access_token,
  88. next_batch=next_batch,
  89. base_url=base_url,
  90. )
  91. self.log = self.log.getChild(str(uuid) if uuid else number)
  92. self.default_mxid = self.get_mxid_from_id(self.uuid)
  93. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  94. self.intent = self._fresh_intent()
  95. self._uuid_lock = asyncio.Lock()
  96. self._update_info_lock = asyncio.Lock()
  97. @classmethod
  98. def init_cls(cls, bridge: "SignalBridge") -> AsyncIterable[Awaitable[None]]:
  99. cls.config = bridge.config
  100. cls.loop = bridge.loop
  101. cls.signal = bridge.signal
  102. cls.mx = bridge.matrix
  103. cls.az = bridge.az
  104. cls.hs_domain = cls.config["homeserver.domain"]
  105. cls.mxid_template = SimpleTemplate(
  106. cls.config["bridge.username_template"],
  107. "userid",
  108. prefix="@",
  109. suffix=f":{cls.hs_domain}",
  110. type=str,
  111. )
  112. cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
  113. cls.homeserver_url_map = {
  114. server: URL(url)
  115. for server, url in cls.config["bridge.double_puppet_server_map"].items()
  116. }
  117. cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
  118. cls.login_shared_secret_map = {
  119. server: secret.encode("utf-8")
  120. for server, secret in cls.config["bridge.login_shared_secret_map"].items()
  121. }
  122. cls.login_device_name = "Signal Bridge"
  123. return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
  124. def intent_for(self, portal: p.Portal) -> IntentAPI:
  125. if portal.chat_id == self.address:
  126. return self.default_mxid_intent
  127. return self.intent
  128. @property
  129. def address(self) -> Address:
  130. return Address(uuid=self.uuid, number=self.number)
  131. async def handle_number_receive(self, number: str) -> None:
  132. async with self._uuid_lock:
  133. if self.number == number:
  134. return
  135. if self.number:
  136. self.by_number.pop(self.number, None)
  137. self.number = number
  138. self.by_number[self.number] = self
  139. await self._set_number(number)
  140. async def _migrate_memberships(self, prev_intent: IntentAPI, new_intent: IntentAPI) -> None:
  141. self.log.debug(f"Migrating memberships {prev_intent.mxid} -> {new_intent.mxid}")
  142. try:
  143. joined_rooms = await prev_intent.get_joined_rooms()
  144. except MForbidden as e:
  145. self.log.debug(
  146. f"Got MForbidden ({e.message}) when getting joined rooms of old mxid, "
  147. "assuming there are no rooms to rejoin"
  148. )
  149. return
  150. for room_id in joined_rooms:
  151. await prev_intent.invite_user(room_id, self.default_mxid)
  152. await self._migrate_powers(prev_intent, new_intent, room_id)
  153. await prev_intent.leave_room(room_id)
  154. await new_intent.join_room_by_id(room_id)
  155. async def _migrate_powers(
  156. self, prev_intent: IntentAPI, new_intent: IntentAPI, room_id: RoomID
  157. ) -> None:
  158. try:
  159. powers: PowerLevelStateEventContent
  160. powers = await prev_intent.get_state_event(room_id, EventType.ROOM_POWER_LEVELS)
  161. user_level = powers.get_user_level(prev_intent.mxid)
  162. pl_state_level = powers.get_event_level(EventType.ROOM_POWER_LEVELS)
  163. if user_level >= pl_state_level > powers.users_default:
  164. powers.ensure_user_level(new_intent.mxid, user_level)
  165. await prev_intent.send_state_event(room_id, EventType.ROOM_POWER_LEVELS, powers)
  166. except Exception:
  167. self.log.warning("Failed to migrate power levels", exc_info=True)
  168. async def update_info(self, info: Profile | Address, source: u.User) -> None:
  169. update = False
  170. address = info.address if isinstance(info, Profile) else info
  171. if address.number and address.number != self.number:
  172. await self.handle_number_receive(address.number)
  173. update = True
  174. self.log.debug("Updating info with %s (source: %s)", info, source.mxid)
  175. async with self._update_info_lock:
  176. if isinstance(info, Profile) or self.name is None:
  177. update = await self._update_name(info) or update
  178. if isinstance(info, Profile):
  179. update = await self._update_avatar(info.avatar) or update
  180. elif self.config["bridge.contact_list_names"] != "disallow" and self.number:
  181. # Try to use a contact list avatar
  182. update = await self._update_avatar(f"contact-{self.number}") or update
  183. if update:
  184. await self.update()
  185. asyncio.create_task(self._update_portal_meta())
  186. @staticmethod
  187. def fmt_phone(number: str) -> str:
  188. if phonenumbers is None:
  189. return number
  190. parsed = phonenumbers.parse(number)
  191. fmt = phonenumbers.PhoneNumberFormat.INTERNATIONAL
  192. return phonenumbers.format_number(parsed, fmt)
  193. @classmethod
  194. def _get_displayname(cls, info: Profile | Address) -> tuple[str, int]:
  195. quality = 10
  196. if isinstance(info, Profile):
  197. address = info.address
  198. name = None
  199. contact_names = cls.config["bridge.contact_list_names"]
  200. if info.profile_name:
  201. name = info.profile_name
  202. quality = 90 if contact_names == "prefer" else 100
  203. if info.contact_name:
  204. if contact_names == "prefer":
  205. quality = 100
  206. name = info.contact_name
  207. elif contact_names == "allow" and not name:
  208. quality = 50
  209. name = info.contact_name
  210. names = name.split("\x00") if name else []
  211. else:
  212. address = info
  213. names = []
  214. data = {
  215. "first_name": names[0] if len(names) > 0 else "",
  216. "last_name": names[-1] if len(names) > 1 else "",
  217. "full_name": " ".join(names),
  218. "phone": cls.fmt_phone(address.number) if address.number else None,
  219. "uuid": str(address.uuid) if address.uuid else None,
  220. "displayname": "Unknown user",
  221. }
  222. for pref in cls.config["bridge.displayname_preference"]:
  223. value = data.get(pref.replace(" ", "_"))
  224. if value:
  225. data["displayname"] = value
  226. break
  227. return cls.config["bridge.displayname_template"].format(**data), quality
  228. async def _update_name(self, info: Profile | Address) -> bool:
  229. name, quality = self._get_displayname(info)
  230. if quality >= self.name_quality and (name != self.name or not self.name_set):
  231. self.log.debug(
  232. "Updating name from '%s' to '%s' (quality: %d)", self.name, name, quality
  233. )
  234. self.name = name
  235. self.name_quality = quality
  236. try:
  237. await self.default_mxid_intent.set_displayname(self.name)
  238. self.name_set = True
  239. except Exception:
  240. self.log.exception("Error setting displayname")
  241. self.name_set = False
  242. return True
  243. elif name != self.name or not self.name_set:
  244. self.log.debug(
  245. "Not updating name from '%s' to '%s', new quality (%d) is lower than old (%d)",
  246. self.name,
  247. name,
  248. quality,
  249. self.name_quality,
  250. )
  251. elif self.name_quality == 0:
  252. # Name matches, but quality is not stored in database - store it now
  253. self.name_quality = quality
  254. return True
  255. return False
  256. @staticmethod
  257. async def upload_avatar(
  258. self: Puppet | p.Portal, path: str, intent: IntentAPI
  259. ) -> bool | tuple[str, ContentURI]:
  260. if not path:
  261. return False
  262. if not path.startswith("/"):
  263. path = os.path.join(self.config["signal.avatar_dir"], path)
  264. try:
  265. with open(path, "rb") as file:
  266. data = file.read()
  267. except FileNotFoundError:
  268. return False
  269. if not data:
  270. return False
  271. new_hash = hashlib.sha256(data).hexdigest()
  272. if self.avatar_set and new_hash == self.avatar_hash:
  273. return False
  274. mxc = await intent.upload_media(data, async_upload=self.config["homeserver.async_media"])
  275. return new_hash, mxc
  276. async def _update_avatar(self, path: str) -> bool:
  277. res = await Puppet.upload_avatar(self, path, self.default_mxid_intent)
  278. if res is False:
  279. return False
  280. self.avatar_hash, self.avatar_url = res
  281. try:
  282. await self.default_mxid_intent.set_avatar_url(self.avatar_url)
  283. self.avatar_set = True
  284. except Exception:
  285. self.log.exception("Error setting avatar")
  286. self.avatar_set = False
  287. return True
  288. async def _update_portal_meta(self) -> None:
  289. async for portal in p.Portal.find_private_chats_with(self.uuid):
  290. if portal.receiver == self.number:
  291. # This is a note to self chat, don't change the name
  292. continue
  293. try:
  294. await portal.update_puppet_name(self.name)
  295. await portal.update_puppet_avatar(self.avatar_hash, self.avatar_url)
  296. if self.number:
  297. await portal.update_puppet_number(self.fmt_phone(self.number))
  298. except Exception:
  299. self.log.exception(f"Error updating portal meta for {portal.receiver}")
  300. async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
  301. portal: p.Portal = await p.Portal.get_by_mxid(room_id)
  302. # Leave all portals except the notes to self room
  303. return not (portal and portal.is_direct and portal.chat_id == self.uuid)
  304. # region Database getters
  305. def _add_to_cache(self) -> None:
  306. self.by_uuid[self.uuid] = self
  307. if self.number:
  308. self.by_number[self.number] = self
  309. if self.custom_mxid:
  310. self.by_custom_mxid[self.custom_mxid] = self
  311. async def save(self) -> None:
  312. await self.update()
  313. @classmethod
  314. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Puppet | None:
  315. uuid = cls.get_id_from_mxid(mxid)
  316. if not uuid:
  317. return None
  318. return await cls.get_by_uuid(uuid, create=create)
  319. @classmethod
  320. @async_getter_lock
  321. async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
  322. try:
  323. return cls.by_custom_mxid[mxid]
  324. except KeyError:
  325. pass
  326. puppet = cast(cls, await super().get_by_custom_mxid(mxid))
  327. if puppet:
  328. puppet._add_to_cache()
  329. return puppet
  330. return None
  331. @classmethod
  332. def get_id_from_mxid(cls, mxid: UserID) -> UUID | None:
  333. identifier = cls.mxid_template.parse(mxid)
  334. if not identifier:
  335. return None
  336. try:
  337. return UUID(identifier.upper())
  338. except ValueError:
  339. return None
  340. @classmethod
  341. def get_mxid_from_id(cls, uuid: UUID) -> UserID:
  342. return UserID(cls.mxid_template.format_full(str(uuid).lower()))
  343. @classmethod
  344. @async_getter_lock
  345. async def get_by_number(
  346. cls, number: str, /, *, resolve_via: str | None = None, raise_resolve: bool = False
  347. ) -> Puppet | None:
  348. try:
  349. return cls.by_number[number]
  350. except KeyError:
  351. pass
  352. puppet = cast(cls, await super().get_by_number(number))
  353. if puppet is not None:
  354. puppet._add_to_cache()
  355. return puppet
  356. if resolve_via:
  357. cls.log.debug(
  358. f"Couldn't find puppet with number {number}, resolving UUID via {resolve_via}"
  359. )
  360. try:
  361. uuid = await cls.signal.find_uuid(resolve_via, number)
  362. except UnregisteredUserError:
  363. if raise_resolve:
  364. raise
  365. cls.log.debug(f"Resolving {number} via {resolve_via} threw UnregisteredUserError")
  366. return None
  367. except Exception:
  368. if raise_resolve:
  369. raise
  370. cls.log.exception(f"Failed to resolve {number} via {resolve_via}")
  371. return None
  372. if uuid:
  373. cls.log.debug(f"Found {uuid} for {number} after resolving via {resolve_via}")
  374. return await cls.get_by_uuid(uuid, number=number)
  375. else:
  376. cls.log.debug(f"Didn't find UUID for {number} via {resolve_via}")
  377. return None
  378. @classmethod
  379. async def get_by_address(
  380. cls,
  381. address: Address,
  382. create: bool = True,
  383. resolve_via: str | None = None,
  384. raise_resolve: bool = False,
  385. ) -> Puppet | None:
  386. if not address.uuid:
  387. return await cls.get_by_number(
  388. address.number, resolve_via=resolve_via, raise_resolve=raise_resolve
  389. )
  390. else:
  391. return await cls.get_by_uuid(address.uuid, create=create, number=address.number)
  392. @classmethod
  393. @async_getter_lock
  394. async def get_by_uuid(
  395. cls, uuid: UUID, /, *, create: bool = True, number: str | None = None
  396. ) -> Puppet | None:
  397. try:
  398. return cls.by_uuid[uuid]
  399. except KeyError:
  400. pass
  401. puppet = cast(cls, await super().get_by_uuid(uuid))
  402. if puppet is not None:
  403. puppet._add_to_cache()
  404. return puppet
  405. if create:
  406. puppet = cls(uuid, number)
  407. await puppet.insert()
  408. puppet._add_to_cache()
  409. return puppet
  410. return None
  411. @classmethod
  412. async def all_with_custom_mxid(cls) -> AsyncGenerator[Puppet, None]:
  413. puppets = await super().all_with_custom_mxid()
  414. puppet: cls
  415. for index, puppet in enumerate(puppets):
  416. try:
  417. yield cls.by_uuid[puppet.uuid]
  418. except KeyError:
  419. puppet._add_to_cache()
  420. yield puppet
  421. # endregion