backfill.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. package main
  2. import (
  3. "context"
  4. "crypto/sha256"
  5. "encoding/base64"
  6. "fmt"
  7. "sort"
  8. "github.com/bwmarrin/discordgo"
  9. "github.com/rs/zerolog"
  10. "maunium.net/go/mautrix"
  11. "maunium.net/go/mautrix/bridge/bridgeconfig"
  12. "maunium.net/go/mautrix/event"
  13. "maunium.net/go/mautrix/id"
  14. "go.mau.fi/mautrix-discord/database"
  15. )
  16. func (portal *Portal) forwardBackfillInitial(source *User) {
  17. defer portal.forwardBackfillLock.Unlock()
  18. // This should only be called from CreateMatrixRoom which locks forwardBackfillLock before creating the room.
  19. if portal.forwardBackfillLock.TryLock() {
  20. panic("forwardBackfillInitial() called without locking forwardBackfillLock")
  21. }
  22. limit := portal.bridge.Config.Bridge.Backfill.Limits.Initial.Channel
  23. if portal.GuildID == "" {
  24. limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.DM
  25. }
  26. if limit == 0 {
  27. return
  28. }
  29. log := portal.log.With().
  30. Str("action", "initial backfill").
  31. Str("room_id", portal.MXID.String()).
  32. Int("limit", limit).
  33. Logger()
  34. portal.backfillLimited(log, source, limit, "")
  35. }
  36. func (portal *Portal) ForwardBackfillMissed(source *User, meta *discordgo.Channel) {
  37. if portal.MXID == "" {
  38. return
  39. }
  40. limit := portal.bridge.Config.Bridge.Backfill.Limits.Missed.Channel
  41. if portal.GuildID == "" {
  42. limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.DM
  43. }
  44. if limit == 0 {
  45. return
  46. }
  47. log := portal.log.With().
  48. Str("action", "missed event backfill").
  49. Str("room_id", portal.MXID.String()).
  50. Int("limit", limit).
  51. Logger()
  52. portal.forwardBackfillLock.Lock()
  53. defer portal.forwardBackfillLock.Unlock()
  54. lastMessage := portal.bridge.DB.Message.GetLast(portal.Key)
  55. if lastMessage == nil || meta.LastMessageID == "" {
  56. log.Debug().Msg("Not backfilling, no last message in database or no last message in metadata")
  57. return
  58. } else if !shouldBackfill(lastMessage.DiscordID, meta.LastMessageID) {
  59. log.Debug().
  60. Str("last_bridged_message", lastMessage.DiscordID).
  61. Str("last_server_message", meta.LastMessageID).
  62. Msg("Not backfilling, last message in database is newer than last message in metadata")
  63. return
  64. }
  65. log.Debug().
  66. Str("last_bridged_message", lastMessage.DiscordID).
  67. Str("last_server_message", meta.LastMessageID).
  68. Msg("Backfilling missed messages")
  69. if limit < 0 {
  70. portal.backfillUnlimitedMissed(log, source, lastMessage.DiscordID)
  71. } else {
  72. portal.backfillLimited(log, source, limit, lastMessage.DiscordID)
  73. }
  74. }
  75. const messageFetchChunkSize = 50
  76. func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, limit int, until string) ([]*discordgo.Message, bool, error) {
  77. var messages []*discordgo.Message
  78. var before string
  79. var foundAll bool
  80. for {
  81. log.Debug().Str("before_id", before).Msg("Fetching messages for backfill")
  82. newMessages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, before, "", "")
  83. if err != nil {
  84. return nil, false, err
  85. }
  86. if until != "" {
  87. for i, msg := range newMessages {
  88. if compareMessageIDs(msg.ID, until) <= 0 {
  89. log.Debug().
  90. Str("message_id", msg.ID).
  91. Str("until_id", until).
  92. Msg("Found message that was already bridged")
  93. newMessages = newMessages[:i]
  94. foundAll = true
  95. break
  96. }
  97. }
  98. }
  99. messages = append(messages, newMessages...)
  100. log.Debug().Int("count", len(newMessages)).Msg("Added messages to backfill collection")
  101. if len(newMessages) < messageFetchChunkSize || len(messages) >= limit {
  102. break
  103. }
  104. before = newMessages[len(newMessages)-1].ID
  105. }
  106. if len(messages) > limit {
  107. foundAll = false
  108. messages = messages[:limit]
  109. }
  110. return messages, foundAll, nil
  111. }
  112. func (portal *Portal) backfillLimited(log zerolog.Logger, source *User, limit int, after string) {
  113. messages, foundAll, err := portal.collectBackfillMessages(log, source, limit, after)
  114. if err != nil {
  115. log.Err(err).Msg("Error collecting messages to forward backfill")
  116. return
  117. }
  118. log.Info().
  119. Int("count", len(messages)).
  120. Bool("found_all", foundAll).
  121. Msg("Collected messages to backfill")
  122. sort.Sort(MessageSlice(messages))
  123. if !foundAll {
  124. _, err = portal.sendMatrixMessage(portal.MainIntent(), event.EventMessage, &event.MessageEventContent{
  125. MsgType: event.MsgNotice,
  126. Body: "Some messages may have been missed here while the bridge was offline.",
  127. }, nil, 0)
  128. if err != nil {
  129. log.Warn().Err(err).Msg("Failed to send missed message warning")
  130. } else {
  131. log.Debug().Msg("Sent warning about possibly missed messages")
  132. }
  133. }
  134. portal.sendBackfillBatch(log, source, messages)
  135. }
  136. func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, after string) {
  137. for {
  138. log.Debug().Str("after_id", after).Msg("Fetching chunk of messages to backfill")
  139. messages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, "", after, "")
  140. if err != nil {
  141. log.Err(err).Msg("Error fetching chunk of messages to forward backfill")
  142. return
  143. }
  144. log.Debug().Int("count", len(messages)).Msg("Fetched chunk of messages to backfill")
  145. sort.Sort(MessageSlice(messages))
  146. portal.sendBackfillBatch(log, source, messages)
  147. if len(messages) < messageFetchChunkSize {
  148. // Assume that was all the missing messages
  149. log.Debug().Msg("Chunk had less than 50 messages, stopping backfill")
  150. return
  151. }
  152. after = messages[len(messages)-1].ID
  153. }
  154. }
  155. func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messages []*discordgo.Message) {
  156. if portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry {
  157. log.Debug().Msg("Using hungryserv, sending messages with batch send endpoint")
  158. portal.forwardBatchSend(log, source, messages)
  159. } else {
  160. log.Debug().Msg("Not using hungryserv, sending messages one by one")
  161. for _, msg := range messages {
  162. portal.handleDiscordMessageCreate(source, msg, nil)
  163. }
  164. }
  165. }
  166. func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message) {
  167. evts, dbMessages := portal.convertMessageBatch(log, source, messages)
  168. if len(evts) == 0 {
  169. log.Warn().Msg("Didn't get any events to backfill")
  170. return
  171. }
  172. log.Info().Int("events", len(evts)).Msg("Converted messages to backfill")
  173. resp, err := portal.MainIntent().BatchSend(portal.MXID, &mautrix.ReqBatchSend{
  174. BeeperNewMessages: true,
  175. Events: evts,
  176. })
  177. if err != nil {
  178. log.Err(err).Msg("Error sending backfill batch")
  179. return
  180. }
  181. for i, evtID := range resp.EventIDs {
  182. dbMessages[i].MXID = evtID
  183. }
  184. portal.bridge.DB.Message.MassInsert(portal.Key, dbMessages)
  185. log.Info().Msg("Inserted backfilled batch to database")
  186. }
  187. func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, messages []*discordgo.Message) ([]*event.Event, []database.Message) {
  188. evts := make([]*event.Event, 0, len(messages))
  189. dbMessages := make([]database.Message, 0, len(messages))
  190. ctx := context.Background()
  191. for _, msg := range messages {
  192. for _, mention := range msg.Mentions {
  193. puppet := portal.bridge.GetPuppetByID(mention.ID)
  194. puppet.UpdateInfo(nil, mention)
  195. }
  196. puppet := portal.bridge.GetPuppetByID(msg.Author.ID)
  197. puppet.UpdateInfo(source, msg.Author)
  198. intent := puppet.IntentFor(portal)
  199. replyTo := portal.getReplyTarget(source, "", msg.MessageReference, msg.Embeds, true)
  200. ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
  201. log := log.With().
  202. Str("message_id", msg.ID).
  203. Int("message_type", int(msg.Type)).
  204. Str("author_id", msg.Author.ID).
  205. Logger()
  206. parts := portal.convertDiscordMessage(log.WithContext(ctx), intent, msg)
  207. for i, part := range parts {
  208. if replyTo != nil {
  209. part.Content.RelatesTo = &event.RelatesTo{InReplyTo: replyTo}
  210. // Only set reply for first event
  211. replyTo = nil
  212. }
  213. partName := part.AttachmentID
  214. // Always use blank part name for first part so that replies and other things
  215. // can reference it without knowing about attachments.
  216. if i == 0 {
  217. partName = ""
  218. }
  219. evt := &event.Event{
  220. ID: portal.deterministicEventID(msg.ID, partName),
  221. Type: part.Type,
  222. Sender: intent.UserID,
  223. Timestamp: ts.UnixMilli(),
  224. Content: event.Content{
  225. Parsed: part.Content,
  226. Raw: part.Extra,
  227. },
  228. }
  229. var err error
  230. evt.Type, err = portal.encrypt(intent, &evt.Content, evt.Type)
  231. if err != nil {
  232. log.Err(err).Msg("Failed to encrypt event")
  233. continue
  234. }
  235. intent.AddDoublePuppetValue(&evt.Content)
  236. evts = append(evts, evt)
  237. dbMessages = append(dbMessages, database.Message{
  238. Channel: portal.Key,
  239. DiscordID: msg.ID,
  240. SenderID: msg.Author.ID,
  241. Timestamp: ts,
  242. AttachmentID: part.AttachmentID,
  243. })
  244. }
  245. }
  246. return evts, dbMessages
  247. }
  248. func (portal *Portal) deterministicEventID(messageID, partName string) id.EventID {
  249. data := fmt.Sprintf("%s/discord/%s/%s", portal.MXID, messageID, partName)
  250. sum := sha256.Sum256([]byte(data))
  251. return id.EventID(fmt.Sprintf("$%s:discord.com", base64.RawURLEncoding.EncodeToString(sum[:])))
  252. }
  253. // compareMessageIDs compares two Discord message IDs.
  254. //
  255. // If the first ID is lower, -1 is returned.
  256. // If the second ID is lower, 1 is returned.
  257. // If the IDs are equal, 0 is returned.
  258. func compareMessageIDs(id1, id2 string) int {
  259. if id1 == id2 {
  260. return 0
  261. }
  262. if len(id1) < len(id2) {
  263. return -1
  264. } else if len(id2) < len(id1) {
  265. return 1
  266. }
  267. if id1 < id2 {
  268. return -1
  269. }
  270. return 1
  271. }
  272. func shouldBackfill(latestBridgedIDStr, latestIDFromServerStr string) bool {
  273. return compareMessageIDs(latestBridgedIDStr, latestIDFromServerStr) == -1
  274. }
  275. type MessageSlice []*discordgo.Message
  276. var _ sort.Interface = (MessageSlice)(nil)
  277. func (a MessageSlice) Len() int {
  278. return len(a)
  279. }
  280. func (a MessageSlice) Swap(i, j int) {
  281. a[i], a[j] = a[j], a[i]
  282. }
  283. func (a MessageSlice) Less(i, j int) bool {
  284. return compareMessageIDs(a[i].ID, a[j].ID) == -1
  285. }