thread.go 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. package main
  2. import (
  3. "sync"
  4. "time"
  5. "github.com/bwmarrin/discordgo"
  6. "maunium.net/go/mautrix/id"
  7. "go.mau.fi/mautrix-discord/database"
  8. )
  9. type Thread struct {
  10. *database.Thread
  11. Parent *Portal
  12. creationNoticeLock sync.Mutex
  13. }
  14. func (br *DiscordBridge) GetThreadByID(id string, root *database.Message) *Thread {
  15. br.threadsLock.Lock()
  16. defer br.threadsLock.Unlock()
  17. thread, ok := br.threadsByID[id]
  18. if !ok {
  19. return br.loadThread(br.DB.Thread.GetByDiscordID(id), id, root)
  20. }
  21. return thread
  22. }
  23. func (br *DiscordBridge) GetThreadByRootMXID(mxid id.EventID) *Thread {
  24. br.threadsLock.Lock()
  25. defer br.threadsLock.Unlock()
  26. thread, ok := br.threadsByRootMXID[mxid]
  27. if !ok {
  28. return br.loadThread(br.DB.Thread.GetByMatrixRootMsg(mxid), "", nil)
  29. }
  30. return thread
  31. }
  32. func (br *DiscordBridge) GetThreadByRootOrCreationNoticeMXID(mxid id.EventID) *Thread {
  33. br.threadsLock.Lock()
  34. defer br.threadsLock.Unlock()
  35. thread, ok := br.threadsByRootMXID[mxid]
  36. if !ok {
  37. thread, ok = br.threadsByCreationNoticeMXID[mxid]
  38. if !ok {
  39. return br.loadThread(br.DB.Thread.GetByMatrixRootOrCreationNoticeMsg(mxid), "", nil)
  40. }
  41. }
  42. return thread
  43. }
  44. func (br *DiscordBridge) loadThread(dbThread *database.Thread, id string, root *database.Message) *Thread {
  45. if dbThread == nil {
  46. if root == nil {
  47. return nil
  48. }
  49. dbThread = br.DB.Thread.New()
  50. dbThread.ID = id
  51. dbThread.RootDiscordID = root.DiscordID
  52. dbThread.RootMXID = root.MXID
  53. dbThread.ParentID = root.Channel.ChannelID
  54. dbThread.Insert()
  55. }
  56. thread := &Thread{
  57. Thread: dbThread,
  58. }
  59. thread.Parent = br.GetExistingPortalByID(database.NewPortalKey(thread.ParentID, ""))
  60. br.threadsByID[thread.ID] = thread
  61. br.threadsByRootMXID[thread.RootMXID] = thread
  62. if thread.CreationNoticeMXID != "" {
  63. br.threadsByCreationNoticeMXID[thread.CreationNoticeMXID] = thread
  64. }
  65. return thread
  66. }
  67. func (thread *Thread) Join(user *User) {
  68. if user.IsInPortal(thread.ID) {
  69. return
  70. }
  71. user.log.Debugfln("Joining thread %s@%s", thread.ID, thread.ParentID)
  72. err := user.Session.ThreadJoinWithLocation(thread.ID, discordgo.ThreadJoinLocationContextMenu)
  73. if err != nil {
  74. user.log.Errorfln("Error joining thread %s@%s: %v", thread.ID, thread.ParentID, err)
  75. } else {
  76. user.MarkInPortal(database.UserPortal{
  77. DiscordID: thread.ID,
  78. Type: database.UserPortalTypeThread,
  79. Timestamp: time.Now(),
  80. })
  81. }
  82. }