message.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. package database
  2. import (
  3. "database/sql"
  4. "errors"
  5. "fmt"
  6. "strings"
  7. "time"
  8. log "maunium.net/go/maulogger/v2"
  9. "maunium.net/go/mautrix/id"
  10. "maunium.net/go/mautrix/util/dbutil"
  11. )
  12. type MessageQuery struct {
  13. db *Database
  14. log log.Logger
  15. }
  16. const (
  17. messageSelect = "SELECT dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid FROM message"
  18. )
  19. func (mq *MessageQuery) New() *Message {
  20. return &Message{
  21. db: mq.db,
  22. log: mq.log,
  23. }
  24. }
  25. func (mq *MessageQuery) scanAll(rows dbutil.Rows, err error) []*Message {
  26. if err != nil {
  27. mq.log.Warnfln("Failed to query many messages: %v", err)
  28. panic(err)
  29. } else if rows == nil {
  30. return nil
  31. }
  32. var messages []*Message
  33. for rows.Next() {
  34. messages = append(messages, mq.New().Scan(rows))
  35. }
  36. return messages
  37. }
  38. func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) []*Message {
  39. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id ASC"
  40. return mq.scanAll(mq.db.Query(query, key.ChannelID, key.Receiver, discordID))
  41. }
  42. func (mq *MessageQuery) GetFirstByDiscordID(key PortalKey, discordID string) *Message {
  43. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id ASC LIMIT 1"
  44. return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
  45. }
  46. func (mq *MessageQuery) GetLastByDiscordID(key PortalKey, discordID string) *Message {
  47. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id DESC LIMIT 1"
  48. return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
  49. }
  50. func (mq *MessageQuery) GetClosestBefore(key PortalKey, threadID string, ts time.Time) *Message {
  51. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 AND timestamp<=$4 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
  52. return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID, ts.UnixMilli()))
  53. }
  54. func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message {
  55. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
  56. return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID))
  57. }
  58. func (mq *MessageQuery) GetLast(key PortalKey) *Message {
  59. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 ORDER BY timestamp DESC LIMIT 1"
  60. return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver))
  61. }
  62. func (mq *MessageQuery) DeleteAll(key PortalKey) {
  63. query := "DELETE FROM message WHERE dc_chan_id=$1 AND dc_chan_receiver=$2"
  64. _, err := mq.db.Exec(query, key.ChannelID, key.Receiver)
  65. if err != nil {
  66. mq.log.Warnfln("Failed to delete messages of %s: %v", key, err)
  67. panic(err)
  68. }
  69. }
  70. func (mq *MessageQuery) GetByMXID(key PortalKey, mxid id.EventID) *Message {
  71. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND mxid=$3"
  72. row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, mxid)
  73. if row == nil {
  74. return nil
  75. }
  76. return mq.New().Scan(row)
  77. }
  78. func (mq *MessageQuery) MassInsert(key PortalKey, msgs []Message) {
  79. if len(msgs) == 0 {
  80. return
  81. }
  82. valueStringFormat := "($%d, $%d, $1, $2, $%d, $%d, $%d, $%d, $%d)"
  83. if mq.db.Dialect == dbutil.SQLite {
  84. valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
  85. }
  86. params := make([]interface{}, 2+len(msgs)*7)
  87. placeholders := make([]string, len(msgs))
  88. params[0] = key.ChannelID
  89. params[1] = key.Receiver
  90. for i, msg := range msgs {
  91. baseIndex := 2 + i*7
  92. params[baseIndex] = msg.DiscordID
  93. params[baseIndex+1] = msg.AttachmentID
  94. params[baseIndex+2] = msg.SenderID
  95. params[baseIndex+3] = msg.Timestamp.UnixMilli()
  96. params[baseIndex+4] = msg.editTimestampVal()
  97. params[baseIndex+5] = msg.ThreadID
  98. params[baseIndex+6] = msg.MXID
  99. placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7)
  100. }
  101. _, err := mq.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...)
  102. if err != nil {
  103. mq.log.Warnfln("Failed to insert %d messages: %v", len(msgs), err)
  104. panic(err)
  105. }
  106. }
  107. type Message struct {
  108. db *Database
  109. log log.Logger
  110. DiscordID string
  111. AttachmentID string
  112. Channel PortalKey
  113. SenderID string
  114. Timestamp time.Time
  115. EditTimestamp time.Time
  116. ThreadID string
  117. MXID id.EventID
  118. }
  119. func (m *Message) DiscordProtoChannelID() string {
  120. if m.ThreadID != "" {
  121. return m.ThreadID
  122. } else {
  123. return m.Channel.ChannelID
  124. }
  125. }
  126. func (m *Message) Scan(row dbutil.Scannable) *Message {
  127. var ts, editTS int64
  128. err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &editTS, &m.ThreadID, &m.MXID)
  129. if err != nil {
  130. if !errors.Is(err, sql.ErrNoRows) {
  131. m.log.Errorln("Database scan failed:", err)
  132. panic(err)
  133. }
  134. return nil
  135. }
  136. if ts != 0 {
  137. m.Timestamp = time.UnixMilli(ts).UTC()
  138. }
  139. if editTS != 0 {
  140. m.EditTimestamp = time.Unix(0, editTS).UTC()
  141. }
  142. return m
  143. }
  144. const messageInsertQuery = `
  145. INSERT INTO message (
  146. dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid
  147. )
  148. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
  149. `
  150. var messageMassInsertTemplate = strings.Replace(messageInsertQuery, "($1, $2, $3, $4, $5, $6, $7, $8, $9)", "%s", 1)
  151. type MessagePart struct {
  152. AttachmentID string
  153. MXID id.EventID
  154. }
  155. func (m *Message) editTimestampVal() int64 {
  156. if m.EditTimestamp.IsZero() {
  157. return 0
  158. }
  159. return m.EditTimestamp.UnixNano()
  160. }
  161. func (m *Message) MassInsertParts(msgs []MessagePart) {
  162. if len(msgs) == 0 {
  163. return
  164. }
  165. valueStringFormat := "($1, $%d, $2, $3, $4, $5, $6, $7, $%d)"
  166. if m.db.Dialect == dbutil.SQLite {
  167. valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
  168. }
  169. params := make([]interface{}, 7+len(msgs)*2)
  170. placeholders := make([]string, len(msgs))
  171. params[0] = m.DiscordID
  172. params[1] = m.Channel.ChannelID
  173. params[2] = m.Channel.Receiver
  174. params[3] = m.SenderID
  175. params[4] = m.Timestamp.UnixMilli()
  176. params[5] = m.editTimestampVal()
  177. params[6] = m.ThreadID
  178. for i, msg := range msgs {
  179. params[7+i*2] = msg.AttachmentID
  180. params[7+i*2+1] = msg.MXID
  181. placeholders[i] = fmt.Sprintf(valueStringFormat, 7+i*2+1, 7+i*2+2)
  182. }
  183. _, err := m.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...)
  184. if err != nil {
  185. m.log.Warnfln("Failed to insert %d parts of %s@%s: %v", len(msgs), m.DiscordID, m.Channel, err)
  186. panic(err)
  187. }
  188. }
  189. func (m *Message) Insert() {
  190. _, err := m.db.Exec(messageInsertQuery,
  191. m.DiscordID, m.AttachmentID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
  192. m.Timestamp.UnixMilli(), m.editTimestampVal(), m.ThreadID, m.MXID)
  193. if err != nil {
  194. m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err)
  195. panic(err)
  196. }
  197. }
  198. const editUpdateQuery = `
  199. UPDATE message
  200. SET dc_edit_timestamp=$1
  201. WHERE dcid=$2 AND dc_attachment_id=$3 AND dc_chan_id=$4 AND dc_chan_receiver=$5 AND dc_edit_timestamp<$1
  202. `
  203. func (m *Message) UpdateEditTimestamp(ts time.Time) {
  204. _, err := m.db.Exec(editUpdateQuery, ts.UnixNano(), m.DiscordID, m.AttachmentID, m.Channel.ChannelID, m.Channel.Receiver)
  205. if err != nil {
  206. m.log.Warnfln("Failed to update edit timestamp of %s@%s: %v", m.DiscordID, m.Channel, err)
  207. panic(err)
  208. }
  209. }
  210. func (m *Message) Delete() {
  211. query := "DELETE FROM message WHERE dcid=$1 AND dc_chan_id=$2 AND dc_chan_receiver=$3 AND dc_attachment_id=$4"
  212. _, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.AttachmentID)
  213. if err != nil {
  214. m.log.Warnfln("Failed to delete %q of %s@%s: %v", m.AttachmentID, m.DiscordID, m.Channel, err)
  215. panic(err)
  216. }
  217. }