userportal.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. package database
  2. import (
  3. "database/sql"
  4. "errors"
  5. "time"
  6. log "maunium.net/go/maulogger/v2"
  7. "maunium.net/go/mautrix/util/dbutil"
  8. )
  9. const (
  10. UserPortalTypeDM = "dm"
  11. UserPortalTypeGuild = "guild"
  12. UserPortalTypeThread = "thread"
  13. )
  14. type UserPortal struct {
  15. DiscordID string
  16. Type string
  17. Timestamp time.Time
  18. InSpace bool
  19. }
  20. func (up UserPortal) Scan(l log.Logger, row dbutil.Scannable) *UserPortal {
  21. var ts int64
  22. err := row.Scan(&up.DiscordID, &up.Type, &ts, &up.InSpace)
  23. if err != nil {
  24. l.Errorln("Error scanning user portal:", err)
  25. panic(err)
  26. }
  27. up.Timestamp = time.UnixMilli(ts).UTC()
  28. return &up
  29. }
  30. func (u *User) scanUserPortals(rows dbutil.Rows) []UserPortal {
  31. var ups []UserPortal
  32. for rows.Next() {
  33. up := UserPortal{}.Scan(u.log, rows)
  34. if up != nil {
  35. ups = append(ups, *up)
  36. }
  37. }
  38. return ups
  39. }
  40. func (u *User) GetPortals() []UserPortal {
  41. rows, err := u.db.Query("SELECT discord_id, type, timestamp, in_space FROM user_portal WHERE user_mxid=$1", u.MXID)
  42. if err != nil {
  43. u.log.Errorln("Failed to get portals:", err)
  44. panic(err)
  45. }
  46. return u.scanUserPortals(rows)
  47. }
  48. func (u *User) IsInSpace(discordID string) (isIn bool) {
  49. query := `SELECT in_space FROM user_portal WHERE user_mxid=$1 AND discord_id=$2`
  50. err := u.db.QueryRow(query, u.MXID, discordID).Scan(&isIn)
  51. if err != nil && !errors.Is(err, sql.ErrNoRows) {
  52. u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, discordID, err)
  53. panic(err)
  54. }
  55. return
  56. }
  57. func (u *User) IsInPortal(discordID string) (isIn bool) {
  58. query := `SELECT EXISTS(SELECT 1 FROM user_portal WHERE user_mxid=$1 AND discord_id=$2)`
  59. err := u.db.QueryRow(query, u.MXID, discordID).Scan(&isIn)
  60. if err != nil && !errors.Is(err, sql.ErrNoRows) {
  61. u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, discordID, err)
  62. panic(err)
  63. }
  64. return
  65. }
  66. func (u *User) MarkInPortal(portal UserPortal) {
  67. query := `
  68. INSERT INTO user_portal (discord_id, type, user_mxid, timestamp, in_space)
  69. VALUES ($1, $2, $3, $4, $5)
  70. ON CONFLICT (discord_id, user_mxid) DO UPDATE
  71. SET timestamp=excluded.timestamp, in_space=excluded.in_space
  72. `
  73. _, err := u.db.Exec(query, portal.DiscordID, portal.Type, u.MXID, portal.Timestamp.UnixMilli(), portal.InSpace)
  74. if err != nil {
  75. u.log.Errorfln("Failed to insert user portal %s/%s: %v", u.MXID, portal.DiscordID, err)
  76. panic(err)
  77. }
  78. }
  79. func (u *User) MarkNotInPortal(discordID string) {
  80. query := `DELETE FROM user_portal WHERE user_mxid=$1 AND discord_id=$2`
  81. _, err := u.db.Exec(query, u.MXID, discordID)
  82. if err != nil {
  83. u.log.Errorfln("Failed to remove user portal %s/%s: %v", u.MXID, discordID, err)
  84. panic(err)
  85. }
  86. }
  87. func (u *User) PortalHasOtherUsers(discordID string) (hasOtherUsers bool) {
  88. query := `SELECT COUNT(*) > 0 FROM user_portal WHERE user_mxid<>$1 AND discord_id=$2`
  89. err := u.db.QueryRow(query, u.MXID, discordID).Scan(&hasOtherUsers)
  90. if err != nil {
  91. u.log.Errorfln("Failed to check if %s has users other than %s: %v", discordID, u.MXID, err)
  92. panic(err)
  93. }
  94. return
  95. }
  96. func (u *User) PrunePortalList(beforeTS time.Time) []UserPortal {
  97. query := `
  98. DELETE FROM user_portal
  99. WHERE user_mxid=$1 AND timestamp<$2 AND type IN ('dm', 'guild')
  100. RETURNING discord_id, type, timestamp, in_space
  101. `
  102. rows, err := u.db.Query(query, u.MXID, beforeTS.UnixMilli())
  103. if err != nil {
  104. u.log.Errorln("Failed to prune user guild list:", err)
  105. panic(err)
  106. }
  107. return u.scanUserPortals(rows)
  108. }