Browse Source

Backfill threads when found and from server thread list sync

Tulir Asokan 2 years ago
parent
commit
11b91dc299
6 changed files with 97 additions and 24 deletions
  1. 4 10
      backfill.go
  2. 1 1
      go.mod
  3. 2 2
      go.sum
  4. 3 10
      portal.go
  5. 60 1
      thread.go
  6. 27 0
      user.go

+ 4 - 10
backfill.go

@@ -29,6 +29,7 @@ func (portal *Portal) forwardBackfillInitial(source *User, thread *Thread) {
 		limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.DM
 		if thread != nil {
 			limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.Thread
+			thread.initialBackfillAttempted = true
 		}
 	}
 	if limit == 0 {
@@ -225,16 +226,9 @@ func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, message
 	for i, evtID := range resp.EventIDs {
 		dbMessages[i].MXID = evtID
 		if metas[i] != nil && metas[i].Flags == discordgo.MessageFlagsHasThread {
-			thread = portal.bridge.GetThreadByID(metas[i].ID, &dbMessages[i])
-			log.Debug().
-				Str("message_id", metas[i].ID).
-				Str("event_id", evtID.String()).
-				Msg("Marked backfilled message as thread root")
-			if thread.CreationNoticeMXID == "" {
-				// TODO proper context
-				ctx := log.WithContext(context.Background())
-				portal.sendThreadCreationNotice(ctx, thread)
-			}
+			// TODO proper context
+			ctx := log.WithContext(context.Background())
+			portal.bridge.threadFound(ctx, source, &dbMessages[i], metas[i].ID, metas[i].Thread)
 		}
 	}
 	portal.bridge.DB.Message.MassInsert(portal.Key, dbMessages)

+ 1 - 1
go.mod

@@ -38,4 +38,4 @@ require (
 	maunium.net/go/mauflag v1.0.0 // indirect
 )
 
-replace github.com/bwmarrin/discordgo => github.com/beeper/discordgo v0.0.0-20230512133900-5b12693331c0
+replace github.com/bwmarrin/discordgo => github.com/beeper/discordgo v0.0.0-20230618183737-3c7afd8d8596

+ 2 - 2
go.sum

@@ -1,6 +1,6 @@
 github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
-github.com/beeper/discordgo v0.0.0-20230512133900-5b12693331c0 h1:ECBEbC4ruaXzcVJJ4UurkGpT/Xlm9ZnwsHiHn9gjPZw=
-github.com/beeper/discordgo v0.0.0-20230512133900-5b12693331c0/go.mod h1:59+AOzzjmL6onAh62nuLXmn7dJCaC/owDLWbGtjTcFA=
+github.com/beeper/discordgo v0.0.0-20230618183737-3c7afd8d8596 h1:PxtbetWbVi2OlACDNtx6YJahhXt/rhiEsGqtOOLSx4o=
+github.com/beeper/discordgo v0.0.0-20230618183737-3c7afd8d8596/go.mod h1:59+AOzzjmL6onAh62nuLXmn7dJCaC/owDLWbGtjTcFA=
 github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
 github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
 github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=

+ 3 - 10
portal.go

@@ -683,11 +683,7 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
 	} else {
 		firstDBMessage := portal.markMessageHandled(msg.ID, msg.Author.ID, ts, discordThreadID, intent.UserID, dbParts)
 		if msg.Flags == discordgo.MessageFlagsHasThread {
-			thread = portal.bridge.GetThreadByID(msg.ID, firstDBMessage)
-			log.Debug().Msg("Marked message as thread root")
-			if thread.CreationNoticeMXID == "" {
-				portal.sendThreadCreationNotice(ctx, thread)
-			}
+			portal.bridge.threadFound(ctx, user, firstDBMessage, msg.ID, msg.Thread)
 		}
 	}
 }
