upgrades.go 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. package upgrades
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "strings"
  6. log "maunium.net/go/maulogger/v2"
  7. )
  8. type Dialect int
  9. const (
  10. Postgres Dialect = iota
  11. SQLite
  12. )
  13. type upgradeFunc func(Dialect, *sql.Tx, *sql.DB) error
  14. type upgrade struct {
  15. message string
  16. fn upgradeFunc
  17. }
  18. const NumberOfUpgrades = 9
  19. var upgrades [NumberOfUpgrades]upgrade
  20. var UnsupportedDatabaseVersion = fmt.Errorf("unsupported database version")
  21. func getVersion(dialect Dialect, db *sql.DB) (int, error) {
  22. _, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
  23. if err != nil {
  24. return -1, err
  25. }
  26. version := 0
  27. row := db.QueryRow("SELECT version FROM version LIMIT 1")
  28. if row != nil {
  29. _ = row.Scan(&version)
  30. }
  31. return version, nil
  32. }
  33. func setVersion(dialect Dialect, tx *sql.Tx, version int) error {
  34. _, err := tx.Exec("DELETE FROM version")
  35. if err != nil {
  36. return err
  37. }
  38. _, err = tx.Exec("INSERT INTO version (version) VALUES ($1)", version)
  39. return err
  40. }
  41. func Run(log log.Logger, dialectName string, db *sql.DB) error {
  42. var dialect Dialect
  43. switch strings.ToLower(dialectName) {
  44. case "postgres":
  45. dialect = Postgres
  46. case "sqlite3":
  47. dialect = SQLite
  48. default:
  49. return fmt.Errorf("unknown dialect %s", dialectName)
  50. }
  51. version, err := getVersion(dialect, db)
  52. if err != nil {
  53. return err
  54. }
  55. if version > NumberOfUpgrades {
  56. return UnsupportedDatabaseVersion
  57. }
  58. log.Infofln("Database currently on v%d, latest: v%d", version, NumberOfUpgrades)
  59. for i, upgrade := range upgrades[version:] {
  60. log.Infofln("Upgrading database to v%d: %s", version+i+1, upgrade.message)
  61. tx, err := db.Begin()
  62. if err != nil {
  63. return err
  64. }
  65. err = upgrade.fn(dialect, tx, db)
  66. if err != nil {
  67. return err
  68. }
  69. err = setVersion(dialect, tx, version+i+1)
  70. if err != nil {
  71. return err
  72. }
  73. err = tx.Commit()
  74. if err != nil {
  75. return err
  76. }
  77. }
  78. return nil
  79. }