migrations.go 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. package migrations
  2. import (
  3. "database/sql"
  4. "embed"
  5. "sort"
  6. "github.com/lopezator/migrator"
  7. log "maunium.net/go/maulogger/v2"
  8. )
  9. //go:embed *.sql
  10. var embeddedMigrations embed.FS
  11. var (
  12. commonMigrations = []string{
  13. "01-initial.sql",
  14. "02-attachments.sql",
  15. "03-emoji.sql",
  16. "04-custom-puppet.sql",
  17. "05-additional-puppet-fields.sql",
  18. }
  19. sqliteMigrations = []string{
  20. "06-remove-unique-user-constraint.sqlite.sql",
  21. }
  22. postgresMigrations = []string{
  23. "06-remove-unique-user-constraint.postgres.sql",
  24. }
  25. )
  26. func migrationFromFile(filename string) *migrator.Migration {
  27. return &migrator.Migration{
  28. Name: filename,
  29. Func: func(tx *sql.Tx) error {
  30. data, err := embeddedMigrations.ReadFile(filename)
  31. if err != nil {
  32. return err
  33. }
  34. if _, err := tx.Exec(string(data)); err != nil {
  35. return err
  36. }
  37. return nil
  38. },
  39. }
  40. }
  41. func Run(db *sql.DB, baseLog log.Logger, dialect string) error {
  42. subLogger := baseLog.Sub("Migrations")
  43. logger := migrator.LoggerFunc(func(msg string, args ...interface{}) {
  44. subLogger.Infof(msg, args...)
  45. })
  46. migrationNames := commonMigrations
  47. switch dialect {
  48. case "sqlite3":
  49. migrationNames = append(migrationNames, sqliteMigrations...)
  50. case "postgres":
  51. migrationNames = append(migrationNames, postgresMigrations...)
  52. }
  53. sort.Strings(migrationNames)
  54. migrations := make([]interface{}, len(migrationNames))
  55. for idx, name := range migrationNames {
  56. migrations[idx] = migrationFromFile(name)
  57. }
  58. m, err := migrator.New(
  59. migrator.TableName("version"),
  60. migrator.WithLogger(logger),
  61. migrator.Migrations(migrations...),
  62. )
  63. if err != nil {
  64. return err
  65. }
  66. if err := m.Migrate(db); err != nil {
  67. return err
  68. }
  69. return nil
  70. }