Procházet zdrojové kódy

historysync: delete history sync messages once backfilled

Sumner Evans před 3 roky
rodič
revize
830c294b91

+ 30 - 4
database/historysync.go

@@ -20,6 +20,8 @@ import (
 	"database/sql"
 	"errors"
 	"fmt"
+	"strconv"
+	"strings"
 	"time"
 
 	waProto "go.mau.fi/whatsmeow/binary/proto"
@@ -203,7 +205,7 @@ func (hsq *HistorySyncQuery) DeleteAllConversations(userID id.UserID) error {
 
 const (
 	getMessagesBetween = `
-		SELECT data
+		SELECT id, data
 		  FROM history_sync_message
 		 WHERE user_mxid=$1
 		   AND conversation_id=$2
@@ -211,18 +213,28 @@ const (
 		 ORDER BY timestamp DESC
 		 %s
 	`
+	deleteMessages = `
+		DELETE FROM history_sync_message
+		 WHERE id IN (%s)
+	`
 )
 
 type HistorySyncMessage struct {
 	db  *Database
 	log log.Logger
 
+	ID             int
 	UserID         id.UserID
 	ConversationID string
 	Timestamp      time.Time
 	Data           []byte
 }
 
+type WrappedWebMessageInfo struct {
+	ID      int
+	Message *waProto.WebMessageInfo
+}
+
 func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversationID string, message *waProto.HistorySyncMsg) (*HistorySyncMessage, error) {
 	msgData, err := proto.Marshal(message)
 	if err != nil {
@@ -248,7 +260,7 @@ func (hsm *HistorySyncMessage) Insert() {
 	}
 }
 
-func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) (messages []*waProto.WebMessageInfo) {
+func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) (messages []*WrappedWebMessageInfo) {
 	whereClauses := ""
 	args := []interface{}{userID, conversationID}
 	argNum := 3
@@ -272,9 +284,10 @@ func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID
 	if err != nil || rows == nil {
 		return nil
 	}
+	var msgID int
 	var msgData []byte
 	for rows.Next() {
-		err := rows.Scan(&msgData)
+		err := rows.Scan(&msgID, &msgData)
 		if err != nil {
 			hsq.log.Error("Database scan failed: %v", err)
 			continue
@@ -285,11 +298,24 @@ func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID
 			hsq.log.Errorf("Failed to unmarshal history sync message: %v", err)
 			continue
 		}
-		messages = append(messages, historySyncMsg.Message)
+		messages = append(messages, &WrappedWebMessageInfo{
+			ID:      msgID,
+			Message: historySyncMsg.Message,
+		})
 	}
 	return
 }
 
+func (hsq *HistorySyncQuery) DeleteMessages(messages []*WrappedWebMessageInfo) error {
+	messageIDs := make([]string, len(messages))
+	for i, msg := range messages {
+		messageIDs[i] = strconv.Itoa(msg.ID)
+	}
+
+	_, err := hsq.db.Exec(fmt.Sprintf(deleteMessages, strings.Join(messageIDs, ",")))
+	return err
+}
+
 func (hsq *HistorySyncQuery) DeleteAllMessages(userID id.UserID) error {
 	_, err := hsq.db.Exec("DELETE FROM history_sync_message WHERE user_mxid=$1", userID)
 	return err

+ 1 - 0
database/message.go

@@ -143,6 +143,7 @@ type Message struct {
 	db  *Database
 	log log.Logger
 
+	ID        int
 	Chat      PortalKey
 	JID       types.MessageID
 	MXID      id.EventID

+ 2 - 0
database/upgrades/2022-03-18-historysync-store.go

@@ -34,6 +34,7 @@ func init() {
 			}
 			_, err = tx.Exec(`
 				CREATE TABLE history_sync_message (
+					id                       SERIAL PRIMARY KEY,
 					user_mxid                TEXT,
 					conversation_id          TEXT,
 					timestamp                TIMESTAMP,
@@ -74,6 +75,7 @@ func init() {
 			}
 			_, err = tx.Exec(`
 				CREATE TABLE history_sync_message (
+					id                       INTEGER PRIMARY KEY,
 					user_mxid                TEXT,
 					conversation_id          TEXT,
 					timestamp                DATETIME,

+ 10 - 5
historysync.go

@@ -126,7 +126,7 @@ func (user *User) handleBackfillRequestsLoop(backfillRequests chan *database.Bac
 					break
 				}
 
-				var msgs []*waProto.WebMessageInfo
+				var msgs []*database.WrappedWebMessageInfo
 				if len(toBackfill) <= req.MaxBatchEvents {
 					msgs = toBackfill
 					toBackfill = toBackfill[0:0]
@@ -144,9 +144,14 @@ func (user *User) handleBackfillRequestsLoop(backfillRequests chan *database.Bac
 			user.log.Debugf("Finished backfilling %d messages in %s", len(allMsgs), portal.Key.JID)
 			if len(insertionEventIds) > 0 {
 				portal.sendPostBackfillDummy(
-					time.Unix(int64(allMsgs[len(allMsgs)-1].GetMessageTimestamp()), 0),
+					time.Unix(int64(allMsgs[len(allMsgs)-1].Message.GetMessageTimestamp()), 0),
 					insertionEventIds[0])
 			}
+			user.log.Debugf("Deleting %d history sync messages after backfilling", len(allMsgs))
+			err := user.bridge.DB.HistorySyncQuery.DeleteMessages(allMsgs)
+			if err != nil {
+				user.log.Warnf("Failed to delete %d history sync messages after backfilling: %v", len(allMsgs), err)
+			}
 		} else {
 			user.log.Debugfln("Not backfilling %s: no bridgeable messages found", portal.Key.JID)
 		}
@@ -288,14 +293,14 @@ var (
 	MSC2716Marker         = event.Type{Type: "org.matrix.msc2716.marker", Class: event.MessageEventType}
 )
 
-func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo) []id.EventID {
+func (portal *Portal) backfill(source *User, messages []*database.WrappedWebMessageInfo) []id.EventID {
 	portal.backfillLock.Lock()
 	defer portal.backfillLock.Unlock()
 
 	var historyBatch, newBatch mautrix.ReqBatchSend
 	var historyBatchInfos, newBatchInfos []*wrappedInfo
 
-	firstMsgTimestamp := time.Unix(int64(messages[len(messages)-1].GetMessageTimestamp()), 0)
+	firstMsgTimestamp := time.Unix(int64(messages[len(messages)-1].Message.GetMessageTimestamp()), 0)
 
 	historyBatch.StateEventsAtStart = make([]*event.Event, 0)
 	newBatch.StateEventsAtStart = make([]*event.Event, 0)
@@ -350,7 +355,7 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo)
 	portal.log.Infofln("Processing history sync with %d messages", len(messages))
 	// The messages are ordered newest to oldest, so iterate them in reverse order.
 	for i := len(messages) - 1; i >= 0; i-- {
-		webMsg := messages[i]
+		webMsg := messages[i].Message
 		msgType := getMessageType(webMsg.GetMessage())
 		if msgType == "unknown" || msgType == "ignore" || msgType == "unknown_protocol" {
 			if msgType != "ignore" {