Sfoglia il codice sorgente

Store message sender mxid in database

Tulir Asokan 2 anni fa
parent
commit
559ac719a4

+ 19 - 18
database/message.go

@@ -44,27 +44,27 @@ func (mq *MessageQuery) New() *Message {
 
 
 const (
 const (
 	getAllMessagesQuery = `
 	getAllMessagesQuery = `
-		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
 		WHERE chat_jid=$1 AND chat_receiver=$2
 		WHERE chat_jid=$1 AND chat_receiver=$2
 	`
 	`
 	getMessageByJIDQuery = `
 	getMessageByJIDQuery = `
-		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
 		WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3
 		WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3
 	`
 	`
 	getMessageByMXIDQuery = `
 	getMessageByMXIDQuery = `
-		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
 		WHERE mxid=$1
 		WHERE mxid=$1
 	`
 	`
 	getLastMessageInChatQuery = `
 	getLastMessageInChatQuery = `
-		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
 		WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp<=$3 AND sent=true ORDER BY timestamp DESC LIMIT 1
 		WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp<=$3 AND sent=true ORDER BY timestamp DESC LIMIT 1
 	`
 	`
 	getFirstMessageInChatQuery = `
 	getFirstMessageInChatQuery = `
-		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
 		WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp ASC LIMIT 1
 		WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp ASC LIMIT 1
 	`
 	`
 	getMessagesBetweenQuery = `
 	getMessagesBetweenQuery = `
-		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
 		WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp>$3 AND timestamp<=$4 AND sent=true AND error='' ORDER BY timestamp ASC
 		WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp>$3 AND timestamp<=$4 AND sent=true AND error='' ORDER BY timestamp ASC
 	`
 	`
 )
 )
@@ -146,14 +146,15 @@ type Message struct {
 	db  *Database
 	db  *Database
 	log log.Logger
 	log log.Logger
 
 
-	Chat      PortalKey
-	JID       types.MessageID
-	MXID      id.EventID
-	Sender    types.JID
-	Timestamp time.Time
-	Sent      bool
-	Type      MessageType
-	Error     MessageErrorType
+	Chat       PortalKey
+	JID        types.MessageID
+	MXID       id.EventID
+	Sender     types.JID
+	SenderMXID id.UserID
+	Timestamp  time.Time
+	Sent       bool
+	Type       MessageType
+	Error      MessageErrorType
 
 
 	BroadcastListJID types.JID
 	BroadcastListJID types.JID
 }
 }
