Răsfoiți Sursa

Possibly significantly improve how portals are created and synced

Tulir Asokan 6 ani în urmă
părinte
comite
b363547bdf

+ 1 - 0
commands.go

@@ -208,6 +208,7 @@ func (handler *CommandHandler) CommandReconnect(ce *CommandEvent) {
 	ce.User.Connected = true
 	ce.User.ConnectionErrors = 0
 	ce.Reply("Reconnected successfully.")
+	go ce.User.PostLogin()
 }
 
 const cmdDisconnectHelp = `disconnect - Disconnect from WhatsApp (without logging out)`

+ 10 - 0
config/bridge.go

@@ -37,6 +37,11 @@ type BridgeConfig struct {
 	MaxConnectionAttempts int  `yaml:"max_connection_attempts"`
 	ReportConnectionRetry bool `yaml:"report_connection_retry"`
 
+	InitialChatSync    int  `yaml:"initial_chat_sync_count"`
+	InitialHistoryFill int  `yaml:"initial_history_fill_count"`
+	RecoverChatSync    int  `yaml:"recovery_chat_sync_count"`
+	RecoverHistory     bool `yaml:"recovery_history_backfill"`
+
 	CommandPrefix string `yaml:"command_prefix"`
 
 	Permissions PermissionConfig `yaml:"permissions"`
@@ -49,6 +54,11 @@ func (bc *BridgeConfig) setDefaults() {
 	bc.ConnectionTimeout = 20
 	bc.MaxConnectionAttempts = 3
 	bc.ReportConnectionRetry = true
+
+	bc.InitialChatSync = 10
+	bc.InitialHistoryFill = 20
+	bc.RecoverChatSync = -1
+	bc.RecoverHistory = true
 }
 
 type umBridgeConfig BridgeConfig

+ 21 - 8
database/message.go

@@ -53,11 +53,23 @@ func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
 }
 
 func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.WhatsAppMessageID) *Message {
-	return mq.get("SELECT * FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", chat.JID, chat.Receiver, jid)
+	return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content " +
+		"FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", chat.JID, chat.Receiver, jid)
 }
 
 func (mq *MessageQuery) GetByMXID(mxid types.MatrixEventID) *Message {
-	return mq.get("SELECT * FROM message WHERE mxid=$1", mxid)
+	return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content " +
+		"FROM message WHERE mxid=$1", mxid)
+}
+
+func (mq *MessageQuery) GetLastInChat(chat PortalKey) *Message {
+	msg := mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content " +
+		"FROM message WHERE chat_jid=$1 AND chat_receiver=$2 ORDER BY timestamp DESC LIMIT 1", chat.JID, chat.Receiver)
+	if msg.Timestamp == 0 {
+		// Old db, we don't know what the last message is.
+		return nil
+	}
+	return msg
 }
 
 func (mq *MessageQuery) get(query string, args ...interface{}) *Message {
@@ -72,16 +84,17 @@ type Message struct {
 	db  *Database
 	log log.Logger
 
-	Chat    PortalKey
-	JID     types.WhatsAppMessageID
-	MXID    types.MatrixEventID
-	Sender  types.WhatsAppID
-	Content *waProto.Message
+	Chat      PortalKey
+	JID       types.WhatsAppMessageID
+	MXID      types.MatrixEventID
+	Sender    types.WhatsAppID
+	Timestamp uint64
+	Content   *waProto.Message
 }
 
 func (msg *Message) Scan(row Scannable) *Message {
 	var content []byte
-	err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &content)
+	err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &msg.Timestamp, &content)
 	if err != nil {
 		if err != sql.ErrNoRows {
 			msg.log.Errorln("Database scan failed:", err)

+ 15 - 0
database/upgrades/2019-05-21-message-timestamp-column.go

@@ -0,0 +1,15 @@
+package upgrades
+
+import (
+	"database/sql"
+)
+
+func init() {
+	upgrades[2] = upgrade{"Add timestamp column to messages", func(dialect Dialect, tx *sql.Tx) error {
+		_, err := tx.Exec("ALTER TABLE message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0")
+		if err != nil {
+			return err
+		}
+		return nil
+	}}
+}

+ 15 - 0
database/upgrades/2019-05-22-user-last-connection-column.go

@@ -0,0 +1,15 @@
+package upgrades
+
+import (
+	"database/sql"
+)
+
+func init() {
+	upgrades[3] = upgrade{"Add last_connection column to users", func(dialect Dialect, tx *sql.Tx) error {
+		_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN last_connection BIGINT NOT NULL DEFAULT 0`)
+		if err != nil {
+			return err
+		}
+		return nil
+	}}
+}

+ 3 - 3
database/upgrades/upgrades.go

@@ -22,7 +22,7 @@ type upgrade struct {
 	fn upgradeFunc
 }
 
-var upgrades [2]upgrade
+var upgrades [4]upgrade
 
 func getVersion(dialect Dialect, db *sql.DB) (int, error) {
 	_, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
@@ -65,7 +65,7 @@ func Run(log log.Logger, dialectName string, db *sql.DB) error {
 
 	log.Infofln("Database currently on v%d, latest: v%d", version, len(upgrades))
 	for i, upgrade := range upgrades[version:] {
-		log.Infofln("Upgrading database to v%d: %s", i+1, upgrade.message)
+		log.Infofln("Upgrading database to v%d: %s", version+i+1, upgrade.message)
 		tx, err := db.Begin()
 		if err != nil {
 			return err
@@ -74,7 +74,7 @@ func Run(log log.Logger, dialectName string, db *sql.DB) error {
 		if err != nil {
 			return err
 		}
-		err = setVersion(dialect, tx, i+1)
+		err = setVersion(dialect, tx, version+i+1)
 		if err != nil {
 			return err
 		}

+ 18 - 5
database/user.go

@@ -19,6 +19,7 @@ package database
 import (
 	"database/sql"
 	"strings"
+	"time"
 
 	"github.com/Rhymen/go-whatsapp"
 
@@ -76,12 +77,14 @@ type User struct {
 	JID            types.WhatsAppID
 	ManagementRoom types.MatrixRoomID
 	Session        *whatsapp.Session
+	LastConnection uint64
 }
 
 func (user *User) Scan(row Scannable) *User {
 	var jid, clientID, clientToken, serverToken sql.NullString
 	var encKey, macKey []byte
-	err := row.Scan(&user.MXID, &jid, &user.ManagementRoom, &clientID, &clientToken, &serverToken, &encKey, &macKey)
+	err := row.Scan(&user.MXID, &jid, &user.ManagementRoom, &clientID, &clientToken, &serverToken, &encKey, &macKey,
+		&user.LastConnection)
 	if err != nil {
 		if err != sql.ErrNoRows {
 			user.log.Errorln("Database scan failed:", err)
@@ -134,18 +137,28 @@ func (user *User) sessionUnptr() (sess whatsapp.Session) {
 
 func (user *User) Insert() {
 	sess := user.sessionUnptr()
-	_, err := user.db.Exec(`INSERT INTO "user" VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`, user.MXID, user.jidPtr(),
-		user.ManagementRoom,
+	_, err := user.db.Exec(`INSERT INTO "user" (mxid, jid, management_room, last_connection, client_id, client_token, server_token, enc_key, mac_key) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
+		user.MXID, user.jidPtr(),
+		user.ManagementRoom, user.LastConnection,
 		sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey)
 	if err != nil {
 		user.log.Warnfln("Failed to insert %s: %v", user.MXID, err)
 	}
 }
 
+func (user *User) UpdateLastConnection() {
+	user.LastConnection = uint64(time.Now().Unix())
+	_, err := user.db.Exec(`UPDATE "user" SET last_connection=$1 WHERE mxid=$2`,
+		user.LastConnection, user.MXID)
+	if err != nil {
+		user.log.Warnfln("Failed to update last connection ts: %v", err)
+	}
+}
+
 func (user *User) Update() {
 	sess := user.sessionUnptr()
-	_, err := user.db.Exec(`UPDATE "user" SET jid=$1, management_room=$2, client_id=$3, client_token=$4, server_token=$5, enc_key=$6, mac_key=$7 WHERE mxid=$8`,
-		user.jidPtr(), user.ManagementRoom,
+	_, err := user.db.Exec(`UPDATE "user" SET jid=$1, management_room=$2, last_connection=$3, client_id=$4, client_token=$5, server_token=$6, enc_key=$7, mac_key=$8 WHERE mxid=$9`,
+		user.jidPtr(), user.ManagementRoom, user.LastConnection,
 		sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey,
 		user.MXID)
 	if err != nil {

+ 10 - 0
example-config.yaml

@@ -66,6 +66,16 @@ bridge:
     # If false, it will only report when it stops retrying.
     report_connection_retry: true
 
+    # Number of chats to sync for new users.
+    initial_chat_sync_count: 10
+    # Number of old messages to fill when creating new portal rooms.
+    initial_history_fill_count: 20
+    # Maximum number of chats to sync when recovering from downtime.
+    # Set to -1 to sync all new chats during downtime.
+    recovery_chat_sync_limit: -1
+    # Whether or not to sync history when recovering from downtime.
+    recovery_history_backfill: true
+
     # The prefix for commands. Only required in non-management rooms.
     command_prefix: "!wa"
 

+ 58 - 10
portal.go

@@ -30,8 +30,10 @@ import (
 	"net/http"
 	"strings"
 	"sync"
+	"time"
 
 	"github.com/Rhymen/go-whatsapp"
+	"github.com/Rhymen/go-whatsapp/binary"
 	waProto "github.com/Rhymen/go-whatsapp/binary/proto"
 
 	log "maunium.net/go/maulogger/v2"
@@ -119,8 +121,9 @@ func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal {
 const recentlyHandledLength = 100
 
 type PortalMessage struct {
-	source *User
-	data   interface{}
+	source    *User
+	data      interface{}
+	timestamp uint64
 }
 
 type Portal struct {
@@ -137,6 +140,7 @@ type Portal struct {
 	recentlyHandledLock  sync.Mutex
 	recentlyHandledIndex uint8
 
+	backfillLock  sync.Mutex
 	lastMessageTs uint64
 
 	messages chan PortalMessage
@@ -144,11 +148,13 @@ type Portal struct {
 	isPrivate *bool
 }
 
+const MaxMessageAgeToCreatePortal = 5 * 60 // 5 minutes
+
 func (portal *Portal) handleMessageLoop() {
 	for msg := range portal.messages {
 		if len(portal.MXID) == 0 {
-			_, isRevocation := msg.data.(whatsappExt.MessageRevocation)
-			if isRevocation {
+			if msg.timestamp+MaxMessageAgeToCreatePortal < uint64(time.Now().Unix()) {
+				portal.log.Debugln("Not creating portal room for incoming message as the message is too old.")
 				continue
 			}
 			err := portal.CreateMatrixRoom(msg.source)
@@ -221,6 +227,7 @@ func (portal *Portal) markHandled(source *User, message *waProto.WebMessageInfo,
 	msg.Chat = portal.Key
 	msg.JID = message.GetKey().GetId()
 	msg.MXID = mxid
+	msg.Timestamp = message.GetMessageTimestamp()
 	if message.GetKey().GetFromMe() {
 		msg.Sender = source.JID
 	} else if portal.IsPrivateChat() {
@@ -414,6 +421,7 @@ func (portal *Portal) Sync(user *User, contact whatsapp.Contact) {
 	if portal.IsPrivateChat() {
 		return
 	}
+	portal.log.Infoln("Syncing portal for", user.MXID)
 
 	if len(portal.MXID) == 0 {
 		portal.Name = contact.Name
@@ -524,15 +532,52 @@ func (portal *Portal) RestrictMetadataChanges(restrict bool) {
 	}
 }
 
-func (portal *Portal) FillHistory(user *User) error {
-	resp, err := user.Conn.LoadMessages(portal.Key.JID, "", 50)
+func (portal *Portal) BackfillHistory(user *User) error {
+	if !portal.bridge.Config.Bridge.RecoverHistory {
+		return nil
+	}
+	portal.backfillLock.Lock()
+	defer portal.backfillLock.Unlock()
+	lastMessage := portal.bridge.DB.Message.GetLastInChat(portal.Key)
+	if lastMessage == nil {
+		return nil
+	}
+
+	lastMessageID := lastMessage.JID
+	portal.log.Infoln("Backfilling history since", lastMessageID, "for", user.MXID)
+	for len(lastMessageID) > 0 {
+		portal.log.Debugln("Backfilling history: 50 messages after", lastMessageID)
+		resp, err := user.Conn.LoadMessagesAfter(portal.Key.JID, lastMessageID, 50)
+		if err != nil {
+			return err
+		}
+		lastMessageID, err = portal.handleHistory(user, resp)
+		if err != nil {
+			return err
+		}
+	}
+	portal.log.Infoln("Backfilling finished")
+	return nil
+}
+
+func (portal *Portal) FillInitialHistory(user *User) error {
+	if portal.bridge.Config.Bridge.InitialHistoryFill == 0 {
+		return nil
+	}
+	resp, err := user.Conn.LoadMessages(portal.Key.JID, "", portal.bridge.Config.Bridge.InitialHistoryFill)
 	if err != nil {
 		return err
 	}
-	messages, ok := resp.Content.([]interface{})
+	_, err = portal.handleHistory(user, resp)
+	return err
+}
+
+func (portal *Portal) handleHistory(user *User, history *binary.Node) (string, error) {
+	messages, ok := history.Content.([]interface{})
 	if !ok {
-		return fmt.Errorf("history response not list")
+		return "", fmt.Errorf("history response not a list")
 	}
+	lastID := ""
 	for _, rawMessage := range messages {
 		message, ok := rawMessage.(*waProto.WebMessageInfo)
 		if !ok {
@@ -541,8 +586,9 @@ func (portal *Portal) FillHistory(user *User) error {
 		}
 		fmt.Println("Filling history", message.GetKey(), message.GetMessageTimestamp())
 		portal.handleMessage(PortalMessage{user, whatsapp.ParseProtoMessage(message)})
+		lastID = message.GetKey().GetId()
 	}
-	return nil
+	return lastID, nil
 }
 
 func (portal *Portal) CreateMatrixRoom(user *User) error {
@@ -557,6 +603,8 @@ func (portal *Portal) CreateMatrixRoom(user *User) error {
 		return err
 	}
 
+	portal.log.Infoln("Creating Matrix room. Info source:", user.MXID)
+
 	isPrivateChat := false
 	if portal.IsPrivateChat() {
 		portal.Name = ""
@@ -592,7 +640,7 @@ func (portal *Portal) CreateMatrixRoom(user *User) error {
 	}
 	portal.MXID = resp.RoomID
 	portal.Update()
-	err = portal.FillHistory(user)
+	err = portal.FillInitialHistory(user)
 	if err != nil {
 		portal.log.Errorln("Failed to fill history:", err)
 	}

+ 97 - 7
user.go

@@ -19,6 +19,8 @@ package main
 import (
 	"encoding/json"
 	"fmt"
+	"sort"
+	"strconv"
 	"strings"
 	"time"
 
@@ -29,6 +31,7 @@ import (
 	"maunium.net/go/mautrix/format"
 
 	"github.com/Rhymen/go-whatsapp"
+	waProto "github.com/Rhymen/go-whatsapp/binary/proto"
 
 	"maunium.net/go/mautrix-whatsapp/database"
 	"maunium.net/go/mautrix-whatsapp/types"
@@ -142,6 +145,9 @@ func (user *User) SetManagementRoom(roomID types.MatrixRoomID) {
 
 func (user *User) SetSession(session *whatsapp.Session) {
 	user.Session = session
+	if session == nil {
+		user.LastConnection = 0
+	}
 	user.Update()
 }
 
@@ -188,6 +194,7 @@ func (user *User) RestoreSession() bool {
 		user.ConnectionErrors = 0
 		user.SetSession(&sess)
 		user.log.Debugln("Session restored successfully")
+		go user.PostLogin()
 	}
 	return true
 }
@@ -243,7 +250,84 @@ func (user *User) Login(ce *CommandEvent) {
 	user.ConnectionErrors = 0
 	user.JID = strings.Replace(user.Conn.Info.Wid, whatsappExt.OldUserSuffix, whatsappExt.NewUserSuffix, 1)
 	user.SetSession(&session)
-	ce.Reply("Successfully logged in. Now, you may ask for `sync [--create]`.")
+	ce.Reply("Successfully logged in, synchronizing chats...")
+	go user.PostLogin()
+}
+
+type Chat struct {
+	Portal          *Portal
+	LastMessageTime uint64
+	Contact         whatsapp.Contact
+}
+
+type ChatList []Chat
+
+func (cl ChatList) Len() int {
+	return len(cl)
+}
+
+func (cl ChatList) Less(i, j int) bool {
+	return cl[i].LastMessageTime < cl[i].LastMessageTime
+}
+
+func (cl ChatList) Swap(i, j int) {
+	cl[i], cl[j] = cl[j], cl[i]
+}
+
+func (user *User) PostLogin() {
+	user.log.Debugln("Waiting for 3 seconds for contacts to arrive")
+	// Hacky way to wait for chats and contacts to arrive automatically
+	time.Sleep(3 * time.Second)
+	user.log.Debugln("Waited 3 seconds:", len(user.Conn.Store.Chats), len(user.Conn.Store.Contacts))
+
+	go user.syncPortals()
+	go user.syncPuppets()
+}
+
+func (user *User) syncPortals() {
+	var chats ChatList
+	for _, chat := range user.Conn.Store.Chats {
+		ts, err := strconv.ParseUint(chat.LastMessageTime, 10, 64)
+		if err != nil {
+			user.log.Warnfln("Non-integer last message time in %s: %s", chat.Jid, chat.LastMessageTime)
+			continue
+		}
+		chats = append(chats, Chat{
+			Portal:          user.GetPortalByJID(chat.Jid),
+			Contact:         user.Conn.Store.Contacts[chat.Jid],
+			LastMessageTime: ts,
+		})
+	}
+	sort.Sort(chats)
+	limit := user.bridge.Config.Bridge.InitialChatSync
+	if limit < 0 {
+		limit = len(chats)
+	}
+	for i, chat := range chats {
+		create := (chat.LastMessageTime >= user.LastConnection && user.LastConnection > 0) || i < limit
+		if len(chat.Portal.MXID) > 0 || create {
+			chat.Portal.Sync(user, chat.Contact)
+			err := chat.Portal.BackfillHistory(user)
+			if err != nil {
+				chat.Portal.log.Errorln("Error backfilling history:", err)
+			}
+		}
+	}
+}
+
+func (user *User) syncPuppets() {
+	for jid, contact := range user.Conn.Store.Contacts {
+		if strings.HasSuffix(jid, whatsappExt.NewUserSuffix) {
+			puppet := user.bridge.GetPuppetByJID(contact.Jid)
+			puppet.Sync(user, contact)
+		}
+	}
+}
+
+func (user *User) updateLastConnectionIfNecessary() {
+	if user.LastConnection+60 < uint64(time.Now().Unix()) {
+		user.UpdateLastConnection()
+	}
 }
 
 func (user *User) HandleError(err error) {
@@ -282,6 +366,7 @@ func (user *User) HandleError(err error) {
 			user.ConnectionErrors = 0
 			user.Connected = true
 			_, _ = user.bridge.Bot.SendNotice(user.ManagementRoom, "Reconnected successfully")
+			go user.PostLogin()
 			return
 		}
 		user.log.Errorln("Error while trying to reconnect after disconnection:", err)
@@ -324,27 +409,27 @@ func (user *User) GetPortalByJID(jid types.WhatsAppID) *Portal {
 }
 
 func (user *User) HandleTextMessage(message whatsapp.TextMessage) {
-	user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message}
+	user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message, message.Info.Timestamp}
 }
 
 func (user *User) HandleImageMessage(message whatsapp.ImageMessage) {
-	user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message}
+	user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message, message.Info.Timestamp}
 }
 
 func (user *User) HandleVideoMessage(message whatsapp.VideoMessage) {
-	user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message}
+	user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message, message.Info.Timestamp}
 }
 
 func (user *User) HandleAudioMessage(message whatsapp.AudioMessage) {
-	user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message}
+	user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message, message.Info.Timestamp}
 }
 
 func (user *User) HandleDocumentMessage(message whatsapp.DocumentMessage) {
-	user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message}
+	user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message, message.Info.Timestamp}
 }
 
 func (user *User) HandleMessageRevoke(message whatsappExt.MessageRevocation) {
-	user.GetPortalByJID(message.RemoteJid).messages <- PortalMessage{user, message}
+	user.GetPortalByJID(message.RemoteJid).messages <- PortalMessage{user, message, 0}
 }
 
 func (user *User) HandlePresence(info whatsappExt.Presence) {
@@ -457,4 +542,9 @@ func (user *User) HandleJsonMessage(message string) {
 		return
 	}
 	user.log.Debugln("JSON message:", message)
+	user.updateLastConnectionIfNecessary()
+}
+
+func (user *User) HandleRawMessage(message *waProto.WebMessageInfo) {
+	user.updateLastConnectionIfNecessary()
 }