Переглянути джерело

Fix an issue where additional users weren't being persisted

Gary Kramlich 3 роки тому
батько
коміт
44443b4079

+ 1 - 1
database/database.go

@@ -40,7 +40,7 @@ func New(dbType, uri string, maxOpenConns, maxIdleConns int, baseLog log.Logger)
 
 	dbLog := baseLog.Sub("Database")
 
-	if err := migrations.Run(conn, dbLog); err != nil {
+	if err := migrations.Run(conn, dbLog, dbType); err != nil {
 		return nil, err
 	}
 

+ 1 - 0
database/migrations/06-remove-unique-user-constraint.postgres.sql

@@ -0,0 +1 @@
+ALTER TABLE "user" DROP CONSTRAINT user_id_key;

+ 18 - 0
database/migrations/06-remove-unique-user-constraint.sqlite.sql

@@ -0,0 +1,18 @@
+PRAGMA foreign_keys=off;
+
+ALTER TABLE "user" RENAME TO "old_user";
+
+CREATE TABLE "user" (
+	mxid TEXT PRIMARY KEY,
+	id   TEXT,
+
+	management_room TEXT,
+
+	token TEXT
+);
+
+INSERT INTO "user" SELECT mxid, id, management_room, token FROM "old_user";
+
+DROP TABLE "old_user";
+
+PRAGMA foreign_keys=on;

+ 42 - 10
database/migrations/migrations.go

@@ -3,19 +3,39 @@ package migrations
 import (
 	"database/sql"
 	"embed"
+	"fmt"
+	"sort"
 
 	"github.com/lopezator/migrator"
 	log "maunium.net/go/maulogger/v2"
 )
 
 //go:embed *.sql
-var migrations embed.FS
+var embeddedMigrations embed.FS
+
+var (
+	commonMigrations = []string{
+		"01-initial.sql",
+		"02-attachments.sql",
+		"03-emoji.sql",
+		"04-custom-puppet.sql",
+		"05-additional-puppet-fields.sql",
+	}
+
+	sqliteMigrations = []string{
+		"06-remove-unique-user-constraint.sqlite.sql",
+	}
+
+	postgresMigrations = []string{
+		"06-remove-unique-user-constraint.postgres.sql",
+	}
+)
 
 func migrationFromFile(filename string) *migrator.Migration {
 	return &migrator.Migration{
 		Name: filename,
 		Func: func(tx *sql.Tx) error {
-			data, err := migrations.ReadFile(filename)
+			data, err := embeddedMigrations.ReadFile(filename)
 			if err != nil {
 				return err
 			}
@@ -29,22 +49,34 @@ func migrationFromFile(filename string) *migrator.Migration {
 	}
 }
 
-func Run(db *sql.DB, baseLog log.Logger) error {
+func Run(db *sql.DB, baseLog log.Logger, dialect string) error {
 	subLogger := baseLog.Sub("Migrations")
 	logger := migrator.LoggerFunc(func(msg string, args ...interface{}) {
 		subLogger.Infof(msg, args...)
 	})
 
+	migrationNames := commonMigrations
+	switch dialect {
+	case "sqlite3":
+		migrationNames = append(migrationNames, sqliteMigrations...)
+	case "postgres":
+		migrationNames = append(migrationNames, postgresMigrations...)
+	}
+
+	sort.Strings(migrationNames)
+
+	migrations := make([]interface{}, len(migrationNames))
+	for idx, name := range migrationNames {
+		fmt.Printf("migration: %s\n", name)
+		migrations[idx] = migrationFromFile(name)
+	}
+
+	fmt.Printf("migrations(%d)\n", len(migrations))
+
 	m, err := migrator.New(
 		migrator.TableName("version"),
 		migrator.WithLogger(logger),
-		migrator.Migrations(
-			migrationFromFile("01-initial.sql"),
-			migrationFromFile("02-attachments.sql"),
-			migrationFromFile("03-emoji.sql"),
-			migrationFromFile("04-custom-puppet.sql"),
-			migrationFromFile("05-additional-puppet-fields.sql"),
-		),
+		migrator.Migrations(migrations...),
 	)
 	if err != nil {
 		return err