浏览代码

Handle decryption errors from WhatsApp properly

Tulir Asokan 3 年之前
父节点
当前提交
b918b4f261
共有 5 个文件被更改,包括 166 次插入61 次删除
  1. 45 19
      database/message.go
  2. 12 0
      database/upgrades/2021-10-27-message-decryption-errors.go
  3. 1 1
      database/upgrades/upgrades.go
  4. 102 40
      portal.go
  5. 6 1
      user.go

+ 45 - 19
database/message.go

@@ -40,12 +40,34 @@ func (mq *MessageQuery) New() *Message {
 	}
 }
 
+const (
+	getAllMessagesQuery = `
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
+		WHERE chat_jid=$1 AND chat_receiver=$2
+	`
+	getMessageByJIDQuery = `
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
+		WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3
+	`
+	getMessageByMXIDQuery = `
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
+		WHERE mxid=$1
+	`
+	getLastMessageInChatQuery = `
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
+		WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp<=$3 AND sent=true ORDER BY timestamp DESC LIMIT 1
+	`
+	getFirstMessageInChatQuery = `
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
+		WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp ASC LIMIT 1
+	`
+)
+
 func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
-	rows, err := mq.db.Query("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent FROM message WHERE chat_jid=$1 AND chat_receiver=$2", chat.JID, chat.Receiver)
+	rows, err := mq.db.Query(getAllMessagesQuery, chat.JID, chat.Receiver)
 	if err != nil || rows == nil {
 		return nil
 	}
-	defer rows.Close()
 	for rows.Next() {
 		messages = append(messages, mq.New().Scan(rows))
 	}
@@ -53,23 +75,19 @@ func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
 }
 
 func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.MessageID) *Message {
-	return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent "+
-		"FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", chat.JID, chat.Receiver, jid)
+	return mq.maybeScan(mq.db.QueryRow(getMessageByJIDQuery, chat.JID, chat.Receiver, jid))
 }
 
 func (mq *MessageQuery) GetByMXID(mxid id.EventID) *Message {
-	return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent "+
-		"FROM message WHERE mxid=$1", mxid)
+	return mq.maybeScan(mq.db.QueryRow(getMessageByMXIDQuery, mxid))
 }
 
 func (mq *MessageQuery) GetLastInChat(chat PortalKey) *Message {
-	return mq.GetLastInChatBefore(chat, time.Now().Add(60 * time.Second))
+	return mq.GetLastInChatBefore(chat, time.Now().Add(60*time.Second))
 }
 
 func (mq *MessageQuery) GetLastInChatBefore(chat PortalKey, maxTimestamp time.Time) *Message {
-	msg := mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent "+
-		"FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp<=$3 AND sent=true ORDER BY timestamp DESC LIMIT 1",
-		chat.JID, chat.Receiver, maxTimestamp.Unix())
+	msg := mq.maybeScan(mq.db.QueryRow(getLastMessageInChatQuery, chat.JID, chat.Receiver, maxTimestamp.Unix()))
 	if msg == nil || msg.Timestamp.IsZero() {
 		// Old db, we don't know what the last message is.
 		return nil
@@ -78,13 +96,10 @@ func (mq *MessageQuery) GetLastInChatBefore(chat PortalKey, maxTimestamp time.Ti
 }
 
 func (mq *MessageQuery) GetFirstInChat(chat PortalKey) *Message {
-	return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent "+
-		"FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp ASC LIMIT 1",
-		chat.JID, chat.Receiver)
+	return mq.maybeScan(mq.db.QueryRow(getFirstMessageInChatQuery, chat.JID, chat.Receiver))
 }
 
