backfill.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. package main
  2. import (
  3. "context"
  4. "crypto/sha256"
  5. "encoding/base64"
  6. "fmt"
  7. "sort"
  8. "github.com/bwmarrin/discordgo"
  9. "github.com/rs/zerolog"
  10. "maunium.net/go/mautrix"
  11. "maunium.net/go/mautrix/event"
  12. "maunium.net/go/mautrix/id"
  13. "go.mau.fi/mautrix-discord/database"
  14. )
  15. func (portal *Portal) forwardBackfillInitial(source *User, thread *Thread) {
  16. log := portal.log
  17. defer func() {
  18. log.Debug().Msg("Forward backfill finished, unlocking lock")
  19. portal.forwardBackfillLock.Unlock()
  20. }()
  21. // This should only be called from CreateMatrixRoom which locks forwardBackfillLock before creating the room.
  22. if portal.forwardBackfillLock.TryLock() {
  23. panic("forwardBackfillInitial() called without locking forwardBackfillLock")
  24. }
  25. limit := portal.bridge.Config.Bridge.Backfill.Limits.Initial.Channel
  26. if portal.GuildID == "" {
  27. limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.DM
  28. if thread != nil {
  29. limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.Thread
  30. thread.initialBackfillAttempted = true
  31. }
  32. }
  33. if limit == 0 {
  34. return
  35. }
  36. with := log.With().
  37. Str("action", "initial backfill").
  38. Str("room_id", portal.MXID.String()).
  39. Int("limit", limit)
  40. if thread != nil {
  41. with = with.Str("thread_id", thread.ID)
  42. }
  43. log = with.Logger()
  44. portal.backfillLimited(log, source, limit, "", thread)
  45. }
  46. func (portal *Portal) ForwardBackfillMissed(source *User, serverLastMessageID string, thread *Thread) {
  47. if portal.MXID == "" {
  48. return
  49. }
  50. limit := portal.bridge.Config.Bridge.Backfill.Limits.Missed.Channel
  51. if portal.GuildID == "" {
  52. limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.DM
  53. if thread != nil {
  54. limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.Thread
  55. }
  56. }
  57. if limit == 0 {
  58. return
  59. }
  60. with := portal.log.With().
  61. Str("action", "missed event backfill").
  62. Str("room_id", portal.MXID.String()).
  63. Int("limit", limit)
  64. if thread != nil {
  65. with = with.Str("thread_id", thread.ID)
  66. }
  67. log := with.Logger()
  68. portal.forwardBackfillLock.Lock()
  69. defer portal.forwardBackfillLock.Unlock()
  70. var lastMessage *database.Message
  71. if thread != nil {
  72. lastMessage = portal.bridge.DB.Message.GetLastInThread(portal.Key, thread.ID)
  73. } else {
  74. lastMessage = portal.bridge.DB.Message.GetLast(portal.Key)
  75. }
  76. if lastMessage == nil || serverLastMessageID == "" {
  77. log.Debug().Msg("Not backfilling, no last message in database or no last message in metadata")
  78. return
  79. } else if !shouldBackfill(lastMessage.DiscordID, serverLastMessageID) {
  80. log.Debug().
  81. Str("last_bridged_message", lastMessage.DiscordID).
  82. Str("last_server_message", serverLastMessageID).
  83. Msg("Not backfilling, last message in database is newer than last message in metadata")
  84. return
  85. }
  86. log.Debug().
  87. Str("last_bridged_message", lastMessage.DiscordID).
  88. Str("last_server_message", serverLastMessageID).
  89. Msg("Backfilling missed messages")
  90. if limit < 0 {
  91. portal.backfillUnlimitedMissed(log, source, lastMessage.DiscordID, thread)
  92. } else {
  93. portal.backfillLimited(log, source, limit, lastMessage.DiscordID, thread)
  94. }
  95. }
  96. const messageFetchChunkSize = 50
  97. func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, limit int, until string, thread *Thread) ([]*discordgo.Message, bool, error) {
  98. var messages []*discordgo.Message
  99. var before string
  100. var foundAll bool
  101. protoChannelID := portal.Key.ChannelID
  102. if thread != nil {
  103. protoChannelID = thread.ID
  104. }
  105. for {
  106. log.Debug().Str("before_id", before).Msg("Fetching messages for backfill")
  107. newMessages, err := source.Session.ChannelMessages(protoChannelID, messageFetchChunkSize, before, "", "")
  108. if err != nil {
  109. return nil, false, err
  110. }
  111. if until != "" {
  112. for i, msg := range newMessages {
  113. if compareMessageIDs(msg.ID, until) <= 0 {
  114. log.Debug().
  115. Str("message_id", msg.ID).
  116. Str("until_id", until).
  117. Msg("Found message that was already bridged")
  118. newMessages = newMessages[:i]
  119. foundAll = true
  120. break
  121. }
  122. }
  123. }
  124. messages = append(messages, newMessages...)
  125. log.Debug().Int("count", len(newMessages)).Msg("Added messages to backfill collection")
  126. if len(newMessages) < messageFetchChunkSize || len(messages) >= limit {
  127. break
  128. }
  129. before = newMessages[len(newMessages)-1].ID
  130. }
  131. if len(messages) > limit {
  132. foundAll = false
  133. messages = messages[:limit]
  134. }
  135. return messages, foundAll, nil
  136. }
  137. func (portal *Portal) backfillLimited(log zerolog.Logger, source *User, limit int, after string, thread *Thread) {
  138. messages, foundAll, err := portal.collectBackfillMessages(log, source, limit, after, thread)
  139. if err != nil {
  140. log.Err(err).Msg("Error collecting messages to forward backfill")
  141. return
  142. }
  143. log.Info().
  144. Int("count", len(messages)).
  145. Bool("found_all", foundAll).
  146. Msg("Collected messages to backfill")
  147. sort.Sort(MessageSlice(messages))
  148. if !foundAll && after != "" {
  149. _, err = portal.sendMatrixMessage(portal.MainIntent(), event.EventMessage, &event.MessageEventContent{
  150. MsgType: event.MsgNotice,
  151. Body: "Some messages may have been missed here while the bridge was offline.",
  152. }, nil, 0)
  153. if err != nil {
  154. log.Warn().Err(err).Msg("Failed to send missed message warning")
  155. } else {
  156. log.Debug().Msg("Sent warning about possibly missed messages")
  157. }
  158. }
  159. portal.sendBackfillBatch(log, source, messages, thread)
  160. }
  161. func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, after string, thread *Thread) {
  162. protoChannelID := portal.Key.ChannelID
  163. if thread != nil {
  164. protoChannelID = thread.ID
  165. }
  166. for {
  167. log.Debug().Str("after_id", after).Msg("Fetching chunk of messages to backfill")
  168. messages, err := source.Session.ChannelMessages(protoChannelID, messageFetchChunkSize, "", after, "")
  169. if err != nil {
  170. log.Err(err).Msg("Error fetching chunk of messages to forward backfill")
  171. return
  172. }
  173. log.Debug().Int("count", len(messages)).Msg("Fetched chunk of messages to backfill")
  174. sort.Sort(MessageSlice(messages))
  175. portal.sendBackfillBatch(log, source, messages, thread)
  176. if len(messages) < messageFetchChunkSize {
  177. // Assume that was all the missing messages
  178. log.Debug().Msg("Chunk had less than 50 messages, stopping backfill")
  179. return
  180. }
  181. after = messages[len(messages)-1].ID
  182. }
  183. }
  184. func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) {
  185. if portal.bridge.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending) {
  186. log.Debug().Msg("Using hungryserv, sending messages with batch send endpoint")
  187. portal.forwardBatchSend(log, source, messages, thread)
  188. } else {
  189. log.Debug().Msg("Not using hungryserv, sending messages one by one")
  190. for _, msg := range messages {
  191. portal.handleDiscordMessageCreate(source, msg, thread)
  192. }
  193. }
  194. }
  195. func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) {
  196. evts, metas, dbMessages := portal.convertMessageBatch(log, source, messages, thread)
  197. if len(evts) == 0 {
  198. log.Warn().Msg("Didn't get any events to backfill")
  199. return
  200. }
  201. log.Info().Int("events", len(evts)).Msg("Converted messages to backfill")
  202. resp, err := portal.MainIntent().BeeperBatchSend(portal.MXID, &mautrix.ReqBeeperBatchSend{
  203. Forward: true,
  204. Events: evts,
  205. })
  206. if err != nil {
  207. log.Err(err).Msg("Error sending backfill batch")
  208. return
  209. }
  210. for i, evtID := range resp.EventIDs {
  211. dbMessages[i].MXID = evtID
  212. if metas[i] != nil && metas[i].Flags == discordgo.MessageFlagsHasThread {
  213. // TODO proper context
  214. ctx := log.WithContext(context.Background())
  215. portal.bridge.threadFound(ctx, source, &dbMessages[i], metas[i].ID, metas[i].Thread)
  216. }
  217. }
  218. portal.bridge.DB.Message.MassInsert(portal.Key, dbMessages)
  219. }
  220. func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) ([]*event.Event, []*discordgo.Message, []database.Message) {
  221. var discordThreadID string
  222. var threadRootEvent, lastThreadEvent id.EventID
  223. if thread != nil {
  224. discordThreadID = thread.ID
  225. threadRootEvent = thread.RootMXID
  226. lastThreadEvent = threadRootEvent
  227. lastInThread := portal.bridge.DB.Message.GetLastInThread(portal.Key, thread.ID)
  228. if lastInThread != nil {
  229. lastThreadEvent = lastInThread.MXID
  230. }
  231. }
  232. evts := make([]*event.Event, 0, len(messages))
  233. dbMessages := make([]database.Message, 0, len(messages))
  234. metas := make([]*discordgo.Message, 0, len(messages))
  235. ctx := context.Background()
  236. for _, msg := range messages {
  237. for _, mention := range msg.Mentions {
  238. puppet := portal.bridge.GetPuppetByID(mention.ID)
  239. puppet.UpdateInfo(nil, mention, nil)
  240. }
  241. puppet := portal.bridge.GetPuppetByID(msg.Author.ID)
  242. puppet.UpdateInfo(source, msg.Author, msg)
  243. intent := puppet.IntentFor(portal)
  244. replyTo := portal.getReplyTarget(source, discordThreadID, msg.MessageReference, msg.Embeds, true)
  245. mentions := portal.convertDiscordMentions(msg, false)
  246. ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
  247. log := log.With().
  248. Str("message_id", msg.ID).
  249. Int("message_type", int(msg.Type)).
  250. Str("author_id", msg.Author.ID).
  251. Logger()
  252. parts := portal.convertDiscordMessage(log.WithContext(ctx), puppet, intent, msg)
  253. for i, part := range parts {
  254. if (replyTo != nil || threadRootEvent != "") && part.Content.RelatesTo == nil {
  255. part.Content.RelatesTo = &event.RelatesTo{}
  256. }
  257. if threadRootEvent != "" {
  258. part.Content.RelatesTo.SetThread(threadRootEvent, lastThreadEvent)
  259. }
  260. if replyTo != nil {
  261. part.Content.RelatesTo.SetReplyTo(replyTo.EventID)
  262. // Only set reply for first event
  263. replyTo = nil
  264. }
  265. part.Content.Mentions = mentions
  266. // Only set mentions for first event, but keep empty object for rest
  267. mentions = &event.Mentions{}
  268. partName := part.AttachmentID
  269. // Always use blank part name for first part so that replies and other things
  270. // can reference it without knowing about attachments.
  271. if i == 0 {
  272. partName = ""
  273. }
  274. evt := &event.Event{
  275. ID: portal.deterministicEventID(msg.ID, partName),
  276. Type: part.Type,
  277. Sender: intent.UserID,
  278. Timestamp: ts.UnixMilli(),
  279. Content: event.Content{
  280. Parsed: part.Content,
  281. Raw: part.Extra,
  282. },
  283. }
  284. var err error
  285. evt.Type, err = portal.encrypt(intent, &evt.Content, evt.Type)
  286. if err != nil {
  287. log.Err(err).Msg("Failed to encrypt event")
  288. continue
  289. }
  290. intent.AddDoublePuppetValue(&evt.Content)
  291. evts = append(evts, evt)
  292. dbMessages = append(dbMessages, database.Message{
  293. Channel: portal.Key,
  294. DiscordID: msg.ID,
  295. SenderID: msg.Author.ID,
  296. Timestamp: ts,
  297. AttachmentID: part.AttachmentID,
  298. SenderMXID: intent.UserID,
  299. })
  300. if i == 0 {
  301. metas = append(metas, msg)
  302. } else {
  303. metas = append(metas, nil)
  304. }
  305. lastThreadEvent = evt.ID
  306. }
  307. }
  308. return evts, metas, dbMessages
  309. }
  310. func (portal *Portal) deterministicEventID(messageID, partName string) id.EventID {
  311. data := fmt.Sprintf("%s/discord/%s/%s", portal.MXID, messageID, partName)
  312. sum := sha256.Sum256([]byte(data))
  313. return id.EventID(fmt.Sprintf("$%s:discord.com", base64.RawURLEncoding.EncodeToString(sum[:])))
  314. }
  315. // compareMessageIDs compares two Discord message IDs.
  316. //
  317. // If the first ID is lower, -1 is returned.
  318. // If the second ID is lower, 1 is returned.
  319. // If the IDs are equal, 0 is returned.
  320. func compareMessageIDs(id1, id2 string) int {
  321. if id1 == id2 {
  322. return 0
  323. }
  324. if len(id1) < len(id2) {
  325. return -1
  326. } else if len(id2) < len(id1) {
  327. return 1
  328. }
  329. if id1 < id2 {
  330. return -1
  331. }
  332. return 1
  333. }
  334. func shouldBackfill(latestBridgedIDStr, latestIDFromServerStr string) bool {
  335. return compareMessageIDs(latestBridgedIDStr, latestIDFromServerStr) == -1
  336. }
  337. type MessageSlice []*discordgo.Message
  338. var _ sort.Interface = (MessageSlice)(nil)
  339. func (a MessageSlice) Len() int {
  340. return len(a)
  341. }
  342. func (a MessageSlice) Swap(i, j int) {
  343. a[i], a[j] = a[j], a[i]
  344. }
  345. func (a MessageSlice) Less(i, j int) bool {
  346. return compareMessageIDs(a[i].ID, a[j].ID) == -1
  347. }