conn.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635
  1. # mautrix-instagram - A Matrix-Instagram puppeting bridge.
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import (Union, Set, Optional, Any, Dict, Awaitable, Type, List, TypeVar, Callable,
  17. Iterable)
  18. from collections import defaultdict
  19. from socket import socket, error as SocketError
  20. from uuid import uuid4
  21. import urllib.request
  22. import logging
  23. import asyncio
  24. import random
  25. import zlib
  26. import time
  27. import json
  28. import re
  29. import paho.mqtt.client
  30. from paho.mqtt.client import MQTTMessage, WebsocketConnectionError
  31. from yarl import URL
  32. from mautrix.util.logging import TraceLogger
  33. from ..errors import MQTTNotLoggedIn, MQTTNotConnected, IrisSubscribeError
  34. from ..state import AndroidState
  35. from ..types import (CommandResponse, ThreadItemType, ThreadAction, ReactionStatus, TypingStatus,
  36. IrisPayload, PubsubPayload, AppPresenceEventPayload, RealtimeDirectEvent,
  37. RealtimeZeroProvisionPayload, ClientConfigUpdatePayload, MessageSyncEvent,
  38. MessageSyncMessage, LiveVideoCommentPayload, PubsubEvent, IrisPayloadData,
  39. ThreadSyncEvent)
  40. from .thrift import RealtimeConfig, RealtimeClientInfo, ForegroundStateConfig, IncomingMessage
  41. from .otclient import MQTToTClient
  42. from .subscription import everclear_subscriptions, RealtimeTopic, GraphQLQueryID
  43. from .events import Connect, Disconnect
  44. try:
  45. import socks
  46. except ImportError:
  47. socks = None
  48. T = TypeVar('T')
  49. ACTIVITY_INDICATOR_REGEX = re.compile(
  50. r"/direct_v2/threads/([\w_]+)/activity_indicator_id/([\w_]+)")
  51. INBOX_THREAD_REGEX = re.compile(
  52. r"/direct_v2/inbox/threads/([\w_]+)")
  53. class AndroidMQTT:
  54. _loop: asyncio.AbstractEventLoop
  55. _client: MQTToTClient
  56. log: TraceLogger
  57. state: AndroidState
  58. _graphql_subs: Set[str]
  59. _skywalker_subs: Set[str]
  60. _iris_seq_id: Optional[int]
  61. _iris_snapshot_at_ms: Optional[int]
  62. _publish_waiters: Dict[int, asyncio.Future]
  63. _response_waiters: Dict[RealtimeTopic, asyncio.Future]
  64. _response_waiter_locks: Dict[RealtimeTopic, asyncio.Lock]
  65. _message_response_waiters: Dict[str, asyncio.Future]
  66. _disconnect_error: Optional[Exception]
  67. _event_handlers: Dict[Type[T], List[Callable[[T], Awaitable[None]]]]
  68. # region Initialization
  69. def __init__(self, state: AndroidState, loop: Optional[asyncio.AbstractEventLoop] = None,
  70. log: Optional[TraceLogger] = None) -> None:
  71. self._graphql_subs = set()
  72. self._skywalker_subs = set()
  73. self._iris_seq_id = None
  74. self._iris_snapshot_at_ms = None
  75. self._publish_waiters = {}
  76. self._response_waiters = {}
  77. self._message_response_waiters = {}
  78. self._disconnect_error = None
  79. self._response_waiter_locks = defaultdict(lambda: asyncio.Lock())
  80. self._event_handlers = defaultdict(lambda: [])
  81. self.log = log or logging.getLogger("mauigpapi.mqtt")
  82. self._loop = loop or asyncio.get_event_loop()
  83. self.state = state
  84. self._client = MQTToTClient(
  85. client_id=self._form_client_id(),
  86. clean_session=True,
  87. protocol=paho.mqtt.client.MQTTv31,
  88. transport="tcp",
  89. )
  90. try:
  91. http_proxy = urllib.request.getproxies()["http"]
  92. except KeyError:
  93. http_proxy = None
  94. if http_proxy and socks and URL:
  95. proxy_url = URL(http_proxy)
  96. proxy_type = {
  97. "http": socks.HTTP,
  98. "https": socks.HTTP,
  99. "socks": socks.SOCKS5,
  100. "socks5": socks.SOCKS5,
  101. "socks4": socks.SOCKS4,
  102. }[proxy_url.scheme]
  103. self._client.proxy_set(proxy_type=proxy_type, proxy_addr=proxy_url.host,
  104. proxy_port=proxy_url.port, proxy_username=proxy_url.user,
  105. proxy_password=proxy_url.password)
  106. self._client.enable_logger()
  107. self._client.tls_set()
  108. # mqtt.max_inflight_messages_set(20) # The rest will get queued
  109. # mqtt.max_queued_messages_set(0) # Unlimited messages can be queued
  110. # mqtt.message_retry_set(20) # Retry sending for at least 20 seconds
  111. # mqtt.reconnect_delay_set(min_delay=1, max_delay=120)
  112. self._client.connect_async("edge-mqtt.facebook.com", 443, keepalive=60)
  113. self._client.on_message = self._on_message_handler
  114. self._client.on_publish = self._on_publish_handler
  115. self._client.on_connect = self._on_connect_handler
  116. # self._client.on_disconnect = self._on_disconnect_handler
  117. self._client.on_socket_open = self._on_socket_open
  118. self._client.on_socket_close = self._on_socket_close
  119. self._client.on_socket_register_write = self._on_socket_register_write
  120. self._client.on_socket_unregister_write = self._on_socket_unregister_write
  121. def _form_client_id(self) -> bytes:
  122. subscribe_topics = [RealtimeTopic.PUBSUB, RealtimeTopic.SUB_IRIS_RESPONSE,
  123. RealtimeTopic.REALTIME_SUB, RealtimeTopic.REGION_HINT,
  124. RealtimeTopic.SEND_MESSAGE_RESPONSE, RealtimeTopic.MESSAGE_SYNC,
  125. RealtimeTopic.UNKNOWN_179, RealtimeTopic.UNKNOWN_PP]
  126. subscribe_topic_ids = [int(topic.encoded) for topic in subscribe_topics]
  127. password = f"sessionid={self.state.cookies['sessionid']}"
  128. cfg = RealtimeConfig(
  129. client_identifier=self.state.device.phone_id[:20],
  130. client_info=RealtimeClientInfo(
  131. user_id=int(self.state.user_id),
  132. user_agent=self.state.user_agent,
  133. client_capabilities=0b10110111,
  134. endpoint_capabilities=0,
  135. publish_format=1,
  136. no_automatic_foreground=True,
  137. make_user_available_in_foreground=False,
  138. device_id=self.state.device.phone_id,
  139. is_initially_foreground=True,
  140. network_type=1,
  141. network_subtype=0,
  142. client_mqtt_session_id=int(time.time() * 1000) & 0xffffffff,
  143. subscribe_topics=subscribe_topic_ids,
  144. client_type="cookie_auth",
  145. app_id=567067343352427,
  146. region_preference=self.state.session.region_hint or "LLA",
  147. device_secret="",
  148. client_stack=3,
  149. ),
  150. password=password,
  151. app_specific_info={
  152. "app_version": self.state.application.APP_VERSION,
  153. "X-IG-Capabilities": self.state.application.CAPABILITIES,
  154. "everclear_subscriptions": json.dumps(everclear_subscriptions),
  155. "User-Agent": self.state.user_agent,
  156. "Accept-Language": self.state.device.language.replace("_", "-"),
  157. "platform": "android",
  158. "ig_mqtt_route": "django",
  159. "pubsub_msg_type_blacklist": "direct, typing_type",
  160. "auth_cache_enabled": "0",
  161. },
  162. )
  163. return zlib.compress(cfg.to_thrift(), level=9)
  164. # endregion
  165. def _on_socket_open(self, client: MQTToTClient, _: Any, sock: socket) -> None:
  166. self._loop.add_reader(sock, client.loop_read)
  167. def _on_socket_close(self, client: MQTToTClient, _: Any, sock: socket) -> None:
  168. self._loop.remove_reader(sock)
  169. def _on_socket_register_write(self, client: MQTToTClient, _: Any, sock: socket) -> None:
  170. self._loop.add_writer(sock, client.loop_write)
  171. def _on_socket_unregister_write(self, client: MQTToTClient, _: Any, sock: socket) -> None:
  172. self._loop.remove_writer(sock)
  173. def _on_connect_handler(self, client: MQTToTClient, _: Any, flags: Dict[str, Any], rc: int
  174. ) -> None:
  175. if rc != 0:
  176. err = paho.mqtt.client.connack_string(rc)
  177. self.log.error("MQTT Connection Error: %s (%d)", err, rc)
  178. return
  179. self._loop.create_task(self._post_connect())
  180. async def _post_connect(self) -> None:
  181. await self._dispatch(Connect())
  182. self.log.debug("Re-subscribing to things after connect")
  183. if self._graphql_subs:
  184. res = await self.graphql_subscribe(self._graphql_subs)
  185. self.log.trace("GraphQL subscribe response: %s", res)
  186. if self._skywalker_subs:
  187. res = await self.skywalker_subscribe(self._skywalker_subs)
  188. self.log.trace("Skywalker subscribe response: %s", res)
  189. if self._iris_seq_id:
  190. retry = 0
  191. while True:
  192. try:
  193. await self.iris_subscribe(self._iris_seq_id, self._iris_snapshot_at_ms)
  194. break
  195. except (asyncio.TimeoutError, IrisSubscribeError) as e:
  196. self.log.exception("Error requesting iris subscribe")
  197. retry += 1
  198. if retry >= 5:
  199. self._disconnect_error = e
  200. self.disconnect()
  201. break
  202. await asyncio.sleep(5)
  203. self.log.debug("Retrying iris subscribe")
  204. def _on_publish_handler(self, client: MQTToTClient, _: Any, mid: int) -> None:
  205. try:
  206. waiter = self._publish_waiters[mid]
  207. except KeyError:
  208. self.log.trace(f"Got publish confirmation for {mid}, but no waiters")
  209. return
  210. self.log.trace(f"Got publish confirmation for {mid}")
  211. waiter.set_result(None)
  212. # region Incoming event parsing
  213. def _parse_direct_thread_path(self, path: str) -> dict:
  214. try:
  215. blank, direct_v2, threads, thread_id, *rest = path.split("/")
  216. except (ValueError, IndexError) as e:
  217. self.log.debug(f"Got {e!r} while parsing path {path}")
  218. raise
  219. if (blank, direct_v2, threads) != ("", "direct_v2", "threads"):
  220. self.log.debug(f"Got unexpected first parts in direct thread path {path}")
  221. raise ValueError("unexpected first three parts in _parse_direct_thread_path")
  222. additional = {
  223. "thread_id": thread_id
  224. }
  225. if rest:
  226. subitem_key = rest[0]
  227. if subitem_key == "approval_required_for_new_members":
  228. additional["approval_required_for_new_members"] = True
  229. elif subitem_key == "participants" and len(rest) > 2 and rest[2] == "has_seen":
  230. additional["has_seen"] = int(rest[1])
  231. elif subitem_key == "items":
  232. additional["item_id"] = rest[1]
  233. # TODO wtf is this?
  234. # it has something to do with reactions
  235. if len(rest) > 4:
  236. additional[rest[2]] = {
  237. rest[3]: rest[4],
  238. }
  239. elif subitem_key in "admin_user_ids":
  240. additional["admin_user_id"] = int(rest[1])
  241. elif subitem_key == "activity_indicator_id":
  242. additional["activity_indicator_id"] = rest[1]
  243. self.log.trace("Parsed path %s -> %s", path, additional)
  244. return additional
  245. def _on_messager_sync_item(self, part: IrisPayloadData, parsed_item: IrisPayload) -> None:
  246. if part.path.startswith("/direct_v2/threads/"):
  247. raw_message = {
  248. "path": part.path,
  249. "op": part.op,
  250. **self._parse_direct_thread_path(part.path),
  251. }
  252. try:
  253. raw_message = {
  254. **raw_message,
  255. **json.loads(part.value),
  256. }
  257. except (json.JSONDecodeError, TypeError):
  258. raw_message["value"] = part.value
  259. message = MessageSyncMessage.deserialize(raw_message)
  260. evt = MessageSyncEvent(iris=parsed_item, message=message)
  261. elif part.path.startswith("/direct_v2/inbox/threads/"):
  262. raw_message = {
  263. "path": part.path,
  264. "op": part.op,
  265. **json.loads(part.value),
  266. }
  267. evt = ThreadSyncEvent.deserialize(raw_message)
  268. else:
  269. self.log.warning(f"Unsupported path {part.path}")
  270. return
  271. self._loop.create_task(self._dispatch(evt))
  272. def _on_message_sync(self, payload: bytes) -> None:
  273. parsed = json.loads(payload.decode("utf-8"))
  274. self.log.trace("Got message sync event: %s", parsed)
  275. for sync_item in parsed:
  276. parsed_item = IrisPayload.deserialize(sync_item)
  277. if self._iris_seq_id < parsed_item.seq_id:
  278. self.log.trace(f"Got new seq_id: {parsed_item.seq_id}")
  279. self._iris_seq_id = parsed_item.seq_id
  280. self._iris_snapshot_at_ms = int(time.time() * 1000)
  281. for part in parsed_item.data:
  282. self._on_messager_sync_item(part, parsed_item)
  283. def _on_pubsub(self, payload: bytes) -> None:
  284. parsed_thrift = IncomingMessage.from_thrift(payload)
  285. self.log.trace(f"Got pubsub event {parsed_thrift.topic} / {parsed_thrift.payload}")
  286. message = PubsubPayload.parse_json(parsed_thrift.payload)
  287. for data in message.data:
  288. match = ACTIVITY_INDICATOR_REGEX.match(data.path)
  289. if match:
  290. evt = PubsubEvent(data=data, base=message, thread_id=match.group(1),
  291. activity_indicator_id=match.group(2))
  292. self._loop.create_task(self._dispatch(evt))
  293. elif not data.double_publish:
  294. self.log.debug("Pubsub: no activity indicator on data: %s", data)
  295. else:
  296. self.log.debug("Pubsub: double publish: %s", data.path)
  297. def _parse_realtime_sub_item(self, topic: Union[str, GraphQLQueryID], raw: dict
  298. ) -> Iterable[Any]:
  299. if topic == GraphQLQueryID.APP_PRESENCE:
  300. yield AppPresenceEventPayload.deserialize(raw).presence_event
  301. elif topic == GraphQLQueryID.ZERO_PROVISION:
  302. yield RealtimeZeroProvisionPayload.deserialize(raw).zero_product_provisioning_event
  303. elif topic == GraphQLQueryID.CLIENT_CONFIG_UPDATE:
  304. yield ClientConfigUpdatePayload.deserialize(raw).client_config_update_event
  305. elif topic == GraphQLQueryID.LIVE_REALTIME_COMMENTS:
  306. yield LiveVideoCommentPayload.deserialize(raw).live_video_comment_event
  307. elif topic == "direct":
  308. event = raw["event"]
  309. for item in raw["data"]:
  310. yield RealtimeDirectEvent.deserialize({
  311. "event": event,
  312. **self._parse_direct_thread_path(item["path"]),
  313. **item,
  314. })
  315. def _on_realtime_sub(self, payload: bytes) -> None:
  316. parsed_thrift = IncomingMessage.from_thrift(payload)
  317. try:
  318. topic = GraphQLQueryID(parsed_thrift.topic)
  319. except ValueError:
  320. topic = parsed_thrift.topic
  321. self.log.trace(f"Got realtime sub event {topic} / {parsed_thrift.payload}")
  322. allowed = ("direct", GraphQLQueryID.APP_PRESENCE, GraphQLQueryID.ZERO_PROVISION,
  323. GraphQLQueryID.CLIENT_CONFIG_UPDATE, GraphQLQueryID.LIVE_REALTIME_COMMENTS)
  324. if topic not in allowed:
  325. return
  326. parsed_json = json.loads(parsed_thrift.payload)
  327. for evt in self._parse_realtime_sub_item(topic, parsed_json):
  328. self._loop.create_task(self._dispatch(evt))
  329. def _on_message_handler(self, client: MQTToTClient, _: Any, message: MQTTMessage) -> None:
  330. try:
  331. topic = RealtimeTopic.decode(message.topic)
  332. # Instagram Android MQTT messages are always compressed
  333. message.payload = zlib.decompress(message.payload)
  334. if topic == RealtimeTopic.MESSAGE_SYNC:
  335. self._on_message_sync(message.payload)
  336. elif topic == RealtimeTopic.PUBSUB:
  337. self._on_pubsub(message.payload)
  338. elif topic == RealtimeTopic.REALTIME_SUB:
  339. self._on_realtime_sub(message.payload)
  340. elif topic == RealtimeTopic.SEND_MESSAGE_RESPONSE:
  341. try:
  342. data = json.loads(message.payload.decode("utf-8"))
  343. ccid = data["payload"]["client_context"]
  344. waiter = self._message_response_waiters.pop(ccid)
  345. except KeyError as e:
  346. self.log.debug("No handler (%s) for send message response: %s",
  347. e, message.payload)
  348. else:
  349. self.log.trace("Got response to %s: %s", ccid, message.payload)
  350. waiter.set_result(message)
  351. else:
  352. try:
  353. waiter = self._response_waiters.pop(topic)
  354. except KeyError:
  355. self.log.debug("No handler for MQTT message in %s: %s",
  356. topic.value, message.payload)
  357. else:
  358. self.log.trace("Got response %s: %s", topic.value, message.payload)
  359. waiter.set_result(message)
  360. except Exception:
  361. self.log.exception("Error in incoming MQTT message handler")
  362. self.log.trace("Errored MQTT payload: %s", message.payload)
  363. # endregion
  364. async def _reconnect(self) -> None:
  365. try:
  366. self.log.trace("Trying to reconnect to MQTT")
  367. self._client.reconnect()
  368. except (SocketError, OSError, WebsocketConnectionError) as e:
  369. # TODO custom class
  370. raise MQTTNotLoggedIn("MQTT reconnection failed") from e
  371. def add_event_handler(self, evt_type: Type[T], handler: Callable[[T], Awaitable[None]]
  372. ) -> None:
  373. self._event_handlers[evt_type].append(handler)
  374. async def _dispatch(self, evt: T) -> None:
  375. for handler in self._event_handlers[type(evt)]:
  376. try:
  377. await handler(evt)
  378. except Exception:
  379. self.log.exception(f"Error in {type(evt)} handler")
  380. def disconnect(self) -> None:
  381. self._client.disconnect()
  382. async def listen(self, graphql_subs: Set[str] = None, skywalker_subs: Set[str] = None,
  383. seq_id: int = None, snapshot_at_ms: int = None, retry_limit: int = 5) -> None:
  384. self._graphql_subs = graphql_subs or set()
  385. self._skywalker_subs = skywalker_subs or set()
  386. self._iris_seq_id = seq_id
  387. self._iris_snapshot_at_ms = snapshot_at_ms
  388. self.log.debug("Connecting to Instagram MQTT")
  389. await self._reconnect()
  390. connection_retries = 0
  391. while True:
  392. try:
  393. await asyncio.sleep(1)
  394. except asyncio.CancelledError:
  395. self.disconnect()
  396. # this might not be necessary
  397. self._client.loop_misc()
  398. break
  399. rc = self._client.loop_misc()
  400. # If disconnect() has been called
  401. # Beware, internal API, may have to change this to something more stable!
  402. if self._client._state == paho.mqtt.client.mqtt_cs_disconnecting:
  403. break # Stop listening
  404. if rc != paho.mqtt.client.MQTT_ERR_SUCCESS:
  405. # If known/expected error
  406. if rc == paho.mqtt.client.MQTT_ERR_CONN_LOST:
  407. await self._dispatch(Disconnect(reason="Connection lost, retrying"))
  408. elif rc == paho.mqtt.client.MQTT_ERR_NOMEM:
  409. # This error is wrongly classified
  410. # See https://github.com/eclipse/paho.mqtt.python/issues/340
  411. await self._dispatch(Disconnect(reason="Connection lost, retrying"))
  412. elif rc == paho.mqtt.client.MQTT_ERR_CONN_REFUSED:
  413. raise MQTTNotLoggedIn("MQTT connection refused")
  414. elif rc == paho.mqtt.client.MQTT_ERR_NO_CONN:
  415. if connection_retries > retry_limit:
  416. raise MQTTNotConnected(f"Connection failed {connection_retries} times")
  417. sleep = connection_retries * 2
  418. await self._dispatch(Disconnect(reason="MQTT Error: no connection, retrying "
  419. f"in {connection_retries} seconds"))
  420. await asyncio.sleep(sleep)
  421. else:
  422. err = paho.mqtt.client.error_string(rc)
  423. self.log.error("MQTT Error: %s", err)
  424. await self._dispatch(Disconnect(reason=f"MQTT Error: {err}, retrying"))
  425. await self._reconnect()
  426. connection_retries += 1
  427. else:
  428. connection_retries = 0
  429. if self._disconnect_error:
  430. self.log.info("disconnect_error is set, raising and clearing variable")
  431. err = self._disconnect_error
  432. self._disconnect_error = None
  433. raise err
  434. # region Basic outgoing MQTT
  435. def publish(self, topic: RealtimeTopic, payload: Union[str, bytes, dict]
  436. ) -> asyncio.Future:
  437. if isinstance(payload, dict):
  438. payload = json.dumps(payload)
  439. if isinstance(payload, str):
  440. payload = payload.encode("utf-8")
  441. self.log.trace(f"Publishing message in {topic.value} ({topic.encoded}): {payload}")
  442. payload = zlib.compress(payload, level=9)
  443. info = self._client.publish(topic.encoded, payload, qos=1)
  444. self.log.trace(f"Published message ID: {info.mid}")
  445. fut = asyncio.Future()
  446. self._publish_waiters[info.mid] = fut
  447. return fut
  448. async def request(self, topic: RealtimeTopic, response: RealtimeTopic,
  449. payload: Union[str, bytes, dict], timeout: Optional[int] = None
  450. ) -> MQTTMessage:
  451. async with self._response_waiter_locks[response]:
  452. fut = asyncio.Future()
  453. self._response_waiters[response] = fut
  454. await self.publish(topic, payload)
  455. self.log.trace(f"Request published to {topic.value}, "
  456. f"waiting for response {response.name}")
  457. return await asyncio.wait_for(fut, timeout)
  458. async def iris_subscribe(self, seq_id: int, snapshot_at_ms: int) -> None:
  459. self.log.debug(f"Requesting iris subscribe {seq_id}/{snapshot_at_ms}")
  460. resp = await self.request(RealtimeTopic.SUB_IRIS, RealtimeTopic.SUB_IRIS_RESPONSE,
  461. {"seq_id": seq_id, "snapshot_at_ms": snapshot_at_ms},
  462. timeout=20 * 1000)
  463. self.log.debug("Iris subscribe response: %s", resp.payload.decode("utf-8"))
  464. resp_dict = json.loads(resp.payload.decode("utf-8"))
  465. if resp_dict["error_type"] and resp_dict["error_message"]:
  466. raise IrisSubscribeError(resp_dict["error_type"], resp_dict["error_message"])
  467. def graphql_subscribe(self, subs: Set[str]) -> asyncio.Future:
  468. self._graphql_subs |= subs
  469. return self.publish(RealtimeTopic.REALTIME_SUB, {"sub": list(subs)})
  470. def graphql_unsubscribe(self, subs: Set[str]) -> asyncio.Future:
  471. self._graphql_subs -= subs
  472. return self.publish(RealtimeTopic.REALTIME_SUB, {"unsub": list(subs)})
  473. def skywalker_subscribe(self, subs: Set[str]) -> asyncio.Future:
  474. self._skywalker_subs |= subs
  475. return self.publish(RealtimeTopic.PUBSUB, {"sub": list(subs)})
  476. def skywalker_unsubscribe(self, subs: Set[str]) -> asyncio.Future:
  477. self._skywalker_subs -= subs
  478. return self.publish(RealtimeTopic.PUBSUB, {"unsub": list(subs)})
  479. # endregion
  480. # region Actually sending messages and stuff
  481. async def send_foreground_state(self, state: ForegroundStateConfig) -> None:
  482. self.log.debug("Updating foreground state: %s", state)
  483. await self.publish(RealtimeTopic.FOREGROUND_STATE,
  484. zlib.compress(state.to_thrift(), level=9))
  485. if state.keep_alive_timeout:
  486. self._client._keepalive = state.keep_alive_timeout
  487. async def send_command(self, thread_id: str, action: ThreadAction,
  488. client_context: Optional[str] = None, **kwargs: Any
  489. ) -> Optional[CommandResponse]:
  490. client_context = client_context or self.state.gen_client_context()
  491. req = {
  492. "thread_id": thread_id,
  493. "client_context": client_context,
  494. "offline_threading_id": client_context,
  495. "action": action.value,
  496. # "device_id": self.state.cookies["ig_did"],
  497. **kwargs,
  498. }
  499. if action in (ThreadAction.MARK_SEEN,):
  500. # Some commands don't have client_context in the response, so we can't properly match
  501. # them to the requests. We probably don't need the data, so just ignore it.
  502. await self.publish(RealtimeTopic.SEND_MESSAGE, payload=req)
  503. return None
  504. else:
  505. fut = asyncio.Future()
  506. self._message_response_waiters[client_context] = fut
  507. await self.publish(RealtimeTopic.SEND_MESSAGE, req)
  508. self.log.trace(f"Request published to {RealtimeTopic.SEND_MESSAGE}, "
  509. f"waiting for response {RealtimeTopic.SEND_MESSAGE_RESPONSE}")
  510. resp = await fut
  511. return CommandResponse.parse_json(resp.payload.decode("utf-8"))
  512. def send_item(self, thread_id: str, item_type: ThreadItemType, shh_mode: bool = False,
  513. client_context: Optional[str] = None, **kwargs: Any
  514. ) -> Awaitable[CommandResponse]:
  515. return self.send_command(thread_id, item_type=item_type.value,
  516. is_shh_mode=str(int(shh_mode)), action=ThreadAction.SEND_ITEM,
  517. client_context=client_context, **kwargs)
  518. def send_hashtag(self, thread_id: str, hashtag: str, text: str = "", shh_mode: bool = False,
  519. client_context: Optional[str] = None) -> Awaitable[CommandResponse]:
  520. return self.send_item(thread_id, text=text, item_id=hashtag, shh_mode=shh_mode,
  521. item_type=ThreadItemType.HASHTAG, client_context=client_context)
  522. def send_like(self, thread_id: str, shh_mode: bool = False,
  523. client_context: Optional[str] = None) -> Awaitable[CommandResponse]:
  524. return self.send_item(thread_id, shh_mode=shh_mode, item_type=ThreadItemType.LIKE,
  525. client_context=client_context)
  526. def send_location(self, thread_id: str, venue_id: str, text: str = "",
  527. shh_mode: bool = False, client_context: Optional[str] = None
  528. ) -> Awaitable[CommandResponse]:
  529. return self.send_item(thread_id, text=text, item_id=venue_id, shh_mode=shh_mode,
  530. item_type=ThreadItemType.LOCATION, client_context=client_context)
  531. def send_media(self, thread_id: str, media_id: str, text: str = "", shh_mode: bool = False,
  532. client_context: Optional[str] = None) -> Awaitable[CommandResponse]:
  533. return self.send_item(thread_id, text=text, media_id=media_id, shh_mode=shh_mode,
  534. item_type=ThreadItemType.MEDIA_SHARE, client_context=client_context)
  535. def send_profile(self, thread_id: str, user_id: str, text: str = "", shh_mode: bool = False,
  536. client_context: Optional[str] = None) -> Awaitable[CommandResponse]:
  537. return self.send_item(thread_id, text=text, item_id=user_id, shh_mode=shh_mode,
  538. item_type=ThreadItemType.PROFILE, client_context=client_context)
  539. def send_reaction(self, thread_id: str, emoji: str, item_id: str,
  540. reaction_status: ReactionStatus = ReactionStatus.CREATED,
  541. target_item_type: ThreadItemType = ThreadItemType.TEXT,
  542. shh_mode: bool = False, client_context: Optional[str] = None
  543. ) -> Awaitable[CommandResponse]:
  544. return self.send_item(thread_id, reaction_status=reaction_status.value, node_type="item",
  545. reaction_type="like", target_item_type=target_item_type.value,
  546. emoji=emoji, item_id=item_id, reaction_action_source="double_tap",
  547. shh_mode=shh_mode, item_type=ThreadItemType.REACTION,
  548. client_context=client_context)
  549. def send_user_story(self, thread_id: str, media_id: str, text: str = "",
  550. shh_mode: bool = False, client_context: Optional[str] = None
  551. ) -> Awaitable[CommandResponse]:
  552. return self.send_item(thread_id, text=text, item_id=media_id, shh_mode=shh_mode,
  553. item_type=ThreadItemType.REEL_SHARE, client_context=client_context)
  554. def send_text(self, thread_id: str, text: str = "", shh_mode: bool = False,
  555. client_context: Optional[str] = None) -> Awaitable[CommandResponse]:
  556. return self.send_item(thread_id, text=text, shh_mode=shh_mode,
  557. item_type=ThreadItemType.TEXT, client_context=client_context)
  558. def mark_seen(self, thread_id: str, item_id: str, client_context: Optional[str] = None
  559. ) -> Awaitable[None]:
  560. return self.send_command(thread_id, item_id=item_id, action=ThreadAction.MARK_SEEN,
  561. client_context=client_context)
  562. def mark_visual_item_seen(self, thread_id: str, item_id: str,
  563. client_context: Optional[str] = None) -> Awaitable[CommandResponse]:
  564. return self.send_command(thread_id, item_id=item_id,
  565. action=ThreadAction.MARK_VISUAL_ITEM_SEEN,
  566. client_context=client_context)
  567. def indicate_activity(self, thread_id: str, activity_status: TypingStatus = TypingStatus.TEXT,
  568. client_context: Optional[str] = None) -> Awaitable[CommandResponse]:
  569. return self.send_command(thread_id, activity_status=activity_status.value,
  570. action=ThreadAction.INDICATE_ACTIVITY,
  571. client_context=client_context)
  572. # endregion