-func (mq *MessageQuery) get(query string, args ...interface{}) *Message {
-	row := mq.db.QueryRow(query, args...)
+func (mq *MessageQuery) maybeScan(row *sql.Row) *Message {
 	if row == nil {
 		return nil
 	}
@@ -101,6 +116,8 @@ type Message struct {
 	Sender    types.JID
 	Timestamp time.Time
 	Sent      bool
+
+	DecryptionError bool
 }
 
 func (msg *Message) IsFakeMXID() bool {
@@ -109,7 +126,7 @@ func (msg *Message) IsFakeMXID() bool {
 
 func (msg *Message) Scan(row Scannable) *Message {
 	var ts int64
-	err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent)
+	err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.DecryptionError)
 	if err != nil {
 		if err != sql.ErrNoRows {
 			msg.log.Errorln("Database scan failed:", err)
@@ -129,9 +146,9 @@ func (msg *Message) Insert() {
 		sender = ""
 	}
 	_, err := msg.db.Exec(`INSERT INTO message
-			(chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent)
-			VALUES ($1, $2, $3, $4, $5, $6, $7)`,
-		msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent)
+			(chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error)
+			VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
+		msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.DecryptionError)
 	if err != nil {
 		msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
 	}
@@ -145,6 +162,15 @@ func (msg *Message) MarkSent() {
 	}
 }
 
+func (msg *Message) UpdateMXID(mxid id.EventID, stillDecryptionError bool) {
+	msg.MXID = mxid
+	msg.DecryptionError = stillDecryptionError
+	_, err := msg.db.Exec("UPDATE message SET mxid=$4, decryption_error=$5 WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", msg.Chat.JID, msg.Chat.Receiver, msg.JID, mxid, stillDecryptionError)
+	if err != nil {
+		msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
+	}
+}
+
 func (msg *Message) Delete() {
 	_, err := msg.db.Exec("DELETE FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", msg.Chat.JID, msg.Chat.Receiver, msg.JID)
 	if err != nil {

+ 12 - 0
database/upgrades/2021-10-27-message-decryption-errors.go

@@ -0,0 +1,12 @@
+package upgrades
+
+import (
+	"database/sql"
+)
+
+func init() {
+	upgrades[27] = upgrade{"Add marker for WhatsApp decryption errors in message table", func(tx *sql.Tx, ctx context) error {
+		_, err := tx.Exec(`ALTER TABLE message ADD COLUMN decryption_error BOOLEAN NOT NULL DEFAULT false`)
+		return err
+	}}
+}

+ 1 - 1
database/upgrades/upgrades.go

@@ -39,7 +39,7 @@ type upgrade struct {
 	fn      upgradeFunc
 }
 
-const NumberOfUpgrades = 27
+const NumberOfUpgrades = 28
 
 var upgrades [NumberOfUpgrades]upgrade
 

+ 102 - 40
portal.go

@@ -42,6 +42,8 @@ import (
 	"golang.org/x/image/webp"
 	"google.golang.org/protobuf/proto"
 
+	"maunium.net/go/mautrix/format"
+
 	"go.mau.fi/whatsmeow"
 	waProto "go.mau.fi/whatsmeow/binary/proto"
 	"go.mau.fi/whatsmeow/types"
@@ -160,8 +162,14 @@ func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal {
 const recentlyHandledLength = 100
 
 type PortalMessage struct {
-	evt    *events.Message
-	source *User
+	evt           *events.Message
+	undecryptable *events.UndecryptableMessage
+	source        *User
+}
+
+type recentlyHandledWrapper struct {
+	id  types.MessageID
+	err bool
 }
 
 type Portal struct {
@@ -174,7 +182,7 @@ type Portal struct {
 	encryptLock    sync.Mutex
 	backfillLock   sync.Mutex
 
-	recentlyHandled      [recentlyHandledLength]types.MessageID
+	recentlyHandled      [recentlyHandledLength]recentlyHandledWrapper
 	recentlyHandledLock  sync.Mutex
 	recentlyHandledIndex uint8
 
@@ -185,8 +193,6 @@ type Portal struct {
 	hasRelaybot *bool
 }
 
-const MaxMessageAgeToCreatePortal = 5 * 60 // 5 minutes
-
 func (portal *Portal) syncDoublePuppetDetailsAfterCreate(source *User) {
 	doublePuppet := portal.bridge.GetPuppetByCustomMXID(source.MXID)
 	if doublePuppet == nil {
@@ -210,13 +216,20 @@ func (portal *Portal) handleMessageLoop() {
 			}
 			portal.syncDoublePuppetDetailsAfterCreate(msg.source)
 		}
-		//portal.backfillLock.Lock()
-		portal.handleMessage(msg.source, msg.evt)
-		//portal.backfillLock.Unlock()
+		if msg.evt != nil {
+			portal.handleMessage(msg.source, msg.evt)
+		} else if msg.undecryptable != nil {
+			portal.handleUndecryptableMessage(msg.source, msg.undecryptable)
+		} else {
+			portal.log.Warnln("Unexpected PortalMessage with no message: %+v", msg)
+		}
 	}
 }
 
 func (portal *Portal) shouldCreateRoom(msg PortalMessage) bool {
+	if msg.undecryptable != nil {
+		return true
+	}
 	waMsg := msg.evt.Message
 	supportedMessages := []interface{}{
 		waMsg.Conversation,
@@ -295,6 +308,30 @@ func (portal *Portal) convertMessage(intent *appservice.IntentAPI, source *User,
 	}
 }
 
+const UndecryptableMessage = "Decrypting message from WhatsApp failed, waiting for sender to re-send... " +
+	"([learn more](https://faq.whatsapp.com/general/security-and-privacy/seeing-waiting-for-this-message-this-may-take-a-while))"
+
+func (portal *Portal) handleUndecryptableMessage(source *User, evt *events.UndecryptableMessage) {
+	if len(portal.MXID) == 0 {
+		portal.log.Warnln("handleUndecryptableMessage called even though portal.MXID is empty")
+		return
+	} else if portal.isRecentlyHandled(evt.Info.ID, true) {
+		portal.log.Debugfln("Not handling %s (undecryptable): message was recently handled", evt.Info.ID)
+		return
+	} else if existingMsg := portal.bridge.DB.Message.GetByJID(portal.Key, evt.Info.ID); existingMsg != nil {
+		portal.log.Debugfln("Not handling %s (undecryptable): message is duplicate", evt.Info.ID)
+		return
+	}
+	intent := portal.getMessageIntent(source, &evt.Info)
+	content := format.RenderMarkdown(UndecryptableMessage, true, false)
+	content.MsgType = event.MsgNotice
+	resp, err := portal.sendMessage(intent, event.EventMessage, &content, evt.Info.Timestamp.UnixMilli())
+	if err != nil {
+		portal.log.Errorln("Failed to send decryption error of %s to Matrix: %v", evt.Info.ID, err)
+	}
+	portal.finishHandling(nil, &evt.Info, resp.EventID, true)
+}
+
 func (portal *Portal) handleMessage(source *User, evt *events.Message) {
 	if len(portal.MXID) == 0 {
 		portal.log.Warnln("handleMessage called even though portal.MXID is empty")
@@ -304,24 +341,35 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
 	msgType := portal.getMessageType(evt.Message)
 	if msgType == "ignore" {
 		return
-	} else if portal.isRecentlyHandled(msgID) {
+	} else if portal.isRecentlyHandled(msgID, false) {
 		portal.log.Debugfln("Not handling %s (%s): message was recently handled", msgID, msgType)
 		return
-	} else if portal.isDuplicate(msgID) {
-		portal.log.Debugfln("Not handling %s (%s): message is duplicate", msgID, msgType)
-		return
 	}
+	existingMsg := portal.bridge.DB.Message.GetByJID(portal.Key, msgID)
+	if existingMsg != nil {
+		if existingMsg.DecryptionError {
+			portal.log.Debugfln("Got decryptable version of previously undecryptable message %s (%s)", msgID, msgType)
+		} else {
+			portal.log.Debugfln("Not handling %s (%s): message is duplicate", msgID, msgType)
+			return
+		}
+	}
+
 	intent := portal.getMessageIntent(source, &evt.Info)
 	converted := portal.convertMessage(intent, source, &evt.Info, evt.Message)
 	if converted != nil {
 		var eventID id.EventID
+		if existingMsg != nil {
+			converted.Content.SetEdit(existingMsg.MXID)
+		}
 		resp, err := portal.sendMessage(converted.Intent, converted.Type, converted.Content, evt.Info.Timestamp.UnixMilli())
 		if err != nil {
 			portal.log.Errorln("Failed to send %s to Matrix: %v", msgID, err)
 		} else {
 			eventID = resp.EventID
 		}
-		if converted.Caption != nil {
+		// TODO figure out how to handle captions with undecryptable messages turning decryptable
+		if converted.Caption != nil && existingMsg == nil {
 			resp, err = portal.sendMessage(converted.Intent, converted.Type, converted.Content, evt.Info.Timestamp.UnixMilli())
 			if err != nil {
 				portal.log.Errorln("Failed to send caption of %s to Matrix: %v", msgID, err)
@@ -330,55 +378,65 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
 			}
 		}
 		if len(eventID) != 0 {
-			portal.finishHandling(&evt.Info, resp.EventID)
+			portal.finishHandling(existingMsg, &evt.Info, resp.EventID, false)
 		}
 	} else if msgType == "revoke" {
 		portal.HandleMessageRevoke(source, evt.Message.GetProtocolMessage().GetKey())
+		if existingMsg != nil {
+			_, _ = portal.MainIntent().RedactEvent(portal.MXID, existingMsg.MXID, mautrix.ReqRedact{
+				Reason: "The undecryptable message was actually the deletion of another message",
+			})
+			existingMsg.UpdateMXID("net.maunium.whatsapp.fake::" + existingMsg.MXID, false)
+		}
 	} else {
 		portal.log.Warnln("Unhandled message:", evt.Info, evt.Message)
+		if existingMsg != nil {
+			_, _ = portal.MainIntent().RedactEvent(portal.MXID, existingMsg.MXID, mautrix.ReqRedact{
+				Reason: "The undecryptable message contained an unsupported message type",
+			})
+			existingMsg.UpdateMXID("net.maunium.whatsapp.fake::" + existingMsg.MXID, false)
+		}
 		return
 	}
 	portal.bridge.Metrics.TrackWhatsAppMessage(evt.Info.Timestamp, strings.Split(msgType, " ")[0])
 }
 
-func (portal *Portal) isRecentlyHandled(id types.MessageID) bool {
+func (portal *Portal) isRecentlyHandled(id types.MessageID, decryptionError bool) bool {
 	start := portal.recentlyHandledIndex
+	lookingForMsg := recentlyHandledWrapper{id, decryptionError}
 	for i := start; i != start; i = (i - 1) % recentlyHandledLength {
-		if portal.recentlyHandled[i] == id {
+		if portal.recentlyHandled[i] == lookingForMsg {
 			return true
 		}
 	}
 	return false
 }
 
-func (portal *Portal) isDuplicate(id types.MessageID) bool {
-	msg := portal.bridge.DB.Message.GetByJID(portal.Key, id)
-	if msg != nil {
-		return true
-	}
-	return false
-}
-
 func init() {
 	gob.Register(&waProto.Message{})
 }
 
-func (portal *Portal) markHandled(info *types.MessageInfo, mxid id.EventID, isSent, recent bool) *database.Message {
-	msg := portal.bridge.DB.Message.New()
-	msg.Chat = portal.Key
-	msg.JID = info.ID
-	msg.MXID = mxid
-	msg.Timestamp = info.Timestamp
-	msg.Sender = info.Sender
-	msg.Sent = isSent
-	msg.Insert()
+func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo, mxid id.EventID, isSent, recent, decryptionError bool) *database.Message {
+	if msg == nil {
+		msg = portal.bridge.DB.Message.New()
+		msg.Chat = portal.Key
+		msg.JID = info.ID
+		msg.MXID = mxid
+		msg.Timestamp = info.Timestamp
+		msg.Sender = info.Sender
+		msg.Sent = isSent
+		msg.DecryptionError = decryptionError
+		msg.Insert()
+	} else {
+		msg.UpdateMXID(mxid, decryptionError)
+	}
 
 	if recent {
 		portal.recentlyHandledLock.Lock()
 		index := portal.recentlyHandledIndex
 		portal.recentlyHandledIndex = (portal.recentlyHandledIndex + 1) % recentlyHandledLength
 		portal.recentlyHandledLock.Unlock()
-		portal.recentlyHandled[index] = msg.JID
+		portal.recentlyHandled[index] = recentlyHandledWrapper{msg.JID, decryptionError}
 	}
 	return msg
 }
@@ -406,10 +464,14 @@ func (portal *Portal) getMessageIntent(user *User, info *types.MessageInfo) *app
 	return puppet.IntentFor(portal)
 }
 
-func (portal *Portal) finishHandling(message *types.MessageInfo, mxid id.EventID) {
-	portal.markHandled(message, mxid, true, true)
+func (portal *Portal) finishHandling(existing *database.Message, message *types.MessageInfo, mxid id.EventID, decryptionError bool) {
+	portal.markHandled(existing, message, mxid, true, true, decryptionError)
 	portal.sendDeliveryReceipt(mxid)
-	portal.log.Debugln("Handled message", message.ID, "->", mxid)
+	if !decryptionError {
+		portal.log.Debugln("Handled message", message.ID, "->", mxid)
+	} else {
+		portal.log.Debugln("Handled message", message.ID, "->", mxid, "(undecryptable message error notice)")
+	}
 }
 
 func (portal *Portal) kickExtraUsers(participantMap map[types.JID]bool) {
@@ -896,13 +958,13 @@ func (portal *Portal) finishBatch(eventIDs []id.EventID, infos []*types.MessageI
 			} else if info, ok := infoMap[types.MessageID(msgID)]; !ok {
 				portal.log.Warnfln("Didn't find info of message %s (event %s) to register it in the database", msgID, eventID)
 			} else {
-				portal.markHandled(info, eventID, true, false)
+				portal.markHandled(nil, info, eventID, true, false, false)
 			}
 		}
 	} else {
 		for i := 0; i < len(infos); i++ {
 			if infos[i] != nil {
-				portal.markHandled(infos[i], eventIDs[i], true, false)
+				portal.markHandled(nil, infos[i], eventIDs[i], true, false, false)
 			}
 		}
 		portal.log.Infofln("Successfully sent %d events", len(eventIDs))
@@ -2358,7 +2420,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event) {
 		return
 	}
 	info := portal.generateMessageInfo(sender)
-	dbMsg := portal.markHandled(info, evt.ID, false, true)
+	dbMsg := portal.markHandled(nil, info, evt.ID, false, true, false)
 	portal.log.Debugln("Sending event", evt.ID, "to WhatsApp", info.ID)
 	err := sender.Client.SendMessage(portal.Key.JID, info.ID, msg)
 	if err != nil {

+ 6 - 1
user.go

@@ -436,7 +436,10 @@ func (user *User) HandleEvent(event interface{}) {
 		go user.handleReceipt(v)
 	case *events.Message:
 		portal := user.GetPortalByJID(v.Info.Chat)
-		portal.messages <- PortalMessage{v, user}
+		portal.messages <- PortalMessage{evt: v, source: user}
+	case *events.UndecryptableMessage:
+		portal := user.GetPortalByJID(v.Info.Chat)
+		portal.messages <- PortalMessage{undecryptable: v, source: user}
 	case *events.HistorySync:
 		user.historySyncs <- v
 	case *events.Mute:
@@ -458,6 +461,8 @@ func (user *User) HandleEvent(event interface{}) {
 		if portal != nil {
 			go user.updateChatTag(nil, portal, user.bridge.Config.Bridge.PinnedTag, v.Action.GetPinned())
 		}
+	case *events.AppState:
+		// Ignore
 	default:
 		user.log.Debugfln("Unknown type of event in HandleEvent: %T", v)
 	}