conn.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964
  1. # mautrix-instagram - A Matrix-Instagram puppeting bridge.
  2. # Copyright (C) 2022 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import Any, Awaitable, Callable, Iterable, Type, TypeVar
  18. from collections import defaultdict
  19. from socket import error as SocketError, socket
  20. import asyncio
  21. import json
  22. import logging
  23. import re
  24. import time
  25. import zlib
  26. from yarl import URL
  27. import paho.mqtt.client as pmc
  28. from mautrix.util.logging import TraceLogger
  29. from ..errors import (
  30. IrisSubscribeError,
  31. MQTTConnectionUnauthorized,
  32. MQTTNotConnected,
  33. MQTTNotLoggedIn,
  34. )
  35. from ..proxy import ProxyHandler
  36. from ..state import AndroidState
  37. from ..types import (
  38. AppPresenceEventPayload,
  39. ClientConfigUpdatePayload,
  40. CommandResponse,
  41. IrisPayload,
  42. IrisPayloadData,
  43. LiveVideoCommentPayload,
  44. MessageSyncEvent,
  45. MessageSyncMessage,
  46. Operation,
  47. PubsubEvent,
  48. PubsubPayload,
  49. ReactionStatus,
  50. ReactionType,
  51. RealtimeDirectEvent,
  52. RealtimeZeroProvisionPayload,
  53. ThreadAction,
  54. ThreadItemType,
  55. ThreadRemoveEvent,
  56. ThreadSyncEvent,
  57. TypingStatus,
  58. )
  59. from .events import Connect, Disconnect, NewSequenceID, ProxyUpdate
  60. from .otclient import MQTToTClient
  61. from .subscription import GraphQLQueryID, RealtimeTopic, everclear_subscriptions
  62. from .thrift import ForegroundStateConfig, IncomingMessage, RealtimeClientInfo, RealtimeConfig
  63. try:
  64. import socks
  65. except ImportError:
  66. socks = None
  67. T = TypeVar("T")
  68. ACTIVITY_INDICATOR_REGEX = re.compile(
  69. r"/direct_v2/threads/([\w_]+)/activity_indicator_id/([\w_]+)"
  70. )
  71. INBOX_THREAD_REGEX = re.compile(r"/direct_v2/inbox/threads/([\w_]+)")
  72. class AndroidMQTT:
  73. _loop: asyncio.AbstractEventLoop
  74. _client: MQTToTClient
  75. log: TraceLogger
  76. state: AndroidState
  77. _graphql_subs: set[str]
  78. _skywalker_subs: set[str]
  79. _iris_seq_id: int | None
  80. _iris_snapshot_at_ms: int | None
  81. _publish_waiters: dict[int, asyncio.Future]
  82. _response_waiters: dict[RealtimeTopic, asyncio.Future]
  83. _response_waiter_locks: dict[RealtimeTopic, asyncio.Lock]
  84. _message_response_waiter_lock: asyncio.Lock
  85. _message_response_waiter_id: str | None
  86. _message_response_waiter: asyncio.Future | None
  87. _disconnect_error: Exception | None
  88. _event_handlers: dict[Type[T], list[Callable[[T], Awaitable[None]]]]
  89. _outgoing_events: asyncio.Queue
  90. _event_dispatcher_task: asyncio.Task | None
  91. # region Initialization
  92. def __init__(
  93. self,
  94. state: AndroidState,
  95. loop: asyncio.AbstractEventLoop | None = None,
  96. log: TraceLogger | None = None,
  97. proxy_handler: ProxyHandler | None = None,
  98. ) -> None:
  99. self._graphql_subs = set()
  100. self._skywalker_subs = set()
  101. self._iris_seq_id = None
  102. self._iris_snapshot_at_ms = None
  103. self._publish_waiters = {}
  104. self._response_waiters = {}
  105. self._message_response_waiter_lock = asyncio.Lock()
  106. self._message_response_waiter_id = None
  107. self._message_response_waiter = None
  108. self._disconnect_error = None
  109. self._response_waiter_locks = defaultdict(lambda: asyncio.Lock())
  110. self._event_handlers = defaultdict(lambda: [])
  111. self._event_dispatcher_task = None
  112. self._outgoing_events = asyncio.Queue()
  113. self.log = log or logging.getLogger("mauigpapi.mqtt")
  114. self._loop = loop or asyncio.get_event_loop()
  115. self.state = state
  116. self._client = MQTToTClient(
  117. client_id=self._form_client_id(),
  118. clean_session=True,
  119. protocol=pmc.MQTTv31,
  120. transport="tcp",
  121. )
  122. self.proxy_handler = proxy_handler
  123. self.setup_proxy()
  124. self._client.enable_logger()
  125. self._client.tls_set()
  126. # mqtt.max_inflight_messages_set(20) # The rest will get queued
  127. # mqtt.max_queued_messages_set(0) # Unlimited messages can be queued
  128. # mqtt.message_retry_set(20) # Retry sending for at least 20 seconds
  129. # mqtt.reconnect_delay_set(min_delay=1, max_delay=120)
  130. self._client.connect_async("edge-mqtt.facebook.com", 443, keepalive=60)
  131. self._client.on_message = self._on_message_handler
  132. self._client.on_publish = self._on_publish_handler
  133. self._client.on_connect = self._on_connect_handler
  134. self._client.on_disconnect = self._on_disconnect_handler
  135. self._client.on_socket_open = self._on_socket_open
  136. self._client.on_socket_close = self._on_socket_close
  137. self._client.on_socket_register_write = self._on_socket_register_write
  138. self._client.on_socket_unregister_write = self._on_socket_unregister_write
  139. def setup_proxy(self):
  140. http_proxy = self.proxy_handler.get_proxy_url()
  141. if http_proxy:
  142. if not socks:
  143. self.log.warning("http_proxy is set, but pysocks is not installed")
  144. else:
  145. proxy_url = URL(http_proxy)
  146. proxy_type = {
  147. "http": socks.HTTP,
  148. "https": socks.HTTP,
  149. "socks": socks.SOCKS5,
  150. "socks5": socks.SOCKS5,
  151. "socks4": socks.SOCKS4,
  152. }[proxy_url.scheme]
  153. self._client.proxy_set(
  154. proxy_type=proxy_type,
  155. proxy_addr=proxy_url.host,
  156. proxy_port=proxy_url.port,
  157. proxy_username=proxy_url.user,
  158. proxy_password=proxy_url.password,
  159. )
  160. def _clear_response_waiters(self) -> None:
  161. for waiter in self._response_waiters.values():
  162. if not waiter.done():
  163. waiter.set_exception(
  164. MQTTNotConnected("MQTT disconnected before request returned response")
  165. )
  166. for waiter in self._publish_waiters.values():
  167. if not waiter.done():
  168. waiter.set_exception(
  169. MQTTNotConnected("MQTT disconnected before request was published")
  170. )
  171. if self._message_response_waiter and not self._message_response_waiter.done():
  172. self._message_response_waiter.set_exception(
  173. MQTTNotConnected("MQTT disconnected before message send returned response")
  174. )
  175. self._message_response_waiter = None
  176. self._message_response_waiter_id = None
  177. self._response_waiters = {}
  178. self._publish_waiters = {}
  179. def _form_client_id(self) -> bytes:
  180. subscribe_topics = [
  181. RealtimeTopic.PUBSUB, # 88
  182. RealtimeTopic.SUB_IRIS_RESPONSE, # 135
  183. RealtimeTopic.RS_REQ, # 244
  184. RealtimeTopic.REALTIME_SUB, # 149
  185. RealtimeTopic.REGION_HINT, # 150
  186. RealtimeTopic.RS_RESP, # 245
  187. RealtimeTopic.T_RTC_LOG, # 274
  188. RealtimeTopic.SEND_MESSAGE_RESPONSE, # 133
  189. RealtimeTopic.MESSAGE_SYNC, # 146
  190. RealtimeTopic.LIGHTSPEED_RESPONSE, # 179
  191. RealtimeTopic.UNKNOWN_PP, # 34
  192. ]
  193. subscribe_topic_ids = [int(topic.encoded) for topic in subscribe_topics]
  194. password = f"authorization={self.state.session.authorization}"
  195. # if not self.state.session.authorization:
  196. # password = f"sessionid={self.state.cookies['sessionid']}"
  197. cfg = RealtimeConfig(
  198. client_identifier=self.state.device.phone_id[:20],
  199. client_info=RealtimeClientInfo(
  200. user_id=int(self.state.user_id),
  201. user_agent=self.state.user_agent,
  202. client_capabilities=0b10110111,
  203. endpoint_capabilities=0,
  204. publish_format=1,
  205. no_automatic_foreground=True,
  206. make_user_available_in_foreground=False,
  207. device_id=self.state.device.phone_id,
  208. is_initially_foreground=False,
  209. network_type=1,
  210. network_subtype=0,
  211. client_mqtt_session_id=int(time.time() * 1000) & 0xFFFFFFFF,
  212. subscribe_topics=subscribe_topic_ids,
  213. client_type="cookie_auth",
  214. app_id=567067343352427,
  215. region_preference=self.state.session.region_hint or "LLA",
  216. device_secret="",
  217. client_stack=3,
  218. ),
  219. password=password,
  220. app_specific_info={
  221. "app_version": self.state.application.APP_VERSION,
  222. "X-IG-Capabilities": self.state.application.CAPABILITIES,
  223. "everclear_subscriptions": json.dumps(everclear_subscriptions),
  224. "User-Agent": self.state.user_agent,
  225. "Accept-Language": self.state.device.language.replace("_", "-"),
  226. "platform": "android",
  227. "ig_mqtt_route": "django",
  228. "pubsub_msg_type_blacklist": "direct, typing_type",
  229. "auth_cache_enabled": "1",
  230. },
  231. )
  232. return zlib.compress(cfg.to_thrift(), level=9)
  233. # endregion
  234. def _on_socket_open(self, client: MQTToTClient, _: Any, sock: socket) -> None:
  235. self._loop.add_reader(sock, client.loop_read)
  236. def _on_socket_close(self, client: MQTToTClient, _: Any, sock: socket) -> None:
  237. self._loop.remove_reader(sock)
  238. def _on_socket_register_write(self, client: MQTToTClient, _: Any, sock: socket) -> None:
  239. self._loop.add_writer(sock, client.loop_write)
  240. def _on_socket_unregister_write(self, client: MQTToTClient, _: Any, sock: socket) -> None:
  241. self._loop.remove_writer(sock)
  242. def _on_connect_handler(
  243. self, client: MQTToTClient, _: Any, flags: dict[str, Any], rc: int
  244. ) -> None:
  245. if rc != 0:
  246. err = pmc.connack_string(rc)
  247. self.log.error("MQTT Connection Error: %s (%d)", err, rc)
  248. if rc == pmc.CONNACK_REFUSED_NOT_AUTHORIZED:
  249. self._disconnect_error = MQTTConnectionUnauthorized()
  250. self.disconnect()
  251. return
  252. self._loop.create_task(self._post_connect())
  253. def _on_disconnect_handler(self, client: MQTToTClient, _: Any, rc: int) -> None:
  254. err_str = "Generic error." if rc == pmc.MQTT_ERR_NOMEM else pmc.error_string(rc)
  255. self.log.debug(f"MQTT disconnection code %d: %s", rc, err_str)
  256. self._clear_response_waiters()
  257. async def _post_connect(self) -> None:
  258. await self._dispatch(Connect())
  259. self.log.debug("Re-subscribing to things after connect")
  260. if self._graphql_subs:
  261. res = await self.graphql_subscribe(self._graphql_subs)
  262. self.log.trace("GraphQL subscribe response: %s", res)
  263. if self._skywalker_subs:
  264. res = await self.skywalker_subscribe(self._skywalker_subs)
  265. self.log.trace("Skywalker subscribe response: %s", res)
  266. if self._iris_seq_id:
  267. retry = 0
  268. while True:
  269. try:
  270. await self.iris_subscribe(self._iris_seq_id, self._iris_snapshot_at_ms)
  271. break
  272. except (asyncio.TimeoutError, IrisSubscribeError) as e:
  273. self.log.exception("Error requesting iris subscribe")
  274. retry += 1
  275. if retry >= 5 or isinstance(e, IrisSubscribeError):
  276. self._disconnect_error = e
  277. self.disconnect()
  278. break
  279. await asyncio.sleep(5)
  280. self.log.debug("Retrying iris subscribe")
  281. def _on_publish_handler(self, client: MQTToTClient, _: Any, mid: int) -> None:
  282. try:
  283. waiter = self._publish_waiters[mid]
  284. except KeyError:
  285. self.log.trace(f"Got publish confirmation for {mid}, but no waiters")
  286. return
  287. self.log.trace(f"Got publish confirmation for {mid}")
  288. waiter.set_result(None)
  289. # region Incoming event parsing
  290. def _parse_direct_thread_path(self, path: str) -> dict:
  291. try:
  292. blank, direct_v2, threads, thread_id, *rest = path.split("/")
  293. except (ValueError, IndexError) as e:
  294. self.log.debug(f"Got {e!r} while parsing path {path}")
  295. raise
  296. if (blank, direct_v2, threads) != ("", "direct_v2", "threads"):
  297. self.log.debug(f"Got unexpected first parts in direct thread path {path}")
  298. raise ValueError("unexpected first three parts in _parse_direct_thread_path")
  299. additional = {"thread_id": thread_id}
  300. if rest:
  301. subitem_key = rest[0]
  302. if subitem_key == "approval_required_for_new_members":
  303. additional["approval_required_for_new_members"] = True
  304. elif subitem_key == "participants" and len(rest) > 2 and rest[2] == "has_seen":
  305. additional["has_seen"] = int(rest[1])
  306. elif subitem_key == "items":
  307. additional["item_id"] = rest[1]
  308. if len(rest) > 4 and rest[2] == "reactions":
  309. additional["reaction_type"] = ReactionType(rest[3])
  310. additional["reaction_user_id"] = int(rest[4])
  311. elif subitem_key in "admin_user_ids":
  312. additional["admin_user_id"] = int(rest[1])
  313. elif subitem_key == "activity_indicator_id":
  314. additional["activity_indicator_id"] = rest[1]
  315. self.log.trace("Parsed path %s -> %s", path, additional)
  316. return additional
  317. def _on_messager_sync_item(self, part: IrisPayloadData, parsed_item: IrisPayload) -> bool:
  318. if part.path.startswith("/direct_v2/threads/"):
  319. raw_message = {
  320. "path": part.path,
  321. "op": part.op,
  322. **self._parse_direct_thread_path(part.path),
  323. }
  324. try:
  325. json_value = json.loads(part.value)
  326. if "reaction_type" in raw_message:
  327. self.log.trace("Treating %s as new reaction data", json_value)
  328. raw_message["new_reaction"] = json_value
  329. json_value["sender_id"] = raw_message.pop("reaction_user_id")
  330. json_value["type"] = raw_message.pop("reaction_type")
  331. json_value["client_context"] = parsed_item.mutation_token
  332. if part.op == Operation.REMOVE:
  333. json_value["emoji"] = None
  334. json_value["timestamp"] = None
  335. else:
  336. raw_message = {
  337. **raw_message,
  338. **json_value,
  339. }
  340. except (json.JSONDecodeError, TypeError):
  341. raw_message["value"] = part.value
  342. message = MessageSyncMessage.deserialize(raw_message)
  343. evt = MessageSyncEvent(iris=parsed_item, message=message)
  344. elif part.path.startswith("/direct_v2/inbox/threads/"):
  345. if part.op == Operation.REMOVE:
  346. blank, direct_v2, inbox, threads, thread_id, *_ = part.path.split("/")
  347. evt = ThreadRemoveEvent.deserialize(
  348. {
  349. "thread_id": thread_id,
  350. "path": part.path,
  351. "op": part.op,
  352. **json.loads(part.value),
  353. }
  354. )
  355. else:
  356. evt = ThreadSyncEvent.deserialize(
  357. {
  358. "path": part.path,
  359. "op": part.op,
  360. **json.loads(part.value),
  361. }
  362. )
  363. else:
  364. self.log.warning(f"Unsupported path {part.path}")
  365. return False
  366. self._outgoing_events.put_nowait(evt)
  367. return True
  368. def _on_message_sync(self, payload: bytes) -> None:
  369. parsed = json.loads(payload.decode("utf-8"))
  370. self.log.trace("Got message sync event: %s", parsed)
  371. has_items = False
  372. for sync_item in parsed:
  373. parsed_item = IrisPayload.deserialize(sync_item)
  374. if self._iris_seq_id < parsed_item.seq_id:
  375. self.log.trace(f"Got new seq_id: {parsed_item.seq_id}")
  376. self._iris_seq_id = parsed_item.seq_id
  377. self._iris_snapshot_at_ms = int(time.time() * 1000)
  378. asyncio.create_task(
  379. self._dispatch(NewSequenceID(self._iris_seq_id, self._iris_snapshot_at_ms))
  380. )
  381. for part in parsed_item.data:
  382. has_items = self._on_messager_sync_item(part, parsed_item) or has_items
  383. if has_items and not self._event_dispatcher_task:
  384. self._event_dispatcher_task = asyncio.create_task(self._dispatcher_loop())
  385. def _on_pubsub(self, payload: bytes) -> None:
  386. parsed_thrift = IncomingMessage.from_thrift(payload)
  387. self.log.trace(f"Got pubsub event {parsed_thrift.topic} / {parsed_thrift.payload}")
  388. message = PubsubPayload.parse_json(parsed_thrift.payload)
  389. for data in message.data:
  390. match = ACTIVITY_INDICATOR_REGEX.match(data.path)
  391. if match:
  392. evt = PubsubEvent(
  393. data=data,
  394. base=message,
  395. thread_id=match.group(1),
  396. activity_indicator_id=match.group(2),
  397. )
  398. self._loop.create_task(self._dispatch(evt))
  399. elif not data.double_publish:
  400. self.log.debug("Pubsub: no activity indicator on data: %s", data)
  401. else:
  402. self.log.debug("Pubsub: double publish: %s", data.path)
  403. def _parse_realtime_sub_item(self, topic: str | GraphQLQueryID, raw: dict) -> Iterable[Any]:
  404. if topic == GraphQLQueryID.APP_PRESENCE:
  405. yield AppPresenceEventPayload.deserialize(raw).presence_event
  406. elif topic == GraphQLQueryID.ZERO_PROVISION:
  407. yield RealtimeZeroProvisionPayload.deserialize(raw).zero_product_provisioning_event
  408. elif topic == GraphQLQueryID.CLIENT_CONFIG_UPDATE:
  409. yield ClientConfigUpdatePayload.deserialize(raw).client_config_update_event
  410. elif topic == GraphQLQueryID.LIVE_REALTIME_COMMENTS:
  411. yield LiveVideoCommentPayload.deserialize(raw).live_video_comment_event
  412. elif topic == "direct":
  413. event = raw["event"]
  414. for item in raw["data"]:
  415. yield RealtimeDirectEvent.deserialize(
  416. {
  417. "event": event,
  418. **self._parse_direct_thread_path(item["path"]),
  419. **item,
  420. }
  421. )
  422. def _on_realtime_sub(self, payload: bytes) -> None:
  423. parsed_thrift = IncomingMessage.from_thrift(payload)
  424. try:
  425. topic = GraphQLQueryID(parsed_thrift.topic)
  426. except ValueError:
  427. topic = parsed_thrift.topic
  428. self.log.trace(f"Got realtime sub event {topic} / {parsed_thrift.payload}")
  429. allowed = (
  430. "direct",
  431. GraphQLQueryID.APP_PRESENCE,
  432. GraphQLQueryID.ZERO_PROVISION,
  433. GraphQLQueryID.CLIENT_CONFIG_UPDATE,
  434. GraphQLQueryID.LIVE_REALTIME_COMMENTS,
  435. )
  436. if topic not in allowed:
  437. return
  438. parsed_json = json.loads(parsed_thrift.payload)
  439. for evt in self._parse_realtime_sub_item(topic, parsed_json):
  440. self._loop.create_task(self._dispatch(evt))
  441. def _handle_send_response(self, message: pmc.MQTTMessage) -> None:
  442. data = json.loads(message.payload.decode("utf-8"))
  443. try:
  444. ccid = data["payload"]["client_context"]
  445. except KeyError:
  446. self.log.warning(
  447. "Didn't find client_context in send message response: %s", message.payload
  448. )
  449. ccid = self._message_response_waiter_id
  450. else:
  451. if ccid != self._message_response_waiter_id:
  452. self.log.error(
  453. "Mismatching client_context in send message response (%s != %s)",
  454. ccid,
  455. self._message_response_waiter_id,
  456. )
  457. return
  458. if self._message_response_waiter and not self._message_response_waiter.done():
  459. self.log.debug("Got response to %s: %s", ccid, message.payload)
  460. self._message_response_waiter.set_result(message)
  461. self._message_response_waiter = None
  462. self._message_response_waiter_id = None
  463. else:
  464. self.log.warning("Didn't find task waiting for response %s", message.payload)
  465. def _on_message_handler(self, client: MQTToTClient, _: Any, message: pmc.MQTTMessage) -> None:
  466. try:
  467. topic = RealtimeTopic.decode(message.topic)
  468. # Instagram Android MQTT messages are always compressed
  469. message.payload = zlib.decompress(message.payload)
  470. if topic == RealtimeTopic.MESSAGE_SYNC:
  471. self._on_message_sync(message.payload)
  472. elif topic == RealtimeTopic.PUBSUB:
  473. self._on_pubsub(message.payload)
  474. elif topic == RealtimeTopic.REALTIME_SUB:
  475. self._on_realtime_sub(message.payload)
  476. elif topic == RealtimeTopic.SEND_MESSAGE_RESPONSE:
  477. self._handle_send_response(message)
  478. else:
  479. try:
  480. waiter = self._response_waiters.pop(topic)
  481. except KeyError:
  482. self.log.debug(
  483. "No handler for MQTT message in %s: %s", topic.value, message.payload
  484. )
  485. else:
  486. self.log.trace("Got response %s: %s", topic.value, message.payload)
  487. waiter.set_result(message)
  488. except Exception:
  489. self.log.exception("Error in incoming MQTT message handler")
  490. self.log.trace("Errored MQTT payload: %s", message.payload)
  491. # endregion
  492. async def _reconnect(self) -> None:
  493. try:
  494. self.log.trace("Trying to reconnect to MQTT")
  495. self._client.reconnect()
  496. except (SocketError, OSError, pmc.WebsocketConnectionError) as e:
  497. # TODO custom class
  498. raise MQTTNotLoggedIn("MQTT reconnection failed") from e
  499. def add_event_handler(
  500. self, evt_type: Type[T], handler: Callable[[T], Awaitable[None]]
  501. ) -> None:
  502. self._event_handlers[evt_type].append(handler)
  503. async def _dispatch(self, evt: T) -> None:
  504. for handler in self._event_handlers[type(evt)]:
  505. try:
  506. await handler(evt)
  507. except Exception:
  508. self.log.exception(f"Error in {type(evt).__name__} handler")
  509. def disconnect(self) -> None:
  510. self._client.disconnect()
  511. async def _dispatcher_loop(self) -> None:
  512. loop_id = f"{hex(id(self))}#{time.monotonic()}"
  513. self.log.debug(f"Dispatcher loop {loop_id} starting")
  514. try:
  515. while True:
  516. evt = await self._outgoing_events.get()
  517. await asyncio.shield(self._dispatch(evt))
  518. except asyncio.CancelledError:
  519. tasks = self._outgoing_events
  520. self._outgoing_events = asyncio.Queue()
  521. if not tasks.empty():
  522. self.log.debug(
  523. f"Dispatcher loop {loop_id} stopping after dispatching {tasks.qsize()} events"
  524. )
  525. while not tasks.empty():
  526. await self._dispatch(tasks.get_nowait())
  527. raise
  528. finally:
  529. self.log.debug(f"Dispatcher loop {loop_id} stopped")
  530. async def listen(
  531. self,
  532. graphql_subs: set[str] | None = None,
  533. skywalker_subs: set[str] | None = None,
  534. seq_id: int = None,
  535. snapshot_at_ms: int = None,
  536. retry_limit: int = 5,
  537. ) -> None:
  538. self._graphql_subs = graphql_subs or set()
  539. self._skywalker_subs = skywalker_subs or set()
  540. self._iris_seq_id = seq_id
  541. self._iris_snapshot_at_ms = snapshot_at_ms
  542. self.log.debug("Connecting to Instagram MQTT")
  543. await self._reconnect()
  544. connection_retries = 0
  545. while True:
  546. try:
  547. await asyncio.sleep(1)
  548. except asyncio.CancelledError:
  549. self.disconnect()
  550. # this might not be necessary
  551. self._client.loop_misc()
  552. break
  553. rc = self._client.loop_misc()
  554. # If disconnect() has been called
  555. # Beware, internal API, may have to change this to something more stable!
  556. if self._client._state == pmc.mqtt_cs_disconnecting:
  557. break # Stop listening
  558. if rc != pmc.MQTT_ERR_SUCCESS:
  559. # If known/expected error
  560. if rc == pmc.MQTT_ERR_CONN_LOST:
  561. await self._dispatch(Disconnect(reason="Connection lost, retrying"))
  562. elif rc == pmc.MQTT_ERR_NOMEM:
  563. # This error is wrongly classified
  564. # See https://github.com/eclipse/paho.mqtt.python/issues/340
  565. await self._dispatch(Disconnect(reason="Connection lost, retrying"))
  566. elif rc == pmc.MQTT_ERR_CONN_REFUSED:
  567. raise MQTTNotLoggedIn("MQTT connection refused")
  568. elif rc == pmc.MQTT_ERR_NO_CONN:
  569. if connection_retries > retry_limit:
  570. raise MQTTNotConnected(f"Connection failed {connection_retries} times")
  571. if self.proxy_handler.update_proxy_url():
  572. self.setup_proxy()
  573. await self._dispatch(ProxyUpdate())
  574. sleep = connection_retries * 2
  575. await self._dispatch(
  576. Disconnect(
  577. reason="MQTT Error: no connection, retrying "
  578. f"in {connection_retries} seconds"
  579. )
  580. )
  581. await asyncio.sleep(sleep)
  582. else:
  583. err = pmc.error_string(rc)
  584. self.log.error("MQTT Error: %s", err)
  585. await self._dispatch(Disconnect(reason=f"MQTT Error: {err}, retrying"))
  586. await self._reconnect()
  587. connection_retries += 1
  588. else:
  589. connection_retries = 0
  590. if self._event_dispatcher_task:
  591. self._event_dispatcher_task.cancel()
  592. self._event_dispatcher_task = None
  593. if self._disconnect_error:
  594. self.log.info("disconnect_error is set, raising and clearing variable")
  595. err = self._disconnect_error
  596. self._disconnect_error = None
  597. raise err
  598. # region Basic outgoing MQTT
  599. def publish(self, topic: RealtimeTopic, payload: str | bytes | dict) -> asyncio.Future:
  600. if isinstance(payload, dict):
  601. payload = json.dumps(payload)
  602. if isinstance(payload, str):
  603. payload = payload.encode("utf-8")
  604. self.log.trace(f"Publishing message in {topic.value} ({topic.encoded}): {payload}")
  605. payload = zlib.compress(payload, level=9)
  606. info = self._client.publish(topic.encoded, payload, qos=1)
  607. self.log.trace(f"Published message ID: {info.mid}")
  608. fut = asyncio.Future()
  609. self._publish_waiters[info.mid] = fut
  610. return fut
  611. async def request(
  612. self,
  613. topic: RealtimeTopic,
  614. response: RealtimeTopic,
  615. payload: str | bytes | dict,
  616. timeout: int | None = None,
  617. ) -> pmc.MQTTMessage:
  618. async with self._response_waiter_locks[response]:
  619. fut = asyncio.Future()
  620. self._response_waiters[response] = fut
  621. await self.publish(topic, payload)
  622. self.log.trace(
  623. f"Request published to {topic.value}, waiting for response {response.name}"
  624. )
  625. return await asyncio.wait_for(fut, timeout)
  626. async def iris_subscribe(self, seq_id: int, snapshot_at_ms: int) -> None:
  627. self.log.debug(f"Requesting iris subscribe {seq_id}/{snapshot_at_ms}")
  628. resp = await self.request(
  629. RealtimeTopic.SUB_IRIS,
  630. RealtimeTopic.SUB_IRIS_RESPONSE,
  631. {
  632. "seq_id": seq_id,
  633. "snapshot_at_ms": snapshot_at_ms,
  634. "snapshot_app_version": self.state.application.APP_VERSION,
  635. "timezone_offset": int(self.state.device.timezone_offset),
  636. "subscription_type": "message",
  637. },
  638. timeout=20,
  639. )
  640. self.log.debug("Iris subscribe response: %s", resp.payload.decode("utf-8"))
  641. resp_dict = json.loads(resp.payload.decode("utf-8"))
  642. if resp_dict["error_type"] and resp_dict["error_message"]:
  643. raise IrisSubscribeError(resp_dict["error_type"], resp_dict["error_message"])
  644. latest_seq_id = resp_dict.get("latest_seq_id")
  645. if latest_seq_id > self._iris_seq_id:
  646. self.log.info(f"Latest sequence ID is {latest_seq_id}, catching up from {seq_id}")
  647. self._iris_seq_id = latest_seq_id
  648. self._iris_snapshot_at_ms = resp_dict.get("subscribed_at_ms", int(time.time() * 1000))
  649. asyncio.create_task(
  650. self._dispatch(NewSequenceID(self._iris_seq_id, self._iris_snapshot_at_ms))
  651. )
  652. def graphql_subscribe(self, subs: set[str]) -> asyncio.Future:
  653. self._graphql_subs |= subs
  654. return self.publish(RealtimeTopic.REALTIME_SUB, {"sub": list(subs)})
  655. def graphql_unsubscribe(self, subs: set[str]) -> asyncio.Future:
  656. self._graphql_subs -= subs
  657. return self.publish(RealtimeTopic.REALTIME_SUB, {"unsub": list(subs)})
  658. def skywalker_subscribe(self, subs: set[str]) -> asyncio.Future:
  659. self._skywalker_subs |= subs
  660. return self.publish(RealtimeTopic.PUBSUB, {"sub": list(subs)})
  661. def skywalker_unsubscribe(self, subs: set[str]) -> asyncio.Future:
  662. self._skywalker_subs -= subs
  663. return self.publish(RealtimeTopic.PUBSUB, {"unsub": list(subs)})
  664. # endregion
  665. # region Actually sending messages and stuff
  666. async def send_foreground_state(self, state: ForegroundStateConfig) -> None:
  667. self.log.debug("Updating foreground state: %s", state)
  668. await self.publish(
  669. RealtimeTopic.FOREGROUND_STATE, zlib.compress(state.to_thrift(), level=9)
  670. )
  671. if state.keep_alive_timeout:
  672. self._client._keepalive = state.keep_alive_timeout
  673. async def send_command(
  674. self,
  675. thread_id: str,
  676. action: ThreadAction,
  677. client_context: str | None = None,
  678. **kwargs: Any,
  679. ) -> CommandResponse | None:
  680. client_context = client_context or self.state.gen_client_context()
  681. req = {
  682. "thread_id": thread_id,
  683. "client_context": client_context,
  684. "offline_threading_id": client_context,
  685. "action": action.value,
  686. # "device_id": self.state.cookies["ig_did"],
  687. **kwargs,
  688. }
  689. lock_start = time.monotonic()
  690. async with self._message_response_waiter_lock:
  691. lock_wait_dur = time.monotonic() - lock_start
  692. if lock_wait_dur > 1:
  693. self.log.debug(f"Waited {lock_wait_dur:.3f} seconds to send {client_context}")
  694. fut = self._message_response_waiter = asyncio.Future()
  695. self._message_response_waiter_id = client_context
  696. await self.publish(RealtimeTopic.SEND_MESSAGE, req)
  697. self.log.trace(
  698. f"Request published to {RealtimeTopic.SEND_MESSAGE}, "
  699. f"waiting for response {RealtimeTopic.SEND_MESSAGE_RESPONSE}"
  700. )
  701. try:
  702. resp = await asyncio.wait_for(fut, timeout=30000)
  703. except asyncio.TimeoutError:
  704. self.log.error(f"Request with ID {client_context} timed out!")
  705. raise
  706. return CommandResponse.parse_json(resp.payload.decode("utf-8"))
  707. def send_item(
  708. self,
  709. thread_id: str,
  710. item_type: ThreadItemType,
  711. shh_mode: bool = False,
  712. client_context: str | None = None,
  713. **kwargs: Any,
  714. ) -> Awaitable[CommandResponse]:
  715. return self.send_command(
  716. thread_id,
  717. item_type=item_type.value,
  718. is_shh_mode=str(int(shh_mode)),
  719. action=ThreadAction.SEND_ITEM,
  720. client_context=client_context,
  721. **kwargs,
  722. )
  723. def send_hashtag(
  724. self,
  725. thread_id: str,
  726. hashtag: str,
  727. text: str = "",
  728. shh_mode: bool = False,
  729. client_context: str | None = None,
  730. ) -> Awaitable[CommandResponse]:
  731. return self.send_item(
  732. thread_id,
  733. text=text,
  734. item_id=hashtag,
  735. shh_mode=shh_mode,
  736. item_type=ThreadItemType.HASHTAG,
  737. client_context=client_context,
  738. )
  739. def send_like(
  740. self, thread_id: str, shh_mode: bool = False, client_context: str | None = None
  741. ) -> Awaitable[CommandResponse]:
  742. return self.send_item(
  743. thread_id,
  744. shh_mode=shh_mode,
  745. item_type=ThreadItemType.LIKE,
  746. client_context=client_context,
  747. )
  748. def send_location(
  749. self,
  750. thread_id: str,
  751. venue_id: str,
  752. text: str = "",
  753. shh_mode: bool = False,
  754. client_context: str | None = None,
  755. ) -> Awaitable[CommandResponse]:
  756. return self.send_item(
  757. thread_id,
  758. text=text,
  759. item_id=venue_id,
  760. shh_mode=shh_mode,
  761. item_type=ThreadItemType.LOCATION,
  762. client_context=client_context,
  763. )
  764. def send_media(
  765. self,
  766. thread_id: str,
  767. media_id: str,
  768. text: str = "",
  769. shh_mode: bool = False,
  770. client_context: str | None = None,
  771. ) -> Awaitable[CommandResponse]:
  772. return self.send_item(
  773. thread_id,
  774. text=text,
  775. media_id=media_id,
  776. shh_mode=shh_mode,
  777. item_type=ThreadItemType.MEDIA_SHARE,
  778. client_context=client_context,
  779. )
  780. def send_profile(
  781. self,
  782. thread_id: str,
  783. user_id: str,
  784. text: str = "",
  785. shh_mode: bool = False,
  786. client_context: str | None = None,
  787. ) -> Awaitable[CommandResponse]:
  788. return self.send_item(
  789. thread_id,
  790. text=text,
  791. item_id=user_id,
  792. shh_mode=shh_mode,
  793. item_type=ThreadItemType.PROFILE,
  794. client_context=client_context,
  795. )
  796. def send_reaction(
  797. self,
  798. thread_id: str,
  799. emoji: str,
  800. item_id: str,
  801. reaction_status: ReactionStatus = ReactionStatus.CREATED,
  802. target_item_type: ThreadItemType = ThreadItemType.TEXT,
  803. shh_mode: bool = False,
  804. client_context: str | None = None,
  805. ) -> Awaitable[CommandResponse]:
  806. return self.send_item(
  807. thread_id,
  808. reaction_status=reaction_status.value,
  809. node_type="item",
  810. reaction_type="like",
  811. target_item_type=target_item_type.value,
  812. emoji=emoji,
  813. item_id=item_id,
  814. reaction_action_source="double_tap",
  815. shh_mode=shh_mode,
  816. item_type=ThreadItemType.REACTION,
  817. client_context=client_context,
  818. )
  819. def send_user_story(
  820. self,
  821. thread_id: str,
  822. media_id: str,
  823. text: str = "",
  824. shh_mode: bool = False,
  825. client_context: str | None = None,
  826. ) -> Awaitable[CommandResponse]:
  827. return self.send_item(
  828. thread_id,
  829. text=text,
  830. item_id=media_id,
  831. shh_mode=shh_mode,
  832. item_type=ThreadItemType.REEL_SHARE,
  833. client_context=client_context,
  834. )
  835. def send_text(
  836. self,
  837. thread_id: str,
  838. text: str = "",
  839. urls: list[str] | None = None,
  840. shh_mode: bool = False,
  841. client_context: str | None = None,
  842. replied_to_item_id: str | None = None,
  843. replied_to_client_context: str | None = None,
  844. ) -> Awaitable[CommandResponse]:
  845. args = {
  846. "text": text,
  847. }
  848. item_type = ThreadItemType.TEXT
  849. if urls is not None:
  850. args = {
  851. "link_text": text,
  852. "link_urls": json.dumps(urls or []),
  853. }
  854. item_type = ThreadItemType.LINK
  855. return self.send_item(
  856. thread_id,
  857. **args,
  858. shh_mode=shh_mode,
  859. item_type=item_type,
  860. client_context=client_context,
  861. replied_to_item_id=replied_to_item_id,
  862. replied_to_client_context=replied_to_client_context,
  863. )
  864. def mark_seen(
  865. self, thread_id: str, item_id: str, client_context: str | None = None
  866. ) -> Awaitable[None]:
  867. return self.send_command(
  868. thread_id,
  869. item_id=item_id,
  870. action=ThreadAction.MARK_SEEN,
  871. client_context=client_context,
  872. )
  873. def mark_visual_item_seen(
  874. self, thread_id: str, item_id: str, client_context: str | None = None
  875. ) -> Awaitable[CommandResponse]:
  876. return self.send_command(
  877. thread_id,
  878. item_id=item_id,
  879. action=ThreadAction.MARK_VISUAL_ITEM_SEEN,
  880. client_context=client_context,
  881. )
  882. def indicate_activity(
  883. self,
  884. thread_id: str,
  885. activity_status: TypingStatus = TypingStatus.TEXT,
  886. client_context: str | None = None,
  887. ) -> Awaitable[CommandResponse]:
  888. return self.send_command(
  889. thread_id,
  890. activity_status=activity_status.value,
  891. action=ThreadAction.INDICATE_ACTIVITY,
  892. client_context=client_context,
  893. )
  894. # endregion