Эх сурвалжийг харах

Only send one message at a time to catch all responses

Tulir Asokan 2 жил өмнө
parent
commit
5ee1c0d1af

+ 43 - 24
mauigpapi/mqtt/conn.py

@@ -91,7 +91,9 @@ class AndroidMQTT:
     _publish_waiters: dict[int, asyncio.Future]
     _response_waiters: dict[RealtimeTopic, asyncio.Future]
     _response_waiter_locks: dict[RealtimeTopic, asyncio.Lock]
-    _message_response_waiters: dict[str, asyncio.Future]
+    _message_response_waiter_lock: asyncio.Lock
+    _message_response_waiter_id: str | None
+    _message_response_waiter: asyncio.Future | None
     _disconnect_error: Exception | None
     _event_handlers: dict[Type[T], list[Callable[[T], Awaitable[None]]]]
     _outgoing_events: asyncio.Queue
@@ -112,7 +114,9 @@ class AndroidMQTT:
         self._iris_snapshot_at_ms = None
         self._publish_waiters = {}
         self._response_waiters = {}
-        self._message_response_waiters = {}
+        self._message_response_waiter_lock = asyncio.Lock()
+        self._message_response_waiter_id = None
+        self._message_response_waiter = None
         self._disconnect_error = None
         self._response_waiter_locks = defaultdict(lambda: asyncio.Lock())
         self._event_handlers = defaultdict(lambda: [])
@@ -449,6 +453,30 @@ class AndroidMQTT:
         for evt in self._parse_realtime_sub_item(topic, parsed_json):
             self._loop.create_task(self._dispatch(evt))
 
+    def _handle_send_response(self, message: pmc.MQTTMessage) -> None:
+        data = json.loads(message.payload.decode("utf-8"))
+        try:
+            ccid = data["payload"]["client_context"]
+        except KeyError:
+            self.log.warning(
+                "Didn't find client_context in send message response: %s", message.payload
+            )
+            ccid = self._message_response_waiter_id
+        else:
+            if ccid != self._message_response_waiter_id:
+                self.log.error(
+                    "Mismatching client_context in send message response (%s != %s)",
+                    ccid,
+                    self._message_response_waiter_id,
+                )
+                return
+        if self._message_response_waiter and not self._message_response_waiter.done():
+            self.log.debug("Got response to %s: %s", ccid, message.payload)
+            self._message_response_waiter.set_result(message)
+            self._message_response_waiter = None
+        else:
+            self.log.warning("Didn't find task waiting for response %s", message.payload)
+
     def _on_message_handler(self, client: MQTToTClient, _: Any, message: pmc.MQTTMessage) -> None:
         try:
             topic = RealtimeTopic.decode(message.topic)
@@ -461,19 +489,7 @@ class AndroidMQTT:
             elif topic == RealtimeTopic.REALTIME_SUB:
                 self._on_realtime_sub(message.payload)
             elif topic == RealtimeTopic.SEND_MESSAGE_RESPONSE:
-                try:
-                    data = json.loads(message.payload.decode("utf-8"))
-                    ccid = data["payload"]["client_context"]
-                    waiter = self._message_response_waiters.pop(ccid)
-                except KeyError as e:
-                    self.log.debug(
-                        "No handler for send message response: %s (missing %s key)",
-                        message.payload,
-                        e,
-                    )
-                else:
-                    self.log.trace("Got response to %s: %s", ccid, message.payload)
-                    waiter.set_result(message)
+                self._handle_send_response(message)
             else:
                 try:
                     waiter = self._response_waiters.pop(topic)
@@ -708,20 +724,23 @@ class AndroidMQTT:
             # "device_id": self.state.cookies["ig_did"],
             **kwargs,
         }
-        if action in (ThreadAction.MARK_SEEN,):
-            # Some commands don't have client_context in the response, so we can't properly match
-            # them to the requests. We probably don't need the data, so just ignore it.
-            await self.publish(RealtimeTopic.SEND_MESSAGE, payload=req)
-            return None
-        else:
-            fut = asyncio.Future()
-            self._message_response_waiters[client_context] = fut
+        lock_start = time.monotonic()
+        async with self._message_response_waiter_lock:
+            lock_wait_dur = time.monotonic() - lock_start
+            if lock_wait_dur > 1:
+                self.log.debug(f"Waited {lock_wait_dur:.3f} seconds to send {client_context}")
+            fut = self._message_response_waiter = asyncio.Future()
+            self._message_response_waiter_id = client_context
             await self.publish(RealtimeTopic.SEND_MESSAGE, req)
             self.log.trace(
                 f"Request published to {RealtimeTopic.SEND_MESSAGE}, "
                 f"waiting for response {RealtimeTopic.SEND_MESSAGE_RESPONSE}"
             )
-            resp = await fut
+            try:
+                resp = await asyncio.wait_for(fut, timeout=30000)
+            except asyncio.TimeoutError:
+                self.log.error(f"Request with ID {client_context} timed out!")
+                raise
             return CommandResponse.parse_json(resp.payload.decode("utf-8"))
 
     def send_item(

+ 14 - 2
mauigpapi/types/mqtt.py

@@ -62,10 +62,22 @@ class CommandResponsePayload(SerializableAttrs):
 
 @dataclass(kw_only=True)
 class CommandResponse(SerializableAttrs):
-    action: str
+    action: Optional[str] = None
     status: str
     status_code: Optional[str] = None
-    payload: CommandResponsePayload
+    message: Optional[str] = None
+    payload: Optional[CommandResponsePayload] = None
+
+    @property
+    def error_message(self) -> Optional[str]:
+        if self.status == "ok":
+            return None
+        if self.payload and self.payload.message:
+            return self.payload.message
+        elif self.message:
+            return self.message
+        else:
+            return "unknown response data"
 
 
 @dataclass(kw_only=True)

+ 2 - 2
mautrix_instagram/portal.py

@@ -558,7 +558,7 @@ class Portal(DBPortal, BasePortal):
         self.log.trace(f"Got response to message send {request_id}: {resp}")
         if resp.status != "ok":
             self.log.warning(f"Failed to handle {event_id}: {resp}")
-            raise Exception(f"Failed to handle event. Error: {resp.payload.message}")
+            raise Exception(f"Failed to handle event. Error: {resp.error_message}")
         else:
             self._msgid_dedup.appendleft(resp.payload.item_id)
             try:
@@ -621,7 +621,7 @@ class Portal(DBPortal, BasePortal):
                 self.thread_id, item_id=message.item_id, emoji=emoji
             )
             if resp.status != "ok":
-                if resp.payload.message == "invalid unicode emoji":
+                if resp.payload and resp.payload.message == "invalid unicode emoji":
                     # Instagram doesn't support this reaction. Notify the user, and redact it
                     # so that it doesn't get confusing.
                     await self.main_intent.redact(self.mxid, event_id, reason="Unsupported emoji")