migrations.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. package migrations
  2. import (
  3. "database/sql"
  4. "embed"
  5. "fmt"
  6. "sort"
  7. "github.com/lopezator/migrator"
  8. log "maunium.net/go/maulogger/v2"
  9. )
  10. //go:embed *.sql
  11. var embeddedMigrations embed.FS
  12. var (
  13. commonMigrations = []string{
  14. "01-initial.sql",
  15. "02-attachments.sql",
  16. "03-emoji.sql",
  17. "04-custom-puppet.sql",
  18. "05-additional-puppet-fields.sql",
  19. }
  20. sqliteMigrations = []string{
  21. "06-remove-unique-user-constraint.sqlite.sql",
  22. }
  23. postgresMigrations = []string{
  24. "06-remove-unique-user-constraint.postgres.sql",
  25. }
  26. )
  27. func migrationFromFile(filename string) *migrator.Migration {
  28. return &migrator.Migration{
  29. Name: filename,
  30. Func: func(tx *sql.Tx) error {
  31. data, err := embeddedMigrations.ReadFile(filename)
  32. if err != nil {
  33. return err
  34. }
  35. if _, err := tx.Exec(string(data)); err != nil {
  36. return err
  37. }
  38. return nil
  39. },
  40. }
  41. }
  42. func Run(db *sql.DB, baseLog log.Logger, dialect string) error {
  43. subLogger := baseLog.Sub("Migrations")
  44. logger := migrator.LoggerFunc(func(msg string, args ...interface{}) {
  45. subLogger.Infof(msg, args...)
  46. })
  47. migrationNames := commonMigrations
  48. switch dialect {
  49. case "sqlite3":
  50. migrationNames = append(migrationNames, sqliteMigrations...)
  51. case "postgres":
  52. migrationNames = append(migrationNames, postgresMigrations...)
  53. }
  54. sort.Strings(migrationNames)
  55. migrations := make([]interface{}, len(migrationNames))
  56. for idx, name := range migrationNames {
  57. fmt.Printf("migration: %s\n", name)
  58. migrations[idx] = migrationFromFile(name)
  59. }
  60. fmt.Printf("migrations(%d)\n", len(migrations))
  61. m, err := migrator.New(
  62. migrator.TableName("version"),
  63. migrator.WithLogger(logger),
  64. migrator.Migrations(migrations...),
  65. )
  66. if err != nil {
  67. return err
  68. }
  69. if err := m.Migrate(db); err != nil {
  70. return err
  71. }
  72. return nil
  73. }