message.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # mautrix-signal - A Matrix-Signal puppeting bridge
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from __future__ import annotations
  17. from typing import TYPE_CHECKING, ClassVar
  18. from uuid import UUID
  19. from attr import dataclass
  20. import asyncpg
  21. from mausignald.types import GroupID
  22. from mautrix.types import EventID, RoomID
  23. from mautrix.util.async_db import Database, Scheme
  24. from .util import ensure_uuid
  25. fake_db = Database.create("") if TYPE_CHECKING else None
  26. @dataclass
  27. class Message:
  28. db: ClassVar[Database] = fake_db
  29. mxid: EventID
  30. mx_room: RoomID
  31. sender: UUID
  32. timestamp: int
  33. signal_chat_id: GroupID | UUID
  34. signal_receiver: str
  35. async def insert(self) -> None:
  36. q = """
  37. INSERT INTO message (mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver)
  38. VALUES ($1, $2, $3, $4, $5, $6)
  39. """
  40. await self.db.execute(
  41. q,
  42. self.mxid,
  43. self.mx_room,
  44. self.sender,
  45. self.timestamp,
  46. str(self.signal_chat_id),
  47. self.signal_receiver,
  48. )
  49. async def delete(self) -> None:
  50. q = """
  51. DELETE FROM message
  52. WHERE sender=$1 AND timestamp=$2 AND signal_chat_id=$3 AND signal_receiver=$4
  53. """
  54. await self.db.execute(
  55. q,
  56. self.sender,
  57. self.timestamp,
  58. str(self.signal_chat_id),
  59. self.signal_receiver,
  60. )
  61. @classmethod
  62. async def delete_all(cls, room_id: RoomID) -> None:
  63. await cls.db.execute("DELETE FROM message WHERE mx_room=$1", room_id)
  64. @classmethod
  65. def _from_row(cls, row: asyncpg.Record | None) -> Message | None:
  66. if row is None:
  67. return None
  68. data = {**row}
  69. chat_id = data.pop("signal_chat_id")
  70. if data["signal_receiver"]:
  71. chat_id = ensure_uuid(chat_id)
  72. sender = ensure_uuid(data.pop("sender"))
  73. return cls(signal_chat_id=chat_id, sender=sender, **data)
  74. @classmethod
  75. async def get_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> Message | None:
  76. q = """
  77. SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message
  78. WHERE mxid=$1 AND mx_room=$2
  79. """
  80. return cls._from_row(await cls.db.fetchrow(q, mxid, mx_room))
  81. @classmethod
  82. async def get_by_signal_id(
  83. cls,
  84. sender: UUID,
  85. timestamp: int,
  86. signal_chat_id: GroupID | UUID,
  87. signal_receiver: str = "",
  88. ) -> Message | None:
  89. q = """
  90. SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message
  91. WHERE sender=$1 AND timestamp=$2 AND signal_chat_id=$3 AND signal_receiver=$4
  92. """
  93. return cls._from_row(
  94. await cls.db.fetchrow(q, sender, timestamp, str(signal_chat_id), signal_receiver)
  95. )
  96. @classmethod
  97. async def find_by_timestamps(cls, timestamps: list[int]) -> list[Message]:
  98. if cls.db.scheme in (Scheme.POSTGRES, Scheme.COCKROACH):
  99. q = """
  100. SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message
  101. WHERE timestamp=ANY($1)
  102. """
  103. rows = await cls.db.fetch(q, timestamps)
  104. else:
  105. placeholders = ", ".join("?" for _ in range(len(timestamps)))
  106. q = f"""
  107. SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message
  108. WHERE timestamp IN ({placeholders})
  109. """
  110. rows = await cls.db.fetch(q, *timestamps)
  111. return [cls._from_row(row) for row in rows]
  112. @classmethod
  113. async def find_by_sender_timestamp(cls, sender: UUID, timestamp: int) -> Message | None:
  114. q = """
  115. SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message
  116. WHERE sender=$1 AND timestamp=$2
  117. """
  118. return cls._from_row(await cls.db.fetchrow(q, sender, timestamp))
  119. @classmethod
  120. async def get_first_before(cls, mx_room: RoomID, timestamp: int) -> Message | None:
  121. q = """
  122. SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message
  123. WHERE mx_room=$1 AND timestamp <= $2
  124. ORDER BY timestamp DESC
  125. LIMIT 1
  126. """
  127. return cls._from_row(await cls.db.fetchrow(q, mx_room, timestamp))