123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315 |
- package main
- import (
- "context"
- "crypto/sha256"
- "encoding/base64"
- "fmt"
- "sort"
- "github.com/bwmarrin/discordgo"
- "github.com/rs/zerolog"
- "maunium.net/go/mautrix"
- "maunium.net/go/mautrix/bridge/bridgeconfig"
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/id"
- "go.mau.fi/mautrix-discord/database"
- )
- func (portal *Portal) forwardBackfillInitial(source *User) {
- defer portal.forwardBackfillLock.Unlock()
- // This should only be called from CreateMatrixRoom which locks forwardBackfillLock before creating the room.
- if portal.forwardBackfillLock.TryLock() {
- panic("forwardBackfillInitial() called without locking forwardBackfillLock")
- }
- limit := portal.bridge.Config.Bridge.Backfill.Limits.Initial.Channel
- if portal.GuildID == "" {
- limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.DM
- }
- if limit == 0 {
- return
- }
- log := portal.log.With().
- Str("action", "initial backfill").
- Str("room_id", portal.MXID.String()).
- Int("limit", limit).
- Logger()
- portal.backfillLimited(log, source, limit, "")
- }
- func (portal *Portal) ForwardBackfillMissed(source *User, meta *discordgo.Channel) {
- if portal.MXID == "" {
- return
- }
- limit := portal.bridge.Config.Bridge.Backfill.Limits.Missed.Channel
- if portal.GuildID == "" {
- limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.DM
- }
- if limit == 0 {
- return
- }
- log := portal.log.With().
- Str("action", "missed event backfill").
- Str("room_id", portal.MXID.String()).
- Int("limit", limit).
- Logger()
- portal.forwardBackfillLock.Lock()
- defer portal.forwardBackfillLock.Unlock()
- lastMessage := portal.bridge.DB.Message.GetLast(portal.Key)
- if lastMessage == nil || meta.LastMessageID == "" {
- log.Debug().Msg("Not backfilling, no last message in database or no last message in metadata")
- return
- } else if !shouldBackfill(lastMessage.DiscordID, meta.LastMessageID) {
- log.Debug().
- Str("last_bridged_message", lastMessage.DiscordID).
- Str("last_server_message", meta.LastMessageID).
- Msg("Not backfilling, last message in database is newer than last message in metadata")
- return
- }
- log.Debug().
- Str("last_bridged_message", lastMessage.DiscordID).
- Str("last_server_message", meta.LastMessageID).
- Msg("Backfilling missed messages")
- if limit < 0 {
- portal.backfillUnlimitedMissed(log, source, lastMessage.DiscordID)
- } else {
- portal.backfillLimited(log, source, limit, lastMessage.DiscordID)
- }
- }
- const messageFetchChunkSize = 50
- func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, limit int, until string) ([]*discordgo.Message, bool, error) {
- var messages []*discordgo.Message
- var before string
- var foundAll bool
- for {
- log.Debug().Str("before_id", before).Msg("Fetching messages for backfill")
- newMessages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, before, "", "")
- if err != nil {
- return nil, false, err
- }
- if until != "" {
- for i, msg := range newMessages {
- if compareMessageIDs(msg.ID, until) <= 0 {
- log.Debug().
- Str("message_id", msg.ID).
- Str("until_id", until).
- Msg("Found message that was already bridged")
- newMessages = newMessages[:i]
- foundAll = true
- break
- }
- }
- }
- messages = append(messages, newMessages...)
- log.Debug().Int("count", len(newMessages)).Msg("Added messages to backfill collection")
- if len(newMessages) < messageFetchChunkSize || len(messages) >= limit {
- break
- }
- before = newMessages[len(newMessages)-1].ID
- }
- if len(messages) > limit {
- foundAll = false
- messages = messages[:limit]
- }
- return messages, foundAll, nil
- }
- func (portal *Portal) backfillLimited(log zerolog.Logger, source *User, limit int, after string) {
- messages, foundAll, err := portal.collectBackfillMessages(log, source, limit, after)
- if err != nil {
- log.Err(err).Msg("Error collecting messages to forward backfill")
- return
- }
- log.Info().
- Int("count", len(messages)).
- Bool("found_all", foundAll).
- Msg("Collected messages to backfill")
- sort.Sort(MessageSlice(messages))
- if !foundAll {
- _, err = portal.sendMatrixMessage(portal.MainIntent(), event.EventMessage, &event.MessageEventContent{
- MsgType: event.MsgNotice,
- Body: "Some messages may have been missed here while the bridge was offline.",
- }, nil, 0)
- if err != nil {
- log.Warn().Err(err).Msg("Failed to send missed message warning")
- } else {
- log.Debug().Msg("Sent warning about possibly missed messages")
- }
- }
- portal.sendBackfillBatch(log, source, messages)
- }
- func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, after string) {
- for {
- log.Debug().Str("after_id", after).Msg("Fetching chunk of messages to backfill")
- messages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, "", after, "")
- if err != nil {
- log.Err(err).Msg("Error fetching chunk of messages to forward backfill")
- return
- }
- log.Debug().Int("count", len(messages)).Msg("Fetched chunk of messages to backfill")
- sort.Sort(MessageSlice(messages))
- portal.sendBackfillBatch(log, source, messages)
- if len(messages) < messageFetchChunkSize {
- // Assume that was all the missing messages
- log.Debug().Msg("Chunk had less than 50 messages, stopping backfill")
- return
- }
- after = messages[len(messages)-1].ID
- }
- }
- func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messages []*discordgo.Message) {
- if portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry {
- log.Debug().Msg("Using hungryserv, sending messages with batch send endpoint")
- portal.forwardBatchSend(log, source, messages)
- } else {
- log.Debug().Msg("Not using hungryserv, sending messages one by one")
- for _, msg := range messages {
- portal.handleDiscordMessageCreate(source, msg, nil)
- }
- }
- }
- func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message) {
- evts, dbMessages := portal.convertMessageBatch(log, source, messages)
- if len(evts) == 0 {
- log.Warn().Msg("Didn't get any events to backfill")
- return
- }
- log.Info().Int("events", len(evts)).Msg("Converted messages to backfill")
- resp, err := portal.MainIntent().BatchSend(portal.MXID, &mautrix.ReqBatchSend{
- BeeperNewMessages: true,
- Events: evts,
- })
- if err != nil {
- log.Err(err).Msg("Error sending backfill batch")
- return
- }
- for i, evtID := range resp.EventIDs {
- dbMessages[i].MXID = evtID
- }
- portal.bridge.DB.Message.MassInsert(portal.Key, dbMessages)
- log.Info().Msg("Inserted backfilled batch to database")
- }
- func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, messages []*discordgo.Message) ([]*event.Event, []database.Message) {
- evts := make([]*event.Event, 0, len(messages))
- dbMessages := make([]database.Message, 0, len(messages))
- ctx := context.Background()
- for _, msg := range messages {
- for _, mention := range msg.Mentions {
- puppet := portal.bridge.GetPuppetByID(mention.ID)
- puppet.UpdateInfo(nil, mention)
- }
- puppet := portal.bridge.GetPuppetByID(msg.Author.ID)
- puppet.UpdateInfo(source, msg.Author)
- intent := puppet.IntentFor(portal)
- replyTo := portal.getReplyTarget(source, "", msg.MessageReference, msg.Embeds, true)
- ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
- log := log.With().
- Str("message_id", msg.ID).
- Int("message_type", int(msg.Type)).
- Str("author_id", msg.Author.ID).
- Logger()
- parts := portal.convertDiscordMessage(log.WithContext(ctx), intent, msg)
- for i, part := range parts {
- if replyTo != nil {
- part.Content.RelatesTo = &event.RelatesTo{InReplyTo: replyTo}
- // Only set reply for first event
- replyTo = nil
- }
- partName := part.AttachmentID
- // Always use blank part name for first part so that replies and other things
- // can reference it without knowing about attachments.
- if i == 0 {
- partName = ""
- }
- evt := &event.Event{
- ID: portal.deterministicEventID(msg.ID, partName),
- Type: part.Type,
- Sender: intent.UserID,
- Timestamp: ts.UnixMilli(),
- Content: event.Content{
- Parsed: part.Content,
- Raw: part.Extra,
- },
- }
- var err error
- evt.Type, err = portal.encrypt(intent, &evt.Content, evt.Type)
- if err != nil {
- log.Err(err).Msg("Failed to encrypt event")
- continue
- }
- intent.AddDoublePuppetValue(&evt.Content)
- evts = append(evts, evt)
- dbMessages = append(dbMessages, database.Message{
- Channel: portal.Key,
- DiscordID: msg.ID,
- SenderID: msg.Author.ID,
- Timestamp: ts,
- AttachmentID: part.AttachmentID,
- })
- }
- }
- return evts, dbMessages
- }
- func (portal *Portal) deterministicEventID(messageID, partName string) id.EventID {
- data := fmt.Sprintf("%s/discord/%s/%s", portal.MXID, messageID, partName)
- sum := sha256.Sum256([]byte(data))
- return id.EventID(fmt.Sprintf("$%s:discord.com", base64.RawURLEncoding.EncodeToString(sum[:])))
- }
- // compareMessageIDs compares two Discord message IDs.
- //
- // If the first ID is lower, -1 is returned.
- // If the second ID is lower, 1 is returned.
- // If the IDs are equal, 0 is returned.
- func compareMessageIDs(id1, id2 string) int {
- if id1 == id2 {
- return 0
- }
- if len(id1) < len(id2) {
- return -1
- } else if len(id2) < len(id1) {
- return 1
- }
- if id1 < id2 {
- return -1
- }
- return 1
- }
- func shouldBackfill(latestBridgedIDStr, latestIDFromServerStr string) bool {
- return compareMessageIDs(latestBridgedIDStr, latestIDFromServerStr) == -1
- }
- type MessageSlice []*discordgo.Message
- var _ sort.Interface = (MessageSlice)(nil)
- func (a MessageSlice) Len() int {
- return len(a)
- }
- func (a MessageSlice) Swap(i, j int) {
- a[i], a[j] = a[j], a[i]
- }
- func (a MessageSlice) Less(i, j int) bool {
- return compareMessageIDs(a[i].ID, a[j].ID) == -1
- }
|