puppet.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2021 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast
  18. from uuid import UUID
  19. import asyncio
  20. import hashlib
  21. import os.path
  22. from yarl import URL
  23. from mausignald.types import Address, Contact, Profile
  24. from mautrix.appservice import IntentAPI
  25. from mautrix.bridge import BasePuppet, async_getter_lock
  26. from mautrix.errors import MForbidden
  27. from mautrix.types import (
  28. ContentURI,
  29. EventType,
  30. PowerLevelStateEventContent,
  31. RoomID,
  32. SyncToken,
  33. UserID,
  34. )
  35. from mautrix.util.simple_template import SimpleTemplate
  36. from . import portal as p, user as u
  37. from .config import Config
  38. from .db import Puppet as DBPuppet
  39. if TYPE_CHECKING:
  40. from .__main__ import SignalBridge
  41. try:
  42. import phonenumbers
  43. except ImportError:
  44. phonenumbers = None
  45. class Puppet(DBPuppet, BasePuppet):
  46. by_uuid: dict[UUID, Puppet] = {}
  47. by_number: dict[str, Puppet] = {}
  48. by_custom_mxid: dict[UserID, Puppet] = {}
  49. hs_domain: str
  50. mxid_template: SimpleTemplate[str]
  51. config: Config
  52. default_mxid_intent: IntentAPI
  53. default_mxid: UserID
  54. _uuid_lock: asyncio.Lock
  55. _update_info_lock: asyncio.Lock
  56. def __init__(
  57. self,
  58. uuid: UUID | None,
  59. number: str | None,
  60. name: str | None = None,
  61. avatar_url: ContentURI | None = None,
  62. avatar_hash: str | None = None,
  63. name_set: bool = False,
  64. avatar_set: bool = False,
  65. uuid_registered: bool = False,
  66. number_registered: bool = False,
  67. custom_mxid: UserID | None = None,
  68. access_token: str | None = None,
  69. next_batch: SyncToken | None = None,
  70. base_url: URL | None = None,
  71. ) -> None:
  72. super().__init__(
  73. uuid=uuid,
  74. number=number,
  75. name=name,
  76. avatar_url=avatar_url,
  77. avatar_hash=avatar_hash,
  78. name_set=name_set,
  79. avatar_set=avatar_set,
  80. uuid_registered=uuid_registered,
  81. number_registered=number_registered,
  82. custom_mxid=custom_mxid,
  83. access_token=access_token,
  84. next_batch=next_batch,
  85. base_url=base_url,
  86. )
  87. self.log = self.log.getChild(str(uuid) if uuid else number)
  88. self.default_mxid = self.get_mxid_from_id(self.address)
  89. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  90. self.intent = self._fresh_intent()
  91. self._uuid_lock = asyncio.Lock()
  92. self._update_info_lock = asyncio.Lock()
  93. @classmethod
  94. def init_cls(cls, bridge: "SignalBridge") -> AsyncIterable[Awaitable[None]]:
  95. cls.config = bridge.config
  96. cls.loop = bridge.loop
  97. cls.mx = bridge.matrix
  98. cls.az = bridge.az
  99. cls.hs_domain = cls.config["homeserver.domain"]
  100. cls.mxid_template = SimpleTemplate(
  101. cls.config["bridge.username_template"],
  102. "userid",
  103. prefix="@",
  104. suffix=f":{cls.hs_domain}",
  105. type=str,
  106. )
  107. cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
  108. cls.homeserver_url_map = {
  109. server: URL(url)
  110. for server, url in cls.config["bridge.double_puppet_server_map"].items()
  111. }
  112. cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
  113. cls.login_shared_secret_map = {
  114. server: secret.encode("utf-8")
  115. for server, secret in cls.config["bridge.login_shared_secret_map"].items()
  116. }
  117. cls.login_device_name = "Signal Bridge"
  118. return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
  119. def intent_for(self, portal: p.Portal) -> IntentAPI:
  120. if portal.chat_id == self.address:
  121. return self.default_mxid_intent
  122. return self.intent
  123. @property
  124. def is_registered(self) -> bool:
  125. return self.uuid_registered if self.uuid is not None else self.number_registered
  126. @is_registered.setter
  127. def is_registered(self, value: bool) -> None:
  128. if self.uuid is not None:
  129. self.uuid_registered = value
  130. else:
  131. self.number_registered = value
  132. @property
  133. def address(self) -> Address:
  134. return Address(uuid=self.uuid, number=self.number)
  135. async def handle_uuid_receive(self, uuid: UUID) -> None:
  136. async with self._uuid_lock:
  137. if self.uuid:
  138. # Received UUID was handled while this call was waiting
  139. return
  140. await self._handle_uuid_receive(uuid)
  141. async def handle_number_receive(self, number: str) -> None:
  142. async with self._uuid_lock:
  143. if self.number:
  144. return
  145. self.number = number
  146. self.by_number[self.number] = self
  147. await self._set_number(number)
  148. async for portal in p.Portal.find_private_chats_with(Address(number=number)):
  149. self.log.trace(f"Updating chat_id of private chat portal {portal.receiver}")
  150. portal.handle_uuid_receive(self.uuid)
  151. prev_mxid = self.get_mxid_from_id(Address(number=number))
  152. if await self.az.state_store.is_registered(prev_mxid):
  153. prev_intent = self.az.intent.user(prev_mxid)
  154. await self._migrate_memberships(prev_intent, self.default_mxid_intent)
  155. async def _handle_uuid_receive(self, uuid: UUID) -> None:
  156. self.log.debug(f"Found UUID for user: {uuid}")
  157. user = await u.User.get_by_username(self.number)
  158. if user and not user.uuid:
  159. user.uuid = self.uuid
  160. user.by_uuid[user.uuid] = user
  161. await user.update()
  162. self.uuid = uuid
  163. self.by_uuid[self.uuid] = self
  164. await self._set_uuid(uuid)
  165. async for portal in p.Portal.find_private_chats_with(Address(number=self.number)):
  166. self.log.trace(f"Updating chat_id of private chat portal {portal.receiver}")
  167. portal.handle_uuid_receive(self.uuid)
  168. prev_intent = self.default_mxid_intent
  169. self.default_mxid = self.get_mxid_from_id(self.address)
  170. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  171. self.intent = self._fresh_intent()
  172. await self.default_mxid_intent.ensure_registered()
  173. if self.name:
  174. await self.default_mxid_intent.set_displayname(self.name)
  175. self.log = Puppet.log.getChild(str(uuid))
  176. await self._migrate_memberships(prev_intent, self.default_mxid_intent)
  177. async def _migrate_memberships(self, prev_intent: IntentAPI, new_intent: IntentAPI) -> None:
  178. self.log.debug(f"Migrating memberships {prev_intent.mxid} -> {new_intent.mxid}")
  179. try:
  180. joined_rooms = await prev_intent.get_joined_rooms()
  181. except MForbidden as e:
  182. self.log.debug(
  183. f"Got MForbidden ({e.message}) when getting joined rooms of old mxid, "
  184. "assuming there are no rooms to rejoin"
  185. )
  186. return
  187. for room_id in joined_rooms:
  188. await prev_intent.invite_user(room_id, self.default_mxid)
  189. await self._migrate_powers(prev_intent, new_intent, room_id)
  190. await prev_intent.leave_room(room_id)
  191. await new_intent.join_room_by_id(room_id)
  192. async def _migrate_powers(
  193. self, prev_intent: IntentAPI, new_intent: IntentAPI, room_id: RoomID
  194. ) -> None:
  195. try:
  196. powers: PowerLevelStateEventContent
  197. powers = await prev_intent.get_state_event(room_id, EventType.ROOM_POWER_LEVELS)
  198. user_level = powers.get_user_level(prev_intent.mxid)
  199. pl_state_level = powers.get_event_level(EventType.ROOM_POWER_LEVELS)
  200. if user_level >= pl_state_level > powers.users_default:
  201. powers.ensure_user_level(new_intent.mxid, user_level)
  202. await prev_intent.send_state_event(room_id, EventType.ROOM_POWER_LEVELS, powers)
  203. except Exception:
  204. self.log.warning("Failed to migrate power levels", exc_info=True)
  205. async def update_info(self, info: Profile | Contact | Address) -> None:
  206. address = info.address if isinstance(info, (Contact, Profile)) else info
  207. if address.uuid and not self.uuid:
  208. await self.handle_uuid_receive(address.uuid)
  209. if address.number and not self.number:
  210. await self.handle_number_receive(address.number)
  211. self.log.debug("Updating info for %s", address)
  212. contact_names = self.config["bridge.contact_list_names"]
  213. name = None
  214. if isinstance(info, Profile):
  215. if info.profile_name:
  216. self.log.debug(
  217. "Found profile name on profile for %s: '%s'", address, info.profile_name
  218. )
  219. name = info.profile_name
  220. if contact_names == "prefer" or (contact_names == "allow" and not name):
  221. # Try and overwrite the name with the contact name if that's the preference, or we
  222. # didn't get a profile name.
  223. self.log.debug("Found contact name on profile for %s: '%s'", address, info.name)
  224. name = info.name or name
  225. elif isinstance(info, Contact) and contact_names != "disallow":
  226. self.log.debug("Found contact name on contact for %s: '%s'", address, info.name)
  227. name = info.name
  228. self.log.debug("Using name '%s' for %s", name, address)
  229. async with self._update_info_lock:
  230. update = False
  231. if name is not None or self.name is None:
  232. update = await self._update_name(name) or update
  233. if isinstance(info, Profile):
  234. update = await self._update_avatar(info.avatar) or update
  235. elif contact_names != "disallow" and self.number:
  236. update = await self._update_avatar(f"contact-{self.number}") or update
  237. if update:
  238. await self.update()
  239. asyncio.create_task(self._update_portal_meta())
  240. @staticmethod
  241. def fmt_phone(number: str) -> str:
  242. if phonenumbers is None:
  243. return number
  244. parsed = phonenumbers.parse(number)
  245. fmt = phonenumbers.PhoneNumberFormat.INTERNATIONAL
  246. return phonenumbers.format_number(parsed, fmt)
  247. @classmethod
  248. def _get_displayname(cls, address: Address, name: str | None) -> str:
  249. names = name.split("\x00") if name else []
  250. data = {
  251. "first_name": names[0] if len(names) > 0 else "",
  252. "last_name": names[-1] if len(names) > 1 else "",
  253. "full_name": " ".join(names),
  254. "phone": cls.fmt_phone(address.number) if address.number else None,
  255. "uuid": str(address.uuid) if address.uuid else None,
  256. "displayname": "Unknown user",
  257. }
  258. for pref in cls.config["bridge.displayname_preference"]:
  259. value = data.get(pref.replace(" ", "_"))
  260. if value:
  261. data["displayname"] = value
  262. break
  263. return cls.config["bridge.displayname_template"].format(**data)
  264. async def _update_name(self, name: str | None) -> bool:
  265. name = self._get_displayname(self.address, name)
  266. if name != self.name or not self.name_set:
  267. self.name = name
  268. try:
  269. await self.default_mxid_intent.set_displayname(self.name)
  270. self.name_set = True
  271. except Exception:
  272. self.log.exception("Error setting displayname")
  273. self.name_set = False
  274. return True
  275. return False
  276. @staticmethod
  277. async def upload_avatar(
  278. self: Puppet | p.Portal, path: str, intent: IntentAPI
  279. ) -> bool | tuple[str, ContentURI]:
  280. if not path:
  281. return False
  282. if not path.startswith("/"):
  283. path = os.path.join(self.config["signal.avatar_dir"], path)
  284. try:
  285. with open(path, "rb") as file:
  286. data = file.read()
  287. except FileNotFoundError:
  288. return False
  289. if not data:
  290. return False
  291. new_hash = hashlib.sha256(data).hexdigest()
  292. if self.avatar_set and new_hash == self.avatar_hash:
  293. return False
  294. mxc = await intent.upload_media(data, async_upload=self.config["homeserver.async_media"])
  295. return new_hash, mxc
  296. async def _update_avatar(self, path: str) -> bool:
  297. res = await Puppet.upload_avatar(self, path, self.default_mxid_intent)
  298. if res is False:
  299. return False
  300. self.avatar_hash, self.avatar_url = res
  301. try:
  302. await self.default_mxid_intent.set_avatar_url(self.avatar_url)
  303. self.avatar_set = True
  304. except Exception:
  305. self.log.exception("Error setting avatar")
  306. self.avatar_set = False
  307. return True
  308. async def _update_portal_meta(self) -> None:
  309. async for portal in p.Portal.find_private_chats_with(self.address):
  310. if portal.receiver == self.number:
  311. # This is a note to self chat, don't change the name
  312. continue
  313. try:
  314. await portal.update_puppet_name(self.name)
  315. await portal.update_puppet_avatar(self.avatar_hash, self.avatar_url)
  316. except Exception:
  317. self.log.exception(f"Error updating portal meta for {portal.receiver}")
  318. async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
  319. portal: p.Portal = await p.Portal.get_by_mxid(room_id)
  320. if not portal or not portal.is_direct:
  321. return True
  322. elif portal.chat_id.uuid and self.uuid:
  323. return portal.chat_id.uuid != self.uuid
  324. elif portal.chat_id.number and self.number:
  325. return portal.chat_id.number != self.number
  326. else:
  327. return True
  328. # region Database getters
  329. def _add_to_cache(self) -> None:
  330. if self.uuid:
  331. self.by_uuid[self.uuid] = self
  332. if self.number:
  333. self.by_number[self.number] = self
  334. if self.custom_mxid:
  335. self.by_custom_mxid[self.custom_mxid] = self
  336. async def save(self) -> None:
  337. await self.update()
  338. @classmethod
  339. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Puppet | None:
  340. address = cls.get_id_from_mxid(mxid)
  341. if not address:
  342. return None
  343. return await cls.get_by_address(address, create)
  344. @classmethod
  345. @async_getter_lock
  346. async def get_by_custom_mxid(cls, mxid: UserID) -> Puppet | None:
  347. try:
  348. return cls.by_custom_mxid[mxid]
  349. except KeyError:
  350. pass
  351. puppet = cast(cls, await super().get_by_custom_mxid(mxid))
  352. if puppet:
  353. puppet._add_to_cache()
  354. return puppet
  355. return None
  356. @classmethod
  357. def get_id_from_mxid(cls, mxid: UserID) -> Address | None:
  358. identifier = cls.mxid_template.parse(mxid)
  359. if not identifier:
  360. return None
  361. if identifier.startswith("phone_"):
  362. return Address(number="+" + identifier[len("phone_") :])
  363. else:
  364. try:
  365. return Address(uuid=UUID(identifier.upper()))
  366. except ValueError:
  367. return None
  368. @classmethod
  369. def get_mxid_from_id(cls, address: Address) -> UserID:
  370. if address.uuid:
  371. identifier = str(address.uuid).lower()
  372. elif address.number:
  373. identifier = f"phone_{address.number.lstrip('+')}"
  374. else:
  375. raise ValueError("Empty address")
  376. return UserID(cls.mxid_template.format_full(identifier))
  377. @classmethod
  378. @async_getter_lock
  379. async def get_by_address(cls, address: Address, create: bool = True) -> Puppet | None:
  380. puppet = await cls._get_by_address(address, create)
  381. if puppet and address.uuid and not puppet.uuid:
  382. # We found a UUID for this user, store it ASAP
  383. await puppet.handle_uuid_receive(address.uuid)
  384. return puppet
  385. @classmethod
  386. async def _get_by_address(cls, address: Address, create: bool = True) -> Puppet | None:
  387. if not address.is_valid:
  388. raise ValueError("Empty address")
  389. if address.uuid:
  390. try:
  391. return cls.by_uuid[address.uuid]
  392. except KeyError:
  393. pass
  394. if address.number:
  395. try:
  396. return cls.by_number[address.number]
  397. except KeyError:
  398. pass
  399. puppet = cast(cls, await super().get_by_address(address))
  400. if puppet is not None:
  401. puppet._add_to_cache()
  402. return puppet
  403. if create:
  404. puppet = cls(address.uuid, address.number)
  405. await puppet.insert()
  406. puppet._add_to_cache()
  407. return puppet
  408. return None
  409. @classmethod
  410. async def all_with_custom_mxid(cls) -> AsyncGenerator[Puppet, None]:
  411. puppets = await super().all_with_custom_mxid()
  412. puppet: cls
  413. for index, puppet in enumerate(puppets):
  414. try:
  415. yield cls.by_uuid[puppet.uuid]
  416. except KeyError:
  417. try:
  418. yield cls.by_number[puppet.number]
  419. except KeyError:
  420. puppet._add_to_cache()
  421. yield puppet
  422. # endregion