user.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import Dict, Optional, AsyncGenerator, Union, TYPE_CHECKING, cast
  17. from collections import defaultdict
  18. from uuid import UUID
  19. import asyncio
  20. from mausignald.types import Account, Address, Contact, Group, GroupV2, ListenEvent, ListenAction
  21. from mautrix.bridge import BaseUser
  22. from mautrix.types import UserID, RoomID
  23. from mautrix.appservice import AppService
  24. from mautrix.util.opt_prometheus import Gauge
  25. from .db import User as DBUser
  26. from .config import Config
  27. from . import puppet as pu, portal as po
  28. if TYPE_CHECKING:
  29. from .__main__ import SignalBridge
  30. METRIC_CONNECTED = Gauge('bridge_connected', 'Bridge users connected to Signal')
  31. class User(DBUser, BaseUser):
  32. by_mxid: Dict[UserID, 'User'] = {}
  33. by_username: Dict[str, 'User'] = {}
  34. config: Config
  35. az: AppService
  36. loop: asyncio.AbstractEventLoop
  37. bridge: 'SignalBridge'
  38. is_admin: bool
  39. permission_level: str
  40. _notice_room_lock: asyncio.Lock
  41. def __init__(self, mxid: UserID, username: Optional[str] = None, uuid: Optional[UUID] = None,
  42. notice_room: Optional[RoomID] = None) -> None:
  43. super().__init__(mxid=mxid, username=username, uuid=uuid, notice_room=notice_room)
  44. self._notice_room_lock = asyncio.Lock()
  45. perms = self.config.get_permissions(mxid)
  46. self.is_whitelisted, self.is_admin, self.permission_level = perms
  47. self.log = self.log.getChild(self.mxid)
  48. self.dm_update_lock = asyncio.Lock()
  49. self.command_status = None
  50. self._metric_value = defaultdict(lambda: False)
  51. @classmethod
  52. def init_cls(cls, bridge: 'SignalBridge') -> None:
  53. cls.bridge = bridge
  54. cls.config = bridge.config
  55. cls.az = bridge.az
  56. cls.loop = bridge.loop
  57. @property
  58. def address(self) -> Optional[Address]:
  59. if not self.username:
  60. return None
  61. return Address(uuid=self.uuid, number=self.username)
  62. async def is_logged_in(self) -> bool:
  63. return bool(self.username)
  64. async def on_signin(self, account: Account) -> None:
  65. self.username = account.username
  66. self.uuid = account.uuid
  67. await self.update()
  68. await self.bridge.signal.subscribe(self.username)
  69. self.loop.create_task(self.sync())
  70. def on_listen(self, evt: ListenEvent) -> None:
  71. if evt.action == ListenAction.STARTED:
  72. self.log.info("Connected to Signal")
  73. self._track_metric(METRIC_CONNECTED, True)
  74. elif evt.action == ListenAction.STOPPED:
  75. if evt.exception:
  76. self.log.warning(f"Disconnected from Signal: {evt.exception}")
  77. else:
  78. self.log.info("Disconnected from Signal")
  79. self._track_metric(METRIC_CONNECTED, False)
  80. else:
  81. self.log.warning(f"Unrecognized listen action {evt.action}")
  82. async def _sync_puppet(self) -> None:
  83. puppet = await pu.Puppet.get_by_address(self.address)
  84. if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
  85. self.log.info(f"Automatically enabling custom puppet")
  86. await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
  87. async def sync(self) -> None:
  88. try:
  89. await self._sync_puppet()
  90. except Exception:
  91. self.log.exception("Error while syncing own puppet")
  92. try:
  93. await self._sync()
  94. except Exception:
  95. self.log.exception("Error while syncing")
  96. async def _sync_contact(self, contact: Contact, create_portals: bool) -> None:
  97. self.log.trace("Syncing contact %s", contact)
  98. puppet = await pu.Puppet.get_by_address(contact.address)
  99. if not puppet.name:
  100. profile = await self.bridge.signal.get_profile(self.username, contact.address)
  101. if profile and profile.name:
  102. self.log.trace("Got profile for %s: %s", contact.address, profile)
  103. else:
  104. profile = None
  105. else:
  106. # get_profile probably does a request to the servers, so let's not do that unless
  107. # necessary, but maybe we could listen for updates?
  108. profile = None
  109. await puppet.update_info(profile or contact)
  110. if create_portals:
  111. portal = await po.Portal.get_by_chat_id(puppet.address, self.username, create=True)
  112. await portal.create_matrix_room(self, profile or contact)
  113. async def _sync_group(self, group: Group, create_portals: bool) -> None:
  114. self.log.trace("Syncing group %s", group)
  115. portal = await po.Portal.get_by_chat_id(group.group_id, create=True)
  116. if create_portals:
  117. await portal.create_matrix_room(self, group)
  118. elif portal.mxid:
  119. await portal.update_matrix_room(self, group)
  120. async def _sync_group_v2(self, group: GroupV2, create_portals: bool) -> None:
  121. self.log.trace("Syncing group %s", group.id)
  122. portal = await po.Portal.get_by_chat_id(group.id, create=True)
  123. if create_portals:
  124. await portal.create_matrix_room(self, group)
  125. elif portal.mxid:
  126. await portal.update_matrix_room(self, group)
  127. async def _sync(self) -> None:
  128. create_contact_portal = self.config["bridge.autocreate_contact_portal"]
  129. for contact in await self.bridge.signal.list_contacts(self.username):
  130. try:
  131. await self._sync_contact(contact, create_contact_portal)
  132. except Exception:
  133. self.log.exception(f"Failed to sync contact {contact.address}")
  134. create_group_portal = self.config["bridge.autocreate_group_portal"]
  135. for group in await self.bridge.signal.list_groups(self.username):
  136. try:
  137. if isinstance(group, Group):
  138. await self._sync_group(group, create_group_portal)
  139. elif isinstance(group, GroupV2):
  140. await self._sync_group_v2(group, create_group_portal)
  141. else:
  142. self.log.warning("Unknown return type in list_groups: %s", type(group))
  143. except Exception:
  144. self.log.exception(f"Failed to sync group {group.group_id}")
  145. # region Database getters
  146. def _add_to_cache(self) -> None:
  147. self.by_mxid[self.mxid] = self
  148. if self.username:
  149. self.by_username[self.username] = self
  150. @classmethod
  151. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['User']:
  152. # Never allow ghosts to be users
  153. if pu.Puppet.get_id_from_mxid(mxid):
  154. return None
  155. try:
  156. return cls.by_mxid[mxid]
  157. except KeyError:
  158. pass
  159. user = cast(cls, await super().get_by_mxid(mxid))
  160. if user is not None:
  161. user._add_to_cache()
  162. return user
  163. if create:
  164. user = cls(mxid)
  165. await user.insert()
  166. user._add_to_cache()
  167. return user
  168. return None
  169. @classmethod
  170. async def get_by_username(cls, username: str) -> Optional['User']:
  171. try:
  172. return cls.by_username[username]
  173. except KeyError:
  174. pass
  175. user = cast(cls, await super().get_by_username(username))
  176. if user is not None:
  177. user._add_to_cache()
  178. return user
  179. return None
  180. @classmethod
  181. async def all_logged_in(cls) -> AsyncGenerator['User', None]:
  182. users = await super().all_logged_in()
  183. user: cls
  184. for user in users:
  185. try:
  186. yield cls.by_mxid[user.mxid]
  187. except KeyError:
  188. user._add_to_cache()
  189. yield user
  190. # endregion