conn.py 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013
  1. # mautrix-instagram - A Matrix-Instagram puppeting bridge.
  2. # Copyright (C) 2022 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import Any, Awaitable, Callable, Iterable, Type, TypeVar
  18. from collections import defaultdict
  19. from socket import error as SocketError, socket
  20. import asyncio
  21. import json
  22. import logging
  23. import re
  24. import time
  25. import zlib
  26. from yarl import URL
  27. import paho.mqtt.client as pmc
  28. from mautrix.util import background_task
  29. from mautrix.util.logging import TraceLogger
  30. from mautrix.util.proxy import ProxyHandler, proxy_with_retry
  31. from ..errors import (
  32. IrisSubscribeError,
  33. MQTTConnectionUnauthorized,
  34. MQTTNotConnected,
  35. MQTTNotLoggedIn,
  36. MQTTReconnectionError,
  37. )
  38. from ..state import AndroidState
  39. from ..types import (
  40. AppPresenceEventPayload,
  41. ClientConfigUpdatePayload,
  42. CommandResponse,
  43. IrisPayload,
  44. IrisPayloadData,
  45. LiveVideoCommentPayload,
  46. MessageSyncEvent,
  47. MessageSyncMessage,
  48. Operation,
  49. PubsubEvent,
  50. PubsubPayload,
  51. ReactionStatus,
  52. ReactionType,
  53. RealtimeDirectEvent,
  54. RealtimeZeroProvisionPayload,
  55. ThreadAction,
  56. ThreadItemType,
  57. ThreadRemoveEvent,
  58. ThreadSyncEvent,
  59. TypingStatus,
  60. )
  61. from .events import Connect, Disconnect, NewSequenceID, ProxyUpdate
  62. from .otclient import MQTToTClient
  63. from .subscription import GraphQLQueryID, RealtimeTopic, everclear_subscriptions
  64. from .thrift import ForegroundStateConfig, IncomingMessage, RealtimeClientInfo, RealtimeConfig
  65. try:
  66. import socks
  67. except ImportError:
  68. socks = None
  69. T = TypeVar("T")
  70. ACTIVITY_INDICATOR_REGEX = re.compile(
  71. r"/direct_v2/threads/([\w_]+)/activity_indicator_id/([\w_]+)"
  72. )
  73. INBOX_THREAD_REGEX = re.compile(r"/direct_v2/inbox/threads/([\w_]+)")
  74. REQUEST_TIMEOUT = 60 * 3
  75. DEFAULT_KEEPALIVE = 60
  76. REQUEST_KEEPALIVE = 5
  77. class AndroidMQTT:
  78. _loop: asyncio.AbstractEventLoop
  79. _client: MQTToTClient
  80. log: TraceLogger
  81. state: AndroidState
  82. _graphql_subs: set[str]
  83. _skywalker_subs: set[str]
  84. _iris_seq_id: int | None
  85. _iris_snapshot_at_ms: int | None
  86. _publish_waiters: dict[int, asyncio.Future]
  87. _response_waiters: dict[RealtimeTopic, asyncio.Future]
  88. _response_waiter_locks: dict[RealtimeTopic, asyncio.Lock]
  89. _message_response_waiter_lock: asyncio.Lock
  90. _message_response_waiter_id: str | None
  91. _message_response_waiter: asyncio.Future | None
  92. _disconnect_error: Exception | None
  93. _event_handlers: dict[Type[T], list[Callable[[T], Awaitable[None]]]]
  94. _outgoing_events: asyncio.Queue
  95. _event_dispatcher_task: asyncio.Task | None
  96. # region Initialization
  97. def __init__(
  98. self,
  99. state: AndroidState,
  100. log: TraceLogger | None = None,
  101. proxy_handler: ProxyHandler | None = None,
  102. ) -> None:
  103. self._graphql_subs = set()
  104. self._skywalker_subs = set()
  105. self._iris_seq_id = None
  106. self._iris_snapshot_at_ms = None
  107. self._publish_waiters = {}
  108. self._response_waiters = {}
  109. self._message_response_waiter_lock = asyncio.Lock()
  110. self._message_response_waiter_id = None
  111. self._message_response_waiter = None
  112. self._disconnect_error = None
  113. self._response_waiter_locks = defaultdict(lambda: asyncio.Lock())
  114. self._event_handlers = defaultdict(lambda: [])
  115. self._event_dispatcher_task = None
  116. self._outgoing_events = asyncio.Queue()
  117. self.log = log or logging.getLogger("mauigpapi.mqtt")
  118. self._loop = asyncio.get_running_loop()
  119. self.state = state
  120. self._client = MQTToTClient(
  121. client_id=self._form_client_id(),
  122. clean_session=True,
  123. protocol=pmc.MQTTv31,
  124. transport="tcp",
  125. )
  126. self.proxy_handler = proxy_handler
  127. self.setup_proxy()
  128. self._client.enable_logger()
  129. self._client.tls_set()
  130. # mqtt.max_inflight_messages_set(20) # The rest will get queued
  131. # mqtt.max_queued_messages_set(0) # Unlimited messages can be queued
  132. # mqtt.message_retry_set(20) # Retry sending for at least 20 seconds
  133. # mqtt.reconnect_delay_set(min_delay=1, max_delay=120)
  134. self._client.connect_async("edge-mqtt.facebook.com", 443, keepalive=60)
  135. self._client.on_message = self._on_message_handler
  136. self._client.on_publish = self._on_publish_handler
  137. self._client.on_connect = self._on_connect_handler
  138. self._client.on_disconnect = self._on_disconnect_handler
  139. self._client.on_socket_open = self._on_socket_open
  140. self._client.on_socket_close = self._on_socket_close
  141. self._client.on_socket_register_write = self._on_socket_register_write
  142. self._client.on_socket_unregister_write = self._on_socket_unregister_write
  143. def setup_proxy(self):
  144. http_proxy = self.proxy_handler.get_proxy_url() if self.proxy_handler else None
  145. if http_proxy:
  146. if not socks:
  147. self.log.warning("http_proxy is set, but pysocks is not installed")
  148. else:
  149. proxy_url = URL(http_proxy)
  150. proxy_type = {
  151. "http": socks.HTTP,
  152. "https": socks.HTTP,
  153. "socks": socks.SOCKS5,
  154. "socks5": socks.SOCKS5,
  155. "socks4": socks.SOCKS4,
  156. }[proxy_url.scheme]
  157. self._client.proxy_set(
  158. proxy_type=proxy_type,
  159. proxy_addr=proxy_url.host,
  160. proxy_port=proxy_url.port,
  161. proxy_username=proxy_url.user,
  162. proxy_password=proxy_url.password,
  163. )
  164. def _clear_publish_waiters(self) -> None:
  165. for waiter in self._publish_waiters.values():
  166. if not waiter.done():
  167. waiter.set_exception(MQTTNotConnected("MQTT disconnected before PUBACK received"))
  168. self._publish_waiters = {}
  169. def _form_client_id(self) -> bytes:
  170. subscribe_topics = [
  171. RealtimeTopic.PUBSUB, # 88
  172. RealtimeTopic.SUB_IRIS_RESPONSE, # 135
  173. RealtimeTopic.RS_REQ, # 244
  174. RealtimeTopic.REALTIME_SUB, # 149
  175. RealtimeTopic.REGION_HINT, # 150
  176. RealtimeTopic.RS_RESP, # 245
  177. RealtimeTopic.T_RTC_LOG, # 274
  178. RealtimeTopic.SEND_MESSAGE_RESPONSE, # 133
  179. RealtimeTopic.MESSAGE_SYNC, # 146
  180. RealtimeTopic.LIGHTSPEED_RESPONSE, # 179
  181. RealtimeTopic.UNKNOWN_PP, # 34
  182. ]
  183. subscribe_topic_ids = [int(topic.encoded) for topic in subscribe_topics]
  184. password = f"authorization={self.state.session.authorization}"
  185. cfg = RealtimeConfig(
  186. client_identifier=self.state.device.phone_id[:20],
  187. client_info=RealtimeClientInfo(
  188. user_id=int(self.state.user_id),
  189. user_agent=self.state.user_agent,
  190. client_capabilities=0b10110111,
  191. endpoint_capabilities=0,
  192. publish_format=1,
  193. no_automatic_foreground=True,
  194. make_user_available_in_foreground=False,
  195. device_id=self.state.device.phone_id,
  196. is_initially_foreground=False,
  197. network_type=1,
  198. network_subtype=-1,
  199. client_mqtt_session_id=int(time.time() * 1000) & 0xFFFFFFFF,
  200. subscribe_topics=subscribe_topic_ids,
  201. client_type="cookie_auth",
  202. app_id=567067343352427,
  203. # region_preference=self.state.session.region_hint or "LLA",
  204. device_secret="",
  205. client_stack=3,
  206. ),
  207. password=password,
  208. app_specific_info={
  209. "capabilities": self.state.application.CAPABILITIES,
  210. "app_version": self.state.application.APP_VERSION,
  211. "everclear_subscriptions": json.dumps(everclear_subscriptions),
  212. "User-Agent": self.state.user_agent,
  213. "Accept-Language": self.state.device.language.replace("_", "-"),
  214. "platform": "android",
  215. "ig_mqtt_route": "django",
  216. "pubsub_msg_type_blacklist": "direct, typing_type",
  217. "auth_cache_enabled": "1",
  218. },
  219. )
  220. return zlib.compress(cfg.to_thrift(), level=9)
  221. # endregion
  222. def _on_socket_open(self, client: MQTToTClient, _: Any, sock: socket) -> None:
  223. self._loop.add_reader(sock, client.loop_read)
  224. def _on_socket_close(self, client: MQTToTClient, _: Any, sock: socket) -> None:
  225. self._loop.remove_reader(sock)
  226. def _on_socket_register_write(self, client: MQTToTClient, _: Any, sock: socket) -> None:
  227. self._loop.add_writer(sock, client.loop_write)
  228. def _on_socket_unregister_write(self, client: MQTToTClient, _: Any, sock: socket) -> None:
  229. self._loop.remove_writer(sock)
  230. def _on_connect_handler(
  231. self, client: MQTToTClient, _: Any, flags: dict[str, Any], rc: int
  232. ) -> None:
  233. if rc != 0:
  234. err = pmc.connack_string(rc)
  235. self.log.error("MQTT Connection Error: %s (%d)", err, rc)
  236. if rc == pmc.CONNACK_REFUSED_NOT_AUTHORIZED:
  237. self._disconnect_error = MQTTConnectionUnauthorized()
  238. self.disconnect()
  239. return
  240. self._loop.create_task(self._post_connect())
  241. def _on_disconnect_handler(self, client: MQTToTClient, _: Any, rc: int) -> None:
  242. err_str = "Generic error." if rc == pmc.MQTT_ERR_NOMEM else pmc.error_string(rc)
  243. self.log.debug("MQTT disconnection code %d: %s", rc, err_str)
  244. self._clear_publish_waiters()
  245. async def _post_connect(self) -> None:
  246. await self._dispatch(Connect())
  247. self.log.debug("Re-subscribing to things after connect")
  248. if self._graphql_subs:
  249. res = await self.graphql_subscribe(self._graphql_subs)
  250. self.log.trace("GraphQL subscribe response: %s", res)
  251. if self._skywalker_subs:
  252. res = await self.skywalker_subscribe(self._skywalker_subs)
  253. self.log.trace("Skywalker subscribe response: %s", res)
  254. if self._iris_seq_id:
  255. retry = 0
  256. while True:
  257. try:
  258. await self.iris_subscribe(self._iris_seq_id, self._iris_snapshot_at_ms)
  259. break
  260. except (asyncio.TimeoutError, IrisSubscribeError) as e:
  261. self.log.exception("Error requesting iris subscribe")
  262. retry += 1
  263. if retry >= 5 or isinstance(e, IrisSubscribeError):
  264. self._disconnect_error = e
  265. self.disconnect()
  266. break
  267. await asyncio.sleep(5)
  268. self.log.debug("Retrying iris subscribe")
  269. def _on_publish_handler(self, client: MQTToTClient, _: Any, mid: int) -> None:
  270. try:
  271. waiter = self._publish_waiters[mid]
  272. except KeyError:
  273. self.log.trace(f"Got publish confirmation for {mid}, but no waiters")
  274. return
  275. self.log.trace(f"Got publish confirmation for {mid}")
  276. if not waiter.done():
  277. waiter.set_result(None)
  278. # region Incoming event parsing
  279. def _parse_direct_thread_path(self, path: str) -> dict:
  280. try:
  281. blank, direct_v2, threads, thread_id, *rest = path.split("/")
  282. except (ValueError, IndexError) as e:
  283. self.log.debug(f"Got {e!r} while parsing path {path}")
  284. raise
  285. if (blank, direct_v2, threads) != ("", "direct_v2", "threads"):
  286. self.log.debug(f"Got unexpected first parts in direct thread path {path}")
  287. raise ValueError("unexpected first three parts in _parse_direct_thread_path")
  288. additional = {"thread_id": thread_id}
  289. if rest:
  290. subitem_key = rest[0]
  291. if subitem_key == "approval_required_for_new_members":
  292. additional["approval_required_for_new_members"] = True
  293. elif subitem_key == "thread_image":
  294. additional["is_thread_image"] = True
  295. elif subitem_key == "participants" and len(rest) > 2 and rest[2] == "has_seen":
  296. additional["has_seen"] = int(rest[1])
  297. elif subitem_key == "items":
  298. additional["item_id"] = rest[1]
  299. if len(rest) > 4 and rest[2] == "reactions":
  300. additional["reaction_type"] = ReactionType(rest[3])
  301. additional["reaction_user_id"] = int(rest[4])
  302. elif subitem_key in "admin_user_ids":
  303. additional["admin_user_id"] = int(rest[1])
  304. elif subitem_key == "activity_indicator_id":
  305. additional["activity_indicator_id"] = rest[1]
  306. self.log.trace("Parsed path %s -> %s", path, additional)
  307. return additional
  308. def _on_messager_sync_item(self, part: IrisPayloadData, parsed_item: IrisPayload) -> bool:
  309. if part.path.startswith("/direct_v2/threads/"):
  310. raw_message = {
  311. "path": part.path,
  312. "op": part.op,
  313. **self._parse_direct_thread_path(part.path),
  314. }
  315. try:
  316. json_value = json.loads(part.value)
  317. if "reaction_type" in raw_message:
  318. self.log.trace("Treating %s as new reaction data", json_value)
  319. raw_message["new_reaction"] = json_value
  320. json_value["sender_id"] = raw_message.pop("reaction_user_id")
  321. json_value["type"] = raw_message.pop("reaction_type")
  322. json_value["client_context"] = parsed_item.mutation_token
  323. if part.op == Operation.REMOVE:
  324. json_value["emoji"] = None
  325. json_value["timestamp"] = None
  326. else:
  327. raw_message = {
  328. **raw_message,
  329. **json_value,
  330. }
  331. except (json.JSONDecodeError, TypeError):
  332. raw_message["value"] = part.value
  333. message = MessageSyncMessage.deserialize(raw_message)
  334. evt = MessageSyncEvent(iris=parsed_item, message=message)
  335. elif part.path.startswith("/direct_v2/inbox/threads/"):
  336. if part.op == Operation.REMOVE:
  337. blank, direct_v2, inbox, threads, thread_id, *_ = part.path.split("/")
  338. evt = ThreadRemoveEvent.deserialize(
  339. {
  340. "thread_id": thread_id,
  341. "path": part.path,
  342. "op": part.op,
  343. **json.loads(part.value),
  344. }
  345. )
  346. else:
  347. evt = ThreadSyncEvent.deserialize(
  348. {
  349. "path": part.path,
  350. "op": part.op,
  351. **json.loads(part.value),
  352. }
  353. )
  354. else:
  355. self.log.warning(f"Unsupported path {part.path}")
  356. return False
  357. self._outgoing_events.put_nowait(evt)
  358. return True
  359. def _on_message_sync(self, payload: bytes) -> None:
  360. parsed = json.loads(payload.decode("utf-8"))
  361. self.log.trace("Got message sync event: %s", parsed)
  362. has_items = False
  363. for sync_item in parsed:
  364. parsed_item = IrisPayload.deserialize(sync_item)
  365. if self._iris_seq_id < parsed_item.seq_id:
  366. self.log.trace(f"Got new seq_id: {parsed_item.seq_id}")
  367. self._iris_seq_id = parsed_item.seq_id
  368. self._iris_snapshot_at_ms = int(time.time() * 1000)
  369. background_task.create(
  370. self._dispatch(NewSequenceID(self._iris_seq_id, self._iris_snapshot_at_ms))
  371. )
  372. for part in parsed_item.data:
  373. has_items = self._on_messager_sync_item(part, parsed_item) or has_items
  374. if has_items and not self._event_dispatcher_task:
  375. self._event_dispatcher_task = asyncio.create_task(self._dispatcher_loop())
  376. def _on_pubsub(self, payload: bytes) -> None:
  377. parsed_thrift = IncomingMessage.from_thrift(payload)
  378. self.log.trace(f"Got pubsub event {parsed_thrift.topic} / {parsed_thrift.payload}")
  379. message = PubsubPayload.parse_json(parsed_thrift.payload)
  380. for data in message.data:
  381. match = ACTIVITY_INDICATOR_REGEX.match(data.path)
  382. if match:
  383. evt = PubsubEvent(
  384. data=data,
  385. base=message,
  386. thread_id=match.group(1),
  387. activity_indicator_id=match.group(2),
  388. )
  389. self._loop.create_task(self._dispatch(evt))
  390. elif not data.double_publish:
  391. self.log.debug("Pubsub: no activity indicator on data: %s", data)
  392. else:
  393. self.log.debug("Pubsub: double publish: %s", data.path)
  394. def _parse_realtime_sub_item(self, topic: str | GraphQLQueryID, raw: dict) -> Iterable[Any]:
  395. if topic == GraphQLQueryID.APP_PRESENCE:
  396. yield AppPresenceEventPayload.deserialize(raw).presence_event
  397. elif topic == GraphQLQueryID.ZERO_PROVISION:
  398. yield RealtimeZeroProvisionPayload.deserialize(raw).zero_product_provisioning_event
  399. elif topic == GraphQLQueryID.CLIENT_CONFIG_UPDATE:
  400. yield ClientConfigUpdatePayload.deserialize(raw).client_config_update_event
  401. elif topic == GraphQLQueryID.LIVE_REALTIME_COMMENTS:
  402. yield LiveVideoCommentPayload.deserialize(raw).live_video_comment_event
  403. elif topic == "direct":
  404. event = raw["event"]
  405. for item in raw["data"]:
  406. yield RealtimeDirectEvent.deserialize(
  407. {
  408. "event": event,
  409. **self._parse_direct_thread_path(item["path"]),
  410. **item,
  411. }
  412. )
  413. def _on_realtime_sub(self, payload: bytes) -> None:
  414. parsed_thrift = IncomingMessage.from_thrift(payload)
  415. try:
  416. topic = GraphQLQueryID(parsed_thrift.topic)
  417. except ValueError:
  418. topic = parsed_thrift.topic
  419. self.log.trace(f"Got realtime sub event {topic} / {parsed_thrift.payload}")
  420. allowed = (
  421. "direct",
  422. GraphQLQueryID.APP_PRESENCE,
  423. GraphQLQueryID.ZERO_PROVISION,
  424. GraphQLQueryID.CLIENT_CONFIG_UPDATE,
  425. GraphQLQueryID.LIVE_REALTIME_COMMENTS,
  426. )
  427. if topic not in allowed:
  428. return
  429. parsed_json = json.loads(parsed_thrift.payload)
  430. for evt in self._parse_realtime_sub_item(topic, parsed_json):
  431. self._loop.create_task(self._dispatch(evt))
  432. def _handle_send_response(self, message: pmc.MQTTMessage) -> None:
  433. data = json.loads(message.payload.decode("utf-8"))
  434. try:
  435. ccid = data["payload"]["client_context"]
  436. except KeyError:
  437. self.log.warning(
  438. "Didn't find client_context in send message response: %s", message.payload
  439. )
  440. ccid = self._message_response_waiter_id
  441. else:
  442. if ccid != self._message_response_waiter_id:
  443. self.log.error(
  444. "Mismatching client_context in send message response (%s != %s)",
  445. ccid,
  446. self._message_response_waiter_id,
  447. )
  448. return
  449. if self._message_response_waiter and not self._message_response_waiter.done():
  450. self.log.debug("Got response to %s: %s", ccid, message.payload)
  451. self._message_response_waiter.set_result(message)
  452. self._message_response_waiter = None
  453. self._message_response_waiter_id = None
  454. else:
  455. self.log.warning("Didn't find task waiting for response %s", message.payload)
  456. def _on_message_handler(self, client: MQTToTClient, _: Any, message: pmc.MQTTMessage) -> None:
  457. try:
  458. topic = RealtimeTopic.decode(message.topic)
  459. # Instagram Android MQTT messages are always compressed
  460. message.payload = zlib.decompress(message.payload)
  461. if topic == RealtimeTopic.MESSAGE_SYNC:
  462. self._on_message_sync(message.payload)
  463. elif topic == RealtimeTopic.PUBSUB:
  464. self._on_pubsub(message.payload)
  465. elif topic == RealtimeTopic.REALTIME_SUB:
  466. self._on_realtime_sub(message.payload)
  467. elif topic == RealtimeTopic.SEND_MESSAGE_RESPONSE:
  468. self._handle_send_response(message)
  469. elif topic == RealtimeTopic.UNKNOWN_PP:
  470. self.log.warning("Reconnecting after receiving /pp message")
  471. background_task.create(self._reconnect())
  472. else:
  473. try:
  474. waiter = self._response_waiters.pop(topic)
  475. except KeyError:
  476. self.log.debug(
  477. "No handler for MQTT message in %s: %s", topic.value, message.payload
  478. )
  479. else:
  480. if not waiter.done():
  481. waiter.set_result(message)
  482. self.log.trace("Got response %s: %s", topic.value, message.payload)
  483. else:
  484. self.log.debug(
  485. "Got response in %s, but waiter was already cancelled: %s",
  486. topic,
  487. message.payload,
  488. )
  489. except Exception:
  490. self.log.exception("Error in incoming MQTT message handler")
  491. self.log.trace("Errored MQTT payload: %s", message.payload)
  492. # endregion
  493. async def _reconnect(self) -> None:
  494. try:
  495. self._client.reconnect()
  496. return
  497. except (SocketError, OSError, pmc.WebsocketConnectionError) as e:
  498. self.log.exception("Error reconnecting to MQTT")
  499. raise MQTTReconnectionError("MQTT reconnection failed") from e
  500. def add_event_handler(
  501. self, evt_type: Type[T], handler: Callable[[T], Awaitable[None]]
  502. ) -> None:
  503. self._event_handlers[evt_type].append(handler)
  504. async def _dispatch(self, evt: T) -> None:
  505. for handler in self._event_handlers[type(evt)]:
  506. try:
  507. await handler(evt)
  508. except Exception:
  509. self.log.exception(f"Error in {type(evt).__name__} handler")
  510. def disconnect(self) -> None:
  511. self._client.disconnect()
  512. async def _dispatcher_loop(self) -> None:
  513. loop_id = f"{hex(id(self))}#{time.monotonic()}"
  514. self.log.debug(f"Dispatcher loop {loop_id} starting")
  515. try:
  516. while True:
  517. evt = await self._outgoing_events.get()
  518. await asyncio.shield(self._dispatch(evt))
  519. except asyncio.CancelledError:
  520. tasks = self._outgoing_events
  521. self._outgoing_events = asyncio.Queue()
  522. if not tasks.empty():
  523. self.log.debug(
  524. f"Dispatcher loop {loop_id} stopping after dispatching {tasks.qsize()} events"
  525. )
  526. while not tasks.empty():
  527. await self._dispatch(tasks.get_nowait())
  528. raise
  529. finally:
  530. self.log.debug(f"Dispatcher loop {loop_id} stopped")
  531. async def listen(
  532. self,
  533. graphql_subs: set[str] | None = None,
  534. skywalker_subs: set[str] | None = None,
  535. seq_id: int = None,
  536. snapshot_at_ms: int = None,
  537. retry_limit: int = 10,
  538. ) -> None:
  539. self._graphql_subs = graphql_subs or set()
  540. self._skywalker_subs = skywalker_subs or set()
  541. self._iris_seq_id = seq_id
  542. self._iris_snapshot_at_ms = snapshot_at_ms
  543. self.log.debug("Connecting to Instagram MQTT")
  544. async def connect_and_watch():
  545. await self._reconnect()
  546. while True:
  547. try:
  548. await asyncio.sleep(1)
  549. except asyncio.CancelledError:
  550. self.disconnect()
  551. # this might not be necessary
  552. self._client.loop_misc()
  553. return
  554. rc = self._client.loop_misc()
  555. # If disconnect() has been called
  556. # Beware, internal API, may have to change this to something more stable!
  557. if self._client._state == pmc.mqtt_cs_disconnecting:
  558. return # Stop listening
  559. if rc != pmc.MQTT_ERR_SUCCESS:
  560. # If known/expected error
  561. if rc == pmc.MQTT_ERR_CONN_LOST:
  562. await self._dispatch(Disconnect(reason="Connection lost, retrying"))
  563. raise MQTTNotConnected("MQTT_ERR_CONN_LOST")
  564. elif rc == pmc.MQTT_ERR_NOMEM:
  565. # This error is wrongly classified
  566. # See https://github.com/eclipse/paho.mqtt.python/issues/340
  567. await self._dispatch(Disconnect(reason="Connection lost, retrying"))
  568. raise MQTTNotConnected("MQTT_ERR_NOMEM")
  569. elif rc == pmc.MQTT_ERR_CONN_REFUSED:
  570. await self._dispatch(Disconnect(reason="Connection refused, retrying"))
  571. raise MQTTNotLoggedIn("MQTT_ERR_CONN_REFUSED")
  572. elif rc == pmc.MQTT_ERR_NO_CONN:
  573. await self._dispatch(Disconnect(reason="Connection dropped, retrying"))
  574. raise MQTTNotConnected("MQTT_ERR_NO_CONN")
  575. else:
  576. err = pmc.error_string(rc)
  577. self.log.error("MQTT Error: %s", err)
  578. await self._dispatch(Disconnect(reason=f"MQTT Error: {err}, retrying"))
  579. raise MQTTNotConnected(err)
  580. await proxy_with_retry(
  581. "mqtt.listen",
  582. lambda: connect_and_watch(),
  583. logger=self.log,
  584. proxy_handler=self.proxy_handler,
  585. on_proxy_change=lambda: self._dispatch(ProxyUpdate()),
  586. max_retries=retry_limit,
  587. retryable_exceptions=(MQTTNotConnected, MQTTReconnectionError),
  588. # Wait 1s * errors, max 5s for fast reconnect or die
  589. max_wait_seconds=5,
  590. multiply_wait_seconds=1,
  591. # If connection stable for >1h, reset the error counter
  592. reset_after_seconds=3600,
  593. )
  594. if self._event_dispatcher_task:
  595. self._event_dispatcher_task.cancel()
  596. self._event_dispatcher_task = None
  597. if self._disconnect_error:
  598. self.log.info("disconnect_error is set, raising and clearing variable")
  599. err = self._disconnect_error
  600. self._disconnect_error = None
  601. raise err
  602. # region Basic outgoing MQTT
  603. @staticmethod
  604. def _publish_cancel_later(fut: asyncio.Future) -> None:
  605. if not fut.done():
  606. fut.set_exception(asyncio.TimeoutError("MQTT publish timed out"))
  607. @staticmethod
  608. def _request_cancel_later(fut: asyncio.Future) -> None:
  609. if not fut.done():
  610. fut.set_exception(asyncio.TimeoutError("MQTT request timed out"))
  611. # The following two functions mutate the client keepalive (cheeky) to temporarily increase
  612. # ping attempts during read/write to MQTT. If things are flowing this should change nothing,
  613. # as pings only send when idle. It should, however, allow the client to detect a bad MQTT
  614. # connection much quicker than the default keepalive.
  615. def set_request_keepalive(self):
  616. self._client._keepalive = REQUEST_KEEPALIVE
  617. def maybe_reset_keepalive(self):
  618. # Reset the keepalive back to the default value if we have no pending publish/receive
  619. if (
  620. not self._response_waiters
  621. and not self._publish_waiters
  622. and not self._message_response_waiter
  623. ):
  624. self._client._keepalive = DEFAULT_KEEPALIVE
  625. def publish(self, topic: RealtimeTopic, payload: str | bytes | dict) -> asyncio.Future:
  626. if isinstance(payload, dict):
  627. payload = json.dumps(payload)
  628. if isinstance(payload, str):
  629. payload = payload.encode("utf-8")
  630. self.log.trace(f"Publishing message in {topic.value} ({topic.encoded}): {payload}")
  631. payload = zlib.compress(payload, level=9)
  632. self.set_request_keepalive()
  633. info = self._client.publish(topic.encoded, payload, qos=1)
  634. self.log.trace(f"Published message ID: {info.mid}")
  635. fut = self._loop.create_future()
  636. timeout_handle = self._loop.call_later(REQUEST_TIMEOUT, self._publish_cancel_later, fut)
  637. fut.add_done_callback(lambda _: timeout_handle.cancel())
  638. fut.add_done_callback(lambda _: self.maybe_reset_keepalive())
  639. self._publish_waiters[info.mid] = fut
  640. return fut
  641. async def request(
  642. self,
  643. topic: RealtimeTopic,
  644. response: RealtimeTopic,
  645. payload: str | bytes | dict,
  646. timeout: int | None = None,
  647. ) -> pmc.MQTTMessage:
  648. async with self._response_waiter_locks[response]:
  649. fut = self._loop.create_future()
  650. self._response_waiters[response] = fut
  651. background_task.create(self.publish(topic, payload))
  652. self.log.trace(
  653. f"Request publish to {topic.value} queued, waiting for response {response.name}"
  654. )
  655. timeout_handle = self._loop.call_later(
  656. timeout or REQUEST_TIMEOUT, self._request_cancel_later, fut
  657. )
  658. fut.add_done_callback(lambda _: timeout_handle.cancel())
  659. fut.add_done_callback(lambda _: self.maybe_reset_keepalive())
  660. return await fut
  661. async def iris_subscribe(self, seq_id: int, snapshot_at_ms: int) -> None:
  662. self.log.debug(f"Requesting iris subscribe {seq_id}/{snapshot_at_ms}")
  663. resp = await self.request(
  664. RealtimeTopic.SUB_IRIS,
  665. RealtimeTopic.SUB_IRIS_RESPONSE,
  666. {
  667. "seq_id": seq_id,
  668. "snapshot_at_ms": snapshot_at_ms,
  669. "snapshot_app_version": self.state.application.APP_VERSION,
  670. "timezone_offset": int(self.state.device.timezone_offset),
  671. "subscription_type": "message",
  672. },
  673. timeout=20,
  674. )
  675. self.log.debug("Iris subscribe response: %s", resp.payload.decode("utf-8"))
  676. resp_dict = json.loads(resp.payload.decode("utf-8"))
  677. if resp_dict["error_type"] and resp_dict["error_message"]:
  678. raise IrisSubscribeError(resp_dict["error_type"], resp_dict["error_message"])
  679. latest_seq_id = resp_dict.get("latest_seq_id")
  680. if latest_seq_id > self._iris_seq_id:
  681. self.log.info(f"Latest sequence ID is {latest_seq_id}, catching up from {seq_id}")
  682. self._iris_seq_id = latest_seq_id
  683. self._iris_snapshot_at_ms = resp_dict.get("subscribed_at_ms", int(time.time() * 1000))
  684. background_task.create(
  685. self._dispatch(NewSequenceID(self._iris_seq_id, self._iris_snapshot_at_ms))
  686. )
  687. def graphql_subscribe(self, subs: set[str]) -> asyncio.Future:
  688. self._graphql_subs |= subs
  689. return self.publish(RealtimeTopic.REALTIME_SUB, {"sub": list(subs)})
  690. def graphql_unsubscribe(self, subs: set[str]) -> asyncio.Future:
  691. self._graphql_subs -= subs
  692. return self.publish(RealtimeTopic.REALTIME_SUB, {"unsub": list(subs)})
  693. def skywalker_subscribe(self, subs: set[str]) -> asyncio.Future:
  694. self._skywalker_subs |= subs
  695. return self.publish(RealtimeTopic.PUBSUB, {"sub": list(subs)})
  696. def skywalker_unsubscribe(self, subs: set[str]) -> asyncio.Future:
  697. self._skywalker_subs -= subs
  698. return self.publish(RealtimeTopic.PUBSUB, {"unsub": list(subs)})
  699. # endregion
  700. # region Actually sending messages and stuff
  701. async def send_foreground_state(self, state: ForegroundStateConfig) -> None:
  702. self.log.debug("Updating foreground state: %s", state)
  703. await self.publish(
  704. RealtimeTopic.FOREGROUND_STATE, zlib.compress(state.to_thrift(), level=9)
  705. )
  706. if state.keep_alive_timeout:
  707. self._client._keepalive = state.keep_alive_timeout
  708. async def send_command(
  709. self,
  710. thread_id: str,
  711. action: ThreadAction,
  712. client_context: str | None = None,
  713. **kwargs: Any,
  714. ) -> CommandResponse | None:
  715. self.log.debug(f"Preparing to send {action} to {thread_id} with {client_context}")
  716. client_context = client_context or self.state.gen_client_context()
  717. req = {
  718. "thread_id": thread_id,
  719. "client_context": client_context,
  720. "offline_threading_id": client_context,
  721. "action": action.value,
  722. # "device_id": self.state.cookies["ig_did"],
  723. **kwargs,
  724. }
  725. lock_start = time.monotonic()
  726. async with self._message_response_waiter_lock:
  727. lock_wait_dur = time.monotonic() - lock_start
  728. if lock_wait_dur > 1:
  729. self.log.warning(f"Waited {lock_wait_dur:.3f} seconds to send {client_context}")
  730. fut = self._message_response_waiter = asyncio.Future()
  731. self._message_response_waiter_id = client_context
  732. background_task.create(self.publish(RealtimeTopic.SEND_MESSAGE, req))
  733. self.log.debug(
  734. f"Request publish to {RealtimeTopic.SEND_MESSAGE} queued, "
  735. f"waiting for response {RealtimeTopic.SEND_MESSAGE_RESPONSE}"
  736. )
  737. try:
  738. resp = await asyncio.wait_for(fut, timeout=REQUEST_TIMEOUT)
  739. except asyncio.TimeoutError:
  740. self.log.error(f"Request with ID {client_context} timed out!")
  741. raise
  742. finally:
  743. self.maybe_reset_keepalive()
  744. return CommandResponse.parse_json(resp.payload.decode("utf-8"))
  745. def send_item(
  746. self,
  747. thread_id: str,
  748. item_type: ThreadItemType,
  749. shh_mode: bool = False,
  750. client_context: str | None = None,
  751. **kwargs: Any,
  752. ) -> Awaitable[CommandResponse]:
  753. return self.send_command(
  754. thread_id,
  755. item_type=item_type.value,
  756. is_shh_mode=str(int(shh_mode)),
  757. action=ThreadAction.SEND_ITEM,
  758. client_context=client_context,
  759. **kwargs,
  760. )
  761. def send_hashtag(
  762. self,
  763. thread_id: str,
  764. hashtag: str,
  765. text: str = "",
  766. shh_mode: bool = False,
  767. client_context: str | None = None,
  768. ) -> Awaitable[CommandResponse]:
  769. return self.send_item(
  770. thread_id,
  771. text=text,
  772. item_id=hashtag,
  773. shh_mode=shh_mode,
  774. item_type=ThreadItemType.HASHTAG,
  775. client_context=client_context,
  776. )
  777. def send_like(
  778. self, thread_id: str, shh_mode: bool = False, client_context: str | None = None
  779. ) -> Awaitable[CommandResponse]:
  780. return self.send_item(
  781. thread_id,
  782. shh_mode=shh_mode,
  783. item_type=ThreadItemType.LIKE,
  784. client_context=client_context,
  785. )
  786. def send_location(
  787. self,
  788. thread_id: str,
  789. venue_id: str,
  790. text: str = "",
  791. shh_mode: bool = False,
  792. client_context: str | None = None,
  793. ) -> Awaitable[CommandResponse]:
  794. return self.send_item(
  795. thread_id,
  796. text=text,
  797. item_id=venue_id,
  798. shh_mode=shh_mode,
  799. item_type=ThreadItemType.LOCATION,
  800. client_context=client_context,
  801. )
  802. def send_media(
  803. self,
  804. thread_id: str,
  805. media_id: str,
  806. text: str = "",
  807. shh_mode: bool = False,
  808. client_context: str | None = None,
  809. ) -> Awaitable[CommandResponse]:
  810. return self.send_item(
  811. thread_id,
  812. text=text,
  813. media_id=media_id,
  814. shh_mode=shh_mode,
  815. item_type=ThreadItemType.MEDIA_SHARE,
  816. client_context=client_context,
  817. )
  818. def send_profile(
  819. self,
  820. thread_id: str,
  821. user_id: str,
  822. text: str = "",
  823. shh_mode: bool = False,
  824. client_context: str | None = None,
  825. ) -> Awaitable[CommandResponse]:
  826. return self.send_item(
  827. thread_id,
  828. text=text,
  829. item_id=user_id,
  830. shh_mode=shh_mode,
  831. item_type=ThreadItemType.PROFILE,
  832. client_context=client_context,
  833. )
  834. def send_reaction(
  835. self,
  836. thread_id: str,
  837. emoji: str,
  838. item_id: str,
  839. reaction_status: ReactionStatus = ReactionStatus.CREATED,
  840. target_item_type: ThreadItemType = ThreadItemType.TEXT,
  841. shh_mode: bool = False,
  842. client_context: str | None = None,
  843. ) -> Awaitable[CommandResponse]:
  844. return self.send_item(
  845. thread_id,
  846. reaction_status=reaction_status.value,
  847. node_type="item",
  848. reaction_type="like",
  849. target_item_type=target_item_type.value,
  850. emoji=emoji,
  851. item_id=item_id,
  852. reaction_action_source="double_tap",
  853. shh_mode=shh_mode,
  854. item_type=ThreadItemType.REACTION,
  855. client_context=client_context,
  856. )
  857. def send_user_story(
  858. self,
  859. thread_id: str,
  860. media_id: str,
  861. text: str = "",
  862. shh_mode: bool = False,
  863. client_context: str | None = None,
  864. ) -> Awaitable[CommandResponse]:
  865. return self.send_item(
  866. thread_id,
  867. text=text,
  868. item_id=media_id,
  869. shh_mode=shh_mode,
  870. item_type=ThreadItemType.REEL_SHARE,
  871. client_context=client_context,
  872. )
  873. def send_text(
  874. self,
  875. thread_id: str,
  876. text: str = "",
  877. urls: list[str] | None = None,
  878. shh_mode: bool = False,
  879. client_context: str | None = None,
  880. replied_to_item_id: str | None = None,
  881. replied_to_client_context: str | None = None,
  882. mentioned_user_ids: list[int] | None = None,
  883. ) -> Awaitable[CommandResponse]:
  884. args = {
  885. "text": text,
  886. }
  887. item_type = ThreadItemType.TEXT
  888. if urls is not None:
  889. args = {
  890. "link_text": text,
  891. "link_urls": json.dumps(urls or []),
  892. }
  893. item_type = ThreadItemType.LINK
  894. if mentioned_user_ids:
  895. args["mentioned_user_ids"] = json.dumps([str(x) for x in mentioned_user_ids])
  896. args["sampled"] = True
  897. return self.send_item(
  898. thread_id,
  899. **args,
  900. shh_mode=shh_mode,
  901. item_type=item_type,
  902. client_context=client_context,
  903. replied_to_item_id=replied_to_item_id,
  904. replied_to_client_context=replied_to_client_context,
  905. )
  906. def mark_seen(
  907. self, thread_id: str, item_id: str, client_context: str | None = None
  908. ) -> Awaitable[None]:
  909. return self.send_command(
  910. thread_id,
  911. item_id=item_id,
  912. action=ThreadAction.MARK_SEEN,
  913. client_context=client_context,
  914. )
  915. def mark_visual_item_seen(
  916. self, thread_id: str, item_id: str, client_context: str | None = None
  917. ) -> Awaitable[CommandResponse]:
  918. return self.send_command(
  919. thread_id,
  920. item_id=item_id,
  921. action=ThreadAction.MARK_VISUAL_ITEM_SEEN,
  922. client_context=client_context,
  923. )
  924. def indicate_activity(
  925. self,
  926. thread_id: str,
  927. activity_status: TypingStatus = TypingStatus.TEXT,
  928. client_context: str | None = None,
  929. ) -> Awaitable[CommandResponse]:
  930. return self.send_command(
  931. thread_id,
  932. activity_status=activity_status.value,
  933. action=ThreadAction.INDICATE_ACTIVITY,
  934. client_context=client_context,
  935. )
  936. # endregion