upgrades.go 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. package upgrades
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "strings"
  6. log "maunium.net/go/maulogger/v2"
  7. )
  8. type Dialect int
  9. const (
  10. Postgres Dialect = iota
  11. SQLite
  12. )
  13. type upgradeFunc func(Dialect, *sql.Tx, *sql.DB) error
  14. type upgrade struct {
  15. message string
  16. fn upgradeFunc
  17. }
  18. const NumberOfUpgrades = 6
  19. var upgrades [NumberOfUpgrades]upgrade
  20. func getVersion(dialect Dialect, db *sql.DB) (int, error) {
  21. _, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
  22. if err != nil {
  23. return -1, err
  24. }
  25. version := 0
  26. row := db.QueryRow("SELECT version FROM version LIMIT 1")
  27. if row != nil {
  28. _ = row.Scan(&version)
  29. }
  30. return version, nil
  31. }
  32. func setVersion(dialect Dialect, tx *sql.Tx, version int) error {
  33. _, err := tx.Exec("DELETE FROM version")
  34. if err != nil {
  35. return err
  36. }
  37. _, err = tx.Exec("INSERT INTO version (version) VALUES ($1)", version)
  38. return err
  39. }
  40. func Run(log log.Logger, dialectName string, db *sql.DB) error {
  41. var dialect Dialect
  42. switch strings.ToLower(dialectName) {
  43. case "postgres":
  44. dialect = Postgres
  45. case "sqlite3":
  46. dialect = SQLite
  47. default:
  48. return fmt.Errorf("unknown dialect %s", dialectName)
  49. }
  50. version, err := getVersion(dialect, db)
  51. if err != nil {
  52. return err
  53. }
  54. log.Infofln("Database currently on v%d, latest: v%d", version, NumberOfUpgrades)
  55. for i, upgrade := range upgrades[version:] {
  56. log.Infofln("Upgrading database to v%d: %s", version+i+1, upgrade.message)
  57. tx, err := db.Begin()
  58. if err != nil {
  59. return err
  60. }
  61. err = upgrade.fn(dialect, tx, db)
  62. if err != nil {
  63. return err
  64. }
  65. err = setVersion(dialect, tx, version+i+1)
  66. if err != nil {
  67. return err
  68. }
  69. err = tx.Commit()
  70. if err != nil {
  71. return err
  72. }
  73. }
  74. return nil
  75. }