migrate.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. package database
  2. import (
  3. "fmt"
  4. "math"
  5. "strings"
  6. )
  7. func countRows(db *Database, table string) (int, error) {
  8. countRow := db.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", table))
  9. var count int
  10. err := countRow.Scan(&count)
  11. return count, err
  12. }
  13. const VariableCountLimit = 512
  14. func migrateTable(old *Database, new *Database, table string, columns ...string) error {
  15. columnNames := strings.Join(columns, ",")
  16. fmt.Printf("Migrating %s: ", table)
  17. rowCount, err := countRows(old, table)
  18. if err != nil {
  19. return err
  20. }
  21. fmt.Print("found ", rowCount, " rows of data, ")
  22. rows, err := old.Query(fmt.Sprintf("SELECT %s FROM \"%s\"", columnNames, table))
  23. if err != nil {
  24. return err
  25. }
  26. serverColNames, err := rows.Columns()
  27. if err != nil {
  28. return err
  29. }
  30. colCount := len(serverColNames)
  31. valueStringFormat := strings.Repeat("$%d, ", colCount)
  32. valueStringFormat = fmt.Sprintf("(%s)", valueStringFormat[:len(valueStringFormat)-2])
  33. cols := make([]interface{}, colCount)
  34. colPtrs := make([]interface{}, colCount)
  35. for i := 0; i < colCount; i++ {
  36. colPtrs[i] = &cols[i]
  37. }
  38. batchSize := VariableCountLimit / colCount
  39. values := make([]interface{}, batchSize*colCount)
  40. valueStrings := make([]string, batchSize)
  41. var inserted int64
  42. batchCount := int(math.Ceil(float64(rowCount) / float64(batchSize)))
  43. tx, err := new.Begin()
  44. if err != nil {
  45. return err
  46. }
  47. fmt.Printf("migrating in %d batches: ", batchCount)
  48. for rowCount > 0 {
  49. var i int
  50. for ; rows.Next() && i < batchSize; i++ {
  51. colPtrs := make([]interface{}, colCount)
  52. valueStringArgs := make([]interface{}, colCount)
  53. for j := 0; j < colCount; j++ {
  54. pos := i*colCount + j
  55. colPtrs[j] = &values[pos]
  56. valueStringArgs[j] = pos + 1
  57. }
  58. valueStrings[i] = fmt.Sprintf(valueStringFormat, valueStringArgs...)
  59. err = rows.Scan(colPtrs...)
  60. if err != nil {
  61. panic(err)
  62. }
  63. }
  64. slicedValues := values
  65. slicedValueStrings := valueStrings
  66. if i < len(valueStrings) {
  67. slicedValueStrings = slicedValueStrings[:i]
  68. slicedValues = slicedValues[:i*colCount]
  69. }
  70. if len(slicedValues) == 0 {
  71. break
  72. }
  73. res, err := tx.Exec(fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES %s", table, columnNames, strings.Join(slicedValueStrings, ",")), slicedValues...)
  74. if err != nil {
  75. panic(err)
  76. }
  77. count, _ := res.RowsAffected()
  78. inserted += count
  79. rowCount -= batchSize
  80. fmt.Print("#")
  81. }
  82. err = tx.Commit()
  83. if err != nil {
  84. return err
  85. }
  86. fmt.Println(" -- done with", inserted, "rows inserted")
  87. return nil
  88. }
  89. func Migrate(old *Database, new *Database) {
  90. err := migrateTable(old, new, "portal", "jid", "receiver", "mxid", "name", "topic", "avatar", "avatar_url", "encrypted", "first_event_id", "next_batch_id", "relay_user_id")
  91. if err != nil {
  92. panic(err)
  93. }
  94. err = migrateTable(old, new, "user", "mxid", "management_room", "username", "agent", "device")
  95. if err != nil {
  96. panic(err)
  97. }
  98. err = migrateTable(old, new, "puppet", "username", "avatar", "displayname", "name_quality", "custom_mxid", "access_token", "next_batch", "avatar_url", "enable_presence", "enable_receipts")
  99. if err != nil {
  100. panic(err)
  101. }
  102. err = migrateTable(old, new, "message", "chat_jid", "chat_receiver", "jid", "mxid", "sender", "timestamp", "sent", "decryption_error")
  103. if err != nil {
  104. panic(err)
  105. }
  106. err = migrateTable(old, new, "mx_registrations", "user_id")
  107. if err != nil {
  108. panic(err)
  109. }
  110. err = migrateTable(old, new, "mx_user_profile", "room_id", "user_id", "membership", "displayname", "avatar_url")
  111. if err != nil {
  112. panic(err)
  113. }
  114. err = migrateTable(old, new, "mx_room_state", "room_id", "power_levels")
  115. if err != nil {
  116. panic(err)
  117. }
  118. err = migrateTable(old, new, "crypto_account", "account_id", "device_id", "shared", "sync_token", "account")
  119. if err != nil {
  120. panic(err)
  121. }
  122. err = migrateTable(old, new, "crypto_message_index", "sender_key", "session_id", `"index"`, "event_id", "timestamp")
  123. if err != nil {
  124. panic(err)
  125. }
  126. err = migrateTable(old, new, "crypto_tracked_user", "user_id")
  127. if err != nil {
  128. panic(err)
  129. }
  130. err = migrateTable(old, new, "crypto_device", "user_id", "device_id", "identity_key", "signing_key", "trust", "deleted", "name")
  131. if err != nil {
  132. panic(err)
  133. }
  134. err = migrateTable(old, new, "crypto_olm_session", "account_id", "session_id", "sender_key", "session", "created_at", "last_used")
  135. if err != nil {
  136. panic(err)
  137. }
  138. err = migrateTable(old, new, "crypto_megolm_inbound_session", "account_id", "session_id", "sender_key", "signing_key", "room_id", "session", "forwarding_chains", "withheld_code", "withheld_reason")
  139. if err != nil {
  140. panic(err)
  141. }
  142. err = migrateTable(old, new, "crypto_megolm_outbound_session", "account_id", "room_id", "session_id", "session", "shared", "max_messages", "message_count", "max_age", "created_at", "last_used")
  143. if err != nil {
  144. panic(err)
  145. }
  146. err = migrateTable(old, new, "crypto_cross_signing_keys", "user_id", "usage", "key")
  147. if err != nil {
  148. panic(err)
  149. }
  150. err = migrateTable(old, new, "crypto_cross_signing_signatures", "signed_user_id", "signed_key", "signer_user_id", "signer_key", "signature")
  151. if err != nil {
  152. panic(err)
  153. }
  154. // Migrate whatsmeow tables.
  155. err = migrateTable(old, new, "whatsmeow_device", "jid", "registration_id", "noise_key", "identity_key", "signed_pre_key", "signed_pre_key_id", "signed_pre_key_sig", "adv_key", "adv_details", "adv_account_sig", "adv_device_sig", "platform", "business_name", "push_name")
  156. if err != nil {
  157. panic(err)
  158. }
  159. err = migrateTable(old, new, "whatsmeow_identity_keys", "our_jid", "their_id", "identity")
  160. if err != nil {
  161. panic(err)
  162. }
  163. err = migrateTable(old, new, "whatsmeow_pre_keys", "jid", "key_id", "key", "uploaded")
  164. if err != nil {
  165. panic(err)
  166. }
  167. err = migrateTable(old, new, "whatsmeow_sessions", "our_jid", "their_id", "session")
  168. if err != nil {
  169. panic(err)
  170. }
  171. err = migrateTable(old, new, "whatsmeow_sender_keys", "our_jid", "chat_id", "sender_id", "sender_key")
  172. if err != nil {
  173. panic(err)
  174. }
  175. err = migrateTable(old, new, "whatsmeow_app_state_sync_keys", "jid", "key_id", "key_data", "timestamp", "fingerprint")
  176. if err != nil {
  177. panic(err)
  178. }
  179. err = migrateTable(old, new, "whatsmeow_app_state_version", "jid", "name", "version", "hash")
  180. if err != nil {
  181. panic(err)
  182. }
  183. err = migrateTable(old, new, "whatsmeow_app_state_mutation_macs", "jid", "name", "version", "index_mac", "value_mac")
  184. if err != nil {
  185. panic(err)
  186. }
  187. err = migrateTable(old, new, "whatsmeow_contacts", "our_jid", "their_jid", "first_name", "full_name", "push_name", "business_name")
  188. if err != nil {
  189. panic(err)
  190. }
  191. err = migrateTable(old, new, "whatsmeow_chat_settings", "our_jid", "chat_jid", "muted_until", "pinned", "archived")
  192. if err != nil {
  193. panic(err)
  194. }
  195. }