Browse Source

Fix read receipt responses sometimes breaking message sending

Tulir Asokan 4 years ago
parent
commit
43da674c11
1 changed files with 29 additions and 6 deletions
  1. 29 6
      mauigpapi/mqtt/conn.py

+ 29 - 6
mauigpapi/mqtt/conn.py

@@ -69,6 +69,7 @@ class AndroidMQTT:
     _publish_waiters: Dict[int, asyncio.Future]
     _publish_waiters: Dict[int, asyncio.Future]
     _response_waiters: Dict[RealtimeTopic, asyncio.Future]
     _response_waiters: Dict[RealtimeTopic, asyncio.Future]
     _response_waiter_locks: Dict[RealtimeTopic, asyncio.Lock]
     _response_waiter_locks: Dict[RealtimeTopic, asyncio.Lock]
+    _message_response_waiters: Dict[str, asyncio.Future]
     _disconnect_error: Optional[Exception]
     _disconnect_error: Optional[Exception]
     _event_handlers: Dict[Type[T], List[Callable[[T], Awaitable[None]]]]
     _event_handlers: Dict[Type[T], List[Callable[[T], Awaitable[None]]]]
 
 
@@ -82,6 +83,7 @@ class AndroidMQTT:
         self._iris_snapshot_at_ms = None
         self._iris_snapshot_at_ms = None
         self._publish_waiters = {}
         self._publish_waiters = {}
         self._response_waiters = {}
         self._response_waiters = {}
+        self._message_response_waiters = {}
         self._disconnect_error = None
         self._disconnect_error = None
         self._response_waiter_locks = defaultdict(lambda: asyncio.Lock())
         self._response_waiter_locks = defaultdict(lambda: asyncio.Lock())
         self._event_handlers = defaultdict(lambda: [])
         self._event_handlers = defaultdict(lambda: [])
@@ -349,14 +351,25 @@ class AndroidMQTT:
                 self._on_pubsub(message.payload)
                 self._on_pubsub(message.payload)
             elif topic == RealtimeTopic.REALTIME_SUB:
             elif topic == RealtimeTopic.REALTIME_SUB:
                 self._on_realtime_sub(message.payload)
                 self._on_realtime_sub(message.payload)
+            elif topic == RealtimeTopic.SEND_MESSAGE_RESPONSE:
+                try:
+                    data = json.loads(message.payload.decode("utf-8"))
+                    ccid = data["payload"]["client_context"]
+                    waiter = self._message_response_waiters.pop(ccid)
+                except KeyError as e:
+                    self.log.debug("No handler (%s) for send message response: %s",
+                                   e, message.payload)
+                else:
+                    self.log.trace("Got response to %s: %s", ccid, message.payload)
+                    waiter.set_result(message)
             else:
             else:
-                self.log.trace("Other message payload: %s", message.payload)
                 try:
                 try:
                     waiter = self._response_waiters.pop(topic)
                     waiter = self._response_waiters.pop(topic)
                 except KeyError:
                 except KeyError:
                     self.log.debug("No handler for MQTT message in %s: %s",
                     self.log.debug("No handler for MQTT message in %s: %s",
                                    topic.value, message.payload)
                                    topic.value, message.payload)
                 else:
                 else:
+                    self.log.trace("Got response %s: %s", topic.value, message.payload)
                     waiter.set_result(message)
                     waiter.set_result(message)
         except Exception:
         except Exception:
             self.log.exception("Error in incoming MQTT message handler")
             self.log.exception("Error in incoming MQTT message handler")
@@ -506,7 +519,7 @@ class AndroidMQTT:
     async def send_command(self, thread_id: str, action: ThreadAction,
     async def send_command(self, thread_id: str, action: ThreadAction,
                            client_context: Optional[str] = None,
                            client_context: Optional[str] = None,
                            offline_threading_id: Optional[str] = None, **kwargs: Any
                            offline_threading_id: Optional[str] = None, **kwargs: Any
-                           ) -> CommandResponse:
+                           ) -> Optional[CommandResponse]:
         client_context = client_context or str(uuid4())
         client_context = client_context or str(uuid4())
         req = {
         req = {
             "thread_id": thread_id,
             "thread_id": thread_id,
@@ -516,9 +529,19 @@ class AndroidMQTT:
             # "device_id": self.state.cookies["ig_did"],
             # "device_id": self.state.cookies["ig_did"],
             **kwargs,
             **kwargs,
         }
         }
-        resp = await self.request(RealtimeTopic.SEND_MESSAGE, RealtimeTopic.SEND_MESSAGE_RESPONSE,
-                                  payload=req)
-        return CommandResponse.parse_json(resp.payload.decode("utf-8"))
+        if action in (ThreadAction.MARK_SEEN,):
+            # Some commands don't have client_context in the response, so we can't properly match
+            # them to the requests. We probably don't need the data, so just ignore it.
+            await self.publish(RealtimeTopic.SEND_MESSAGE, payload=req)
+            return None
+        else:
+            fut = asyncio.Future()
+            self._message_response_waiters[client_context] = fut
+            await self.publish(RealtimeTopic.SEND_MESSAGE, req)
+            self.log.trace(f"Request published to {RealtimeTopic.SEND_MESSAGE}, "
+                           f"waiting for response {RealtimeTopic.SEND_MESSAGE_RESPONSE}")
+            resp = await fut
+            return CommandResponse.parse_json(resp.payload.decode("utf-8"))
 
 
     def send_item(self, thread_id: str, item_type: ThreadItemType, shh_mode: bool = False,
     def send_item(self, thread_id: str, item_type: ThreadItemType, shh_mode: bool = False,
                   client_context: Optional[str] = None, offline_threading_id: Optional[str] = None,
                   client_context: Optional[str] = None, offline_threading_id: Optional[str] = None,
@@ -591,7 +614,7 @@ class AndroidMQTT:
                               offline_threading_id=offline_threading_id)
                               offline_threading_id=offline_threading_id)
 
 
     def mark_seen(self, thread_id: str, item_id: str, client_context: Optional[str] = None,
     def mark_seen(self, thread_id: str, item_id: str, client_context: Optional[str] = None,
-                  offline_threading_id: Optional[str] = None) -> Awaitable[CommandResponse]:
+                  offline_threading_id: Optional[str] = None) -> Awaitable[None]:
         return self.send_command(thread_id, item_id=item_id, action=ThreadAction.MARK_SEEN,
         return self.send_command(thread_id, item_id=item_id, action=ThreadAction.MARK_SEEN,
                                  client_context=client_context,
                                  client_context=client_context,
                                  offline_threading_id=offline_threading_id)
                                  offline_threading_id=offline_threading_id)