backfill.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. package main
  2. import (
  3. "crypto/sha256"
  4. "encoding/base64"
  5. "fmt"
  6. "sort"
  7. "github.com/bwmarrin/discordgo"
  8. "github.com/rs/zerolog"
  9. "maunium.net/go/mautrix"
  10. "maunium.net/go/mautrix/bridge/bridgeconfig"
  11. "maunium.net/go/mautrix/event"
  12. "maunium.net/go/mautrix/id"
  13. "go.mau.fi/mautrix-discord/database"
  14. )
  15. func (portal *Portal) ForwardBackfillMissed(source *User, meta *discordgo.Channel) {
  16. limit := portal.bridge.Config.Bridge.Backfill.Limits.Missed.Channel
  17. if portal.GuildID == "" {
  18. limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.DM
  19. }
  20. if limit == 0 {
  21. return
  22. }
  23. log := portal.zlog.With().
  24. Str("action", "missed event backfill").
  25. Int("limit", limit).
  26. Logger()
  27. portal.forwardBackfillLock.Lock()
  28. defer portal.forwardBackfillLock.Unlock()
  29. lastMessage := portal.bridge.DB.Message.GetLast(portal.Key)
  30. if lastMessage == nil || meta.LastMessageID == "" {
  31. log.Debug().Msg("Not backfilling, no last message in database or no last message in metadata")
  32. return
  33. } else if !shouldBackfill(lastMessage.DiscordID, meta.LastMessageID) {
  34. log.Debug().
  35. Str("last_bridged_message", lastMessage.DiscordID).
  36. Str("last_server_message", meta.LastMessageID).
  37. Msg("Not backfilling, last message in database is newer than last message in metadata")
  38. return
  39. }
  40. log.Debug().
  41. Str("last_bridged_message", lastMessage.DiscordID).
  42. Str("last_server_message", meta.LastMessageID).
  43. Msg("Backfilling missed messages")
  44. if limit < 0 {
  45. portal.backfillUnlimitedMissed(log, source, lastMessage.DiscordID)
  46. } else {
  47. portal.backfillLimitedMissed(log, source, limit, lastMessage.DiscordID)
  48. }
  49. }
  50. const messageFetchChunkSize = 50
  51. func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, limit int, until string) ([]*discordgo.Message, bool, error) {
  52. var messages []*discordgo.Message
  53. var before string
  54. var foundAll bool
  55. for {
  56. log.Debug().Str("before_id", before).Msg("Fetching messages for backfill")
  57. newMessages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, before, "", "")
  58. if err != nil {
  59. return nil, false, err
  60. }
  61. for i, msg := range newMessages {
  62. if compareMessageIDs(msg.ID, until) <= 0 {
  63. log.Debug().
  64. Str("message_id", msg.ID).
  65. Str("until_id", until).
  66. Msg("Found message that was already bridged")
  67. newMessages = newMessages[:i]
  68. foundAll = true
  69. break
  70. }
  71. }
  72. messages = append(messages, newMessages...)
  73. log.Debug().Int("count", len(newMessages)).Msg("Added messages to backfill collection")
  74. if len(newMessages) <= messageFetchChunkSize || len(messages) >= limit {
  75. break
  76. }
  77. before = newMessages[len(newMessages)-1].ID
  78. }
  79. if len(messages) > limit {
  80. messages = messages[:limit]
  81. }
  82. return messages, foundAll, nil
  83. }
  84. func (portal *Portal) backfillLimitedMissed(log zerolog.Logger, source *User, limit int, after string) {
  85. messages, foundAll, err := portal.collectBackfillMessages(log, source, limit, after)
  86. if err != nil {
  87. log.Err(err).Msg("Error collecting messages to forward backfill")
  88. return
  89. }
  90. log.Info().
  91. Int("count", len(messages)).
  92. Bool("found_all", foundAll).
  93. Msg("Collected messages to backfill")
  94. sort.Sort(MessageSlice(messages))
  95. if !foundAll {
  96. _, err = portal.sendMatrixMessage(portal.MainIntent(), event.EventMessage, &event.MessageEventContent{
  97. MsgType: event.MsgNotice,
  98. Body: "Some messages may have been missed here while the bridge was offline.",
  99. }, nil, 0)
  100. if err != nil {
  101. log.Warn().Err(err).Msg("Failed to send missed message warning")
  102. } else {
  103. log.Debug().Msg("Sent warning about possibly missed messages")
  104. }
  105. }
  106. portal.sendBackfillBatch(log, source, messages)
  107. }
  108. func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, after string) {
  109. for {
  110. log.Debug().Str("after_id", after).Msg("Fetching chunk of messages to backfill")
  111. messages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, "", after, "")
  112. if err != nil {
  113. log.Err(err).Msg("Error fetching chunk of messages to forward backfill")
  114. return
  115. }
  116. log.Debug().Int("count", len(messages)).Msg("Fetched chunk of messages to backfill")
  117. sort.Sort(MessageSlice(messages))
  118. portal.sendBackfillBatch(log, source, messages)
  119. if len(messages) < messageFetchChunkSize {
  120. // Assume that was all the missing messages
  121. log.Debug().Msg("Chunk had less than 50 messages, stopping backfill")
  122. return
  123. }
  124. after = messages[len(messages)-1].ID
  125. }
  126. }
  127. func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messages []*discordgo.Message) {
  128. if portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry {
  129. log.Debug().Msg("Using hungryserv, sending messages with batch send endpoint")
  130. portal.forwardBatchSend(log, source, messages)
  131. } else {
  132. log.Debug().Msg("Not using hungryserv, sending messages one by one")
  133. for _, msg := range messages {
  134. portal.handleDiscordMessageCreate(source, msg, nil)
  135. }
  136. }
  137. }
  138. func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message) {
  139. evts := make([]*event.Event, 0, len(messages))
  140. dbMessages := make([]database.Message, 0, len(messages))
  141. for _, msg := range messages {
  142. for _, mention := range msg.Mentions {
  143. puppet := portal.bridge.GetPuppetByID(mention.ID)
  144. puppet.UpdateInfo(nil, mention)
  145. }
  146. puppet := portal.bridge.GetPuppetByID(msg.Author.ID)
  147. puppet.UpdateInfo(source, msg.Author)
  148. intent := puppet.IntentFor(portal)
  149. replyTo := portal.getReplyTarget(source, msg.MessageReference, true)
  150. ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
  151. parts := portal.convertDiscordMessage(intent, msg)
  152. for i, part := range parts {
  153. if replyTo != nil {
  154. part.Content.RelatesTo = &event.RelatesTo{InReplyTo: replyTo}
  155. // Only set reply for first event
  156. replyTo = nil
  157. }
  158. partName := part.AttachmentID
  159. // Always use blank part name for first part so that replies and other things
  160. // can reference it without knowing about attachments.
  161. if i == 0 {
  162. partName = ""
  163. }
  164. evts = append(evts, &event.Event{
  165. ID: portal.deterministicEventID(msg.ID, partName),
  166. Type: part.Type,
  167. Sender: intent.UserID,
  168. Timestamp: ts.UnixMilli(),
  169. Content: event.Content{
  170. Parsed: part.Content,
  171. Raw: part.Extra,
  172. },
  173. })
  174. dbMessages = append(dbMessages, database.Message{
  175. Channel: portal.Key,
  176. DiscordID: msg.ID,
  177. SenderID: msg.Author.ID,
  178. Timestamp: ts,
  179. AttachmentID: part.AttachmentID,
  180. })
  181. }
  182. }
  183. log.Info().Int("parts", len(evts)).Msg("Converted messages to backfill")
  184. resp, err := portal.MainIntent().BatchSend(portal.MXID, &mautrix.ReqBatchSend{
  185. BeeperNewMessages: true,
  186. Events: evts,
  187. })
  188. if err != nil {
  189. log.Err(err).Msg("Error sending backfill batch")
  190. return
  191. }
  192. for i, evtID := range resp.EventIDs {
  193. dbMessages[i].MXID = evtID
  194. }
  195. portal.bridge.DB.Message.MassInsert(portal.Key, dbMessages)
  196. log.Info().Msg("Inserted backfilled batch to database")
  197. }
  198. func (portal *Portal) deterministicEventID(messageID, partName string) id.EventID {
  199. data := fmt.Sprintf("%s/discord/%s/%s", portal.MXID, messageID, partName)
  200. sum := sha256.Sum256([]byte(data))
  201. return id.EventID(fmt.Sprintf("$%s:discord.com", base64.RawURLEncoding.EncodeToString(sum[:])))
  202. }
  203. // compareMessageIDs compares two Discord message IDs.
  204. //
  205. // If the first ID is lower, -1 is returned.
  206. // If the second ID is lower, 1 is returned.
  207. // If the IDs are equal, 0 is returned.
  208. func compareMessageIDs(id1, id2 string) int {
  209. if id1 == id2 {
  210. return 0
  211. }
  212. if len(id1) < len(id2) {
  213. return -1
  214. } else if len(id2) < len(id1) {
  215. return 1
  216. }
  217. if id1 < id2 {
  218. return -1
  219. }
  220. return 1
  221. }
  222. func shouldBackfill(latestBridgedIDStr, latestIDFromServerStr string) bool {
  223. return compareMessageIDs(latestBridgedIDStr, latestIDFromServerStr) == -1
  224. }
  225. type MessageSlice []*discordgo.Message
  226. var _ sort.Interface = (MessageSlice)(nil)
  227. func (a MessageSlice) Len() int {
  228. return len(a)
  229. }
  230. func (a MessageSlice) Swap(i, j int) {
  231. a[i], a[j] = a[j], a[i]
  232. }
  233. func (a MessageSlice) Less(i, j int) bool {
  234. return compareMessageIDs(a[i].ID, a[j].ID) == -1
  235. }