portalquery.go 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. package database
  2. import (
  3. "github.com/bwmarrin/discordgo"
  4. log "maunium.net/go/maulogger/v2"
  5. "maunium.net/go/mautrix/id"
  6. )
  7. const (
  8. portalSelect = "SELECT channel_id, receiver, mxid, name, topic, avatar," +
  9. " avatar_url, type, dmuser, first_event_id, encrypted" +
  10. " FROM portal"
  11. )
  12. type PortalQuery struct {
  13. db *Database
  14. log log.Logger
  15. }
  16. func (pq *PortalQuery) New() *Portal {
  17. return &Portal{
  18. db: pq.db,
  19. log: pq.log,
  20. }
  21. }
  22. func (pq *PortalQuery) GetAll() []*Portal {
  23. return pq.getAll(portalSelect)
  24. }
  25. func (pq *PortalQuery) GetByID(key PortalKey) *Portal {
  26. return pq.get(portalSelect+" WHERE channel_id=$1 AND receiver=$2", key.ChannelID, key.Receiver)
  27. }
  28. func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
  29. return pq.get(portalSelect+" WHERE mxid=$1", mxid)
  30. }
  31. func (pq *PortalQuery) GetAllByID(id string) []*Portal {
  32. return pq.getAll(portalSelect+" WHERE receiver=$1", id)
  33. }
  34. func (pq *PortalQuery) FindPrivateChats(receiver string) []*Portal {
  35. query := portalSelect + " portal WHERE receiver=$1 AND type=$2;"
  36. return pq.getAll(query, receiver, discordgo.ChannelTypeDM)
  37. }
  38. func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal {
  39. rows, err := pq.db.Query(query, args...)
  40. if err != nil || rows == nil {
  41. return nil
  42. }
  43. defer rows.Close()
  44. portals := []*Portal{}
  45. for rows.Next() {
  46. portals = append(portals, pq.New().Scan(rows))
  47. }
  48. return portals
  49. }
  50. func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
  51. row := pq.db.QueryRow(query, args...)
  52. if row == nil {
  53. return nil
  54. }
  55. return pq.New().Scan(row)
  56. }