portal.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. package database
  2. import (
  3. "database/sql"
  4. "github.com/bwmarrin/discordgo"
  5. log "maunium.net/go/maulogger/v2"
  6. "maunium.net/go/mautrix/id"
  7. "maunium.net/go/mautrix/util/dbutil"
  8. )
  9. const (
  10. portalSelect = "SELECT dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, " +
  11. " mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, encrypted, in_space, first_event_id" +
  12. " FROM portal"
  13. )
  14. type PortalKey struct {
  15. ChannelID string
  16. Receiver string
  17. }
  18. func NewPortalKey(channelID, receiver string) PortalKey {
  19. return PortalKey{
  20. ChannelID: channelID,
  21. Receiver: receiver,
  22. }
  23. }
  24. func (key PortalKey) String() string {
  25. if key.Receiver == "" {
  26. return key.ChannelID
  27. }
  28. return key.ChannelID + "-" + key.Receiver
  29. }
  30. type PortalQuery struct {
  31. db *Database
  32. log log.Logger
  33. }
  34. func (pq *PortalQuery) New() *Portal {
  35. return &Portal{
  36. db: pq.db,
  37. log: pq.log,
  38. }
  39. }
  40. func (pq *PortalQuery) GetAll() []*Portal {
  41. return pq.getAll(portalSelect)
  42. }
  43. func (pq *PortalQuery) GetByID(key PortalKey) *Portal {
  44. return pq.get(portalSelect+" WHERE dcid=$1 AND (receiver=$2 OR receiver='')", key.ChannelID, key.Receiver)
  45. }
  46. func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
  47. return pq.get(portalSelect+" WHERE mxid=$1", mxid)
  48. }
  49. func (pq *PortalQuery) FindPrivateChatsWith(id string) []*Portal {
  50. return pq.getAll(portalSelect+" WHERE other_user_id=$1 AND type=$2", id, discordgo.ChannelTypeDM)
  51. }
  52. func (pq *PortalQuery) FindPrivateChatsOf(receiver string) []*Portal {
  53. query := portalSelect + " portal WHERE receiver=$1 AND type=$2;"
  54. return pq.getAll(query, receiver, discordgo.ChannelTypeDM)
  55. }
  56. func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal {
  57. rows, err := pq.db.Query(query, args...)
  58. if err != nil || rows == nil {
  59. return nil
  60. }
  61. defer rows.Close()
  62. var portals []*Portal
  63. for rows.Next() {
  64. portals = append(portals, pq.New().Scan(rows))
  65. }
  66. return portals
  67. }
  68. func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
  69. return pq.New().Scan(pq.db.QueryRow(query, args...))
  70. }
  71. type Portal struct {
  72. db *Database
  73. log log.Logger
  74. Key PortalKey
  75. Type discordgo.ChannelType
  76. OtherUserID string
  77. ParentID string
  78. GuildID string
  79. MXID id.RoomID
  80. Name string
  81. NameSet bool
  82. Topic string
  83. TopicSet bool
  84. Avatar string
  85. AvatarURL id.ContentURI
  86. AvatarSet bool
  87. Encrypted bool
  88. InSpace id.RoomID
  89. FirstEventID id.EventID
  90. }
  91. func (p *Portal) Scan(row dbutil.Scannable) *Portal {
  92. var otherUserID, guildID, parentID, mxid, firstEventID sql.NullString
  93. var chanType int32
  94. var avatarURL string
  95. err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &chanType, &otherUserID, &guildID, &parentID,
  96. &mxid, &p.Name, &p.NameSet, &p.Topic, &p.TopicSet, &p.Avatar, &avatarURL, &p.AvatarSet,
  97. &p.Encrypted, &p.InSpace, &firstEventID)
  98. if err != nil {
  99. if err != sql.ErrNoRows {
  100. p.log.Errorln("Database scan failed:", err)
  101. panic(err)
  102. }
  103. return nil
  104. }
  105. p.MXID = id.RoomID(mxid.String)
  106. p.OtherUserID = otherUserID.String
  107. p.GuildID = guildID.String
  108. p.ParentID = parentID.String
  109. p.Type = discordgo.ChannelType(chanType)
  110. p.FirstEventID = id.EventID(firstEventID.String)
  111. p.AvatarURL, _ = id.ParseContentURI(avatarURL)
  112. return p
  113. }
  114. func (p *Portal) Insert() {
  115. query := `
  116. INSERT INTO portal (dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, mxid,
  117. name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
  118. encrypted, in_space, first_event_id)
  119. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
  120. `
  121. _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.Type,
  122. strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)),
  123. p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
  124. p.Encrypted, p.InSpace, p.FirstEventID.String())
  125. if err != nil {
  126. p.log.Warnfln("Failed to insert %s: %v", p.Key, err)
  127. panic(err)
  128. }
  129. }
  130. func (p *Portal) Update() {
  131. query := `
  132. UPDATE portal SET type=$1, other_user_id=$2, dc_guild_id=$3, dc_parent_id=$4, mxid=$5,
  133. name=$6, name_set=$7, topic=$8, topic_set=$9, avatar=$10, avatar_url=$11, avatar_set=$12,
  134. encrypted=$13, in_space=$14, first_event_id=$15
  135. WHERE dcid=$16 AND receiver=$17
  136. `
  137. _, err := p.db.Exec(query,
  138. p.Type, strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)),
  139. p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
  140. p.Encrypted, p.InSpace, p.FirstEventID.String(),
  141. p.Key.ChannelID, p.Key.Receiver)
  142. if err != nil {
  143. p.log.Warnfln("Failed to update %s: %v", p.Key, err)
  144. panic(err)
  145. }
  146. }
  147. func (p *Portal) Delete() {
  148. query := "DELETE FROM portal WHERE dcid=$1 AND receiver=$2"
  149. _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver)
  150. if err != nil {
  151. p.log.Warnfln("Failed to delete %s: %v", p.Key, err)
  152. panic(err)
  153. }
  154. }