conn.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868
  1. # mautrix-instagram - A Matrix-Instagram puppeting bridge.
  2. # Copyright (C) 2022 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import Any, Awaitable, Callable, Iterable, Type, TypeVar
  18. from collections import defaultdict
  19. from socket import error as SocketError, socket
  20. import asyncio
  21. import json
  22. import logging
  23. import re
  24. import time
  25. import urllib.request
  26. import zlib
  27. from paho.mqtt.client import MQTTMessage, WebsocketConnectionError
  28. from yarl import URL
  29. import paho.mqtt.client
  30. from mautrix.util.logging import TraceLogger
  31. from ..errors import IrisSubscribeError, MQTTNotConnected, MQTTNotLoggedIn
  32. from ..state import AndroidState
  33. from ..types import (
  34. AppPresenceEventPayload,
  35. ClientConfigUpdatePayload,
  36. CommandResponse,
  37. IrisPayload,
  38. IrisPayloadData,
  39. LiveVideoCommentPayload,
  40. MessageSyncEvent,
  41. MessageSyncMessage,
  42. PubsubEvent,
  43. PubsubPayload,
  44. ReactionStatus,
  45. RealtimeDirectEvent,
  46. RealtimeZeroProvisionPayload,
  47. ThreadAction,
  48. ThreadItemType,
  49. ThreadSyncEvent,
  50. TypingStatus,
  51. )
  52. from .events import Connect, Disconnect, NewSequenceID
  53. from .otclient import MQTToTClient
  54. from .subscription import GraphQLQueryID, RealtimeTopic, everclear_subscriptions
  55. from .thrift import ForegroundStateConfig, IncomingMessage, RealtimeClientInfo, RealtimeConfig
  56. try:
  57. import socks
  58. except ImportError:
  59. socks = None
  60. T = TypeVar("T")
  61. ACTIVITY_INDICATOR_REGEX = re.compile(
  62. r"/direct_v2/threads/([\w_]+)/activity_indicator_id/([\w_]+)"
  63. )
  64. INBOX_THREAD_REGEX = re.compile(r"/direct_v2/inbox/threads/([\w_]+)")
  65. class AndroidMQTT:
  66. _loop: asyncio.AbstractEventLoop
  67. _client: MQTToTClient
  68. log: TraceLogger
  69. state: AndroidState
  70. _graphql_subs: set[str]
  71. _skywalker_subs: set[str]
  72. _iris_seq_id: int | None
  73. _iris_snapshot_at_ms: int | None
  74. _publish_waiters: dict[int, asyncio.Future]
  75. _response_waiters: dict[RealtimeTopic, asyncio.Future]
  76. _response_waiter_locks: dict[RealtimeTopic, asyncio.Lock]
  77. _message_response_waiters: dict[str, asyncio.Future]
  78. _disconnect_error: Exception | None
  79. _event_handlers: dict[Type[T], list[Callable[[T], Awaitable[None]]]]
  80. _outgoing_events: asyncio.Queue
  81. _event_dispatcher_task: asyncio.Task | None
  82. # region Initialization
  83. def __init__(
  84. self,
  85. state: AndroidState,
  86. loop: asyncio.AbstractEventLoop | None = None,
  87. log: TraceLogger | None = None,
  88. ) -> None:
  89. self._graphql_subs = set()
  90. self._skywalker_subs = set()
  91. self._iris_seq_id = None
  92. self._iris_snapshot_at_ms = None
  93. self._publish_waiters = {}
  94. self._response_waiters = {}
  95. self._message_response_waiters = {}
  96. self._disconnect_error = None
  97. self._response_waiter_locks = defaultdict(lambda: asyncio.Lock())
  98. self._event_handlers = defaultdict(lambda: [])
  99. self._event_dispatcher_task = None
  100. self._outgoing_events = asyncio.Queue()
  101. self.log = log or logging.getLogger("mauigpapi.mqtt")
  102. self._loop = loop or asyncio.get_event_loop()
  103. self.state = state
  104. self._client = MQTToTClient(
  105. client_id=self._form_client_id(),
  106. clean_session=True,
  107. protocol=paho.mqtt.client.MQTTv31,
  108. transport="tcp",
  109. )
  110. try:
  111. http_proxy = urllib.request.getproxies()["http"]
  112. except KeyError:
  113. http_proxy = None
  114. if http_proxy and socks and URL:
  115. proxy_url = URL(http_proxy)
  116. proxy_type = {
  117. "http": socks.HTTP,
  118. "https": socks.HTTP,
  119. "socks": socks.SOCKS5,
  120. "socks5": socks.SOCKS5,
  121. "socks4": socks.SOCKS4,
  122. }[proxy_url.scheme]
  123. self._client.proxy_set(
  124. proxy_type=proxy_type,
  125. proxy_addr=proxy_url.host,
  126. proxy_port=proxy_url.port,
  127. proxy_username=proxy_url.user,
  128. proxy_password=proxy_url.password,
  129. )
  130. self._client.enable_logger()
  131. self._client.tls_set()
  132. # mqtt.max_inflight_messages_set(20) # The rest will get queued
  133. # mqtt.max_queued_messages_set(0) # Unlimited messages can be queued
  134. # mqtt.message_retry_set(20) # Retry sending for at least 20 seconds
  135. # mqtt.reconnect_delay_set(min_delay=1, max_delay=120)
  136. self._client.connect_async("edge-mqtt.facebook.com", 443, keepalive=60)
  137. self._client.on_message = self._on_message_handler
  138. self._client.on_publish = self._on_publish_handler
  139. self._client.on_connect = self._on_connect_handler
  140. # self._client.on_disconnect = self._on_disconnect_handler
  141. self._client.on_socket_open = self._on_socket_open
  142. self._client.on_socket_close = self._on_socket_close
  143. self._client.on_socket_register_write = self._on_socket_register_write
  144. self._client.on_socket_unregister_write = self._on_socket_unregister_write
  145. def _form_client_id(self) -> bytes:
  146. subscribe_topics = [
  147. RealtimeTopic.PUBSUB,
  148. RealtimeTopic.SUB_IRIS_RESPONSE,
  149. RealtimeTopic.REALTIME_SUB,
  150. RealtimeTopic.REGION_HINT,
  151. RealtimeTopic.SEND_MESSAGE_RESPONSE,
  152. RealtimeTopic.MESSAGE_SYNC,
  153. RealtimeTopic.UNKNOWN_179,
  154. RealtimeTopic.UNKNOWN_PP,
  155. ]
  156. subscribe_topic_ids = [int(topic.encoded) for topic in subscribe_topics]
  157. password = f"sessionid={self.state.cookies['sessionid']}"
  158. cfg = RealtimeConfig(
  159. client_identifier=self.state.device.phone_id[:20],
  160. client_info=RealtimeClientInfo(
  161. user_id=int(self.state.user_id),
  162. user_agent=self.state.user_agent,
  163. client_capabilities=0b10110111,
  164. endpoint_capabilities=0,
  165. publish_format=1,
  166. no_automatic_foreground=True,
  167. make_user_available_in_foreground=False,
  168. device_id=self.state.device.phone_id,
  169. is_initially_foreground=False,
  170. network_type=1,
  171. network_subtype=0,
  172. client_mqtt_session_id=int(time.time() * 1000) & 0xFFFFFFFF,
  173. subscribe_topics=subscribe_topic_ids,
  174. client_type="cookie_auth",
  175. app_id=567067343352427,
  176. region_preference=self.state.session.region_hint or "LLA",
  177. device_secret="",
  178. client_stack=3,
  179. ),
  180. password=password,
  181. app_specific_info={
  182. "app_version": self.state.application.APP_VERSION,
  183. "X-IG-Capabilities": self.state.application.CAPABILITIES,
  184. "everclear_subscriptions": json.dumps(everclear_subscriptions),
  185. "User-Agent": self.state.user_agent,
  186. "Accept-Language": self.state.device.language.replace("_", "-"),
  187. "platform": "android",
  188. "ig_mqtt_route": "django",
  189. "pubsub_msg_type_blacklist": "direct, typing_type",
  190. "auth_cache_enabled": "0",
  191. },
  192. )
  193. return zlib.compress(cfg.to_thrift(), level=9)
  194. # endregion
  195. def _on_socket_open(self, client: MQTToTClient, _: Any, sock: socket) -> None:
  196. self._loop.add_reader(sock, client.loop_read)
  197. def _on_socket_close(self, client: MQTToTClient, _: Any, sock: socket) -> None:
  198. self._loop.remove_reader(sock)
  199. def _on_socket_register_write(self, client: MQTToTClient, _: Any, sock: socket) -> None:
  200. self._loop.add_writer(sock, client.loop_write)
  201. def _on_socket_unregister_write(self, client: MQTToTClient, _: Any, sock: socket) -> None:
  202. self._loop.remove_writer(sock)
  203. def _on_connect_handler(
  204. self, client: MQTToTClient, _: Any, flags: dict[str, Any], rc: int
  205. ) -> None:
  206. if rc != 0:
  207. err = paho.mqtt.client.connack_string(rc)
  208. self.log.error("MQTT Connection Error: %s (%d)", err, rc)
  209. return
  210. self._loop.create_task(self._post_connect())
  211. async def _post_connect(self) -> None:
  212. await self._dispatch(Connect())
  213. self.log.debug("Re-subscribing to things after connect")
  214. if self._graphql_subs:
  215. res = await self.graphql_subscribe(self._graphql_subs)
  216. self.log.trace("GraphQL subscribe response: %s", res)
  217. if self._skywalker_subs:
  218. res = await self.skywalker_subscribe(self._skywalker_subs)
  219. self.log.trace("Skywalker subscribe response: %s", res)
  220. if self._iris_seq_id:
  221. retry = 0
  222. while True:
  223. try:
  224. await self.iris_subscribe(self._iris_seq_id, self._iris_snapshot_at_ms)
  225. break
  226. except (asyncio.TimeoutError, IrisSubscribeError) as e:
  227. self.log.exception("Error requesting iris subscribe")
  228. retry += 1
  229. if retry >= 5 or isinstance(e, IrisSubscribeError):
  230. self._disconnect_error = e
  231. self.disconnect()
  232. break
  233. await asyncio.sleep(5)
  234. self.log.debug("Retrying iris subscribe")
  235. def _on_publish_handler(self, client: MQTToTClient, _: Any, mid: int) -> None:
  236. try:
  237. waiter = self._publish_waiters[mid]
  238. except KeyError:
  239. self.log.trace(f"Got publish confirmation for {mid}, but no waiters")
  240. return
  241. self.log.trace(f"Got publish confirmation for {mid}")
  242. waiter.set_result(None)
  243. # region Incoming event parsing
  244. def _parse_direct_thread_path(self, path: str) -> dict:
  245. try:
  246. blank, direct_v2, threads, thread_id, *rest = path.split("/")
  247. except (ValueError, IndexError) as e:
  248. self.log.debug(f"Got {e!r} while parsing path {path}")
  249. raise
  250. if (blank, direct_v2, threads) != ("", "direct_v2", "threads"):
  251. self.log.debug(f"Got unexpected first parts in direct thread path {path}")
  252. raise ValueError("unexpected first three parts in _parse_direct_thread_path")
  253. additional = {"thread_id": thread_id}
  254. if rest:
  255. subitem_key = rest[0]
  256. if subitem_key == "approval_required_for_new_members":
  257. additional["approval_required_for_new_members"] = True
  258. elif subitem_key == "participants" and len(rest) > 2 and rest[2] == "has_seen":
  259. additional["has_seen"] = int(rest[1])
  260. elif subitem_key == "items":
  261. additional["item_id"] = rest[1]
  262. # TODO wtf is this?
  263. # it has something to do with reactions
  264. if len(rest) > 4:
  265. additional[rest[2]] = {
  266. rest[3]: rest[4],
  267. }
  268. elif subitem_key in "admin_user_ids":
  269. additional["admin_user_id"] = int(rest[1])
  270. elif subitem_key == "activity_indicator_id":
  271. additional["activity_indicator_id"] = rest[1]
  272. self.log.trace("Parsed path %s -> %s", path, additional)
  273. return additional
  274. def _on_messager_sync_item(self, part: IrisPayloadData, parsed_item: IrisPayload) -> bool:
  275. if part.path.startswith("/direct_v2/threads/"):
  276. raw_message = {
  277. "path": part.path,
  278. "op": part.op,
  279. **self._parse_direct_thread_path(part.path),
  280. }
  281. try:
  282. raw_message = {
  283. **raw_message,
  284. **json.loads(part.value),
  285. }
  286. except (json.JSONDecodeError, TypeError):
  287. raw_message["value"] = part.value
  288. message = MessageSyncMessage.deserialize(raw_message)
  289. evt = MessageSyncEvent(iris=parsed_item, message=message)
  290. elif part.path.startswith("/direct_v2/inbox/threads/"):
  291. raw_message = {
  292. "path": part.path,
  293. "op": part.op,
  294. **json.loads(part.value),
  295. }
  296. evt = ThreadSyncEvent.deserialize(raw_message)
  297. else:
  298. self.log.warning(f"Unsupported path {part.path}")
  299. return False
  300. self._outgoing_events.put_nowait(evt)
  301. return True
  302. def _on_message_sync(self, payload: bytes) -> None:
  303. parsed = json.loads(payload.decode("utf-8"))
  304. self.log.trace("Got message sync event: %s", parsed)
  305. has_items = False
  306. for sync_item in parsed:
  307. parsed_item = IrisPayload.deserialize(sync_item)
  308. if self._iris_seq_id < parsed_item.seq_id:
  309. self.log.trace(f"Got new seq_id: {parsed_item.seq_id}")
  310. self._iris_seq_id = parsed_item.seq_id
  311. self._iris_snapshot_at_ms = int(time.time() * 1000)
  312. asyncio.create_task(
  313. self._dispatch(NewSequenceID(self._iris_seq_id, self._iris_snapshot_at_ms))
  314. )
  315. for part in parsed_item.data:
  316. has_items = self._on_messager_sync_item(part, parsed_item) or has_items
  317. if has_items and not self._event_dispatcher_task:
  318. self._event_dispatcher_task = asyncio.create_task(self._dispatcher_loop())
  319. def _on_pubsub(self, payload: bytes) -> None:
  320. parsed_thrift = IncomingMessage.from_thrift(payload)
  321. self.log.trace(f"Got pubsub event {parsed_thrift.topic} / {parsed_thrift.payload}")
  322. message = PubsubPayload.parse_json(parsed_thrift.payload)
  323. for data in message.data:
  324. match = ACTIVITY_INDICATOR_REGEX.match(data.path)
  325. if match:
  326. evt = PubsubEvent(
  327. data=data,
  328. base=message,
  329. thread_id=match.group(1),
  330. activity_indicator_id=match.group(2),
  331. )
  332. self._loop.create_task(self._dispatch(evt))
  333. elif not data.double_publish:
  334. self.log.debug("Pubsub: no activity indicator on data: %s", data)
  335. else:
  336. self.log.debug("Pubsub: double publish: %s", data.path)
  337. def _parse_realtime_sub_item(self, topic: str | GraphQLQueryID, raw: dict) -> Iterable[Any]:
  338. if topic == GraphQLQueryID.APP_PRESENCE:
  339. yield AppPresenceEventPayload.deserialize(raw).presence_event
  340. elif topic == GraphQLQueryID.ZERO_PROVISION:
  341. yield RealtimeZeroProvisionPayload.deserialize(raw).zero_product_provisioning_event
  342. elif topic == GraphQLQueryID.CLIENT_CONFIG_UPDATE:
  343. yield ClientConfigUpdatePayload.deserialize(raw).client_config_update_event
  344. elif topic == GraphQLQueryID.LIVE_REALTIME_COMMENTS:
  345. yield LiveVideoCommentPayload.deserialize(raw).live_video_comment_event
  346. elif topic == "direct":
  347. event = raw["event"]
  348. for item in raw["data"]:
  349. yield RealtimeDirectEvent.deserialize(
  350. {
  351. "event": event,
  352. **self._parse_direct_thread_path(item["path"]),
  353. **item,
  354. }
  355. )
  356. def _on_realtime_sub(self, payload: bytes) -> None:
  357. parsed_thrift = IncomingMessage.from_thrift(payload)
  358. try:
  359. topic = GraphQLQueryID(parsed_thrift.topic)
  360. except ValueError:
  361. topic = parsed_thrift.topic
  362. self.log.trace(f"Got realtime sub event {topic} / {parsed_thrift.payload}")
  363. allowed = (
  364. "direct",
  365. GraphQLQueryID.APP_PRESENCE,
  366. GraphQLQueryID.ZERO_PROVISION,
  367. GraphQLQueryID.CLIENT_CONFIG_UPDATE,
  368. GraphQLQueryID.LIVE_REALTIME_COMMENTS,
  369. )
  370. if topic not in allowed:
  371. return
  372. parsed_json = json.loads(parsed_thrift.payload)
  373. for evt in self._parse_realtime_sub_item(topic, parsed_json):
  374. self._loop.create_task(self._dispatch(evt))
  375. def _on_message_handler(self, client: MQTToTClient, _: Any, message: MQTTMessage) -> None:
  376. try:
  377. topic = RealtimeTopic.decode(message.topic)
  378. # Instagram Android MQTT messages are always compressed
  379. message.payload = zlib.decompress(message.payload)
  380. if topic == RealtimeTopic.MESSAGE_SYNC:
  381. self._on_message_sync(message.payload)
  382. elif topic == RealtimeTopic.PUBSUB:
  383. self._on_pubsub(message.payload)
  384. elif topic == RealtimeTopic.REALTIME_SUB:
  385. self._on_realtime_sub(message.payload)
  386. elif topic == RealtimeTopic.SEND_MESSAGE_RESPONSE:
  387. try:
  388. data = json.loads(message.payload.decode("utf-8"))
  389. ccid = data["payload"]["client_context"]
  390. waiter = self._message_response_waiters.pop(ccid)
  391. except KeyError as e:
  392. self.log.debug(
  393. "No handler (%s) for send message response: %s", e, message.payload
  394. )
  395. else:
  396. self.log.trace("Got response to %s: %s", ccid, message.payload)
  397. waiter.set_result(message)
  398. else:
  399. try:
  400. waiter = self._response_waiters.pop(topic)
  401. except KeyError:
  402. self.log.debug(
  403. "No handler for MQTT message in %s: %s", topic.value, message.payload
  404. )
  405. else:
  406. self.log.trace("Got response %s: %s", topic.value, message.payload)
  407. waiter.set_result(message)
  408. except Exception:
  409. self.log.exception("Error in incoming MQTT message handler")
  410. self.log.trace("Errored MQTT payload: %s", message.payload)
  411. # endregion
  412. async def _reconnect(self) -> None:
  413. try:
  414. self.log.trace("Trying to reconnect to MQTT")
  415. self._client.reconnect()
  416. except (SocketError, OSError, WebsocketConnectionError) as e:
  417. # TODO custom class
  418. raise MQTTNotLoggedIn("MQTT reconnection failed") from e
  419. def add_event_handler(
  420. self, evt_type: Type[T], handler: Callable[[T], Awaitable[None]]
  421. ) -> None:
  422. self._event_handlers[evt_type].append(handler)
  423. async def _dispatch(self, evt: T) -> None:
  424. for handler in self._event_handlers[type(evt)]:
  425. try:
  426. await handler(evt)
  427. except Exception:
  428. self.log.exception(f"Error in {type(evt).__name__} handler")
  429. def disconnect(self) -> None:
  430. self._client.disconnect()
  431. async def _dispatcher_loop(self) -> None:
  432. loop_id = f"{hex(id(self))}#{time.monotonic()}"
  433. self.log.debug(f"Dispatcher loop {loop_id} starting")
  434. try:
  435. while True:
  436. evt = await self._outgoing_events.get()
  437. await asyncio.shield(self._dispatch(evt))
  438. except asyncio.CancelledError:
  439. tasks = self._outgoing_events
  440. self._outgoing_events = asyncio.Queue()
  441. if not tasks.empty():
  442. self.log.debug(
  443. f"Dispatcher loop {loop_id} stopping after dispatching {tasks.qsize()} events"
  444. )
  445. while not tasks.empty():
  446. await self._dispatch(tasks.get_nowait())
  447. raise
  448. finally:
  449. self.log.debug(f"Dispatcher loop {loop_id} stopped")
  450. async def listen(
  451. self,
  452. graphql_subs: set[str] | None = None,
  453. skywalker_subs: set[str] | None = None,
  454. seq_id: int = None,
  455. snapshot_at_ms: int = None,
  456. retry_limit: int = 5,
  457. ) -> None:
  458. self._graphql_subs = graphql_subs or set()
  459. self._skywalker_subs = skywalker_subs or set()
  460. self._iris_seq_id = seq_id
  461. self._iris_snapshot_at_ms = snapshot_at_ms
  462. self.log.debug("Connecting to Instagram MQTT")
  463. await self._reconnect()
  464. connection_retries = 0
  465. while True:
  466. try:
  467. await asyncio.sleep(1)
  468. except asyncio.CancelledError:
  469. self.disconnect()
  470. # this might not be necessary
  471. self._client.loop_misc()
  472. break
  473. rc = self._client.loop_misc()
  474. # If disconnect() has been called
  475. # Beware, internal API, may have to change this to something more stable!
  476. if self._client._state == paho.mqtt.client.mqtt_cs_disconnecting:
  477. break # Stop listening
  478. if rc != paho.mqtt.client.MQTT_ERR_SUCCESS:
  479. # If known/expected error
  480. if rc == paho.mqtt.client.MQTT_ERR_CONN_LOST:
  481. await self._dispatch(Disconnect(reason="Connection lost, retrying"))
  482. elif rc == paho.mqtt.client.MQTT_ERR_NOMEM:
  483. # This error is wrongly classified
  484. # See https://github.com/eclipse/paho.mqtt.python/issues/340
  485. await self._dispatch(Disconnect(reason="Connection lost, retrying"))
  486. elif rc == paho.mqtt.client.MQTT_ERR_CONN_REFUSED:
  487. raise MQTTNotLoggedIn("MQTT connection refused")
  488. elif rc == paho.mqtt.client.MQTT_ERR_NO_CONN:
  489. if connection_retries > retry_limit:
  490. raise MQTTNotConnected(f"Connection failed {connection_retries} times")
  491. sleep = connection_retries * 2
  492. await self._dispatch(
  493. Disconnect(
  494. reason="MQTT Error: no connection, retrying "
  495. f"in {connection_retries} seconds"
  496. )
  497. )
  498. await asyncio.sleep(sleep)
  499. else:
  500. err = paho.mqtt.client.error_string(rc)
  501. self.log.error("MQTT Error: %s", err)
  502. await self._dispatch(Disconnect(reason=f"MQTT Error: {err}, retrying"))
  503. await self._reconnect()
  504. connection_retries += 1
  505. else:
  506. connection_retries = 0
  507. if self._event_dispatcher_task:
  508. self._event_dispatcher_task.cancel()
  509. self._event_dispatcher_task = None
  510. if self._disconnect_error:
  511. self.log.info("disconnect_error is set, raising and clearing variable")
  512. err = self._disconnect_error
  513. self._disconnect_error = None
  514. raise err
  515. # region Basic outgoing MQTT
  516. def publish(self, topic: RealtimeTopic, payload: str | bytes | dict) -> asyncio.Future:
  517. if isinstance(payload, dict):
  518. payload = json.dumps(payload)
  519. if isinstance(payload, str):
  520. payload = payload.encode("utf-8")
  521. self.log.trace(f"Publishing message in {topic.value} ({topic.encoded}): {payload}")
  522. payload = zlib.compress(payload, level=9)
  523. info = self._client.publish(topic.encoded, payload, qos=1)
  524. self.log.trace(f"Published message ID: {info.mid}")
  525. fut = asyncio.Future()
  526. self._publish_waiters[info.mid] = fut
  527. return fut
  528. async def request(
  529. self,
  530. topic: RealtimeTopic,
  531. response: RealtimeTopic,
  532. payload: str | bytes | dict,
  533. timeout: int | None = None,
  534. ) -> MQTTMessage:
  535. async with self._response_waiter_locks[response]:
  536. fut = asyncio.Future()
  537. self._response_waiters[response] = fut
  538. await self.publish(topic, payload)
  539. self.log.trace(
  540. f"Request published to {topic.value}, waiting for response {response.name}"
  541. )
  542. return await asyncio.wait_for(fut, timeout)
  543. async def iris_subscribe(self, seq_id: int, snapshot_at_ms: int) -> None:
  544. self.log.debug(f"Requesting iris subscribe {seq_id}/{snapshot_at_ms}")
  545. resp = await self.request(
  546. RealtimeTopic.SUB_IRIS,
  547. RealtimeTopic.SUB_IRIS_RESPONSE,
  548. {"seq_id": seq_id, "snapshot_at_ms": snapshot_at_ms},
  549. timeout=20 * 1000,
  550. )
  551. self.log.debug("Iris subscribe response: %s", resp.payload.decode("utf-8"))
  552. resp_dict = json.loads(resp.payload.decode("utf-8"))
  553. if resp_dict["error_type"] and resp_dict["error_message"]:
  554. raise IrisSubscribeError(resp_dict["error_type"], resp_dict["error_message"])
  555. latest_seq_id = resp_dict.get("latest_seq_id")
  556. if latest_seq_id > self._iris_seq_id:
  557. self.log.info(f"Latest sequence ID is {latest_seq_id}, catching up from {seq_id}")
  558. self._iris_seq_id = latest_seq_id
  559. self._iris_snapshot_at_ms = resp_dict.get("subscribed_at_ms", int(time.time() * 1000))
  560. asyncio.create_task(
  561. self._dispatch(NewSequenceID(self._iris_seq_id, self._iris_snapshot_at_ms))
  562. )
  563. def graphql_subscribe(self, subs: set[str]) -> asyncio.Future:
  564. self._graphql_subs |= subs
  565. return self.publish(RealtimeTopic.REALTIME_SUB, {"sub": list(subs)})
  566. def graphql_unsubscribe(self, subs: set[str]) -> asyncio.Future:
  567. self._graphql_subs -= subs
  568. return self.publish(RealtimeTopic.REALTIME_SUB, {"unsub": list(subs)})
  569. def skywalker_subscribe(self, subs: set[str]) -> asyncio.Future:
  570. self._skywalker_subs |= subs
  571. return self.publish(RealtimeTopic.PUBSUB, {"sub": list(subs)})
  572. def skywalker_unsubscribe(self, subs: set[str]) -> asyncio.Future:
  573. self._skywalker_subs -= subs
  574. return self.publish(RealtimeTopic.PUBSUB, {"unsub": list(subs)})
  575. # endregion
  576. # region Actually sending messages and stuff
  577. async def send_foreground_state(self, state: ForegroundStateConfig) -> None:
  578. self.log.debug("Updating foreground state: %s", state)
  579. await self.publish(
  580. RealtimeTopic.FOREGROUND_STATE, zlib.compress(state.to_thrift(), level=9)
  581. )
  582. if state.keep_alive_timeout:
  583. self._client._keepalive = state.keep_alive_timeout
  584. async def send_command(
  585. self,
  586. thread_id: str,
  587. action: ThreadAction,
  588. client_context: str | None = None,
  589. **kwargs: Any,
  590. ) -> CommandResponse | None:
  591. client_context = client_context or self.state.gen_client_context()
  592. req = {
  593. "thread_id": thread_id,
  594. "client_context": client_context,
  595. "offline_threading_id": client_context,
  596. "action": action.value,
  597. # "device_id": self.state.cookies["ig_did"],
  598. **kwargs,
  599. }
  600. if action in (ThreadAction.MARK_SEEN,):
  601. # Some commands don't have client_context in the response, so we can't properly match
  602. # them to the requests. We probably don't need the data, so just ignore it.
  603. await self.publish(RealtimeTopic.SEND_MESSAGE, payload=req)
  604. return None
  605. else:
  606. fut = asyncio.Future()
  607. self._message_response_waiters[client_context] = fut
  608. await self.publish(RealtimeTopic.SEND_MESSAGE, req)
  609. self.log.trace(
  610. f"Request published to {RealtimeTopic.SEND_MESSAGE}, "
  611. f"waiting for response {RealtimeTopic.SEND_MESSAGE_RESPONSE}"
  612. )
  613. resp = await fut
  614. return CommandResponse.parse_json(resp.payload.decode("utf-8"))
  615. def send_item(
  616. self,
  617. thread_id: str,
  618. item_type: ThreadItemType,
  619. shh_mode: bool = False,
  620. client_context: str | None = None,
  621. **kwargs: Any,
  622. ) -> Awaitable[CommandResponse]:
  623. return self.send_command(
  624. thread_id,
  625. item_type=item_type.value,
  626. is_shh_mode=str(int(shh_mode)),
  627. action=ThreadAction.SEND_ITEM,
  628. client_context=client_context,
  629. **kwargs,
  630. )
  631. def send_hashtag(
  632. self,
  633. thread_id: str,
  634. hashtag: str,
  635. text: str = "",
  636. shh_mode: bool = False,
  637. client_context: str | None = None,
  638. ) -> Awaitable[CommandResponse]:
  639. return self.send_item(
  640. thread_id,
  641. text=text,
  642. item_id=hashtag,
  643. shh_mode=shh_mode,
  644. item_type=ThreadItemType.HASHTAG,
  645. client_context=client_context,
  646. )
  647. def send_like(
  648. self, thread_id: str, shh_mode: bool = False, client_context: str | None = None
  649. ) -> Awaitable[CommandResponse]:
  650. return self.send_item(
  651. thread_id,
  652. shh_mode=shh_mode,
  653. item_type=ThreadItemType.LIKE,
  654. client_context=client_context,
  655. )
  656. def send_location(
  657. self,
  658. thread_id: str,
  659. venue_id: str,
  660. text: str = "",
  661. shh_mode: bool = False,
  662. client_context: str | None = None,
  663. ) -> Awaitable[CommandResponse]:
  664. return self.send_item(
  665. thread_id,
  666. text=text,
  667. item_id=venue_id,
  668. shh_mode=shh_mode,
  669. item_type=ThreadItemType.LOCATION,
  670. client_context=client_context,
  671. )
  672. def send_media(
  673. self,
  674. thread_id: str,
  675. media_id: str,
  676. text: str = "",
  677. shh_mode: bool = False,
  678. client_context: str | None = None,
  679. ) -> Awaitable[CommandResponse]:
  680. return self.send_item(
  681. thread_id,
  682. text=text,
  683. media_id=media_id,
  684. shh_mode=shh_mode,
  685. item_type=ThreadItemType.MEDIA_SHARE,
  686. client_context=client_context,
  687. )
  688. def send_profile(
  689. self,
  690. thread_id: str,
  691. user_id: str,
  692. text: str = "",
  693. shh_mode: bool = False,
  694. client_context: str | None = None,
  695. ) -> Awaitable[CommandResponse]:
  696. return self.send_item(
  697. thread_id,
  698. text=text,
  699. item_id=user_id,
  700. shh_mode=shh_mode,
  701. item_type=ThreadItemType.PROFILE,
  702. client_context=client_context,
  703. )
  704. def send_reaction(
  705. self,
  706. thread_id: str,
  707. emoji: str,
  708. item_id: str,
  709. reaction_status: ReactionStatus = ReactionStatus.CREATED,
  710. target_item_type: ThreadItemType = ThreadItemType.TEXT,
  711. shh_mode: bool = False,
  712. client_context: str | None = None,
  713. ) -> Awaitable[CommandResponse]:
  714. return self.send_item(
  715. thread_id,
  716. reaction_status=reaction_status.value,
  717. node_type="item",
  718. reaction_type="like",
  719. target_item_type=target_item_type.value,
  720. emoji=emoji,
  721. item_id=item_id,
  722. reaction_action_source="double_tap",
  723. shh_mode=shh_mode,
  724. item_type=ThreadItemType.REACTION,
  725. client_context=client_context,
  726. )
  727. def send_user_story(
  728. self,
  729. thread_id: str,
  730. media_id: str,
  731. text: str = "",
  732. shh_mode: bool = False,
  733. client_context: str | None = None,
  734. ) -> Awaitable[CommandResponse]:
  735. return self.send_item(
  736. thread_id,
  737. text=text,
  738. item_id=media_id,
  739. shh_mode=shh_mode,
  740. item_type=ThreadItemType.REEL_SHARE,
  741. client_context=client_context,
  742. )
  743. def send_text(
  744. self,
  745. thread_id: str,
  746. text: str = "",
  747. urls: list[str] | None = None,
  748. shh_mode: bool = False,
  749. client_context: str | None = None,
  750. replied_to_item_id: str | None = None,
  751. replied_to_client_context: str | None = None,
  752. ) -> Awaitable[CommandResponse]:
  753. args = {
  754. "text": text,
  755. }
  756. item_type = ThreadItemType.TEXT
  757. if urls is not None:
  758. args = {
  759. "link_text": text,
  760. "link_urls": json.dumps(urls or []),
  761. }
  762. item_type = ThreadItemType.LINK
  763. return self.send_item(
  764. thread_id,
  765. **args,
  766. shh_mode=shh_mode,
  767. item_type=item_type,
  768. client_context=client_context,
  769. replied_to_item_id=replied_to_item_id,
  770. replied_to_client_context=replied_to_client_context,
  771. )
  772. def mark_seen(
  773. self, thread_id: str, item_id: str, client_context: str | None = None
  774. ) -> Awaitable[None]:
  775. return self.send_command(
  776. thread_id,
  777. item_id=item_id,
  778. action=ThreadAction.MARK_SEEN,
  779. client_context=client_context,
  780. )
  781. def mark_visual_item_seen(
  782. self, thread_id: str, item_id: str, client_context: str | None = None
  783. ) -> Awaitable[CommandResponse]:
  784. return self.send_command(
  785. thread_id,
  786. item_id=item_id,
  787. action=ThreadAction.MARK_VISUAL_ITEM_SEEN,
  788. client_context=client_context,
  789. )
  790. def indicate_activity(
  791. self,
  792. thread_id: str,
  793. activity_status: TypingStatus = TypingStatus.TEXT,
  794. client_context: str | None = None,
  795. ) -> Awaitable[CommandResponse]:
  796. return self.send_command(
  797. thread_id,
  798. activity_status=activity_status.value,
  799. action=ThreadAction.INDICATE_ACTIVITY,
  800. client_context=client_context,
  801. )
  802. # endregion