thread.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. package main
  2. import (
  3. "context"
  4. "sync"
  5. "time"
  6. "github.com/bwmarrin/discordgo"
  7. "github.com/rs/zerolog"
  8. "golang.org/x/exp/slices"
  9. "maunium.net/go/mautrix/id"
  10. "go.mau.fi/mautrix-discord/database"
  11. )
  12. type Thread struct {
  13. *database.Thread
  14. Parent *Portal
  15. creationNoticeLock sync.Mutex
  16. initialBackfillAttempted bool
  17. }
  18. func (br *DiscordBridge) GetThreadByID(id string, root *database.Message) *Thread {
  19. br.threadsLock.Lock()
  20. defer br.threadsLock.Unlock()
  21. thread, ok := br.threadsByID[id]
  22. if !ok {
  23. return br.loadThread(br.DB.Thread.GetByDiscordID(id), id, root)
  24. }
  25. return thread
  26. }
  27. func (br *DiscordBridge) GetThreadByRootMXID(mxid id.EventID) *Thread {
  28. br.threadsLock.Lock()
  29. defer br.threadsLock.Unlock()
  30. thread, ok := br.threadsByRootMXID[mxid]
  31. if !ok {
  32. return br.loadThread(br.DB.Thread.GetByMatrixRootMsg(mxid), "", nil)
  33. }
  34. return thread
  35. }
  36. func (br *DiscordBridge) GetThreadByRootOrCreationNoticeMXID(mxid id.EventID) *Thread {
  37. br.threadsLock.Lock()
  38. defer br.threadsLock.Unlock()
  39. thread, ok := br.threadsByRootMXID[mxid]
  40. if !ok {
  41. thread, ok = br.threadsByCreationNoticeMXID[mxid]
  42. if !ok {
  43. return br.loadThread(br.DB.Thread.GetByMatrixRootOrCreationNoticeMsg(mxid), "", nil)
  44. }
  45. }
  46. return thread
  47. }
  48. func (br *DiscordBridge) loadThread(dbThread *database.Thread, id string, root *database.Message) *Thread {
  49. if dbThread == nil {
  50. if root == nil {
  51. return nil
  52. }
  53. dbThread = br.DB.Thread.New()
  54. dbThread.ID = id
  55. dbThread.RootDiscordID = root.DiscordID
  56. dbThread.RootMXID = root.MXID
  57. dbThread.ParentID = root.Channel.ChannelID
  58. dbThread.Insert()
  59. }
  60. thread := &Thread{
  61. Thread: dbThread,
  62. }
  63. thread.Parent = br.GetExistingPortalByID(database.NewPortalKey(thread.ParentID, ""))
  64. br.threadsByID[thread.ID] = thread
  65. br.threadsByRootMXID[thread.RootMXID] = thread
  66. if thread.CreationNoticeMXID != "" {
  67. br.threadsByCreationNoticeMXID[thread.CreationNoticeMXID] = thread
  68. }
  69. return thread
  70. }
  71. func (br *DiscordBridge) threadFound(ctx context.Context, source *User, rootMessage *database.Message, id string, metadata *discordgo.Channel) {
  72. thread := br.GetThreadByID(id, rootMessage)
  73. log := zerolog.Ctx(ctx)
  74. log.Debug().Msg("Marked message as thread root")
  75. if thread.CreationNoticeMXID == "" {
  76. thread.Parent.sendThreadCreationNotice(ctx, thread)
  77. }
  78. // TODO member_ids_preview is probably not guaranteed to contain the source user
  79. if source != nil && metadata != nil && slices.Contains(metadata.MemberIDsPreview, source.DiscordID) && !source.IsInPortal(thread.ID) {
  80. source.MarkInPortal(database.UserPortal{
  81. DiscordID: thread.ID,
  82. Type: database.UserPortalTypeThread,
  83. Timestamp: time.Now(),
  84. })
  85. if metadata.MessageCount > 0 {
  86. go thread.maybeInitialBackfill(source)
  87. } else {
  88. thread.initialBackfillAttempted = true
  89. }
  90. }
  91. }
  92. func (thread *Thread) maybeInitialBackfill(source *User) {
  93. if thread.initialBackfillAttempted || thread.Parent.bridge.Config.Bridge.Backfill.Limits.Initial.Thread == 0 {
  94. return
  95. }
  96. thread.Parent.forwardBackfillLock.Lock()
  97. if thread.Parent.bridge.DB.Message.GetLastInThread(thread.Parent.Key, thread.ID) != nil {
  98. thread.Parent.forwardBackfillLock.Unlock()
  99. return
  100. }
  101. thread.Parent.forwardBackfillInitial(source, thread)
  102. }
  103. func (thread *Thread) Join(user *User) {
  104. if user.IsInPortal(thread.ID) {
  105. return
  106. }
  107. log := user.log.With().Str("thread_id", thread.ID).Str("channel_id", thread.ParentID).Logger()
  108. log.Debug().Msg("Joining thread")
  109. var doBackfill, backfillStarted bool
  110. if !thread.initialBackfillAttempted && thread.Parent.bridge.Config.Bridge.Backfill.Limits.Initial.Thread > 0 {
  111. thread.Parent.forwardBackfillLock.Lock()
  112. lastMessage := thread.Parent.bridge.DB.Message.GetLastInThread(thread.Parent.Key, thread.ID)
  113. if lastMessage != nil {
  114. thread.Parent.forwardBackfillLock.Unlock()
  115. } else {
  116. doBackfill = true
  117. defer func() {
  118. if !backfillStarted {
  119. thread.Parent.forwardBackfillLock.Unlock()
  120. }
  121. }()
  122. }
  123. }
  124. var err error
  125. if user.Session.IsUser {
  126. err = user.Session.ThreadJoinWithLocation(thread.ID, discordgo.ThreadJoinLocationContextMenu)
  127. } else {
  128. err = user.Session.ThreadJoin(thread.ID)
  129. }
  130. if err != nil {
  131. log.Error().Err(err).Msg("Error joining thread")
  132. } else {
  133. user.MarkInPortal(database.UserPortal{
  134. DiscordID: thread.ID,
  135. Type: database.UserPortalTypeThread,
  136. Timestamp: time.Now(),
  137. })
  138. if doBackfill {
  139. go thread.Parent.forwardBackfillInitial(user, thread)
  140. backfillStarted = true
  141. }
  142. }
  143. }