thread.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. package database
  2. import (
  3. "database/sql"
  4. "errors"
  5. log "maunium.net/go/maulogger/v2"
  6. "maunium.net/go/mautrix/id"
  7. "maunium.net/go/mautrix/util/dbutil"
  8. )
  9. type ThreadQuery struct {
  10. db *Database
  11. log log.Logger
  12. }
  13. const (
  14. threadSelect = "SELECT dcid, parent_chan_id, root_msg_dcid, root_msg_mxid FROM thread"
  15. )
  16. func (tq *ThreadQuery) New() *Thread {
  17. return &Thread{
  18. db: tq.db,
  19. log: tq.log,
  20. }
  21. }
  22. func (tq *ThreadQuery) GetByDiscordID(discordID string) *Thread {
  23. query := threadSelect + " WHERE dcid=$1"
  24. row := tq.db.QueryRow(query, discordID)
  25. if row == nil {
  26. return nil
  27. }
  28. return tq.New().Scan(row)
  29. }
  30. //func (tq *ThreadQuery) GetByDiscordRootMsg(channelID, messageID string) *Thread {
  31. // query := messageSelect + " WHERE parent_chan_id=$1 AND root_msg_dcid=$2"
  32. //
  33. // row := tq.db.QueryRow(query, channelID, messageID)
  34. // if row == nil {
  35. // return nil
  36. // }
  37. //
  38. // return tq.New().Scan(row)
  39. //}
  40. func (tq *ThreadQuery) GetByMatrixRootMsg(mxid id.EventID) *Thread {
  41. query := threadSelect + " WHERE root_msg_mxid=$1"
  42. row := tq.db.QueryRow(query, mxid)
  43. if row == nil {
  44. return nil
  45. }
  46. return tq.New().Scan(row)
  47. }
  48. type Thread struct {
  49. db *Database
  50. log log.Logger
  51. ID string
  52. ParentID string
  53. RootDiscordID string
  54. RootMXID id.EventID
  55. }
  56. func (t *Thread) Scan(row dbutil.Scannable) *Thread {
  57. err := row.Scan(&t.ID, &t.ParentID, &t.RootDiscordID, &t.RootMXID)
  58. if err != nil {
  59. if !errors.Is(err, sql.ErrNoRows) {
  60. t.log.Errorln("Database scan failed:", err)
  61. panic(err)
  62. }
  63. return nil
  64. }
  65. return t
  66. }
  67. func (t *Thread) Insert() {
  68. query := "INSERT INTO thread (dcid, parent_chan_id, root_msg_dcid, root_msg_mxid) VALUES ($1, $2, $3, $4)"
  69. _, err := t.db.Exec(query, t.ID, t.ParentID, t.RootDiscordID, t.RootMXID)
  70. if err != nil {
  71. t.log.Warnfln("Failed to insert %s@%s: %v", t.ID, t.ParentID, err)
  72. panic(err)
  73. }
  74. }
  75. func (t *Thread) Delete() {
  76. query := "DELETE FROM thread WHERE dcid=$1 AND parent_chan_id=$2"
  77. _, err := t.db.Exec(query, t.ID, t.ParentID)
  78. if err != nil {
  79. t.log.Warnfln("Failed to delete %s@%s: %v", t.ID, t.ParentID, err)
  80. panic(err)
  81. }
  82. }