Browse Source

Add way to migrate database

Tulir Asokan 5 years ago
parent
commit
0f36ee0168
3 changed files with 160 additions and 7 deletions
  1. 124 0
      database/migrate.go
  2. 5 6
      database/upgrades/2019-08-25-move-state-store-to-db.go
  3. 31 1
      main.go

+ 124 - 0
database/migrate.go

@@ -0,0 +1,124 @@
+package database
+
+import (
+	"fmt"
+	"math"
+	"strings"
+)
+
+func countRows(db *Database, table string) (int, error) {
+	countRow := db.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", table))
+	var count int
+	err := countRow.Scan(&count)
+	return count, err
+}
+
+const VariableCountLimit = 512
+
+func migrateTable(old *Database, new *Database, table string, columns ...string) error {
+	columnNames := strings.Join(columns, ",")
+	fmt.Printf("Migrating %s: ", table)
+	rowCount, err := countRows(old, table)
+	if err != nil {
+		return err
+	}
+	fmt.Print("found ", rowCount, " rows of data, ")
+	rows, err := old.Query(fmt.Sprintf("SELECT %s FROM \"%s\"", columnNames, table))
+	if err != nil {
+		return err
+	}
+	serverColNames, err := rows.Columns()
+	if err != nil {
+		return err
+	}
+	colCount := len(serverColNames)
+	valueStringFormat := strings.Repeat("$%d, ", colCount)
+	valueStringFormat = fmt.Sprintf("(%s)", valueStringFormat[:len(valueStringFormat)-2])
+	cols := make([]interface{}, colCount)
+	colPtrs := make([]interface{}, colCount)
+	for i := 0; i < colCount; i++ {
+		colPtrs[i] = &cols[i]
+	}
+	batchSize := VariableCountLimit / colCount
+	values := make([]interface{}, batchSize*colCount)
+	valueStrings := make([]string, batchSize)
+	var inserted int64
+	batchCount := int(math.Ceil(float64(rowCount) / float64(batchSize)))
+	tx, err := new.Begin()
+	if err != nil {
+		return err
+	}
+	fmt.Printf("migrating in %d batches: ", batchCount)
+	for rowCount > 0 {
+		var i int
+		for ; rows.Next() && i < batchSize; i++ {
+			colPtrs := make([]interface{}, colCount)
+			valueStringArgs := make([]interface{}, colCount)
+			for j := 0; j < colCount; j++ {
+				pos := i*colCount + j
+				colPtrs[j] = &values[pos]
+				valueStringArgs[j] = pos + 1
+			}
+			valueStrings[i] = fmt.Sprintf(valueStringFormat, valueStringArgs...)
+			err = rows.Scan(colPtrs...)
+			if err != nil {
+				panic(err)
+			}
+		}
+		slicedValues := values
+		slicedValueStrings := valueStrings
+		if i < len(valueStrings) {
+			slicedValueStrings = slicedValueStrings[:i]
+			slicedValues = slicedValues[:i*colCount]
+		}
+		res, err := tx.Exec(fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES %s", table, columnNames, strings.Join(slicedValueStrings, ",")), slicedValues...)
+		if err != nil {
+			panic(err)
+		}
+		count, _ := res.RowsAffected()
+		inserted += count
+		rowCount -= batchSize
+		fmt.Print("#")
+	}
+	err = tx.Commit()
+	if err != nil {
+		return err
+	}
+	fmt.Println(" -- done with", inserted, "rows inserted")
+	return nil
+}
+
+func Migrate(old *Database, new *Database) {
+	err := migrateTable(old, new, "portal", "jid", "receiver", "mxid", "name", "topic", "avatar", "avatar_url")
+	if err != nil {
+		panic(err)
+	}
+	err = migrateTable(old, new, "user", "mxid", "jid", "management_room", "client_id", "client_token", "server_token", "enc_key", "mac_key", "last_connection")
+	if err != nil {
+		panic(err)
+	}
+	err = migrateTable(old, new, "puppet", "jid", "avatar", "displayname", "name_quality", "custom_mxid", "access_token", "next_batch", "avatar_url")
+	if err != nil {
+		panic(err)
+	}
+	err = migrateTable(old, new, "user_portal", "user_jid", "portal_jid", "portal_receiver", "in_community")
+	if err != nil {
+		panic(err)
+	}
+	err = migrateTable(old, new, "message", "chat_jid", "chat_receiver", "jid", "mxid", "sender", "content", "timestamp")
+	if err != nil {
+		panic(err)
+	}
+	err = migrateTable(old, new, "mx_registrations", "user_id")
+	if err != nil {
+		panic(err)
+	}
+	err = migrateTable(old, new, "mx_user_profile", "room_id", "user_id", "membership")
+	if err != nil {
+		panic(err)
+	}
+	err = migrateTable(old, new, "mx_room_state", "room_id", "power_levels")
+	if err != nil {
+		panic(err)
+	}
+}

+ 5 - 6
database/upgrades/2019-08-25-move-state-store-to-db.go

@@ -86,17 +86,16 @@ func init() {
 			roomStateTable = strings.Replace(roomStateTable, "TEXT", "JSONB", 1)
 		}
 
-		if data, err := ioutil.ReadFile("mx-state.json"); err != nil {
-			ctx.log.Debugln("mx-state.json not found, not migrating state store")
-			return nil
-		} else if err = json.Unmarshal(data, &store); err != nil {
-			return err
-		} else if _, err := tx.Exec(userProfileTable); err != nil {
+		if _, err := tx.Exec(userProfileTable); err != nil {
 			return err
 		} else if _, err = tx.Exec(roomStateTable); err != nil {
 			return err
 		} else if _, err = tx.Exec(registrationsTable); err != nil {
 			return err
+		} else if data, err := ioutil.ReadFile("mx-state.json"); err != nil {
+			ctx.log.Debugln("mx-state.json not found, not migrating state store")
+		} else if err = json.Unmarshal(data, &store); err != nil {
+			return err
 		} else if err = migrateRegistrations(tx, store.Registrations); err != nil {
 			return err
 		} else if err = migrateMemberships(tx, store.Memberships); err != nil {

+ 31 - 1
main.go

@@ -43,6 +43,7 @@ var configPath = flag.MakeFull("c", "config", "The path to your config file.", "
 var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String()
 var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool()
 var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if database is too new").Default("false").Bool()
+var migrateFrom = flag.Make().LongKey("migrate-db").Usage("Source database type and URI to migrate from.").Bool()
 var wantHelp, _ = flag.MakeHelpFlag()
 
 func (bridge *Bridge) GenerateRegistration() {
@@ -67,6 +68,32 @@ func (bridge *Bridge) GenerateRegistration() {
 	os.Exit(0)
 }
 
+func (bridge *Bridge) MigrateDatabase() {
+	oldDB, err := database.New(flag.Arg(0), flag.Arg(1))
+	if err != nil {
+		fmt.Println("Failed to open old database:", err)
+		os.Exit(30)
+	}
+	err = oldDB.Init()
+	if err != nil {
+		fmt.Println("Failed to upgrade old database:", err)
+		os.Exit(31)
+	}
+
+	newDB, err := database.New(bridge.Config.AppService.Database.Type, bridge.Config.AppService.Database.URI)
+	if err != nil {
+		bridge.Log.Fatalln("Failed to open new database:", err)
+		os.Exit(32)
+	}
+	err = newDB.Init()
+	if err != nil {
+		fmt.Println("Failed to upgrade new database:", err)
+		os.Exit(33)
+	}
+
+	database.Migrate(oldDB, newDB)
+}
+
 type Bridge struct {
 	AS             *appservice.AppService
 	EventProcessor *appservice.EventProcessor
@@ -265,6 +292,9 @@ func (bridge *Bridge) Main() {
 	if *generateRegistration {
 		bridge.GenerateRegistration()
 		return
+	} else if *migrateFrom {
+		bridge.MigrateDatabase()
+		return
 	}
 
 	bridge.Init()
@@ -285,7 +315,7 @@ func (bridge *Bridge) Main() {
 func main() {
 	flag.SetHelpTitles(
 		"mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.",
-		"mautrix-whatsapp [-h] [-c <path>] [-r <path>] [-g]")
+		"mautrix-whatsapp [-h] [-c <path>] [-r <path>] [-g] [--migrate-db <source type> <source uri>]")
 	err := flag.Parse()
 	if err != nil {
 		fmt.Fprintln(os.Stderr, err)