Parcourir la source

Add logout command/api

Tulir Asokan il y a 4 ans
Parent
commit
0b1c63f16d

+ 8 - 0
mausignald/signald.py

@@ -80,6 +80,14 @@ class SignaldClient(SignaldRPCClient):
             self.log.debug("Failed to subscribe to %s: %s", username, e)
             return False
 
+    async def unsubscribe(self, username: str) -> bool:
+        try:
+            await self.request("unsubscribe", "unsubscribed", username=username)
+            return True
+        except UnexpectedError as e:
+            self.log.debug("Failed to unsubscribe from %s: %s", username, e)
+            return False
+
     async def register(self, phone: str, voice: bool = False, captcha: Optional[str] = None
                        ) -> str:
         resp = await self.request("register", "verification_required", username=phone,

+ 10 - 0
mautrix_signal/commands/auth.py

@@ -101,6 +101,16 @@ async def enter_register_code(evt: CommandEvent) -> None:
         await evt.reply(f"Successfully logged in as {pu.Puppet.fmt_phone(evt.sender.username)}")
 
 
+@command_handler(needs_auth=True, management_only=True, help_section=SECTION_AUTH,
+                 help_text="Remove all local data about your Signal link")
+async def logout(evt: CommandEvent) -> None:
+    if not evt.sender.username:
+        await evt.reply("You're not logged in")
+        return
+    await evt.sender.logout()
+    await evt.reply("Successfully logged out")
+
+
 @command_handler(needs_auth=True, management_only=True, help_args="<_access token_>",
                  help_section=SECTION_AUTH, help_text="Replace your Signal account's Matrix puppet"
                                                       " with your Matrix account")

+ 1 - 0
mautrix_signal/config.py

@@ -50,6 +50,7 @@ class Config(BaseBridgeConfig):
         copy("signal.socket_path")
         copy("signal.outgoing_attachment_dir")
         copy("signal.avatar_dir")
+        copy("signal.data_dir")
         copy("signal.remove_file_after_handling")
 
         copy("metrics.enabled")

+ 2 - 0
mautrix_signal/example-config.yaml

@@ -66,6 +66,8 @@ signal:
     outgoing_attachment_dir: /tmp
     # Directory where signald stores avatars for groups.
     avatar_dir: ~/.config/signald/avatars
+    # Directory where signald stores auth data. Used to delete data when logging out.
+    data_dir: ~/.config/signald/data
     # Whether or not message attachments should be removed from disk after they're bridged.
     remove_file_after_handling: true
 

+ 23 - 1
mautrix_signal/user.py

@@ -13,10 +13,12 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import Dict, Optional, AsyncGenerator, Union, TYPE_CHECKING, cast
+from typing import Dict, Optional, AsyncGenerator, TYPE_CHECKING, cast
 from collections import defaultdict
 from uuid import UUID
 import asyncio
+import os.path
+import shutil
 
 from mausignald.types import Account, Address, Contact, Group, GroupV2, ListenEvent, ListenAction
 from mautrix.bridge import BaseUser
@@ -74,6 +76,26 @@ class User(DBUser, BaseUser):
     async def is_logged_in(self) -> bool:
         return bool(self.username)
 
+    async def logout(self) -> None:
+        if not self.username:
+            return
+        username = self.username
+        self.username = None
+        self.uuid = None
+        await self.update()
+        await self.bridge.signal.unsubscribe(username)
+        # Wait a while for signald to finish disconnecting
+        await asyncio.sleep(1)
+        path = os.path.join(self.config["signal.data_dir"], username)
+        extra_dir = f"{path}.d/"
+        try:
+            self.log.debug("Removing %s", path)
+            os.remove(path)
+        except FileNotFoundError as e:
+            self.log.warning(f"Failed to remove signald data file: {e}")
+        self.log.debug("Removing %s", extra_dir)
+        shutil.rmtree(extra_dir, ignore_errors=True)
+
     async def on_signin(self, account: Account) -> None:
         self.username = account.username
         self.uuid = account.uuid

+ 10 - 7
mautrix_signal/web/provisioning_api.py

@@ -45,12 +45,12 @@ class ProvisioningAPI:
         self.app.router.add_options("/api/link/wait", self.login_options)
         # self.app.router.add_options("/api/register", self.login_options)
         # self.app.router.add_options("/api/register/code", self.login_options)
-        # self.app.router.add_options("/api/logout", self.login_options)
+        self.app.router.add_options("/api/logout", self.login_options)
         self.app.router.add_post("/api/link", self.link)
         self.app.router.add_post("/api/link/wait", self.link_wait)
         # self.app.router.add_post("/api/register", self.register)
         # self.app.router.add_post("/api/register/code", self.register_code)
-        # self.app.router.add_post("/api/logout", self.logout)
+        self.app.router.add_post("/api/logout", self.logout)
 
     @property
     def _acao_headers(self) -> Dict[str, str]:
@@ -106,7 +106,7 @@ class ProvisioningAPI:
             data["signal"] = {
                 "number": number or user.username,
                 "uuid": str(uuid or user.uuid or ""),
-                "name": profile.name if profile else null,
+                "name": profile.name if profile else None,
             }
         return web.json_response(data, headers=self._acao_headers)
 
@@ -151,7 +151,10 @@ class ProvisioningAPI:
             "uuid": str(account.uuid),
         })
 
-    # async def logout(self, request: web.Request) -> web.Response:
-    #     user = await self.check_token(request)
-    #     await user.()
-    #     return web.json_response({}, headers=self._acao_headers)
+    async def logout(self, request: web.Request) -> web.Response:
+        user = await self.check_token(request)
+        if not user.username:
+            raise web.HTTPNotFound(text='''{"error": "You're not logged in"}''',
+                                   headers=self._headers)
+        await user.logout()
+        return web.json_response({}, headers=self._acao_headers)