Selaa lähdekoodia

Make backfilling code compatible with threads

This doesn't trigger thread backfill yet, but the backfill methods can
handle threads now.
Tulir Asokan 2 vuotta sitten
vanhempi
sitoutus
8ebad277f5
6 muutettua tiedostoa jossa 84 lisäystä ja 35 poistoa
  1. 76 32
      backfill.go
  2. 1 0
      config/bridge.go
  3. 2 0
      config/upgrade.go
  4. 2 0
      example-config.yaml
  5. 1 1
      portal.go
  6. 2 2
      user.go

+ 76 - 32
backfill.go

@@ -17,7 +17,7 @@ import (
 	"go.mau.fi/mautrix-discord/database"
 )
 
-func (portal *Portal) forwardBackfillInitial(source *User) {
+func (portal *Portal) forwardBackfillInitial(source *User, thread *Thread) {
 	defer portal.forwardBackfillLock.Unlock()
 	// This should only be called from CreateMatrixRoom which locks forwardBackfillLock before creating the room.
 	if portal.forwardBackfillLock.TryLock() {
@@ -27,21 +27,27 @@ func (portal *Portal) forwardBackfillInitial(source *User) {
 	limit := portal.bridge.Config.Bridge.Backfill.Limits.Initial.Channel
 	if portal.GuildID == "" {
 		limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.DM
+		if thread != nil {
+			limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.Thread
+		}
 	}
 	if limit == 0 {
 		return
 	}
 
-	log := portal.log.With().
+	with := portal.log.With().
 		Str("action", "initial backfill").
 		Str("room_id", portal.MXID.String()).
-		Int("limit", limit).
-		Logger()
+		Int("limit", limit)
+	if thread != nil {
+		with = with.Str("thread_id", thread.ID)
+	}
+	log := with.Logger()
 
-	portal.backfillLimited(log, source, limit, "")
+	portal.backfillLimited(log, source, limit, "", thread)
 }
 
-func (portal *Portal) ForwardBackfillMissed(source *User, meta *discordgo.Channel) {
+func (portal *Portal) ForwardBackfillMissed(source *User, serverLastMessageID string, thread *Thread) {
 	if portal.MXID == "" {
 		return
 	}
@@ -49,50 +55,65 @@ func (portal *Portal) ForwardBackfillMissed(source *User, meta *discordgo.Channe
 	limit := portal.bridge.Config.Bridge.Backfill.Limits.Missed.Channel
 	if portal.GuildID == "" {
 		limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.DM
+		if thread != nil {
+			limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.Thread
+		}
 	}
 	if limit == 0 {
 		return
 	}
-	log := portal.log.With().
+	with := portal.log.With().
 		Str("action", "missed event backfill").
 		Str("room_id", portal.MXID.String()).
-		Int("limit", limit).
-		Logger()
+		Int("limit", limit)
+	if thread != nil {
+		with = with.Str("thread_id", thread.ID)
+	}
+	log := with.Logger()
 
 	portal.forwardBackfillLock.Lock()
 	defer portal.forwardBackfillLock.Unlock()
 
-	lastMessage := portal.bridge.DB.Message.GetLast(portal.Key)
-	if lastMessage == nil || meta.LastMessageID == "" {
+	var lastMessage *database.Message
+	if thread != nil {
+		lastMessage = portal.bridge.DB.Message.GetLastInThread(portal.Key, thread.ID)
+	} else {
+		lastMessage = portal.bridge.DB.Message.GetLast(portal.Key)
+	}
+	if lastMessage == nil || serverLastMessageID == "" {
 		log.Debug().Msg("Not backfilling, no last message in database or no last message in metadata")
 		return
-	} else if !shouldBackfill(lastMessage.DiscordID, meta.LastMessageID) {
+	} else if !shouldBackfill(lastMessage.DiscordID, serverLastMessageID) {
 		log.Debug().
 			Str("last_bridged_message", lastMessage.DiscordID).
-			Str("last_server_message", meta.LastMessageID).
+			Str("last_server_message", serverLastMessageID).
 			Msg("Not backfilling, last message in database is newer than last message in metadata")
 		return
 	}
 	log.Debug().
 		Str("last_bridged_message", lastMessage.DiscordID).
-		Str("last_server_message", meta.LastMessageID).
+		Str("last_server_message", serverLastMessageID).
 		Msg("Backfilling missed messages")
 	if limit < 0 {
-		portal.backfillUnlimitedMissed(log, source, lastMessage.DiscordID)
+		portal.backfillUnlimitedMissed(log, source, lastMessage.DiscordID, thread)
 	} else {
-		portal.backfillLimited(log, source, limit, lastMessage.DiscordID)
+		portal.backfillLimited(log, source, limit, lastMessage.DiscordID, thread)
 	}
 }
 
 const messageFetchChunkSize = 50
 
-func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, limit int, until string) ([]*discordgo.Message, bool, error) {
+func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, limit int, until string, thread *Thread) ([]*discordgo.Message, bool, error) {
 	var messages []*discordgo.Message
 	var before string
 	var foundAll bool
+	protoChannelID := portal.Key.ChannelID
+	if thread != nil {
+		protoChannelID = thread.ID
+	}
 	for {
 		log.Debug().Str("before_id", before).Msg("Fetching messages for backfill")
-		newMessages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, before, "", "")
+		newMessages, err := source.Session.ChannelMessages(protoChannelID, messageFetchChunkSize, before, "", "")
 		if err != nil {
 			return nil, false, err
 		}
@@ -123,8 +144,8 @@ func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User,
 	return messages, foundAll, nil
 }
 
-func (portal *Portal) backfillLimited(log zerolog.Logger, source *User, limit int, after string) {
-	messages, foundAll, err := portal.collectBackfillMessages(log, source, limit, after)
+func (portal *Portal) backfillLimited(log zerolog.Logger, source *User, limit int, after string, thread *Thread) {
+	messages, foundAll, err := portal.collectBackfillMessages(log, source, limit, after, thread)
 	if err != nil {
 		log.Err(err).Msg("Error collecting messages to forward backfill")
 		return
@@ -145,13 +166,17 @@ func (portal *Portal) backfillLimited(log zerolog.Logger, source *User, limit in
 			log.Debug().Msg("Sent warning about possibly missed messages")
 		}
 	}
-	portal.sendBackfillBatch(log, source, messages)
+	portal.sendBackfillBatch(log, source, messages, thread)
 }
 
-func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, after string) {
+func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, after string, thread *Thread) {
+	protoChannelID := portal.Key.ChannelID
+	if thread != nil {
+		protoChannelID = thread.ID
+	}
 	for {
 		log.Debug().Str("after_id", after).Msg("Fetching chunk of messages to backfill")
-		messages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, "", after, "")
+		messages, err := source.Session.ChannelMessages(protoChannelID, messageFetchChunkSize, "", after, "")
 		if err != nil {
 			log.Err(err).Msg("Error fetching chunk of messages to forward backfill")
 			return
@@ -159,7 +184,7 @@ func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User,
 		log.Debug().Int("count", len(messages)).Msg("Fetched chunk of messages to backfill")
 		sort.Sort(MessageSlice(messages))
 
-		portal.sendBackfillBatch(log, source, messages)
+		portal.sendBackfillBatch(log, source, messages, thread)
 
 		if len(messages) < messageFetchChunkSize {
 			// Assume that was all the missing messages
@@ -170,20 +195,20 @@ func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User,
 	}
 }
 
-func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messages []*discordgo.Message) {
+func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) {
 	if portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry {
 		log.Debug().Msg("Using hungryserv, sending messages with batch send endpoint")
-		portal.forwardBatchSend(log, source, messages)
+		portal.forwardBatchSend(log, source, messages, thread)
 	} else {
 		log.Debug().Msg("Not using hungryserv, sending messages one by one")
 		for _, msg := range messages {
-			portal.handleDiscordMessageCreate(source, msg, nil)
+			portal.handleDiscordMessageCreate(source, msg, thread)
 		}
 	}
 }
 
-func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message) {
-	evts, dbMessages := portal.convertMessageBatch(log, source, messages)
+func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) {
+	evts, dbMessages := portal.convertMessageBatch(log, source, messages, thread)
 	if len(evts) == 0 {
 		log.Warn().Msg("Didn't get any events to backfill")
 		return
@@ -204,7 +229,19 @@ func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, message
 	log.Info().Msg("Inserted backfilled batch to database")
 }
 
-func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, messages []*discordgo.Message) ([]*event.Event, []database.Message) {
+func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) ([]*event.Event, []database.Message) {
+	var discordThreadID string
+	var threadRootEvent, lastThreadEvent id.EventID
+	if thread != nil {
+		discordThreadID = thread.ID
+		threadRootEvent = thread.RootMXID
+		lastThreadEvent = threadRootEvent
+		lastInThread := portal.bridge.DB.Message.GetLastInThread(portal.Key, thread.ID)
+		if lastInThread != nil {
+			lastThreadEvent = lastInThread.MXID
+		}
+	}
+
 	evts := make([]*event.Event, 0, len(messages))
 	dbMessages := make([]database.Message, 0, len(messages))
 	ctx := context.Background()
@@ -217,7 +254,7 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess
 		puppet := portal.bridge.GetPuppetByID(msg.Author.ID)
 		puppet.UpdateInfo(source, msg.Author, msg.WebhookID)
 		intent := puppet.IntentFor(portal)
-		replyTo := portal.getReplyTarget(source, "", msg.MessageReference, msg.Embeds, true)
+		replyTo := portal.getReplyTarget(source, discordThreadID, msg.MessageReference, msg.Embeds, true)
 		mentions := portal.convertDiscordMentions(msg, false)
 
 		ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
@@ -228,8 +265,14 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess
 			Logger()
 		parts := portal.convertDiscordMessage(log.WithContext(ctx), puppet, intent, msg)
 		for i, part := range parts {
+			if (replyTo != nil || threadRootEvent != "") && part.Content.RelatesTo == nil {
+				part.Content.RelatesTo = &event.RelatesTo{}
+			}
+			if threadRootEvent != "" {
+				part.Content.RelatesTo.SetThread(threadRootEvent, lastThreadEvent)
+			}
 			if replyTo != nil {
-				part.Content.RelatesTo = &event.RelatesTo{InReplyTo: replyTo}
+				part.Content.RelatesTo.SetReplyTo(replyTo.EventID)
 				// Only set reply for first event
 				replyTo = nil
 			}
@@ -270,6 +313,7 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess
 				AttachmentID: part.AttachmentID,
 				SenderMXID:   intent.UserID,
 			})
+			lastThreadEvent = evt.ID
 		}
 	}
 	return evts, dbMessages

+ 1 - 0
config/bridge.go

@@ -207,6 +207,7 @@ func (mp *MediaPatterns) Avatar(userID, avatarID, ext string) id.ContentURI {
 type BackfillLimitPart struct {
 	DM      int `yaml:"dm"`
 	Channel int `yaml:"channel"`
+	Thread  int `yaml:"thread"`
 }
 
 func (bc *BridgeConfig) GetResendBridgeInfo() bool {

+ 2 - 0
config/upgrade.go

@@ -79,8 +79,10 @@ func DoUpgrade(helper *up.Helper) {
 	helper.Copy(up.Bool, "bridge", "backfill", "enabled")
 	helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "initial", "dm")
 	helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "initial", "channel")
+	helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "initial", "thread")
 	helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "missed", "dm")
 	helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "missed", "channel")
