Kaynağa Gözat

Fix some bugs in missed message and initial history filling

Tulir Asokan 6 yıl önce
ebeveyn
işleme
14f96bd96f
3 değiştirilmiş dosya ile 48 ekleme ve 14 silme
  1. 4 2
      database/message.go
  2. 9 2
      portal.go
  3. 35 10
      user.go

+ 4 - 2
database/message.go

@@ -65,7 +65,7 @@ func (mq *MessageQuery) GetByMXID(mxid types.MatrixEventID) *Message {
 func (mq *MessageQuery) GetLastInChat(chat PortalKey) *Message {
 func (mq *MessageQuery) GetLastInChat(chat PortalKey) *Message {
 	msg := mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content " +
 	msg := mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content " +
 		"FROM message WHERE chat_jid=$1 AND chat_receiver=$2 ORDER BY timestamp DESC LIMIT 1", chat.JID, chat.Receiver)
 		"FROM message WHERE chat_jid=$1 AND chat_receiver=$2 ORDER BY timestamp DESC LIMIT 1", chat.JID, chat.Receiver)
-	if msg.Timestamp == 0 {
+	if msg == nil || msg.Timestamp == 0 {
 		// Old db, we don't know what the last message is.
 		// Old db, we don't know what the last message is.
 		return nil
 		return nil
 	}
 	}
@@ -128,7 +128,9 @@ func (msg *Message) encodeBinaryContent() []byte {
 }
 }
 
 
 func (msg *Message) Insert() {
 func (msg *Message) Insert() {
-	_, err := msg.db.Exec("INSERT INTO message VALUES ($1, $2, $3, $4, $5, $6)", msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, msg.Sender, msg.encodeBinaryContent())
+	_, err := msg.db.Exec("INSERT INTO message (chat_jid, chat_receiver, jid, mxid, sender, timestamp, content) " +
+		"VALUES ($1, $2, $3, $4, $5, $6, $7)",
+		msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, msg.Sender, msg.Timestamp, msg.encodeBinaryContent())
 	if err != nil {
 	if err != nil {
 		msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
 		msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
 	}
 	}

+ 9 - 2
portal.go

@@ -120,6 +120,7 @@ func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal {
 const recentlyHandledLength = 100
 const recentlyHandledLength = 100
 
 
 type PortalMessage struct {
 type PortalMessage struct {
+	chat      string
 	source    *User
 	source    *User
 	data      interface{}
 	data      interface{}
 	timestamp uint64
 	timestamp uint64
@@ -162,7 +163,9 @@ func (portal *Portal) handleMessageLoop() {
 				return
 				return
 			}
 			}
 		}
 		}
+		portal.backfillLock.Lock()
 		portal.handleMessage(msg)
 		portal.handleMessage(msg)
+		portal.backfillLock.Unlock()
 	}
 	}
 }
 }
 
 
@@ -531,7 +534,7 @@ func (portal *Portal) RestrictMetadataChanges(restrict bool) {
 	}
 	}
 }
 }
 
 
-func (portal *Portal) BackfillHistory(user *User) error {
+func (portal *Portal) BackfillHistory(user *User, lastMessageTime uint64) error {
 	if !portal.bridge.Config.Bridge.RecoverHistory {
 	if !portal.bridge.Config.Bridge.RecoverHistory {
 		return nil
 		return nil
 	}
 	}
@@ -541,6 +544,10 @@ func (portal *Portal) BackfillHistory(user *User) error {
 	if lastMessage == nil {
 	if lastMessage == nil {
 		return nil
 		return nil
 	}
 	}
+	if lastMessage.Timestamp <= lastMessageTime {
+		portal.log.Debugln("Not backfilling: no new messages")
+		return nil
+	}
 
 
 	lastMessageID := lastMessage.JID
 	lastMessageID := lastMessage.JID
 	portal.log.Infoln("Backfilling history since", lastMessageID, "for", user.MXID)
 	portal.log.Infoln("Backfilling history since", lastMessageID, "for", user.MXID)
@@ -619,7 +626,7 @@ func (portal *Portal) handleHistory(user *User, messages []interface{}) {
 			continue
 			continue
 		}
 		}
 		data := whatsapp.ParseProtoMessage(message)
 		data := whatsapp.ParseProtoMessage(message)
-		portal.handleMessage(PortalMessage{user, data, message.GetMessageTimestamp()})
+		portal.handleMessage(PortalMessage{portal.Key.JID, user, data, message.GetMessageTimestamp()})
 	}
 	}
 }
 }
 
 

+ 35 - 10
user.go

@@ -22,6 +22,7 @@ import (
 	"sort"
 	"sort"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
+	"sync"
 	"time"
 	"time"
 
 
 	"github.com/skip2/go-qrcode"
 	"github.com/skip2/go-qrcode"
@@ -50,6 +51,9 @@ type User struct {
 	Connected   bool
 	Connected   bool
 
 
 	ConnectionErrors int
 	ConnectionErrors int
+
+	messages chan PortalMessage
+	syncLock sync.Mutex
 }
 }
 
 
 func (bridge *Bridge) GetUserByMXID(userID types.MatrixUserID) *User {
 func (bridge *Bridge) GetUserByMXID(userID types.MatrixUserID) *User {
@@ -125,9 +129,12 @@ func (bridge *Bridge) NewUser(dbUser *database.User) *User {
 		User:   dbUser,
 		User:   dbUser,
 		bridge: bridge,
 		bridge: bridge,
 		log:    bridge.Log.Sub("User").Sub(string(dbUser.MXID)),
 		log:    bridge.Log.Sub("User").Sub(string(dbUser.MXID)),
+
+		messages: make(chan PortalMessage, 256),
 	}
 	}
 	user.Whitelisted = user.bridge.Config.Bridge.Permissions.IsWhitelisted(user.MXID)
 	user.Whitelisted = user.bridge.Config.Bridge.Permissions.IsWhitelisted(user.MXID)
 	user.Admin = user.bridge.Config.Bridge.Permissions.IsAdmin(user.MXID)
 	user.Admin = user.bridge.Config.Bridge.Permissions.IsAdmin(user.MXID)
+	go user.handleMessageLoop()
 	return user
 	return user
 }
 }
 
 
