1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- package upgrades
- import (
- "database/sql"
- "fmt"
- "strings"
- log "maunium.net/go/maulogger/v2"
- )
- type Dialect int
- const (
- Postgres Dialect = iota
- SQLite
- )
- type upgradeFunc func(Dialect, *sql.Tx, *sql.DB) error
- type upgrade struct {
- message string
- fn upgradeFunc
- }
- const NumberOfUpgrades = 9
- var upgrades [NumberOfUpgrades]upgrade
- var UnsupportedDatabaseVersion = fmt.Errorf("unsupported database version")
- func getVersion(dialect Dialect, db *sql.DB) (int, error) {
- _, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
- if err != nil {
- return -1, err
- }
- version := 0
- row := db.QueryRow("SELECT version FROM version LIMIT 1")
- if row != nil {
- _ = row.Scan(&version)
- }
- return version, nil
- }
- func setVersion(dialect Dialect, tx *sql.Tx, version int) error {
- _, err := tx.Exec("DELETE FROM version")
- if err != nil {
- return err
- }
- _, err = tx.Exec("INSERT INTO version (version) VALUES ($1)", version)
- return err
- }
- func Run(log log.Logger, dialectName string, db *sql.DB) error {
- var dialect Dialect
- switch strings.ToLower(dialectName) {
- case "postgres":
- dialect = Postgres
- case "sqlite3":
- dialect = SQLite
- default:
- return fmt.Errorf("unknown dialect %s", dialectName)
- }
- version, err := getVersion(dialect, db)
- if err != nil {
- return err
- }
- if version > NumberOfUpgrades {
- return UnsupportedDatabaseVersion
- }
- log.Infofln("Database currently on v%d, latest: v%d", version, NumberOfUpgrades)
- for i, upgrade := range upgrades[version:] {
- log.Infofln("Upgrading database to v%d: %s", version+i+1, upgrade.message)
- tx, err := db.Begin()
- if err != nil {
- return err
- }
- err = upgrade.fn(dialect, tx, db)
- if err != nil {
- return err
- }
- err = setVersion(dialect, tx, version+i+1)
- if err != nil {
- return err
- }
- err = tx.Commit()
- if err != nil {
- return err
- }
- }
- return nil
- }
|