puppet.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. package database
  2. import (
  3. "database/sql"
  4. log "maunium.net/go/maulogger/v2"
  5. "maunium.net/go/mautrix/id"
  6. "maunium.net/go/mautrix/util/dbutil"
  7. )
  8. const (
  9. puppetSelect = "SELECT id, name, name_set, avatar, avatar_url, avatar_set," +
  10. " custom_mxid, access_token, next_batch" +
  11. " FROM puppet "
  12. )
  13. type PuppetQuery struct {
  14. db *Database
  15. log log.Logger
  16. }
  17. func (pq *PuppetQuery) New() *Puppet {
  18. return &Puppet{
  19. db: pq.db,
  20. log: pq.log,
  21. }
  22. }
  23. func (pq *PuppetQuery) Get(id string) *Puppet {
  24. return pq.get(puppetSelect+" WHERE id=$1", id)
  25. }
  26. func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet {
  27. return pq.get(puppetSelect+" WHERE custom_mxid=$1", mxid)
  28. }
  29. func (pq *PuppetQuery) get(query string, args ...interface{}) *Puppet {
  30. return pq.New().Scan(pq.db.QueryRow(query, args...))
  31. }
  32. func (pq *PuppetQuery) GetAll() []*Puppet {
  33. return pq.getAll(puppetSelect)
  34. }
  35. func (pq *PuppetQuery) GetAllWithCustomMXID() []*Puppet {
  36. return pq.getAll(puppetSelect + " WHERE custom_mxid<>''")
  37. }
  38. func (pq *PuppetQuery) getAll(query string, args ...interface{}) []*Puppet {
  39. rows, err := pq.db.Query(query, args...)
  40. if err != nil || rows == nil {
  41. return nil
  42. }
  43. defer rows.Close()
  44. var puppets []*Puppet
  45. for rows.Next() {
  46. puppets = append(puppets, pq.New().Scan(rows))
  47. }
  48. return puppets
  49. }
  50. type Puppet struct {
  51. db *Database
  52. log log.Logger
  53. ID string
  54. Name string
  55. NameSet bool
  56. Avatar string
  57. AvatarURL id.ContentURI
  58. AvatarSet bool
  59. CustomMXID id.UserID
  60. AccessToken string
  61. NextBatch string
  62. }
  63. func (p *Puppet) Scan(row dbutil.Scannable) *Puppet {
  64. var avatarURL string
  65. var customMXID, accessToken, nextBatch sql.NullString
  66. err := row.Scan(&p.ID, &p.Name, &p.NameSet, &p.Avatar, &avatarURL, &p.AvatarSet,
  67. &customMXID, &accessToken, &nextBatch)
  68. if err != nil {
  69. if err != sql.ErrNoRows {
  70. p.log.Errorln("Database scan failed:", err)
  71. panic(err)
  72. }
  73. return nil
  74. }
  75. p.AvatarURL, _ = id.ParseContentURI(avatarURL)
  76. p.CustomMXID = id.UserID(customMXID.String)
  77. p.AccessToken = accessToken.String
  78. p.NextBatch = nextBatch.String
  79. return p
  80. }
  81. func (p *Puppet) Insert() {
  82. query := `
  83. INSERT INTO puppet (id, name, name_set, avatar, avatar_url, avatar_set, custom_mxid, access_token, next_batch)
  84. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
  85. `
  86. _, err := p.db.Exec(query, p.ID, p.Name, p.NameSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
  87. strPtr(string(p.CustomMXID)), strPtr(p.AccessToken), strPtr(p.NextBatch))
  88. if err != nil {
  89. p.log.Warnfln("Failed to insert %s: %v", p.ID, err)
  90. panic(err)
  91. }
  92. }
  93. func (p *Puppet) Update() {
  94. query := `
  95. UPDATE puppet SET name=$1, name_set=$2, avatar=$3, avatar_url=$4, avatar_set=$5,
  96. custom_mxid=$6, access_token=$7, next_batch=$8
  97. WHERE id=$9
  98. `
  99. _, err := p.db.Exec(query, p.Name, p.NameSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
  100. strPtr(string(p.CustomMXID)), strPtr(p.AccessToken), strPtr(p.NextBatch),
  101. p.ID)
  102. if err != nil {
  103. p.log.Warnfln("Failed to update %s: %v", p.ID, err)
  104. panic(err)
  105. }
  106. }