portal.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. package database
  2. import (
  3. "database/sql"
  4. "github.com/bwmarrin/discordgo"
  5. log "maunium.net/go/maulogger/v2"
  6. "maunium.net/go/mautrix/id"
  7. "maunium.net/go/mautrix/util/dbutil"
  8. )
  9. // language=postgresql
  10. const (
  11. portalSelect = `
  12. SELECT dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, mxid,
  13. plain_name, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
  14. encrypted, in_space, first_event_id, relay_webhook_id, relay_webhook_secret
  15. FROM portal
  16. `
  17. )
  18. type PortalKey struct {
  19. ChannelID string
  20. Receiver string
  21. }
  22. func NewPortalKey(channelID, receiver string) PortalKey {
  23. return PortalKey{
  24. ChannelID: channelID,
  25. Receiver: receiver,
  26. }
  27. }
  28. func (key PortalKey) String() string {
  29. if key.Receiver == "" {
  30. return key.ChannelID
  31. }
  32. return key.ChannelID + "-" + key.Receiver
  33. }
  34. type PortalQuery struct {
  35. db *Database
  36. log log.Logger
  37. }
  38. func (pq *PortalQuery) New() *Portal {
  39. return &Portal{
  40. db: pq.db,
  41. log: pq.log,
  42. }
  43. }
  44. func (pq *PortalQuery) GetAll() []*Portal {
  45. return pq.getAll(portalSelect)
  46. }
  47. func (pq *PortalQuery) GetAllInGuild(guildID string) []*Portal {
  48. return pq.getAll(portalSelect+" WHERE dc_guild_id=$1", guildID)
  49. }
  50. func (pq *PortalQuery) GetByID(key PortalKey) *Portal {
  51. return pq.get(portalSelect+" WHERE dcid=$1 AND (receiver=$2 OR receiver='')", key.ChannelID, key.Receiver)
  52. }
  53. func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
  54. return pq.get(portalSelect+" WHERE mxid=$1", mxid)
  55. }
  56. func (pq *PortalQuery) FindPrivateChatsWith(id string) []*Portal {
  57. return pq.getAll(portalSelect+" WHERE other_user_id=$1 AND type=$2", id, discordgo.ChannelTypeDM)
  58. }
  59. func (pq *PortalQuery) FindPrivateChatsOf(receiver string) []*Portal {
  60. query := portalSelect + " portal WHERE receiver=$1 AND type=$2;"
  61. return pq.getAll(query, receiver, discordgo.ChannelTypeDM)
  62. }
  63. func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal {
  64. rows, err := pq.db.Query(query, args...)
  65. if err != nil || rows == nil {
  66. return nil
  67. }
  68. defer rows.Close()
  69. var portals []*Portal
  70. for rows.Next() {
  71. portals = append(portals, pq.New().Scan(rows))
  72. }
  73. return portals
  74. }
  75. func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
  76. return pq.New().Scan(pq.db.QueryRow(query, args...))
  77. }
  78. type Portal struct {
  79. db *Database
  80. log log.Logger
  81. Key PortalKey
  82. Type discordgo.ChannelType
  83. OtherUserID string
  84. ParentID string
  85. GuildID string
  86. MXID id.RoomID
  87. PlainName string
  88. Name string
  89. NameSet bool
  90. Topic string
  91. TopicSet bool
  92. Avatar string
  93. AvatarURL id.ContentURI
  94. AvatarSet bool
  95. Encrypted bool
  96. InSpace id.RoomID
  97. FirstEventID id.EventID
  98. RelayWebhookID string
  99. RelayWebhookSecret string
  100. }
  101. func (p *Portal) Scan(row dbutil.Scannable) *Portal {
  102. var otherUserID, guildID, parentID, mxid, firstEventID, relayWebhookID, relayWebhookSecret sql.NullString
  103. var chanType int32
  104. var avatarURL string
  105. err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &chanType, &otherUserID, &guildID, &parentID,
  106. &mxid, &p.PlainName, &p.Name, &p.NameSet, &p.Topic, &p.TopicSet, &p.Avatar, &avatarURL, &p.AvatarSet,
  107. &p.Encrypted, &p.InSpace, &firstEventID, &relayWebhookID, &relayWebhookSecret)
  108. if err != nil {
  109. if err != sql.ErrNoRows {
  110. p.log.Errorln("Database scan failed:", err)
  111. panic(err)
  112. }
  113. return nil
  114. }
  115. p.MXID = id.RoomID(mxid.String)
  116. p.OtherUserID = otherUserID.String
  117. p.GuildID = guildID.String
  118. p.ParentID = parentID.String
  119. p.Type = discordgo.ChannelType(chanType)
  120. p.FirstEventID = id.EventID(firstEventID.String)
  121. p.AvatarURL, _ = id.ParseContentURI(avatarURL)
  122. p.RelayWebhookID = relayWebhookID.String
  123. p.RelayWebhookSecret = relayWebhookSecret.String
  124. return p
  125. }
  126. func (p *Portal) Insert() {
  127. query := `
  128. INSERT INTO portal (dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, mxid,
  129. plain_name, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
  130. encrypted, in_space, first_event_id, relay_webhook_id, relay_webhook_secret)
  131. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20)
  132. `
  133. _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.Type,
  134. strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)),
  135. p.PlainName, p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
  136. p.Encrypted, p.InSpace, p.FirstEventID.String(), strPtr(p.RelayWebhookID), strPtr(p.RelayWebhookSecret))
  137. if err != nil {
  138. p.log.Warnfln("Failed to insert %s: %v", p.Key, err)
  139. panic(err)
  140. }
  141. }
  142. func (p *Portal) Update() {
  143. query := `
  144. UPDATE portal
  145. SET type=$1, other_user_id=$2, dc_guild_id=$3, dc_parent_id=$4, mxid=$5,
  146. plain_name=$6, name=$7, name_set=$8, topic=$9, topic_set=$10, avatar=$11, avatar_url=$12, avatar_set=$13,
  147. encrypted=$14, in_space=$15, first_event_id=$16, relay_webhook_id=$17, relay_webhook_secret=$18
  148. WHERE dcid=$19 AND receiver=$20
  149. `
  150. _, err := p.db.Exec(query,
  151. p.Type, strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)),
  152. p.PlainName, p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
  153. p.Encrypted, p.InSpace, p.FirstEventID.String(), strPtr(p.RelayWebhookID), strPtr(p.RelayWebhookSecret),
  154. p.Key.ChannelID, p.Key.Receiver)
  155. if err != nil {
  156. p.log.Warnfln("Failed to update %s: %v", p.Key, err)
  157. panic(err)
  158. }
  159. }
  160. func (p *Portal) Delete() {
  161. query := "DELETE FROM portal WHERE dcid=$1 AND receiver=$2"
  162. _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver)
  163. if err != nil {
  164. p.log.Warnfln("Failed to delete %s: %v", p.Key, err)
  165. panic(err)
  166. }
  167. }