puppet.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. package database
  2. import (
  3. "database/sql"
  4. log "maunium.net/go/maulogger/v2"
  5. "maunium.net/go/mautrix/id"
  6. "maunium.net/go/mautrix/util/dbutil"
  7. )
  8. const (
  9. puppetSelect = "SELECT id, display_name, avatar, avatar_url," +
  10. " custom_mxid, access_token, next_batch" +
  11. " FROM puppet "
  12. )
  13. type PuppetQuery struct {
  14. db *Database
  15. log log.Logger
  16. }
  17. func (pq *PuppetQuery) New() *Puppet {
  18. return &Puppet{
  19. db: pq.db,
  20. log: pq.log,
  21. }
  22. }
  23. func (pq *PuppetQuery) Get(id string) *Puppet {
  24. return pq.get(puppetSelect+" WHERE id=$1", id)
  25. }
  26. func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet {
  27. return pq.get(puppetSelect+" WHERE custom_mxid=$1", mxid)
  28. }
  29. func (pq *PuppetQuery) get(query string, args ...interface{}) *Puppet {
  30. row := pq.db.QueryRow(query, args...)
  31. if row == nil {
  32. return nil
  33. }
  34. return pq.New().Scan(row)
  35. }
  36. func (pq *PuppetQuery) GetAll() []*Puppet {
  37. return pq.getAll(puppetSelect)
  38. }
  39. func (pq *PuppetQuery) GetAllWithCustomMXID() []*Puppet {
  40. return pq.getAll(puppetSelect + " WHERE custom_mxid<>''")
  41. }
  42. func (pq *PuppetQuery) getAll(query string, args ...interface{}) []*Puppet {
  43. rows, err := pq.db.Query(query, args...)
  44. if err != nil || rows == nil {
  45. return nil
  46. }
  47. defer rows.Close()
  48. puppets := []*Puppet{}
  49. for rows.Next() {
  50. puppets = append(puppets, pq.New().Scan(rows))
  51. }
  52. return puppets
  53. }
  54. type Puppet struct {
  55. db *Database
  56. log log.Logger
  57. ID string
  58. DisplayName string
  59. Avatar string
  60. AvatarURL id.ContentURI
  61. CustomMXID id.UserID
  62. AccessToken string
  63. NextBatch string
  64. }
  65. func (p *Puppet) Scan(row dbutil.Scannable) *Puppet {
  66. var did, displayName, avatar, avatarURL sql.NullString
  67. var customMXID, accessToken, nextBatch sql.NullString
  68. err := row.Scan(&did, &displayName, &avatar, &avatarURL,
  69. &customMXID, &accessToken, &nextBatch)
  70. if err != nil {
  71. if err != sql.ErrNoRows {
  72. p.log.Errorln("Database scan failed:", err)
  73. }
  74. return nil
  75. }
  76. p.ID = did.String
  77. p.DisplayName = displayName.String
  78. p.Avatar = avatar.String
  79. p.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
  80. p.CustomMXID = id.UserID(customMXID.String)
  81. p.AccessToken = accessToken.String
  82. p.NextBatch = nextBatch.String
  83. return p
  84. }
  85. func (p *Puppet) Insert() {
  86. query := "INSERT INTO puppet" +
  87. " (id, display_name, avatar, avatar_url," +
  88. " custom_mxid, access_token, next_batch)" +
  89. " VALUES ($1, $2, $3, $4, $5, $6, $7)"
  90. _, err := p.db.Exec(query, p.ID, p.DisplayName, p.Avatar,
  91. p.AvatarURL.String(), p.CustomMXID, p.AccessToken,
  92. p.NextBatch)
  93. if err != nil {
  94. p.log.Warnfln("Failed to insert %s: %v", p.ID, err)
  95. }
  96. }
  97. func (p *Puppet) Update() {
  98. query := "UPDATE puppet" +
  99. " SET display_name=$1, avatar=$2, avatar_url=$3, " +
  100. " custom_mxid=$4, access_token=$5, next_batch=$6" +
  101. " WHERE id=$7"
  102. _, err := p.db.Exec(query, p.DisplayName, p.Avatar, p.AvatarURL.String(),
  103. p.CustomMXID, p.AccessToken, p.NextBatch,
  104. p.ID)
  105. if err != nil {
  106. p.log.Warnfln("Failed to update %s: %v", p.ID, err)
  107. }
  108. }