message.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
  2. // Copyright (C) 2021 Tulir Asokan
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package database
  17. import (
  18. "database/sql"
  19. "errors"
  20. "fmt"
  21. "strings"
  22. "time"
  23. "go.mau.fi/util/dbutil"
  24. "go.mau.fi/whatsmeow/types"
  25. log "maunium.net/go/maulogger/v2"
  26. "maunium.net/go/mautrix/id"
  27. )
  28. type MessageQuery struct {
  29. db *Database
  30. log log.Logger
  31. }
  32. func (mq *MessageQuery) New() *Message {
  33. return &Message{
  34. db: mq.db,
  35. log: mq.log,
  36. }
  37. }
  38. const (
  39. getAllMessagesQuery = `
  40. SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
  41. WHERE chat_jid=$1 AND chat_receiver=$2
  42. `
  43. getMessageByJIDQuery = `
  44. SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
  45. WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3
  46. `
  47. getMessageByMXIDQuery = `
  48. SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
  49. WHERE mxid=$1
  50. `
  51. getLastMessageInChatQuery = `
  52. SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
  53. WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp<=$3 AND sent=true ORDER BY timestamp DESC LIMIT 1
  54. `
  55. getFirstMessageInChatQuery = `
  56. SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
  57. WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp ASC LIMIT 1
  58. `
  59. getMessagesBetweenQuery = `
  60. SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
  61. WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp>$3 AND timestamp<=$4 AND sent=true AND error='' ORDER BY timestamp ASC
  62. `
  63. )
  64. func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
  65. rows, err := mq.db.Query(getAllMessagesQuery, chat.JID, chat.Receiver)
  66. if err != nil || rows == nil {
  67. return nil
  68. }
  69. for rows.Next() {
  70. messages = append(messages, mq.New().Scan(rows))
  71. }
  72. return
  73. }
  74. func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.MessageID) *Message {
  75. return mq.maybeScan(mq.db.QueryRow(getMessageByJIDQuery, chat.JID, chat.Receiver, jid))
  76. }
  77. func (mq *MessageQuery) GetByMXID(mxid id.EventID) *Message {
  78. return mq.maybeScan(mq.db.QueryRow(getMessageByMXIDQuery, mxid))
  79. }
  80. func (mq *MessageQuery) GetLastInChat(chat PortalKey) *Message {
  81. return mq.GetLastInChatBefore(chat, time.Now().Add(60*time.Second))
  82. }
  83. func (mq *MessageQuery) GetLastInChatBefore(chat PortalKey, maxTimestamp time.Time) *Message {
  84. msg := mq.maybeScan(mq.db.QueryRow(getLastMessageInChatQuery, chat.JID, chat.Receiver, maxTimestamp.Unix()))
  85. if msg == nil || msg.Timestamp.IsZero() {
  86. // Old db, we don't know what the last message is.
  87. return nil
  88. }
  89. return msg
  90. }
  91. func (mq *MessageQuery) GetFirstInChat(chat PortalKey) *Message {
  92. return mq.maybeScan(mq.db.QueryRow(getFirstMessageInChatQuery, chat.JID, chat.Receiver))
  93. }
  94. func (mq *MessageQuery) GetMessagesBetween(chat PortalKey, minTimestamp, maxTimestamp time.Time) (messages []*Message) {
  95. rows, err := mq.db.Query(getMessagesBetweenQuery, chat.JID, chat.Receiver, minTimestamp.Unix(), maxTimestamp.Unix())
  96. if err != nil || rows == nil {
  97. return nil
  98. }
  99. for rows.Next() {
  100. messages = append(messages, mq.New().Scan(rows))
  101. }
  102. return
  103. }
  104. func (mq *MessageQuery) maybeScan(row *sql.Row) *Message {
  105. if row == nil {
  106. return nil
  107. }
  108. return mq.New().Scan(row)
  109. }
  110. type MessageErrorType string
  111. const (
  112. MsgNoError MessageErrorType = ""
  113. MsgErrDecryptionFailed MessageErrorType = "decryption_failed"
  114. MsgErrMediaNotFound MessageErrorType = "media_not_found"
  115. )
  116. type MessageType string
  117. const (
  118. MsgUnknown MessageType = ""
  119. MsgFake MessageType = "fake"
  120. MsgNormal MessageType = "message"
  121. MsgReaction MessageType = "reaction"
  122. MsgEdit MessageType = "edit"
  123. MsgMatrixPoll MessageType = "matrix-poll"
  124. MsgBeeperGallery MessageType = "beeper-gallery"
  125. )
  126. type Message struct {
  127. db *Database
  128. log log.Logger
  129. Chat PortalKey
  130. JID types.MessageID
  131. MXID id.EventID
  132. Sender types.JID
  133. SenderMXID id.UserID
  134. Timestamp time.Time
  135. Sent bool
  136. Type MessageType
  137. Error MessageErrorType
  138. GalleryPart int
  139. BroadcastListJID types.JID
  140. }
  141. func (msg *Message) IsFakeMXID() bool {
  142. return strings.HasPrefix(msg.MXID.String(), "net.maunium.whatsapp.fake::")
  143. }
  144. func (msg *Message) IsFakeJID() bool {
  145. return strings.HasPrefix(msg.JID, "FAKE::") || msg.JID == string(msg.MXID)
  146. }
  147. const fakeGalleryMXIDFormat = "com.beeper.gallery::%d:%s"
  148. func (msg *Message) Scan(row dbutil.Scannable) *Message {
  149. var ts int64
  150. err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &msg.SenderMXID, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID)
  151. if err != nil {
  152. if !errors.Is(err, sql.ErrNoRows) {
  153. msg.log.Errorln("Database scan failed:", err)
  154. }
  155. return nil
  156. }
  157. if strings.HasPrefix(msg.MXID.String(), "com.beeper.gallery::") {
  158. _, err = fmt.Sscanf(msg.MXID.String(), fakeGalleryMXIDFormat, &msg.GalleryPart, &msg.MXID)
  159. if err != nil {
  160. msg.log.Errorln("Parsing gallery MXID failed:", err)
  161. }
  162. }
  163. if ts != 0 {
  164. msg.Timestamp = time.Unix(ts, 0)
  165. }
  166. return msg
  167. }
  168. func (msg *Message) Insert(txn dbutil.Execable) {
  169. if txn == nil {
  170. txn = msg.db
  171. }
  172. var sender interface{} = msg.Sender
  173. // Slightly hacky hack to allow inserting empty senders (used for post-backfill dummy events)
  174. if msg.Sender.IsEmpty() {
  175. sender = ""
  176. }
  177. mxid := msg.MXID.String()
  178. if msg.GalleryPart != 0 {
  179. mxid = fmt.Sprintf(fakeGalleryMXIDFormat, msg.GalleryPart, mxid)
  180. }
  181. _, err := txn.Exec(`
  182. INSERT INTO message
  183. (chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid)
  184. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
  185. `, msg.Chat.JID, msg.Chat.Receiver, msg.JID, mxid, sender, msg.SenderMXID, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID)
  186. if err != nil {
  187. msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
  188. }
  189. }
  190. func (msg *Message) MarkSent(ts time.Time) {
  191. msg.Sent = true
  192. msg.Timestamp = ts
  193. _, err := msg.db.Exec("UPDATE message SET sent=true, timestamp=$1 WHERE chat_jid=$2 AND chat_receiver=$3 AND jid=$4", ts.Unix(), msg.Chat.JID, msg.Chat.Receiver, msg.JID)
  194. if err != nil {
  195. msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
  196. }
  197. }
  198. func (msg *Message) UpdateMXID(txn dbutil.Execable, mxid id.EventID, newType MessageType, newError MessageErrorType) {
  199. if txn == nil {
  200. txn = msg.db
  201. }
  202. msg.MXID = mxid
  203. msg.Type = newType
  204. msg.Error = newError
  205. _, err := txn.Exec("UPDATE message SET mxid=$1, type=$2, error=$3 WHERE chat_jid=$4 AND chat_receiver=$5 AND jid=$6",
  206. mxid, newType, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
  207. if err != nil {
  208. msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
  209. }
  210. }
  211. func (msg *Message) Delete() {
  212. _, err := msg.db.Exec("DELETE FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", msg.Chat.JID, msg.Chat.Receiver, msg.JID)
  213. if err != nil {
  214. msg.log.Warnfln("Failed to delete %s@%s: %v", msg.Chat, msg.JID, err)
  215. }
  216. }