user.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. # mautrix-instagram - A Matrix-Instagram puppeting bridge.
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import (Dict, Optional, AsyncIterable, Awaitable, AsyncGenerator, List, TYPE_CHECKING,
  17. cast)
  18. from collections import defaultdict
  19. import asyncio
  20. import logging
  21. from mauigpapi.mqtt import (AndroidMQTT, Connect, Disconnect, GraphQLSubscription,
  22. SkywalkerSubscription)
  23. from mauigpapi.http import AndroidAPI
  24. from mauigpapi.state import AndroidState
  25. from mauigpapi.types import CurrentUser, MessageSyncEvent
  26. from mauigpapi.errors import IGNotLoggedInError
  27. from mautrix.bridge import BaseUser
  28. from mautrix.types import UserID, RoomID, EventID, TextMessageEventContent, MessageType
  29. from mautrix.appservice import AppService
  30. from mautrix.util.opt_prometheus import Summary, Gauge, async_time
  31. from .db import User as DBUser, Portal as DBPortal
  32. from .config import Config
  33. from . import puppet as pu, portal as po
  34. if TYPE_CHECKING:
  35. from .__main__ import InstagramBridge
  36. METRIC_MESSAGE = Summary("bridge_on_message", "calls to handle_message")
  37. METRIC_RECEIPT = Summary("bridge_on_receipt", "calls to handle_receipt")
  38. METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Users logged into the bridge")
  39. METRIC_CONNECTED = Gauge("bridge_connected", "Bridged users connected to Instagram")
  40. class User(DBUser, BaseUser):
  41. by_mxid: Dict[UserID, 'User'] = {}
  42. by_igpk: Dict[int, 'User'] = {}
  43. config: Config
  44. az: AppService
  45. loop: asyncio.AbstractEventLoop
  46. client: Optional[AndroidAPI]
  47. mqtt: Optional[AndroidMQTT]
  48. _listen_task: Optional[asyncio.Task] = None
  49. permission_level: str
  50. username: Optional[str]
  51. _notice_room_lock: asyncio.Lock
  52. _notice_send_lock: asyncio.Lock
  53. _is_logged_in: bool
  54. def __init__(self, mxid: UserID, igpk: Optional[int] = None,
  55. state: Optional[AndroidState] = None, notice_room: Optional[RoomID] = None
  56. ) -> None:
  57. super().__init__(mxid=mxid, igpk=igpk, state=state, notice_room=notice_room)
  58. self._notice_room_lock = asyncio.Lock()
  59. self._notice_send_lock = asyncio.Lock()
  60. perms = self.config.get_permissions(mxid)
  61. self.is_whitelisted, self.is_admin, self.permission_level = perms
  62. self.log = self.log.getChild(self.mxid)
  63. self.client = None
  64. self.username = None
  65. self.dm_update_lock = asyncio.Lock()
  66. self._metric_value = defaultdict(lambda: False)
  67. self._is_logged_in = False
  68. self._listen_task = None
  69. self.command_status = None
  70. @classmethod
  71. def init_cls(cls, bridge: 'InstagramBridge') -> AsyncIterable[Awaitable[None]]:
  72. cls.bridge = bridge
  73. cls.config = bridge.config
  74. cls.az = bridge.az
  75. cls.loop = bridge.loop
  76. return (user.try_connect() async for user in cls.all_logged_in())
  77. # region Connection management
  78. async def is_logged_in(self) -> bool:
  79. return bool(self.client) and self._is_logged_in
  80. async def try_connect(self) -> None:
  81. try:
  82. await self.connect()
  83. except Exception:
  84. self.log.exception("Error while connecting to Instagram")
  85. async def connect(self) -> None:
  86. client = AndroidAPI(self.state)
  87. try:
  88. resp = await client.current_user()
  89. except IGNotLoggedInError as e:
  90. self.log.warning(f"Failed to connect to Instagram: {e}")
  91. # TODO show reason?
  92. await self.send_bridge_notice("You have been logged out of Instagram")
  93. return
  94. self.client = client
  95. self._is_logged_in = True
  96. self.igpk = resp.user.pk
  97. self.username = resp.user.username
  98. self._track_metric(METRIC_LOGGED_IN, True)
  99. self.by_igpk[self.igpk] = self
  100. self.mqtt = AndroidMQTT(self.state, loop=self.loop,
  101. log=logging.getLogger("mau.instagram.mqtt").getChild(self.mxid))
  102. self.mqtt.add_event_handler(Connect, self.on_connect)
  103. self.mqtt.add_event_handler(Disconnect, self.on_disconnect)
  104. self.mqtt.add_event_handler(MessageSyncEvent, self.handle_message)
  105. await self.update()
  106. self.loop.create_task(self._try_sync_puppet(resp.user))
  107. self.loop.create_task(self._try_sync())
  108. async def on_connect(self, evt: Connect) -> None:
  109. self._track_metric(METRIC_CONNECTED, True)
  110. async def on_disconnect(self, evt: Disconnect) -> None:
  111. self._track_metric(METRIC_CONNECTED, False)
  112. # TODO this stuff could probably be moved to mautrix-python
  113. async def get_notice_room(self) -> RoomID:
  114. if not self.notice_room:
  115. async with self._notice_room_lock:
  116. # If someone already created the room while this call was waiting,
  117. # don't make a new room
  118. if self.notice_room:
  119. return self.notice_room
  120. self.notice_room = await self.az.intent.create_room(
  121. is_direct=True, invitees=[self.mxid],
  122. topic="Instagram bridge notices")
  123. await self.update()
  124. return self.notice_room
  125. async def send_bridge_notice(self, text: str, edit: Optional[EventID] = None,
  126. important: bool = False) -> Optional[EventID]:
  127. event_id = None
  128. try:
  129. self.log.debug("Sending bridge notice: %s", text)
  130. content = TextMessageEventContent(body=text, msgtype=(MessageType.TEXT if important
  131. else MessageType.NOTICE))
  132. if edit:
  133. content.set_edit(edit)
  134. # This is locked to prevent notices going out in the wrong order
  135. async with self._notice_send_lock:
  136. event_id = await self.az.intent.send_message(await self.get_notice_room(), content)
  137. except Exception:
  138. self.log.warning("Failed to send bridge notice", exc_info=True)
  139. return edit or event_id
  140. async def _try_sync_puppet(self, user_info: CurrentUser) -> None:
  141. puppet = await pu.Puppet.get_by_pk(self.igpk)
  142. try:
  143. await puppet.update_info(user_info)
  144. except Exception:
  145. self.log.exception("Failed to update own puppet info")
  146. try:
  147. if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
  148. self.log.info(f"Automatically enabling custom puppet")
  149. await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
  150. except Exception:
  151. self.log.exception("Failed to automatically enable custom puppet")
  152. async def _try_sync(self) -> None:
  153. try:
  154. await self.sync()
  155. except Exception:
  156. self.log.exception("Exception while syncing")
  157. async def get_direct_chats(self) -> Dict[UserID, List[RoomID]]:
  158. return {
  159. pu.Puppet.get_mxid_from_id(portal.other_user_pk): [portal.mxid]
  160. for portal in await DBPortal.find_private_chats_of(self.igpk)
  161. if portal.mxid
  162. }
  163. async def sync(self) -> None:
  164. resp = await self.client.get_inbox()
  165. limit = self.config["bridge.initial_conversation_sync"]
  166. threads = sorted(resp.inbox.threads, key=lambda thread: thread.last_activity_at)
  167. if limit < 0:
  168. limit = len(threads)
  169. for i, thread in enumerate(threads):
  170. portal = await po.Portal.get_by_thread(thread, self.igpk)
  171. if portal.mxid or i < limit:
  172. await portal.create_matrix_room(self, thread)
  173. await self.update_direct_chats()
  174. self._listen_task = self.loop.create_task(self.mqtt.listen(
  175. graphql_subs={GraphQLSubscription.app_presence(),
  176. GraphQLSubscription.direct_typing(self.state.user_id),
  177. GraphQLSubscription.direct_status()},
  178. skywalker_subs={SkywalkerSubscription.direct_sub(self.state.user_id),
  179. SkywalkerSubscription.live_sub(self.state.user_id)},
  180. seq_id=resp.seq_id, snapshot_at_ms=resp.snapshot_at_ms))
  181. async def stop(self) -> None:
  182. if self.mqtt:
  183. self.mqtt.disconnect()
  184. self._track_metric(METRIC_CONNECTED, False)
  185. await self.update()
  186. async def logout(self) -> None:
  187. if self.mqtt:
  188. self.mqtt.disconnect()
  189. self._track_metric(METRIC_CONNECTED, False)
  190. self._track_metric(METRIC_LOGGED_IN, False)
  191. puppet = await pu.Puppet.get_by_pk(self.igpk, create=False)
  192. if puppet and puppet.is_real_user:
  193. await puppet.switch_mxid(None, None)
  194. try:
  195. del self.by_igpk[self.igpk]
  196. except KeyError:
  197. pass
  198. self.client = None
  199. self.mqtt = None
  200. self.state = None
  201. self._is_logged_in = False
  202. await self.update()
  203. # endregion
  204. # region Event handlers
  205. @async_time(METRIC_MESSAGE)
  206. async def handle_message(self, evt: MessageSyncEvent) -> None:
  207. # We don't care about messages with no sender
  208. if not evt.message.user_id:
  209. return
  210. portal = await po.Portal.get_by_thread_id(evt.message.thread_id, receiver=self.igpk)
  211. if not portal.mxid:
  212. # TODO try to find the thread?
  213. self.log.warning(f"Ignoring message to unknown thread {evt.message.thread_id}")
  214. return
  215. sender = await pu.Puppet.get_by_pk(evt.message.user_id)
  216. await portal.handle_instagram_item(self, sender, evt.message)
  217. # @async_time(METRIC_RECEIPT)
  218. # async def handle_receipt(self, evt: ConversationReadEntry) -> None:
  219. # portal = await po.Portal.get_by_twid(evt.conversation_id, self.twid,
  220. # conv_type=evt.conversation.type)
  221. # if not portal.mxid:
  222. # return
  223. # sender = await pu.Puppet.get_by_twid(self.twid)
  224. # await portal.handle_twitter_receipt(sender, int(evt.last_read_event_id))
  225. # endregion
  226. # region Database getters
  227. def _add_to_cache(self) -> None:
  228. self.by_mxid[self.mxid] = self
  229. if self.igpk:
  230. self.by_igpk[self.igpk] = self
  231. @classmethod
  232. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['User']:
  233. # Never allow ghosts to be users
  234. if pu.Puppet.get_id_from_mxid(mxid):
  235. return None
  236. try:
  237. return cls.by_mxid[mxid]
  238. except KeyError:
  239. pass
  240. user = cast(cls, await super().get_by_mxid(mxid))
  241. if user is not None:
  242. user._add_to_cache()
  243. return user
  244. if create:
  245. user = cls(mxid)
  246. await user.insert()
  247. user._add_to_cache()
  248. return user
  249. return None
  250. @classmethod
  251. async def get_by_igpk(cls, igpk: int) -> Optional['User']:
  252. try:
  253. return cls.by_igpk[igpk]
  254. except KeyError:
  255. pass
  256. user = cast(cls, await super().get_by_igpk(igpk))
  257. if user is not None:
  258. user._add_to_cache()
  259. return user
  260. return None
  261. @classmethod
  262. async def all_logged_in(cls) -> AsyncGenerator['User', None]:
  263. users = await super().all_logged_in()
  264. user: cls
  265. for index, user in enumerate(users):
  266. try:
  267. yield cls.by_mxid[user.mxid]
  268. except KeyError:
  269. user._add_to_cache()
  270. yield user
  271. # endregion