puppet.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import (Optional, Dict, AsyncIterable, Awaitable, AsyncGenerator, Union,
  17. TYPE_CHECKING, cast)
  18. from uuid import UUID
  19. import asyncio
  20. from yarl import URL
  21. from mausignald.types import Address, Contact, Profile
  22. from mautrix.bridge import BasePuppet
  23. from mautrix.appservice import IntentAPI
  24. from mautrix.types import UserID, SyncToken, RoomID
  25. from mautrix.util.simple_template import SimpleTemplate
  26. from .db import Puppet as DBPuppet
  27. from .config import Config
  28. from . import portal as p, user as u
  29. if TYPE_CHECKING:
  30. from .__main__ import SignalBridge
  31. try:
  32. import phonenumbers
  33. except ImportError:
  34. phonenumbers = None
  35. class Puppet(DBPuppet, BasePuppet):
  36. by_uuid: Dict[UUID, 'Puppet'] = {}
  37. by_number: Dict[str, 'Puppet'] = {}
  38. by_custom_mxid: Dict[UserID, 'Puppet'] = {}
  39. hs_domain: str
  40. mxid_template: SimpleTemplate[str]
  41. config: Config
  42. default_mxid_intent: IntentAPI
  43. default_mxid: UserID
  44. _uuid_lock: asyncio.Lock
  45. _update_info_lock: asyncio.Lock
  46. def __init__(self, uuid: Optional[UUID], number: Optional[str], name: Optional[str] = None,
  47. uuid_registered: bool = False, number_registered: bool = False,
  48. custom_mxid: Optional[UserID] = None, access_token: Optional[str] = None,
  49. next_batch: Optional[SyncToken] = None, base_url: Optional[URL] = None) -> None:
  50. super().__init__(uuid=uuid, number=number, name=name, uuid_registered=uuid_registered,
  51. number_registered=number_registered, custom_mxid=custom_mxid,
  52. access_token=access_token, next_batch=next_batch, base_url=base_url)
  53. self.log = self.log.getChild(str(uuid) or number)
  54. self.default_mxid = self.get_mxid_from_id(self.address)
  55. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  56. self.intent = self._fresh_intent()
  57. self._uuid_lock = asyncio.Lock()
  58. self._update_info_lock = asyncio.Lock()
  59. @classmethod
  60. def init_cls(cls, bridge: 'SignalBridge') -> AsyncIterable[Awaitable[None]]:
  61. cls.config = bridge.config
  62. cls.loop = bridge.loop
  63. cls.mx = bridge.matrix
  64. cls.az = bridge.az
  65. cls.hs_domain = cls.config["homeserver.domain"]
  66. cls.mxid_template = SimpleTemplate(cls.config["bridge.username_template"], "userid",
  67. prefix="@", suffix=f":{cls.hs_domain}", type=str)
  68. cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
  69. cls.homeserver_url_map = {server: URL(url) for server, url
  70. in cls.config["bridge.double_puppet_server_map"].items()}
  71. cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
  72. cls.login_shared_secret_map = {server: secret.encode("utf-8") for server, secret
  73. in cls.config["bridge.login_shared_secret_map"].items()}
  74. cls.login_device_name = "Signal Bridge"
  75. return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
  76. def intent_for(self, portal: 'p.Portal') -> IntentAPI:
  77. if portal.chat_id == self.address:
  78. return self.default_mxid_intent
  79. return self.intent
  80. @property
  81. def is_registered(self) -> bool:
  82. return self.uuid_registered if self.uuid is not None else self.number_registered
  83. @is_registered.setter
  84. def is_registered(self, value: bool) -> None:
  85. if self.uuid is not None:
  86. self.uuid_registered = value
  87. else:
  88. self.number_registered = value
  89. @property
  90. def address(self) -> Address:
  91. return Address(uuid=self.uuid, number=self.number)
  92. async def handle_uuid_receive(self, uuid: UUID) -> None:
  93. async with self._uuid_lock:
  94. if self.uuid:
  95. # Received UUID was handled while this call was waiting
  96. return
  97. await self._handle_uuid_receive(uuid)
  98. async def _handle_uuid_receive(self, uuid: UUID) -> None:
  99. self.log.debug(f"Found UUID for user: {uuid}")
  100. user = await u.User.get_by_username(self.number)
  101. if user and not user.uuid:
  102. user.uuid = self.uuid
  103. await user.update()
  104. await self._set_uuid(uuid)
  105. self.by_uuid[self.uuid] = self
  106. async for portal in p.Portal.find_private_chats_with(Address(number=self.number)):
  107. self.log.trace(f"Updating chat_id of private chat portal {portal.receiver}")
  108. portal.handle_uuid_receive(self.uuid)
  109. prev_intent = self.default_mxid_intent
  110. self.default_mxid = self.get_mxid_from_id(self.address)
  111. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  112. self.intent = self._fresh_intent()
  113. await self.intent.ensure_registered()
  114. if self.name:
  115. await self.intent.set_displayname(self.name)
  116. self.log = self.log.getChild(str(uuid))
  117. self.log.debug(f"Migrating memberships {prev_intent.mxid} -> {self.default_mxid_intent}")
  118. for room_id in await prev_intent.get_joined_rooms():
  119. await prev_intent.invite_user(room_id, self.default_mxid)
  120. await self.default_mxid_intent.join_room_by_id(room_id)
  121. await prev_intent.leave_room(room_id)
  122. async def update_info(self, info: Union[Profile, Contact]) -> None:
  123. if isinstance(info, (Contact, Address)):
  124. address = info.address if isinstance(info, Contact) else info
  125. if address.uuid and not self.uuid:
  126. await self.handle_uuid_receive(address.uuid)
  127. if not self.config["bridge.allow_contact_list_name_updates"] and self.name is not None:
  128. return
  129. name = info.name if isinstance(info, (Contact, Profile)) else None
  130. async with self._update_info_lock:
  131. update = False
  132. update = await self._update_name(name) or update
  133. if update:
  134. await self.update()
  135. @staticmethod
  136. def fmt_phone(number: str) -> str:
  137. if phonenumbers is None:
  138. return number
  139. parsed = phonenumbers.parse(number)
  140. fmt = phonenumbers.PhoneNumberFormat.INTERNATIONAL
  141. return phonenumbers.format_number(parsed, fmt)
  142. @classmethod
  143. def _get_displayname(cls, address: Address, name: Optional[str]) -> str:
  144. names = name.split("\x00") if name else []
  145. data = {
  146. "first_name": names[0] if len(names) > 0 else "",
  147. "last_name": names[-1] if len(names) > 1 else "",
  148. "full_name": " ".join(names),
  149. "phone": cls.fmt_phone(address.number) if address.number else None,
  150. "uuid": str(address.uuid) if address.uuid else None,
  151. "displayname": "Unknown user",
  152. }
  153. for pref in cls.config["bridge.displayname_preference"]:
  154. value = data.get(pref.replace(" ", "_"))
  155. if value:
  156. data["displayname"] = value
  157. break
  158. return cls.config["bridge.displayname_template"].format(**data)
  159. async def _update_name(self, name: Optional[str]) -> bool:
  160. name = self._get_displayname(self.address, name)
  161. if name != self.name:
  162. self.name = name
  163. await self.default_mxid_intent.set_displayname(self.name)
  164. self.loop.create_task(self._update_portal_names())
  165. return True
  166. return False
  167. async def _update_portal_names(self) -> None:
  168. async for portal in p.Portal.find_private_chats_with(self.address):
  169. if portal.receiver == self.number:
  170. # This is a note to self chat, don't change the name
  171. continue
  172. await portal.update_puppet_name(self.name)
  173. async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
  174. portal = await p.Portal.get_by_mxid(room_id)
  175. return portal and portal.chat_id != self.uuid
  176. # region Database getters
  177. def _add_to_cache(self) -> None:
  178. if self.uuid:
  179. self.by_uuid[self.uuid] = self
  180. if self.number:
  181. self.by_number[self.number] = self
  182. if self.custom_mxid:
  183. self.by_custom_mxid[self.custom_mxid] = self
  184. async def save(self) -> None:
  185. await self.update()
  186. @classmethod
  187. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']:
  188. address = cls.get_id_from_mxid(mxid)
  189. if not address:
  190. return None
  191. return await cls.get_by_address(address, create)
  192. @classmethod
  193. async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
  194. try:
  195. return cls.by_custom_mxid[mxid]
  196. except KeyError:
  197. pass
  198. puppet = cast(cls, await super().get_by_custom_mxid(mxid))
  199. if puppet:
  200. puppet._add_to_cache()
  201. return puppet
  202. return None
  203. @classmethod
  204. def get_id_from_mxid(cls, mxid: UserID) -> Optional[Address]:
  205. identifier = cls.mxid_template.parse(mxid)
  206. if not identifier:
  207. return None
  208. if identifier.startswith("phone_"):
  209. return Address(number="+" + identifier[len("phone_"):])
  210. else:
  211. try:
  212. return Address(uuid=UUID(identifier.upper()))
  213. except ValueError:
  214. return None
  215. @classmethod
  216. def get_mxid_from_id(cls, address: Address) -> UserID:
  217. if address.uuid:
  218. identifier = str(address.uuid).lower()
  219. elif address.number:
  220. identifier = f"phone_{address.number.lstrip('+')}"
  221. else:
  222. raise ValueError("Empty address")
  223. return UserID(cls.mxid_template.format_full(identifier))
  224. @classmethod
  225. async def get_by_address(cls, address: Address, create: bool = True) -> Optional['Puppet']:
  226. puppet = await cls._get_by_address(address, create)
  227. if puppet and address.uuid and not puppet.uuid:
  228. # We found a UUID for this user, store it ASAP
  229. await puppet.handle_uuid_receive(address.uuid)
  230. return puppet
  231. @classmethod
  232. async def _get_by_address(cls, address: Address, create: bool = True) -> Optional['Puppet']:
  233. if not address.is_valid:
  234. raise ValueError("Empty address")
  235. if address.uuid:
  236. try:
  237. return cls.by_uuid[address.uuid]
  238. except KeyError:
  239. pass
  240. if address.number:
  241. try:
  242. return cls.by_number[address.number]
  243. except KeyError:
  244. pass
  245. puppet = cast(cls, await super().get_by_address(address))
  246. if puppet is not None:
  247. puppet._add_to_cache()
  248. return puppet
  249. if create:
  250. puppet = cls(address.uuid, address.number)
  251. await puppet.insert()
  252. puppet._add_to_cache()
  253. return puppet
  254. return None
  255. @classmethod
  256. async def all_with_custom_mxid(cls) -> AsyncGenerator['Puppet', None]:
  257. puppets = await super().all_with_custom_mxid()
  258. puppet: cls
  259. for index, puppet in enumerate(puppets):
  260. try:
  261. yield cls.by_uuid[puppet.uuid]
  262. except KeyError:
  263. try:
  264. yield cls.by_number[puppet.number]
  265. except KeyError:
  266. puppet._add_to_cache()
  267. yield puppet
  268. # endregion