|
@@ -0,0 +1,259 @@
|
|
|
+package main
|
|
|
+
|
|
|
+import (
|
|
|
+ "crypto/sha256"
|
|
|
+ "encoding/base64"
|
|
|
+ "fmt"
|
|
|
+ "sort"
|
|
|
+
|
|
|
+ "github.com/bwmarrin/discordgo"
|
|
|
+ "github.com/rs/zerolog"
|
|
|
+ "maunium.net/go/mautrix"
|
|
|
+ "maunium.net/go/mautrix/bridge/bridgeconfig"
|
|
|
+ "maunium.net/go/mautrix/event"
|
|
|
+ "maunium.net/go/mautrix/id"
|
|
|
+
|
|
|
+ "go.mau.fi/mautrix-discord/database"
|
|
|
+)
|
|
|
+
|
|
|
+func (portal *Portal) ForwardBackfillMissed(source *User, meta *discordgo.Channel) {
|
|
|
+ limit := portal.bridge.Config.Bridge.Backfill.Limits.Missed.Channel
|
|
|
+ if portal.GuildID == "" {
|
|
|
+ limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.DM
|
|
|
+ }
|
|
|
+ if limit == 0 {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ log := portal.zlog.With().
|
|
|
+ Str("action", "missed event backfill").
|
|
|
+ Int("limit", limit).
|
|
|
+ Logger()
|
|
|
+
|
|
|
+ portal.forwardBackfillLock.Lock()
|
|
|
+ defer portal.forwardBackfillLock.Unlock()
|
|
|
+
|
|
|
+ lastMessage := portal.bridge.DB.Message.GetLast(portal.Key)
|
|
|
+ if lastMessage == nil || meta.LastMessageID == "" {
|
|
|
+ log.Debug().Msg("Not backfilling, no last message in database or no last message in metadata")
|
|
|
+ return
|
|
|
+ } else if !shouldBackfill(lastMessage.DiscordID, meta.LastMessageID) {
|
|
|
+ log.Debug().
|
|
|
+ Str("last_bridged_message", lastMessage.DiscordID).
|
|
|
+ Str("last_server_message", meta.LastMessageID).
|
|
|
+ Msg("Not backfilling, last message in database is newer than last message in metadata")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ log.Debug().
|
|
|
+ Str("last_bridged_message", lastMessage.DiscordID).
|
|
|
+ Str("last_server_message", meta.LastMessageID).
|
|
|
+ Msg("Backfilling missed messages")
|
|
|
+ if limit < 0 {
|
|
|
+ portal.backfillUnlimitedMissed(log, source, lastMessage.DiscordID)
|
|
|
+ } else {
|
|
|
+ portal.backfillLimitedMissed(log, source, limit, lastMessage.DiscordID)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+const messageFetchChunkSize = 50
|
|
|
+
|
|
|
+func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, limit int, until string) ([]*discordgo.Message, bool, error) {
|
|
|
+ var messages []*discordgo.Message
|
|
|
+ var before string
|
|
|
+ var foundAll bool
|
|
|
+ for {
|
|
|
+ log.Debug().Str("before_id", before).Msg("Fetching messages for backfill")
|
|
|
+ newMessages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, before, "", "")
|
|
|
+ if err != nil {
|
|
|
+ return nil, false, err
|
|
|
+ }
|
|
|
+ for i, msg := range newMessages {
|
|
|
+ if compareMessageIDs(msg.ID, until) <= 0 {
|
|
|
+ log.Debug().
|
|
|
+ Str("message_id", msg.ID).
|
|
|
+ Str("until_id", until).
|
|
|
+ Msg("Found message that was already bridged")
|
|
|
+ newMessages = newMessages[:i]
|
|
|
+ foundAll = true
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ messages = append(messages, newMessages...)
|
|
|
+ log.Debug().Int("count", len(newMessages)).Msg("Added messages to backfill collection")
|
|
|
+ if len(newMessages) <= messageFetchChunkSize || len(messages) >= limit {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ before = newMessages[len(newMessages)-1].ID
|
|
|
+ }
|
|
|
+ if len(messages) > limit {
|
|
|
+ messages = messages[:limit]
|
|
|
+ }
|
|
|
+ return messages, foundAll, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (portal *Portal) backfillLimitedMissed(log zerolog.Logger, source *User, limit int, after string) {
|
|
|
+ messages, foundAll, err := portal.collectBackfillMessages(log, source, limit, after)
|
|
|
+ if err != nil {
|
|
|
+ log.Err(err).Msg("Error collecting messages to forward backfill")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ log.Info().
|
|
|
+ Int("count", len(messages)).
|
|
|
+ Bool("found_all", foundAll).
|
|
|
+ Msg("Collected messages to backfill")
|
|
|
+ sort.Sort(MessageSlice(messages))
|
|
|
+ if !foundAll {
|
|
|
+ _, err = portal.sendMatrixMessage(portal.MainIntent(), event.EventMessage, &event.MessageEventContent{
|
|
|
+ MsgType: event.MsgNotice,
|
|
|
+ Body: "Some messages may have been missed here while the bridge was offline.",
|
|
|
+ }, nil, 0)
|
|
|
+ if err != nil {
|
|
|
+ log.Warn().Err(err).Msg("Failed to send missed message warning")
|
|
|
+ } else {
|
|
|
+ log.Debug().Msg("Sent warning about possibly missed messages")
|
|
|
+ }
|
|
|
+ }
|
|
|
+ portal.sendBackfillBatch(log, source, messages)
|
|
|
+}
|
|
|
+
|
|
|
+func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, after string) {
|
|
|
+ for {
|
|
|
+ log.Debug().Str("after_id", after).Msg("Fetching chunk of messages to backfill")
|
|
|
+ messages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, "", after, "")
|
|
|
+ if err != nil {
|
|
|
+ log.Err(err).Msg("Error fetching chunk of messages to forward backfill")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ log.Debug().Int("count", len(messages)).Msg("Fetched chunk of messages to backfill")
|
|
|
+ sort.Sort(MessageSlice(messages))
|
|
|
+
|
|
|
+ portal.sendBackfillBatch(log, source, messages)
|
|
|
+
|
|
|
+ if len(messages) < messageFetchChunkSize {
|
|
|
+ // Assume that was all the missing messages
|
|
|
+ log.Debug().Msg("Chunk had less than 50 messages, stopping backfill")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ after = messages[len(messages)-1].ID
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messages []*discordgo.Message) {
|
|
|
+ if portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry {
|
|
|
+ log.Debug().Msg("Using hungryserv, sending messages with batch send endpoint")
|
|
|
+ portal.forwardBatchSend(log, source, messages)
|
|
|
+ } else {
|
|
|
+ log.Debug().Msg("Not using hungryserv, sending messages one by one")
|
|
|
+ for _, msg := range messages {
|
|
|
+ portal.handleDiscordMessageCreate(source, msg, nil)
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message) {
|
|
|
+ evts := make([]*event.Event, 0, len(messages))
|
|
|
+ dbMessages := make([]database.Message, 0, len(messages))
|
|
|
+ for _, msg := range messages {
|
|
|
+ for _, mention := range msg.Mentions {
|
|
|
+ puppet := portal.bridge.GetPuppetByID(mention.ID)
|
|
|
+ puppet.UpdateInfo(nil, mention)
|
|
|
+ }
|
|
|
+
|
|
|
+ puppet := portal.bridge.GetPuppetByID(msg.Author.ID)
|
|
|
+ puppet.UpdateInfo(source, msg.Author)
|
|
|
+ intent := puppet.IntentFor(portal)
|
|
|
+ replyTo := portal.getReplyTarget(source, msg.MessageReference, true)
|
|
|
+
|
|
|
+ ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
|
|
|
+ parts := portal.convertDiscordMessage(intent, msg)
|
|
|
+ for i, part := range parts {
|
|
|
+ if replyTo != nil {
|
|
|
+ part.Content.RelatesTo = &event.RelatesTo{InReplyTo: replyTo}
|
|
|
+ // Only set reply for first event
|
|
|
+ replyTo = nil
|
|
|
+ }
|
|
|
+ partName := part.AttachmentID
|
|
|
+ // Always use blank part name for first part so that replies and other things
|
|
|
+ // can reference it without knowing about attachments.
|
|
|
+ if i == 0 {
|
|
|
+ partName = ""
|
|
|
+ }
|
|
|
+ evts = append(evts, &event.Event{
|
|
|
+ ID: portal.deterministicEventID(msg.ID, partName),
|
|
|
+ Type: part.Type,
|
|
|
+ Sender: intent.UserID,
|
|
|
+ Timestamp: ts.UnixMilli(),
|
|
|
+ Content: event.Content{
|
|
|
+ Parsed: part.Content,
|
|
|
+ Raw: part.Extra,
|
|
|
+ },
|
|
|
+ })
|
|
|
+ dbMessages = append(dbMessages, database.Message{
|
|
|
+ Channel: portal.Key,
|
|
|
+ DiscordID: msg.ID,
|
|
|
+ SenderID: msg.Author.ID,
|
|
|
+ Timestamp: ts,
|
|
|
+ AttachmentID: part.AttachmentID,
|
|
|
+ })
|
|
|
+ }
|
|
|
+ }
|
|
|
+ log.Info().Int("parts", len(evts)).Msg("Converted messages to backfill")
|
|
|
+ resp, err := portal.MainIntent().BatchSend(portal.MXID, &mautrix.ReqBatchSend{
|
|
|
+ BeeperNewMessages: true,
|
|
|
+ Events: evts,
|
|
|
+ })
|
|
|
+ if err != nil {
|
|
|
+ log.Err(err).Msg("Error sending backfill batch")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ for i, evtID := range resp.EventIDs {
|
|
|
+ dbMessages[i].MXID = evtID
|
|
|
+ }
|
|
|
+ portal.bridge.DB.Message.MassInsert(portal.Key, dbMessages)
|
|
|
+ log.Info().Msg("Inserted backfilled batch to database")
|
|
|
+}
|
|
|
+
|
|
|
+func (portal *Portal) deterministicEventID(messageID, partName string) id.EventID {
|
|
|
+ data := fmt.Sprintf("%s/discord/%s/%s", portal.MXID, messageID, partName)
|
|
|
+ sum := sha256.Sum256([]byte(data))
|
|
|
+ return id.EventID(fmt.Sprintf("$%s:discord.com", base64.RawURLEncoding.EncodeToString(sum[:])))
|
|
|
+}
|
|
|
+
|
|
|
+// compareMessageIDs compares two Discord message IDs.
|
|
|
+//
|
|
|
+// If the first ID is lower, -1 is returned.
|
|
|
+// If the second ID is lower, 1 is returned.
|
|
|
+// If the IDs are equal, 0 is returned.
|
|
|
+func compareMessageIDs(id1, id2 string) int {
|
|
|
+ if id1 == id2 {
|
|
|
+ return 0
|
|
|
+ }
|
|
|
+ if len(id1) < len(id2) {
|
|
|
+ return -1
|
|
|
+ } else if len(id2) < len(id1) {
|
|
|
+ return 1
|
|
|
+ }
|
|
|
+ if id1 < id2 {
|
|
|
+ return -1
|
|
|
+ }
|
|
|
+ return 1
|
|
|
+}
|
|
|
+
|
|
|
+func shouldBackfill(latestBridgedIDStr, latestIDFromServerStr string) bool {
|
|
|
+ return compareMessageIDs(latestBridgedIDStr, latestIDFromServerStr) == -1
|
|
|
+}
|
|
|
+
|
|
|
+type MessageSlice []*discordgo.Message
|
|
|
+
|
|
|
+var _ sort.Interface = (MessageSlice)(nil)
|
|
|
+
|
|
|
+func (a MessageSlice) Len() int {
|
|
|
+ return len(a)
|
|
|
+}
|
|
|
+
|
|
|
+func (a MessageSlice) Swap(i, j int) {
|
|
|
+ a[i], a[j] = a[j], a[i]
|
|
|
+}
|
|
|
+
|
|
|
+func (a MessageSlice) Less(i, j int) bool {
|
|
|
+ return compareMessageIDs(a[i].ID, a[j].ID) == -1
|
|
|
+}
|