message.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. package database
  2. import (
  3. "database/sql"
  4. "errors"
  5. "fmt"
  6. "strings"
  7. "time"
  8. "go.mau.fi/util/dbutil"
  9. log "maunium.net/go/maulogger/v2"
  10. "maunium.net/go/mautrix/id"
  11. )
  12. type MessageQuery struct {
  13. db *Database
  14. log log.Logger
  15. }
  16. const (
  17. messageSelect = "SELECT dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid, sender_mxid FROM message"
  18. )
  19. func (mq *MessageQuery) New() *Message {
  20. return &Message{
  21. db: mq.db,
  22. log: mq.log,
  23. }
  24. }
  25. func (mq *MessageQuery) scanAll(rows dbutil.Rows, err error) []*Message {
  26. if err != nil {
  27. mq.log.Warnfln("Failed to query many messages: %v", err)
  28. panic(err)
  29. } else if rows == nil {
  30. return nil
  31. }
  32. var messages []*Message
  33. for rows.Next() {
  34. messages = append(messages, mq.New().Scan(rows))
  35. }
  36. return messages
  37. }
  38. func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) []*Message {
  39. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id ASC"
  40. return mq.scanAll(mq.db.Query(query, key.ChannelID, key.Receiver, discordID))
  41. }
  42. func (mq *MessageQuery) GetFirstByDiscordID(key PortalKey, discordID string) *Message {
  43. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id ASC LIMIT 1"
  44. return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
  45. }
  46. func (mq *MessageQuery) GetLastByDiscordID(key PortalKey, discordID string) *Message {
  47. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id DESC LIMIT 1"
  48. return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
  49. }
  50. func (mq *MessageQuery) GetClosestBefore(key PortalKey, threadID string, ts time.Time) *Message {
  51. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 AND timestamp<=$4 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
  52. return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID, ts.UnixMilli()))
  53. }
  54. func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message {
  55. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
  56. return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID))
  57. }
  58. func (mq *MessageQuery) GetLast(key PortalKey) *Message {
  59. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 ORDER BY timestamp DESC LIMIT 1"
  60. return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver))
  61. }
  62. func (mq *MessageQuery) DeleteAll(key PortalKey) {
  63. query := "DELETE FROM message WHERE dc_chan_id=$1 AND dc_chan_receiver=$2"
  64. _, err := mq.db.Exec(query, key.ChannelID, key.Receiver)
  65. if err != nil {
  66. mq.log.Warnfln("Failed to delete messages of %s: %v", key, err)
  67. panic(err)
  68. }
  69. }
  70. func (mq *MessageQuery) GetByMXID(key PortalKey, mxid id.EventID) *Message {
  71. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND mxid=$3"
  72. row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, mxid)
  73. if row == nil {
  74. return nil
  75. }
  76. return mq.New().Scan(row)
  77. }
  78. func (mq *MessageQuery) MassInsert(key PortalKey, msgs []Message) {
  79. if len(msgs) == 0 {
  80. return
  81. }
  82. valueStringFormat := "($%d, $%d, $1, $2, $%d, $%d, $%d, $%d, $%d, $%d)"
  83. if mq.db.Dialect == dbutil.SQLite {
  84. valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
  85. }
  86. params := make([]interface{}, 2+len(msgs)*8)
  87. placeholders := make([]string, len(msgs))
  88. params[0] = key.ChannelID
  89. params[1] = key.Receiver
  90. for i, msg := range msgs {
  91. baseIndex := 2 + i*7
  92. params[baseIndex] = msg.DiscordID
  93. params[baseIndex+1] = msg.AttachmentID
  94. params[baseIndex+2] = msg.SenderID
  95. params[baseIndex+3] = msg.Timestamp.UnixMilli()
  96. params[baseIndex+4] = msg.editTimestampVal()
  97. params[baseIndex+5] = msg.ThreadID
  98. params[baseIndex+6] = msg.MXID
  99. params[baseIndex+7] = msg.SenderMXID.String()
  100. placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7, baseIndex+8)
  101. }
  102. _, err := mq.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...)
  103. if err != nil {
  104. mq.log.Warnfln("Failed to insert %d messages: %v", len(msgs), err)
  105. panic(err)
  106. }
  107. }
  108. type Message struct {
  109. db *Database
  110. log log.Logger
  111. DiscordID string
  112. AttachmentID string
  113. Channel PortalKey
  114. SenderID string
  115. Timestamp time.Time
  116. EditTimestamp time.Time
  117. ThreadID string
  118. MXID id.EventID
  119. SenderMXID id.UserID
  120. }
  121. func (m *Message) DiscordProtoChannelID() string {
  122. if m.ThreadID != "" {
  123. return m.ThreadID
  124. } else {
  125. return m.Channel.ChannelID
  126. }
  127. }
  128. func (m *Message) Scan(row dbutil.Scannable) *Message {
  129. var ts, editTS int64
  130. err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &editTS, &m.ThreadID, &m.MXID, &m.SenderMXID)
  131. if err != nil {
  132. if !errors.Is(err, sql.ErrNoRows) {
  133. m.log.Errorln("Database scan failed:", err)
  134. panic(err)
  135. }
  136. return nil
  137. }
  138. if ts != 0 {
  139. m.Timestamp = time.UnixMilli(ts).UTC()
  140. }
  141. if editTS != 0 {
  142. m.EditTimestamp = time.Unix(0, editTS).UTC()
  143. }
  144. return m
  145. }
  146. const messageInsertQuery = `
  147. INSERT INTO message (
  148. dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid, sender_mxid
  149. )
  150. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
  151. `
  152. var messageMassInsertTemplate = strings.Replace(messageInsertQuery, "($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)", "%s", 1)
  153. type MessagePart struct {
  154. AttachmentID string
  155. MXID id.EventID
  156. }
  157. func (m *Message) editTimestampVal() int64 {
  158. if m.EditTimestamp.IsZero() {
  159. return 0
  160. }
  161. return m.EditTimestamp.UnixNano()
  162. }
  163. func (m *Message) MassInsertParts(msgs []MessagePart) {
  164. if len(msgs) == 0 {
  165. return
  166. }
  167. valueStringFormat := "($1, $%d, $2, $3, $4, $5, $6, $7, $%d, $8)"
  168. if m.db.Dialect == dbutil.SQLite {
  169. valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
  170. }
  171. params := make([]interface{}, 8+len(msgs)*2)
  172. placeholders := make([]string, len(msgs))
  173. params[0] = m.DiscordID
  174. params[1] = m.Channel.ChannelID
  175. params[2] = m.Channel.Receiver
  176. params[3] = m.SenderID
  177. params[4] = m.Timestamp.UnixMilli()
  178. params[5] = m.editTimestampVal()
  179. params[6] = m.ThreadID
  180. params[7] = m.SenderMXID.String()
  181. for i, msg := range msgs {
  182. params[8+i*2] = msg.AttachmentID
  183. params[8+i*2+1] = msg.MXID
  184. placeholders[i] = fmt.Sprintf(valueStringFormat, 8+i*2+1, 8+i*2+2)
  185. }
  186. _, err := m.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...)
  187. if err != nil {
  188. m.log.Warnfln("Failed to insert %d parts of %s@%s: %v", len(msgs), m.DiscordID, m.Channel, err)
  189. panic(err)
  190. }
  191. }
  192. func (m *Message) Insert() {
  193. _, err := m.db.Exec(messageInsertQuery,
  194. m.DiscordID, m.AttachmentID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
  195. m.Timestamp.UnixMilli(), m.editTimestampVal(), m.ThreadID, m.MXID, m.SenderMXID.String())
  196. if err != nil {
  197. m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err)
  198. panic(err)
  199. }
  200. }
  201. const editUpdateQuery = `
  202. UPDATE message
  203. SET dc_edit_timestamp=$1
  204. WHERE dcid=$2 AND dc_attachment_id=$3 AND dc_chan_id=$4 AND dc_chan_receiver=$5 AND dc_edit_timestamp<$1
  205. `
  206. func (m *Message) UpdateEditTimestamp(ts time.Time) {
  207. _, err := m.db.Exec(editUpdateQuery, ts.UnixNano(), m.DiscordID, m.AttachmentID, m.Channel.ChannelID, m.Channel.Receiver)
  208. if err != nil {
  209. m.log.Warnfln("Failed to update edit timestamp of %s@%s: %v", m.DiscordID, m.Channel, err)
  210. panic(err)
  211. }
  212. }
  213. func (m *Message) Delete() {
  214. query := "DELETE FROM message WHERE dcid=$1 AND dc_chan_id=$2 AND dc_chan_receiver=$3 AND dc_attachment_id=$4"
  215. _, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.AttachmentID)
  216. if err != nil {
  217. m.log.Warnfln("Failed to delete %q of %s@%s: %v", m.AttachmentID, m.DiscordID, m.Channel, err)
  218. panic(err)
  219. }
  220. }