@@ -168,7 +169,7 @@ func (msg *Message) IsFakeJID() bool {
 
 
 func (msg *Message) Scan(row dbutil.Scannable) *Message {
 func (msg *Message) Scan(row dbutil.Scannable) *Message {
 	var ts int64
 	var ts int64
-	err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID)
+	err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &msg.SenderMXID, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID)
 	if err != nil {
 	if err != nil {
 		if !errors.Is(err, sql.ErrNoRows) {
 		if !errors.Is(err, sql.ErrNoRows) {
 			msg.log.Errorln("Database scan failed:", err)
 			msg.log.Errorln("Database scan failed:", err)
@@ -192,9 +193,9 @@ func (msg *Message) Insert(txn dbutil.Execable) {
 	}
 	}
 	_, err := txn.Exec(`
 	_, err := txn.Exec(`
 		INSERT INTO message
 		INSERT INTO message
-			(chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid)
-		VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
-	`, msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID)
+			(chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid)
+		VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
+	`, msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.SenderMXID, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID)
 	if err != nil {
 	if err != nil {
 		msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
 		msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
 	}
 	}

+ 2 - 1
database/upgrades/00-latest-revision.sql

@@ -1,4 +1,4 @@
--- v0 -> v55: Latest revision
+-- v0 -> v56 (compatible with v45+): Latest revision
 
 
 CREATE TABLE "user" (
 CREATE TABLE "user" (
     mxid     TEXT PRIMARY KEY,
     mxid     TEXT PRIMARY KEY,
@@ -70,6 +70,7 @@ CREATE TABLE message (
     jid           TEXT,
     jid           TEXT,
     mxid          TEXT UNIQUE,
     mxid          TEXT UNIQUE,
     sender        TEXT,
     sender        TEXT,
+    sender_mxid   TEXT NOT NULL DEFAULT '',
     timestamp     BIGINT,
     timestamp     BIGINT,
     sent          BOOLEAN,
     sent          BOOLEAN,
     error         error_type,
     error         error_type,

+ 2 - 0
database/upgrades/56-message-sender-mxid.sql

@@ -0,0 +1,2 @@
+-- v56 (compatible with v45+): Store whether custom contact info has been set for a puppet
+ALTER TABLE message ADD COLUMN sender_mxid TEXT NOT NULL DEFAULT '';

+ 7 - 1
historysync.go

@@ -45,6 +45,8 @@ type wrappedInfo struct {
 	Type  database.MessageType
 	Type  database.MessageType
 	Error database.MessageErrorType
 	Error database.MessageErrorType
 
 
+	SenderMXID id.UserID
+
 	ReactionTarget types.MessageID
 	ReactionTarget types.MessageID
 
 
 	MediaKey []byte
 	MediaKey []byte
@@ -268,6 +270,7 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
 		msg.MXID = resp.EventID
 		msg.MXID = resp.EventID
 		msg.JID = types.MessageID(resp.EventID)
 		msg.JID = types.MessageID(resp.EventID)
 		msg.Timestamp = conv.LastMessageTimestamp
 		msg.Timestamp = conv.LastMessageTimestamp
+		msg.SenderMXID = portal.MainIntent().UserID
 		msg.Sent = true
 		msg.Sent = true
 		msg.Type = database.MsgFake
 		msg.Type = database.MsgFake
 		msg.Insert(nil)
 		msg.Insert(nil)
@@ -749,6 +752,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag
 	mainInfo := &wrappedInfo{
 	mainInfo := &wrappedInfo{
 		MessageInfo:     info,
 		MessageInfo:     info,
 		Type:            database.MsgNormal,
 		Type:            database.MsgNormal,
+		SenderMXID:      mainEvt.Sender,
 		Error:           converted.Error,
 		Error:           converted.Error,
 		MediaKey:        converted.MediaKey,
 		MediaKey:        converted.MediaKey,
 		ExpirationStart: expirationStart,
 		ExpirationStart: expirationStart,
@@ -783,6 +787,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag
 				*eventsArray = append(*eventsArray, reactionEvent)
 				*eventsArray = append(*eventsArray, reactionEvent)
 				*infoArray = append(*infoArray, &wrappedInfo{
 				*infoArray = append(*infoArray, &wrappedInfo{
 					MessageInfo:    reactionInfo,
 					MessageInfo:    reactionInfo,
+					SenderMXID:     reactionEvent.Sender,
 					ReactionTarget: info.ID,
 					ReactionTarget: info.ID,
 					Type:           database.MsgReaction,
 					Type:           database.MsgReaction,
 				})
 				})
@@ -872,7 +877,7 @@ func (portal *Portal) finishBatch(txn dbutil.Transaction, eventIDs []id.EventID,
 		}
 		}
 
 
 		eventID := eventIDs[i]
 		eventID := eventIDs[i]
-		portal.markHandled(txn, nil, info.MessageInfo, eventID, true, false, info.Type, info.Error)
+		portal.markHandled(txn, nil, info.MessageInfo, eventID, info.SenderMXID, true, false, info.Type, info.Error)
 		if info.Type == database.MsgReaction {
 		if info.Type == database.MsgReaction {
 			portal.upsertReaction(txn, nil, info.ReactionTarget, info.Sender, eventID, info.ID)
 			portal.upsertReaction(txn, nil, info.ReactionTarget, info.Sender, eventID, info.ID)
 		}
 		}
@@ -896,6 +901,7 @@ func (portal *Portal) sendPostBackfillDummy(lastTimestamp time.Time, insertionEv
 	msg := portal.bridge.DB.Message.New()
 	msg := portal.bridge.DB.Message.New()
 	msg.Chat = portal.Key
 	msg.Chat = portal.Key
 	msg.MXID = resp.EventID
 	msg.MXID = resp.EventID
+	msg.SenderMXID = portal.MainIntent().UserID
 	msg.JID = types.MessageID(resp.EventID)
 	msg.JID = types.MessageID(resp.EventID)
 	msg.Timestamp = lastTimestamp.Add(1 * time.Second)
 	msg.Timestamp = lastTimestamp.Add(1 * time.Second)
 	msg.Sent = true
 	msg.Sent = true

+ 24 - 23
portal.go

@@ -670,7 +670,7 @@ func (portal *Portal) handleUndecryptableMessage(source *User, evt *events.Undec
 		portal.log.Errorfln("Failed to send decryption error of %s to Matrix: %v", evt.Info.ID, err)
 		portal.log.Errorfln("Failed to send decryption error of %s to Matrix: %v", evt.Info.ID, err)
 		return
 		return
 	}
 	}
-	portal.finishHandling(nil, &evt.Info, resp.EventID, database.MsgUnknown, database.MsgErrDecryptionFailed)
+	portal.finishHandling(nil, &evt.Info, resp.EventID, intent.UserID, database.MsgUnknown, database.MsgErrDecryptionFailed)
 }
 }
 
 
 func (portal *Portal) handleFakeMessage(msg fakeMessage) {
 func (portal *Portal) handleFakeMessage(msg fakeMessage) {
@@ -703,7 +703,7 @@ func (portal *Portal) handleFakeMessage(msg fakeMessage) {
 			MessageSource: types.MessageSource{
 			MessageSource: types.MessageSource{
 				Sender: msg.Sender,
 				Sender: msg.Sender,
 			},
 			},
-		}, resp.EventID, database.MsgFake, database.MsgNoError)
+		}, resp.EventID, intent.UserID, database.MsgFake, database.MsgNoError)
 	}
 	}
 }
 }
 
 
@@ -818,7 +818,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
 			}
 			}
 		}
 		}
 		if len(eventID) != 0 {
 		if len(eventID) != 0 {
-			portal.finishHandling(existingMsg, &evt.Info, eventID, dbMsgType, converted.Error)
+			portal.finishHandling(existingMsg, &evt.Info, eventID, intent.UserID, dbMsgType, converted.Error)
 		}
 		}
 	} else if msgType == "reaction" || msgType == "encrypted reaction" {
 	} else if msgType == "reaction" || msgType == "encrypted reaction" {
 		if evt.Message.GetEncReactionMessage() != nil {
 		if evt.Message.GetEncReactionMessage() != nil {
@@ -863,7 +863,7 @@ func (portal *Portal) isRecentlyHandled(id types.MessageID, error database.Messa
 	return false
 	return false
 }
 }
 
 
-func (portal *Portal) markHandled(txn dbutil.Transaction, msg *database.Message, info *types.MessageInfo, mxid id.EventID, isSent, recent bool, msgType database.MessageType, errType database.MessageErrorType) *database.Message {
+func (portal *Portal) markHandled(txn dbutil.Transaction, msg *database.Message, info *types.MessageInfo, mxid id.EventID, senderMXID id.UserID, isSent, recent bool, msgType database.MessageType, errType database.MessageErrorType) *database.Message {
 	if msg == nil {
 	if msg == nil {
 		msg = portal.bridge.DB.Message.New()
 		msg = portal.bridge.DB.Message.New()
 		msg.Chat = portal.Key
 		msg.Chat = portal.Key
@@ -871,6 +871,7 @@ func (portal *Portal) markHandled(txn dbutil.Transaction, msg *database.Message,
 		msg.MXID = mxid
 		msg.MXID = mxid
 		msg.Timestamp = info.Timestamp
 		msg.Timestamp = info.Timestamp
 		msg.Sender = info.Sender
 		msg.Sender = info.Sender
+		msg.SenderMXID = senderMXID
 		msg.Sent = isSent
 		msg.Sent = isSent
 		msg.Type = msgType
 		msg.Type = msgType
 		msg.Error = errType
 		msg.Error = errType
@@ -922,8 +923,8 @@ func (portal *Portal) getMessageIntent(user *User, info *types.MessageInfo, msgT
 	return intent
 	return intent
 }
 }
 
 
-func (portal *Portal) finishHandling(existing *database.Message, message *types.MessageInfo, mxid id.EventID, msgType database.MessageType, errType database.MessageErrorType) {
-	portal.markHandled(nil, existing, message, mxid, true, true, msgType, errType)
+func (portal *Portal) finishHandling(existing *database.Message, message *types.MessageInfo, mxid id.EventID, senderMXID id.UserID, msgType database.MessageType, errType database.MessageErrorType) {
+	portal.markHandled(nil, existing, message, mxid, senderMXID, true, true, msgType, errType)
 	portal.sendDeliveryReceipt(mxid)
 	portal.sendDeliveryReceipt(mxid)
 	var suffix string
 	var suffix string
 	if errType == database.MsgErrDecryptionFailed {
 	if errType == database.MsgErrDecryptionFailed {
@@ -1881,19 +1882,20 @@ func (portal *Portal) MainIntent() *appservice.IntentAPI {
 	return portal.bridge.Bot
 	return portal.bridge.Bot
 }
 }
 
 
-func (portal *Portal) addReplyMention(content *event.MessageEventContent, sender types.JID) {
-	if content.Mentions == nil {
+func (portal *Portal) addReplyMention(content *event.MessageEventContent, sender types.JID, senderMXID id.UserID) {
+	if content.Mentions == nil || (sender.IsEmpty() && senderMXID == "") {
 		return
 		return
 	}
 	}
-	var mxid id.UserID
-	if user := portal.bridge.GetUserByJID(sender); user != nil {
-		mxid = user.MXID
-	} else {
-		puppet := portal.bridge.GetPuppetByJID(sender)
-		mxid = puppet.MXID
+	if senderMXID == "" {
+		if user := portal.bridge.GetUserByJID(sender); user != nil {
+			senderMXID = user.MXID
+		} else {
+			puppet := portal.bridge.GetPuppetByJID(sender)
+			senderMXID = puppet.MXID
+		}
 	}
 	}
-	if slices.Contains(content.Mentions.UserIDs, mxid) {
-		content.Mentions.UserIDs = append(content.Mentions.UserIDs, mxid)
+	if senderMXID != "" && !slices.Contains(content.Mentions.UserIDs, senderMXID) {
+		content.Mentions.UserIDs = append(content.Mentions.UserIDs, senderMXID)
 	}
 	}
 }
 }
 
 
@@ -1925,13 +1927,12 @@ func (portal *Portal) SetReply(content *event.MessageEventContent, replyTo *Repl
 	if message == nil || message.IsFakeMXID() {
 	if message == nil || message.IsFakeMXID() {
 		if isBackfill && portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry {
 		if isBackfill && portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry {
 			content.RelatesTo = (&event.RelatesTo{}).SetReplyTo(targetPortal.deterministicEventID(replyTo.Sender, replyTo.MessageID, ""))
 			content.RelatesTo = (&event.RelatesTo{}).SetReplyTo(targetPortal.deterministicEventID(replyTo.Sender, replyTo.MessageID, ""))
-			portal.addReplyMention(content, replyTo.Sender)
+			portal.addReplyMention(content, replyTo.Sender, "")
 			return true
 			return true
 		}
 		}
 		return false
 		return false
 	}
 	}
-	// TODO store sender mxid in db message
-	portal.addReplyMention(content, message.Sender)
+	portal.addReplyMention(content, message.Sender, message.SenderMXID)
 	content.RelatesTo = (&event.RelatesTo{}).SetReplyTo(message.MXID)
 	content.RelatesTo = (&event.RelatesTo{}).SetReplyTo(message.MXID)
 	if portal.bridge.Config.Bridge.DisableReplyFallbacks {
 	if portal.bridge.Config.Bridge.DisableReplyFallbacks {
 		return true
 		return true
@@ -1973,7 +1974,7 @@ func (portal *Portal) HandleMessageReaction(intent *appservice.IntentAPI, user *
 		if err != nil {
 		if err != nil {
 			portal.log.Errorfln("Failed to redact reaction %s/%s from %s to %s: %v", existing.MXID, existing.JID, info.Sender, targetJID, err)
 			portal.log.Errorfln("Failed to redact reaction %s/%s from %s to %s: %v", existing.MXID, existing.JID, info.Sender, targetJID, err)
 		}
 		}
-		portal.finishHandling(existingMsg, info, resp.EventID, database.MsgReaction, database.MsgNoError)
+		portal.finishHandling(existingMsg, info, resp.EventID, intent.UserID, database.MsgReaction, database.MsgNoError)
 		existing.Delete()
 		existing.Delete()
 	} else {
 	} else {
 		target := portal.bridge.DB.Message.GetByJID(portal.Key, targetJID)
 		target := portal.bridge.DB.Message.GetByJID(portal.Key, targetJID)
@@ -1994,7 +1995,7 @@ func (portal *Portal) HandleMessageReaction(intent *appservice.IntentAPI, user *
 			return
 			return
 		}
 		}
 
 
-		portal.finishHandling(existingMsg, info, resp.EventID, database.MsgReaction, database.MsgNoError)
+		portal.finishHandling(existingMsg, info, resp.EventID, intent.UserID, database.MsgReaction, database.MsgNoError)
 		portal.upsertReaction(nil, intent, target.JID, info.Sender, resp.EventID, info.ID)
 		portal.upsertReaction(nil, intent, target.JID, info.Sender, resp.EventID, info.ID)
 	}
 	}
 }
 }
@@ -4134,7 +4135,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing
 	}
 	}
 	info := portal.generateMessageInfo(sender)
 	info := portal.generateMessageInfo(sender)
 	if dbMsg == nil {
 	if dbMsg == nil {
-		dbMsg = portal.markHandled(nil, nil, info, evt.ID, false, true, dbMsgType, database.MsgNoError)
+		dbMsg = portal.markHandled(nil, nil, info, evt.ID, evt.Sender, false, true, dbMsgType, database.MsgNoError)
 	} else {
 	} else {
 		info.ID = dbMsg.JID
 		info.ID = dbMsg.JID
 	}
 	}
@@ -4189,7 +4190,7 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) error
 		return fmt.Errorf("unknown target event %s", content.RelatesTo.EventID)
 		return fmt.Errorf("unknown target event %s", content.RelatesTo.EventID)
 	}
 	}
 	info := portal.generateMessageInfo(sender)
 	info := portal.generateMessageInfo(sender)
-	dbMsg := portal.markHandled(nil, nil, info, evt.ID, false, true, database.MsgReaction, database.MsgNoError)
+	dbMsg := portal.markHandled(nil, nil, info, evt.ID, evt.Sender, false, true, database.MsgReaction, database.MsgNoError)
 	portal.upsertReaction(nil, nil, target.JID, sender.JID, evt.ID, info.ID)
 	portal.upsertReaction(nil, nil, target.JID, sender.JID, evt.ID, info.ID)
 	portal.log.Debugln("Sending reaction", evt.ID, "to WhatsApp", info.ID)
 	portal.log.Debugln("Sending reaction", evt.ID, "to WhatsApp", info.ID)
 	resp, err := portal.sendReactionToWhatsApp(sender, info.ID, target, content.RelatesTo.Key, evt.Timestamp)
 	resp, err := portal.sendReactionToWhatsApp(sender, info.ID, target, content.RelatesTo.Key, evt.Timestamp)