puppet.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import (Optional, Dict, AsyncIterable, Awaitable, AsyncGenerator, Union, Tuple,
  17. TYPE_CHECKING, cast)
  18. from uuid import UUID
  19. import hashlib
  20. import asyncio
  21. import os.path
  22. from yarl import URL
  23. from mausignald.types import Address, Contact, Profile
  24. from mautrix.bridge import BasePuppet, async_getter_lock
  25. from mautrix.appservice import IntentAPI
  26. from mautrix.types import UserID, SyncToken, RoomID, ContentURI
  27. from mautrix.errors import MForbidden
  28. from mautrix.util.simple_template import SimpleTemplate
  29. from .db import Puppet as DBPuppet
  30. from .config import Config
  31. from . import portal as p, user as u
  32. if TYPE_CHECKING:
  33. from .__main__ import SignalBridge
  34. try:
  35. import phonenumbers
  36. except ImportError:
  37. phonenumbers = None
  38. class Puppet(DBPuppet, BasePuppet):
  39. by_uuid: Dict[UUID, 'Puppet'] = {}
  40. by_number: Dict[str, 'Puppet'] = {}
  41. by_custom_mxid: Dict[UserID, 'Puppet'] = {}
  42. hs_domain: str
  43. mxid_template: SimpleTemplate[str]
  44. config: Config
  45. default_mxid_intent: IntentAPI
  46. default_mxid: UserID
  47. _uuid_lock: asyncio.Lock
  48. _update_info_lock: asyncio.Lock
  49. def __init__(self, uuid: Optional[UUID], number: Optional[str], name: Optional[str] = None,
  50. avatar_url: Optional[ContentURI] = None, avatar_hash: Optional[str] = None,
  51. name_set: bool = False, avatar_set: bool = False, uuid_registered: bool = False,
  52. number_registered: bool = False, custom_mxid: Optional[UserID] = None,
  53. access_token: Optional[str] = None, next_batch: Optional[SyncToken] = None,
  54. base_url: Optional[URL] = None) -> None:
  55. super().__init__(uuid=uuid, number=number, name=name, avatar_url=avatar_url,
  56. avatar_hash=avatar_hash, name_set=name_set, avatar_set=avatar_set,
  57. uuid_registered=uuid_registered, number_registered=number_registered,
  58. custom_mxid=custom_mxid, access_token=access_token, next_batch=next_batch,
  59. base_url=base_url)
  60. self.log = self.log.getChild(str(uuid) if uuid else number)
  61. self.default_mxid = self.get_mxid_from_id(self.address)
  62. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  63. self.intent = self._fresh_intent()
  64. self._uuid_lock = asyncio.Lock()
  65. self._update_info_lock = asyncio.Lock()
  66. @classmethod
  67. def init_cls(cls, bridge: 'SignalBridge') -> AsyncIterable[Awaitable[None]]:
  68. cls.config = bridge.config
  69. cls.loop = bridge.loop
  70. cls.mx = bridge.matrix
  71. cls.az = bridge.az
  72. cls.hs_domain = cls.config["homeserver.domain"]
  73. cls.mxid_template = SimpleTemplate(cls.config["bridge.username_template"], "userid",
  74. prefix="@", suffix=f":{cls.hs_domain}", type=str)
  75. cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
  76. cls.homeserver_url_map = {server: URL(url) for server, url
  77. in cls.config["bridge.double_puppet_server_map"].items()}
  78. cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
  79. cls.login_shared_secret_map = {server: secret.encode("utf-8") for server, secret
  80. in cls.config["bridge.login_shared_secret_map"].items()}
  81. cls.login_device_name = "Signal Bridge"
  82. return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
  83. def intent_for(self, portal: 'p.Portal') -> IntentAPI:
  84. if portal.chat_id == self.address:
  85. return self.default_mxid_intent
  86. return self.intent
  87. @property
  88. def is_registered(self) -> bool:
  89. return self.uuid_registered if self.uuid is not None else self.number_registered
  90. @is_registered.setter
  91. def is_registered(self, value: bool) -> None:
  92. if self.uuid is not None:
  93. self.uuid_registered = value
  94. else:
  95. self.number_registered = value
  96. @property
  97. def address(self) -> Address:
  98. return Address(uuid=self.uuid, number=self.number)
  99. async def handle_uuid_receive(self, uuid: UUID) -> None:
  100. async with self._uuid_lock:
  101. if self.uuid:
  102. # Received UUID was handled while this call was waiting
  103. return
  104. await self._handle_uuid_receive(uuid)
  105. async def _handle_uuid_receive(self, uuid: UUID) -> None:
  106. self.log.debug(f"Found UUID for user: {uuid}")
  107. user = await u.User.get_by_username(self.number)
  108. if user and not user.uuid:
  109. user.uuid = self.uuid
  110. user.by_uuid[user.uuid] = user
  111. await user.update()
  112. await self._set_uuid(uuid)
  113. self.by_uuid[self.uuid] = self
  114. async for portal in p.Portal.find_private_chats_with(Address(number=self.number)):
  115. self.log.trace(f"Updating chat_id of private chat portal {portal.receiver}")
  116. portal.handle_uuid_receive(self.uuid)
  117. prev_intent = self.default_mxid_intent
  118. self.default_mxid = self.get_mxid_from_id(self.address)
  119. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  120. self.intent = self._fresh_intent()
  121. await self.default_mxid_intent.ensure_registered()
  122. if self.name:
  123. await self.default_mxid_intent.set_displayname(self.name)
  124. self.log = Puppet.log.getChild(str(uuid))
  125. self.log.debug(f"Migrating memberships {prev_intent.mxid}"
  126. f" -> {self.default_mxid_intent.mxid}")
  127. try:
  128. joined_rooms = await prev_intent.get_joined_rooms()
  129. except MForbidden as e:
  130. self.log.debug(f"Got MForbidden ({e.message}) when getting joined rooms of old mxid, "
  131. "assuming there are no rooms to rejoin")
  132. return
  133. for room_id in joined_rooms:
  134. await prev_intent.invite_user(room_id, self.default_mxid)
  135. await prev_intent.leave_room(room_id)
  136. await self.default_mxid_intent.join_room_by_id(room_id)
  137. async def update_info(self, info: Union[Profile, Contact, Address]) -> None:
  138. if isinstance(info, (Contact, Address)):
  139. address = info.address if isinstance(info, Contact) else info
  140. if address.uuid and not self.uuid:
  141. await self.handle_uuid_receive(address.uuid)
  142. contact_names = self.config["bridge.contact_list_names"]
  143. if isinstance(info, Profile) and contact_names != "prefer" and info.profile_name:
  144. name = info.profile_name
  145. elif isinstance(info, (Contact, Profile)) and contact_names != "disallow":
  146. name = info.name
  147. else:
  148. name = None
  149. async with self._update_info_lock:
  150. update = False
  151. if name is not None or self.name is None:
  152. update = await self._update_name(name) or update
  153. if isinstance(info, Profile):
  154. update = await self._update_avatar(info.avatar) or update
  155. if update:
  156. await self.update()
  157. self.loop.create_task(self._update_portal_meta())
  158. @staticmethod
  159. def fmt_phone(number: str) -> str:
  160. if phonenumbers is None:
  161. return number
  162. parsed = phonenumbers.parse(number)
  163. fmt = phonenumbers.PhoneNumberFormat.INTERNATIONAL
  164. return phonenumbers.format_number(parsed, fmt)
  165. @classmethod
  166. def _get_displayname(cls, address: Address, name: Optional[str]) -> str:
  167. names = name.split("\x00") if name else []
  168. data = {
  169. "first_name": names[0] if len(names) > 0 else "",
  170. "last_name": names[-1] if len(names) > 1 else "",
  171. "full_name": " ".join(names),
  172. "phone": cls.fmt_phone(address.number) if address.number else None,
  173. "uuid": str(address.uuid) if address.uuid else None,
  174. "displayname": "Unknown user",
  175. }
  176. for pref in cls.config["bridge.displayname_preference"]:
  177. value = data.get(pref.replace(" ", "_"))
  178. if value:
  179. data["displayname"] = value
  180. break
  181. return cls.config["bridge.displayname_template"].format(**data)
  182. async def _update_name(self, name: Optional[str]) -> bool:
  183. name = self._get_displayname(self.address, name)
  184. if name != self.name or not self.name_set:
  185. self.name = name
  186. try:
  187. await self.default_mxid_intent.set_displayname(self.name)
  188. self.name_set = True
  189. except Exception:
  190. self.log.exception("Error setting displayname")
  191. self.name_set = False
  192. return True
  193. return False
  194. @staticmethod
  195. async def upload_avatar(self: Union['Puppet', 'p.Portal'], path: str, intent: IntentAPI,
  196. ) -> Union[bool, Tuple[str, ContentURI]]:
  197. if not path:
  198. return False
  199. if not path.startswith("/"):
  200. path = os.path.join(self.config["signal.avatar_dir"], path)
  201. try:
  202. with open(path, "rb") as file:
  203. data = file.read()
  204. except FileNotFoundError:
  205. return False
  206. new_hash = hashlib.sha256(data).hexdigest()
  207. if self.avatar_set and new_hash == self.avatar_hash:
  208. return False
  209. mxc = await intent.upload_media(data)
  210. return new_hash, mxc
  211. async def _update_avatar(self, path: str) -> bool:
  212. res = await Puppet.upload_avatar(self, path, self.default_mxid_intent)
  213. if res is False:
  214. return False
  215. self.avatar_hash, self.avatar_url = res
  216. try:
  217. await self.default_mxid_intent.set_avatar_url(self.avatar_url)
  218. self.avatar_set = True
  219. except Exception:
  220. self.log.exception("Error setting avatar")
  221. self.avatar_set = False
  222. return True
  223. async def _update_portal_meta(self) -> None:
  224. async for portal in p.Portal.find_private_chats_with(self.address):
  225. if portal.receiver == self.number:
  226. # This is a note to self chat, don't change the name
  227. continue
  228. try:
  229. await portal.update_puppet_name(self.name)
  230. await portal.update_puppet_avatar(self.avatar_hash, self.avatar_url)
  231. except Exception:
  232. self.log.exception(f"Error updating portal meta for {portal.receiver}")
  233. async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
  234. portal = await p.Portal.get_by_mxid(room_id)
  235. return portal and portal.chat_id != self.uuid
  236. # region Database getters
  237. def _add_to_cache(self) -> None:
  238. if self.uuid:
  239. self.by_uuid[self.uuid] = self
  240. if self.number:
  241. self.by_number[self.number] = self
  242. if self.custom_mxid:
  243. self.by_custom_mxid[self.custom_mxid] = self
  244. async def save(self) -> None:
  245. await self.update()
  246. @classmethod
  247. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']:
  248. address = cls.get_id_from_mxid(mxid)
  249. if not address:
  250. return None
  251. return await cls.get_by_address(address, create)
  252. @classmethod
  253. @async_getter_lock
  254. async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
  255. try:
  256. return cls.by_custom_mxid[mxid]
  257. except KeyError:
  258. pass
  259. puppet = cast(cls, await super().get_by_custom_mxid(mxid))
  260. if puppet:
  261. puppet._add_to_cache()
  262. return puppet
  263. return None
  264. @classmethod
  265. def get_id_from_mxid(cls, mxid: UserID) -> Optional[Address]:
  266. identifier = cls.mxid_template.parse(mxid)
  267. if not identifier:
  268. return None
  269. if identifier.startswith("phone_"):
  270. return Address(number="+" + identifier[len("phone_"):])
  271. else:
  272. try:
  273. return Address(uuid=UUID(identifier.upper()))
  274. except ValueError:
  275. return None
  276. @classmethod
  277. def get_mxid_from_id(cls, address: Address) -> UserID:
  278. if address.uuid:
  279. identifier = str(address.uuid).lower()
  280. elif address.number:
  281. identifier = f"phone_{address.number.lstrip('+')}"
  282. else:
  283. raise ValueError("Empty address")
  284. return UserID(cls.mxid_template.format_full(identifier))
  285. @classmethod
  286. @async_getter_lock
  287. async def get_by_address(cls, address: Address, create: bool = True) -> Optional['Puppet']:
  288. puppet = await cls._get_by_address(address, create)
  289. if puppet and address.uuid and not puppet.uuid:
  290. # We found a UUID for this user, store it ASAP
  291. await puppet.handle_uuid_receive(address.uuid)
  292. return puppet
  293. @classmethod
  294. async def _get_by_address(cls, address: Address, create: bool = True) -> Optional['Puppet']:
  295. if not address.is_valid:
  296. raise ValueError("Empty address")
  297. if address.uuid:
  298. try:
  299. return cls.by_uuid[address.uuid]
  300. except KeyError:
  301. pass
  302. if address.number:
  303. try:
  304. return cls.by_number[address.number]
  305. except KeyError:
  306. pass
  307. puppet = cast(cls, await super().get_by_address(address))
  308. if puppet is not None:
  309. puppet._add_to_cache()
  310. return puppet
  311. if create:
  312. puppet = cls(address.uuid, address.number)
  313. await puppet.insert()
  314. puppet._add_to_cache()
  315. return puppet
  316. return None
  317. @classmethod
  318. async def all_with_custom_mxid(cls) -> AsyncGenerator['Puppet', None]:
  319. puppets = await super().all_with_custom_mxid()
  320. puppet: cls
  321. for index, puppet in enumerate(puppets):
  322. try:
  323. yield cls.by_uuid[puppet.uuid]
  324. except KeyError:
  325. try:
  326. yield cls.by_number[puppet.number]
  327. except KeyError:
  328. puppet._add_to_cache()
  329. yield puppet
  330. # endregion