userportal.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. package database
  2. import (
  3. "database/sql"
  4. "errors"
  5. "time"
  6. log "maunium.net/go/maulogger/v2"
  7. "maunium.net/go/mautrix/util/dbutil"
  8. )
  9. const (
  10. UserPortalTypeDM = "dm"
  11. UserPortalTypeGuild = "guild"
  12. UserPortalTypeThread = "thread"
  13. )
  14. type UserPortal struct {
  15. DiscordID string
  16. Type string
  17. Timestamp time.Time
  18. InSpace bool
  19. }
  20. func (up UserPortal) Scan(l log.Logger, row dbutil.Scannable) *UserPortal {
  21. var ts int64
  22. err := row.Scan(&up.DiscordID, &up.Type, &ts, &up.InSpace)
  23. if err != nil {
  24. l.Errorln("Error scanning user portal:", err)
  25. panic(err)
  26. }
  27. up.Timestamp = time.UnixMilli(ts)
  28. return &up
  29. }
  30. func (u *User) scanUserPortals(rows *sql.Rows) []UserPortal {
  31. var ups []UserPortal
  32. for rows.Next() {
  33. up := UserPortal{}.Scan(u.log, rows)
  34. if up != nil {
  35. ups = append(ups, *up)
  36. }
  37. }
  38. return ups
  39. }
  40. func (u *User) GetPortals() []UserPortal {
  41. rows, err := u.db.Query("SELECT discord_id, type, timestamp, in_space FROM user_portal WHERE user_mxid=$1", u.MXID)
  42. if err != nil {
  43. u.log.Errorln("Failed to get portals:", err)
  44. panic(err)
  45. }
  46. return u.scanUserPortals(rows)
  47. }
  48. func (u *User) IsInSpace(discordID string) (isIn bool) {
  49. query := `SELECT in_space FROM user_portal WHERE user_mxid=$1 AND discord_id=$2`
  50. err := u.db.QueryRow(query, u.MXID, discordID).Scan(&isIn)
  51. if err != nil && !errors.Is(err, sql.ErrNoRows) {
  52. u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, discordID, err)
  53. panic(err)
  54. }
  55. return
  56. }
  57. func (u *User) IsInPortal(discordID string) (isIn bool) {
  58. query := `SELECT EXISTS(SELECT 1 FROM user_portal WHERE user_mxid=$1 AND discord_id=$2)`
  59. err := u.db.QueryRow(query, u.MXID, discordID).Scan(&isIn)
  60. if err != nil && !errors.Is(err, sql.ErrNoRows) {
  61. u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, discordID, err)
  62. panic(err)
  63. }
  64. return
  65. }
  66. func (u *User) MarkInPortal(portal UserPortal) {
  67. query := `
  68. INSERT INTO user_portal (discord_id, type, user_mxid, timestamp, in_space)
  69. VALUES ($1, $2, $3, $4, $5)
  70. ON CONFLICT (discord_id, user_mxid) DO UPDATE
  71. SET timestamp=excluded.timestamp, in_space=excluded.in_space
  72. `
  73. _, err := u.db.Exec(query, portal.DiscordID, portal.Type, u.MXID, portal.Timestamp.UnixMilli(), portal.InSpace)
  74. if err != nil {
  75. u.log.Errorfln("Failed to insert user portal %s/%s: %v", u.MXID, portal.DiscordID, err)
  76. panic(err)
  77. }
  78. }
  79. func (u *User) MarkNotInPortal(discordID string) {
  80. query := `DELETE FROM user_portal WHERE user_mxid=$1 AND discord_id=$2`
  81. _, err := u.db.Exec(query, u.MXID, discordID)
  82. if err != nil {
  83. u.log.Errorfln("Failed to remove user portal %s/%s: %v", u.MXID, discordID, err)
  84. panic(err)
  85. }
  86. }
  87. func (u *User) PrunePortalList(beforeTS time.Time) []UserPortal {
  88. query := `
  89. DELETE FROM user_portal
  90. WHERE user_mxid=$1 AND timestamp<$2 AND type IN ('dm', 'guild')
  91. RETURNING discord_id, type, timestamp, in_space
  92. `
  93. rows, err := u.db.Query(query, u.MXID, beforeTS.UnixMilli())
  94. if err != nil {
  95. u.log.Errorln("Failed to prune user guild list:", err)
  96. panic(err)
  97. }
  98. return u.scanUserPortals(rows)
  99. }