portalquery.go 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. package database
  2. import (
  3. log "maunium.net/go/maulogger/v2"
  4. "maunium.net/go/mautrix/id"
  5. )
  6. type PortalQuery struct {
  7. db *Database
  8. log log.Logger
  9. }
  10. func (pq *PortalQuery) New() *Portal {
  11. return &Portal{
  12. db: pq.db,
  13. log: pq.log,
  14. }
  15. }
  16. func (pq *PortalQuery) GetAll() []*Portal {
  17. return pq.getAll("SELECT * FROM portal")
  18. }
  19. func (pq *PortalQuery) GetByID(key PortalKey) *Portal {
  20. return pq.get("SELECT * FROM portal WHERE channel_id=$1 AND receiver=$2", key.ChannelID, key.Receiver)
  21. }
  22. func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
  23. return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid)
  24. }
  25. func (pq *PortalQuery) GetAllByID(id string) []*Portal {
  26. return pq.getAll("SELECT * FROM portal WHERE receiver=$1", id)
  27. }
  28. func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal {
  29. rows, err := pq.db.Query(query, args...)
  30. if err != nil || rows == nil {
  31. return nil
  32. }
  33. defer rows.Close()
  34. portals := []*Portal{}
  35. for rows.Next() {
  36. portals = append(portals, pq.New().Scan(rows))
  37. }
  38. return portals
  39. }
  40. func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
  41. row := pq.db.QueryRow(query, args...)
  42. if row == nil {
  43. return nil
  44. }
  45. return pq.New().Scan(row)
  46. }