user.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. # mautrix-instagram - A Matrix-Instagram puppeting bridge.
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import (Dict, Optional, AsyncIterable, Awaitable, AsyncGenerator, List, TYPE_CHECKING,
  17. cast)
  18. from collections import defaultdict
  19. import asyncio
  20. import logging
  21. import time
  22. from mauigpapi import AndroidAPI, AndroidState, AndroidMQTT
  23. from mauigpapi.mqtt import Connect, Disconnect, GraphQLSubscription, SkywalkerSubscription
  24. from mauigpapi.types import (CurrentUser, MessageSyncEvent, Operation, RealtimeDirectEvent,
  25. ActivityIndicatorData, TypingStatus, ThreadSyncEvent, Thread)
  26. from mauigpapi.errors import IGNotLoggedInError, MQTTNotLoggedIn, MQTTNotConnected
  27. from mautrix.bridge import BaseUser, async_getter_lock
  28. from mautrix.types import UserID, RoomID, EventID, TextMessageEventContent, MessageType
  29. from mautrix.appservice import AppService
  30. from mautrix.util.opt_prometheus import Summary, Gauge, async_time
  31. from mautrix.util.logging import TraceLogger
  32. from .db import User as DBUser, Portal as DBPortal
  33. from .config import Config
  34. from . import puppet as pu, portal as po
  35. if TYPE_CHECKING:
  36. from .__main__ import InstagramBridge
  37. METRIC_MESSAGE = Summary("bridge_on_message", "calls to handle_message")
  38. METRIC_THREAD_SYNC = Summary("bridge_on_thread_sync", "calls to handle_thread_sync")
  39. METRIC_RTD = Summary("bridge_on_rtd", "calls to handle_rtd")
  40. METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Users logged into the bridge")
  41. METRIC_CONNECTED = Gauge("bridge_connected", "Bridged users connected to Instagram")
  42. class User(DBUser, BaseUser):
  43. ig_base_log: TraceLogger = logging.getLogger("mau.instagram")
  44. _activity_indicator_ids: Dict[str, int] = {}
  45. by_mxid: Dict[UserID, 'User'] = {}
  46. by_igpk: Dict[int, 'User'] = {}
  47. config: Config
  48. az: AppService
  49. loop: asyncio.AbstractEventLoop
  50. client: Optional[AndroidAPI]
  51. mqtt: Optional[AndroidMQTT]
  52. _listen_task: Optional[asyncio.Task] = None
  53. permission_level: str
  54. username: Optional[str]
  55. _notice_room_lock: asyncio.Lock
  56. _notice_send_lock: asyncio.Lock
  57. _is_logged_in: bool
  58. _is_connected: bool
  59. shutdown: bool
  60. remote_typing_status: Optional[TypingStatus]
  61. def __init__(self, mxid: UserID, igpk: Optional[int] = None,
  62. state: Optional[AndroidState] = None, notice_room: Optional[RoomID] = None
  63. ) -> None:
  64. super().__init__(mxid=mxid, igpk=igpk, state=state, notice_room=notice_room)
  65. self._notice_room_lock = asyncio.Lock()
  66. self._notice_send_lock = asyncio.Lock()
  67. perms = self.config.get_permissions(mxid)
  68. self.is_whitelisted, self.is_admin, self.permission_level = perms
  69. self.log = self.log.getChild(self.mxid)
  70. self.client = None
  71. self.mqtt = None
  72. self.username = None
  73. self.dm_update_lock = asyncio.Lock()
  74. self._metric_value = defaultdict(lambda: False)
  75. self._is_logged_in = False
  76. self._is_connected = False
  77. self.shutdown = False
  78. self._listen_task = None
  79. self.command_status = None
  80. self.remote_typing_status = None
  81. @classmethod
  82. def init_cls(cls, bridge: 'InstagramBridge') -> AsyncIterable[Awaitable[None]]:
  83. cls.bridge = bridge
  84. cls.config = bridge.config
  85. cls.az = bridge.az
  86. cls.loop = bridge.loop
  87. return (user.try_connect() async for user in cls.all_logged_in())
  88. # region Connection management
  89. async def is_logged_in(self) -> bool:
  90. return bool(self.client) and self._is_logged_in
  91. async def try_connect(self) -> None:
  92. try:
  93. await self.connect()
  94. except Exception:
  95. self.log.exception("Error while connecting to Instagram")
  96. @property
  97. def api_log(self) -> TraceLogger:
  98. return self.ig_base_log.getChild("http").getChild(self.mxid)
  99. @property
  100. def is_connected(self) -> bool:
  101. return bool(self.client) and bool(self.mqtt) and self._is_connected
  102. async def connect(self) -> None:
  103. client = AndroidAPI(self.state, log=self.api_log)
  104. try:
  105. resp = await client.current_user()
  106. except IGNotLoggedInError as e:
  107. self.log.warning(f"Failed to connect to Instagram: {e}")
  108. # TODO show reason?
  109. await self.send_bridge_notice("You have been logged out of Instagram",
  110. important=True)
  111. return
  112. self.client = client
  113. self._is_logged_in = True
  114. self.igpk = resp.user.pk
  115. self.username = resp.user.username
  116. self._track_metric(METRIC_LOGGED_IN, True)
  117. self.by_igpk[self.igpk] = self
  118. self.mqtt = AndroidMQTT(self.state, loop=self.loop,
  119. log=self.ig_base_log.getChild("mqtt").getChild(self.mxid))
  120. self.mqtt.add_event_handler(Connect, self.on_connect)
  121. self.mqtt.add_event_handler(Disconnect, self.on_disconnect)
  122. self.mqtt.add_event_handler(MessageSyncEvent, self.handle_message)
  123. self.mqtt.add_event_handler(ThreadSyncEvent, self.handle_thread_sync)
  124. self.mqtt.add_event_handler(RealtimeDirectEvent, self.handle_rtd)
  125. await self.update()
  126. self.loop.create_task(self._try_sync_puppet(resp.user))
  127. self.loop.create_task(self._try_sync())
  128. async def on_connect(self, evt: Connect) -> None:
  129. self.log.debug("Connected to Instagram")
  130. self._track_metric(METRIC_CONNECTED, True)
  131. self._is_connected = True
  132. await self.send_bridge_notice("Connected to Instagram")
  133. async def on_disconnect(self, evt: Disconnect) -> None:
  134. self.log.debug("Disconnected from Instagram")
  135. self._track_metric(METRIC_CONNECTED, False)
  136. self._is_connected = False
  137. # TODO this stuff could probably be moved to mautrix-python
  138. async def get_notice_room(self) -> RoomID:
  139. if not self.notice_room:
  140. async with self._notice_room_lock:
  141. # If someone already created the room while this call was waiting,
  142. # don't make a new room
  143. if self.notice_room:
  144. return self.notice_room
  145. self.notice_room = await self.az.intent.create_room(
  146. is_direct=True, invitees=[self.mxid],
  147. topic="Instagram bridge notices")
  148. await self.update()
  149. return self.notice_room
  150. async def send_bridge_notice(self, text: str, edit: Optional[EventID] = None,
  151. important: bool = False) -> Optional[EventID]:
  152. if not important and not self.config["bridge.unimportant_bridge_notices"]:
  153. self.log.debug("Not sending unimportant bridge notice: %s", text)
  154. return
  155. event_id = None
  156. try:
  157. self.log.debug("Sending bridge notice: %s", text)
  158. content = TextMessageEventContent(body=text, msgtype=(MessageType.TEXT if important
  159. else MessageType.NOTICE))
  160. if edit:
  161. content.set_edit(edit)
  162. # This is locked to prevent notices going out in the wrong order
  163. async with self._notice_send_lock:
  164. event_id = await self.az.intent.send_message(await self.get_notice_room(), content)
  165. except Exception:
  166. self.log.warning("Failed to send bridge notice", exc_info=True)
  167. return edit or event_id
  168. async def _try_sync_puppet(self, user_info: CurrentUser) -> None:
  169. puppet = await pu.Puppet.get_by_pk(self.igpk)
  170. try:
  171. await puppet.update_info(user_info, self)
  172. except Exception:
  173. self.log.exception("Failed to update own puppet info")
  174. try:
  175. if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
  176. self.log.info(f"Automatically enabling custom puppet")
  177. await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
  178. except Exception:
  179. self.log.exception("Failed to automatically enable custom puppet")
  180. async def _try_sync(self) -> None:
  181. try:
  182. await self.sync()
  183. except Exception:
  184. self.log.exception("Exception while syncing")
  185. async def get_direct_chats(self) -> Dict[UserID, List[RoomID]]:
  186. return {
  187. pu.Puppet.get_mxid_from_id(portal.other_user_pk): [portal.mxid]
  188. for portal in await DBPortal.find_private_chats_of(self.igpk)
  189. if portal.mxid
  190. }
  191. async def refresh(self, resync: bool = True) -> None:
  192. await self.stop_listen()
  193. if resync:
  194. await self.sync()
  195. else:
  196. await self.start_listen()
  197. async def _sync_thread(self, thread: Thread, min_active_at: int) -> None:
  198. portal = await po.Portal.get_by_thread(thread, self.igpk)
  199. if portal.mxid:
  200. self.log.debug(f"{thread.thread_id} has a portal, syncing and backfilling...")
  201. await portal.update_matrix_room(self, thread, backfill=True)
  202. elif thread.last_activity_at > min_active_at:
  203. self.log.debug(f"{thread.thread_id} has been active recently, creating portal...")
  204. await portal.create_matrix_room(self, thread)
  205. else:
  206. self.log.debug(f"{thread.thread_id} is not active and doesn't have a portal")
  207. async def sync(self) -> None:
  208. resp = await self.client.get_inbox()
  209. max_age = self.config["bridge.portal_create_max_age"] * 1_000_000
  210. limit = self.config["bridge.chat_sync_limit"]
  211. min_active_at = (time.time() * 1_000_000) - max_age
  212. i = 0
  213. async for thread in self.client.iter_inbox(start_at=resp):
  214. try:
  215. await self._sync_thread(thread, min_active_at)
  216. except Exception:
  217. self.log.exception(f"Error syncing thread {thread.thread_id}")
  218. i += 1
  219. if i >= limit:
  220. break
  221. try:
  222. await self.update_direct_chats()
  223. except Exception:
  224. self.log.exception("Error updating direct chat list")
  225. if not self._listen_task:
  226. await self.start_listen(resp.seq_id, resp.snapshot_at_ms)
  227. async def start_listen(self, seq_id: Optional[int] = None, snapshot_at_ms: Optional[int] = None
  228. ) -> None:
  229. self.shutdown = False
  230. if not seq_id:
  231. resp = await self.client.get_inbox(limit=1)
  232. seq_id, snapshot_at_ms = resp.seq_id, resp.snapshot_at_ms
  233. task = self.listen(seq_id=seq_id, snapshot_at_ms=snapshot_at_ms)
  234. self._listen_task = self.loop.create_task(task)
  235. async def listen(self, seq_id: int, snapshot_at_ms: int) -> None:
  236. try:
  237. await self.mqtt.listen(
  238. graphql_subs={GraphQLSubscription.app_presence(),
  239. GraphQLSubscription.direct_typing(self.state.user_id),
  240. GraphQLSubscription.direct_status()},
  241. skywalker_subs={SkywalkerSubscription.direct_sub(self.state.user_id),
  242. SkywalkerSubscription.live_sub(self.state.user_id)},
  243. seq_id=seq_id, snapshot_at_ms=snapshot_at_ms)
  244. except (MQTTNotConnected, MQTTNotLoggedIn) as e:
  245. await self.send_bridge_notice(f"Error in listener: {e}", important=True)
  246. self.mqtt.disconnect()
  247. except Exception:
  248. self.log.exception("Fatal error in listener")
  249. await self.send_bridge_notice("Fatal error in listener (see logs for more info)",
  250. important=True)
  251. self.mqtt.disconnect()
  252. else:
  253. if not self.shutdown:
  254. await self.send_bridge_notice("Instagram connection closed without error")
  255. finally:
  256. self._listen_task = None
  257. self._is_connected = False
  258. self._track_metric(METRIC_CONNECTED, False)
  259. async def stop_listen(self) -> None:
  260. if self.mqtt:
  261. self.shutdown = True
  262. self.mqtt.disconnect()
  263. if self._listen_task:
  264. await self._listen_task
  265. self.shutdown = False
  266. self._track_metric(METRIC_CONNECTED, False)
  267. self._is_connected = False
  268. await self.update()
  269. async def logout(self) -> None:
  270. if self.client:
  271. try:
  272. await self.client.logout(one_tap_app_login=False)
  273. except Exception:
  274. self.log.debug("Exception logging out", exc_info=True)
  275. if self.mqtt:
  276. self.mqtt.disconnect()
  277. self._track_metric(METRIC_CONNECTED, False)
  278. self._track_metric(METRIC_LOGGED_IN, False)
  279. puppet = await pu.Puppet.get_by_pk(self.igpk, create=False)
  280. if puppet and puppet.is_real_user:
  281. await puppet.switch_mxid(None, None)
  282. try:
  283. del self.by_igpk[self.igpk]
  284. except KeyError:
  285. pass
  286. self.client = None
  287. self.mqtt = None
  288. self.state = None
  289. self.igpk = None
  290. self._is_logged_in = False
  291. await self.update()
  292. # endregion
  293. # region Event handlers
  294. @async_time(METRIC_MESSAGE)
  295. async def handle_message(self, evt: MessageSyncEvent) -> None:
  296. portal = await po.Portal.get_by_thread_id(evt.message.thread_id, receiver=self.igpk)
  297. if not portal or not portal.mxid:
  298. self.log.debug("Got message in thread with no portal, getting info...")
  299. resp = await self.client.get_thread(evt.message.thread_id)
  300. portal = await po.Portal.get_by_thread(resp.thread, self.igpk)
  301. self.log.debug("Got info for unknown portal, creating room")
  302. await portal.create_matrix_room(self, resp.thread)
  303. if not portal.mxid:
  304. self.log.warning("Room creation appears to have failed, "
  305. f"dropping message in {evt.message.thread_id}")
  306. return
  307. self.log.trace(f"Received message sync event {evt.message}")
  308. sender = await pu.Puppet.get_by_pk(evt.message.user_id) if evt.message.user_id else None
  309. if evt.message.op == Operation.ADD:
  310. if not sender:
  311. # I don't think we care about adds with no sender
  312. return
  313. await portal.handle_instagram_item(self, sender, evt.message)
  314. elif evt.message.op == Operation.REMOVE:
  315. # Removes don't have a sender, only the message sender can unsend messages anyway
  316. await portal.handle_instagram_remove(evt.message.item_id)
  317. elif evt.message.op == Operation.REPLACE:
  318. await portal.handle_instagram_update(evt.message)
  319. @async_time(METRIC_THREAD_SYNC)
  320. async def handle_thread_sync(self, evt: ThreadSyncEvent) -> None:
  321. self.log.trace("Received thread sync event %s", evt)
  322. portal = await po.Portal.get_by_thread(evt, receiver=self.igpk)
  323. await portal.create_matrix_room(self, evt)
  324. @async_time(METRIC_RTD)
  325. async def handle_rtd(self, evt: RealtimeDirectEvent) -> None:
  326. if not isinstance(evt.value, ActivityIndicatorData):
  327. return
  328. now = int(time.time() * 1000)
  329. date = int(evt.value.timestamp) // 1000
  330. expiry = date + evt.value.ttl
  331. if expiry < now:
  332. return
  333. if evt.activity_indicator_id in self._activity_indicator_ids:
  334. return
  335. # TODO clear expired items from this dict
  336. self._activity_indicator_ids[evt.activity_indicator_id] = expiry
  337. puppet = await pu.Puppet.get_by_pk(int(evt.value.sender_id))
  338. portal = await po.Portal.get_by_thread_id(evt.thread_id, receiver=self.igpk)
  339. if not puppet or not portal or not portal.mxid:
  340. return
  341. is_typing = evt.value.activity_status != TypingStatus.OFF
  342. if puppet.pk == self.igpk:
  343. self.remote_typing_status = TypingStatus.TEXT if is_typing else TypingStatus.OFF
  344. await puppet.intent_for(portal).set_typing(portal.mxid, is_typing=is_typing,
  345. timeout=evt.value.ttl)
  346. # endregion
  347. # region Database getters
  348. def _add_to_cache(self) -> None:
  349. self.by_mxid[self.mxid] = self
  350. if self.igpk:
  351. self.by_igpk[self.igpk] = self
  352. @classmethod
  353. @async_getter_lock
  354. async def get_by_mxid(cls, mxid: UserID, *, create: bool = True) -> Optional['User']:
  355. # Never allow ghosts to be users
  356. if pu.Puppet.get_id_from_mxid(mxid):
  357. return None
  358. try:
  359. return cls.by_mxid[mxid]
  360. except KeyError:
  361. pass
  362. user = cast(cls, await super().get_by_mxid(mxid))
  363. if user is not None:
  364. user._add_to_cache()
  365. return user
  366. if create:
  367. user = cls(mxid)
  368. await user.insert()
  369. user._add_to_cache()
  370. return user
  371. return None
  372. @classmethod
  373. @async_getter_lock
  374. async def get_by_igpk(cls, igpk: int) -> Optional['User']:
  375. try:
  376. return cls.by_igpk[igpk]
  377. except KeyError:
  378. pass
  379. user = cast(cls, await super().get_by_igpk(igpk))
  380. if user is not None:
  381. user._add_to_cache()
  382. return user
  383. return None
  384. @classmethod
  385. async def all_logged_in(cls) -> AsyncGenerator['User', None]:
  386. users = await super().all_logged_in()
  387. user: cls
  388. for index, user in enumerate(users):
  389. try:
  390. yield cls.by_mxid[user.mxid]
  391. except KeyError:
  392. user._add_to_cache()
  393. yield user
  394. # endregion