message.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. package database
  2. import (
  3. "database/sql"
  4. "errors"
  5. "time"
  6. log "maunium.net/go/maulogger/v2"
  7. "maunium.net/go/mautrix/id"
  8. "maunium.net/go/mautrix/util/dbutil"
  9. )
  10. type MessageQuery struct {
  11. db *Database
  12. log log.Logger
  13. }
  14. const (
  15. messageSelect = "SELECT dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, mxid FROM message"
  16. )
  17. func (mq *MessageQuery) New() *Message {
  18. return &Message{
  19. db: mq.db,
  20. log: mq.log,
  21. }
  22. }
  23. func (mq *MessageQuery) GetAll(key PortalKey) []*Message {
  24. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2"
  25. rows, err := mq.db.Query(query, key.ChannelID, key.Receiver)
  26. if err != nil || rows == nil {
  27. return nil
  28. }
  29. var messages []*Message
  30. for rows.Next() {
  31. messages = append(messages, mq.New().Scan(rows))
  32. }
  33. return messages
  34. }
  35. func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) *Message {
  36. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3"
  37. row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID)
  38. if row == nil {
  39. mq.log.Debugfln("failed to find existing message for discord_id %s", discordID)
  40. return nil
  41. }
  42. return mq.New().Scan(row)
  43. }
  44. func (mq *MessageQuery) GetByMXID(key PortalKey, mxid id.EventID) *Message {
  45. query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND mxid=$3"
  46. row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, mxid)
  47. if row == nil {
  48. return nil
  49. }
  50. return mq.New().Scan(row)
  51. }
  52. type Message struct {
  53. db *Database
  54. log log.Logger
  55. DiscordID string
  56. Channel PortalKey
  57. SenderID string
  58. Timestamp time.Time
  59. MXID id.EventID
  60. }
  61. func (m *Message) Scan(row dbutil.Scannable) *Message {
  62. var ts int64
  63. err := row.Scan(&m.DiscordID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &m.MXID)
  64. if err != nil {
  65. if !errors.Is(err, sql.ErrNoRows) {
  66. m.log.Errorln("Database scan failed:", err)
  67. }
  68. return nil
  69. }
  70. if ts != 0 {
  71. m.Timestamp = time.Unix(ts, 0)
  72. }
  73. return m
  74. }
  75. func (m *Message) Insert() {
  76. query := "INSERT INTO message (dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, mxid) VALUES ($1, $2, $3, $4, $5, $6)"
  77. _, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID, m.Timestamp.Unix(), m.MXID)
  78. if err != nil {
  79. m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err)
  80. }
  81. }
  82. func (m *Message) Delete() {
  83. query := "DELETE FROM message WHERE dcid=$1 AND dc_chan_id=$2 AND dc_chan_receiver=$3"
  84. _, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver)
  85. if err != nil {
  86. m.log.Warnfln("Failed to delete %s@%s: %v", m.DiscordID, m.Channel, err)
  87. }
  88. }