浏览代码

Add support for intentional mentions

Tulir Asokan 2 年之前
父节点
当前提交
434f27c8b4
共有 9 个文件被更改,包括 91 次插入33 次删除
  1. 8 1
      backfill.go
  2. 1 0
      config/upgrade.go
  3. 18 15
      database/message.go
  4. 3 2
      database/upgrades/00-latest-revision.sql
  5. 2 0
      database/upgrades/20-message-sender-mxid.sql
  6. 2 0
      example-config.yaml
  7. 9 0
      formatter.go
  8. 21 15
      portal.go
  9. 27 0
      portal_convert.go

+ 8 - 1
backfill.go

@@ -217,7 +217,8 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess
 		puppet := portal.bridge.GetPuppetByID(msg.Author.ID)
 		puppet.UpdateInfo(source, msg.Author)
 		intent := puppet.IntentFor(portal)
-		replyTo := portal.getReplyTarget(source, "", msg.MessageReference, msg.Embeds, true)
+		replyTo, replySenderMXID := portal.getReplyTarget(source, "", msg.MessageReference, msg.Embeds, true)
+		mentions := portal.convertDiscordMentions(msg, replySenderMXID, false)
 
 		ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
 		log := log.With().
@@ -232,6 +233,11 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess
 				// Only set reply for first event
 				replyTo = nil
 			}
+
+			part.Content.Mentions = mentions
+			// Only set mentions for first event, but keep empty object for rest
+			mentions = &event.Mentions{}
+
 			partName := part.AttachmentID
 			// Always use blank part name for first part so that replies and other things
 			// can reference it without knowing about attachments.
@@ -262,6 +268,7 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess
 				SenderID:     msg.Author.ID,
 				Timestamp:    ts,
 				AttachmentID: part.AttachmentID,
+				SenderMXID:   intent.UserID,
 			})
 		}
 	}

+ 1 - 0
config/upgrade.go

