Эх сурвалжийг харах

Reroute broadcast list messages to correct DM portal. Fixes #411

Tulir Asokan 3 жил өмнө
parent
commit
ca5fcc42ba

+ 12 - 11
database/message.go

@@ -43,27 +43,27 @@ func (mq *MessageQuery) New() *Message {
 
 const (
 	getAllMessagesQuery = `
-		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
 		WHERE chat_jid=$1 AND chat_receiver=$2
 	`
 	getMessageByJIDQuery = `
-		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
 		WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3
 	`
 	getMessageByMXIDQuery = `
-		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
 		WHERE mxid=$1
 	`
 	getLastMessageInChatQuery = `
-		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
 		WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp<=$3 AND sent=true ORDER BY timestamp DESC LIMIT 1
 	`
 	getFirstMessageInChatQuery = `
-		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
 		WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp ASC LIMIT 1
 	`
 	getMessagesBetweenQuery = `
-		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
 		WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp>$3 AND timestamp<=$4 AND sent=true ORDER BY timestamp ASC
 	`
 )
@@ -133,7 +133,8 @@ type Message struct {
 	Timestamp time.Time
 	Sent      bool
 
-	DecryptionError bool
+	DecryptionError  bool
+	BroadcastListJID types.JID
 }
 
 func (msg *Message) IsFakeMXID() bool {
@@ -146,7 +147,7 @@ func (msg *Message) IsFakeJID() bool {
 
 func (msg *Message) Scan(row Scannable) *Message {
 	var ts int64
-	err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.DecryptionError)
+	err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.DecryptionError, &msg.BroadcastListJID)
 	if err != nil {
 		if !errors.Is(err, sql.ErrNoRows) {
 			msg.log.Errorln("Database scan failed:", err)
@@ -166,9 +167,9 @@ func (msg *Message) Insert() {
 		sender = ""
 	}
 	_, err := msg.db.Exec(`INSERT INTO message
-			(chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error)
-			VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
-		msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.DecryptionError)
+			(chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid)
+			VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
+		msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.DecryptionError, msg.BroadcastListJID)
 	if err != nil {
 		msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
 	}

+ 12 - 0
database/upgrades/2021-12-25-broadcast-list-message-source.go

@@ -0,0 +1,12 @@
+package upgrades
+
+import (
+	"database/sql"
+)
+
+func init() {
+	upgrades[32] = upgrade{"Store source broadcast list in message table", func(tx *sql.Tx, ctx context) error {
+		_, err := tx.Exec(`ALTER TABLE message ADD COLUMN broadcast_list_jid TEXT`)
+		return err
+	}}
+}

+ 1 - 1
database/upgrades/upgrades.go

@@ -39,7 +39,7 @@ type upgrade struct {
 	fn      upgradeFunc
 }
 
-const NumberOfUpgrades = 32
+const NumberOfUpgrades = 33
 
 var upgrades [NumberOfUpgrades]upgrade
 

+ 1 - 1
go.mod

@@ -8,7 +8,7 @@ require (
 	github.com/mattn/go-sqlite3 v1.14.9
 	github.com/prometheus/client_golang v1.11.0
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
-	go.mau.fi/whatsmeow v0.0.0-20211221173950-fbdc16e29058
+	go.mau.fi/whatsmeow v0.0.0-20211225184405-612b42c0c164
 	golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d
 	google.golang.org/protobuf v1.27.1
 	gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b

+ 2 - 2
go.sum

@@ -139,8 +139,8 @@ github.com/tidwall/sjson v1.2.3 h1:5+deguEhHSEjmuICXZ21uSSsXotWMA0orU783+Z7Cp8=
 github.com/tidwall/sjson v1.2.3/go.mod h1:5WdjKx3AQMvCJ4RG6/2UYT7dLrGvJUV1x4jdTAyGvZs=
 go.mau.fi/libsignal v0.0.0-20211109153248-a67163214910 h1:9FFhG0OmkuMau5UEaTgiUQ+7cSbtbOQ7hiWKdN8OI3I=
 go.mau.fi/libsignal v0.0.0-20211109153248-a67163214910/go.mod h1:AufGrvVh+00Nc07Jm4hTquh7yleZyn20tKJI2wCPAKg=
-go.mau.fi/whatsmeow v0.0.0-20211221173950-fbdc16e29058 h1:5z1PUeFB4XaTtUzXM2n8nK6c+Uu+Mkzm5JliSTCsFL0=
-go.mau.fi/whatsmeow v0.0.0-20211221173950-fbdc16e29058/go.mod h1:8jUjOAi3xtGubxcZgG8uSHpAdyQXBRbWAfxkctX/4y4=
+go.mau.fi/whatsmeow v0.0.0-20211225184405-612b42c0c164 h1:uA2QfpClxXnrRzkAy08UXJ5P7Wc/QiQFLKZSVAgXg5w=
+go.mau.fi/whatsmeow v0.0.0-20211225184405-612b42c0c164/go.mod h1:8jUjOAi3xtGubxcZgG8uSHpAdyQXBRbWAfxkctX/4y4=
 golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
 golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=

+ 31 - 11
portal.go

@@ -43,21 +43,20 @@ import (
 	"golang.org/x/image/webp"
 	"google.golang.org/protobuf/proto"
 
-	"maunium.net/go/mautrix/format"
-
-	"go.mau.fi/whatsmeow"
-	waProto "go.mau.fi/whatsmeow/binary/proto"
-	"go.mau.fi/whatsmeow/types"
-	"go.mau.fi/whatsmeow/types/events"
-
 	log "maunium.net/go/maulogger/v2"
 
 	"maunium.net/go/mautrix"
 	"maunium.net/go/mautrix/appservice"
 	"maunium.net/go/mautrix/crypto/attachment"
 	"maunium.net/go/mautrix/event"
+	"maunium.net/go/mautrix/format"
 	"maunium.net/go/mautrix/id"
 
+	"go.mau.fi/whatsmeow"
+	waProto "go.mau.fi/whatsmeow/binary/proto"
+	"go.mau.fi/whatsmeow/types"
+	"go.mau.fi/whatsmeow/types/events"
+
 	"maunium.net/go/mautrix-whatsapp/database"
 )
 
@@ -452,6 +451,12 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
 	}
 	converted := portal.convertMessage(intent, source, &evt.Info, evt.Message)
 	if converted != nil {
+		if evt.Info.IsIncomingBroadcast() {
+			if converted.Extra == nil {
+				converted.Extra = map[string]interface{}{}
+			}
+			converted.Extra["fi.mau.whatsapp.source_broadcast_list"] = evt.Info.Chat.String()
+		}
 		var eventID id.EventID
 		if existingMsg != nil {
 			converted.Content.SetEdit(existingMsg.MXID)
@@ -522,6 +527,9 @@ func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo
 		msg.Sender = info.Sender
 		msg.Sent = isSent
 		msg.DecryptionError = decryptionError
+		if info.IsIncomingBroadcast() {
+			msg.BroadcastListJID = info.Chat
+		}
 		msg.Insert()
 	} else {
 		msg.UpdateMXID(mxid, decryptionError)
@@ -2288,13 +2296,25 @@ func (portal *Portal) HandleMatrixReadReceipt(sender *User, eventID id.EventID,
 	}
 	groupedMessages := make(map[types.JID][]types.MessageID)
 	for _, msg := range messages {
-		if !msg.IsFakeJID() {
-			groupedMessages[msg.Sender] = append(groupedMessages[msg.Sender], msg.JID)
-		}
+		var key types.JID
+		if msg.IsFakeJID() || msg.Sender.User == sender.JID.User {
+			// Don't send read receipts for own messages or fake messages
+			continue
+		} else if !portal.IsPrivateChat() {
+			key = msg.Sender
+		} else if !msg.BroadcastListJID.IsEmpty() {
+			key = msg.BroadcastListJID
+		} // else: blank key (participant field isn't needed in direct chat read receipts)
+		groupedMessages[key] = append(groupedMessages[key], msg.JID)
 	}
 	portal.log.Debugfln("Sending read receipts by %s: %v", sender.JID, groupedMessages)
 	for messageSender, ids := range groupedMessages {
-		err := sender.Client.MarkRead(ids, receiptTimestamp, portal.Key.JID, messageSender)
+		chatJID := portal.Key.JID
+		if messageSender.Server == types.BroadcastServer {
+			chatJID = messageSender
+			messageSender = portal.Key.JID
+		}
+		err := sender.Client.MarkRead(ids, receiptTimestamp, chatJID, messageSender)
 		if err != nil {
 			portal.log.Warnfln("Failed to mark %v as read by %s: %v", ids, sender.JID, err)
 		}

+ 23 - 9
user.go

@@ -28,21 +28,20 @@ import (
 
 	log "maunium.net/go/maulogger/v2"
 
-	"go.mau.fi/whatsmeow/appstate"
+	"maunium.net/go/mautrix"
 	"maunium.net/go/mautrix/appservice"
+	"maunium.net/go/mautrix/event"
+	"maunium.net/go/mautrix/format"
+	"maunium.net/go/mautrix/id"
 	"maunium.net/go/mautrix/pushrules"
 
 	"go.mau.fi/whatsmeow"
+	"go.mau.fi/whatsmeow/appstate"
 	"go.mau.fi/whatsmeow/store"
 	"go.mau.fi/whatsmeow/types"
 	"go.mau.fi/whatsmeow/types/events"
 	waLog "go.mau.fi/whatsmeow/util/log"
 
-	"maunium.net/go/mautrix"
-	"maunium.net/go/mautrix/event"
-	"maunium.net/go/mautrix/format"
-	"maunium.net/go/mautrix/id"
-
 	"maunium.net/go/mautrix-whatsapp/database"
 )
 
@@ -442,7 +441,7 @@ func (user *User) HandleEvent(event interface{}) {
 	case *events.ChatPresence:
 		go user.handleChatPresence(v)
 	case *events.Message:
-		portal := user.GetPortalByJID(v.Info.Chat)
+		portal := user.GetPortalByMessageSource(v.Info.MessageSource)
 		portal.messages <- PortalMessage{evt: v, source: user}
 	case *events.CallOffer:
 		user.handleCallStart(v.CallCreator, v.CallID, "", v.Timestamp)
@@ -470,7 +469,7 @@ func (user *User) HandleEvent(event interface{}) {
 	case *events.CallTerminate, *events.CallRelayLatency, *events.CallAccept, *events.UnknownCallEvent:
 		// ignore
 	case *events.UndecryptableMessage:
-		portal := user.GetPortalByJID(v.Info.Chat)
+		portal := user.GetPortalByMessageSource(v.Info.MessageSource)
 		portal.messages <- PortalMessage{undecryptable: v, source: user}
 	case *events.HistorySync:
 		user.historySyncs <- v
@@ -667,6 +666,21 @@ func (user *User) handleLoggedOut(onConnect bool) {
 	user.sendBridgeState(BridgeState{StateEvent: StateBadCredentials, Error: WANotLoggedIn})
 }
 
+func (user *User) GetPortalByMessageSource(ms types.MessageSource) *Portal {
+	jid := ms.Chat
+	if ms.IsIncomingBroadcast() {
+		if ms.IsFromMe {
+			jid = ms.BroadcastListOwner.ToNonAD()
+		} else {
+			jid = ms.Sender.ToNonAD()
+		}
+		if jid.IsEmpty() {
+			return nil
+		}
+	}
+	return user.bridge.GetPortalByJID(database.NewPortalKey(jid, user.JID))
+}
+
 func (user *User) GetPortalByJID(jid types.JID) *Portal {
 	return user.bridge.GetPortalByJID(database.NewPortalKey(jid, user.JID))
 }
@@ -737,7 +751,7 @@ func (user *User) handleReceipt(receipt *events.Receipt) {
 	if receipt.Type != events.ReceiptTypeRead && receipt.Type != events.ReceiptTypeReadSelf {
 		return
 	}
-	portal := user.GetPortalByJID(receipt.Chat)
+	portal := user.GetPortalByMessageSource(receipt.MessageSource)
 	if portal == nil || len(portal.MXID) == 0 {
 		return
 	}