historysync.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
  2. // Copyright (C) 2022 Tulir Asokan, Sumner Evans
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package database
  17. import (
  18. "database/sql"
  19. "errors"
  20. "fmt"
  21. "time"
  22. _ "github.com/mattn/go-sqlite3"
  23. "go.mau.fi/util/dbutil"
  24. waProto "go.mau.fi/whatsmeow/binary/proto"
  25. "google.golang.org/protobuf/proto"
  26. log "maunium.net/go/maulogger/v2"
  27. "maunium.net/go/mautrix/id"
  28. )
  29. type HistorySyncQuery struct {
  30. db *Database
  31. log log.Logger
  32. }
  33. type HistorySyncConversation struct {
  34. db *Database
  35. log log.Logger
  36. UserID id.UserID
  37. ConversationID string
  38. PortalKey *PortalKey
  39. LastMessageTimestamp time.Time
  40. MuteEndTime time.Time
  41. Archived bool
  42. Pinned uint32
  43. DisappearingMode waProto.DisappearingMode_Initiator
  44. EndOfHistoryTransferType waProto.Conversation_EndOfHistoryTransferType
  45. EphemeralExpiration *uint32
  46. MarkedAsUnread bool
  47. UnreadCount uint32
  48. }
  49. func (hsq *HistorySyncQuery) NewConversation() *HistorySyncConversation {
  50. return &HistorySyncConversation{
  51. db: hsq.db,
  52. log: hsq.log,
  53. PortalKey: &PortalKey{},
  54. }
  55. }
  56. func (hsq *HistorySyncQuery) NewConversationWithValues(
  57. userID id.UserID,
  58. conversationID string,
  59. portalKey *PortalKey,
  60. lastMessageTimestamp,
  61. muteEndTime uint64,
  62. archived bool,
  63. pinned uint32,
  64. disappearingMode waProto.DisappearingMode_Initiator,
  65. endOfHistoryTransferType waProto.Conversation_EndOfHistoryTransferType,
  66. ephemeralExpiration *uint32,
  67. markedAsUnread bool,
  68. unreadCount uint32) *HistorySyncConversation {
  69. return &HistorySyncConversation{
  70. db: hsq.db,
  71. log: hsq.log,
  72. UserID: userID,
  73. ConversationID: conversationID,
  74. PortalKey: portalKey,
  75. LastMessageTimestamp: time.Unix(int64(lastMessageTimestamp), 0).UTC(),
  76. MuteEndTime: time.Unix(int64(muteEndTime), 0).UTC(),
  77. Archived: archived,
  78. Pinned: pinned,
  79. DisappearingMode: disappearingMode,
  80. EndOfHistoryTransferType: endOfHistoryTransferType,
  81. EphemeralExpiration: ephemeralExpiration,
  82. MarkedAsUnread: markedAsUnread,
  83. UnreadCount: unreadCount,
  84. }
  85. }
  86. const (
  87. getNMostRecentConversations = `
  88. SELECT user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count
  89. FROM history_sync_conversation
  90. WHERE user_mxid=$1
  91. ORDER BY last_message_timestamp DESC
  92. LIMIT $2
  93. `
  94. getConversationByPortal = `
  95. SELECT user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count
  96. FROM history_sync_conversation
  97. WHERE user_mxid=$1
  98. AND portal_jid=$2
  99. AND portal_receiver=$3
  100. `
  101. )
  102. func (hsc *HistorySyncConversation) Upsert() {
  103. _, err := hsc.db.Exec(`
  104. INSERT INTO history_sync_conversation (user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count)
  105. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
  106. ON CONFLICT (user_mxid, conversation_id)
  107. DO UPDATE SET
  108. last_message_timestamp=CASE
  109. WHEN EXCLUDED.last_message_timestamp > history_sync_conversation.last_message_timestamp THEN EXCLUDED.last_message_timestamp
  110. ELSE history_sync_conversation.last_message_timestamp
  111. END,
  112. end_of_history_transfer_type=EXCLUDED.end_of_history_transfer_type
  113. `,
  114. hsc.UserID,
  115. hsc.ConversationID,
  116. hsc.PortalKey.JID.String(),
  117. hsc.PortalKey.Receiver.String(),
  118. hsc.LastMessageTimestamp,
  119. hsc.Archived,
  120. hsc.Pinned,
  121. hsc.MuteEndTime,
  122. hsc.DisappearingMode,
  123. hsc.EndOfHistoryTransferType,
  124. hsc.EphemeralExpiration,
  125. hsc.MarkedAsUnread,
  126. hsc.UnreadCount)
  127. if err != nil {
  128. hsc.log.Warnfln("Failed to insert history sync conversation %s/%s: %v", hsc.UserID, hsc.ConversationID, err)
  129. }
  130. }
  131. func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) *HistorySyncConversation {
  132. err := row.Scan(
  133. &hsc.UserID,
  134. &hsc.ConversationID,
  135. &hsc.PortalKey.JID,
  136. &hsc.PortalKey.Receiver,
  137. &hsc.LastMessageTimestamp,
  138. &hsc.Archived,
  139. &hsc.Pinned,
  140. &hsc.MuteEndTime,
  141. &hsc.DisappearingMode,
  142. &hsc.EndOfHistoryTransferType,
  143. &hsc.EphemeralExpiration,
  144. &hsc.MarkedAsUnread,
  145. &hsc.UnreadCount)
  146. if err != nil {
  147. if !errors.Is(err, sql.ErrNoRows) {
  148. hsc.log.Errorln("Database scan failed:", err)
  149. }
  150. return nil
  151. }
  152. return hsc
  153. }
  154. func (hsq *HistorySyncQuery) GetRecentConversations(userID id.UserID, n int) (conversations []*HistorySyncConversation) {
  155. nPtr := &n
  156. // Negative limit on SQLite means unlimited, but Postgres prefers a NULL limit.
  157. if n < 0 && hsq.db.Dialect == dbutil.Postgres {
  158. nPtr = nil
  159. }
  160. rows, err := hsq.db.Query(getNMostRecentConversations, userID, nPtr)
  161. defer rows.Close()
  162. if err != nil || rows == nil {
  163. return nil
  164. }
  165. for rows.Next() {
  166. conversations = append(conversations, hsq.NewConversation().Scan(rows))
  167. }
  168. return
  169. }
  170. func (hsq *HistorySyncQuery) GetConversation(userID id.UserID, portalKey PortalKey) (conversation *HistorySyncConversation) {
  171. rows, err := hsq.db.Query(getConversationByPortal, userID, portalKey.JID, portalKey.Receiver)
  172. defer rows.Close()
  173. if err != nil || rows == nil {
  174. return nil
  175. }
  176. if rows.Next() {
  177. conversation = hsq.NewConversation().Scan(rows)
  178. }
  179. return
  180. }
  181. func (hsq *HistorySyncQuery) DeleteAllConversations(userID id.UserID) {
  182. _, err := hsq.db.Exec("DELETE FROM history_sync_conversation WHERE user_mxid=$1", userID)
  183. if err != nil {
  184. hsq.log.Warnfln("Failed to delete historical chat info for %s/%s: %v", userID, err)
  185. }
  186. }
  187. const (
  188. getMessagesBetween = `
  189. SELECT data FROM history_sync_message
  190. WHERE user_mxid=$1 AND conversation_id=$2
  191. %s
  192. ORDER BY timestamp DESC
  193. %s
  194. `
  195. deleteMessagesBetweenExclusive = `
  196. DELETE FROM history_sync_message
  197. WHERE user_mxid=$1 AND conversation_id=$2 AND timestamp<$3 AND timestamp>$4
  198. `
  199. )
  200. type HistorySyncMessage struct {
  201. db *Database
  202. log log.Logger
  203. UserID id.UserID
  204. ConversationID string
  205. MessageID string
  206. Timestamp time.Time
  207. Data []byte
  208. }
  209. func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversationID, messageID string, message *waProto.HistorySyncMsg) (*HistorySyncMessage, error) {
  210. msgData, err := proto.Marshal(message)
  211. if err != nil {
  212. return nil, err
  213. }
  214. return &HistorySyncMessage{
  215. db: hsq.db,
  216. log: hsq.log,
  217. UserID: userID,
  218. ConversationID: conversationID,
  219. MessageID: messageID,
  220. Timestamp: time.Unix(int64(message.Message.GetMessageTimestamp()), 0),
  221. Data: msgData,
  222. }, nil
  223. }
  224. func (hsm *HistorySyncMessage) Insert() error {
  225. _, err := hsm.db.Exec(`
  226. INSERT INTO history_sync_message (user_mxid, conversation_id, message_id, timestamp, data, inserted_time)
  227. VALUES ($1, $2, $3, $4, $5, $6)
  228. ON CONFLICT (user_mxid, conversation_id, message_id) DO NOTHING
  229. `, hsm.UserID, hsm.ConversationID, hsm.MessageID, hsm.Timestamp, hsm.Data, time.Now())
  230. return err
  231. }
  232. func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) (messages []*waProto.WebMessageInfo) {
  233. whereClauses := ""
  234. args := []interface{}{userID, conversationID}
  235. argNum := 3
  236. if startTime != nil {
  237. whereClauses += fmt.Sprintf(" AND timestamp >= $%d", argNum)
  238. args = append(args, startTime)
  239. argNum++
  240. }
  241. if endTime != nil {
  242. whereClauses += fmt.Sprintf(" AND timestamp <= $%d", argNum)
  243. args = append(args, endTime)
  244. }
  245. limitClause := ""
  246. if limit > 0 {
  247. limitClause = fmt.Sprintf("LIMIT %d", limit)
  248. }
  249. rows, err := hsq.db.Query(fmt.Sprintf(getMessagesBetween, whereClauses, limitClause), args...)
  250. defer rows.Close()
  251. if err != nil || rows == nil {
  252. if err != nil && !errors.Is(err, sql.ErrNoRows) {
  253. hsq.log.Warnfln("Failed to query messages between range: %v", err)
  254. }
  255. return nil
  256. }
  257. var msgData []byte
  258. for rows.Next() {
  259. err = rows.Scan(&msgData)
  260. if err != nil {
  261. hsq.log.Errorfln("Database scan failed: %v", err)
  262. continue
  263. }
  264. var historySyncMsg waProto.HistorySyncMsg
  265. err = proto.Unmarshal(msgData, &historySyncMsg)
  266. if err != nil {
  267. hsq.log.Errorfln("Failed to unmarshal history sync message: %v", err)
  268. continue
  269. }
  270. messages = append(messages, historySyncMsg.Message)
  271. }
  272. return
  273. }
  274. func (hsq *HistorySyncQuery) DeleteMessages(userID id.UserID, conversationID string, messages []*waProto.WebMessageInfo) error {
  275. newest := messages[0]
  276. beforeTS := time.Unix(int64(newest.GetMessageTimestamp())+1, 0)
  277. oldest := messages[len(messages)-1]
  278. afterTS := time.Unix(int64(oldest.GetMessageTimestamp())-1, 0)
  279. _, err := hsq.db.Exec(deleteMessagesBetweenExclusive, userID, conversationID, beforeTS, afterTS)
  280. return err
  281. }
  282. func (hsq *HistorySyncQuery) DeleteAllMessages(userID id.UserID) {
  283. _, err := hsq.db.Exec("DELETE FROM history_sync_message WHERE user_mxid=$1", userID)
  284. if err != nil {
  285. hsq.log.Warnfln("Failed to delete historical messages for %s: %v", userID, err)
  286. }
  287. }
  288. func (hsq *HistorySyncQuery) DeleteAllMessagesForPortal(userID id.UserID, portalKey PortalKey) {
  289. _, err := hsq.db.Exec(`
  290. DELETE FROM history_sync_message
  291. WHERE user_mxid=$1 AND conversation_id=$2
  292. `, userID, portalKey.JID)
  293. if err != nil {
  294. hsq.log.Warnfln("Failed to delete historical messages for %s/%s: %v", userID, portalKey.JID, err)
  295. }
  296. }
  297. func (hsq *HistorySyncQuery) DeleteConversation(userID id.UserID, jid string) {
  298. // This will also clear history_sync_message as there's a foreign key constraint
  299. _, err := hsq.db.Exec(`
  300. DELETE FROM history_sync_conversation
  301. WHERE user_mxid=$1 AND conversation_id=$2
  302. `, userID, jid)
  303. if err != nil {
  304. hsq.log.Warnfln("Failed to delete historical messages for %s/%s: %v", userID, jid, err)
  305. }
  306. }