upgrades.go 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. package upgrades
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "strings"
  6. log "maunium.net/go/maulogger/v2"
  7. )
  8. type Dialect int
  9. const (
  10. Postgres Dialect = iota
  11. SQLite
  12. )
  13. type upgradeFunc func(Dialect, *sql.Tx) error
  14. type upgrade struct {
  15. message string
  16. fn upgradeFunc
  17. }
  18. var upgrades [2]upgrade
  19. func getVersion(dialect Dialect, db *sql.DB) (int, error) {
  20. _, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
  21. if err != nil {
  22. return -1, err
  23. }
  24. version := 0
  25. row := db.QueryRow("SELECT version FROM version LIMIT 1")
  26. if row != nil {
  27. _ = row.Scan(&version)
  28. }
  29. return version, nil
  30. }
  31. func setVersion(dialect Dialect, tx *sql.Tx, version int) error {
  32. _, err := tx.Exec("DELETE FROM version")
  33. if err != nil {
  34. return err
  35. }
  36. _, err = tx.Exec("INSERT INTO version (version) VALUES ($1)", version)
  37. return err
  38. }
  39. func Run(log log.Logger, dialectName string, db *sql.DB) error {
  40. var dialect Dialect
  41. switch strings.ToLower(dialectName) {
  42. case "postgres":
  43. dialect = Postgres
  44. case "sqlite3":
  45. dialect = SQLite
  46. default:
  47. return fmt.Errorf("unknown dialect %s", dialectName)
  48. }
  49. version, err := getVersion(dialect, db)
  50. if err != nil {
  51. return err
  52. }
  53. log.Infofln("Database currently on v%d, latest: v%d", version, len(upgrades))
  54. for i, upgrade := range upgrades[version:] {
  55. log.Infofln("Upgrading database to v%d: %s", i+1, upgrade.message)
  56. tx, err := db.Begin()
  57. if err != nil {
  58. return err
  59. }
  60. err = upgrade.fn(dialect, tx)
  61. if err != nil {
  62. return err
  63. }
  64. err = setVersion(dialect, tx, i+1)
  65. if err != nil {
  66. return err
  67. }
  68. err = tx.Commit()
  69. if err != nil {
  70. return err
  71. }
  72. }
  73. return nil
  74. }