فهرست منبع

Handle finding conflicting numbers in puppet insert

Tulir Asokan 2 سال پیش
والد
کامیت
21312b3012
2فایلهای تغییر یافته به همراه27 افزوده شده و 11 حذف شده
  1. 14 7
      mautrix_signal/db/puppet.py
  2. 13 4
      mautrix_signal/puppet.py

+ 14 - 7
mautrix_signal/db/puppet.py

@@ -23,7 +23,7 @@ from yarl import URL
 import asyncpg
 
 from mautrix.types import ContentURI, SyncToken, UserID
-from mautrix.util.async_db import Database
+from mautrix.util.async_db import Connection, Database
 
 fake_db = Database.create("") if TYPE_CHECKING else None
 
@@ -69,6 +69,13 @@ class Puppet:
             self._base_url_str,
         )
 
+    async def _delete_existing_number(self, conn: Connection) -> None:
+        if not self.number:
+            return
+        await conn.execute(
+            "UPDATE puppet SET number=null WHERE number=$1 AND uuid<>$2", self.number, self.uuid
+        )
+
     async def insert(self) -> None:
         q = """
         INSERT INTO puppet (uuid, number, name, name_quality, avatar_hash, avatar_url,
@@ -76,14 +83,14 @@ class Puppet:
                             custom_mxid, access_token, next_batch, base_url)
         VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
         """
-        await self.db.execute(q, *self._values)
+        async with self.db.acquire() as conn, conn.transaction():
+            await self._delete_existing_number(conn)
+            await self.db.execute(q, *self._values)
 
-    async def _set_number(self, number: str) -> None:
+    async def _update_number(self) -> None:
         async with self.db.acquire() as conn, conn.transaction():
-            await conn.execute(
-                "UPDATE puppet SET number=null WHERE number=$1 AND uuid<>$2", number, self.uuid
-            )
-            await conn.execute("UPDATE puppet SET number=$1 WHERE uuid=$2", number, self.uuid)
+            await self._delete_existing_number(conn)
+            await conn.execute("UPDATE puppet SET number=$1 WHERE uuid=$2", self.number, self.uuid)
 
     async def update(self) -> None:
         q = """

+ 13 - 4
mautrix_signal/puppet.py

@@ -154,8 +154,8 @@ class Puppet(DBPuppet, BasePuppet):
             if self.number:
                 self.by_number.pop(self.number, None)
             self.number = number
-            self.by_number[self.number] = self
-            await self._set_number(number)
+            self._add_number_to_cache()
+            await self._update_number()
 
     async def _migrate_memberships(self, prev_intent: IntentAPI, new_intent: IntentAPI) -> None:
         self.log.debug(f"Migrating memberships {prev_intent.mxid} -> {new_intent.mxid}")
@@ -334,10 +334,19 @@ class Puppet(DBPuppet, BasePuppet):
 
     # region Database getters
 
-    def _add_to_cache(self) -> None:
-        self.by_uuid[self.uuid] = self
+    def _add_number_to_cache(self) -> None:
         if self.number:
+            try:
+                existing = self.by_number[self.number]
+                if existing and existing.uuid != self.uuid and existing != self:
+                    existing.number = None
+            except KeyError:
+                pass
             self.by_number[self.number] = self
+
+    def _add_to_cache(self) -> None:
+        self.by_uuid[self.uuid] = self
+        self._add_number_to_cache()
         if self.custom_mxid:
             self.by_custom_mxid[self.custom_mxid] = self