puppetquery.go 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. package database
  2. import (
  3. log "maunium.net/go/maulogger/v2"
  4. "maunium.net/go/mautrix/id"
  5. )
  6. type PuppetQuery struct {
  7. db *Database
  8. log log.Logger
  9. }
  10. func (pq *PuppetQuery) New() *Puppet {
  11. return &Puppet{
  12. db: pq.db,
  13. log: pq.log,
  14. EnablePresence: true,
  15. }
  16. }
  17. func (pq *PuppetQuery) Get(id string) *Puppet {
  18. return pq.get(puppetSelect+" WHERE id=$1", id)
  19. }
  20. func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet {
  21. return pq.get(puppetSelect+" WHERE custom_mxid=$1", mxid)
  22. }
  23. func (pq *PuppetQuery) get(query string, args ...interface{}) *Puppet {
  24. row := pq.db.QueryRow(query, args...)
  25. if row == nil {
  26. return nil
  27. }
  28. return pq.New().Scan(row)
  29. }
  30. func (pq *PuppetQuery) GetAll() []*Puppet {
  31. return pq.getAll(puppetSelect)
  32. }
  33. func (pq *PuppetQuery) GetAllWithCustomMXID() []*Puppet {
  34. return pq.getAll(puppetSelect + " WHERE custom_mxid<>''")
  35. }
  36. func (pq *PuppetQuery) getAll(query string, args ...interface{}) []*Puppet {
  37. rows, err := pq.db.Query(query, args...)
  38. if err != nil || rows == nil {
  39. return nil
  40. }
  41. defer rows.Close()
  42. puppets := []*Puppet{}
  43. for rows.Next() {
  44. puppets = append(puppets, pq.New().Scan(rows))
  45. }
  46. return puppets
  47. }