Forráskód Böngészése

Handle 404 errors when backfilling messages

Tulir Asokan 4 éve
szülő
commit
ca118e8678

+ 20 - 9
database/message.go

@@ -43,7 +43,7 @@ func (mq *MessageQuery) New() *Message {
 }
 
 func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
-	rows, err := mq.db.Query("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content FROM message WHERE chat_jid=$1 AND chat_receiver=$2", chat.JID, chat.Receiver)
+	rows, err := mq.db.Query("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, content FROM message WHERE chat_jid=$1 AND chat_receiver=$2", chat.JID, chat.Receiver)
 	if err != nil || rows == nil {
 		return nil
 	}
@@ -55,18 +55,19 @@ func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
 }
 
 func (mq *MessageQuery) GetByJID(chat PortalKey, jid whatsapp.MessageID) *Message {
-	return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content "+
+	return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, content "+
 		"FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", chat.JID, chat.Receiver, jid)
 }
 
 func (mq *MessageQuery) GetByMXID(mxid id.EventID) *Message {
-	return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content "+
+	return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, content "+
 		"FROM message WHERE mxid=$1", mxid)
 }
 
 func (mq *MessageQuery) GetLastInChat(chat PortalKey) *Message {
-	msg := mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content "+
-		"FROM message WHERE chat_jid=$1 AND chat_receiver=$2 ORDER BY timestamp DESC LIMIT 1", chat.JID, chat.Receiver)
+	msg := mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, content "+
+		"FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp DESC LIMIT 1",
+		chat.JID, chat.Receiver)
 	if msg == nil || msg.Timestamp == 0 {
 		// Old db, we don't know what the last message is.
 		return nil
@@ -91,6 +92,7 @@ type Message struct {
 	MXID      id.EventID
 	Sender    whatsapp.JID
 	Timestamp uint64
+	Sent      bool
 	Content   *waProto.Message
 }
 
@@ -100,7 +102,7 @@ func (msg *Message) IsFakeMXID() bool {
 
 func (msg *Message) Scan(row Scannable) *Message {
 	var content []byte
-	err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &msg.Timestamp, &content)
+	err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &msg.Timestamp, &msg.Sent, &content)
 	if err != nil {
 		if err != sql.ErrNoRows {
 			msg.log.Errorln("Database scan failed:", err)
@@ -134,14 +136,23 @@ func (msg *Message) encodeBinaryContent() []byte {
 }
 
 func (msg *Message) Insert() {
-	_, err := msg.db.Exec("INSERT INTO message (chat_jid, chat_receiver, jid, mxid, sender, timestamp, content) "+
-		"VALUES ($1, $2, $3, $4, $5, $6, $7)",
-		msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, msg.Sender, msg.Timestamp, msg.encodeBinaryContent())
+	_, err := msg.db.Exec(`INSERT INTO message
+			(chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, content)
+			VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
+		msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, msg.Sender, msg.Timestamp, msg.Sent, msg.encodeBinaryContent())
 	if err != nil {
 		msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
 	}
 }
 
+func (msg *Message) MarkSent() {
+	msg.Sent = true
+	_, err := msg.db.Exec("UPDATE message SET sent=true WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", msg.Chat.JID, msg.Chat.Receiver, msg.JID)
+	if err != nil {
+		msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
+	}
+}
+
 func (msg *Message) Delete() {
 	_, err := msg.db.Exec("DELETE FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", msg.Chat.JID, msg.Chat.Receiver, msg.JID)
 	if err != nil {

+ 12 - 0
database/upgrades/2021-02-17-message-sent-status.go

@@ -0,0 +1,12 @@
+package upgrades
+
+import (
+	"database/sql"
+)
+
+func init() {
+	upgrades[20] = upgrade{"Add sent column for messages", func(tx *sql.Tx, ctx context) error {
+		_, err := tx.Exec(`ALTER TABLE message ADD COLUMN sent BOOLEAN NOT NULL DEFAULT true`)
+		return err
+	}}
+}

+ 1 - 1
database/upgrades/upgrades.go

@@ -39,7 +39,7 @@ type upgrade struct {
 	fn      upgradeFunc
 }
 
-const NumberOfUpgrades = 20
+const NumberOfUpgrades = 21
 
 var upgrades [NumberOfUpgrades]upgrade
 

+ 14 - 13
portal.go

@@ -283,7 +283,7 @@ func init() {
 	gob.Register(&waProto.Message{})
 }
 
-func (portal *Portal) markHandled(source *User, message *waProto.WebMessageInfo, mxid id.EventID) {
+func (portal *Portal) markHandled(source *User, message *waProto.WebMessageInfo, mxid id.EventID, isSent bool) *database.Message {
 	msg := portal.bridge.DB.Message.New()
 	msg.Chat = portal.Key
 	msg.JID = message.GetKey().GetId()
@@ -300,6 +300,7 @@ func (portal *Portal) markHandled(source *User, message *waProto.WebMessageInfo,
 		}
 	}
 	msg.Content = message.Message
+	msg.Sent = isSent
 	msg.Insert()
 
 	portal.recentlyHandledLock.Lock()
@@ -307,6 +308,7 @@ func (portal *Portal) markHandled(source *User, message *waProto.WebMessageInfo,
 	portal.recentlyHandledIndex = (portal.recentlyHandledIndex + 1) % recentlyHandledLength
 	portal.recentlyHandledLock.Unlock()
 	portal.recentlyHandled[index] = msg.JID
+	return msg
 }
 
 func (portal *Portal) getMessageIntent(user *User, info whatsapp.MessageInfo) *appservice.IntentAPI {
@@ -346,7 +348,7 @@ func (portal *Portal) startHandling(source *User, info whatsapp.MessageInfo) *ap
 }
 
 func (portal *Portal) finishHandling(source *User, message *waProto.WebMessageInfo, mxid id.EventID) {
-	portal.markHandled(source, message, mxid)
+	portal.markHandled(source, message, mxid, true)
 	portal.sendDeliveryReceipt(mxid)
 	portal.log.Debugln("Handled message", message.GetKey().GetId(), "->", mxid)
 }
@@ -735,6 +737,10 @@ func (portal *Portal) BackfillHistory(user *User, lastMessageTime uint64) error
 	for len(lastMessageID) > 0 {
 		portal.log.Debugln("Fetching 50 messages of history after", lastMessageID)
 		resp, err := user.Conn.LoadMessagesAfter(portal.Key.JID, lastMessageID, lastMessageFromMe, 50)
+		if err == whatsapp.ErrServerRespondedWith404 {
+			portal.log.Warnln("Got 404 response trying to fetch messages to backfill. Fetching latest messages as fallback.")
+			resp, err = user.Conn.LoadMessagesBefore(portal.Key.JID, "", true, 50)
+		}
 		if err != nil {
 			return err
 		}
@@ -1322,7 +1328,7 @@ func (portal *Portal) HandleStubMessage(source *User, message whatsapp.StubMessa
 	if len(eventID) == 0 {
 		eventID = id.EventID(fmt.Sprintf("net.maunium.whatsapp.fake::%s", message.Info.Id))
 	}
-	portal.markHandled(source, message.Info.Source, eventID)
+	portal.markHandled(source, message.Info.Source, eventID, true)
 }
 
 func (portal *Portal) HandleLocationMessage(source *User, message whatsapp.LocationMessage) {
@@ -2087,12 +2093,12 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event) {
 	if info == nil {
 		return
 	}
-	portal.markHandled(sender, info, evt.ID)
+	dbMsg := portal.markHandled(sender, info, evt.ID, false)
 	portal.log.Debugln("Sending event", evt.ID, "to WhatsApp", info.Key.GetId())
-	portal.sendRaw(sender, evt, info)
+	portal.sendRaw(sender, evt, info, dbMsg)
 }
 
-func (portal *Portal) sendRaw(sender *User, evt *event.Event, info *waProto.WebMessageInfo) {
+func (portal *Portal) sendRaw(sender *User, evt *event.Event, info *waProto.WebMessageInfo, dbMsg *database.Message) {
 	errChan := make(chan error, 1)
 	go sender.Conn.SendRaw(info, errChan)
 
@@ -2112,16 +2118,11 @@ func (portal *Portal) sendRaw(sender *User, evt *event.Event, info *waProto.WebM
 	}
 	if err != nil {
 		portal.log.Errorfln("Error handling Matrix event %s: %v", evt.ID, err)
-		var statusResp whatsapp.StatusResponse
-		if errors.As(err, &statusResp) && statusResp.Status == 599 {
-			portal.log.Debugfln("599 status response extra data: %+v", statusResp.Extra)
-			portal.sendErrorMessage(fmt.Sprintf("%v. Please try again after a few minutes", err))
-		} else {
-			portal.sendErrorMessage(err.Error())
-		}
+		portal.sendErrorMessage(err.Error())
 	} else {
 		portal.log.Debugfln("Handled Matrix event %s", evt.ID)
 		portal.sendDeliveryReceipt(evt.ID)
+		dbMsg.MarkSent()
 	}
 	if errorEventID != "" {
 		_, err = portal.MainIntent().RedactEvent(portal.MXID, errorEventID)