puppet.py 17 KB


  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import (
  17. Optional,
  18. Dict,
  19. AsyncIterable,
  20. Awaitable,
  21. AsyncGenerator,
  22. Union,
  23. Tuple,
  24. TYPE_CHECKING,
  25. cast,
  26. )
  27. from uuid import UUID
  28. import hashlib
  29. import asyncio
  30. import os.path
  31. from yarl import URL
  32. from mausignald.types import Address, Contact, Profile
  33. from mautrix.bridge import BasePuppet, async_getter_lock
  34. from mautrix.appservice import IntentAPI
  35. from mautrix.types import (
  36. UserID,
  37. SyncToken,
  38. RoomID,
  39. ContentURI,
  40. EventType,
  41. PowerLevelStateEventContent,
  42. )
  43. from mautrix.errors import MForbidden
  44. from mautrix.util.simple_template import SimpleTemplate
  45. from .db import Puppet as DBPuppet
  46. from .config import Config
  47. from . import portal as p, user as u
  48. if TYPE_CHECKING:
  49. from .__main__ import SignalBridge
  50. try:
  51. import phonenumbers
  52. except ImportError:
  53. phonenumbers = None
  54. class Puppet(DBPuppet, BasePuppet):
  55. by_uuid: Dict[UUID, "Puppet"] = {}
  56. by_number: Dict[str, "Puppet"] = {}
  57. by_custom_mxid: Dict[UserID, "Puppet"] = {}
  58. hs_domain: str
  59. mxid_template: SimpleTemplate[str]
  60. config: Config
  61. default_mxid_intent: IntentAPI
  62. default_mxid: UserID
  63. _uuid_lock: asyncio.Lock
  64. _update_info_lock: asyncio.Lock
  65. def __init__(
  66. self,
  67. uuid: Optional[UUID],
  68. number: Optional[str],
  69. name: Optional[str] = None,
  70. avatar_url: Optional[ContentURI] = None,
  71. avatar_hash: Optional[str] = None,
  72. name_set: bool = False,
  73. avatar_set: bool = False,
  74. uuid_registered: bool = False,
  75. number_registered: bool = False,
  76. custom_mxid: Optional[UserID] = None,
  77. access_token: Optional[str] = None,
  78. next_batch: Optional[SyncToken] = None,
  79. base_url: Optional[URL] = None,
  80. ) -> None:
  81. super().__init__(
  82. uuid=uuid,
  83. number=number,
  84. name=name,
  85. avatar_url=avatar_url,
  86. avatar_hash=avatar_hash,
  87. name_set=name_set,
  88. avatar_set=avatar_set,
  89. uuid_registered=uuid_registered,
  90. number_registered=number_registered,
  91. custom_mxid=custom_mxid,
  92. access_token=access_token,
  93. next_batch=next_batch,
  94. base_url=base_url,
  95. )
  96. self.log = self.log.getChild(str(uuid) if uuid else number)
  97. self.default_mxid = self.get_mxid_from_id(self.address)
  98. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  99. self.intent = self._fresh_intent()
  100. self._uuid_lock = asyncio.Lock()
  101. self._update_info_lock = asyncio.Lock()
  102. @classmethod
  103. def init_cls(cls, bridge: "SignalBridge") -> AsyncIterable[Awaitable[None]]:
  104. cls.config = bridge.config
  105. cls.loop = bridge.loop
  106. cls.mx = bridge.matrix
  107. cls.az = bridge.az
  108. cls.hs_domain = cls.config["homeserver.domain"]
  109. cls.mxid_template = SimpleTemplate(
  110. cls.config["bridge.username_template"],
  111. "userid",
  112. prefix="@",
  113. suffix=f":{cls.hs_domain}",
  114. type=str,
  115. )
  116. cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
  117. cls.homeserver_url_map = {
  118. server: URL(url)
  119. for server, url in cls.config["bridge.double_puppet_server_map"].items()
  120. }
  121. cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
  122. cls.login_shared_secret_map = {
  123. server: secret.encode("utf-8")
  124. for server, secret in cls.config["bridge.login_shared_secret_map"].items()
  125. }
  126. cls.login_device_name = "Signal Bridge"
  127. return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
  128. def intent_for(self, portal: "p.Portal") -> IntentAPI:
  129. if portal.chat_id == self.address:
  130. return self.default_mxid_intent
  131. return self.intent
  132. @property
  133. def is_registered(self) -> bool:
  134. return self.uuid_registered if self.uuid is not None else self.number_registered
  135. @is_registered.setter
  136. def is_registered(self, value: bool) -> None:
  137. if self.uuid is not None:
  138. self.uuid_registered = value
  139. else:
  140. self.number_registered = value
  141. @property
  142. def address(self) -> Address:
  143. return Address(uuid=self.uuid, number=self.number)
  144. async def handle_uuid_receive(self, uuid: UUID) -> None:
  145. async with self._uuid_lock:
  146. if self.uuid:
  147. # Received UUID was handled while this call was waiting
  148. return
  149. await self._handle_uuid_receive(uuid)
  150. async def handle_number_receive(self, number: str) -> None:
  151. async with self._uuid_lock:
  152. if self.number:
  153. return
  154. self.number = number
  155. self.by_number[self.number] = self
  156. await self._set_number(number)
  157. async for portal in p.Portal.find_private_chats_with(Address(number=number)):
  158. self.log.trace(f"Updating chat_id of private chat portal {portal.receiver}")
  159. portal.handle_uuid_receive(self.uuid)
  160. prev_mxid = self.get_mxid_from_id(Address(number=number))
  161. if await self.az.state_store.is_registered(prev_mxid):
  162. prev_intent = self.az.intent.user(prev_mxid)
  163. await self._migrate_memberships(prev_intent, self.default_mxid_intent)
  164. async def _handle_uuid_receive(self, uuid: UUID) -> None:
  165. self.log.debug(f"Found UUID for user: {uuid}")
  166. user = await u.User.get_by_username(self.number)
  167. if user and not user.uuid:
  168. user.uuid = self.uuid
  169. user.by_uuid[user.uuid] = user
  170. await user.update()
  171. self.uuid = uuid
  172. self.by_uuid[self.uuid] = self
  173. await self._set_uuid(uuid)
  174. async for portal in p.Portal.find_private_chats_with(Address(number=self.number)):
  175. self.log.trace(f"Updating chat_id of private chat portal {portal.receiver}")
  176. portal.handle_uuid_receive(self.uuid)
  177. prev_intent = self.default_mxid_intent
  178. self.default_mxid = self.get_mxid_from_id(self.address)
  179. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  180. self.intent = self._fresh_intent()
  181. await self.default_mxid_intent.ensure_registered()
  182. if self.name:
  183. await self.default_mxid_intent.set_displayname(self.name)
  184. self.log = Puppet.log.getChild(str(uuid))
  185. await self._migrate_memberships(prev_intent, self.default_mxid_intent)
  186. async def _migrate_memberships(self, prev_intent: IntentAPI, new_intent: IntentAPI) -> None:
  187. self.log.debug(f"Migrating memberships {prev_intent.mxid} -> {new_intent.mxid}")
  188. try:
  189. joined_rooms = await prev_intent.get_joined_rooms()
  190. except MForbidden as e:
  191. self.log.debug(
  192. f"Got MForbidden ({e.message}) when getting joined rooms of old mxid, "
  193. "assuming there are no rooms to rejoin"
  194. )
  195. return
  196. for room_id in joined_rooms:
  197. await prev_intent.invite_user(room_id, self.default_mxid)
  198. await self._migrate_powers(prev_intent, new_intent, room_id)
  199. await prev_intent.leave_room(room_id)
  200. await new_intent.join_room_by_id(room_id)
  201. async def _migrate_powers(
  202. self, prev_intent: IntentAPI, new_intent: IntentAPI, room_id: RoomID
  203. ) -> None:
  204. try:
  205. powers: PowerLevelStateEventContent
  206. powers = await prev_intent.get_state_event(room_id, EventType.ROOM_POWER_LEVELS)
  207. user_level = powers.get_user_level(prev_intent.mxid)
  208. pl_state_level = powers.get_event_level(EventType.ROOM_POWER_LEVELS)
  209. if user_level >= pl_state_level > powers.users_default:
  210. powers.ensure_user_level(new_intent.mxid, user_level)
  211. await prev_intent.send_state_event(room_id, EventType.ROOM_POWER_LEVELS, powers)
  212. except Exception:
  213. self.log.warning("Failed to migrate power levels", exc_info=True)
  214. async def update_info(self, info: Union[Profile, Contact, Address]) -> None:
  215. address = info.address if isinstance(info, (Contact, Profile)) else info
  216. if address.uuid and not self.uuid:
  217. await self.handle_uuid_receive(address.uuid)
  218. if address.number and not self.number:
  219. await self.handle_number_receive(address.number)
  220. contact_names = self.config["bridge.contact_list_names"]
  221. if isinstance(info, Profile) and contact_names != "prefer" and info.profile_name:
  222. name = info.profile_name
  223. elif isinstance(info, (Contact, Profile)) and contact_names != "disallow":
  224. name = info.name
  225. if not name and isinstance(info, Profile) and info.profile_name:
  226. # Contact list name is preferred, but was not found, fall back to profile
  227. name = info.profile_name
  228. else:
  229. name = None
  230. async with self._update_info_lock:
  231. update = False
  232. if name is not None or self.name is None:
  233. update = await self._update_name(name) or update
  234. if isinstance(info, Profile):
  235. update = await self._update_avatar(info.avatar) or update
  236. elif contact_names != "disallow" and self.number:
  237. update = await self._update_avatar(f"contact-{self.number}") or update
  238. if update:
  239. await self.update()
  240. asyncio.create_task(self._update_portal_meta())
  241. @staticmethod
  242. def fmt_phone(number: str) -> str:
  243. if phonenumbers is None:
  244. return number
  245. parsed = phonenumbers.parse(number)
  246. fmt = phonenumbers.PhoneNumberFormat.INTERNATIONAL
  247. return phonenumbers.format_number(parsed, fmt)
  248. @classmethod
  249. def _get_displayname(cls, address: Address, name: Optional[str]) -> str:
  250. names = name.split("\x00") if name else []
  251. data = {
  252. "first_name": names[0] if len(names) > 0 else "",
  253. "last_name": names[-1] if len(names) > 1 else "",
  254. "full_name": " ".join(names),
  255. "phone": cls.fmt_phone(address.number) if address.number else None,
  256. "uuid": str(address.uuid) if address.uuid else None,
  257. "displayname": "Unknown user",
  258. }
  259. for pref in cls.config["bridge.displayname_preference"]:
  260. value = data.get(pref.replace(" ", "_"))
  261. if value:
  262. data["displayname"] = value
  263. break
  264. return cls.config["bridge.displayname_template"].format(**data)
  265. async def _update_name(self, name: Optional[str]) -> bool:
  266. name = self._get_displayname(self.address, name)
  267. if name != self.name or not self.name_set:
  268. self.name = name
  269. try:
  270. await self.default_mxid_intent.set_displayname(self.name)
  271. self.name_set = True
  272. except Exception:
  273. self.log.exception("Error setting displayname")
  274. self.name_set = False
  275. return True
  276. return False
  277. @staticmethod
  278. async def upload_avatar(
  279. self: Union["Puppet", "p.Portal"],
  280. path: str,
  281. intent: IntentAPI,
  282. ) -> Union[bool, Tuple[str, ContentURI]]:
  283. if not path:
  284. return False
  285. if not path.startswith("/"):
  286. path = os.path.join(self.config["signal.avatar_dir"], path)
  287. try:
  288. with open(path, "rb") as file:
  289. data = file.read()
  290. except FileNotFoundError:
  291. return False
  292. if not data:
  293. return False
  294. new_hash = hashlib.sha256(data).hexdigest()
  295. if self.avatar_set and new_hash == self.avatar_hash:
  296. return False
  297. mxc = await intent.upload_media(data)
  298. return new_hash, mxc
  299. async def _update_avatar(self, path: str) -> bool:
  300. res = await Puppet.upload_avatar(self, path, self.default_mxid_intent)
  301. if res is False:
  302. return False
  303. self.avatar_hash, self.avatar_url = res
  304. try:
  305. await self.default_mxid_intent.set_avatar_url(self.avatar_url)
  306. self.avatar_set = True
  307. except Exception:
  308. self.log.exception("Error setting avatar")
  309. self.avatar_set = False
  310. return True
  311. async def _update_portal_meta(self) -> None:
  312. async for portal in p.Portal.find_private_chats_with(self.address):
  313. if portal.receiver == self.number:
  314. # This is a note to self chat, don't change the name
  315. continue
  316. try:
  317. await portal.update_puppet_name(self.name)
  318. await portal.update_puppet_avatar(self.avatar_hash, self.avatar_url)
  319. except Exception:
  320. self.log.exception(f"Error updating portal meta for {portal.receiver}")
  321. async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
  322. portal = await p.Portal.get_by_mxid(room_id)
  323. return portal and portal.chat_id != self.uuid
  324. # region Database getters
  325. def _add_to_cache(self) -> None:
  326. if self.uuid:
  327. self.by_uuid[self.uuid] = self
  328. if self.number:
  329. self.by_number[self.number] = self
  330. if self.custom_mxid:
  331. self.by_custom_mxid[self.custom_mxid] = self
  332. async def save(self) -> None:
  333. await self.update()
  334. @classmethod
  335. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional["Puppet"]:
  336. address = cls.get_id_from_mxid(mxid)
  337. if not address:
  338. return None
  339. return await cls.get_by_address(address, create)
  340. @classmethod
  341. @async_getter_lock
  342. async def get_by_custom_mxid(cls, mxid: UserID) -> Optional["Puppet"]:
  343. try:
  344. return cls.by_custom_mxid[mxid]
  345. except KeyError:
  346. pass
  347. puppet = cast(cls, await super().get_by_custom_mxid(mxid))
  348. if puppet:
  349. puppet._add_to_cache()
  350. return puppet
  351. return None
  352. @classmethod
  353. def get_id_from_mxid(cls, mxid: UserID) -> Optional[Address]:
  354. identifier = cls.mxid_template.parse(mxid)
  355. if not identifier:
  356. return None
  357. if identifier.startswith("phone_"):
  358. return Address(number="+" + identifier[len("phone_") :])
  359. else:
  360. try:
  361. return Address(uuid=UUID(identifier.upper()))
  362. except ValueError:
  363. return None
  364. @classmethod
  365. def get_mxid_from_id(cls, address: Address) -> UserID:
  366. if address.uuid:
  367. identifier = str(address.uuid).lower()
  368. elif address.number:
  369. identifier = f"phone_{address.number.lstrip('+')}"
  370. else:
  371. raise ValueError("Empty address")
  372. return UserID(cls.mxid_template.format_full(identifier))
  373. @classmethod
  374. @async_getter_lock
  375. async def get_by_address(cls, address: Address, create: bool = True) -> Optional["Puppet"]:
  376. puppet = await cls._get_by_address(address, create)
  377. if puppet and address.uuid and not puppet.uuid:
  378. # We found a UUID for this user, store it ASAP
  379. await puppet.handle_uuid_receive(address.uuid)
  380. return puppet
  381. @classmethod
  382. async def _get_by_address(cls, address: Address, create: bool = True) -> Optional["Puppet"]:
  383. if not address.is_valid:
  384. raise ValueError("Empty address")
  385. if address.uuid:
  386. try:
  387. return cls.by_uuid[address.uuid]
  388. except KeyError:
  389. pass
  390. if address.number:
  391. try:
  392. return cls.by_number[address.number]
  393. except KeyError:
  394. pass
  395. puppet = cast(cls, await super().get_by_address(address))
  396. if puppet is not None:
  397. puppet._add_to_cache()
  398. return puppet
  399. if create:
  400. puppet = cls(address.uuid, address.number)
  401. await puppet.insert()
  402. puppet._add_to_cache()
  403. return puppet
  404. return None
  405. @classmethod
  406. async def all_with_custom_mxid(cls) -> AsyncGenerator["Puppet", None]:
  407. puppets = await super().all_with_custom_mxid()
  408. puppet: cls
  409. for index, puppet in enumerate(puppets):
  410. try:
  411. yield cls.by_uuid[puppet.uuid]
  412. except KeyError:
  413. try:
  414. yield cls.by_number[puppet.number]
  415. except KeyError:
  416. puppet._add_to_cache()
  417. yield puppet
  418. # endregion