浏览代码

Store edit timestamp in database to deduplicate edits. Fixes #86

Tulir Asokan 2 年之前
父节点
当前提交
4324b60a2c

+ 52 - 28
database/message.go

@@ -19,7 +19,7 @@ type MessageQuery struct {
 }
 
 const (
-	messageSelect = "SELECT dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid FROM message"
+	messageSelect = "SELECT dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid FROM message"
 )
 
 func (mq *MessageQuery) New() *Message {
@@ -46,17 +46,17 @@ func (mq *MessageQuery) scanAll(rows dbutil.Rows, err error) []*Message {
 }
 
 func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) []*Message {
-	query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 AND dc_edit_index=0 ORDER BY dc_attachment_id ASC"
+	query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id ASC"
 	return mq.scanAll(mq.db.Query(query, key.ChannelID, key.Receiver, discordID))
 }
 
 func (mq *MessageQuery) GetFirstByDiscordID(key PortalKey, discordID string) *Message {
-	query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 AND dc_edit_index=0 ORDER BY dc_attachment_id ASC LIMIT 1"
+	query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id ASC LIMIT 1"
 	return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
 }
 
 func (mq *MessageQuery) GetLastByDiscordID(key PortalKey, discordID string) *Message {
-	query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 AND dc_edit_index=0 ORDER BY dc_attachment_id DESC LIMIT 1"
+	query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id DESC LIMIT 1"
 	return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
 }
 
@@ -66,12 +66,12 @@ func (mq *MessageQuery) GetClosestBefore(key PortalKey, threadID string, ts time
 }
 
 func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message {
-	query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 AND dc_edit_index=0 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
+	query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
 	return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID))
 }
 
 func (mq *MessageQuery) GetLast(key PortalKey) *Message {
-	query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_edit_index=0 ORDER BY timestamp DESC LIMIT 1"
+	query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 ORDER BY timestamp DESC LIMIT 1"
 	return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver))
 }
 