@@ -275,13 +282,15 @@ func (cl ChatList) Swap(i, j int) {
 }
 }
 
 
 func (user *User) PostLogin() {
 func (user *User) PostLogin() {
-	user.log.Debugln("Waiting for 3 seconds for contacts to arrive")
+	user.syncLock.Lock()
+	user.log.Debugln("Waiting a second for contacts to arrive")
 	// Hacky way to wait for chats and contacts to arrive automatically
 	// Hacky way to wait for chats and contacts to arrive automatically
-	time.Sleep(3 * time.Second)
+	time.Sleep(1 * time.Second)
 	user.log.Debugln("Waited 3 seconds:", len(user.Conn.Store.Chats), len(user.Conn.Store.Contacts))
 	user.log.Debugln("Waited 3 seconds:", len(user.Conn.Store.Chats), len(user.Conn.Store.Contacts))
 
 
-	go user.syncPortals()
 	go user.syncPuppets()
 	go user.syncPuppets()
+	user.syncPortals()
+	user.syncLock.Unlock()
 }
 }
 
 
 func (user *User) syncPortals() {
 func (user *User) syncPortals() {
@@ -307,7 +316,7 @@ func (user *User) syncPortals() {
 		create := (chat.LastMessageTime >= user.LastConnection && user.LastConnection > 0) || i < limit
 		create := (chat.LastMessageTime >= user.LastConnection && user.LastConnection > 0) || i < limit
 		if len(chat.Portal.MXID) > 0 || create {
 		if len(chat.Portal.MXID) > 0 || create {
 			chat.Portal.Sync(user, chat.Contact)
 			chat.Portal.Sync(user, chat.Contact)
-			err := chat.Portal.BackfillHistory(user)
+			err := chat.Portal.BackfillHistory(user, chat.LastMessageTime)
 			if err != nil {
 			if err != nil {
 				chat.Portal.log.Errorln("Error backfilling history:", err)
 				chat.Portal.log.Errorln("Error backfilling history:", err)
 			}
 			}
@@ -408,28 +417,44 @@ func (user *User) GetPortalByJID(jid types.WhatsAppID) *Portal {
 	return user.bridge.GetPortalByJID(user.PortalKey(jid))
 	return user.bridge.GetPortalByJID(user.PortalKey(jid))
 }
 }
 
 
+func (user *User) handleMessageLoop() {
+	for msg := range user.messages {
+		user.syncLock.Lock()
+		user.GetPortalByJID(msg.chat).messages <- msg
+		user.syncLock.Unlock()
+	}
+}
+
+func (user *User) putMessage(message PortalMessage) {
+	select {
+	case user.messages <- message:
+	default:
+		user.log.Warnln("Buffer is full, dropping message in", message.chat)
+	}
+}
+
 func (user *User) HandleTextMessage(message whatsapp.TextMessage) {
 func (user *User) HandleTextMessage(message whatsapp.TextMessage) {
-	user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message, message.Info.Timestamp}
+	user.putMessage(PortalMessage{message.Info.RemoteJid, user, message, message.Info.Timestamp})
 }
 }
 
 
 func (user *User) HandleImageMessage(message whatsapp.ImageMessage) {
 func (user *User) HandleImageMessage(message whatsapp.ImageMessage) {
-	user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message, message.Info.Timestamp}
+	user.putMessage(PortalMessage{message.Info.RemoteJid, user, message, message.Info.Timestamp})
 }
 }
 
 
 func (user *User) HandleVideoMessage(message whatsapp.VideoMessage) {
 func (user *User) HandleVideoMessage(message whatsapp.VideoMessage) {
-	user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message, message.Info.Timestamp}
+	user.putMessage(PortalMessage{message.Info.RemoteJid, user, message, message.Info.Timestamp})
 }
 }
 
 
 func (user *User) HandleAudioMessage(message whatsapp.AudioMessage) {
 func (user *User) HandleAudioMessage(message whatsapp.AudioMessage) {
-	user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message, message.Info.Timestamp}
+	user.putMessage(PortalMessage{message.Info.RemoteJid, user, message, message.Info.Timestamp})
 }
 }
 
 
 func (user *User) HandleDocumentMessage(message whatsapp.DocumentMessage) {
 func (user *User) HandleDocumentMessage(message whatsapp.DocumentMessage) {
-	user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message, message.Info.Timestamp}
+	user.putMessage(PortalMessage{message.Info.RemoteJid, user, message, message.Info.Timestamp})
 }
 }
 
 
 func (user *User) HandleMessageRevoke(message whatsappExt.MessageRevocation) {
 func (user *User) HandleMessageRevoke(message whatsappExt.MessageRevocation) {
-	user.GetPortalByJID(message.RemoteJid).messages <- PortalMessage{user, message, 0}
+	user.putMessage(PortalMessage{message.RemoteJid, user, message, 0})
 }
 }
 
 
 func (user *User) HandlePresence(info whatsappExt.Presence) {
 func (user *User) HandlePresence(info whatsappExt.Presence) {