user.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. # mautrix-instagram - A Matrix-Instagram puppeting bridge.
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import (Dict, Optional, AsyncIterable, Awaitable, AsyncGenerator, List, TYPE_CHECKING,
  17. cast)
  18. from collections import defaultdict
  19. import asyncio
  20. import logging
  21. from mauigpapi.mqtt import (AndroidMQTT, Connect, Disconnect, GraphQLSubscription,
  22. SkywalkerSubscription)
  23. from mauigpapi.http import AndroidAPI
  24. from mauigpapi.state import AndroidState
  25. from mauigpapi.types import CurrentUser, MessageSyncEvent
  26. from mauigpapi.errors import IGNotLoggedInError
  27. from mautrix.bridge import BaseUser
  28. from mautrix.types import UserID, RoomID, EventID, TextMessageEventContent, MessageType
  29. from mautrix.appservice import AppService
  30. from mautrix.util.opt_prometheus import Summary, Gauge, async_time
  31. from .db import User as DBUser, Portal as DBPortal
  32. from .config import Config
  33. from . import puppet as pu, portal as po
  34. if TYPE_CHECKING:
  35. from .__main__ import InstagramBridge
  36. METRIC_MESSAGE = Summary("bridge_on_message", "calls to handle_message")
  37. METRIC_RECEIPT = Summary("bridge_on_receipt", "calls to handle_receipt")
  38. METRIC_LOGGED_IN = Gauge("bridge_logged_in", "Users logged into the bridge")
  39. METRIC_CONNECTED = Gauge("bridge_connected", "Bridged users connected to Instagram")
  40. class User(DBUser, BaseUser):
  41. by_mxid: Dict[UserID, 'User'] = {}
  42. by_igpk: Dict[int, 'User'] = {}
  43. config: Config
  44. az: AppService
  45. loop: asyncio.AbstractEventLoop
  46. client: Optional[AndroidAPI]
  47. mqtt: Optional[AndroidMQTT]
  48. _listen_task: Optional[asyncio.Task] = None
  49. permission_level: str
  50. username: Optional[str]
  51. _notice_room_lock: asyncio.Lock
  52. _notice_send_lock: asyncio.Lock
  53. _is_logged_in: bool
  54. def __init__(self, mxid: UserID, igpk: Optional[int] = None,
  55. state: Optional[AndroidState] = None, notice_room: Optional[RoomID] = None
  56. ) -> None:
  57. super().__init__(mxid=mxid, igpk=igpk, state=state, notice_room=notice_room)
  58. self._notice_room_lock = asyncio.Lock()
  59. self._notice_send_lock = asyncio.Lock()
  60. perms = self.config.get_permissions(mxid)
  61. self.is_whitelisted, self.is_admin, self.permission_level = perms
  62. self.log = self.log.getChild(self.mxid)
  63. self.client = None
  64. self.username = None
  65. self.dm_update_lock = asyncio.Lock()
  66. self._metric_value = defaultdict(lambda: False)
  67. self._is_logged_in = False
  68. self._listen_task = None
  69. @classmethod
  70. def init_cls(cls, bridge: 'InstagramBridge') -> AsyncIterable[Awaitable[None]]:
  71. cls.bridge = bridge
  72. cls.config = bridge.config
  73. cls.az = bridge.az
  74. cls.loop = bridge.loop
  75. return (user.try_connect() async for user in cls.all_logged_in())
  76. # region Connection management
  77. async def is_logged_in(self) -> bool:
  78. return bool(self.client) and self._is_logged_in
  79. async def try_connect(self) -> None:
  80. try:
  81. await self.connect()
  82. except Exception:
  83. self.log.exception("Error while connecting to Instagram")
  84. async def connect(self) -> None:
  85. client = AndroidAPI(self.state)
  86. try:
  87. resp = await client.current_user()
  88. except IGNotLoggedInError as e:
  89. self.log.warning(f"Failed to connect to Instagram: {e}")
  90. # TODO show reason?
  91. await self.send_bridge_notice("You have been logged out of Instagram")
  92. return
  93. self.client = client
  94. self.igpk = resp.user.pk
  95. self.username = resp.user.username
  96. self._track_metric(METRIC_LOGGED_IN, True)
  97. self.by_igpk[self.igpk] = self
  98. self.mqtt = AndroidMQTT(self.state, loop=self.loop,
  99. log=logging.getLogger("mau.instagram.mqtt").getChild(self.mxid))
  100. self.mqtt.add_event_handler(Connect, self.on_connect)
  101. self.mqtt.add_event_handler(Disconnect, self.on_disconnect)
  102. self.mqtt.add_event_handler(MessageSyncEvent, self.handle_message)
  103. await self.update()
  104. self.loop.create_task(self._try_sync_puppet(resp.user))
  105. self.loop.create_task(self._try_sync())
  106. async def on_connect(self, evt: Connect) -> None:
  107. self._track_metric(METRIC_CONNECTED, True)
  108. async def on_disconnect(self, evt: Disconnect) -> None:
  109. self._track_metric(METRIC_CONNECTED, False)
  110. # TODO this stuff could probably be moved to mautrix-python
  111. async def get_notice_room(self) -> RoomID:
  112. if not self.notice_room:
  113. async with self._notice_room_lock:
  114. # If someone already created the room while this call was waiting,
  115. # don't make a new room
  116. if self.notice_room:
  117. return self.notice_room
  118. self.notice_room = await self.az.intent.create_room(
  119. is_direct=True, invitees=[self.mxid],
  120. topic="Instagram bridge notices")
  121. await self.update()
  122. return self.notice_room
  123. async def send_bridge_notice(self, text: str, edit: Optional[EventID] = None,
  124. important: bool = False) -> Optional[EventID]:
  125. event_id = None
  126. try:
  127. self.log.debug("Sending bridge notice: %s", text)
  128. content = TextMessageEventContent(body=text, msgtype=(MessageType.TEXT if important
  129. else MessageType.NOTICE))
  130. if edit:
  131. content.set_edit(edit)
  132. # This is locked to prevent notices going out in the wrong order
  133. async with self._notice_send_lock:
  134. event_id = await self.az.intent.send_message(await self.get_notice_room(), content)
  135. except Exception:
  136. self.log.warning("Failed to send bridge notice", exc_info=True)
  137. return edit or event_id
  138. async def _try_sync_puppet(self, user_info: CurrentUser) -> None:
  139. puppet = await pu.Puppet.get_by_pk(self.igpk)
  140. try:
  141. await puppet.update_info(user_info)
  142. except Exception:
  143. self.log.exception("Failed to update own puppet info")
  144. try:
  145. if puppet.custom_mxid != self.mxid and puppet.can_auto_login(self.mxid):
  146. self.log.info(f"Automatically enabling custom puppet")
  147. await puppet.switch_mxid(access_token="auto", mxid=self.mxid)
  148. except Exception:
  149. self.log.exception("Failed to automatically enable custom puppet")
  150. async def _try_sync(self) -> None:
  151. try:
  152. await self.sync()
  153. except Exception:
  154. self.log.exception("Exception while syncing")
  155. async def get_direct_chats(self) -> Dict[UserID, List[RoomID]]:
  156. return {
  157. pu.Puppet.get_mxid_from_id(portal.other_user_pk): [portal.mxid]
  158. for portal in await DBPortal.find_private_chats_of(self.igpk)
  159. if portal.mxid
  160. }
  161. async def sync(self) -> None:
  162. resp = await self.client.get_inbox()
  163. limit = self.config["bridge.initial_conversation_sync"]
  164. threads = sorted(resp.inbox.threads, key=lambda thread: thread.last_activity_at)
  165. if limit < 0:
  166. limit = len(threads)
  167. for i, thread in enumerate(threads):
  168. portal = await po.Portal.get_by_thread(thread, self.igpk)
  169. if portal.mxid or i < limit:
  170. await portal.create_matrix_room(self, thread)
  171. await self.update_direct_chats()
  172. self._listen_task = self.loop.create_task(self.mqtt.listen(
  173. graphql_subs={GraphQLSubscription.app_presence(),
  174. GraphQLSubscription.direct_typing(self.state.user_id),
  175. GraphQLSubscription.direct_status()},
  176. skywalker_subs={SkywalkerSubscription.direct_sub(self.state.user_id),
  177. SkywalkerSubscription.live_sub(self.state.user_id)},
  178. seq_id=resp.seq_id, snapshot_at_ms=resp.snapshot_at_ms))
  179. async def stop(self) -> None:
  180. if self.mqtt:
  181. self.mqtt.disconnect()
  182. self._track_metric(METRIC_CONNECTED, False)
  183. await self.update()
  184. async def logout(self) -> None:
  185. if self.mqtt:
  186. self.mqtt.disconnect()
  187. self._track_metric(METRIC_CONNECTED, False)
  188. self._track_metric(METRIC_LOGGED_IN, False)
  189. puppet = await pu.Puppet.get_by_pk(self.igpk, create=False)
  190. if puppet and puppet.is_real_user:
  191. await puppet.switch_mxid(None, None)
  192. try:
  193. del self.by_igpk[self.igpk]
  194. except KeyError:
  195. pass
  196. self.client = None
  197. self.mqtt = None
  198. self.state = None
  199. self._is_logged_in = False
  200. await self.update()
  201. # endregion
  202. # region Event handlers
  203. @async_time(METRIC_MESSAGE)
  204. async def handle_message(self, evt: MessageSyncEvent) -> None:
  205. portal = await po.Portal.get_by_thread_id(evt.message.thread_id, receiver=self.igpk)
  206. if not portal.mxid:
  207. # TODO try to find the thread?
  208. self.log.warning(f"Ignoring message to unknown thread {evt.message.thread_id}")
  209. return
  210. sender = await pu.Puppet.get_by_pk(evt.message.user_id)
  211. await portal.handle_instagram_item(self, sender, evt.message)
  212. # @async_time(METRIC_RECEIPT)
  213. # async def handle_receipt(self, evt: ConversationReadEntry) -> None:
  214. # portal = await po.Portal.get_by_twid(evt.conversation_id, self.twid,
  215. # conv_type=evt.conversation.type)
  216. # if not portal.mxid:
  217. # return
  218. # sender = await pu.Puppet.get_by_twid(self.twid)
  219. # await portal.handle_twitter_receipt(sender, int(evt.last_read_event_id))
  220. # endregion
  221. # region Database getters
  222. def _add_to_cache(self) -> None:
  223. self.by_mxid[self.mxid] = self
  224. if self.igpk:
  225. self.by_igpk[self.igpk] = self
  226. @classmethod
  227. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['User']:
  228. # Never allow ghosts to be users
  229. if pu.Puppet.get_id_from_mxid(mxid):
  230. return None
  231. try:
  232. return cls.by_mxid[mxid]
  233. except KeyError:
  234. pass
  235. user = cast(cls, await super().get_by_mxid(mxid))
  236. if user is not None:
  237. user._add_to_cache()
  238. return user
  239. if create:
  240. user = cls(mxid)
  241. await user.insert()
  242. user._add_to_cache()
  243. return user
  244. return None
  245. @classmethod
  246. async def get_by_igpk(cls, igpk: int) -> Optional['User']:
  247. try:
  248. return cls.by_igpk[igpk]
  249. except KeyError:
  250. pass
  251. user = cast(cls, await super().get_by_igpk(igpk))
  252. if user is not None:
  253. user._add_to_cache()
  254. return user
  255. return None
  256. @classmethod
  257. async def all_logged_in(cls) -> AsyncGenerator['User', None]:
  258. users = await super().all_logged_in()
  259. user: cls
  260. for index, user in enumerate(users):
  261. try:
  262. yield cls.by_mxid[user.mxid]
  263. except KeyError:
  264. user._add_to_cache()
  265. yield user
  266. # endregion