puppet.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import (Optional, Dict, AsyncIterable, Awaitable, AsyncGenerator, Union,
  17. TYPE_CHECKING, cast)
  18. from uuid import UUID
  19. import asyncio
  20. from yarl import URL
  21. from mausignald.types import Address, Contact, Profile
  22. from mautrix.bridge import BasePuppet
  23. from mautrix.appservice import IntentAPI
  24. from mautrix.types import UserID, SyncToken, RoomID
  25. from mautrix.errors import MForbidden
  26. from mautrix.util.simple_template import SimpleTemplate
  27. from .db import Puppet as DBPuppet
  28. from .config import Config
  29. from . import portal as p, user as u
  30. if TYPE_CHECKING:
  31. from .__main__ import SignalBridge
  32. try:
  33. import phonenumbers
  34. except ImportError:
  35. phonenumbers = None
  36. class Puppet(DBPuppet, BasePuppet):
  37. by_uuid: Dict[UUID, 'Puppet'] = {}
  38. by_number: Dict[str, 'Puppet'] = {}
  39. by_custom_mxid: Dict[UserID, 'Puppet'] = {}
  40. hs_domain: str
  41. mxid_template: SimpleTemplate[str]
  42. config: Config
  43. default_mxid_intent: IntentAPI
  44. default_mxid: UserID
  45. _uuid_lock: asyncio.Lock
  46. _update_info_lock: asyncio.Lock
  47. def __init__(self, uuid: Optional[UUID], number: Optional[str], name: Optional[str] = None,
  48. uuid_registered: bool = False, number_registered: bool = False,
  49. custom_mxid: Optional[UserID] = None, access_token: Optional[str] = None,
  50. next_batch: Optional[SyncToken] = None, base_url: Optional[URL] = None) -> None:
  51. super().__init__(uuid=uuid, number=number, name=name, uuid_registered=uuid_registered,
  52. number_registered=number_registered, custom_mxid=custom_mxid,
  53. access_token=access_token, next_batch=next_batch, base_url=base_url)
  54. self.log = self.log.getChild(str(uuid) if uuid else number)
  55. self.default_mxid = self.get_mxid_from_id(self.address)
  56. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  57. self.intent = self._fresh_intent()
  58. self._uuid_lock = asyncio.Lock()
  59. self._update_info_lock = asyncio.Lock()
  60. @classmethod
  61. def init_cls(cls, bridge: 'SignalBridge') -> AsyncIterable[Awaitable[None]]:
  62. cls.config = bridge.config
  63. cls.loop = bridge.loop
  64. cls.mx = bridge.matrix
  65. cls.az = bridge.az
  66. cls.hs_domain = cls.config["homeserver.domain"]
  67. cls.mxid_template = SimpleTemplate(cls.config["bridge.username_template"], "userid",
  68. prefix="@", suffix=f":{cls.hs_domain}", type=str)
  69. cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
  70. cls.homeserver_url_map = {server: URL(url) for server, url
  71. in cls.config["bridge.double_puppet_server_map"].items()}
  72. cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
  73. cls.login_shared_secret_map = {server: secret.encode("utf-8") for server, secret
  74. in cls.config["bridge.login_shared_secret_map"].items()}
  75. cls.login_device_name = "Signal Bridge"
  76. return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
  77. def intent_for(self, portal: 'p.Portal') -> IntentAPI:
  78. if portal.chat_id == self.address:
  79. return self.default_mxid_intent
  80. return self.intent
  81. @property
  82. def is_registered(self) -> bool:
  83. return self.uuid_registered if self.uuid is not None else self.number_registered
  84. @is_registered.setter
  85. def is_registered(self, value: bool) -> None:
  86. if self.uuid is not None:
  87. self.uuid_registered = value
  88. else:
  89. self.number_registered = value
  90. @property
  91. def address(self) -> Address:
  92. return Address(uuid=self.uuid, number=self.number)
  93. async def handle_uuid_receive(self, uuid: UUID) -> None:
  94. async with self._uuid_lock:
  95. if self.uuid:
  96. # Received UUID was handled while this call was waiting
  97. return
  98. await self._handle_uuid_receive(uuid)
  99. async def _handle_uuid_receive(self, uuid: UUID) -> None:
  100. self.log.debug(f"Found UUID for user: {uuid}")
  101. user = await u.User.get_by_username(self.number)
  102. if user and not user.uuid:
  103. user.uuid = self.uuid
  104. await user.update()
  105. await self._set_uuid(uuid)
  106. self.by_uuid[self.uuid] = self
  107. async for portal in p.Portal.find_private_chats_with(Address(number=self.number)):
  108. self.log.trace(f"Updating chat_id of private chat portal {portal.receiver}")
  109. portal.handle_uuid_receive(self.uuid)
  110. prev_intent = self.default_mxid_intent
  111. self.default_mxid = self.get_mxid_from_id(self.address)
  112. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  113. self.intent = self._fresh_intent()
  114. await self.default_mxid_intent.ensure_registered()
  115. if self.name:
  116. await self.default_mxid_intent.set_displayname(self.name)
  117. self.log = Puppet.log.getChild(str(uuid))
  118. self.log.debug(f"Migrating memberships {prev_intent.mxid}"
  119. f" -> {self.default_mxid_intent.mxid}")
  120. try:
  121. joined_rooms = await prev_intent.get_joined_rooms()
  122. except MForbidden as e:
  123. self.log.debug(f"Got MForbidden ({e.message}) when getting joined rooms of old mxid, "
  124. "assuming there are no rooms to rejoin")
  125. return
  126. for room_id in joined_rooms:
  127. await prev_intent.invite_user(room_id, self.default_mxid)
  128. await prev_intent.leave_room(room_id)
  129. await self.default_mxid_intent.join_room_by_id(room_id)
  130. async def update_info(self, info: Union[Profile, Contact, Address]) -> None:
  131. if isinstance(info, (Contact, Address)):
  132. address = info.address if isinstance(info, Contact) else info
  133. if address.uuid and not self.uuid:
  134. await self.handle_uuid_receive(address.uuid)
  135. if not self.config["bridge.allow_contact_list_name_updates"] and self.name is not None:
  136. return
  137. name = info.name if isinstance(info, (Contact, Profile)) else None
  138. async with self._update_info_lock:
  139. update = False
  140. if name is not None or self.name is None:
  141. update = await self._update_name(name) or update
  142. if update:
  143. await self.update()
  144. @staticmethod
  145. def fmt_phone(number: str) -> str:
  146. if phonenumbers is None:
  147. return number
  148. parsed = phonenumbers.parse(number)
  149. fmt = phonenumbers.PhoneNumberFormat.INTERNATIONAL
  150. return phonenumbers.format_number(parsed, fmt)
  151. @classmethod
  152. def _get_displayname(cls, address: Address, name: Optional[str]) -> str:
  153. names = name.split("\x00") if name else []
  154. data = {
  155. "first_name": names[0] if len(names) > 0 else "",
  156. "last_name": names[-1] if len(names) > 1 else "",
  157. "full_name": " ".join(names),
  158. "phone": cls.fmt_phone(address.number) if address.number else None,
  159. "uuid": str(address.uuid) if address.uuid else None,
  160. "displayname": "Unknown user",
  161. }
  162. for pref in cls.config["bridge.displayname_preference"]:
  163. value = data.get(pref.replace(" ", "_"))
  164. if value:
  165. data["displayname"] = value
  166. break
  167. return cls.config["bridge.displayname_template"].format(**data)
  168. async def _update_name(self, name: Optional[str]) -> bool:
  169. name = self._get_displayname(self.address, name)
  170. if name != self.name:
  171. self.name = name
  172. await self.default_mxid_intent.set_displayname(self.name)
  173. self.loop.create_task(self._update_portal_names())
  174. return True
  175. return False
  176. async def _update_portal_names(self) -> None:
  177. async for portal in p.Portal.find_private_chats_with(self.address):
  178. if portal.receiver == self.number:
  179. # This is a note to self chat, don't change the name
  180. continue
  181. await portal.update_puppet_name(self.name)
  182. async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
  183. portal = await p.Portal.get_by_mxid(room_id)
  184. return portal and portal.chat_id != self.uuid
  185. # region Database getters
  186. def _add_to_cache(self) -> None:
  187. if self.uuid:
  188. self.by_uuid[self.uuid] = self
  189. if self.number:
  190. self.by_number[self.number] = self
  191. if self.custom_mxid:
  192. self.by_custom_mxid[self.custom_mxid] = self
  193. async def save(self) -> None:
  194. await self.update()
  195. @classmethod
  196. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']:
  197. address = cls.get_id_from_mxid(mxid)
  198. if not address:
  199. return None
  200. return await cls.get_by_address(address, create)
  201. @classmethod
  202. async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
  203. try:
  204. return cls.by_custom_mxid[mxid]
  205. except KeyError:
  206. pass
  207. puppet = cast(cls, await super().get_by_custom_mxid(mxid))
  208. if puppet:
  209. puppet._add_to_cache()
  210. return puppet
  211. return None
  212. @classmethod
  213. def get_id_from_mxid(cls, mxid: UserID) -> Optional[Address]:
  214. identifier = cls.mxid_template.parse(mxid)
  215. if not identifier:
  216. return None
  217. if identifier.startswith("phone_"):
  218. return Address(number="+" + identifier[len("phone_"):])
  219. else:
  220. try:
  221. return Address(uuid=UUID(identifier.upper()))
  222. except ValueError:
  223. return None
  224. @classmethod
  225. def get_mxid_from_id(cls, address: Address) -> UserID:
  226. if address.uuid:
  227. identifier = str(address.uuid).lower()
  228. elif address.number:
  229. identifier = f"phone_{address.number.lstrip('+')}"
  230. else:
  231. raise ValueError("Empty address")
  232. return UserID(cls.mxid_template.format_full(identifier))
  233. @classmethod
  234. async def get_by_address(cls, address: Address, create: bool = True) -> Optional['Puppet']:
  235. puppet = await cls._get_by_address(address, create)
  236. if puppet and address.uuid and not puppet.uuid:
  237. # We found a UUID for this user, store it ASAP
  238. await puppet.handle_uuid_receive(address.uuid)
  239. return puppet
  240. @classmethod
  241. async def _get_by_address(cls, address: Address, create: bool = True) -> Optional['Puppet']:
  242. if not address.is_valid:
  243. raise ValueError("Empty address")
  244. if address.uuid:
  245. try:
  246. return cls.by_uuid[address.uuid]
  247. except KeyError:
  248. pass
  249. if address.number:
  250. try:
  251. return cls.by_number[address.number]
  252. except KeyError:
  253. pass
  254. puppet = cast(cls, await super().get_by_address(address))
  255. if puppet is not None:
  256. puppet._add_to_cache()
  257. return puppet
  258. if create:
  259. puppet = cls(address.uuid, address.number)
  260. await puppet.insert()
  261. puppet._add_to_cache()
  262. return puppet
  263. return None
  264. @classmethod
  265. async def all_with_custom_mxid(cls) -> AsyncGenerator['Puppet', None]:
  266. puppets = await super().all_with_custom_mxid()
  267. puppet: cls
  268. for index, puppet in enumerate(puppets):
  269. try:
  270. yield cls.by_uuid[puppet.uuid]
  271. except KeyError:
  272. try:
  273. yield cls.by_number[puppet.number]
  274. except KeyError:
  275. puppet._add_to_cache()
  276. yield puppet
  277. # endregion