@@ -85,6 +85,7 @@ func DoUpgrade(helper *up.Helper) {
 	helper.Copy(up.Bool, "bridge", "encryption", "require")
 	helper.Copy(up.Bool, "bridge", "encryption", "appservice")
 	helper.Copy(up.Bool, "bridge", "encryption", "allow_key_sharing")
+	helper.Copy(up.Bool, "bridge", "encryption", "plaintext_mentions")
 	helper.Copy(up.Bool, "bridge", "encryption", "delete_keys", "delete_outbound_on_ack")
 	helper.Copy(up.Bool, "bridge", "encryption", "delete_keys", "dont_store_outbound")
 	helper.Copy(up.Bool, "bridge", "encryption", "delete_keys", "ratchet_on_decrypt")

+ 18 - 15
database/message.go

@@ -19,7 +19,7 @@ type MessageQuery struct {
 }
 
 const (
-	messageSelect = "SELECT dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid FROM message"
+	messageSelect = "SELECT dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid, sender_mxid FROM message"
 )
 
 func (mq *MessageQuery) New() *Message {
@@ -99,11 +99,11 @@ func (mq *MessageQuery) MassInsert(key PortalKey, msgs []Message) {
 	if len(msgs) == 0 {
 		return
 	}
-	valueStringFormat := "($%d, $%d, $1, $2, $%d, $%d, $%d, $%d, $%d)"
+	valueStringFormat := "($%d, $%d, $1, $2, $%d, $%d, $%d, $%d, $%d, $%d)"
 	if mq.db.Dialect == dbutil.SQLite {
 		valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
 	}
-	params := make([]interface{}, 2+len(msgs)*7)
+	params := make([]interface{}, 2+len(msgs)*8)
 	placeholders := make([]string, len(msgs))
 	params[0] = key.ChannelID
 	params[1] = key.Receiver
@@ -116,7 +116,8 @@ func (mq *MessageQuery) MassInsert(key PortalKey, msgs []Message) {
 		params[baseIndex+4] = msg.editTimestampVal()
 		params[baseIndex+5] = msg.ThreadID
 		params[baseIndex+6] = msg.MXID
-		placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7)
+		params[baseIndex+7] = msg.SenderMXID.String()
+		placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7, baseIndex+8)
 	}
 	_, err := mq.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...)
 	if err != nil {
@@ -137,7 +138,8 @@ type Message struct {
 	EditTimestamp time.Time
 	ThreadID      string
 
-	MXID id.EventID
+	MXID       id.EventID
+	SenderMXID id.UserID
 }
 
 func (m *Message) DiscordProtoChannelID() string {
@@ -151,7 +153,7 @@ func (m *Message) DiscordProtoChannelID() string {
 func (m *Message) Scan(row dbutil.Scannable) *Message {
 	var ts, editTS int64
 
-	err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &editTS, &m.ThreadID, &m.MXID)
+	err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &editTS, &m.ThreadID, &m.MXID, &m.SenderMXID)
 	if err != nil {
 		if !errors.Is(err, sql.ErrNoRows) {
 			m.log.Errorln("Database scan failed:", err)
@@ -173,12 +175,12 @@ func (m *Message) Scan(row dbutil.Scannable) *Message {
 
 const messageInsertQuery = `
 	INSERT INTO message (
-		dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid
+		dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid, sender_mxid
 	)
-	VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
+	VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
 `
 
-var messageMassInsertTemplate = strings.Replace(messageInsertQuery, "($1, $2, $3, $4, $5, $6, $7, $8, $9)", "%s", 1)
+var messageMassInsertTemplate = strings.Replace(messageInsertQuery, "($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)", "%s", 1)
 
 type MessagePart struct {
 	AttachmentID string
@@ -196,11 +198,11 @@ func (m *Message) MassInsertParts(msgs []MessagePart) {
 	if len(msgs) == 0 {
 		return
 	}
-	valueStringFormat := "($1, $%d, $2, $3, $4, $5, $6, $7, $%d)"
+	valueStringFormat := "($1, $%d, $2, $3, $4, $5, $6, $7, $%d, $8)"
 	if m.db.Dialect == dbutil.SQLite {
 		valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
 	}
-	params := make([]interface{}, 7+len(msgs)*2)
+	params := make([]interface{}, 8+len(msgs)*2)
 	placeholders := make([]string, len(msgs))
 	params[0] = m.DiscordID
 	params[1] = m.Channel.ChannelID
@@ -209,10 +211,11 @@ func (m *Message) MassInsertParts(msgs []MessagePart) {
 	params[4] = m.Timestamp.UnixMilli()
 	params[5] = m.editTimestampVal()
 	params[6] = m.ThreadID
+	params[7] = m.SenderMXID.String()
 	for i, msg := range msgs {
-		params[7+i*2] = msg.AttachmentID
-		params[7+i*2+1] = msg.MXID
-		placeholders[i] = fmt.Sprintf(valueStringFormat, 7+i*2+1, 7+i*2+2)
+		params[8+i*2] = msg.AttachmentID
+		params[8+i*2+1] = msg.MXID
+		placeholders[i] = fmt.Sprintf(valueStringFormat, 8+i*2+1, 8+i*2+2)
 	}
 	_, err := m.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...)
 	if err != nil {
@@ -224,7 +227,7 @@ func (m *Message) MassInsertParts(msgs []MessagePart) {
 func (m *Message) Insert() {
 	_, err := m.db.Exec(messageInsertQuery,
 		m.DiscordID, m.AttachmentID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
-		m.Timestamp.UnixMilli(), m.editTimestampVal(), m.ThreadID, m.MXID)
+		m.Timestamp.UnixMilli(), m.editTimestampVal(), m.ThreadID, m.MXID, m.SenderMXID.String())
 
 	if err != nil {
 		m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err)

+ 3 - 2
database/upgrades/00-latest-revision.sql

@@ -1,4 +1,4 @@
--- v0 -> v19: Latest revision
+-- v0 -> v20 (compatible with v19+): Latest revision
 
 CREATE TABLE guild (
     dcid       TEXT PRIMARY KEY,
@@ -113,7 +113,8 @@ CREATE TABLE message (
     dc_edit_timestamp BIGINT NOT NULL,
     dc_thread_id      TEXT   NOT NULL,
 
-    mxid TEXT NOT NULL UNIQUE,
+    mxid        TEXT NOT NULL UNIQUE,
+    sender_mxid TEXT NOT NULL DEFAULT '',
 
     PRIMARY KEY (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver),
     CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE

+ 2 - 0
database/upgrades/20-message-sender-mxid.sql

@@ -0,0 +1,2 @@
+-- v20 (compatible with v19+): Store message sender Matrix user ID
+ALTER TABLE message ADD COLUMN sender_mxid TEXT NOT NULL DEFAULT '';

+ 2 - 0
example-config.yaml

@@ -247,6 +247,8 @@ bridge:
         # Enable key sharing? If enabled, key requests for rooms where users are in will be fulfilled.
         # You must use a client that supports requesting keys from other users to use this feature.
         allow_key_sharing: false
+        # Should users mentions be in the event wire content to enable the server to send push notifications?
+        plaintext_mentions: false
         # Options for deleting megolm sessions from the bridge.
         delete_keys:
             # Beeper-specific: delete outbound sessions when hungryserv confirms

+ 9 - 0
formatter.go

@@ -26,6 +26,7 @@ import (
 	"github.com/yuin/goldmark/extension"
 	"github.com/yuin/goldmark/parser"
 	"github.com/yuin/goldmark/util"
+	"golang.org/x/exp/slices"
 
 	"maunium.net/go/mautrix/event"
 	"maunium.net/go/mautrix/format"
@@ -93,6 +94,7 @@ func (portal *Portal) renderDiscordMarkdownOnlyHTML(text string, allowInlineLink
 
 const formatterContextPortalKey = "fi.mau.discord.portal"
 const formatterContextAllowedMentionsKey = "fi.mau.discord.allowed_mentions"
+const formatterContextInputAllowedMentionsKey = "fi.mau.discord.input_allowed_mentions"
 
 func appendIfNotContains(arr []string, newItem string) []string {
 	for _, item := range arr {
@@ -135,6 +137,10 @@ func (br *DiscordBridge) pillConverter(displayname, mxid, eventID string, ctx fo
 			}
 		}
 	} else if mxid[0] == '@' {
+		allowedMentions, _ := ctx.ReturnData[formatterContextInputAllowedMentionsKey].([]id.UserID)
+		if allowedMentions != nil && !slices.Contains(allowedMentions, id.UserID(mxid)) {
+			return displayname
+		}
 		mentions := ctx.ReturnData[formatterContextAllowedMentionsKey].(*discordgo.MessageAllowedMentions)
 		parsedID, ok := br.ParsePuppetMXID(id.UserID(mxid))
 		if ok {
@@ -219,6 +225,9 @@ func (portal *Portal) parseMatrixHTML(content *event.MessageEventContent) (strin
 		ctx := format.NewContext()
 		ctx.ReturnData[formatterContextPortalKey] = portal
 		ctx.ReturnData[formatterContextAllowedMentionsKey] = allowedMentions
+		if content.Mentions != nil {
+			ctx.ReturnData[formatterContextInputAllowedMentionsKey] = content.Mentions.UserIDs
+		}
 		return variationselector.FullyQualify(matrixHTMLParser.Parse(content.FormattedBody, ctx)), allowedMentions
 	} else {
 		return variationselector.FullyQualify(escapeDiscordMarkdown(content.Body)), allowedMentions

+ 21 - 15
portal.go

@@ -584,13 +584,14 @@ func (portal *Portal) ensureUserInvited(user *User, ignoreCache bool) bool {
 	return user.ensureInvited(portal.MainIntent(), portal.MXID, portal.IsPrivateChat(), ignoreCache)
 }
 
-func (portal *Portal) markMessageHandled(discordID string, authorID string, timestamp time.Time, threadID string, parts []database.MessagePart) {
+func (portal *Portal) markMessageHandled(discordID string, authorID string, timestamp time.Time, threadID string, senderMXID id.UserID, parts []database.MessagePart) {
 	msg := portal.bridge.DB.Message.New()
 	msg.Channel = portal.Key
 	msg.DiscordID = discordID
 	msg.SenderID = authorID
 	msg.Timestamp = timestamp
 	msg.ThreadID = threadID
+	msg.SenderMXID = senderMXID
 	msg.MassInsertParts(parts)
 }
 
@@ -618,11 +619,6 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
 	}
 	log.Debug().Msg("Starting handling of Discord message")
 
-	for _, mention := range msg.Mentions {
-		puppet := portal.bridge.GetPuppetByID(mention.ID)
-		puppet.UpdateInfo(nil, mention)
-	}
-
 	puppet := portal.bridge.GetPuppetByID(msg.Author.ID)
 	puppet.UpdateInfo(user, msg.Author)
 	intent := puppet.IntentFor(portal)
@@ -638,7 +634,8 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
 			lastThreadEvent = lastInThread.MXID
 		}
 	}
-	replyTo := portal.getReplyTarget(user, discordThreadID, msg.MessageReference, msg.Embeds, false)
+	replyTo, replySenderMXID := portal.getReplyTarget(user, discordThreadID, msg.MessageReference, msg.Embeds, false)
+	mentions := portal.convertDiscordMentions(msg, replySenderMXID, true)
 
 	ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
 	parts := portal.convertDiscordMessage(ctx, intent, msg)
@@ -658,6 +655,11 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
 			// Only set reply for first event
 			replyTo = nil
 		}
+
+		part.Content.Mentions = mentions
+		// Only set mentions for first event, but keep empty object for rest
+		mentions = &event.Mentions{}
+
 		resp, err := portal.sendMatrixMessage(intent, part.Type, part.Content, part.Extra, ts.UnixMilli())
 		if err != nil {
 			log.Err(err).
@@ -674,7 +676,7 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
 	} else if len(dbParts) == 0 {
 		log.Warn().Msg("All parts of message failed to send to Matrix")
 	} else {
-		portal.markMessageHandled(msg.ID, msg.Author.ID, ts, discordThreadID, dbParts)
+		portal.markMessageHandled(msg.ID, msg.Author.ID, ts, discordThreadID, intent.UserID, dbParts)
 	}
 }
 
@@ -684,7 +686,7 @@ func isReplyEmbed(embed *discordgo.MessageEmbed) bool {
 	return hackyReplyPattern.MatchString(embed.Description)
 }
 
-func (portal *Portal) getReplyTarget(source *User, threadID string, ref *discordgo.MessageReference, embeds []*discordgo.MessageEmbed, allowNonExistent bool) *event.InReplyTo {
+func (portal *Portal) getReplyTarget(source *User, threadID string, ref *discordgo.MessageReference, embeds []*discordgo.MessageEmbed, allowNonExistent bool) (*event.InReplyTo, id.UserID) {
 	if ref == nil && len(embeds) > 0 {
 		match := hackyReplyPattern.FindStringSubmatch(embeds[0].Description)
 		if match != nil && match[1] == portal.GuildID && (match[2] == portal.Key.ChannelID || match[2] == threadID) {
@@ -696,7 +698,7 @@ func (portal *Portal) getReplyTarget(source *User, threadID string, ref *discord
 		}
 	}
 	if ref == nil {
-		return nil
+		return nil, ""
 	}
 	isHungry := portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry
 	if !isHungry {
@@ -709,25 +711,25 @@ func (portal *Portal) getReplyTarget(source *User, threadID string, ref *discord
 	if ref.ChannelID != portal.Key.ChannelID && ref.ChannelID != threadID && crossRoomReplies {
 		targetPortal = portal.bridge.GetExistingPortalByID(database.PortalKey{ChannelID: ref.ChannelID, Receiver: source.DiscordID})
 		if targetPortal == nil {
-			return nil
+			return nil, ""
 		}
 	}
 	replyToMsg := portal.bridge.DB.Message.GetByDiscordID(targetPortal.Key, ref.MessageID)
 	if len(replyToMsg) > 0 {
 		if !crossRoomReplies {
-			return &event.InReplyTo{EventID: replyToMsg[0].MXID}
+			return &event.InReplyTo{EventID: replyToMsg[0].MXID}, replyToMsg[0].SenderMXID
 		}
 		return &event.InReplyTo{
 			EventID:        replyToMsg[0].MXID,
 			UnstableRoomID: targetPortal.MXID,
-		}
+		}, replyToMsg[0].SenderMXID
 	} else if allowNonExistent {
 		return &event.InReplyTo{
 			EventID:        targetPortal.deterministicEventID(ref.MessageID, ""),
 			UnstableRoomID: targetPortal.MXID,
-		}
+		}, ""
 	}
-	return nil
+	return nil, ""
 }
 
 const JoinThreadReaction = "join thread"
@@ -895,7 +897,10 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess
 			Msg("Dropping non-text edit")
 		return
 	}
+	converted.Content.Mentions = portal.convertDiscordMentions(msg, "", false)
 	converted.Content.SetEdit(existing[0].MXID)
+	// Never actually mention new users of edits, only include mentions inside m.new_content
+	converted.Content.Mentions = &event.Mentions{}
 	if converted.Extra != nil {
 		converted.Extra = map[string]any{
 			"m.new_content": converted.Extra,
@@ -1585,6 +1590,7 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) {
 		} else {
 			dbMsg.SenderID = portal.RelayWebhookID
 		}
+		dbMsg.SenderMXID = sender.MXID
 		dbMsg.Timestamp, _ = discordgo.SnowflakeTimestamp(msg.ID)
 		dbMsg.ThreadID = threadID
 		dbMsg.Insert()

+ 27 - 0
portal_convert.go

@@ -26,6 +26,8 @@ import (
 
 	"github.com/bwmarrin/discordgo"
 	"github.com/rs/zerolog"
+	"golang.org/x/exp/slices"
+	"maunium.net/go/mautrix/id"
 
 	"maunium.net/go/mautrix"
 	"maunium.net/go/mautrix/appservice"
@@ -518,6 +520,31 @@ func isPlainGifMessage(msg *discordgo.Message) bool {
 	return len(msg.Embeds) == 1 && msg.Embeds[0].Video != nil && msg.Embeds[0].URL == msg.Content && msg.Embeds[0].Type == discordgo.EmbedTypeGifv
 }
 
+func (portal *Portal) convertDiscordMentions(msg *discordgo.Message, replySender id.UserID, syncGhosts bool) *event.Mentions {
+	var matrixMentions event.Mentions
+	for _, mention := range msg.Mentions {
+		puppet := portal.bridge.GetPuppetByID(mention.ID)
+		if syncGhosts {
+			puppet.UpdateInfo(nil, mention)
+		}
+		user := portal.bridge.GetUserByID(mention.ID)
+		if user != nil {
+			matrixMentions.UserIDs = append(matrixMentions.UserIDs, user.MXID)
+		} else {
+			matrixMentions.UserIDs = append(matrixMentions.UserIDs, puppet.MXID)
+		}
+	}
+	if replySender != "" {
+		matrixMentions.UserIDs = append(matrixMentions.UserIDs, replySender)
+	}
+	slices.Sort(matrixMentions.UserIDs)
+	matrixMentions.UserIDs = slices.Compact(matrixMentions.UserIDs)
+	if msg.MentionEveryone {
+		matrixMentions.Room = true
+	}
+	return &matrixMentions
+}
+
 func (portal *Portal) convertDiscordTextMessage(ctx context.Context, intent *appservice.IntentAPI, msg *discordgo.Message) *ConvertedMessage {
 	log := zerolog.Ctx(ctx)
 	if msg.Type == discordgo.MessageTypeCall {