migrate.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. package database
  2. import (
  3. "fmt"
  4. "math"
  5. "strings"
  6. )
  7. func countRows(db *Database, table string) (int, error) {
  8. countRow := db.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", table))
  9. var count int
  10. err := countRow.Scan(&count)
  11. return count, err
  12. }
  13. const VariableCountLimit = 512
  14. func migrateTable(old *Database, new *Database, table string, columns ...string) error {
  15. columnNames := strings.Join(columns, ",")
  16. fmt.Printf("Migrating %s: ", table)
  17. rowCount, err := countRows(old, table)
  18. if err != nil {
  19. return err
  20. }
  21. fmt.Print("found ", rowCount, " rows of data, ")
  22. rows, err := old.Query(fmt.Sprintf("SELECT %s FROM \"%s\"", columnNames, table))
  23. if err != nil {
  24. return err
  25. }
  26. serverColNames, err := rows.Columns()
  27. if err != nil {
  28. return err
  29. }
  30. colCount := len(serverColNames)
  31. valueStringFormat := strings.Repeat("$%d, ", colCount)
  32. valueStringFormat = fmt.Sprintf("(%s)", valueStringFormat[:len(valueStringFormat)-2])
  33. cols := make([]interface{}, colCount)
  34. colPtrs := make([]interface{}, colCount)
  35. for i := 0; i < colCount; i++ {
  36. colPtrs[i] = &cols[i]
  37. }
  38. batchSize := VariableCountLimit / colCount
  39. values := make([]interface{}, batchSize*colCount)
  40. valueStrings := make([]string, batchSize)
  41. var inserted int64
  42. batchCount := int(math.Ceil(float64(rowCount) / float64(batchSize)))
  43. tx, err := new.Begin()
  44. if err != nil {
  45. return err
  46. }
  47. fmt.Printf("migrating in %d batches: ", batchCount)
  48. for rowCount > 0 {
  49. var i int
  50. for ; rows.Next() && i < batchSize; i++ {
  51. colPtrs := make([]interface{}, colCount)
  52. valueStringArgs := make([]interface{}, colCount)
  53. for j := 0; j < colCount; j++ {
  54. pos := i*colCount + j
  55. colPtrs[j] = &values[pos]
  56. valueStringArgs[j] = pos + 1
  57. }
  58. valueStrings[i] = fmt.Sprintf(valueStringFormat, valueStringArgs...)
  59. err = rows.Scan(colPtrs...)
  60. if err != nil {
  61. panic(err)
  62. }
  63. }
  64. slicedValues := values
  65. slicedValueStrings := valueStrings
  66. if i < len(valueStrings) {
  67. slicedValueStrings = slicedValueStrings[:i]
  68. slicedValues = slicedValues[:i*colCount]
  69. }
  70. res, err := tx.Exec(fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES %s", table, columnNames, strings.Join(slicedValueStrings, ",")), slicedValues...)
  71. if err != nil {
  72. panic(err)
  73. }
  74. count, _ := res.RowsAffected()
  75. inserted += count
  76. rowCount -= batchSize
  77. fmt.Print("#")
  78. }
  79. err = tx.Commit()
  80. if err != nil {
  81. return err
  82. }
  83. fmt.Println(" -- done with", inserted, "rows inserted")
  84. return nil
  85. }
  86. func Migrate(old *Database, new *Database) {
  87. err := migrateTable(old, new, "portal", "jid", "receiver", "mxid", "name", "topic", "avatar", "avatar_url")
  88. if err != nil {
  89. panic(err)
  90. }
  91. err = migrateTable(old, new, "user", "mxid", "jid", "management_room", "client_id", "client_token", "server_token", "enc_key", "mac_key", "last_connection")
  92. if err != nil {
  93. panic(err)
  94. }
  95. err = migrateTable(old, new, "puppet", "jid", "avatar", "displayname", "name_quality", "custom_mxid", "access_token", "next_batch", "avatar_url")
  96. if err != nil {
  97. panic(err)
  98. }
  99. err = migrateTable(old, new, "user_portal", "user_jid", "portal_jid", "portal_receiver", "in_community")
  100. if err != nil {
  101. panic(err)
  102. }
  103. err = migrateTable(old, new, "message", "chat_jid", "chat_receiver", "jid", "mxid", "sender", "content", "timestamp")
  104. if err != nil {
  105. panic(err)
  106. }
  107. err = migrateTable(old, new, "mx_registrations", "user_id")
  108. if err != nil {
  109. panic(err)
  110. }
  111. err = migrateTable(old, new, "mx_user_profile", "room_id", "user_id", "membership")
  112. if err != nil {
  113. panic(err)
  114. }
  115. err = migrateTable(old, new, "mx_room_state", "room_id", "power_levels")
  116. if err != nil {
  117. panic(err)
  118. }
  119. }