guild.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. package database
  2. import (
  3. "database/sql"
  4. "errors"
  5. "fmt"
  6. log "maunium.net/go/maulogger/v2"
  7. "maunium.net/go/mautrix/util/dbutil"
  8. )
  9. type GuildQuery struct {
  10. db *Database
  11. log log.Logger
  12. }
  13. const (
  14. guildSelect = "SELECT discord_id, guild_id, guild_name, bridge FROM guild"
  15. )
  16. func (gq *GuildQuery) New() *Guild {
  17. return &Guild{
  18. db: gq.db,
  19. log: gq.log,
  20. }
  21. }
  22. func (gq *GuildQuery) Get(discordID, guildID string) *Guild {
  23. query := guildSelect + " WHERE discord_id=$1 AND guild_id=$2"
  24. row := gq.db.QueryRow(query, discordID, guildID)
  25. if row == nil {
  26. return nil
  27. }
  28. return gq.New().Scan(row)
  29. }
  30. func (gq *GuildQuery) GetAll(discordID string) []*Guild {
  31. query := guildSelect + " WHERE discord_id=$1"
  32. rows, err := gq.db.Query(query, discordID)
  33. if err != nil || rows == nil {
  34. return nil
  35. }
  36. guilds := []*Guild{}
  37. for rows.Next() {
  38. guilds = append(guilds, gq.New().Scan(rows))
  39. }
  40. return guilds
  41. }
  42. func (gq *GuildQuery) Prune(discordID string, guilds []string) {
  43. // We need this interface slice because a variadic function can't mix
  44. // arguements with a `...` expanded slice.
  45. args := []interface{}{discordID}
  46. nGuilds := len(guilds)
  47. if nGuilds <= 0 {
  48. return
  49. }
  50. gq.log.Debugfln("prunning guilds for %s", discordID)
  51. // Build the in query
  52. inQuery := "$2"
  53. for i := 1; i < nGuilds; i++ {
  54. inQuery += fmt.Sprintf(", $%d", i+2)
  55. }
  56. // Add the arguements for the build query
  57. for _, guildID := range guilds {
  58. args = append(args, guildID)
  59. }
  60. // Now remove any guilds that the user has left.
  61. query := "DELETE FROM guild WHERE discord_id=$1 AND guild_id NOT IN (" +
  62. inQuery + ")"
  63. _, err := gq.db.Exec(query, args...)
  64. if err != nil {
  65. gq.log.Warnfln("Failed to remove old guilds for user %s: %v", discordID, err)
  66. }
  67. }
  68. type Guild struct {
  69. db *Database
  70. log log.Logger
  71. DiscordID string
  72. GuildID string
  73. GuildName string
  74. Bridge bool
  75. }
  76. func (g *Guild) Scan(row dbutil.Scannable) *Guild {
  77. err := row.Scan(&g.DiscordID, &g.GuildID, &g.GuildName, &g.Bridge)
  78. if err != nil {
  79. if !errors.Is(err, sql.ErrNoRows) {
  80. g.log.Errorln("Database scan failed:", err)
  81. }
  82. return nil
  83. }
  84. return g
  85. }
  86. func (g *Guild) Upsert() {
  87. query := "INSERT INTO guild" +
  88. " (discord_id, guild_id, guild_name, bridge)" +
  89. " VALUES ($1, $2, $3, $4)" +
  90. " ON CONFLICT(discord_id, guild_id)" +
  91. " DO UPDATE SET guild_name=excluded.guild_name, bridge=excluded.bridge"
  92. _, err := g.db.Exec(query, g.DiscordID, g.GuildID, g.GuildName, g.Bridge)
  93. if err != nil {
  94. g.log.Warnfln("Failed to upsert guild %s for %s: %v", g.GuildID, g.DiscordID, err)
  95. }
  96. }
  97. func (g *Guild) Delete() {
  98. query := "DELETE FROM guild WHERE discord_id=$1 AND guild_id=$2"
  99. _, err := g.db.Exec(query, g.DiscordID, g.GuildID)
  100. if err != nil {
  101. g.log.Warnfln("Failed to delete guild %s for user %s: %v", g.GuildID, g.DiscordID, err)
  102. }
  103. }