puppet.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. package database
  2. import (
  3. "database/sql"
  4. log "maunium.net/go/maulogger/v2"
  5. "maunium.net/go/mautrix/id"
  6. "maunium.net/go/mautrix/util/dbutil"
  7. )
  8. const (
  9. puppetSelect = "SELECT id, name, name_set, avatar, avatar_url, avatar_set," +
  10. " contact_info_set, username, discriminator, is_bot, custom_mxid, access_token, next_batch" +
  11. " FROM puppet "
  12. )
  13. type PuppetQuery struct {
  14. db *Database
  15. log log.Logger
  16. }
  17. func (pq *PuppetQuery) New() *Puppet {
  18. return &Puppet{
  19. db: pq.db,
  20. log: pq.log,
  21. }
  22. }
  23. func (pq *PuppetQuery) Get(id string) *Puppet {
  24. return pq.get(puppetSelect+" WHERE id=$1", id)
  25. }
  26. func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet {
  27. return pq.get(puppetSelect+" WHERE custom_mxid=$1", mxid)
  28. }
  29. func (pq *PuppetQuery) get(query string, args ...interface{}) *Puppet {
  30. return pq.New().Scan(pq.db.QueryRow(query, args...))
  31. }
  32. func (pq *PuppetQuery) GetAll() []*Puppet {
  33. return pq.getAll(puppetSelect)
  34. }
  35. func (pq *PuppetQuery) GetAllWithCustomMXID() []*Puppet {
  36. return pq.getAll(puppetSelect + " WHERE custom_mxid<>''")
  37. }
  38. func (pq *PuppetQuery) getAll(query string, args ...interface{}) []*Puppet {
  39. rows, err := pq.db.Query(query, args...)
  40. if err != nil || rows == nil {
  41. return nil
  42. }
  43. defer rows.Close()
  44. var puppets []*Puppet
  45. for rows.Next() {
  46. puppets = append(puppets, pq.New().Scan(rows))
  47. }
  48. return puppets
  49. }
  50. type Puppet struct {
  51. db *Database
  52. log log.Logger
  53. ID string
  54. Name string
  55. NameSet bool
  56. Avatar string
  57. AvatarURL id.ContentURI
  58. AvatarSet bool
  59. ContactInfoSet bool
  60. Username string
  61. Discriminator string
  62. IsBot bool
  63. CustomMXID id.UserID
  64. AccessToken string
  65. NextBatch string
  66. }
  67. func (p *Puppet) Scan(row dbutil.Scannable) *Puppet {
  68. var avatarURL string
  69. var customMXID, accessToken, nextBatch sql.NullString
  70. err := row.Scan(&p.ID, &p.Name, &p.NameSet, &p.Avatar, &avatarURL, &p.AvatarSet, &p.ContactInfoSet,
  71. &p.Username, &p.Discriminator, &p.IsBot, &customMXID, &accessToken, &nextBatch)
  72. if err != nil {
  73. if err != sql.ErrNoRows {
  74. p.log.Errorln("Database scan failed:", err)
  75. panic(err)
  76. }
  77. return nil
  78. }
  79. p.AvatarURL, _ = id.ParseContentURI(avatarURL)
  80. p.CustomMXID = id.UserID(customMXID.String)
  81. p.AccessToken = accessToken.String
  82. p.NextBatch = nextBatch.String
  83. return p
  84. }
  85. func (p *Puppet) Insert() {
  86. query := `
  87. INSERT INTO puppet (id, name, name_set, avatar, avatar_url, avatar_set, contact_info_set, username, discriminator, is_bot, custom_mxid, access_token, next_batch)
  88. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
  89. `
  90. _, err := p.db.Exec(query, p.ID, p.Name, p.NameSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet, p.ContactInfoSet,
  91. p.Username, p.Discriminator, p.IsBot, strPtr(p.CustomMXID), strPtr(p.AccessToken), strPtr(p.NextBatch))
  92. if err != nil {
  93. p.log.Warnfln("Failed to insert %s: %v", p.ID, err)
  94. panic(err)
  95. }
  96. }
  97. func (p *Puppet) Update() {
  98. query := `
  99. UPDATE puppet SET name=$1, name_set=$2, avatar=$3, avatar_url=$4, avatar_set=$5, contact_info_set=$6,
  100. username=$7, discriminator=$8, is_bot=$9, custom_mxid=$10, access_token=$11, next_batch=$12
  101. WHERE id=$13
  102. `
  103. _, err := p.db.Exec(query, p.Name, p.NameSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet, p.ContactInfoSet,
  104. p.Username, p.Discriminator, p.IsBot, strPtr(p.CustomMXID), strPtr(p.AccessToken), strPtr(p.NextBatch),
  105. p.ID)
  106. if err != nil {
  107. p.log.Warnfln("Failed to update %s: %v", p.ID, err)
  108. panic(err)
  109. }
  110. }