portal.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. package database
  2. import (
  3. "database/sql"
  4. "github.com/bwmarrin/discordgo"
  5. "go.mau.fi/util/dbutil"
  6. log "maunium.net/go/maulogger/v2"
  7. "maunium.net/go/mautrix/id"
  8. )
  9. // language=postgresql
  10. const (
  11. portalSelect = `
  12. SELECT dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, mxid,
  13. plain_name, name, name_set, friend_nick, topic, topic_set, avatar, avatar_url, avatar_set,
  14. encrypted, in_space, first_event_id, relay_webhook_id, relay_webhook_secret
  15. FROM portal
  16. `
  17. )
  18. type PortalKey struct {
  19. ChannelID string
  20. Receiver string
  21. }
  22. func NewPortalKey(channelID, receiver string) PortalKey {
  23. return PortalKey{
  24. ChannelID: channelID,
  25. Receiver: receiver,
  26. }
  27. }
  28. func (key PortalKey) String() string {
  29. if key.Receiver == "" {
  30. return key.ChannelID
  31. }
  32. return key.ChannelID + "-" + key.Receiver
  33. }
  34. type PortalQuery struct {
  35. db *Database
  36. log log.Logger
  37. }
  38. func (pq *PortalQuery) New() *Portal {
  39. return &Portal{
  40. db: pq.db,
  41. log: pq.log,
  42. }
  43. }
  44. func (pq *PortalQuery) GetAll() []*Portal {
  45. return pq.getAll(portalSelect)
  46. }
  47. func (pq *PortalQuery) GetAllInGuild(guildID string) []*Portal {
  48. return pq.getAll(portalSelect+" WHERE dc_guild_id=$1", guildID)
  49. }
  50. func (pq *PortalQuery) GetByID(key PortalKey) *Portal {
  51. return pq.get(portalSelect+" WHERE dcid=$1 AND (receiver=$2 OR receiver='')", key.ChannelID, key.Receiver)
  52. }
  53. func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
  54. return pq.get(portalSelect+" WHERE mxid=$1", mxid)
  55. }
  56. func (pq *PortalQuery) FindPrivateChatBetween(id, receiver string) *Portal {
  57. return pq.get(portalSelect+" WHERE other_user_id=$1 AND receiver=$2 AND type=$3", id, receiver, discordgo.ChannelTypeDM)
  58. }
  59. func (pq *PortalQuery) FindPrivateChatsWith(id string) []*Portal {
  60. return pq.getAll(portalSelect+" WHERE other_user_id=$1 AND type=$2", id, discordgo.ChannelTypeDM)
  61. }
  62. func (pq *PortalQuery) FindPrivateChatsOf(receiver string) []*Portal {
  63. query := portalSelect + " portal WHERE receiver=$1 AND type=$2;"
  64. return pq.getAll(query, receiver, discordgo.ChannelTypeDM)
  65. }
  66. func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal {
  67. rows, err := pq.db.Query(query, args...)
  68. if err != nil || rows == nil {
  69. return nil
  70. }
  71. defer rows.Close()
  72. var portals []*Portal
  73. for rows.Next() {
  74. portals = append(portals, pq.New().Scan(rows))
  75. }
  76. return portals
  77. }
  78. func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
  79. return pq.New().Scan(pq.db.QueryRow(query, args...))
  80. }
  81. type Portal struct {
  82. db *Database
  83. log log.Logger
  84. Key PortalKey
  85. Type discordgo.ChannelType
  86. OtherUserID string
  87. ParentID string
  88. GuildID string
  89. MXID id.RoomID
  90. PlainName string
  91. Name string
  92. NameSet bool
  93. FriendNick bool
  94. Topic string
  95. TopicSet bool
  96. Avatar string
  97. AvatarURL id.ContentURI
  98. AvatarSet bool
  99. Encrypted bool
  100. InSpace id.RoomID
  101. FirstEventID id.EventID
  102. RelayWebhookID string
  103. RelayWebhookSecret string
  104. }
  105. func (p *Portal) Scan(row dbutil.Scannable) *Portal {
  106. var otherUserID, guildID, parentID, mxid, firstEventID, relayWebhookID, relayWebhookSecret sql.NullString
  107. var chanType int32
  108. var avatarURL string
  109. err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &chanType, &otherUserID, &guildID, &parentID,
  110. &mxid, &p.PlainName, &p.Name, &p.NameSet, &p.FriendNick, &p.Topic, &p.TopicSet, &p.Avatar, &avatarURL, &p.AvatarSet,
  111. &p.Encrypted, &p.InSpace, &firstEventID, &relayWebhookID, &relayWebhookSecret)
  112. if err != nil {
  113. if err != sql.ErrNoRows {
  114. p.log.Errorln("Database scan failed:", err)
  115. panic(err)
  116. }
  117. return nil
  118. }
  119. p.MXID = id.RoomID(mxid.String)
  120. p.OtherUserID = otherUserID.String
  121. p.GuildID = guildID.String
  122. p.ParentID = parentID.String
  123. p.Type = discordgo.ChannelType(chanType)
  124. p.FirstEventID = id.EventID(firstEventID.String)
  125. p.AvatarURL, _ = id.ParseContentURI(avatarURL)
  126. p.RelayWebhookID = relayWebhookID.String
  127. p.RelayWebhookSecret = relayWebhookSecret.String
  128. return p
  129. }
  130. func (p *Portal) Insert() {
  131. query := `
  132. INSERT INTO portal (dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, mxid,
  133. plain_name, name, name_set, friend_nick, topic, topic_set, avatar, avatar_url, avatar_set,
  134. encrypted, in_space, first_event_id, relay_webhook_id, relay_webhook_secret)
  135. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21)
  136. `
  137. _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.Type,
  138. strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)),
  139. p.PlainName, p.Name, p.NameSet, p.FriendNick, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
  140. p.Encrypted, p.InSpace, p.FirstEventID.String(), strPtr(p.RelayWebhookID), strPtr(p.RelayWebhookSecret))
  141. if err != nil {
  142. p.log.Warnfln("Failed to insert %s: %v", p.Key, err)
  143. panic(err)
  144. }
  145. }
  146. func (p *Portal) Update() {
  147. query := `
  148. UPDATE portal
  149. SET type=$1, other_user_id=$2, dc_guild_id=$3, dc_parent_id=$4, mxid=$5,
  150. plain_name=$6, name=$7, name_set=$8, friend_nick=$9, topic=$10, topic_set=$11,
  151. avatar=$12, avatar_url=$13, avatar_set=$14, encrypted=$15, in_space=$16, first_event_id=$17,
  152. relay_webhook_id=$18, relay_webhook_secret=$19
  153. WHERE dcid=$20 AND receiver=$21
  154. `
  155. _, err := p.db.Exec(query,
  156. p.Type, strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)),
  157. p.PlainName, p.Name, p.NameSet, p.FriendNick, p.Topic, p.TopicSet,
  158. p.Avatar, p.AvatarURL.String(), p.AvatarSet, p.Encrypted, p.InSpace, p.FirstEventID.String(),
  159. strPtr(p.RelayWebhookID), strPtr(p.RelayWebhookSecret),
  160. p.Key.ChannelID, p.Key.Receiver)
  161. if err != nil {
  162. p.log.Warnfln("Failed to update %s: %v", p.Key, err)
  163. panic(err)
  164. }
  165. }
  166. func (p *Portal) Delete() {
  167. query := "DELETE FROM portal WHERE dcid=$1 AND receiver=$2"
  168. _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver)
  169. if err != nil {
  170. p.log.Warnfln("Failed to delete %s: %v", p.Key, err)
  171. panic(err)
  172. }
  173. }