|
@@ -69,6 +69,7 @@ class AndroidMQTT:
|
|
|
_publish_waiters: Dict[int, asyncio.Future]
|
|
|
_response_waiters: Dict[RealtimeTopic, asyncio.Future]
|
|
|
_response_waiter_locks: Dict[RealtimeTopic, asyncio.Lock]
|
|
|
+ _message_response_waiters: Dict[str, asyncio.Future]
|
|
|
_disconnect_error: Optional[Exception]
|
|
|
_event_handlers: Dict[Type[T], List[Callable[[T], Awaitable[None]]]]
|
|
|
|
|
@@ -82,6 +83,7 @@ class AndroidMQTT:
|
|
|
self._iris_snapshot_at_ms = None
|
|
|
self._publish_waiters = {}
|
|
|
self._response_waiters = {}
|
|
|
+ self._message_response_waiters = {}
|
|
|
self._disconnect_error = None
|
|
|
self._response_waiter_locks = defaultdict(lambda: asyncio.Lock())
|
|
|
self._event_handlers = defaultdict(lambda: [])
|
|
@@ -349,14 +351,25 @@ class AndroidMQTT:
|
|
|
self._on_pubsub(message.payload)
|
|
|
elif topic == RealtimeTopic.REALTIME_SUB:
|
|
|
self._on_realtime_sub(message.payload)
|
|
|
+ elif topic == RealtimeTopic.SEND_MESSAGE_RESPONSE:
|
|
|
+ try:
|
|
|
+ data = json.loads(message.payload.decode("utf-8"))
|
|
|
+ ccid = data["payload"]["client_context"]
|
|
|
+ waiter = self._message_response_waiters.pop(ccid)
|
|
|
+ except KeyError as e:
|
|
|
+ self.log.debug("No handler (%s) for send message response: %s",
|
|
|
+ e, message.payload)
|
|
|
+ else:
|
|
|
+ self.log.trace("Got response to %s: %s", ccid, message.payload)
|
|
|
+ waiter.set_result(message)
|
|
|
else:
|
|
|
- self.log.trace("Other message payload: %s", message.payload)
|
|
|
try:
|
|
|
waiter = self._response_waiters.pop(topic)
|
|
|
except KeyError:
|
|
|
self.log.debug("No handler for MQTT message in %s: %s",
|
|
|
topic.value, message.payload)
|
|
|
else:
|
|
|
+ self.log.trace("Got response %s: %s", topic.value, message.payload)
|
|
|
waiter.set_result(message)
|
|
|
except Exception:
|
|
|
self.log.exception("Error in incoming MQTT message handler")
|
|
@@ -506,7 +519,7 @@ class AndroidMQTT:
|
|
|
async def send_command(self, thread_id: str, action: ThreadAction,
|
|
|
client_context: Optional[str] = None,
|
|
|
offline_threading_id: Optional[str] = None, **kwargs: Any
|
|
|
- ) -> CommandResponse:
|
|
|
+ ) -> Optional[CommandResponse]:
|
|
|
client_context = client_context or str(uuid4())
|
|
|
req = {
|
|
|
"thread_id": thread_id,
|
|
@@ -516,9 +529,19 @@ class AndroidMQTT:
|
|
|
# "device_id": self.state.cookies["ig_did"],
|
|
|
**kwargs,
|
|
|
}
|
|
|
- resp = await self.request(RealtimeTopic.SEND_MESSAGE, RealtimeTopic.SEND_MESSAGE_RESPONSE,
|
|
|
- payload=req)
|
|
|
- return CommandResponse.parse_json(resp.payload.decode("utf-8"))
|
|
|
+ if action in (ThreadAction.MARK_SEEN,):
|
|
|
+ # Some commands don't have client_context in the response, so we can't properly match
|
|
|
+ # them to the requests. We probably don't need the data, so just ignore it.
|
|
|
+ await self.publish(RealtimeTopic.SEND_MESSAGE, payload=req)
|
|
|
+ return None
|
|
|
+ else:
|
|
|
+ fut = asyncio.Future()
|
|
|
+ self._message_response_waiters[client_context] = fut
|
|
|
+ await self.publish(RealtimeTopic.SEND_MESSAGE, req)
|
|
|
+ self.log.trace(f"Request published to {RealtimeTopic.SEND_MESSAGE}, "
|
|
|
+ f"waiting for response {RealtimeTopic.SEND_MESSAGE_RESPONSE}")
|
|
|
+ resp = await fut
|
|
|
+ return CommandResponse.parse_json(resp.payload.decode("utf-8"))
|
|
|
|
|
|
def send_item(self, thread_id: str, item_type: ThreadItemType, shh_mode: bool = False,
|
|
|
client_context: Optional[str] = None, offline_threading_id: Optional[str] = None,
|
|
@@ -591,7 +614,7 @@ class AndroidMQTT:
|
|
|
offline_threading_id=offline_threading_id)
|
|
|
|
|
|
def mark_seen(self, thread_id: str, item_id: str, client_context: Optional[str] = None,
|
|
|
- offline_threading_id: Optional[str] = None) -> Awaitable[CommandResponse]:
|
|
|
+ offline_threading_id: Optional[str] = None) -> Awaitable[None]:
|
|
|
return self.send_command(thread_id, item_id=item_id, action=ThreadAction.MARK_SEEN,
|
|
|
client_context=client_context,
|
|
|
offline_threading_id=offline_threading_id)
|