Эх сурвалжийг харах

Add initial support for requesting media retries from phone

Tulir Asokan 3 жил өмнө
parent
commit
528fbda53f

+ 21 - 13
database/message.go

@@ -43,27 +43,27 @@ func (mq *MessageQuery) New() *Message {
 
 const (
 	getAllMessagesQuery = `
-		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, error, broadcast_list_jid FROM message
 		WHERE chat_jid=$1 AND chat_receiver=$2
 	`
 	getMessageByJIDQuery = `
-		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, error, broadcast_list_jid FROM message
 		WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3
 	`
 	getMessageByMXIDQuery = `
-		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, error, broadcast_list_jid FROM message
 		WHERE mxid=$1
 	`
 	getLastMessageInChatQuery = `
-		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, error, broadcast_list_jid FROM message
 		WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp<=$3 AND sent=true ORDER BY timestamp DESC LIMIT 1
 	`
 	getFirstMessageInChatQuery = `
-		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, error, broadcast_list_jid FROM message
 		WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp ASC LIMIT 1
 	`
 	getMessagesBetweenQuery = `
-		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
+		SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, error, broadcast_list_jid FROM message
 		WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp>$3 AND timestamp<=$4 AND sent=true ORDER BY timestamp ASC
 	`
 )
@@ -122,6 +122,14 @@ func (mq *MessageQuery) maybeScan(row *sql.Row) *Message {
 	return mq.New().Scan(row)
 }
 
+type MessageErrorType string
+
+const (
+	MsgNoError             MessageErrorType = ""
+	MsgErrDecryptionFailed MessageErrorType = "decryption_failed"
+	MsgErrMediaNotFound    MessageErrorType = "media_not_found"
+)
+
 type Message struct {
 	db  *Database
 	log log.Logger
@@ -133,7 +141,7 @@ type Message struct {
 	Timestamp time.Time
 	Sent      bool
 
-	DecryptionError  bool
+	Error            MessageErrorType
 	BroadcastListJID types.JID
 }
 
@@ -147,7 +155,7 @@ func (msg *Message) IsFakeJID() bool {
 
 func (msg *Message) Scan(row Scannable) *Message {
 	var ts int64
-	err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.DecryptionError, &msg.BroadcastListJID)
+	err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.Error, &msg.BroadcastListJID)
 	if err != nil {
 		if !errors.Is(err, sql.ErrNoRows) {
 			msg.log.Errorln("Database scan failed:", err)
@@ -167,9 +175,9 @@ func (msg *Message) Insert() {
 		sender = ""
 	}
 	_, err := msg.db.Exec(`INSERT INTO message
-			(chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid)
+			(chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, error, broadcast_list_jid)
 			VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
-		msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.DecryptionError, msg.BroadcastListJID)
+		msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.Error, msg.BroadcastListJID)
 	if err != nil {
 		msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
 	}
@@ -184,10 +192,10 @@ func (msg *Message) MarkSent(ts time.Time) {
 	}
 }
 
-func (msg *Message) UpdateMXID(mxid id.EventID, stillDecryptionError bool) {
+func (msg *Message) UpdateMXID(mxid id.EventID, newError MessageErrorType) {
 	msg.MXID = mxid
-	msg.DecryptionError = stillDecryptionError
-	_, err := msg.db.Exec("UPDATE message SET mxid=$1, decryption_error=$2 WHERE chat_jid=$3 AND chat_receiver=$4 AND jid=$5", mxid, stillDecryptionError, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
+	msg.Error = newError
+	_, err := msg.db.Exec("UPDATE message SET mxid=$1, error=$2 WHERE chat_jid=$3 AND chat_receiver=$4 AND jid=$5", mxid, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
 	if err != nil {
 		msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
 	}

+ 30 - 0
database/upgrades/2022-02-10-message-error-string.go

@@ -0,0 +1,30 @@
+package upgrades
+
+import "database/sql"
+
+func init() {
+	upgrades[36] = upgrade{"Store message error type as string", func(tx *sql.Tx, ctx context) error {
+		if ctx.dialect == Postgres {
+			_, err := tx.Exec("CREATE TYPE error_type AS ENUM ('', 'decryption_failed', 'media_not_found')")
+			if err != nil {
+				return err
+			}
+		}
+		_, err := tx.Exec("ALTER TABLE message ADD COLUMN error error_type NOT NULL DEFAULT ''")
+		if err != nil {
+			return err
+		}
+		_, err = tx.Exec("UPDATE message SET error='decryption_failed' WHERE decryption_error=true")
+		if err != nil {
+			return err
+		}
+		if ctx.dialect == Postgres {
+			// TODO do this on sqlite at some point
+			_, err = tx.Exec("ALTER TABLE message DROP COLUMN decryption_error")
+			if err != nil {
+				return err
+			}
+		}
+		return nil
+	}}
+}

+ 1 - 1
database/upgrades/upgrades.go

@@ -40,7 +40,7 @@ type upgrade struct {
 	fn      upgradeFunc
 }
 
-const NumberOfUpgrades = 36
+const NumberOfUpgrades = 37
 
 var upgrades [NumberOfUpgrades]upgrade
 

+ 1 - 1
go.mod

@@ -10,7 +10,7 @@ require (
 	github.com/prometheus/client_golang v1.11.0
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
 	github.com/tidwall/gjson v1.13.0
-	go.mau.fi/whatsmeow v0.0.0-20220210104450-b05cf0cef136
+	go.mau.fi/whatsmeow v0.0.0-20220210171358-894bfaa70e7b
 	golang.org/x/image v0.0.0-20211028202545-6944b10bf410
 	google.golang.org/protobuf v1.27.1
 	gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b

+ 2 - 2
go.sum

@@ -140,8 +140,8 @@ github.com/tidwall/sjson v1.2.4 h1:cuiLzLnaMeBhRmEv00Lpk3tkYrcxpmbU81tAY4Dw0tc=
 github.com/tidwall/sjson v1.2.4/go.mod h1:098SZ494YoMWPmMO6ct4dcFnqxwj9r/gF0Etp19pSNM=
 go.mau.fi/libsignal v0.0.0-20211109153248-a67163214910 h1:9FFhG0OmkuMau5UEaTgiUQ+7cSbtbOQ7hiWKdN8OI3I=
 go.mau.fi/libsignal v0.0.0-20211109153248-a67163214910/go.mod h1:AufGrvVh+00Nc07Jm4hTquh7yleZyn20tKJI2wCPAKg=
-go.mau.fi/whatsmeow v0.0.0-20220210104450-b05cf0cef136 h1:AeE8izwMOwAbzNGC+GGOY821w6EWSpPxVmQcpRgyff4=
-go.mau.fi/whatsmeow v0.0.0-20220210104450-b05cf0cef136/go.mod h1:8jUjOAi3xtGubxcZgG8uSHpAdyQXBRbWAfxkctX/4y4=
+go.mau.fi/whatsmeow v0.0.0-20220210171358-894bfaa70e7b h1:IwkG1atB+tTeMXIaPTeoZEkfviFNmxgbXLfqY3NseMk=
+go.mau.fi/whatsmeow v0.0.0-20220210171358-894bfaa70e7b/go.mod h1:8jUjOAi3xtGubxcZgG8uSHpAdyQXBRbWAfxkctX/4y4=
 golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
 golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=

+ 16 - 9
historysync.go

@@ -28,6 +28,8 @@ import (
 	"maunium.net/go/mautrix/appservice"
 	"maunium.net/go/mautrix/event"
 	"maunium.net/go/mautrix/id"
+
+	"maunium.net/go/mautrix-whatsapp/database"
 )
 
 // region User history sync handling
@@ -44,6 +46,11 @@ type portalToBackfill struct {
 	msgs   []*waProto.WebMessageInfo
 }
 
+type wrappedInfo struct {
+	*types.MessageInfo
+	Error database.MessageErrorType
+}
+
 type conversationList []*waProto.Conversation
 
 var _ sort.Interface = (conversationList)(nil)
@@ -310,7 +317,7 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo)
 	defer portal.backfillLock.Unlock()
 
 	var historyBatch, newBatch mautrix.ReqBatchSend
-	var historyBatchInfos, newBatchInfos []*types.MessageInfo
+	var historyBatchInfos, newBatchInfos []*wrappedInfo
 
 	firstMsgTimestamp := time.Unix(int64(messages[len(messages)-1].GetMessageTimestamp()), 0)
 
@@ -392,7 +399,7 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo)
 			continue
 		}
 		var batch *mautrix.ReqBatchSend
-		var infos *[]*types.MessageInfo
+		var infos *[]*wrappedInfo
 		if !historyMaxTs.IsZero() && info.Timestamp.Before(historyMaxTs) {
 			batch, infos = &historyBatch, &historyBatchInfos
 		} else if !newMinTs.IsZero() && info.Timestamp.After(newMinTs) {
@@ -498,7 +505,7 @@ func (portal *Portal) parseWebMessageInfo(webMsg *waProto.WebMessageInfo) *types
 	return &info
 }
 
-func (portal *Portal) appendBatchEvents(converted *ConvertedMessage, info *types.MessageInfo, eventsArray *[]*event.Event, infoArray *[]*types.MessageInfo) error {
+func (portal *Portal) appendBatchEvents(converted *ConvertedMessage, info *types.MessageInfo, eventsArray *[]*event.Event, infoArray *[]*wrappedInfo) error {
 	mainEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, converted.Content, converted.Extra)
 	if err != nil {
 		return err
@@ -509,10 +516,10 @@ func (portal *Portal) appendBatchEvents(converted *ConvertedMessage, info *types
 			return err
 		}
 		*eventsArray = append(*eventsArray, mainEvt, captionEvt)
-		*infoArray = append(*infoArray, nil, info)
+		*infoArray = append(*infoArray, &wrappedInfo{info, converted.Error}, nil)
 	} else {
 		*eventsArray = append(*eventsArray, mainEvt)
-		*infoArray = append(*infoArray, info)
+		*infoArray = append(*infoArray, &wrappedInfo{info, converted.Error})
 	}
 	if converted.MultiEvent != nil {
 		for _, subEvtContent := range converted.MultiEvent {
@@ -553,10 +560,10 @@ func (portal *Portal) wrapBatchEvent(info *types.MessageInfo, intent *appservice
 	}, nil
 }
 
-func (portal *Portal) finishBatch(eventIDs []id.EventID, infos []*types.MessageInfo) {
+func (portal *Portal) finishBatch(eventIDs []id.EventID, infos []*wrappedInfo) {
 	if len(eventIDs) != len(infos) {
 		portal.log.Errorfln("Length of event IDs (%d) and message infos (%d) doesn't match! Using slow path for mapping event IDs", len(eventIDs), len(infos))
-		infoMap := make(map[types.MessageID]*types.MessageInfo, len(infos))
+		infoMap := make(map[types.MessageID]*wrappedInfo, len(infos))
 		for _, info := range infos {
 			infoMap[info.ID] = info
 		}
@@ -568,13 +575,13 @@ func (portal *Portal) finishBatch(eventIDs []id.EventID, infos []*types.MessageI
 			} else if info, ok := infoMap[types.MessageID(msgID)]; !ok {
 				portal.log.Warnfln("Didn't find info of message %s (event %s) to register it in the database", msgID, eventID)
 			} else {
-				portal.markHandled(nil, info, eventID, true, false, false)
+				portal.markHandled(nil, info.MessageInfo, eventID, true, false, info.Error)
 			}
 		}
 	} else {
 		for i := 0; i < len(infos); i++ {
 			if infos[i] != nil {
-				portal.markHandled(nil, infos[i], eventIDs[i], true, false, false)
+				portal.markHandled(nil, infos[i].MessageInfo, eventIDs[i], true, false, infos[i].Error)
 			}
 		}
 		portal.log.Infofln("Successfully sent %d events", len(eventIDs))

+ 7 - 1
matrix.go

@@ -476,7 +476,13 @@ func (mx *MatrixHandler) HandleReaction(evt *event.Event) {
 	}
 
 	portal := mx.bridge.GetPortalByMXID(evt.RoomID)
-	if portal != nil && (user.Whitelisted || portal.HasRelaybot()) && mx.bridge.Config.Bridge.ReactionNotices {
+	if portal == nil || (!user.Whitelisted && !portal.HasRelaybot()) {
+		return
+	}
+	content := evt.Content.AsReaction()
+	if content.RelatesTo.Key == "click to retry" {
+		portal.requestMediaRetry(user, content.RelatesTo.EventID)
+	} else if mx.bridge.Config.Bridge.ReactionNotices {
 		_, _ = portal.sendMainIntentMessage(&event.MessageEventContent{
 			MsgType: event.MsgNotice,
 			Body:    fmt.Sprintf("\u26a0 Reactions are not yet supported by WhatsApp."),

+ 307 - 108
portal.go

@@ -20,6 +20,7 @@ import (
 	"bytes"
 	"context"
 	"encoding/gob"
+	"encoding/json"
 	"errors"
 	"fmt"
 	"html"
@@ -35,6 +36,7 @@ import (
 	"sync"
 	"time"
 
+	"github.com/tidwall/gjson"
 	"golang.org/x/image/draw"
 	"golang.org/x/image/webp"
 	"google.golang.org/protobuf/proto"
@@ -132,32 +134,32 @@ func (portal *Portal) GetUsers() []*User {
 	return nil
 }
 
-func (bridge *Bridge) NewManualPortal(key database.PortalKey) *Portal {
+func (bridge *Bridge) newBlankPortal(key database.PortalKey) *Portal {
 	portal := &Portal{
-		Portal: bridge.DB.Portal.New(),
 		bridge: bridge,
 		log:    bridge.Log.Sub(fmt.Sprintf("Portal/%s", key)),
 
 		messages:       make(chan PortalMessage, bridge.Config.Bridge.PortalMessageBuffer),
 		receipts:       make(chan PortalReceipt, bridge.Config.Bridge.PortalMessageBuffer),
 		matrixMessages: make(chan PortalMatrixMessage, bridge.Config.Bridge.PortalMessageBuffer),
+		mediaRetries:   make(chan PortalMediaRetry, bridge.Config.Bridge.PortalMessageBuffer),
+
+		mediaErrorCache: make(map[types.MessageID]*FailedMediaMeta),
 	}
-	portal.Key = key
 	go portal.handleMessageLoop()
 	return portal
 }
 
-func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal {
-	portal := &Portal{
-		Portal: dbPortal,
-		bridge: bridge,
-		log:    bridge.Log.Sub(fmt.Sprintf("Portal/%s", dbPortal.Key)),
+func (bridge *Bridge) NewManualPortal(key database.PortalKey) *Portal {
+	portal := bridge.newBlankPortal(key)
+	portal.Portal = bridge.DB.Portal.New()
+	portal.Key = key
+	return portal
+}
 
-		messages:       make(chan PortalMessage, bridge.Config.Bridge.PortalMessageBuffer),
-		receipts:       make(chan PortalReceipt, bridge.Config.Bridge.PortalMessageBuffer),
-		matrixMessages: make(chan PortalMatrixMessage, bridge.Config.Bridge.PortalMessageBuffer),
-	}
-	go portal.handleMessageLoop()
+func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal {
+	portal := bridge.newBlankPortal(dbPortal.Key)
+	portal.Portal = dbPortal
 	return portal
 }
 
@@ -188,9 +190,14 @@ type PortalMatrixMessage struct {
 	user *User
 }
 
+type PortalMediaRetry struct {
+	evt    *events.MediaRetry
+	source *User
+}
+
 type recentlyHandledWrapper struct {
 	id  types.MessageID
-	err bool
+	err database.MessageErrorType
 }
 
 type Portal struct {
@@ -216,6 +223,9 @@ type Portal struct {
 	messages       chan PortalMessage
 	receipts       chan PortalReceipt
 	matrixMessages chan PortalMatrixMessage
+	mediaRetries   chan PortalMediaRetry
+
+	mediaErrorCache map[types.MessageID]*FailedMediaMeta
 
 	relayUser *User
 }
@@ -303,6 +313,8 @@ func (portal *Portal) handleMessageLoop() {
 			portal.handleReceipt(receipt.evt, receipt.source)
 		case msg := <-portal.matrixMessages:
 			portal.handleMatrixMessageLoopItem(msg)
+		case retry := <-portal.mediaRetries:
+			portal.handleMediaRetry(retry.evt, retry.source)
 		}
 	}
 }
@@ -523,7 +535,7 @@ func (portal *Portal) handleUndecryptableMessage(source *User, evt *events.Undec
 	if len(portal.MXID) == 0 {
 		portal.log.Warnln("handleUndecryptableMessage called even though portal.MXID is empty")
 		return
-	} else if portal.isRecentlyHandled(evt.Info.ID, true) {
+	} else if portal.isRecentlyHandled(evt.Info.ID, database.MsgErrDecryptionFailed) {
 		portal.log.Debugfln("Not handling %s (undecryptable): message was recently handled", evt.Info.ID)
 		return
 	} else if existingMsg := portal.bridge.DB.Message.GetByJID(portal.Key, evt.Info.ID); existingMsg != nil {
@@ -540,11 +552,11 @@ func (portal *Portal) handleUndecryptableMessage(source *User, evt *events.Undec
 	if err != nil {
 		portal.log.Errorln("Failed to send decryption error of %s to Matrix: %v", evt.Info.ID, err)
 	}
-	portal.finishHandling(nil, &evt.Info, resp.EventID, true)
+	portal.finishHandling(nil, &evt.Info, resp.EventID, database.MsgErrDecryptionFailed)
 }
 
 func (portal *Portal) handleFakeMessage(msg fakeMessage) {
-	if portal.isRecentlyHandled(msg.ID, false) {
+	if portal.isRecentlyHandled(msg.ID, database.MsgNoError) {
 		portal.log.Debugfln("Not handling %s (fake): message was recently handled", msg.ID)
 		return
 	} else if existingMsg := portal.bridge.DB.Message.GetByJID(portal.Key, msg.ID); existingMsg != nil {
@@ -573,7 +585,7 @@ func (portal *Portal) handleFakeMessage(msg fakeMessage) {
 			MessageSource: types.MessageSource{
 				Sender: msg.Sender,
 			},
-		}, resp.EventID, false)
+		}, resp.EventID, database.MsgNoError)
 	}
 }
 
@@ -586,13 +598,13 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
 	msgType := getMessageType(evt.Message)
 	if msgType == "ignore" {
 		return
-	} else if portal.isRecentlyHandled(msgID, false) {
+	} else if portal.isRecentlyHandled(msgID, database.MsgNoError) {
 		portal.log.Debugfln("Not handling %s (%s): message was recently handled", msgID, msgType)
 		return
 	}
 	existingMsg := portal.bridge.DB.Message.GetByJID(portal.Key, msgID)
 	if existingMsg != nil {
-		if existingMsg.DecryptionError {
+		if existingMsg.Error == database.MsgErrDecryptionFailed {
 			portal.log.Debugfln("Got decryptable version of previously undecryptable message %s (%s)", msgID, msgType)
 		} else {
 			portal.log.Debugfln("Not handling %s (%s): message is duplicate", msgID, msgType)
@@ -634,7 +646,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
 				portal.log.Errorfln("Failed to send caption of %s to Matrix: %v", msgID, err)
 			} else {
 				portal.MarkDisappearing(resp.EventID, converted.ExpiresIn, false)
-				eventID = resp.EventID
+				//eventID = resp.EventID
 			}
 		}
 		if converted.MultiEvent != nil && existingMsg == nil {
@@ -648,7 +660,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
 			}
 		}
 		if len(eventID) != 0 {
-			portal.finishHandling(existingMsg, &evt.Info, eventID, false)
+			portal.finishHandling(existingMsg, &evt.Info, eventID, converted.Error)
 		}
 	} else if msgType == "revoke" {
 		portal.HandleMessageRevoke(source, &evt.Info, evt.Message.GetProtocolMessage().GetKey())
@@ -656,7 +668,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
 			_, _ = portal.MainIntent().RedactEvent(portal.MXID, existingMsg.MXID, mautrix.ReqRedact{
 				Reason: "The undecryptable message was actually the deletion of another message",
 			})
-			existingMsg.UpdateMXID("net.maunium.whatsapp.fake::"+existingMsg.MXID, false)
+			existingMsg.UpdateMXID("net.maunium.whatsapp.fake::"+existingMsg.MXID, database.MsgNoError)
 		}
 	} else {
 		portal.log.Warnfln("Unhandled message: %+v (%s)", evt.Info, msgType)
@@ -664,16 +676,16 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
 			_, _ = portal.MainIntent().RedactEvent(portal.MXID, existingMsg.MXID, mautrix.ReqRedact{
 				Reason: "The undecryptable message contained an unsupported message type",
 			})
-			existingMsg.UpdateMXID("net.maunium.whatsapp.fake::"+existingMsg.MXID, false)
+			existingMsg.UpdateMXID("net.maunium.whatsapp.fake::"+existingMsg.MXID, database.MsgNoError)
 		}
 		return
 	}
 	portal.bridge.Metrics.TrackWhatsAppMessage(evt.Info.Timestamp, strings.Split(msgType, " ")[0])
 }
 
-func (portal *Portal) isRecentlyHandled(id types.MessageID, decryptionError bool) bool {
+func (portal *Portal) isRecentlyHandled(id types.MessageID, error database.MessageErrorType) bool {
 	start := portal.recentlyHandledIndex
-	lookingForMsg := recentlyHandledWrapper{id, decryptionError}
+	lookingForMsg := recentlyHandledWrapper{id, error}
 	for i := start; i != start; i = (i - 1) % recentlyHandledLength {
 		if portal.recentlyHandled[i] == lookingForMsg {
 			return true
@@ -686,7 +698,7 @@ func init() {
 	gob.Register(&waProto.Message{})
 }
 
-func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo, mxid id.EventID, isSent, recent, decryptionError bool) *database.Message {
+func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo, mxid id.EventID, isSent, recent bool, error database.MessageErrorType) *database.Message {
 	if msg == nil {
 		msg = portal.bridge.DB.Message.New()
 		msg.Chat = portal.Key
@@ -695,13 +707,13 @@ func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo
 		msg.Timestamp = info.Timestamp
 		msg.Sender = info.Sender
 		msg.Sent = isSent
-		msg.DecryptionError = decryptionError
+		msg.Error = error
 		if info.IsIncomingBroadcast() {
 			msg.BroadcastListJID = info.Chat
 		}
 		msg.Insert()
 	} else {
-		msg.UpdateMXID(mxid, decryptionError)
+		msg.UpdateMXID(mxid, error)
 	}
 
 	if recent {
@@ -709,7 +721,7 @@ func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo
 		index := portal.recentlyHandledIndex
 		portal.recentlyHandledIndex = (portal.recentlyHandledIndex + 1) % recentlyHandledLength
 		portal.recentlyHandledLock.Unlock()
-		portal.recentlyHandled[index] = recentlyHandledWrapper{msg.JID, decryptionError}
+		portal.recentlyHandled[index] = recentlyHandledWrapper{msg.JID, error}
 	}
 	return msg
 }
@@ -730,14 +742,16 @@ func (portal *Portal) getMessageIntent(user *User, info *types.MessageInfo) *app
 	return portal.getMessagePuppet(user, info).IntentFor(portal)
 }
 
-func (portal *Portal) finishHandling(existing *database.Message, message *types.MessageInfo, mxid id.EventID, decryptionError bool) {
-	portal.markHandled(existing, message, mxid, true, true, decryptionError)
+func (portal *Portal) finishHandling(existing *database.Message, message *types.MessageInfo, mxid id.EventID, error database.MessageErrorType) {
+	portal.markHandled(existing, message, mxid, true, true, error)
 	portal.sendDeliveryReceipt(mxid)
-	if !decryptionError {
-		portal.log.Debugln("Handled message", message.ID, "->", mxid)
-	} else {
-		portal.log.Debugln("Handled message", message.ID, "->", mxid, "(undecryptable message error notice)")
+	var suffix string
+	if error == database.MsgErrDecryptionFailed {
+		suffix = "(undecryptable message error notice)"
+	} else if error == database.MsgErrMediaNotFound {
+		suffix = "(media not found notice)"
 	}
+	portal.log.Debugln("Handled message", message.ID, "->", mxid, suffix)
 }
 
 func (portal *Portal) kickExtraUsers(participantMap map[types.JID]bool) {
@@ -1492,6 +1506,7 @@ type ConvertedMessage struct {
 
 	ReplyTo   types.MessageID
 	ExpiresIn uint32
+	Error     database.MessageErrorType
 }
 
 func (portal *Portal) convertTextMessage(intent *appservice.IntentAPI, source *User, msg *waProto.Message) *ConvertedMessage {
@@ -1786,12 +1801,49 @@ func (portal *Portal) HandleWhatsAppInvite(source *User, senderJID *types.JID, j
 	return
 }
 
-func (portal *Portal) makeMediaBridgeFailureMessage(intent *appservice.IntentAPI, info *types.MessageInfo, bridgeErr error, captionContent *event.MessageEventContent) *ConvertedMessage {
+const failedMediaField = "fi.mau.whatsapp.failed_media"
+
+type FailedMediaKeys struct {
+	Key       []byte              `json:"key"`
+	Length    int                 `json:"length"`
+	Type      whatsmeow.MediaType `json:"type"`
+	SHA256    []byte              `json:"sha256"`
+	EncSHA256 []byte              `json:"enc_sha256"`
+}
+
+type FailedMediaMeta struct {
+	Type         event.Type                 `json:"type"`
+	Content      *event.MessageEventContent `json:"content"`
+	ExtraContent map[string]interface{}     `json:"extra_content,omitempty"`
+	Media        FailedMediaKeys            `json:"whatsapp_media"`
+}
+
+func shallowCopyMap(data map[string]interface{}) map[string]interface{} {
+	newMap := make(map[string]interface{}, len(data))
+	for key, value := range data {
+		newMap[key] = value
+	}
+	return newMap
+}
+
+func (portal *Portal) makeMediaBridgeFailureMessage(info *types.MessageInfo, bridgeErr error, converted *ConvertedMessage, keys *FailedMediaKeys) *ConvertedMessage {
 	portal.log.Errorfln("Failed to bridge media for %s: %v", info.ID, bridgeErr)
-	return &ConvertedMessage{Intent: intent, Type: event.EventMessage, Content: &event.MessageEventContent{
+	if keys != nil {
+		meta := &FailedMediaMeta{
+			Type:         converted.Type,
+			Content:      converted.Content,
+			ExtraContent: shallowCopyMap(converted.Extra),
+			Media:        *keys,
+		}
+		converted.Extra[failedMediaField] = meta
+		portal.mediaErrorCache[info.ID] = meta
+	}
+	converted.Type = event.EventMessage
+	converted.Content = &event.MessageEventContent{
 		MsgType: event.MsgNotice,
-		Body:    "Failed to bridge media",
-	}, Caption: captionContent}
+		Body:    fmt.Sprintf("Failed to bridge media: %v", bridgeErr),
+	}
+	return converted
 }
 
 func (portal *Portal) encryptFile(data []byte, mimeType string) ([]byte, string, *event.EncryptedFileInfo) {
@@ -1809,6 +1861,7 @@ func (portal *Portal) encryptFile(data []byte, mimeType string) ([]byte, string,
 type MediaMessage interface {
 	whatsmeow.DownloadableMessage
 	GetContextInfo() *waProto.ContextInfo
+	GetFileLength() uint64
 	GetMimetype() string
 }
 
@@ -1838,72 +1891,20 @@ type MediaMessageWithDuration interface {
 	GetSeconds() uint32
 }
 
-func (portal *Portal) convertMediaMessage(intent *appservice.IntentAPI, source *User, info *types.MessageInfo, msg MediaMessage) *ConvertedMessage {
-	messageWithCaption, ok := msg.(MediaMessageWithCaption)
-	var captionContent *event.MessageEventContent
-	if ok && len(messageWithCaption.GetCaption()) > 0 {
-		captionContent = &event.MessageEventContent{
-			Body:    messageWithCaption.GetCaption(),
-			MsgType: event.MsgNotice,
-		}
-
-		portal.bridge.Formatter.ParseWhatsApp(captionContent, msg.GetContextInfo().GetMentionedJid())
-	}
-
-	data, err := source.Client.Download(msg)
-	// TODO can these errors still be handled?
-	//if errors.Is(err, whatsapp.ErrMediaDownloadFailedWith404) || errors.Is(err, whatsapp.ErrMediaDownloadFailedWith410) {
-	//	portal.log.Warnfln("Failed to download media for %s: %v. Calling LoadMediaInfo and retrying download...", msg.info.Id, err)
-	//	_, err = source.Conn.LoadMediaInfo(msg.info.RemoteJid, msg.info.Id, msg.info.FromMe)
-	//	if err != nil {
-	//		portal.sendMediaBridgeFailure(source, intent, msg.info, fmt.Errorf("failed to load media info: %w", err))
-	//		return true
-	//	}
-	//	data, err = msg.download()
-	//}
-	if errors.Is(err, whatsmeow.ErrNoURLPresent) {
-		portal.log.Debugfln("No URL present error for media message %s, ignoring...", info.ID)
-		return nil
-	} else if errors.Is(err, whatsmeow.ErrFileLengthMismatch) || errors.Is(err, whatsmeow.ErrInvalidMediaSHA256) {
-		portal.log.Warnfln("Mismatching media checksums in %s: %v. Ignoring because WhatsApp seems to ignore them too", info.ID, err)
-	} else if err != nil {
-		return portal.makeMediaBridgeFailureMessage(intent, info, err, captionContent)
-	}
-
-	var width, height int
-	messageWithDimensions, ok := msg.(MediaMessageWithDimensions)
-	if ok {
-		width = int(messageWithDimensions.GetWidth())
-		height = int(messageWithDimensions.GetHeight())
-	}
-	if width == 0 && height == 0 && strings.HasPrefix(msg.GetMimetype(), "image/") {
-		cfg, _, _ := image.DecodeConfig(bytes.NewReader(data))
-		width, height = cfg.Width, cfg.Height
-	}
-
-	data, uploadMimeType, file := portal.encryptFile(data, msg.GetMimetype())
-
-	uploaded, err := intent.UploadBytes(data, uploadMimeType)
-	if err != nil {
-		if errors.Is(err, mautrix.MTooLarge) {
-			return portal.makeMediaBridgeFailureMessage(intent, info, errors.New("homeserver rejected too large file"), captionContent)
-		} else if httpErr, ok := err.(mautrix.HTTPError); ok && httpErr.IsStatus(413) {
-			return portal.makeMediaBridgeFailureMessage(intent, info, errors.New("proxy rejected too large file"), captionContent)
-		} else {
-			return portal.makeMediaBridgeFailureMessage(intent, info, fmt.Errorf("failed to upload media: %w", err), captionContent)
-		}
-	}
-
+func (portal *Portal) convertMediaMessageContent(intent *appservice.IntentAPI, msg MediaMessage) *ConvertedMessage {
 	content := &event.MessageEventContent{
-		File: file,
 		Info: &event.FileInfo{
-			Size:     len(data),
 			MimeType: msg.GetMimetype(),
-			Width:    width,
-			Height:   height,
+			Size:     int(msg.GetFileLength()),
 		},
 	}
 
+	messageWithDimensions, ok := msg.(MediaMessageWithDimensions)
+	if ok {
+		content.Info.Width = int(messageWithDimensions.GetWidth())
+		content.Info.Height = int(messageWithDimensions.GetHeight())
+	}
+
 	msgWithName, ok := msg.(MediaMessageWithFileName)
 	if ok && len(msgWithName.GetFileName()) > 0 {
 		content.Body = msgWithName.GetFileName()
@@ -1924,12 +1925,6 @@ func (portal *Portal) convertMediaMessage(intent *appservice.IntentAPI, source *
 		content.Info.Duration = int(msgWithDuration.GetSeconds()) * 1000
 	}
 
-	if content.File != nil {
-		content.File.URL = uploaded.ContentURI.CUString()
-	} else {
-		content.URL = uploaded.ContentURI.CUString()
-	}
-
 	messageWithThumbnail, ok := msg.(MediaMessageWithThumbnail)
 	if ok && messageWithThumbnail.GetJpegThumbnail() != nil && portal.bridge.Config.Bridge.WhatsappThumbnail {
 		thumbnailData := messageWithThumbnail.GetJpegThumbnail()
@@ -1939,7 +1934,7 @@ func (portal *Portal) convertMediaMessage(intent *appservice.IntentAPI, source *
 		thumbnail, thumbnailUploadMime, thumbnailFile := portal.encryptFile(thumbnailData, thumbnailMime)
 		uploadedThumbnail, err := intent.UploadBytes(thumbnail, thumbnailUploadMime)
 		if err != nil {
-			portal.log.Warnfln("Failed to upload thumbnail for %s: %v", info.ID, err)
+			portal.log.Warnfln("Failed to upload thumbnail: %v", err)
 		} else if uploadedThumbnail != nil {
 			if thumbnailFile != nil {
 				thumbnailFile.URL = uploadedThumbnail.ContentURI.CUString()
@@ -1995,6 +1990,17 @@ func (portal *Portal) convertMediaMessage(intent *appservice.IntentAPI, source *
 		}
 	}
 
+	messageWithCaption, ok := msg.(MediaMessageWithCaption)
+	var captionContent *event.MessageEventContent
+	if ok && len(messageWithCaption.GetCaption()) > 0 {
+		captionContent = &event.MessageEventContent{
+			Body:    messageWithCaption.GetCaption(),
+			MsgType: event.MsgNotice,
+		}
+
+		portal.bridge.Formatter.ParseWhatsApp(captionContent, msg.GetContextInfo().GetMentionedJid())
+	}
+
 	return &ConvertedMessage{
 		Intent:    intent,
 		Type:      eventType,
@@ -2006,6 +2012,199 @@ func (portal *Portal) convertMediaMessage(intent *appservice.IntentAPI, source *
 	}
 }
 
+func (portal *Portal) uploadMedia(intent *appservice.IntentAPI, data []byte, content *event.MessageEventContent) error {
+	data, uploadMimeType, file := portal.encryptFile(data, content.Info.MimeType)
+
+	uploaded, err := intent.UploadBytes(data, uploadMimeType)
+	if err != nil {
+		return err
+	}
+
+	if file != nil {
+		file.URL = uploaded.ContentURI.CUString()
+		content.File = file
+	} else {
+		content.URL = uploaded.ContentURI.CUString()
+	}
+
+	content.Info.Size = len(data)
+	if content.Info.Width == 0 && content.Info.Height == 0 && strings.HasPrefix(content.Info.MimeType, "image/") {
+		cfg, _, _ := image.DecodeConfig(bytes.NewReader(data))
+		content.Info.Width, content.Info.Height = cfg.Width, cfg.Height
+	}
+	return nil
+}
+
+func (portal *Portal) convertMediaMessage(intent *appservice.IntentAPI, source *User, info *types.MessageInfo, msg MediaMessage) *ConvertedMessage {
+	converted := portal.convertMediaMessageContent(intent, msg)
+	data, err := source.Client.Download(msg)
+	if errors.Is(err, whatsmeow.ErrMediaDownloadFailedWith404) || errors.Is(err, whatsmeow.ErrMediaDownloadFailedWith410) {
+		//portal.log.Warnfln("Failed to download media for %s: %v. Requesting retry", info.ID, err)
+		//err = source.Client.SendMediaRetryReceipt(info, msg.GetMediaKey())
+		//if err != nil {
+		//	portal.log.Errorfln("Failed to send media retry receipt for %s: %v", info.ID, err)
+		//}
+		converted.Error = database.MsgErrMediaNotFound
+		return portal.makeMediaBridgeFailureMessage(info, err, converted, &FailedMediaKeys{
+			Key:       msg.GetMediaKey(),
+			Length:    int(msg.GetFileLength()),
+			Type:      whatsmeow.GetMediaType(msg),
+			SHA256:    msg.GetFileSha256(),
+			EncSHA256: msg.GetFileEncSha256(),
+		})
+	} else if errors.Is(err, whatsmeow.ErrNoURLPresent) {
+		portal.log.Debugfln("No URL present error for media message %s, ignoring...", info.ID)
+		return nil
+	} else if errors.Is(err, whatsmeow.ErrFileLengthMismatch) || errors.Is(err, whatsmeow.ErrInvalidMediaSHA256) {
+		portal.log.Warnfln("Mismatching media checksums in %s: %v. Ignoring because WhatsApp seems to ignore them too", info.ID, err)
+	} else if err != nil {
+		return portal.makeMediaBridgeFailureMessage(info, err, converted, nil)
+	}
+
+	err = portal.uploadMedia(intent, data, converted.Content)
+	if err != nil {
+		if errors.Is(err, mautrix.MTooLarge) {
+			return portal.makeMediaBridgeFailureMessage(info, errors.New("homeserver rejected too large file"), converted, nil)
+		} else if httpErr, ok := err.(mautrix.HTTPError); ok && httpErr.IsStatus(413) {
+			return portal.makeMediaBridgeFailureMessage(info, errors.New("proxy rejected too large file"), converted, nil)
+		} else {
+			return portal.makeMediaBridgeFailureMessage(info, fmt.Errorf("failed to upload media: %w", err), converted, nil)
+		}
+	}
+	return converted
+}
+
+func (portal *Portal) fetchMediaRetryEvent(msg *database.Message) (*FailedMediaMeta, error) {
+	errorMeta, ok := portal.mediaErrorCache[msg.JID]
+	if ok {
+		return errorMeta, nil
+	}
+	evt, err := portal.MainIntent().GetEvent(portal.MXID, msg.MXID)
+	if err != nil {
+		return nil, fmt.Errorf("failed to fetch event %s: %w", msg.MXID, err)
+	}
+	if evt.Type == event.EventEncrypted {
+		err = evt.Content.ParseRaw(evt.Type)
+		if err != nil {
+			return nil, fmt.Errorf("failed to parse encrypted content in %s: %w", msg.MXID, err)
+		}
+		evt, err = portal.bridge.Crypto.Decrypt(evt)
+		if err != nil {
+			return nil, fmt.Errorf("failed to decrypt event %s: %w", msg.MXID, err)
+		}
+	}
+	errorMetaResult := gjson.GetBytes(evt.Content.VeryRaw, strings.ReplaceAll(failedMediaField, ".", "\\."))
+	if !errorMetaResult.Exists() || !errorMetaResult.IsObject() {
+		return nil, fmt.Errorf("didn't find failed media metadata in %s", msg.MXID)
+	}
+	var errorMetaBytes []byte
+	if errorMetaResult.Index > 0 {
+		errorMetaBytes = evt.Content.VeryRaw[errorMetaResult.Index : errorMetaResult.Index+len(errorMetaResult.Raw)]
+	} else {
+		errorMetaBytes = []byte(errorMetaResult.Raw)
+	}
+	err = json.Unmarshal(errorMetaBytes, &errorMeta)
+	if err != nil {
+		return nil, fmt.Errorf("failed to unmarshal failed media metadata in %s: %w", msg.MXID, err)
+	}
+	return errorMeta, nil
+}
+
+func (portal *Portal) handleMediaRetry(retry *events.MediaRetry, source *User) {
+	msg := portal.bridge.DB.Message.GetByJID(portal.Key, retry.MessageID)
+	if msg == nil {
+		portal.log.Warnfln("Dropping media retry notification for unknown message %s", retry.MessageID)
+		return
+	} else if msg.Error != database.MsgErrMediaNotFound {
+		portal.log.Warnfln("Dropping media retry notification for non-errored message %s / %s", retry.MessageID, msg.MXID)
+		return
+	}
+
+	meta, err := portal.fetchMediaRetryEvent(msg)
+	if err != nil {
+		portal.log.Warnfln("Can't handle media retry notification for %s: %v", retry.MessageID, err)
+		return
+	}
+
+	retryData, err := whatsmeow.DecryptMediaRetryNotification(retry, meta.Media.Key)
+	if err != nil {
+		portal.log.Warnfln("Failed to handle media retry notification for %s: %v", retry.MessageID, err)
+		return
+	} else if retryData.GetResult() != waProto.MediaRetryNotification_SUCCESS {
+		portal.log.Warnfln("Got error response in media retry notification for %s: %s", retry.MessageID, waProto.MediaRetryNotification_MediaRetryNotificationResultType_name[int32(retryData.GetResult())])
+		return
+	}
+
+	var puppet *Puppet
+	if retry.FromMe {
+		puppet = portal.bridge.GetPuppetByJID(source.JID)
+	} else if retry.ChatID.Server == types.DefaultUserServer {
+		puppet = portal.bridge.GetPuppetByJID(retry.ChatID)
+	} else {
+		puppet = portal.bridge.GetPuppetByJID(retry.SenderID)
+	}
+	intent := puppet.IntentFor(portal)
+
+	data, err := source.Client.DownloadMediaWithPath(retryData.GetDirectPath(), meta.Media.EncSHA256, meta.Media.SHA256, meta.Media.Key, meta.Media.Length, meta.Media.Type, "")
+	if err != nil {
+		portal.log.Warnfln("Failed to download media in %s after retry notification: %v", retry.MessageID, err)
+		return
+	}
+	err = portal.uploadMedia(intent, data, meta.Content)
+	if err != nil {
+		portal.log.Warnfln("Failed to re-upload media for %s after retry notification: %v", retry.MessageID, err)
+		return
+	}
+	replaceContent := &event.MessageEventContent{
+		MsgType:    meta.Content.MsgType,
+		Body:       "* " + meta.Content.Body,
+		NewContent: meta.Content,
+		RelatesTo: &event.RelatesTo{
+			EventID: msg.MXID,
+			Type:    event.RelReplace,
+		},
+	}
+	resp, err := portal.sendMessage(intent, meta.Type, replaceContent, meta.ExtraContent, time.Now().UnixMilli())
+	if err != nil {
+		portal.log.Warnfln("Failed to edit %s after retry notification for %s: %v", msg.MXID, retry.MessageID, err)
+		return
+	}
+	portal.log.Debugfln("Successfully edited %s -> %s after retry notification for %s", msg.MXID, resp.EventID, retry.MessageID)
+	msg.UpdateMXID(resp.EventID, database.MsgNoError)
+}
+
+func (portal *Portal) requestMediaRetry(user *User, eventID id.EventID) {
+	msg := portal.bridge.DB.Message.GetByMXID(eventID)
+	if msg == nil {
+		portal.log.Debugfln("%s requested a media retry for unknown event %s", user.MXID, eventID)
+		return
+	} else if msg.Error != database.MsgErrMediaNotFound {
+		portal.log.Debugfln("%s requested a media retry for non-errored event %s", user.MXID, eventID)
+		return
+	}
+
+	evt, err := portal.fetchMediaRetryEvent(msg)
+	if err != nil {
+		portal.log.Warnfln("Can't send media retry request for %s: %v", msg.JID, err)
+		return
+	}
+
+	err = user.Client.SendMediaRetryReceipt(&types.MessageInfo{
+		ID: msg.JID,
+		MessageSource: types.MessageSource{
+			IsFromMe: msg.Sender.User == user.JID.User,
+			IsGroup:  !portal.IsPrivateChat(),
+			Sender:   msg.Sender,
+			Chat:     portal.Key.JID,
+		},
+	}, evt.Media.Key)
+	if err != nil {
+		portal.log.Warnfln("Failed to send media retry request for %s: %v", msg.JID, err)
+	} else {
+		portal.log.Debugfln("Sent media retry request for %s", msg.JID)
+	}
+}
+
 const thumbnailMaxSize = 72
 const thumbnailMinSize = 24
 
@@ -2441,7 +2640,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event) {
 	}
 	portal.MarkDisappearing(evt.ID, portal.ExpirationTime, true)
 	info := portal.generateMessageInfo(sender)
-	dbMsg := portal.markHandled(nil, info, evt.ID, false, true, false)
+	dbMsg := portal.markHandled(nil, info, evt.ID, false, true, database.MsgNoError)
 	portal.log.Debugln("Sending event", evt.ID, "to WhatsApp", info.ID)
 	ts, err := sender.Client.SendMessage(portal.Key.JID, info.ID, msg)
 	if err != nil {

+ 4 - 0
user.go

@@ -603,6 +603,10 @@ func (user *User) HandleEvent(event interface{}) {
 	case *events.Message:
 		portal := user.GetPortalByMessageSource(v.Info.MessageSource)
 		portal.messages <- PortalMessage{evt: v, source: user}
+	case *events.MediaRetry:
+		user.phoneSeen(v.Timestamp)
+		portal := user.GetPortalByJID(v.ChatID)
+		portal.mediaRetries <- PortalMediaRetry{evt: v, source: user}
 	case *events.CallOffer:
 		user.handleCallStart(v.CallCreator, v.CallID, "", v.Timestamp)
 	case *events.CallOfferNotice: