Przeglądaj źródła

Create threads for backfilled messages

Tulir Asokan 2 lat temu
rodzic
commit
b77eea4586
3 zmienionych plików z 40 dodań i 6 usunięć
  1. 21 3
      backfill.go
  2. 13 3
      portal.go
  3. 6 0
      portal_convert.go

+ 21 - 3
backfill.go

@@ -208,7 +208,7 @@ func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messag
 }
 
 func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) {
-	evts, dbMessages := portal.convertMessageBatch(log, source, messages, thread)
+	evts, metas, dbMessages := portal.convertMessageBatch(log, source, messages, thread)
 	if len(evts) == 0 {
 		log.Warn().Msg("Didn't get any events to backfill")
 		return
@@ -224,12 +224,24 @@ func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, message
 	}
 	for i, evtID := range resp.EventIDs {
 		dbMessages[i].MXID = evtID
+		if metas[i] != nil && metas[i].Flags == discordgo.MessageFlagsHasThread {
+			thread = portal.bridge.GetThreadByID(metas[i].ID, &dbMessages[i])
+			log.Debug().
+				Str("message_id", metas[i].ID).
+				Str("event_id", evtID.String()).
+				Msg("Marked backfilled message as thread root")
+			if thread.CreationNoticeMXID == "" {
+				// TODO proper context
+				ctx := log.WithContext(context.Background())
+				portal.sendThreadCreationNotice(ctx, thread)
+			}
+		}
 	}
 	portal.bridge.DB.Message.MassInsert(portal.Key, dbMessages)
 	log.Info().Msg("Inserted backfilled batch to database")
 }
 
-func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) ([]*event.Event, []database.Message) {
+func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) ([]*event.Event, []*discordgo.Message, []database.Message) {
 	var discordThreadID string
 	var threadRootEvent, lastThreadEvent id.EventID
 	if thread != nil {
@@ -244,6 +256,7 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess
 
 	evts := make([]*event.Event, 0, len(messages))
 	dbMessages := make([]database.Message, 0, len(messages))
+	metas := make([]*discordgo.Message, 0, len(messages))
 	ctx := context.Background()
 	for _, msg := range messages {
 		for _, mention := range msg.Mentions {
@@ -313,10 +326,15 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess
 				AttachmentID: part.AttachmentID,
 				SenderMXID:   intent.UserID,
 			})
+			if i == 0 {
+				metas = append(metas, msg)
+			} else {
+				metas = append(metas, nil)
+			}
 			lastThreadEvent = evt.ID
 		}
 	}
-	return evts, dbMessages
+	return evts, metas, dbMessages
 }
 
 func (portal *Portal) deterministicEventID(messageID, partName string) id.EventID {

+ 13 - 3
portal.go

@@ -586,7 +586,7 @@ func (portal *Portal) ensureUserInvited(user *User, ignoreCache bool) bool {
 	return user.ensureInvited(portal.MainIntent(), portal.MXID, portal.IsPrivateChat(), ignoreCache)
 }
 
-func (portal *Portal) markMessageHandled(discordID string, authorID string, timestamp time.Time, threadID string, senderMXID id.UserID, parts []database.MessagePart) {
+func (portal *Portal) markMessageHandled(discordID string, authorID string, timestamp time.Time, threadID string, senderMXID id.UserID, parts []database.MessagePart) *database.Message {
 	msg := portal.bridge.DB.Message.New()
 	msg.Channel = portal.Key
 	msg.DiscordID = discordID
@@ -595,6 +595,9 @@ func (portal *Portal) markMessageHandled(discordID string, authorID string, time
 	msg.ThreadID = threadID
 	msg.SenderMXID = senderMXID
 	msg.MassInsertParts(parts)
+	msg.MXID = parts[0].MXID
+	msg.AttachmentID = parts[0].AttachmentID
+	return msg
 }
 
 func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Message, thread *Thread) {
@@ -678,7 +681,14 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
 	} else if len(dbParts) == 0 {
 		log.Warn().Msg("All parts of message failed to send to Matrix")
 	} else {
-		portal.markMessageHandled(msg.ID, msg.Author.ID, ts, discordThreadID, intent.UserID, dbParts)
+		firstDBMessage := portal.markMessageHandled(msg.ID, msg.Author.ID, ts, discordThreadID, intent.UserID, dbParts)
+		if msg.Flags == discordgo.MessageFlagsHasThread {
+			thread = portal.bridge.GetThreadByID(msg.ID, firstDBMessage)
+			log.Debug().Msg("Marked message as thread root")
+			if thread.CreationNoticeMXID == "" {
+				portal.sendThreadCreationNotice(ctx, thread)
+			}
+		}
 	}
 }
 
@@ -1463,7 +1473,7 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) {
 		}
 		return
 	} else if threadRoot := content.GetRelatesTo().GetThreadParent(); threadRoot != "" {
-		existingThread := portal.bridge.DB.Thread.GetByMatrixRootMsg(threadRoot)
+		existingThread := portal.bridge.GetThreadByRootMXID(threadRoot)
 		if existingThread != nil {
 			threadID = existingThread.ID
 		} else {

+ 6 - 0
portal_convert.go

@@ -308,6 +308,12 @@ func (portal *Portal) convertDiscordMessage(ctx context.Context, puppet *Puppet,
 			parts = append(parts, part)
 		}
 	}
+	if len(parts) == 0 && msg.Thread != nil {
+		parts = append(parts, &ConvertedMessage{Type: event.EventMessage, Content: &event.MessageEventContent{
+			MsgType: event.MsgText,
+			Body:    fmt.Sprintf("Created a thread: %s", msg.Thread.Name),
+		}})
+	}
 	for _, part := range parts {
 		puppet.addWebhookMeta(part, msg)
 		puppet.addMemberMeta(part, msg)