message.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
  2. // Copyright (C) 2021 Tulir Asokan
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package database
  17. import (
  18. "database/sql"
  19. "errors"
  20. "strings"
  21. "time"
  22. log "maunium.net/go/maulogger/v2"
  23. "maunium.net/go/mautrix/id"
  24. "maunium.net/go/mautrix/util/dbutil"
  25. "go.mau.fi/whatsmeow/types"
  26. )
  27. type MessageQuery struct {
  28. db *Database
  29. log log.Logger
  30. }
  31. func (mq *MessageQuery) New() *Message {
  32. return &Message{
  33. db: mq.db,
  34. log: mq.log,
  35. }
  36. }
  37. const (
  38. getAllMessagesQuery = `
  39. SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message
  40. WHERE chat_jid=$1 AND chat_receiver=$2
  41. `
  42. getMessageByJIDQuery = `
  43. SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message
  44. WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3
  45. `
  46. getMessageByMXIDQuery = `
  47. SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message
  48. WHERE mxid=$1
  49. `
  50. getLastMessageInChatQuery = `
  51. SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message
  52. WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp<=$3 AND sent=true ORDER BY timestamp DESC LIMIT 1
  53. `
  54. getFirstMessageInChatQuery = `
  55. SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message
  56. WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp ASC LIMIT 1
  57. `
  58. getMessagesBetweenQuery = `
  59. SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message
  60. WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp>$3 AND timestamp<=$4 AND sent=true AND error='' ORDER BY timestamp ASC
  61. `
  62. )
  63. func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
  64. rows, err := mq.db.Query(getAllMessagesQuery, chat.JID, chat.Receiver)
  65. if err != nil || rows == nil {
  66. return nil
  67. }
  68. for rows.Next() {
  69. messages = append(messages, mq.New().Scan(rows))
  70. }
  71. return
  72. }
  73. func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.MessageID) *Message {
  74. return mq.maybeScan(mq.db.QueryRow(getMessageByJIDQuery, chat.JID, chat.Receiver, jid))
  75. }
  76. func (mq *MessageQuery) GetByMXID(mxid id.EventID) *Message {
  77. return mq.maybeScan(mq.db.QueryRow(getMessageByMXIDQuery, mxid))
  78. }
  79. func (mq *MessageQuery) GetLastInChat(chat PortalKey) *Message {
  80. return mq.GetLastInChatBefore(chat, time.Now().Add(60*time.Second))
  81. }
  82. func (mq *MessageQuery) GetLastInChatBefore(chat PortalKey, maxTimestamp time.Time) *Message {
  83. msg := mq.maybeScan(mq.db.QueryRow(getLastMessageInChatQuery, chat.JID, chat.Receiver, maxTimestamp.Unix()))
  84. if msg == nil || msg.Timestamp.IsZero() {
  85. // Old db, we don't know what the last message is.
  86. return nil
  87. }
  88. return msg
  89. }
  90. func (mq *MessageQuery) GetFirstInChat(chat PortalKey) *Message {
  91. return mq.maybeScan(mq.db.QueryRow(getFirstMessageInChatQuery, chat.JID, chat.Receiver))
  92. }
  93. func (mq *MessageQuery) GetMessagesBetween(chat PortalKey, minTimestamp, maxTimestamp time.Time) (messages []*Message) {
  94. rows, err := mq.db.Query(getMessagesBetweenQuery, chat.JID, chat.Receiver, minTimestamp.Unix(), maxTimestamp.Unix())
  95. if err != nil || rows == nil {
  96. return nil
  97. }
  98. for rows.Next() {
  99. messages = append(messages, mq.New().Scan(rows))
  100. }
  101. return
  102. }
  103. func (mq *MessageQuery) maybeScan(row *sql.Row) *Message {
  104. if row == nil {
  105. return nil
  106. }
  107. return mq.New().Scan(row)
  108. }
  109. type MessageErrorType string
  110. const (
  111. MsgNoError MessageErrorType = ""
  112. MsgErrDecryptionFailed MessageErrorType = "decryption_failed"
  113. MsgErrMediaNotFound MessageErrorType = "media_not_found"
  114. )
  115. type MessageType string
  116. const (
  117. MsgUnknown MessageType = ""
  118. MsgFake MessageType = "fake"
  119. MsgNormal MessageType = "message"
  120. MsgReaction MessageType = "reaction"
  121. MsgEdit MessageType = "edit"
  122. MsgMatrixPoll MessageType = "matrix-poll"
  123. )
  124. type Message struct {
  125. db *Database
  126. log log.Logger
  127. Chat PortalKey
  128. JID types.MessageID
  129. MXID id.EventID
  130. Sender types.JID
  131. Timestamp time.Time
  132. Sent bool
  133. Type MessageType
  134. Error MessageErrorType
  135. BroadcastListJID types.JID
  136. }
  137. func (msg *Message) IsFakeMXID() bool {
  138. return strings.HasPrefix(msg.MXID.String(), "net.maunium.whatsapp.fake::")
  139. }
  140. func (msg *Message) IsFakeJID() bool {
  141. return strings.HasPrefix(msg.JID, "FAKE::") || msg.JID == string(msg.MXID)
  142. }
  143. func (msg *Message) Scan(row dbutil.Scannable) *Message {
  144. var ts int64
  145. err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID)
  146. if err != nil {
  147. if !errors.Is(err, sql.ErrNoRows) {
  148. msg.log.Errorln("Database scan failed:", err)
  149. }
  150. return nil
  151. }
  152. if ts != 0 {
  153. msg.Timestamp = time.Unix(ts, 0)
  154. }
  155. return msg
  156. }
  157. func (msg *Message) Insert(txn dbutil.Execable) {
  158. if txn == nil {
  159. txn = msg.db
  160. }
  161. var sender interface{} = msg.Sender
  162. // Slightly hacky hack to allow inserting empty senders (used for post-backfill dummy events)
  163. if msg.Sender.IsEmpty() {
  164. sender = ""
  165. }
  166. _, err := txn.Exec(`
  167. INSERT INTO message
  168. (chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid)
  169. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
  170. `, msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID)
  171. if err != nil {
  172. msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
  173. }
  174. }
  175. func (msg *Message) MarkSent(ts time.Time) {
  176. msg.Sent = true
  177. msg.Timestamp = ts
  178. _, err := msg.db.Exec("UPDATE message SET sent=true, timestamp=$1 WHERE chat_jid=$2 AND chat_receiver=$3 AND jid=$4", ts.Unix(), msg.Chat.JID, msg.Chat.Receiver, msg.JID)
  179. if err != nil {
  180. msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
  181. }
  182. }
  183. func (msg *Message) UpdateMXID(txn dbutil.Execable, mxid id.EventID, newType MessageType, newError MessageErrorType) {
  184. if txn == nil {
  185. txn = msg.db
  186. }
  187. msg.MXID = mxid
  188. msg.Type = newType
  189. msg.Error = newError
  190. _, err := txn.Exec("UPDATE message SET mxid=$1, type=$2, error=$3 WHERE chat_jid=$4 AND chat_receiver=$5 AND jid=$6",
  191. mxid, newType, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
  192. if err != nil {
  193. msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
  194. }
  195. }
  196. func (msg *Message) Delete() {
  197. _, err := msg.db.Exec("DELETE FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", msg.Chat.JID, msg.Chat.Receiver, msg.JID)
  198. if err != nil {
  199. msg.log.Warnfln("Failed to delete %s@%s: %v", msg.Chat, msg.JID, err)
  200. }
  201. }