portal.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. package database
  2. import (
  3. "database/sql"
  4. "github.com/bwmarrin/discordgo"
  5. log "maunium.net/go/maulogger/v2"
  6. "maunium.net/go/mautrix/id"
  7. "maunium.net/go/mautrix/util/dbutil"
  8. )
  9. const (
  10. portalSelect = "SELECT dcid, receiver, mxid, name, topic, avatar," +
  11. " avatar_url, type, other_user_id, first_event_id, encrypted" +
  12. " FROM portal"
  13. )
  14. type PortalQuery struct {
  15. db *Database
  16. log log.Logger
  17. }
  18. func (pq *PortalQuery) New() *Portal {
  19. return &Portal{
  20. db: pq.db,
  21. log: pq.log,
  22. }
  23. }
  24. func (pq *PortalQuery) GetAll() []*Portal {
  25. return pq.getAll(portalSelect)
  26. }
  27. func (pq *PortalQuery) GetByID(key PortalKey) *Portal {
  28. return pq.get(portalSelect+" WHERE dcid=$1 AND receiver=$2", key.ChannelID, key.Receiver)
  29. }
  30. func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
  31. return pq.get(portalSelect+" WHERE mxid=$1", mxid)
  32. }
  33. func (pq *PortalQuery) FindPrivateChatsWith(id string) []*Portal {
  34. return pq.getAll(portalSelect+" WHERE other_user_id=$1 AND type=$2", id, discordgo.ChannelTypeDM)
  35. }
  36. func (pq *PortalQuery) FindPrivateChatsOf(receiver string) []*Portal {
  37. query := portalSelect + " portal WHERE receiver=$1 AND type=$2;"
  38. return pq.getAll(query, receiver, discordgo.ChannelTypeDM)
  39. }
  40. func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal {
  41. rows, err := pq.db.Query(query, args...)
  42. if err != nil || rows == nil {
  43. return nil
  44. }
  45. defer rows.Close()
  46. var portals []*Portal
  47. for rows.Next() {
  48. portals = append(portals, pq.New().Scan(rows))
  49. }
  50. return portals
  51. }
  52. func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
  53. row := pq.db.QueryRow(query, args...)
  54. if row == nil {
  55. return nil
  56. }
  57. return pq.New().Scan(row)
  58. }
  59. type Portal struct {
  60. db *Database
  61. log log.Logger
  62. Key PortalKey
  63. Type discordgo.ChannelType
  64. OtherUserID string
  65. MXID id.RoomID
  66. Name string
  67. Topic string
  68. Avatar string
  69. AvatarURL id.ContentURI
  70. Encrypted bool
  71. FirstEventID id.EventID
  72. }
  73. func (p *Portal) Scan(row dbutil.Scannable) *Portal {
  74. var mxid, avatarURL, firstEventID sql.NullString
  75. var typ sql.NullInt32
  76. err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &mxid, &p.Name,
  77. &p.Topic, &p.Avatar, &avatarURL, &typ, &p.OtherUserID, &firstEventID,
  78. &p.Encrypted)
  79. if err != nil {
  80. if err != sql.ErrNoRows {
  81. p.log.Errorln("Database scan failed:", err)
  82. }
  83. return nil
  84. }
  85. p.MXID = id.RoomID(mxid.String)
  86. p.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
  87. p.Type = discordgo.ChannelType(typ.Int32)
  88. p.FirstEventID = id.EventID(firstEventID.String)
  89. return p
  90. }
  91. func (p *Portal) mxidPtr() *id.RoomID {
  92. if p.MXID != "" {
  93. return &p.MXID
  94. }
  95. return nil
  96. }
  97. func (p *Portal) Insert() {
  98. query := "INSERT INTO portal" +
  99. " (dcid, receiver, mxid, name, topic, avatar, avatar_url," +
  100. " type, other_user_id, first_event_id, encrypted)" +
  101. " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)"
  102. _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.mxidPtr(),
  103. p.Name, p.Topic, p.Avatar, p.AvatarURL.String(), p.Type, p.OtherUserID,
  104. p.FirstEventID.String(), p.Encrypted)
  105. if err != nil {
  106. p.log.Warnfln("Failed to insert %s: %v", p.Key, err)
  107. }
  108. }
  109. func (p *Portal) Update() {
  110. query := "UPDATE portal SET" +
  111. " mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, type=$6," +
  112. " other_user_id=$7, first_event_id=$8, encrypted=$9" +
  113. " WHERE dcid=$10 AND receiver=$11"
  114. _, err := p.db.Exec(query, p.mxidPtr(), p.Name, p.Topic, p.Avatar,
  115. p.AvatarURL.String(), p.Type, p.OtherUserID, p.FirstEventID.String(),
  116. p.Encrypted,
  117. p.Key.ChannelID, p.Key.Receiver)
  118. if err != nil {
  119. p.log.Warnfln("Failed to update %s: %v", p.Key, err)
  120. }
  121. }
  122. func (p *Portal) Delete() {
  123. query := "DELETE FROM portal WHERE dcid=$1 AND receiver=$2"
  124. _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver)
  125. if err != nil {
  126. p.log.Warnfln("Failed to delete %s: %v", p.Key, err)
  127. }
  128. }