thread.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. package database
  2. import (
  3. "database/sql"
  4. "errors"
  5. "go.mau.fi/util/dbutil"
  6. log "maunium.net/go/maulogger/v2"
  7. "maunium.net/go/mautrix/id"
  8. )
  9. type ThreadQuery struct {
  10. db *Database
  11. log log.Logger
  12. }
  13. const (
  14. threadSelect = "SELECT dcid, parent_chan_id, root_msg_dcid, root_msg_mxid, creation_notice_mxid FROM thread"
  15. )
  16. func (tq *ThreadQuery) New() *Thread {
  17. return &Thread{
  18. db: tq.db,
  19. log: tq.log,
  20. }
  21. }
  22. func (tq *ThreadQuery) GetByDiscordID(discordID string) *Thread {
  23. query := threadSelect + " WHERE dcid=$1"
  24. row := tq.db.QueryRow(query, discordID)
  25. if row == nil {
  26. return nil
  27. }
  28. return tq.New().Scan(row)
  29. }
  30. func (tq *ThreadQuery) GetByMatrixRootMsg(mxid id.EventID) *Thread {
  31. query := threadSelect + " WHERE root_msg_mxid=$1"
  32. row := tq.db.QueryRow(query, mxid)
  33. if row == nil {
  34. return nil
  35. }
  36. return tq.New().Scan(row)
  37. }
  38. func (tq *ThreadQuery) GetByMatrixRootOrCreationNoticeMsg(mxid id.EventID) *Thread {
  39. query := threadSelect + " WHERE root_msg_mxid=$1 OR creation_notice_mxid=$1"
  40. row := tq.db.QueryRow(query, mxid)
  41. if row == nil {
  42. return nil
  43. }
  44. return tq.New().Scan(row)
  45. }
  46. type Thread struct {
  47. db *Database
  48. log log.Logger
  49. ID string
  50. ParentID string
  51. RootDiscordID string
  52. RootMXID id.EventID
  53. CreationNoticeMXID id.EventID
  54. }
  55. func (t *Thread) Scan(row dbutil.Scannable) *Thread {
  56. err := row.Scan(&t.ID, &t.ParentID, &t.RootDiscordID, &t.RootMXID, &t.CreationNoticeMXID)
  57. if err != nil {
  58. if !errors.Is(err, sql.ErrNoRows) {
  59. t.log.Errorln("Database scan failed:", err)
  60. panic(err)
  61. }
  62. return nil
  63. }
  64. return t
  65. }
  66. func (t *Thread) Insert() {
  67. query := "INSERT INTO thread (dcid, parent_chan_id, root_msg_dcid, root_msg_mxid, creation_notice_mxid) VALUES ($1, $2, $3, $4, $5)"
  68. _, err := t.db.Exec(query, t.ID, t.ParentID, t.RootDiscordID, t.RootMXID, t.CreationNoticeMXID)
  69. if err != nil {
  70. t.log.Warnfln("Failed to insert %s@%s: %v", t.ID, t.ParentID, err)
  71. panic(err)
  72. }
  73. }
  74. func (t *Thread) Update() {
  75. query := "UPDATE thread SET creation_notice_mxid=$2 WHERE dcid=$1"
  76. _, err := t.db.Exec(query, t.ID, t.CreationNoticeMXID)
  77. if err != nil {
  78. t.log.Warnfln("Failed to update %s@%s: %v", t.ID, t.ParentID, err)
  79. panic(err)
  80. }
  81. }
  82. func (t *Thread) Delete() {
  83. query := "DELETE FROM thread WHERE dcid=$1 AND parent_chan_id=$2"
  84. _, err := t.db.Exec(query, t.ID, t.ParentID)
  85. if err != nil {
  86. t.log.Warnfln("Failed to delete %s@%s: %v", t.ID, t.ParentID, err)
  87. panic(err)
  88. }
  89. }