portalquery.go 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. package database
  2. import (
  3. "github.com/bwmarrin/discordgo"
  4. log "maunium.net/go/maulogger/v2"
  5. "maunium.net/go/mautrix/id"
  6. )
  7. type PortalQuery struct {
  8. db *Database
  9. log log.Logger
  10. }
  11. func (pq *PortalQuery) New() *Portal {
  12. return &Portal{
  13. db: pq.db,
  14. log: pq.log,
  15. }
  16. }
  17. func (pq *PortalQuery) GetAll() []*Portal {
  18. return pq.getAll("SELECT * FROM portal")
  19. }
  20. func (pq *PortalQuery) GetByID(key PortalKey) *Portal {
  21. return pq.get("SELECT * FROM portal WHERE channel_id=$1 AND receiver=$2", key.ChannelID, key.Receiver)
  22. }
  23. func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
  24. return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid)
  25. }
  26. func (pq *PortalQuery) GetAllByID(id string) []*Portal {
  27. return pq.getAll("SELECT * FROM portal WHERE receiver=$1", id)
  28. }
  29. func (pq *PortalQuery) FindPrivateChats(receiver string) []*Portal {
  30. query := "SELECT * FROM portal WHERE receiver=$1 AND type=$2;"
  31. return pq.getAll(query, receiver, discordgo.ChannelTypeDM)
  32. }
  33. func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal {
  34. rows, err := pq.db.Query(query, args...)
  35. if err != nil || rows == nil {
  36. return nil
  37. }
  38. defer rows.Close()
  39. portals := []*Portal{}
  40. for rows.Next() {
  41. portals = append(portals, pq.New().Scan(rows))
  42. }
  43. return portals
  44. }
  45. func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
  46. row := pq.db.QueryRow(query, args...)
  47. if row == nil {
  48. return nil
  49. }
  50. return pq.New().Scan(row)
  51. }