userportal.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. package database
  2. import (
  3. "database/sql"
  4. "errors"
  5. "time"
  6. log "maunium.net/go/maulogger/v2"
  7. "maunium.net/go/mautrix/util/dbutil"
  8. )
  9. const (
  10. UserPortalTypeDM = "dm"
  11. UserPortalTypeGuild = "guild"
  12. )
  13. type UserPortal struct {
  14. DiscordID string
  15. Type string
  16. Timestamp time.Time
  17. InSpace bool
  18. }
  19. func (up UserPortal) Scan(l log.Logger, row dbutil.Scannable) *UserPortal {
  20. var ts int64
  21. err := row.Scan(&up.DiscordID, &up.Type, &ts, &up.InSpace)
  22. if err != nil {
  23. l.Errorln("Error scanning user portal:", err)
  24. panic(err)
  25. return nil
  26. }
  27. up.Timestamp = time.UnixMilli(ts)
  28. return &up
  29. }
  30. func (u *User) scanUserPortals(rows *sql.Rows) []UserPortal {
  31. var ups []UserPortal
  32. for rows.Next() {
  33. up := UserPortal{}.Scan(u.log, rows)
  34. if up != nil {
  35. ups = append(ups, *up)
  36. }
  37. }
  38. return ups
  39. }
  40. func (u *User) GetPortals() []UserPortal {
  41. rows, err := u.db.Query("SELECT discord_id, type, timestamp, in_space FROM user_portal WHERE user_mxid=$1", u.MXID)
  42. if err != nil {
  43. u.log.Errorln("Failed to get portals:", err)
  44. panic(err)
  45. return nil
  46. }
  47. return u.scanUserPortals(rows)
  48. }
  49. func (u *User) IsInSpace(discordID string) (isIn bool) {
  50. query := `SELECT in_space FROM user_portal WHERE user_mxid=$1 AND discord_id=$2`
  51. err := u.db.QueryRow(query, u.MXID, discordID).Scan(&isIn)
  52. if err != nil && !errors.Is(err, sql.ErrNoRows) {
  53. u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, discordID, err)
  54. panic(err)
  55. }
  56. return
  57. }
  58. func (u *User) MarkInPortal(portal UserPortal) {
  59. query := `
  60. INSERT INTO user_portal (discord_id, type, user_mxid, timestamp, in_space)
  61. VALUES ($1, $2, $3, $4, $5)
  62. ON CONFLICT (discord_id, user_mxid) DO UPDATE
  63. SET timestamp=excluded.timestamp, in_space=excluded.in_space
  64. `
  65. _, err := u.db.Exec(query, portal.DiscordID, portal.Type, u.MXID, portal.Timestamp.UnixMilli(), portal.InSpace)
  66. if err != nil {
  67. u.log.Errorfln("Failed to insert user portal %s/%s: %v", u.MXID, portal.DiscordID, err)
  68. panic(err)
  69. }
  70. }
  71. func (u *User) MarkNotInPortal(discordID string) {
  72. query := `DELETE FROM user_portal WHERE user_mxid=$1 AND discord_id=$2`
  73. _, err := u.db.Exec(query, u.MXID, discordID)
  74. if err != nil {
  75. u.log.Errorfln("Failed to remove user portal %s/%s: %v", u.MXID, discordID, err)
  76. panic(err)
  77. }
  78. }
  79. func (u *User) PrunePortalList(beforeTS time.Time) []UserPortal {
  80. query := `
  81. DELETE FROM user_portal
  82. WHERE user_mxid=$1 AND timestamp<$2
  83. RETURNING discord_id, type, timestamp, in_space
  84. `
  85. rows, err := u.db.Query(query, u.MXID, beforeTS.UnixMilli())
  86. if err != nil {
  87. u.log.Errorln("Failed to prune user guild list:", err)
  88. panic(err)
  89. }
  90. return u.scanUserPortals(rows)
  91. }