userportal.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. package database
  2. import (
  3. "database/sql"
  4. "errors"
  5. "time"
  6. log "maunium.net/go/maulogger/v2"
  7. "maunium.net/go/mautrix/util/dbutil"
  8. )
  9. const (
  10. UserPortalTypeDM = "dm"
  11. UserPortalTypeGuild = "guild"
  12. )
  13. type UserPortal struct {
  14. DiscordID string
  15. Type string
  16. Timestamp time.Time
  17. InSpace bool
  18. }
  19. func (up UserPortal) Scan(l log.Logger, row dbutil.Scannable) *UserPortal {
  20. var ts int64
  21. err := row.Scan(&up.DiscordID, &up.Type, &ts, &up.InSpace)
  22. if err != nil {
  23. l.Errorln("Error scanning user portal:", err)
  24. panic(err)
  25. }
  26. up.Timestamp = time.UnixMilli(ts)
  27. return &up
  28. }
  29. func (u *User) scanUserPortals(rows *sql.Rows) []UserPortal {
  30. var ups []UserPortal
  31. for rows.Next() {
  32. up := UserPortal{}.Scan(u.log, rows)
  33. if up != nil {
  34. ups = append(ups, *up)
  35. }
  36. }
  37. return ups
  38. }
  39. func (u *User) GetPortals() []UserPortal {
  40. rows, err := u.db.Query("SELECT discord_id, type, timestamp, in_space FROM user_portal WHERE user_mxid=$1", u.MXID)
  41. if err != nil {
  42. u.log.Errorln("Failed to get portals:", err)
  43. panic(err)
  44. }
  45. return u.scanUserPortals(rows)
  46. }
  47. func (u *User) IsInSpace(discordID string) (isIn bool) {
  48. query := `SELECT in_space FROM user_portal WHERE user_mxid=$1 AND discord_id=$2`
  49. err := u.db.QueryRow(query, u.MXID, discordID).Scan(&isIn)
  50. if err != nil && !errors.Is(err, sql.ErrNoRows) {
  51. u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, discordID, err)
  52. panic(err)
  53. }
  54. return
  55. }
  56. func (u *User) MarkInPortal(portal UserPortal) {
  57. query := `
  58. INSERT INTO user_portal (discord_id, type, user_mxid, timestamp, in_space)
  59. VALUES ($1, $2, $3, $4, $5)
  60. ON CONFLICT (discord_id, user_mxid) DO UPDATE
  61. SET timestamp=excluded.timestamp, in_space=excluded.in_space
  62. `
  63. _, err := u.db.Exec(query, portal.DiscordID, portal.Type, u.MXID, portal.Timestamp.UnixMilli(), portal.InSpace)
  64. if err != nil {
  65. u.log.Errorfln("Failed to insert user portal %s/%s: %v", u.MXID, portal.DiscordID, err)
  66. panic(err)
  67. }
  68. }
  69. func (u *User) MarkNotInPortal(discordID string) {
  70. query := `DELETE FROM user_portal WHERE user_mxid=$1 AND discord_id=$2`
  71. _, err := u.db.Exec(query, u.MXID, discordID)
  72. if err != nil {
  73. u.log.Errorfln("Failed to remove user portal %s/%s: %v", u.MXID, discordID, err)
  74. panic(err)
  75. }
  76. }
  77. func (u *User) PrunePortalList(beforeTS time.Time) []UserPortal {
  78. query := `
  79. DELETE FROM user_portal
  80. WHERE user_mxid=$1 AND timestamp<$2
  81. RETURNING discord_id, type, timestamp, in_space
  82. `
  83. rows, err := u.db.Query(query, u.MXID, beforeTS.UnixMilli())
  84. if err != nil {
  85. u.log.Errorln("Failed to prune user guild list:", err)
  86. panic(err)
  87. }
  88. return u.scanUserPortals(rows)
  89. }