1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- package main
- import (
- "sync"
- "time"
- "github.com/bwmarrin/discordgo"
- "maunium.net/go/mautrix/id"
- "go.mau.fi/mautrix-discord/database"
- )
- type Thread struct {
- *database.Thread
- Parent *Portal
- creationNoticeLock sync.Mutex
- }
- func (br *DiscordBridge) GetThreadByID(id string, root *database.Message) *Thread {
- br.threadsLock.Lock()
- defer br.threadsLock.Unlock()
- thread, ok := br.threadsByID[id]
- if !ok {
- return br.loadThread(br.DB.Thread.GetByDiscordID(id), id, root)
- }
- return thread
- }
- func (br *DiscordBridge) GetThreadByRootMXID(mxid id.EventID) *Thread {
- br.threadsLock.Lock()
- defer br.threadsLock.Unlock()
- thread, ok := br.threadsByRootMXID[mxid]
- if !ok {
- return br.loadThread(br.DB.Thread.GetByMatrixRootMsg(mxid), "", nil)
- }
- return thread
- }
- func (br *DiscordBridge) GetThreadByRootOrCreationNoticeMXID(mxid id.EventID) *Thread {
- br.threadsLock.Lock()
- defer br.threadsLock.Unlock()
- thread, ok := br.threadsByRootMXID[mxid]
- if !ok {
- thread, ok = br.threadsByCreationNoticeMXID[mxid]
- if !ok {
- return br.loadThread(br.DB.Thread.GetByMatrixRootOrCreationNoticeMsg(mxid), "", nil)
- }
- }
- return thread
- }
- func (br *DiscordBridge) loadThread(dbThread *database.Thread, id string, root *database.Message) *Thread {
- if dbThread == nil {
- if root == nil {
- return nil
- }
- dbThread = br.DB.Thread.New()
- dbThread.ID = id
- dbThread.RootDiscordID = root.DiscordID
- dbThread.RootMXID = root.MXID
- dbThread.ParentID = root.Channel.ChannelID
- dbThread.Insert()
- }
- thread := &Thread{
- Thread: dbThread,
- }
- thread.Parent = br.GetExistingPortalByID(database.NewPortalKey(thread.ParentID, ""))
- br.threadsByID[thread.ID] = thread
- br.threadsByRootMXID[thread.RootMXID] = thread
- if thread.CreationNoticeMXID != "" {
- br.threadsByCreationNoticeMXID[thread.CreationNoticeMXID] = thread
- }
- return thread
- }
- func (thread *Thread) Join(user *User) {
- if user.IsInPortal(thread.ID) {
- return
- }
- user.log.Debugfln("Joining thread %s@%s", thread.ID, thread.ParentID)
- err := user.Session.ThreadJoinWithLocation(thread.ID, discordgo.ThreadJoinLocationContextMenu)
- if err != nil {
- user.log.Errorfln("Error joining thread %s@%s: %v", thread.ID, thread.ParentID, err)
- } else {
- user.MarkInPortal(database.UserPortal{
- DiscordID: thread.ID,
- Type: database.UserPortalTypeThread,
- Timestamp: time.Now(),
- })
- }
- }
|