@@ -99,7 +99,7 @@ func (mq *MessageQuery) MassInsert(key PortalKey, msgs []Message) {
 	if len(msgs) == 0 {
 		return
 	}
-	valueStringFormat := "($%d, $%d, $%d, $1, $2, $%d, $%d, $%d, $%d)"
+	valueStringFormat := "($%d, $%d, $1, $2, $%d, $%d, $%d, $%d, $%d)"
 	if mq.db.Dialect == dbutil.SQLite {
 		valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
 	}
@@ -111,9 +111,9 @@ func (mq *MessageQuery) MassInsert(key PortalKey, msgs []Message) {
 		baseIndex := 2 + i*7
 		params[baseIndex] = msg.DiscordID
 		params[baseIndex+1] = msg.AttachmentID
-		params[baseIndex+2] = msg.EditIndex
-		params[baseIndex+3] = msg.SenderID
-		params[baseIndex+4] = msg.Timestamp.UnixMilli()
+		params[baseIndex+2] = msg.SenderID
+		params[baseIndex+3] = msg.Timestamp.UnixMilli()
+		params[baseIndex+4] = msg.editTimestampVal()
 		params[baseIndex+5] = msg.ThreadID
 		params[baseIndex+6] = msg.MXID
 		placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7)
@@ -129,13 +129,13 @@ type Message struct {
 	db  *Database
 	log log.Logger
 
-	DiscordID    string
-	AttachmentID string
-	EditIndex    int
-	Channel      PortalKey
-	SenderID     string
-	Timestamp    time.Time
-	ThreadID     string
+	DiscordID     string
+	AttachmentID  string
+	Channel       PortalKey
+	SenderID      string
+	Timestamp     time.Time
+	EditTimestamp time.Time
+	ThreadID      string
 
 	MXID id.EventID
 }
@@ -149,9 +149,9 @@ func (m *Message) DiscordProtoChannelID() string {
 }
 
 func (m *Message) Scan(row dbutil.Scannable) *Message {
-	var ts int64
+	var ts, editTS int64
 
-	err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.EditIndex, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &m.ThreadID, &m.MXID)
+	err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &editTS, &m.ThreadID, &m.MXID)
 	if err != nil {
 		if !errors.Is(err, sql.ErrNoRows) {
 			m.log.Errorln("Database scan failed:", err)
@@ -162,7 +162,10 @@ func (m *Message) Scan(row dbutil.Scannable) *Message {
 	}
 
 	if ts != 0 {
-		m.Timestamp = time.UnixMilli(ts)
+		m.Timestamp = time.UnixMilli(ts).UTC()
+	}
+	if editTS != 0 {
+		m.EditTimestamp = time.Unix(0, editTS).UTC()
 	}
 
 	return m
@@ -170,7 +173,7 @@ func (m *Message) Scan(row dbutil.Scannable) *Message {
 
 const messageInsertQuery = `
 	INSERT INTO message (
-		dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid
+		dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid
 	)
 	VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
 `
@@ -182,6 +185,13 @@ type MessagePart struct {
 	MXID         id.EventID
 }
 
+func (m *Message) editTimestampVal() int64 {
+	if m.EditTimestamp.IsZero() {
+		return 0
+	}
+	return m.EditTimestamp.UnixNano()
+}
+
 func (m *Message) MassInsertParts(msgs []MessagePart) {
 	if len(msgs) == 0 {
 		return
@@ -193,11 +203,11 @@ func (m *Message) MassInsertParts(msgs []MessagePart) {
 	params := make([]interface{}, 7+len(msgs)*2)
 	placeholders := make([]string, len(msgs))
 	params[0] = m.DiscordID
-	params[1] = m.EditIndex
-	params[2] = m.Channel.ChannelID
-	params[3] = m.Channel.Receiver
-	params[4] = m.SenderID
-	params[5] = m.Timestamp.UnixMilli()
+	params[1] = m.Channel.ChannelID
+	params[2] = m.Channel.Receiver
+	params[3] = m.SenderID
+	params[4] = m.Timestamp.UnixMilli()
+	params[5] = m.editTimestampVal()
 	params[6] = m.ThreadID
 	for i, msg := range msgs {
 		params[7+i*2] = msg.AttachmentID
@@ -213,8 +223,8 @@ func (m *Message) MassInsertParts(msgs []MessagePart) {
 
 func (m *Message) Insert() {
 	_, err := m.db.Exec(messageInsertQuery,
-		m.DiscordID, m.AttachmentID, m.EditIndex, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
-		m.Timestamp.UnixMilli(), m.ThreadID, m.MXID)
+		m.DiscordID, m.AttachmentID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
+		m.Timestamp.UnixMilli(), m.editTimestampVal(), m.ThreadID, m.MXID)
 
 	if err != nil {
 		m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err)
@@ -222,6 +232,20 @@ func (m *Message) Insert() {
 	}
 }
 
+const editUpdateQuery = `
+	UPDATE message
+	SET dc_edit_timestamp=$1
+	WHERE dcid=$2 AND dc_attachment_id=$3 AND dc_chan_id=$4 AND dc_chan_receiver=$5 AND dc_edit_timestamp<$1
+`
+
+func (m *Message) UpdateEditTimestamp(ts time.Time) {
+	_, err := m.db.Exec(editUpdateQuery, ts.UnixNano(), m.DiscordID, m.AttachmentID, m.Channel.ChannelID, m.Channel.Receiver)
+	if err != nil {
+		m.log.Warnfln("Failed to update edit timestamp of %s@%s: %v", m.DiscordID, m.Channel, err)
+		panic(err)
+	}
+}
+
 func (m *Message) Delete() {
 	query := "DELETE FROM message WHERE dcid=$1 AND dc_chan_id=$2 AND dc_chan_receiver=$3 AND dc_attachment_id=$4"
 	_, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.AttachmentID)

+ 12 - 13
database/upgrades/00-latest-revision.sql

@@ -1,4 +1,4 @@
--- v0 -> v18 (compatible with v15+): Latest revision
+-- v0 -> v19: Latest revision
 
 CREATE TABLE guild (
     dcid       TEXT PRIMARY KEY,
@@ -104,18 +104,18 @@ CREATE TABLE user_portal (
 );
 
 CREATE TABLE message (
-    dcid             TEXT,
-    dc_attachment_id TEXT,
-    dc_edit_index    INTEGER,
-    dc_chan_id       TEXT,
-    dc_chan_receiver TEXT,
-    dc_sender        TEXT   NOT NULL,
-    timestamp        BIGINT NOT NULL,
-    dc_thread_id     TEXT   NOT NULL,
+    dcid              TEXT,
+    dc_attachment_id  TEXT,
+    dc_chan_id        TEXT,
+    dc_chan_receiver  TEXT,
+    dc_sender         TEXT   NOT NULL,
+    timestamp         BIGINT NOT NULL,
+    dc_edit_timestamp BIGINT NOT NULL,
+    dc_thread_id      TEXT   NOT NULL,
 
     mxid TEXT NOT NULL UNIQUE,
 
-    PRIMARY KEY (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver),
+    PRIMARY KEY (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver),
     CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE
 );
 
@@ -127,13 +127,12 @@ CREATE TABLE reaction (
     dc_emoji_name    TEXT,
     dc_thread_id     TEXT NOT NULL,
 
-    dc_first_attachment_id TEXT    NOT NULL,
-    _dc_first_edit_index   INTEGER NOT NULL DEFAULT 0,
+    dc_first_attachment_id TEXT NOT NULL,
 
     mxid TEXT NOT NULL UNIQUE,
 
     PRIMARY KEY (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name),
-    CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_first_attachment_id, _dc_first_edit_index, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE
+    CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_first_attachment_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE
 );
 
 CREATE TABLE role (

+ 15 - 0
database/upgrades/19-message-edit-ts.postgres.sql

@@ -0,0 +1,15 @@
+-- v19: Replace dc_edit_index with dc_edit_timestamp
+-- transaction: off
+BEGIN;
+
+ALTER TABLE reaction DROP CONSTRAINT reaction_message_fkey;
+ALTER TABLE message DROP CONSTRAINT message_pkey;
+ALTER TABLE message DROP COLUMN dc_edit_index;
+ALTER TABLE reaction DROP COLUMN _dc_first_edit_index;
+ALTER TABLE message ADD PRIMARY KEY (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver);
+ALTER TABLE reaction ADD CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_first_attachment_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE;
+
+ALTER TABLE message ADD COLUMN dc_edit_timestamp BIGINT NOT NULL DEFAULT 0;
+ALTER TABLE message ALTER COLUMN dc_edit_timestamp DROP DEFAULT;
+
+COMMIT;

+ 48 - 0
database/upgrades/19-message-edit-ts.sqlite.sql

@@ -0,0 +1,48 @@
+-- v19: Replace dc_edit_index with dc_edit_timestamp
+-- transaction: off
+PRAGMA foreign_keys = OFF;
+BEGIN;
+
+CREATE TABLE message_new (
+    dcid              TEXT,
+    dc_attachment_id  TEXT,
+    dc_chan_id        TEXT,
+    dc_chan_receiver  TEXT,
+    dc_sender         TEXT   NOT NULL,
+    timestamp         BIGINT NOT NULL,
+    dc_edit_timestamp BIGINT NOT NULL,
+    dc_thread_id      TEXT   NOT NULL,
+
+    mxid TEXT NOT NULL UNIQUE,
+
+    PRIMARY KEY (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver),
+    CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE
+);
+INSERT INTO message_new (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid)
+    SELECT dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, 0, dc_thread_id, mxid FROM message;
+DROP TABLE message;
+ALTER TABLE message_new RENAME TO message;
+
+CREATE TABLE reaction_new (
+    dc_chan_id       TEXT,
+    dc_chan_receiver TEXT,
+    dc_msg_id        TEXT,
+    dc_sender        TEXT,
+    dc_emoji_name    TEXT,
+    dc_thread_id     TEXT NOT NULL,
+
+    dc_first_attachment_id TEXT NOT NULL,
+
+    mxid TEXT NOT NULL UNIQUE,
+
+    PRIMARY KEY (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name),
+    CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_first_attachment_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE
+);
+INSERT INTO reaction_new (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, dc_thread_id, dc_first_attachment_id, mxid)
+    SELECT dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, dc_thread_id, dc_first_attachment_id, mxid FROM reaction;
+DROP TABLE reaction;
+ALTER TABLE reaction_new RENAME TO reaction;
+
+PRAGMA foreign_key_check;
+COMMIT;
+PRAGMA foreign_keys = ON;

+ 18 - 7
portal.go

@@ -583,11 +583,10 @@ func (portal *Portal) ensureUserInvited(user *User, ignoreCache bool) bool {
 	return user.ensureInvited(portal.MainIntent(), portal.MXID, portal.IsPrivateChat(), ignoreCache)
 }
 
-func (portal *Portal) markMessageHandled(discordID string, editIndex int, authorID string, timestamp time.Time, threadID string, parts []database.MessagePart) {
+func (portal *Portal) markMessageHandled(discordID string, authorID string, timestamp time.Time, threadID string, parts []database.MessagePart) {
 	msg := portal.bridge.DB.Message.New()
 	msg.Channel = portal.Key
 	msg.DiscordID = discordID
-	msg.EditIndex = editIndex
 	msg.SenderID = authorID
 	msg.Timestamp = timestamp
 	msg.ThreadID = threadID
@@ -674,7 +673,7 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
 	} else if len(dbParts) == 0 {
 		log.Warn().Msg("All parts of message failed to send to Matrix")
 	} else {
-		portal.markMessageHandled(msg.ID, 0, msg.Author.ID, ts, discordThreadID, dbParts)
+		portal.markMessageHandled(msg.ID, msg.Author.ID, ts, discordThreadID, dbParts)
 	}
 }
 
@@ -778,6 +777,13 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess
 		log.Warn().Msg("Dropping update of unknown message")
 		return
 	}
+	if msg.EditedTimestamp != nil && !msg.EditedTimestamp.After(existing[0].EditTimestamp) {
+		log.Debug().
+			Time("received_edit_ts", *msg.EditedTimestamp).
+			Time("db_edit_ts", existing[0].EditTimestamp).
+			Msg("Dropping update of message with older or equal edit timestamp")
+		return
+	}
 
 	if msg.Flags == discordgo.MessageFlagsHasThread {
 		thread := portal.bridge.GetThreadByID(msg.ID, existing[0])
@@ -885,8 +891,9 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess
 
 	portal.sendDeliveryReceipt(resp.EventID)
 
-	//ts, _ := msg.Timestamp.Parse()
-	//portal.markMessageHandled(existing, msg.ID, resp.EventID, msg.Author.ID, ts)
+	if msg.EditedTimestamp != nil {
+		existing[0].UpdateEditTimestamp(*msg.EditedTimestamp)
+	}
 }
 
 func (portal *Portal) handleDiscordMessageDelete(user *User, msg *discordgo.Message) {
@@ -1386,16 +1393,20 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) {
 		if edits != nil {
 			discordContent, allowedMentions := portal.parseMatrixHTML(content.NewContent)
 			var err error
+			var msg *discordgo.Message
 			if !isWebhookSend {
 				// TODO save edit in message table
-				_, err = sess.ChannelMessageEdit(edits.DiscordProtoChannelID(), edits.DiscordID, discordContent)
+				msg, err = sess.ChannelMessageEdit(edits.DiscordProtoChannelID(), edits.DiscordID, discordContent)
 			} else {
-				_, err = relayClient.WebhookMessageEdit(portal.RelayWebhookID, portal.RelayWebhookSecret, edits.DiscordID, &discordgo.WebhookEdit{
+				msg, err = relayClient.WebhookMessageEdit(portal.RelayWebhookID, portal.RelayWebhookSecret, edits.DiscordID, &discordgo.WebhookEdit{
 					Content:         &discordContent,
 					AllowedMentions: allowedMentions,
 				})
 			}
 			go portal.sendMessageMetrics(evt, err, "Failed to edit")
+			if msg.EditedTimestamp != nil {
+				edits.UpdateEditTimestamp(*msg.EditedTimestamp)
+			}
 		} else {
 			go portal.sendMessageMetrics(evt, fmt.Errorf("%w %s", errUnknownEditTarget, editMXID), "Ignoring")
 		}