upgrades.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. package upgrades
  2. import (
  3. "database/sql"
  4. "errors"
  5. "fmt"
  6. "strings"
  7. log "maunium.net/go/maulogger/v2"
  8. )
  9. type Dialect int
  10. const (
  11. Postgres Dialect = iota
  12. SQLite
  13. )
  14. func (dialect Dialect) String() string {
  15. switch dialect {
  16. case Postgres:
  17. return "postgres"
  18. case SQLite:
  19. return "sqlite3"
  20. default:
  21. return ""
  22. }
  23. }
  24. type upgradeFunc func(*sql.Tx, context) error
  25. type context struct {
  26. dialect Dialect
  27. db *sql.DB
  28. log log.Logger
  29. }
  30. type upgrade struct {
  31. message string
  32. fn upgradeFunc
  33. }
  34. const NumberOfUpgrades = 46
  35. var upgrades [NumberOfUpgrades]upgrade
  36. var ErrUnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version")
  37. var ErrForeignTables = fmt.Errorf("the database contains foreign tables")
  38. var ErrNotOwned = fmt.Errorf("the database is owned by")
  39. var IgnoreForeignTables = false
  40. const databaseOwner = "mautrix-whatsapp"
  41. func GetVersion(db *sql.DB) (int, error) {
  42. _, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
  43. if err != nil {
  44. return -1, err
  45. }
  46. version := 0
  47. err = db.QueryRow("SELECT version FROM version LIMIT 1").Scan(&version)
  48. if err != nil && !errors.Is(err, sql.ErrNoRows) {
  49. return -1, err
  50. }
  51. return version, nil
  52. }
  53. const tableExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)"
  54. const tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND table_name=$1)"
  55. func tableExists(dialect Dialect, db *sql.DB, table string) (exists bool) {
  56. if dialect == SQLite {
  57. _ = db.QueryRow(tableExistsSQLite, table).Scan(&exists)
  58. } else if dialect == Postgres {
  59. _ = db.QueryRow(tableExistsPostgres, table).Scan(&exists)
  60. }
  61. return
  62. }
  63. const createOwnerTable = `
  64. CREATE TABLE IF NOT EXISTS database_owner (
  65. key INTEGER PRIMARY KEY DEFAULT 0,
  66. owner TEXT NOT NULL
  67. )
  68. `
  69. func CheckDatabaseOwner(dialect Dialect, db *sql.DB) error {
  70. var owner string
  71. if !IgnoreForeignTables {
  72. if tableExists(dialect, db, "state_groups_state") {
  73. return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables)
  74. } else if tableExists(dialect, db, "goose_db_version") {
  75. return fmt.Errorf("%w (found goose_db_version, possibly belonging to Dendrite)", ErrForeignTables)
  76. }
  77. }
  78. if _, err := db.Exec(createOwnerTable); err != nil {
  79. return fmt.Errorf("failed to ensure database owner table exists: %w", err)
  80. } else if err = db.QueryRow("SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
  81. _, err = db.Exec("INSERT INTO database_owner (owner) VALUES ($1)", databaseOwner)
  82. if err != nil {
  83. return fmt.Errorf("failed to insert database owner: %w", err)
  84. }
  85. } else if err != nil {
  86. return fmt.Errorf("failed to check database owner: %w", err)
  87. } else if owner != databaseOwner {
  88. return fmt.Errorf("%w %s", ErrNotOwned, owner)
  89. }
  90. return nil
  91. }
  92. func SetVersion(tx *sql.Tx, version int) error {
  93. _, err := tx.Exec("DELETE FROM version")
  94. if err != nil {
  95. return err
  96. }
  97. _, err = tx.Exec("INSERT INTO version (version) VALUES ($1)", version)
  98. return err
  99. }
  100. func execMany(tx *sql.Tx, queries ...string) error {
  101. for _, query := range queries {
  102. _, err := tx.Exec(query)
  103. if err != nil {
  104. return err
  105. }
  106. }
  107. return nil
  108. }
  109. func Run(log log.Logger, dialectName string, db *sql.DB) error {
  110. var dialect Dialect
  111. switch strings.ToLower(dialectName) {
  112. case "postgres":
  113. dialect = Postgres
  114. case "sqlite3":
  115. dialect = SQLite
  116. default:
  117. return fmt.Errorf("unknown dialect %s", dialectName)
  118. }
  119. err := CheckDatabaseOwner(dialect, db)
  120. if err != nil {
  121. return err
  122. }
  123. version, err := GetVersion(db)
  124. if err != nil {
  125. return err
  126. }
  127. if version > NumberOfUpgrades {
  128. return fmt.Errorf("%w: currently on v%d, latest known: v%d", ErrUnsupportedDatabaseVersion, version, NumberOfUpgrades)
  129. }
  130. log.Infofln("Database currently on v%d, latest: v%d", version, NumberOfUpgrades)
  131. for i, upgradeItem := range upgrades[version:] {
  132. if upgradeItem.fn == nil {
  133. continue
  134. }
  135. log.Infofln("Upgrading database to v%d: %s", version+i+1, upgradeItem.message)
  136. var tx *sql.Tx
  137. tx, err = db.Begin()
  138. if err != nil {
  139. return err
  140. }
  141. err = upgradeItem.fn(tx, context{dialect, db, log})
  142. if err != nil {
  143. return err
  144. }
  145. err = SetVersion(tx, version+i+1)
  146. if err != nil {
  147. return err
  148. }
  149. err = tx.Commit()
  150. if err != nil {
  151. return err
  152. }
  153. }
  154. return nil
  155. }