فهرست منبع

historysync: use userID, conversationID, messageID as PK

Sumner Evans 3 سال پیش
والد
کامیت
eb0a13a753
3فایلهای تغییر یافته به همراه32 افزوده شده و 37 حذف شده
  1. 19 24
      database/historysync.go
  2. 6 6
      database/upgrades/2022-03-18-historysync-store.go
  3. 7 7
      historysync.go

+ 19 - 24
database/historysync.go

@@ -20,7 +20,6 @@ import (
 	"database/sql"
 	"errors"
 	"fmt"
-	"strconv"
 	"strings"
 	"time"
 
@@ -205,7 +204,7 @@ func (hsq *HistorySyncQuery) DeleteAllConversations(userID id.UserID) error {
 
 const (
 	getMessagesBetween = `
-		SELECT id, data
+		SELECT data
 		  FROM history_sync_message
 		 WHERE user_mxid=$1
 		   AND conversation_id=$2
@@ -215,7 +214,7 @@ const (
 	`
 	deleteMessages = `
 		DELETE FROM history_sync_message
-		 WHERE id IN (%s)
+		 WHERE %s
 	`
 )
 
@@ -223,19 +222,14 @@ type HistorySyncMessage struct {
 	db  *Database
 	log log.Logger
 
-	ID             int
 	UserID         id.UserID
 	ConversationID string
+	MessageID      string
 	Timestamp      time.Time
 	Data           []byte
 }
 
-type WrappedWebMessageInfo struct {
-	ID      int
-	Message *waProto.WebMessageInfo
-}
-
-func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversationID string, message *waProto.HistorySyncMsg) (*HistorySyncMessage, error) {
+func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversationID, messageID string, message *waProto.HistorySyncMsg) (*HistorySyncMessage, error) {
 	msgData, err := proto.Marshal(message)
 	if err != nil {
 		return nil, err
@@ -245,6 +239,7 @@ func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversation
 		log:            hsq.log,
 		UserID:         userID,
 		ConversationID: conversationID,
+		MessageID:      messageID,
 		Timestamp:      time.Unix(int64(message.Message.GetMessageTimestamp()), 0),
 		Data:           msgData,
 	}, nil
@@ -252,15 +247,16 @@ func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversation
 
 func (hsm *HistorySyncMessage) Insert() {
 	_, err := hsm.db.Exec(`
-		INSERT INTO history_sync_message (user_mxid, conversation_id, timestamp, data)
-		VALUES ($1, $2, $3, $4)
-	`, hsm.UserID, hsm.ConversationID, hsm.Timestamp, hsm.Data)
+		INSERT INTO history_sync_message (user_mxid, conversation_id, message_id, timestamp, data)
+		VALUES ($1, $2, $3, $4, $5)
+		ON CONFLICT (user_mxid, conversation_id, message_id) DO NOTHING
+	`, hsm.UserID, hsm.ConversationID, hsm.MessageID, hsm.Timestamp, hsm.Data)
 	if err != nil {
 		hsm.log.Warnfln("Failed to insert history sync message %s/%s: %v", hsm.ConversationID, hsm.Timestamp, err)
 	}
 }
 
-func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) (messages []*WrappedWebMessageInfo) {
+func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) (messages []*waProto.WebMessageInfo) {
 	whereClauses := ""
 	args := []interface{}{userID, conversationID}
 	argNum := 3
@@ -284,10 +280,10 @@ func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID
 	if err != nil || rows == nil {
 		return nil
 	}
-	var msgID int
+
 	var msgData []byte
 	for rows.Next() {
-		err := rows.Scan(&msgID, &msgData)
+		err := rows.Scan(&msgData)
 		if err != nil {
 			hsq.log.Error("Database scan failed: %v", err)
 			continue
@@ -298,21 +294,20 @@ func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID
 			hsq.log.Errorf("Failed to unmarshal history sync message: %v", err)
 			continue
 		}
-		messages = append(messages, &WrappedWebMessageInfo{
-			ID:      msgID,
-			Message: historySyncMsg.Message,
-		})
+		messages = append(messages, historySyncMsg.Message)
 	}
 	return
 }
 
-func (hsq *HistorySyncQuery) DeleteMessages(messages []*WrappedWebMessageInfo) error {
-	messageIDs := make([]string, len(messages))
+func (hsq *HistorySyncQuery) DeleteMessages(userID id.UserID, conversationID string, messages []*waProto.WebMessageInfo) error {
+	whereClauses := []string{}
+	preparedStatementArgs := []interface{}{userID, conversationID}
 	for i, msg := range messages {
-		messageIDs[i] = strconv.Itoa(msg.ID)
+		whereClauses = append(whereClauses, fmt.Sprintf("(user_mxid=$1 AND conversation_id=$2 AND message_id=$%d)", i+3))
+		preparedStatementArgs = append(preparedStatementArgs, msg.GetKey().GetId())
 	}
 
-	_, err := hsq.db.Exec(fmt.Sprintf(deleteMessages, strings.Join(messageIDs, ",")))
+	_, err := hsq.db.Exec(fmt.Sprintf(deleteMessages, strings.Join(whereClauses, " OR ")), preparedStatementArgs...)
 	return err
 }
 

+ 6 - 6
database/upgrades/2022-03-18-historysync-store.go

@@ -24,7 +24,6 @@ func init() {
 					unread_count                    INTEGER,
 
 					PRIMARY KEY (user_mxid, conversation_id),
-					UNIQUE (conversation_id),
 					FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
 					FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
 				)
@@ -34,14 +33,15 @@ func init() {
 			}
 			_, err = tx.Exec(`
 				CREATE TABLE history_sync_message (
-					id                       SERIAL PRIMARY KEY,
 					user_mxid                TEXT,
 					conversation_id          TEXT,
+					message_id               TEXT,
 					timestamp                TIMESTAMP,
 					data                     BYTEA,
 
+					PRIMARY KEY (user_mxid, conversation_id, message_id),
 					FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
-					FOREIGN KEY (conversation_id) REFERENCES history_sync_conversation(conversation_id) ON DELETE CASCADE
+					FOREIGN KEY (user_mxid, conversation_id) REFERENCES history_sync_conversation(user_mxid, conversation_id) ON DELETE CASCADE
 				)
 			`)
 			if err != nil {
@@ -65,7 +65,6 @@ func init() {
 					unread_count                    INTEGER,
 
 					PRIMARY KEY (user_mxid, conversation_id),
-					UNIQUE (conversation_id),
 					FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
 					FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
 				)
@@ -75,14 +74,15 @@ func init() {
 			}
 			_, err = tx.Exec(`
 				CREATE TABLE history_sync_message (
-					id                       INTEGER PRIMARY KEY,
 					user_mxid                TEXT,
 					conversation_id          TEXT,
+					message_id               TEXT,
 					timestamp                DATETIME,
 					data                     BLOB,
 
+					PRIMARY KEY (user_mxid, conversation_id, message_id),
 					FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
-					FOREIGN KEY (conversation_id) REFERENCES history_sync_conversation(conversation_id) ON DELETE CASCADE
+					FOREIGN KEY (user_mxid, conversation_id) REFERENCES history_sync_conversation(user_mxid, conversation_id) ON DELETE CASCADE
 				)
 			`)
 			if err != nil {

+ 7 - 7
historysync.go

@@ -134,7 +134,7 @@ func (user *User) createOrUpdatePortalAndBackfillWithLock(req *database.Backfill
 				break
 			}
 
-			var msgs []*database.WrappedWebMessageInfo
+			var msgs []*waProto.WebMessageInfo
 			if len(toBackfill) <= req.MaxBatchEvents {
 				msgs = toBackfill
 				toBackfill = toBackfill[0:0]
@@ -152,11 +152,11 @@ func (user *User) createOrUpdatePortalAndBackfillWithLock(req *database.Backfill
 		user.log.Debugf("Finished backfilling %d messages in %s", len(allMsgs), portal.Key.JID)
 		if len(insertionEventIds) > 0 {
 			portal.sendPostBackfillDummy(
-				time.Unix(int64(allMsgs[len(allMsgs)-1].Message.GetMessageTimestamp()), 0),
+				time.Unix(int64(allMsgs[len(allMsgs)-1].GetMessageTimestamp()), 0),
 				insertionEventIds[0])
 		}
 		user.log.Debugf("Deleting %d history sync messages after backfilling", len(allMsgs))
-		err := user.bridge.DB.HistorySyncQuery.DeleteMessages(allMsgs)
+		err := user.bridge.DB.HistorySyncQuery.DeleteMessages(user.MXID, conv.ConversationID, allMsgs)
 		if err != nil {
 			user.log.Warnf("Failed to delete %d history sync messages after backfilling: %v", len(allMsgs), err)
 		}
@@ -227,7 +227,7 @@ func (user *User) handleHistorySync(reCheckQueue chan bool, evt *waProto.History
 				continue
 			}
 
-			message, err := user.bridge.DB.HistorySyncQuery.NewMessageWithValues(user.MXID, conv.GetId(), msg)
+			message, err := user.bridge.DB.HistorySyncQuery.NewMessageWithValues(user.MXID, conv.GetId(), msg.Message.GetKey().GetId(), msg)
 			if err != nil {
 				user.log.Warnf("Failed to save message %s in %s. Error: %+v", msg.Message.Key.Id, conv.GetId(), err)
 				continue
@@ -306,11 +306,11 @@ var (
 	MSC2716Marker         = event.Type{Type: "org.matrix.msc2716.marker", Class: event.MessageEventType}
 )
 
-func (portal *Portal) backfill(source *User, messages []*database.WrappedWebMessageInfo) []id.EventID {
+func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo) []id.EventID {
 	var historyBatch, newBatch mautrix.ReqBatchSend
 	var historyBatchInfos, newBatchInfos []*wrappedInfo
 
-	firstMsgTimestamp := time.Unix(int64(messages[len(messages)-1].Message.GetMessageTimestamp()), 0)
+	firstMsgTimestamp := time.Unix(int64(messages[len(messages)-1].GetMessageTimestamp()), 0)
 
 	historyBatch.StateEventsAtStart = make([]*event.Event, 0)
 	newBatch.StateEventsAtStart = make([]*event.Event, 0)
@@ -365,7 +365,7 @@ func (portal *Portal) backfill(source *User, messages []*database.WrappedWebMess
 	portal.log.Infofln("Processing history sync with %d messages", len(messages))
 	// The messages are ordered newest to oldest, so iterate them in reverse order.
 	for i := len(messages) - 1; i >= 0; i-- {
-		webMsg := messages[i].Message
+		webMsg := messages[i]
 		msgType := getMessageType(webMsg.GetMessage())
 		if msgType == "unknown" || msgType == "ignore" || msgType == "unknown_protocol" {
 			if msgType != "ignore" {