|
@@ -3,13 +3,13 @@
|
|
|
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
|
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
|
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
-from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, TypeVar, Type
|
|
|
+from typing import Union, Optional, List, Dict, Any, Callable, Awaitable, Set, TypeVar, Type
|
|
|
from uuid import uuid4
|
|
|
import asyncio
|
|
|
|
|
|
from mautrix.util.logging import TraceLogger
|
|
|
|
|
|
-from .rpc import SignaldRPCClient
|
|
|
+from .rpc import CONNECT_EVENT, SignaldRPCClient
|
|
|
from .errors import UnexpectedError, UnexpectedResponse, make_linking_error
|
|
|
from .types import (Address, Quote, Attachment, Reaction, Account, Message, Contact, Group,
|
|
|
Profile, GroupID, GetIdentitiesResponse, ListenEvent, ListenAction, GroupV2)
|
|
@@ -20,16 +20,19 @@ EventHandler = Callable[[T], Awaitable[None]]
|
|
|
|
|
|
class SignaldClient(SignaldRPCClient):
|
|
|
_event_handlers: Dict[Type[T], List[EventHandler]]
|
|
|
+ _subscriptions: Set[str]
|
|
|
|
|
|
def __init__(self, socket_path: str = "/var/run/signald/signald.sock",
|
|
|
log: Optional[TraceLogger] = None,
|
|
|
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
|
|
super().__init__(socket_path, log, loop)
|
|
|
self._event_handlers = {}
|
|
|
+ self._subscriptions = set()
|
|
|
self.add_rpc_handler("message", self._parse_message)
|
|
|
self.add_rpc_handler("listen_started", self._parse_listen_start)
|
|
|
self.add_rpc_handler("listen_stopped", self._parse_listen_stop)
|
|
|
self.add_rpc_handler("version", self._log_version)
|
|
|
+ self.add_rpc_handler(CONNECT_EVENT, self._resubscribe)
|
|
|
|
|
|
def add_event_handler(self, event_class: Type[T], handler: EventHandler) -> None:
|
|
|
self._event_handlers.setdefault(event_class, []).append(handler)
|
|
@@ -75,6 +78,7 @@ class SignaldClient(SignaldRPCClient):
|
|
|
async def subscribe(self, username: str) -> bool:
|
|
|
try:
|
|
|
await self.request("subscribe", "subscribed", username=username)
|
|
|
+ self._subscriptions.add(username)
|
|
|
return True
|
|
|
except UnexpectedError as e:
|
|
|
self.log.debug("Failed to subscribe to %s: %s", username, e)
|
|
@@ -83,11 +87,18 @@ class SignaldClient(SignaldRPCClient):
|
|
|
async def unsubscribe(self, username: str) -> bool:
|
|
|
try:
|
|
|
await self.request("unsubscribe", "unsubscribed", username=username)
|
|
|
+ self._subscriptions.remove(username)
|
|
|
return True
|
|
|
except UnexpectedError as e:
|
|
|
self.log.debug("Failed to unsubscribe from %s: %s", username, e)
|
|
|
return False
|
|
|
|
|
|
+ async def _resubscribe(self, unused_data: Dict[str, Any]) -> None:
|
|
|
+ if self._subscriptions:
|
|
|
+ self.log.debug("Resubscribing to users")
|
|
|
+ for username in list(self._subscriptions):
|
|
|
+ await self.subscribe(username)
|
|
|
+
|
|
|
async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
|
|
|
) -> str:
|
|
|
resp = await self.request("register", "verification_required", username=phone,
|