123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- package upgrades
- import (
- "database/sql"
- "fmt"
- "strings"
- log "maunium.net/go/maulogger/v2"
- )
- type Dialect int
- const (
- Postgres Dialect = iota
- SQLite
- )
- type upgradeFunc func(*sql.Tx, context) error
- type context struct {
- dialect Dialect
- db *sql.DB
- log log.Logger
- }
- type upgrade struct {
- message string
- fn upgradeFunc
- }
- const NumberOfUpgrades = 15
- var upgrades [NumberOfUpgrades]upgrade
- var UnsupportedDatabaseVersion = fmt.Errorf("unsupported database version")
- func GetVersion(db *sql.DB) (int, error) {
- _, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
- if err != nil {
- return -1, err
- }
- version := 0
- row := db.QueryRow("SELECT version FROM version LIMIT 1")
- if row != nil {
- _ = row.Scan(&version)
- }
- return version, nil
- }
- func SetVersion(tx *sql.Tx, version int) error {
- _, err := tx.Exec("DELETE FROM version")
- if err != nil {
- return err
- }
- _, err = tx.Exec("INSERT INTO version (version) VALUES ($1)", version)
- return err
- }
- func Run(log log.Logger, dialectName string, db *sql.DB) error {
- var dialect Dialect
- switch strings.ToLower(dialectName) {
- case "postgres":
- dialect = Postgres
- case "sqlite3":
- dialect = SQLite
- default:
- return fmt.Errorf("unknown dialect %s", dialectName)
- }
- version, err := GetVersion(db)
- if err != nil {
- return err
- }
- if version > NumberOfUpgrades {
- return UnsupportedDatabaseVersion
- }
- log.Infofln("Database currently on v%d, latest: v%d", version, NumberOfUpgrades)
- for i, upgrade := range upgrades[version:] {
- log.Infofln("Upgrading database to v%d: %s", version+i+1, upgrade.message)
- tx, err := db.Begin()
- if err != nil {
- return err
- }
- err = upgrade.fn(tx, context{dialect, db, log})
- if err != nil {
- return err
- }
- err = SetVersion(tx, version+i+1)
- if err != nil {
- return err
- }
- err = tx.Commit()
- if err != nil {
- return err
- }
- }
- return nil
- }
|