portal.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. package database
  2. import (
  3. "database/sql"
  4. "github.com/bwmarrin/discordgo"
  5. log "maunium.net/go/maulogger/v2"
  6. "maunium.net/go/mautrix/id"
  7. "maunium.net/go/mautrix/util/dbutil"
  8. )
  9. // language=postgresql
  10. const (
  11. portalSelect = `
  12. SELECT dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, mxid,
  13. plain_name, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
  14. encrypted, in_space, first_event_id
  15. FROM portal
  16. `
  17. )
  18. type PortalKey struct {
  19. ChannelID string
  20. Receiver string
  21. }
  22. func NewPortalKey(channelID, receiver string) PortalKey {
  23. return PortalKey{
  24. ChannelID: channelID,
  25. Receiver: receiver,
  26. }
  27. }
  28. func (key PortalKey) String() string {
  29. if key.Receiver == "" {
  30. return key.ChannelID
  31. }
  32. return key.ChannelID + "-" + key.Receiver
  33. }
  34. type PortalQuery struct {
  35. db *Database
  36. log log.Logger
  37. }
  38. func (pq *PortalQuery) New() *Portal {
  39. return &Portal{
  40. db: pq.db,
  41. log: pq.log,
  42. }
  43. }
  44. func (pq *PortalQuery) GetAll() []*Portal {
  45. return pq.getAll(portalSelect)
  46. }
  47. func (pq *PortalQuery) GetByID(key PortalKey) *Portal {
  48. return pq.get(portalSelect+" WHERE dcid=$1 AND (receiver=$2 OR receiver='')", key.ChannelID, key.Receiver)
  49. }
  50. func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
  51. return pq.get(portalSelect+" WHERE mxid=$1", mxid)
  52. }
  53. func (pq *PortalQuery) FindPrivateChatsWith(id string) []*Portal {
  54. return pq.getAll(portalSelect+" WHERE other_user_id=$1 AND type=$2", id, discordgo.ChannelTypeDM)
  55. }
  56. func (pq *PortalQuery) FindPrivateChatsOf(receiver string) []*Portal {
  57. query := portalSelect + " portal WHERE receiver=$1 AND type=$2;"
  58. return pq.getAll(query, receiver, discordgo.ChannelTypeDM)
  59. }
  60. func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal {
  61. rows, err := pq.db.Query(query, args...)
  62. if err != nil || rows == nil {
  63. return nil
  64. }
  65. defer rows.Close()
  66. var portals []*Portal
  67. for rows.Next() {
  68. portals = append(portals, pq.New().Scan(rows))
  69. }
  70. return portals
  71. }
  72. func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
  73. return pq.New().Scan(pq.db.QueryRow(query, args...))
  74. }
  75. type Portal struct {
  76. db *Database
  77. log log.Logger
  78. Key PortalKey
  79. Type discordgo.ChannelType
  80. OtherUserID string
  81. ParentID string
  82. GuildID string
  83. MXID id.RoomID
  84. PlainName string
  85. Name string
  86. NameSet bool
  87. Topic string
  88. TopicSet bool
  89. Avatar string
  90. AvatarURL id.ContentURI
  91. AvatarSet bool
  92. Encrypted bool
  93. InSpace id.RoomID
  94. FirstEventID id.EventID
  95. }
  96. func (p *Portal) Scan(row dbutil.Scannable) *Portal {
  97. var otherUserID, guildID, parentID, mxid, firstEventID sql.NullString
  98. var chanType int32
  99. var avatarURL string
  100. err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &chanType, &otherUserID, &guildID, &parentID,
  101. &mxid, &p.PlainName, &p.Name, &p.NameSet, &p.Topic, &p.TopicSet, &p.Avatar, &avatarURL, &p.AvatarSet,
  102. &p.Encrypted, &p.InSpace, &firstEventID)
  103. if err != nil {
  104. if err != sql.ErrNoRows {
  105. p.log.Errorln("Database scan failed:", err)
  106. panic(err)
  107. }
  108. return nil
  109. }
  110. p.MXID = id.RoomID(mxid.String)
  111. p.OtherUserID = otherUserID.String
  112. p.GuildID = guildID.String
  113. p.ParentID = parentID.String
  114. p.Type = discordgo.ChannelType(chanType)
  115. p.FirstEventID = id.EventID(firstEventID.String)
  116. p.AvatarURL, _ = id.ParseContentURI(avatarURL)
  117. return p
  118. }
  119. func (p *Portal) Insert() {
  120. query := `
  121. INSERT INTO portal (dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, mxid,
  122. plain_name, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
  123. encrypted, in_space, first_event_id)
  124. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18)
  125. `
  126. _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.Type,
  127. strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)),
  128. p.PlainName, p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
  129. p.Encrypted, p.InSpace, p.FirstEventID.String())
  130. if err != nil {
  131. p.log.Warnfln("Failed to insert %s: %v", p.Key, err)
  132. panic(err)
  133. }
  134. }
  135. func (p *Portal) Update() {
  136. query := `
  137. UPDATE portal
  138. SET type=$1, other_user_id=$2, dc_guild_id=$3, dc_parent_id=$4, mxid=$5,
  139. plain_name=$6, name=$7, name_set=$8, topic=$9, topic_set=$10, avatar=$11, avatar_url=$12, avatar_set=$13,
  140. encrypted=$14, in_space=$15, first_event_id=$16
  141. WHERE dcid=$17 AND receiver=$18
  142. `
  143. _, err := p.db.Exec(query,
  144. p.Type, strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)),
  145. p.PlainName, p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
  146. p.Encrypted, p.InSpace, p.FirstEventID.String(),
  147. p.Key.ChannelID, p.Key.Receiver)
  148. if err != nil {
  149. p.log.Warnfln("Failed to update %s: %v", p.Key, err)
  150. panic(err)
  151. }
  152. }
  153. func (p *Portal) Delete() {
  154. query := "DELETE FROM portal WHERE dcid=$1 AND receiver=$2"
  155. _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver)
  156. if err != nil {
  157. p.log.Warnfln("Failed to delete %s: %v", p.Key, err)
  158. panic(err)
  159. }
  160. }