user.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. package database
  2. import (
  3. "database/sql"
  4. "errors"
  5. "time"
  6. log "maunium.net/go/maulogger/v2"
  7. "maunium.net/go/mautrix/id"
  8. "maunium.net/go/mautrix/util/dbutil"
  9. )
  10. type UserQuery struct {
  11. db *Database
  12. log log.Logger
  13. }
  14. func (uq *UserQuery) New() *User {
  15. return &User{
  16. db: uq.db,
  17. log: uq.log,
  18. }
  19. }
  20. func (uq *UserQuery) GetByMXID(userID id.UserID) *User {
  21. query := `SELECT mxid, dcid, discord_token, management_room, space_room FROM "user" WHERE mxid=$1`
  22. return uq.New().Scan(uq.db.QueryRow(query, userID))
  23. }
  24. func (uq *UserQuery) GetByID(id string) *User {
  25. query := `SELECT mxid, dcid, discord_token, management_room, space_room FROM "user" WHERE dcid=$1`
  26. return uq.New().Scan(uq.db.QueryRow(query, id))
  27. }
  28. func (uq *UserQuery) GetAllWithToken() []*User {
  29. query := `
  30. SELECT mxid, dcid, discord_token, management_room, space_room
  31. FROM "user" WHERE discord_token IS NOT NULL
  32. `
  33. rows, err := uq.db.Query(query)
  34. if err != nil || rows == nil {
  35. return nil
  36. }
  37. var users []*User
  38. for rows.Next() {
  39. user := uq.New().Scan(rows)
  40. if user != nil {
  41. users = append(users, user)
  42. }
  43. }
  44. return users
  45. }
  46. type User struct {
  47. db *Database
  48. log log.Logger
  49. MXID id.UserID
  50. DiscordID string
  51. DiscordToken string
  52. ManagementRoom id.RoomID
  53. SpaceRoom id.RoomID
  54. }
  55. type UserGuild struct {
  56. GuildID string
  57. Timestamp time.Time
  58. InSpace bool
  59. }
  60. func (u *User) GetGuilds() []UserGuild {
  61. res, err := u.db.Query("SELECT guild_id, timestamp, in_space FROM user_guild WHERE user_mxid=$1", u.MXID)
  62. if err != nil {
  63. u.log.Errorln("Failed to get guilds:", err)
  64. panic(err)
  65. return nil
  66. }
  67. var guilds []UserGuild
  68. for res.Next() {
  69. var guild UserGuild
  70. var ts int64
  71. err = res.Scan(&guild.GuildID, &ts, &guild.InSpace)
  72. if err != nil {
  73. u.log.Errorln("Error scanning user guild:", err)
  74. panic(err)
  75. } else {
  76. guild.Timestamp = time.UnixMilli(ts)
  77. guilds = append(guilds, guild)
  78. }
  79. }
  80. return guilds
  81. }
  82. func (u *User) IsInSpace(guildID string) (isIn bool) {
  83. query := `SELECT in_space FROM user_guild WHERE user_mxid=$1 AND guild_id=$2`
  84. err := u.db.QueryRow(query, u.MXID, guildID).Scan(&isIn)
  85. if err != nil && !errors.Is(err, sql.ErrNoRows) {
  86. u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, guildID, err)
  87. panic(err)
  88. }
  89. return
  90. }
  91. func (u *User) MarkInGuild(guild UserGuild) {
  92. query := `
  93. INSERT INTO user_guild (guild_id, user_mxid, timestamp, in_space)
  94. VALUES ($1, $2, $3, $4)
  95. ON CONFLICT (guild_id, user_mxid) DO UPDATE
  96. SET timestamp=excluded.timestamp, in_space=excluded.in_space
  97. `
  98. _, err := u.db.Exec(query, guild.GuildID, u.MXID, guild.Timestamp.UnixMilli(), guild.InSpace)
  99. if err != nil {
  100. u.log.Errorfln("Failed to insert user guild %s/%s: %v", u.MXID, guild.GuildID, err)
  101. panic(err)
  102. }
  103. }
  104. func (u *User) MarkNotInGuild(guildID string) {
  105. query := `DELETE FROM user_guild WHERE user_mxid=$1 AND guild_id=$2`
  106. _, err := u.db.Exec(query, u.MXID, guildID)
  107. if err != nil {
  108. u.log.Errorfln("Failed to remove user guild %s/%s: %v", u.MXID, guildID, err)
  109. panic(err)
  110. }
  111. }
  112. func (u *User) PruneGuildList(beforeTS time.Time) {
  113. _, err := u.db.Exec("DELETE FROM user_guild WHERE user_mxid=$1 AND timestamp<$2", u.MXID, beforeTS.UnixMilli())
  114. if err != nil {
  115. u.log.Errorln("Failed to prune user guild list:", err)
  116. panic(err)
  117. }
  118. }
  119. func (u *User) Scan(row dbutil.Scannable) *User {
  120. var discordID, managementRoom, spaceRoom, discordToken sql.NullString
  121. err := row.Scan(&u.MXID, &discordID, &discordToken, &managementRoom, &spaceRoom)
  122. if err != nil {
  123. if err != sql.ErrNoRows {
  124. u.log.Errorln("Database scan failed:", err)
  125. panic(err)
  126. }
  127. return nil
  128. }
  129. u.DiscordID = discordID.String
  130. u.DiscordToken = discordToken.String
  131. u.ManagementRoom = id.RoomID(managementRoom.String)
  132. u.SpaceRoom = id.RoomID(spaceRoom.String)
  133. return u
  134. }
  135. func (u *User) Insert() {
  136. query := `INSERT INTO "user" (mxid, dcid, discord_token, management_room, space_room) VALUES ($1, $2, $3, $4, $5)`
  137. _, err := u.db.Exec(query, u.MXID, strPtr(u.DiscordID), strPtr(u.DiscordToken), strPtr(string(u.ManagementRoom)), strPtr(string(u.SpaceRoom)))
  138. if err != nil {
  139. u.log.Warnfln("Failed to insert %s: %v", u.MXID, err)
  140. panic(err)
  141. }
  142. }
  143. func (u *User) Update() {
  144. query := `UPDATE "user" SET dcid=$1, discord_token=$2, management_room=$3, space_room=$4 WHERE mxid=$5`
  145. _, err := u.db.Exec(query, strPtr(u.DiscordID), strPtr(u.DiscordToken), strPtr(string(u.ManagementRoom)), strPtr(string(u.SpaceRoom)), u.MXID)
  146. if err != nil {
  147. u.log.Warnfln("Failed to update %q: %v", u.MXID, err)
  148. panic(err)
  149. }
  150. }