cryptostore.go 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. package database
  2. import (
  3. "database/sql"
  4. log "maunium.net/go/maulogger/v2"
  5. "maunium.net/go/mautrix/crypto"
  6. "maunium.net/go/mautrix/id"
  7. )
  8. type SQLCryptoStore struct {
  9. *crypto.SQLCryptoStore
  10. UserID id.UserID
  11. GhostIDFormat string
  12. }
  13. var _ crypto.Store = (*SQLCryptoStore)(nil)
  14. func NewSQLCryptoStore(db *Database, userID id.UserID, ghostIDFormat string) *SQLCryptoStore {
  15. return &SQLCryptoStore{
  16. SQLCryptoStore: crypto.NewSQLCryptoStore(db.DB, db.dialect, "", "",
  17. []byte("maunium.net/go/mautrix-whatsapp"),
  18. &cryptoLogger{db.log.Sub("CryptoStore")}),
  19. UserID: userID,
  20. GhostIDFormat: ghostIDFormat,
  21. }
  22. }
  23. func (store *SQLCryptoStore) FindDeviceID() id.DeviceID {
  24. var deviceID id.DeviceID
  25. query := `SELECT device_id FROM crypto_account WHERE account_id=$1`
  26. err := store.DB.QueryRow(query, store.AccountID).Scan(&deviceID)
  27. if err != nil && err != sql.ErrNoRows {
  28. store.Log.Warn("Failed to scan device ID: %v", err)
  29. }
  30. return deviceID
  31. }
  32. func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) ([]id.UserID, error) {
  33. query := `
  34. SELECT user_id FROM mx_user_profile
  35. WHERE room_id=$1
  36. AND (membership='join' OR membership='invite')
  37. AND user_id<>$2
  38. AND user_id NOT LIKE $3
  39. `
  40. members := []id.UserID{}
  41. rows, err := store.DB.Query(query, roomID, store.UserID, store.GhostIDFormat)
  42. if err != nil {
  43. return members, err
  44. }
  45. for rows.Next() {
  46. var userID id.UserID
  47. err := rows.Scan(&userID)
  48. if err != nil {
  49. store.Log.Warn("Failed to scan member in %s: %v", roomID, err)
  50. return members, err
  51. }
  52. members = append(members, userID)
  53. }
  54. return members, nil
  55. }
  56. // TODO merge this with the one in the parent package
  57. type cryptoLogger struct {
  58. int log.Logger
  59. }
  60. var levelTrace = log.Level{
  61. Name: "TRACE",
  62. Severity: -10,
  63. Color: -1,
  64. }
  65. func (c *cryptoLogger) Error(message string, args ...interface{}) {
  66. c.int.Errorfln(message, args...)
  67. }
  68. func (c *cryptoLogger) Warn(message string, args ...interface{}) {
  69. c.int.Warnfln(message, args...)
  70. }
  71. func (c *cryptoLogger) Debug(message string, args ...interface{}) {
  72. c.int.Debugfln(message, args...)
  73. }
  74. func (c *cryptoLogger) Trace(message string, args ...interface{}) {
  75. c.int.Logfln(levelTrace, message, args...)
  76. }