migrate.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. package database
  2. import (
  3. "fmt"
  4. "math"
  5. "strings"
  6. )
  7. func countRows(db *Database, table string) (int, error) {
  8. countRow := db.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", table))
  9. var count int
  10. err := countRow.Scan(&count)
  11. return count, err
  12. }
  13. const VariableCountLimit = 512
  14. func migrateTable(old *Database, new *Database, table string, columns ...string) error {
  15. columnNames := strings.Join(columns, ",")
  16. fmt.Printf("Migrating %s: ", table)
  17. rowCount, err := countRows(old, table)
  18. if err != nil {
  19. return err
  20. }
  21. fmt.Print("found ", rowCount, " rows of data, ")
  22. rows, err := old.Query(fmt.Sprintf("SELECT %s FROM \"%s\"", columnNames, table))
  23. if err != nil {
  24. return err
  25. }
  26. serverColNames, err := rows.Columns()
  27. if err != nil {
  28. return err
  29. }
  30. colCount := len(serverColNames)
  31. valueStringFormat := strings.Repeat("$%d, ", colCount)
  32. valueStringFormat = fmt.Sprintf("(%s)", valueStringFormat[:len(valueStringFormat)-2])
  33. cols := make([]interface{}, colCount)
  34. colPtrs := make([]interface{}, colCount)
  35. for i := 0; i < colCount; i++ {
  36. colPtrs[i] = &cols[i]
  37. }
  38. batchSize := VariableCountLimit / colCount
  39. values := make([]interface{}, batchSize*colCount)
  40. valueStrings := make([]string, batchSize)
  41. var inserted int64
  42. batchCount := int(math.Ceil(float64(rowCount) / float64(batchSize)))
  43. tx, err := new.Begin()
  44. if err != nil {
  45. return err
  46. }
  47. fmt.Printf("migrating in %d batches: ", batchCount)
  48. for rowCount > 0 {
  49. var i int
  50. for ; rows.Next() && i < batchSize; i++ {
  51. colPtrs := make([]interface{}, colCount)
  52. valueStringArgs := make([]interface{}, colCount)
  53. for j := 0; j < colCount; j++ {
  54. pos := i*colCount + j
  55. colPtrs[j] = &values[pos]
  56. valueStringArgs[j] = pos + 1
  57. }
  58. valueStrings[i] = fmt.Sprintf(valueStringFormat, valueStringArgs...)
  59. err = rows.Scan(colPtrs...)
  60. if err != nil {
  61. panic(err)
  62. }
  63. }
  64. slicedValues := values
  65. slicedValueStrings := valueStrings
  66. if i < len(valueStrings) {
  67. slicedValueStrings = slicedValueStrings[:i]
  68. slicedValues = slicedValues[:i*colCount]
  69. }
  70. res, err := tx.Exec(fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES %s", table, columnNames, strings.Join(slicedValueStrings, ",")), slicedValues...)
  71. if err != nil {
  72. panic(err)
  73. }
  74. count, _ := res.RowsAffected()
  75. inserted += count
  76. rowCount -= batchSize
  77. fmt.Print("#")
  78. }
  79. err = tx.Commit()
  80. if err != nil {
  81. return err
  82. }
  83. fmt.Println(" -- done with", inserted, "rows inserted")
  84. return nil
  85. }
  86. func Migrate(old *Database, new *Database) {
  87. err := migrateTable(old, new, "portal", "jid", "receiver", "mxid", "name", "topic", "avatar", "avatar_url", "encrypted")
  88. if err != nil {
  89. panic(err)
  90. }
  91. err = migrateTable(old, new, "user", "mxid", "jid", "management_room", "client_id", "client_token", "server_token", "enc_key", "mac_key", "last_connection")
  92. if err != nil {
  93. panic(err)
  94. }
  95. err = migrateTable(old, new, "puppet", "jid", "avatar", "displayname", "name_quality", "custom_mxid", "access_token", "next_batch", "avatar_url")
  96. if err != nil {
  97. panic(err)
  98. }
  99. err = migrateTable(old, new, "user_portal", "user_jid", "portal_jid", "portal_receiver", "in_community")
  100. if err != nil {
  101. panic(err)
  102. }
  103. err = migrateTable(old, new, "message", "chat_jid", "chat_receiver", "jid", "mxid", "sender", "content", "timestamp")
  104. if err != nil {
  105. panic(err)
  106. }
  107. err = migrateTable(old, new, "mx_registrations", "user_id")
  108. if err != nil {
  109. panic(err)
  110. }
  111. err = migrateTable(old, new, "mx_user_profile", "room_id", "user_id", "membership")
  112. if err != nil {
  113. panic(err)
  114. }
  115. err = migrateTable(old, new, "mx_room_state", "room_id", "power_levels")
  116. if err != nil {
  117. panic(err)
  118. }
  119. err = migrateTable(old, new, "crypto_account", "device_id", "shared", "sync_token", "account")
  120. if err != nil {
  121. panic(err)
  122. }
  123. err = migrateTable(old, new, "crypto_message_index", "sender_key", "session_id", `"index"`, "event_id", "timestamp")
  124. if err != nil {
  125. panic(err)
  126. }
  127. err = migrateTable(old, new, "crypto_tracked_user", "user_id")
  128. if err != nil {
  129. panic(err)
  130. }
  131. err = migrateTable(old, new, "crypto_device", "user_id", "device_id", "identity_key", "signing_key", "trust", "deleted", "name")
  132. if err != nil {
  133. panic(err)
  134. }
  135. err = migrateTable(old, new, "crypto_olm_session", "session_id", "sender_key", "session", "created_at", "last_used")
  136. if err != nil {
  137. panic(err)
  138. }
  139. err = migrateTable(old, new, "crypto_megolm_inbound_session", "session_id", "sender_key", "signing_key", "room_id", "session", "forwarding_chains")
  140. if err != nil {
  141. panic(err)
  142. }
  143. err = migrateTable(old, new, "crypto_megolm_outbound_session", "room_id", "session_id", "session", "shared", "max_messages", "message_count", "max_age", "created_at", "last_used")
  144. if err != nil {
  145. panic(err)
  146. }
  147. }