message.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. package database
  2. import (
  3. "database/sql"
  4. "errors"
  5. "fmt"
  6. "strings"
  7. "time"
  8. log "maunium.net/go/maulogger/v2"
  9. "maunium.net/go/mautrix/id"
  10. "maunium.net/go/mautrix/util/dbutil"
  11. )
  12. type MessageQuery struct {
  13. db *Database
  14. log log.Logger
  15. }
  16. const (
  17. messageSelect = "SELECT dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid FROM message"
  18. )
  19. func (mq *MessageQuery) New() *Message {
  20. return &Message{
  21. db: mq.db,
  22. log: mq.log,
  23. }
  24. }
  25. func (mq *MessageQuery) scanAll(rows *sql.Rows, err error) []*Message {
  26. if err != nil {
  27. mq.log.Warnfln("Failed to query many messages: %v", err)
  28. panic(err)
  29. return nil
  30. } else if rows == nil {
  31. return nil
  32. }
  33. var messages []*Message
  34. for rows.Next() {
  35. messages = append(messages, mq.New().Scan(rows))
  36. }
  37. return messages
  38. }
  39. func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) []*Message {
  40. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 AND dc_edit_index=0 ORDER BY dc_attachment_id ASC"
  41. return mq.scanAll(mq.db.Query(query, key.ChannelID, key.Receiver, discordID))
  42. }
  43. func (mq *MessageQuery) GetFirstByDiscordID(key PortalKey, discordID string) *Message {
  44. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 AND dc_edit_index=0 ORDER BY dc_attachment_id ASC LIMIT 1"
  45. return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
  46. }
  47. func (mq *MessageQuery) GetLastByDiscordID(key PortalKey, discordID string) *Message {
  48. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 AND dc_edit_index=0 ORDER BY dc_attachment_id DESC LIMIT 1"
  49. return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
  50. }
  51. func (mq *MessageQuery) GetClosestBefore(key PortalKey, ts time.Time) *Message {
  52. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND timestamp<=$3 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
  53. return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, ts.UnixMilli()))
  54. }
  55. func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message {
  56. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 AND dc_edit_index=0 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
  57. return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID))
  58. }
  59. func (mq *MessageQuery) GetByMXID(key PortalKey, mxid id.EventID) *Message {
  60. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND mxid=$3"
  61. row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, mxid)
  62. if row == nil {
  63. return nil
  64. }
  65. return mq.New().Scan(row)
  66. }
  67. type Message struct {
  68. db *Database
  69. log log.Logger
  70. DiscordID string
  71. AttachmentID string
  72. EditIndex int
  73. Channel PortalKey
  74. SenderID string
  75. Timestamp time.Time
  76. ThreadID string
  77. MXID id.EventID
  78. }
  79. func (m *Message) DiscordProtoChannelID() string {
  80. if m.ThreadID != "" {
  81. return m.ThreadID
  82. } else {
  83. return m.Channel.ChannelID
  84. }
  85. }
  86. func (m *Message) Scan(row dbutil.Scannable) *Message {
  87. var ts int64
  88. var threadID sql.NullString
  89. err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.EditIndex, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &threadID, &m.MXID)
  90. if err != nil {
  91. if !errors.Is(err, sql.ErrNoRows) {
  92. m.log.Errorln("Database scan failed:", err)
  93. panic(err)
  94. }
  95. return nil
  96. }
  97. if ts != 0 {
  98. m.Timestamp = time.UnixMilli(ts)
  99. }
  100. m.ThreadID = threadID.String
  101. return m
  102. }
  103. const messageInsertQuery = `
  104. INSERT INTO message (
  105. dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid
  106. )
  107. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
  108. `
  109. var messageMassInsertTemplate = strings.Replace(messageInsertQuery, "($1, $2, $3, $4, $5, $6, $7, $8, $9)", "%s", 1)
  110. type MessagePart struct {
  111. AttachmentID string
  112. MXID id.EventID
  113. }
  114. func (m *Message) MassInsert(msgs []MessagePart) {
  115. if len(msgs) == 0 {
  116. return
  117. }
  118. valueStringFormat := "($1, $%d, $2, $3, $4, $5, $6, $7, $%d)"
  119. if m.db.Dialect == dbutil.SQLite {
  120. valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
  121. }
  122. params := make([]interface{}, 7+len(msgs)*2)
  123. placeholders := make([]string, len(msgs))
  124. params[0] = m.DiscordID
  125. params[1] = m.EditIndex
  126. params[2] = m.Channel.ChannelID
  127. params[3] = m.Channel.Receiver
  128. params[4] = m.SenderID
  129. params[5] = m.Timestamp.UnixMilli()
  130. params[6] = m.ThreadID
  131. for i, msg := range msgs {
  132. params[7+i*2] = msg.AttachmentID
  133. params[7+i*2+1] = msg.MXID
  134. placeholders[i] = fmt.Sprintf(valueStringFormat, 7+i*2+1, 7+i*2+2)
  135. }
  136. _, err := m.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...)
  137. if err != nil {
  138. m.log.Warnfln("Failed to insert %d parts of %s@%s: %v", len(msgs), m.DiscordID, m.Channel, err)
  139. panic(err)
  140. }
  141. }
  142. func (m *Message) Insert() {
  143. _, err := m.db.Exec(messageInsertQuery,
  144. m.DiscordID, m.AttachmentID, m.EditIndex, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
  145. m.Timestamp.UnixMilli(), strPtr(m.ThreadID), m.MXID)
  146. if err != nil {
  147. m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err)
  148. panic(err)
  149. }
  150. }
  151. func (m *Message) Delete() {
  152. query := "DELETE FROM message WHERE dcid=$1 AND dc_chan_id=$2 AND dc_chan_receiver=$3 AND dc_attachment_id=$4"
  153. _, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.AttachmentID)
  154. if err != nil {
  155. m.log.Warnfln("Failed to delete %q of %s@%s: %v", m.AttachmentID, m.DiscordID, m.Channel, err)
  156. panic(err)
  157. }
  158. }