@@ -817,11 +813,7 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess
 	}
 
 	if msg.Flags == discordgo.MessageFlagsHasThread {
-		thread := portal.bridge.GetThreadByID(msg.ID, existing[0])
-		log.Debug().Msg("Marked message as thread root")
-		if thread.CreationNoticeMXID == "" {
-			portal.sendThreadCreationNotice(ctx, thread)
-		}
+		portal.bridge.threadFound(ctx, user, existing[0], msg.ID, msg.Thread)
 	}
 
 	if msg.Author == nil {
@@ -1476,6 +1468,7 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) {
 		existingThread := portal.bridge.GetThreadByRootMXID(threadRoot)
 		if existingThread != nil {
 			threadID = existingThread.ID
+			existingThread.initialBackfillAttempted = true
 		} else {
 			if isWebhookSend {
 				// TODO start thread with bot?

+ 60 - 1
thread.go

@@ -1,10 +1,13 @@
 package main
 
 import (
+	"context"
 	"sync"
 	"time"
 
 	"github.com/bwmarrin/discordgo"
+	"github.com/rs/zerolog"
+	"golang.org/x/exp/slices"
 	"maunium.net/go/mautrix/id"
 
 	"go.mau.fi/mautrix-discord/database"
@@ -14,7 +17,8 @@ type Thread struct {
 	*database.Thread
 	Parent *Portal
 
-	creationNoticeLock sync.Mutex
+	creationNoticeLock       sync.Mutex
+	initialBackfillAttempted bool
 }
 
 func (br *DiscordBridge) GetThreadByID(id string, root *database.Message) *Thread {
@@ -74,12 +78,63 @@ func (br *DiscordBridge) loadThread(dbThread *database.Thread, id string, root *
 	return thread
 }
 
+func (br *DiscordBridge) threadFound(ctx context.Context, source *User, rootMessage *database.Message, id string, metadata *discordgo.Channel) {
+	thread := br.GetThreadByID(id, rootMessage)
+	log := zerolog.Ctx(ctx)
+	log.Debug().Msg("Marked message as thread root")
+	if thread.CreationNoticeMXID == "" {
+		thread.Parent.sendThreadCreationNotice(ctx, thread)
+	}
+	// TODO member_ids_preview is probably not guaranteed to contain the source user
+	if source != nil && metadata != nil && slices.Contains(metadata.MemberIDsPreview, source.DiscordID) && !source.IsInPortal(thread.ID) {
+		source.MarkInPortal(database.UserPortal{
+			DiscordID: thread.ID,
+			Type:      database.UserPortalTypeThread,
+			Timestamp: time.Now(),
+		})
+		if metadata.MessageCount > 0 {
+			go thread.maybeInitialBackfill(source)
+		} else {
+			thread.initialBackfillAttempted = true
+		}
+	}
+}
+
+func (thread *Thread) maybeInitialBackfill(source *User) {
+	if thread.initialBackfillAttempted || thread.Parent.bridge.Config.Bridge.Backfill.Limits.Initial.Thread == 0 {
+		return
+	}
+	thread.Parent.forwardBackfillLock.Lock()
+	if thread.Parent.bridge.DB.Message.GetLastInThread(thread.Parent.Key, thread.ID) != nil {
+		thread.Parent.forwardBackfillLock.Unlock()
+		return
+	}
+	thread.Parent.forwardBackfillInitial(source, thread)
+}
+
 func (thread *Thread) Join(user *User) {
 	if user.IsInPortal(thread.ID) {
 		return
 	}
 	log := user.log.With().Str("thread_id", thread.ID).Str("channel_id", thread.ParentID).Logger()
 	log.Debug().Msg("Joining thread")
+
+	var doBackfill, backfillStarted bool
+	if !thread.initialBackfillAttempted && thread.Parent.bridge.Config.Bridge.Backfill.Limits.Initial.Thread > 0 {
+		thread.Parent.forwardBackfillLock.Lock()
+		lastMessage := thread.Parent.bridge.DB.Message.GetLastInThread(thread.Parent.Key, thread.ID)
+		if lastMessage != nil {
+			thread.Parent.forwardBackfillLock.Unlock()
+		} else {
+			doBackfill = true
+			defer func() {
+				if !backfillStarted {
+					thread.Parent.forwardBackfillLock.Unlock()
+				}
+			}()
+		}
+	}
+
 	var err error
 	if user.Session.IsUser {
 		err = user.Session.ThreadJoinWithLocation(thread.ID, discordgo.ThreadJoinLocationContextMenu)
@@ -94,5 +149,9 @@ func (thread *Thread) Join(user *User) {
 			Type:      database.UserPortalTypeThread,
 			Timestamp: time.Now(),
 		})
+		if doBackfill {
+			go thread.Parent.forwardBackfillInitial(user, thread)
+			backfillStarted = true
+		}
 	}
 }

+ 27 - 0
user.go

@@ -1,6 +1,7 @@
 package main
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"math/rand"
@@ -650,6 +651,8 @@ func (user *User) eventHandler(rawEvt any) {
 		user.typingStartHandler(evt)
 	case *discordgo.InteractionSuccess:
 		user.interactionSuccessHandler(evt)
+	case *discordgo.ThreadListSync:
+		user.threadListSyncHandler(evt)
 	case *discordgo.Event:
 		// Ignore
 	default:
@@ -1038,6 +1041,30 @@ func (user *User) guildUpdateHandler(g *discordgo.GuildUpdate) {
 	user.handleGuild(g.Guild, time.Now(), user.IsInSpace(g.ID))
 }
 
+func (user *User) threadListSyncHandler(t *discordgo.ThreadListSync) {
+	for _, meta := range t.Threads {
+		log := user.log.With().
+			Str("action", "thread list sync").
+			Str("guild_id", t.GuildID).
+			Str("parent_id", meta.ParentID).
+			Str("thread_id", meta.ID).
+			Logger()
+		ctx := log.WithContext(context.Background())
+		thread := user.bridge.GetThreadByID(meta.ID, nil)
+		if thread == nil {
+			msg := user.bridge.DB.Message.GetByDiscordID(database.NewPortalKey(meta.ParentID, ""), meta.ID)
+			if len(msg) == 0 {
+				log.Debug().Msg("Found unknown thread in thread list sync and don't have message")
+			} else {
+				log.Debug().Msg("Found unknown thread in thread list sync for existing message, creating thread")
+				user.bridge.threadFound(ctx, user, msg[0], meta.ID, meta)
+			}
+		} else {
+			thread.Parent.ForwardBackfillMissed(user, meta.LastMessageID, thread)
+		}
+	}
+}
+
 func (user *User) channelCreateHandler(c *discordgo.ChannelCreate) {
 	if user.getGuildBridgingMode(c.GuildID) < database.GuildBridgeEverything {
 		user.log.Debug().