message.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
  2. // Copyright (C) 2021 Tulir Asokan
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package database
  17. import (
  18. "database/sql"
  19. "errors"
  20. "strings"
  21. "time"
  22. log "maunium.net/go/maulogger/v2"
  23. "maunium.net/go/mautrix/id"
  24. "go.mau.fi/whatsmeow/types"
  25. )
  26. type MessageQuery struct {
  27. db *Database
  28. log log.Logger
  29. }
  30. func (mq *MessageQuery) New() *Message {
  31. return &Message{
  32. db: mq.db,
  33. log: mq.log,
  34. }
  35. }
  36. const (
  37. getAllMessagesQuery = `
  38. SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message
  39. WHERE chat_jid=$1 AND chat_receiver=$2
  40. `
  41. getMessageByJIDQuery = `
  42. SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message
  43. WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3
  44. `
  45. getMessageByMXIDQuery = `
  46. SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message
  47. WHERE mxid=$1
  48. `
  49. getLastMessageInChatQuery = `
  50. SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message
  51. WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp<=$3 AND sent=true ORDER BY timestamp DESC LIMIT 1
  52. `
  53. getFirstMessageInChatQuery = `
  54. SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message
  55. WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp ASC LIMIT 1
  56. `
  57. getMessagesBetweenQuery = `
  58. SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message
  59. WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp>$3 AND timestamp<=$4 AND sent=true AND error='' ORDER BY timestamp ASC
  60. `
  61. )
  62. func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
  63. rows, err := mq.db.Query(getAllMessagesQuery, chat.JID, chat.Receiver)
  64. if err != nil || rows == nil {
  65. return nil
  66. }
  67. for rows.Next() {
  68. messages = append(messages, mq.New().Scan(rows))
  69. }
  70. return
  71. }
  72. func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.MessageID) *Message {
  73. return mq.maybeScan(mq.db.QueryRow(getMessageByJIDQuery, chat.JID, chat.Receiver, jid))
  74. }
  75. func (mq *MessageQuery) GetByMXID(mxid id.EventID) *Message {
  76. return mq.maybeScan(mq.db.QueryRow(getMessageByMXIDQuery, mxid))
  77. }
  78. func (mq *MessageQuery) GetLastInChat(chat PortalKey) *Message {
  79. return mq.GetLastInChatBefore(chat, time.Now().Add(60*time.Second))
  80. }
  81. func (mq *MessageQuery) GetLastInChatBefore(chat PortalKey, maxTimestamp time.Time) *Message {
  82. msg := mq.maybeScan(mq.db.QueryRow(getLastMessageInChatQuery, chat.JID, chat.Receiver, maxTimestamp.Unix()))
  83. if msg == nil || msg.Timestamp.IsZero() {
  84. // Old db, we don't know what the last message is.
  85. return nil
  86. }
  87. return msg
  88. }
  89. func (mq *MessageQuery) GetFirstInChat(chat PortalKey) *Message {
  90. return mq.maybeScan(mq.db.QueryRow(getFirstMessageInChatQuery, chat.JID, chat.Receiver))
  91. }
  92. func (mq *MessageQuery) GetMessagesBetween(chat PortalKey, minTimestamp, maxTimestamp time.Time) (messages []*Message) {
  93. rows, err := mq.db.Query(getMessagesBetweenQuery, chat.JID, chat.Receiver, minTimestamp.Unix(), maxTimestamp.Unix())
  94. if err != nil || rows == nil {
  95. return nil
  96. }
  97. for rows.Next() {
  98. messages = append(messages, mq.New().Scan(rows))
  99. }
  100. return
  101. }
  102. func (mq *MessageQuery) maybeScan(row *sql.Row) *Message {
  103. if row == nil {
  104. return nil
  105. }
  106. return mq.New().Scan(row)
  107. }
  108. type MessageErrorType string
  109. const (
  110. MsgNoError MessageErrorType = ""
  111. MsgErrDecryptionFailed MessageErrorType = "decryption_failed"
  112. MsgErrMediaNotFound MessageErrorType = "media_not_found"
  113. )
  114. type MessageType string
  115. const (
  116. MsgUnknown MessageType = ""
  117. MsgFake MessageType = "fake"
  118. MsgNormal MessageType = "message"
  119. MsgReaction MessageType = "reaction"
  120. )
  121. type Message struct {
  122. db *Database
  123. log log.Logger
  124. Chat PortalKey
  125. JID types.MessageID
  126. MXID id.EventID
  127. Sender types.JID
  128. Timestamp time.Time
  129. Sent bool
  130. Type MessageType
  131. Error MessageErrorType
  132. BroadcastListJID types.JID
  133. }
  134. func (msg *Message) IsFakeMXID() bool {
  135. return strings.HasPrefix(msg.MXID.String(), "net.maunium.whatsapp.fake::")
  136. }
  137. func (msg *Message) IsFakeJID() bool {
  138. return strings.HasPrefix(msg.JID, "FAKE::") || msg.JID == string(msg.MXID)
  139. }
  140. func (msg *Message) Scan(row Scannable) *Message {
  141. var ts int64
  142. err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID)
  143. if err != nil {
  144. if !errors.Is(err, sql.ErrNoRows) {
  145. msg.log.Errorln("Database scan failed:", err)
  146. }
  147. return nil
  148. }
  149. if ts != 0 {
  150. msg.Timestamp = time.Unix(ts, 0)
  151. }
  152. return msg
  153. }
  154. func (msg *Message) Insert(txn *sql.Tx) {
  155. var sender interface{} = msg.Sender
  156. // Slightly hacky hack to allow inserting empty senders (used for post-backfill dummy events)
  157. if msg.Sender.IsEmpty() {
  158. sender = ""
  159. }
  160. query := `
  161. INSERT INTO message
  162. (chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid)
  163. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
  164. `
  165. args := []interface{}{
  166. msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID,
  167. }
  168. var err error
  169. if txn != nil {
  170. _, err = txn.Exec(query, args...)
  171. } else {
  172. _, err = msg.db.Exec(query, args...)
  173. }
  174. if err != nil {
  175. msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
  176. }
  177. }
  178. func (msg *Message) MarkSent(ts time.Time) {
  179. msg.Sent = true
  180. msg.Timestamp = ts
  181. _, err := msg.db.Exec("UPDATE message SET sent=true, timestamp=$1 WHERE chat_jid=$2 AND chat_receiver=$3 AND jid=$4", ts.Unix(), msg.Chat.JID, msg.Chat.Receiver, msg.JID)
  182. if err != nil {
  183. msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
  184. }
  185. }
  186. func (msg *Message) UpdateMXID(txn *sql.Tx, mxid id.EventID, newType MessageType, newError MessageErrorType) {
  187. msg.MXID = mxid
  188. msg.Type = newType
  189. msg.Error = newError
  190. query := "UPDATE message SET mxid=$1, type=$2, error=$3 WHERE chat_jid=$4 AND chat_receiver=$5 AND jid=$6"
  191. args := []interface{}{mxid, newType, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID}
  192. var err error
  193. if txn != nil {
  194. _, err = txn.Exec(query, args...)
  195. } else {
  196. _, err = msg.db.Exec(query, args...)
  197. }
  198. if err != nil {
  199. msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
  200. }
  201. }
  202. func (msg *Message) Delete() {
  203. _, err := msg.db.Exec("DELETE FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", msg.Chat.JID, msg.Chat.Receiver, msg.JID)
  204. if err != nil {
  205. msg.log.Warnfln("Failed to delete %s@%s: %v", msg.Chat, msg.JID, err)
  206. }
  207. }