historysync.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
  2. // Copyright (C) 2022 Tulir Asokan, Sumner Evans
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package database
  17. import (
  18. "database/sql"
  19. "errors"
  20. "fmt"
  21. "strconv"
  22. "strings"
  23. "time"
  24. waProto "go.mau.fi/whatsmeow/binary/proto"
  25. "google.golang.org/protobuf/proto"
  26. _ "github.com/mattn/go-sqlite3"
  27. log "maunium.net/go/maulogger/v2"
  28. "maunium.net/go/mautrix/id"
  29. )
  30. type HistorySyncQuery struct {
  31. db *Database
  32. log log.Logger
  33. }
  34. type HistorySyncConversation struct {
  35. db *Database
  36. log log.Logger
  37. UserID id.UserID
  38. ConversationID string
  39. PortalKey *PortalKey
  40. LastMessageTimestamp time.Time
  41. MuteEndTime time.Time
  42. Archived bool
  43. Pinned uint32
  44. DisappearingMode waProto.DisappearingMode_DisappearingModeInitiator
  45. EndOfHistoryTransferType waProto.Conversation_ConversationEndOfHistoryTransferType
  46. EphemeralExpiration *uint32
  47. MarkedAsUnread bool
  48. UnreadCount uint32
  49. }
  50. func (hsq *HistorySyncQuery) NewConversation() *HistorySyncConversation {
  51. return &HistorySyncConversation{
  52. db: hsq.db,
  53. log: hsq.log,
  54. PortalKey: &PortalKey{},
  55. }
  56. }
  57. func (hsq *HistorySyncQuery) NewConversationWithValues(
  58. userID id.UserID,
  59. conversationID string,
  60. portalKey *PortalKey,
  61. lastMessageTimestamp,
  62. muteEndTime uint64,
  63. archived bool,
  64. pinned uint32,
  65. disappearingMode waProto.DisappearingMode_DisappearingModeInitiator,
  66. endOfHistoryTransferType waProto.Conversation_ConversationEndOfHistoryTransferType,
  67. ephemeralExpiration *uint32,
  68. markedAsUnread bool,
  69. unreadCount uint32) *HistorySyncConversation {
  70. return &HistorySyncConversation{
  71. db: hsq.db,
  72. log: hsq.log,
  73. UserID: userID,
  74. ConversationID: conversationID,
  75. PortalKey: portalKey,
  76. LastMessageTimestamp: time.Unix(int64(lastMessageTimestamp), 0),
  77. MuteEndTime: time.Unix(int64(muteEndTime), 0),
  78. Archived: archived,
  79. Pinned: pinned,
  80. DisappearingMode: disappearingMode,
  81. EndOfHistoryTransferType: endOfHistoryTransferType,
  82. EphemeralExpiration: ephemeralExpiration,
  83. MarkedAsUnread: markedAsUnread,
  84. UnreadCount: unreadCount,
  85. }
  86. }
  87. const (
  88. getNMostRecentConversations = `
  89. SELECT user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count
  90. FROM history_sync_conversation
  91. WHERE user_mxid=$1
  92. ORDER BY last_message_timestamp DESC
  93. LIMIT $2
  94. `
  95. getConversationByPortal = `
  96. SELECT user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count
  97. FROM history_sync_conversation
  98. WHERE user_mxid=$1
  99. AND portal_jid=$2
  100. AND portal_receiver=$3
  101. `
  102. )
  103. func (hsc *HistorySyncConversation) Upsert() {
  104. _, err := hsc.db.Exec(`
  105. INSERT INTO history_sync_conversation (user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count)
  106. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
  107. ON CONFLICT (user_mxid, conversation_id)
  108. DO UPDATE SET
  109. portal_jid=EXCLUDED.portal_jid,
  110. portal_receiver=EXCLUDED.portal_receiver,
  111. last_message_timestamp=CASE
  112. WHEN EXCLUDED.last_message_timestamp > history_sync_conversation.last_message_timestamp THEN EXCLUDED.last_message_timestamp
  113. ELSE history_sync_conversation.last_message_timestamp
  114. END,
  115. archived=EXCLUDED.archived,
  116. pinned=EXCLUDED.pinned,
  117. mute_end_time=EXCLUDED.mute_end_time,
  118. disappearing_mode=EXCLUDED.disappearing_mode,
  119. end_of_history_transfer_type=EXCLUDED.end_of_history_transfer_type,
  120. ephemeral_expiration=EXCLUDED.ephemeral_expiration,
  121. marked_as_unread=EXCLUDED.marked_as_unread,
  122. unread_count=EXCLUDED.unread_count
  123. `,
  124. hsc.UserID,
  125. hsc.ConversationID,
  126. hsc.PortalKey.JID.String(),
  127. hsc.PortalKey.Receiver.String(),
  128. hsc.LastMessageTimestamp,
  129. hsc.Archived,
  130. hsc.Pinned,
  131. hsc.MuteEndTime,
  132. hsc.DisappearingMode,
  133. hsc.EndOfHistoryTransferType,
  134. hsc.EphemeralExpiration,
  135. hsc.MarkedAsUnread,
  136. hsc.UnreadCount)
  137. if err != nil {
  138. hsc.log.Warnfln("Failed to insert history sync conversation %s/%s: %v", hsc.UserID, hsc.ConversationID, err)
  139. }
  140. }
  141. func (hsc *HistorySyncConversation) Scan(row Scannable) *HistorySyncConversation {
  142. err := row.Scan(
  143. &hsc.UserID,
  144. &hsc.ConversationID,
  145. &hsc.PortalKey.JID,
  146. &hsc.PortalKey.Receiver,
  147. &hsc.LastMessageTimestamp,
  148. &hsc.Archived,
  149. &hsc.Pinned,
  150. &hsc.MuteEndTime,
  151. &hsc.DisappearingMode,
  152. &hsc.EndOfHistoryTransferType,
  153. &hsc.EphemeralExpiration,
  154. &hsc.MarkedAsUnread,
  155. &hsc.UnreadCount)
  156. if err != nil {
  157. if !errors.Is(err, sql.ErrNoRows) {
  158. hsc.log.Errorln("Database scan failed:", err)
  159. }
  160. return nil
  161. }
  162. return hsc
  163. }
  164. func (hsq *HistorySyncQuery) GetNMostRecentConversations(userID id.UserID, n int) (conversations []*HistorySyncConversation) {
  165. rows, err := hsq.db.Query(getNMostRecentConversations, userID, n)
  166. defer rows.Close()
  167. if err != nil || rows == nil {
  168. return nil
  169. }
  170. for rows.Next() {
  171. conversations = append(conversations, hsq.NewConversation().Scan(rows))
  172. }
  173. return
  174. }
  175. func (hsq *HistorySyncQuery) GetConversation(userID id.UserID, portalKey *PortalKey) (conversation *HistorySyncConversation) {
  176. rows, err := hsq.db.Query(getConversationByPortal, userID, portalKey.JID, portalKey.Receiver)
  177. defer rows.Close()
  178. if err != nil || rows == nil {
  179. return nil
  180. }
  181. if rows.Next() {
  182. conversation = hsq.NewConversation().Scan(rows)
  183. }
  184. return
  185. }
  186. func (hsq *HistorySyncQuery) DeleteAllConversations(userID id.UserID) error {
  187. _, err := hsq.db.Exec("DELETE FROM history_sync_conversation WHERE user_mxid=$1", userID)
  188. return err
  189. }
  190. const (
  191. getMessagesBetween = `
  192. SELECT id, data
  193. FROM history_sync_message
  194. WHERE user_mxid=$1
  195. AND conversation_id=$2
  196. %s
  197. ORDER BY timestamp DESC
  198. %s
  199. `
  200. deleteMessages = `
  201. DELETE FROM history_sync_message
  202. WHERE id IN (%s)
  203. `
  204. )
  205. type HistorySyncMessage struct {
  206. db *Database
  207. log log.Logger
  208. ID int
  209. UserID id.UserID
  210. ConversationID string
  211. Timestamp time.Time
  212. Data []byte
  213. }
  214. type WrappedWebMessageInfo struct {
  215. ID int
  216. Message *waProto.WebMessageInfo
  217. }
  218. func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversationID string, message *waProto.HistorySyncMsg) (*HistorySyncMessage, error) {
  219. msgData, err := proto.Marshal(message)
  220. if err != nil {
  221. return nil, err
  222. }
  223. return &HistorySyncMessage{
  224. db: hsq.db,
  225. log: hsq.log,
  226. UserID: userID,
  227. ConversationID: conversationID,
  228. Timestamp: time.Unix(int64(message.Message.GetMessageTimestamp()), 0),
  229. Data: msgData,
  230. }, nil
  231. }
  232. func (hsm *HistorySyncMessage) Insert() {
  233. _, err := hsm.db.Exec(`
  234. INSERT INTO history_sync_message (user_mxid, conversation_id, timestamp, data)
  235. VALUES ($1, $2, $3, $4)
  236. `, hsm.UserID, hsm.ConversationID, hsm.Timestamp, hsm.Data)
  237. if err != nil {
  238. hsm.log.Warnfln("Failed to insert history sync message %s/%s: %v", hsm.ConversationID, hsm.Timestamp, err)
  239. }
  240. }
  241. func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) (messages []*WrappedWebMessageInfo) {
  242. whereClauses := ""
  243. args := []interface{}{userID, conversationID}
  244. argNum := 3
  245. if startTime != nil {
  246. whereClauses += fmt.Sprintf(" AND timestamp >= $%d", argNum)
  247. args = append(args, startTime)
  248. argNum++
  249. }
  250. if endTime != nil {
  251. whereClauses += fmt.Sprintf(" AND timestamp <= $%d", argNum)
  252. args = append(args, endTime)
  253. }
  254. limitClause := ""
  255. if limit > 0 {
  256. limitClause = fmt.Sprintf("LIMIT %d", limit)
  257. }
  258. rows, err := hsq.db.Query(fmt.Sprintf(getMessagesBetween, whereClauses, limitClause), args...)
  259. defer rows.Close()
  260. if err != nil || rows == nil {
  261. return nil
  262. }
  263. var msgID int
  264. var msgData []byte
  265. for rows.Next() {
  266. err := rows.Scan(&msgID, &msgData)
  267. if err != nil {
  268. hsq.log.Error("Database scan failed: %v", err)
  269. continue
  270. }
  271. var historySyncMsg waProto.HistorySyncMsg
  272. err = proto.Unmarshal(msgData, &historySyncMsg)
  273. if err != nil {
  274. hsq.log.Errorf("Failed to unmarshal history sync message: %v", err)
  275. continue
  276. }
  277. messages = append(messages, &WrappedWebMessageInfo{
  278. ID: msgID,
  279. Message: historySyncMsg.Message,
  280. })
  281. }
  282. return
  283. }
  284. func (hsq *HistorySyncQuery) DeleteMessages(messages []*WrappedWebMessageInfo) error {
  285. messageIDs := make([]string, len(messages))
  286. for i, msg := range messages {
  287. messageIDs[i] = strconv.Itoa(msg.ID)
  288. }
  289. _, err := hsq.db.Exec(fmt.Sprintf(deleteMessages, strings.Join(messageIDs, ",")))
  290. return err
  291. }
  292. func (hsq *HistorySyncQuery) DeleteAllMessages(userID id.UserID) error {
  293. _, err := hsq.db.Exec("DELETE FROM history_sync_message WHERE user_mxid=$1", userID)
  294. return err
  295. }