message.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. package database
  2. import (
  3. "database/sql"
  4. "errors"
  5. "time"
  6. log "maunium.net/go/maulogger/v2"
  7. "maunium.net/go/mautrix/id"
  8. "maunium.net/go/mautrix/util/dbutil"
  9. )
  10. type MessageQuery struct {
  11. db *Database
  12. log log.Logger
  13. }
  14. const (
  15. messageSelect = "SELECT dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid FROM message"
  16. )
  17. func (mq *MessageQuery) New() *Message {
  18. return &Message{
  19. db: mq.db,
  20. log: mq.log,
  21. }
  22. }
  23. func (mq *MessageQuery) GetAll(key PortalKey) []*Message {
  24. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2"
  25. rows, err := mq.db.Query(query, key.ChannelID, key.Receiver)
  26. if err != nil || rows == nil {
  27. return nil
  28. }
  29. var messages []*Message
  30. for rows.Next() {
  31. messages = append(messages, mq.New().Scan(rows))
  32. }
  33. return messages
  34. }
  35. func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) *Message {
  36. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3"
  37. return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
  38. }
  39. func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message {
  40. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 ORDER BY timestamp DESC LIMIT 1"
  41. return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID))
  42. }
  43. func (mq *MessageQuery) GetByMXID(key PortalKey, mxid id.EventID) *Message {
  44. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND mxid=$3"
  45. row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, mxid)
  46. if row == nil {
  47. return nil
  48. }
  49. return mq.New().Scan(row)
  50. }
  51. type Message struct {
  52. db *Database
  53. log log.Logger
  54. DiscordID string
  55. Channel PortalKey
  56. SenderID string
  57. Timestamp time.Time
  58. ThreadID string
  59. MXID id.EventID
  60. }
  61. func (m *Message) DiscordProtoChannelID() string {
  62. if m.ThreadID != "" {
  63. return m.ThreadID
  64. } else {
  65. return m.Channel.ChannelID
  66. }
  67. }
  68. func (m *Message) Scan(row dbutil.Scannable) *Message {
  69. var ts int64
  70. var threadID sql.NullString
  71. err := row.Scan(&m.DiscordID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &threadID, &m.MXID)
  72. if err != nil {
  73. if !errors.Is(err, sql.ErrNoRows) {
  74. m.log.Errorln("Database scan failed:", err)
  75. panic(err)
  76. }
  77. return nil
  78. }
  79. if ts != 0 {
  80. m.Timestamp = time.UnixMilli(ts)
  81. }
  82. m.ThreadID = threadID.String
  83. return m
  84. }
  85. func (m *Message) Insert() {
  86. query := `
  87. INSERT INTO message (dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid)
  88. VALUES ($1, $2, $3, $4, $5, $6, $7)
  89. `
  90. _, err := m.db.Exec(query,
  91. m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
  92. m.Timestamp.UnixMilli(), strPtr(m.ThreadID), m.MXID)
  93. if err != nil {
  94. m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err)
  95. panic(err)
  96. }
  97. }
  98. func (m *Message) Delete() {
  99. query := "DELETE FROM message WHERE dcid=$1 AND dc_chan_id=$2 AND dc_chan_receiver=$3"
  100. _, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver)
  101. if err != nil {
  102. m.log.Warnfln("Failed to delete %s@%s: %v", m.DiscordID, m.Channel, err)
  103. panic(err)
  104. }
  105. }