upgrades.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. package upgrades
  2. import (
  3. "database/sql"
  4. "errors"
  5. "fmt"
  6. "strings"
  7. log "maunium.net/go/maulogger/v2"
  8. )
  9. type Dialect int
  10. const (
  11. Postgres Dialect = iota
  12. SQLite
  13. )
  14. func (dialect Dialect) String() string {
  15. switch dialect {
  16. case Postgres:
  17. return "postgres"
  18. case SQLite:
  19. return "sqlite3"
  20. default:
  21. return ""
  22. }
  23. }
  24. type upgradeFunc func(*sql.Tx, context) error
  25. type context struct {
  26. dialect Dialect
  27. db *sql.DB
  28. log log.Logger
  29. }
  30. type upgrade struct {
  31. message string
  32. fn upgradeFunc
  33. }
  34. const NumberOfUpgrades = 36
  35. var upgrades [NumberOfUpgrades]upgrade
  36. var UnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version")
  37. func GetVersion(db *sql.DB) (int, error) {
  38. _, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
  39. if err != nil {
  40. return -1, err
  41. }
  42. version := 0
  43. err = db.QueryRow("SELECT version FROM version LIMIT 1").Scan(&version)
  44. if err != nil && !errors.Is(err, sql.ErrNoRows) {
  45. return -1, err
  46. }
  47. return version, nil
  48. }
  49. func SetVersion(tx *sql.Tx, version int) error {
  50. _, err := tx.Exec("DELETE FROM version")
  51. if err != nil {
  52. return err
  53. }
  54. _, err = tx.Exec("INSERT INTO version (version) VALUES ($1)", version)
  55. return err
  56. }
  57. func execMany(tx *sql.Tx, queries ...string) error {
  58. for _, query := range queries {
  59. _, err := tx.Exec(query)
  60. if err != nil {
  61. return err
  62. }
  63. }
  64. return nil
  65. }
  66. func Run(log log.Logger, dialectName string, db *sql.DB) error {
  67. var dialect Dialect
  68. switch strings.ToLower(dialectName) {
  69. case "postgres":
  70. dialect = Postgres
  71. case "sqlite3":
  72. dialect = SQLite
  73. default:
  74. return fmt.Errorf("unknown dialect %s", dialectName)
  75. }
  76. version, err := GetVersion(db)
  77. if err != nil {
  78. return err
  79. }
  80. if version > NumberOfUpgrades {
  81. return fmt.Errorf("%w: currently on v%d, latest known: v%d", UnsupportedDatabaseVersion, version, NumberOfUpgrades)
  82. }
  83. log.Infofln("Database currently on v%d, latest: v%d", version, NumberOfUpgrades)
  84. for i, upgradeItem := range upgrades[version:] {
  85. if upgradeItem.fn == nil {
  86. continue
  87. }
  88. log.Infofln("Upgrading database to v%d: %s", version+i+1, upgradeItem.message)
  89. var tx *sql.Tx
  90. tx, err = db.Begin()
  91. if err != nil {
  92. return err
  93. }
  94. err = upgradeItem.fn(tx, context{dialect, db, log})
  95. if err != nil {
  96. return err
  97. }
  98. err = SetVersion(tx, version+i+1)
  99. if err != nil {
  100. return err
  101. }
  102. err = tx.Commit()
  103. if err != nil {
  104. return err
  105. }
  106. }
  107. return nil
  108. }