+	helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "missed", "thread")
 	helper.Copy(up.Int, "bridge", "backfill", "max_guild_members")
 	helper.Copy(up.Bool, "bridge", "encryption", "allow")
 	helper.Copy(up.Bool, "bridge", "encryption", "default")

+ 2 - 0
example-config.yaml

@@ -232,6 +232,7 @@ bridge:
             initial:
                 dm: 0
                 channel: 0
+                thread: 0
             # Missed message backfill (on startup).
             # 0 means backfill is disabled, -1 means fetch all messages since last bridged message.
             # When using unlimited backfill (-1), messages are backfilled as they are fetched.
@@ -239,6 +240,7 @@ bridge:
             missed:
                 dm: 0
                 channel: 0
+                thread: 0
         # Maximum members in a guild to enable backfilling. Set to -1 to disable limit.
         # This can be used as a rough heuristic to disable backfilling in channels that are too active.
         # Currently only applies to missed message backfill.

+ 1 - 1
portal.go

@@ -541,7 +541,7 @@ func (portal *Portal) CreateMatrixRoom(user *User, channel *discordgo.Channel) e
 		portal.Update()
 	}
 
-	go portal.forwardBackfillInitial(user)
+	go portal.forwardBackfillInitial(user, nil)
 	backfillStarted = true
 
 	return nil

+ 2 - 2
user.go

@@ -864,7 +864,7 @@ func (user *User) handlePrivateChannel(portal *Portal, meta *discordgo.Channel,
 		}
 	} else {
 		portal.UpdateInfo(user, meta)
-		portal.ForwardBackfillMissed(user, meta)
+		portal.ForwardBackfillMissed(user, meta.LastMessageID, nil)
 	}
 	user.MarkInPortal(database.UserPortal{
 		DiscordID: portal.Key.ChannelID,
@@ -966,7 +966,7 @@ func (user *User) handleGuild(meta *discordgo.Guild, timestamp time.Time, isInSp
 			} else {
 				portal.UpdateInfo(user, ch)
 				if user.bridge.Config.Bridge.Backfill.MaxGuildMembers < 0 || meta.MemberCount < user.bridge.Config.Bridge.Backfill.MaxGuildMembers {
-					portal.ForwardBackfillMissed(user, ch)
+					portal.ForwardBackfillMissed(user, ch.LastMessageID, nil)
 				}
 			}
 		}