Browse Source

Track remote state and subscribe failure more correctly

Max Sandholm 2 years ago
parent
commit
a432383445
2 changed files with 14 additions and 4 deletions
  1. 2 0
      mautrix_signal/signal.py
  2. 12 4
      mautrix_signal/user.py

+ 2 - 0
mautrix_signal/signal.py

@@ -415,6 +415,8 @@ class SignalHandler(SignaldClient):
                     f"Successfully subscribed {user.username}, running sync in background"
                     f"Successfully subscribed {user.username}, running sync in background"
                 )
                 )
                 background_task.create(user.sync())
                 background_task.create(user.sync())
+            else:
+                user.username = None
         if self.delete_unknown_accounts:
         if self.delete_unknown_accounts:
             self.log.debug("Checking for unknown accounts to delete")
             self.log.debug("Checking for unknown accounts to delete")
             for account in await self.list_accounts():
             for account in await self.list_accounts():

+ 12 - 4
mautrix_signal/user.py

@@ -73,6 +73,7 @@ class User(DBUser, BaseUser):
     _sync_lock: asyncio.Lock
     _sync_lock: asyncio.Lock
     _notice_room_lock: asyncio.Lock
     _notice_room_lock: asyncio.Lock
     _connected: bool
     _connected: bool
+    _state_id: str
     _websocket_connection_state: BridgeStateEvent | None
     _websocket_connection_state: BridgeStateEvent | None
     _latest_non_transient_bridge_state: datetime | None
     _latest_non_transient_bridge_state: datetime | None
 
 
@@ -88,6 +89,7 @@ class User(DBUser, BaseUser):
         self._notice_room_lock = asyncio.Lock()
         self._notice_room_lock = asyncio.Lock()
         self._sync_lock = asyncio.Lock()
         self._sync_lock = asyncio.Lock()
         self._connected = False
         self._connected = False
+        self._state_id = self.username
         self._websocket_connection_state = None
         self._websocket_connection_state = None
         perms = self.config.get_permissions(mxid)
         perms = self.config.get_permissions(mxid)
         self.relay_whitelisted, self.is_whitelisted, self.is_admin, self.permission_level = perms
         self.relay_whitelisted, self.is_whitelisted, self.is_admin, self.permission_level = perms
@@ -134,7 +136,7 @@ class User(DBUser, BaseUser):
     async def fill_bridge_state(self, state: BridgeState) -> None:
     async def fill_bridge_state(self, state: BridgeState) -> None:
         await super().fill_bridge_state(state)
         await super().fill_bridge_state(state)
         if not state.remote_id:
         if not state.remote_id:
-            state.remote_id = self.username
+            state.remote_id = self._state_id
         if self.address:
         if self.address:
             puppet = await self.get_puppet()
             puppet = await self.get_puppet()
             state.remote_name = puppet.name or self.username
             state.remote_name = puppet.name or self.username
@@ -151,6 +153,7 @@ class User(DBUser, BaseUser):
 
 
     async def handle_auth_failure(self, e: Exception) -> None:
     async def handle_auth_failure(self, e: Exception) -> None:
         if isinstance(e, AuthorizationFailedError):
         if isinstance(e, AuthorizationFailedError):
+            self.username = None
             await self.push_bridge_state(BridgeStateEvent.BAD_CREDENTIALS, error=str(e))
             await self.push_bridge_state(BridgeStateEvent.BAD_CREDENTIALS, error=str(e))
 
 
     async def get_puppet(self) -> pu.Puppet | None:
     async def get_puppet(self) -> pu.Puppet | None:
@@ -168,9 +171,12 @@ class User(DBUser, BaseUser):
         self.uuid = account.address.uuid
         self.uuid = account.address.uuid
         self._add_to_cache()
         self._add_to_cache()
         await self.update()
         await self.update()
-        await self.bridge.signal.subscribe(self.username)
-        background_task.create(self.sync())
-        self._track_metric(METRIC_LOGGED_IN, True)
+        if await self.bridge.signal.subscribe(self.username):
+            background_task.create(self.sync())
+            self._track_metric(METRIC_LOGGED_IN, True)
+            self.log.debug("Successfully subscribed")
+        else:
+            self.username = None
 
 
     def on_websocket_connection_state_change(
     def on_websocket_connection_state_change(
         self, evt: WebsocketConnectionStateChangeEvent
         self, evt: WebsocketConnectionStateChangeEvent
@@ -255,6 +261,8 @@ class User(DBUser, BaseUser):
             self.log.info("Websocket state unchanged, not reporting new bridge state")
             self.log.info("Websocket state unchanged, not reporting new bridge state")
             self._latest_non_transient_bridge_state = now
             self._latest_non_transient_bridge_state = now
         else:
         else:
+            if bridge_state == BridgeStateEvent.BAD_CREDENTIALS:
+                self.username = None
             background_task.create(self.push_bridge_state(bridge_state))
             background_task.create(self.push_bridge_state(bridge_state))
             self._latest_non_transient_bridge_state = now
             self._latest_non_transient_bridge_state = now
             self._websocket_connection_state = bridge_state
             self._websocket_connection_state = bridge_state