123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152 |
- package database
- import (
- "fmt"
- "math"
- "strings"
- )
- func countRows(db *Database, table string) (int, error) {
- countRow := db.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", table))
- var count int
- err := countRow.Scan(&count)
- return count, err
- }
- const VariableCountLimit = 512
- func migrateTable(old *Database, new *Database, table string, columns ...string) error {
- columnNames := strings.Join(columns, ",")
- fmt.Printf("Migrating %s: ", table)
- rowCount, err := countRows(old, table)
- if err != nil {
- return err
- }
- fmt.Print("found ", rowCount, " rows of data, ")
- rows, err := old.Query(fmt.Sprintf("SELECT %s FROM \"%s\"", columnNames, table))
- if err != nil {
- return err
- }
- serverColNames, err := rows.Columns()
- if err != nil {
- return err
- }
- colCount := len(serverColNames)
- valueStringFormat := strings.Repeat("$%d, ", colCount)
- valueStringFormat = fmt.Sprintf("(%s)", valueStringFormat[:len(valueStringFormat)-2])
- cols := make([]interface{}, colCount)
- colPtrs := make([]interface{}, colCount)
- for i := 0; i < colCount; i++ {
- colPtrs[i] = &cols[i]
- }
- batchSize := VariableCountLimit / colCount
- values := make([]interface{}, batchSize*colCount)
- valueStrings := make([]string, batchSize)
- var inserted int64
- batchCount := int(math.Ceil(float64(rowCount) / float64(batchSize)))
- tx, err := new.Begin()
- if err != nil {
- return err
- }
- fmt.Printf("migrating in %d batches: ", batchCount)
- for rowCount > 0 {
- var i int
- for ; rows.Next() && i < batchSize; i++ {
- colPtrs := make([]interface{}, colCount)
- valueStringArgs := make([]interface{}, colCount)
- for j := 0; j < colCount; j++ {
- pos := i*colCount + j
- colPtrs[j] = &values[pos]
- valueStringArgs[j] = pos + 1
- }
- valueStrings[i] = fmt.Sprintf(valueStringFormat, valueStringArgs...)
- err = rows.Scan(colPtrs...)
- if err != nil {
- panic(err)
- }
- }
- slicedValues := values
- slicedValueStrings := valueStrings
- if i < len(valueStrings) {
- slicedValueStrings = slicedValueStrings[:i]
- slicedValues = slicedValues[:i*colCount]
- }
- res, err := tx.Exec(fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES %s", table, columnNames, strings.Join(slicedValueStrings, ",")), slicedValues...)
- if err != nil {
- panic(err)
- }
- count, _ := res.RowsAffected()
- inserted += count
- rowCount -= batchSize
- fmt.Print("#")
- }
- err = tx.Commit()
- if err != nil {
- return err
- }
- fmt.Println(" -- done with", inserted, "rows inserted")
- return nil
- }
- func Migrate(old *Database, new *Database) {
- err := migrateTable(old, new, "portal", "jid", "receiver", "mxid", "name", "topic", "avatar", "avatar_url", "encrypted")
- if err != nil {
- panic(err)
- }
- err = migrateTable(old, new, "user", "mxid", "jid", "management_room", "client_id", "client_token", "server_token", "enc_key", "mac_key", "last_connection")
- if err != nil {
- panic(err)
- }
- err = migrateTable(old, new, "puppet", "jid", "avatar", "displayname", "name_quality", "custom_mxid", "access_token", "next_batch", "avatar_url")
- if err != nil {
- panic(err)
- }
- err = migrateTable(old, new, "user_portal", "user_jid", "portal_jid", "portal_receiver", "in_community")
- if err != nil {
- panic(err)
- }
- err = migrateTable(old, new, "message", "chat_jid", "chat_receiver", "jid", "mxid", "sender", "content", "timestamp")
- if err != nil {
- panic(err)
- }
- err = migrateTable(old, new, "mx_registrations", "user_id")
- if err != nil {
- panic(err)
- }
- err = migrateTable(old, new, "mx_user_profile", "room_id", "user_id", "membership")
- if err != nil {
- panic(err)
- }
- err = migrateTable(old, new, "mx_room_state", "room_id", "power_levels")
- if err != nil {
- panic(err)
- }
- err = migrateTable(old, new, "crypto_account", "device_id", "shared", "sync_token", "account")
- if err != nil {
- panic(err)
- }
- err = migrateTable(old, new, "crypto_message_index", "sender_key", "session_id", `"index"`, "event_id", "timestamp")
- if err != nil {
- panic(err)
- }
- err = migrateTable(old, new, "crypto_tracked_user", "user_id")
- if err != nil {
- panic(err)
- }
- err = migrateTable(old, new, "crypto_device", "user_id", "device_id", "identity_key", "signing_key", "trust", "deleted", "name")
- if err != nil {
- panic(err)
- }
- err = migrateTable(old, new, "crypto_olm_session", "session_id", "sender_key", "session", "created_at", "last_used")
- if err != nil {
- panic(err)
- }
- err = migrateTable(old, new, "crypto_megolm_inbound_session", "session_id", "sender_key", "signing_key", "room_id", "session", "forwarding_chains")
- if err != nil {
- panic(err)
- }
- err = migrateTable(old, new, "crypto_megolm_outbound_session", "room_id", "session_id", "session", "shared", "max_messages", "message_count", "max_age", "created_at", "last_used")
- if err != nil {
- panic(err)
- }
- }
|