Selaa lähdekoodia

Merge pull request #494 from mautrix/sumner/bri-3309

backfill: remove intermediate queues and unblock re-checking DB
Sumner Evans 3 vuotta sitten
vanhempi
sitoutus
5b73ba7efd

+ 54 - 24
backfillqueue.go

@@ -20,43 +20,73 @@ import (
 	"time"
 
 	log "maunium.net/go/maulogger/v2"
+	"maunium.net/go/mautrix/id"
 
 	"maunium.net/go/mautrix-whatsapp/database"
 )
 
 type BackfillQueue struct {
-	BackfillQuery             *database.BackfillQuery
-	ImmediateBackfillRequests chan *database.Backfill
-	DeferredBackfillRequests  chan *database.Backfill
-	ReCheckQueue              chan bool
+	BackfillQuery   *database.BackfillQuery
+	reCheckChannels []chan bool
+	log             log.Logger
+}
 
-	log log.Logger
+func (bq *BackfillQueue) ReCheck() {
+	bq.log.Info("Sending re-checks to %d channels", len(bq.reCheckChannels))
+	for _, channel := range bq.reCheckChannels {
+		go func(c chan bool) {
+			c <- true
+		}(channel)
+	}
 }
 
-// RunLoop fetches backfills from the database, prioritizing immediate and forward backfills
-func (bq *BackfillQueue) RunLoop(user *User) {
+func (bq *BackfillQueue) GetNextBackfill(userID id.UserID, backfillTypes []database.BackfillType, reCheckChannel chan bool) *database.Backfill {
 	for {
-		if backfill := bq.BackfillQuery.GetNext(user.MXID); backfill != nil {
-			if backfill.BackfillType == database.BackfillImmediate || backfill.BackfillType == database.BackfillForward {
-				bq.ImmediateBackfillRequests <- backfill
-			} else if backfill.BackfillType == database.BackfillDeferred {
-				select {
-				case <-bq.ReCheckQueue:
-					// If a queue re-check is requested, interrupt sending the
-					// backfill request to the deferred channel so that
-					// immediate backfills can happen ASAP.
-					continue
-				case bq.DeferredBackfillRequests <- backfill:
-				}
-			} else {
-				bq.log.Debugfln("Unrecognized backfill type %d in queue", backfill.BackfillType)
-			}
-			backfill.MarkDone()
+		if backfill := bq.BackfillQuery.GetNext(userID, backfillTypes); backfill != nil {
+			backfill.MarkDispatched()
+			return backfill
 		} else {
 			select {
-			case <-bq.ReCheckQueue:
+			case <-reCheckChannel:
 			case <-time.After(time.Minute):
 			}
 		}
 	}
 }
+
+func (user *User) HandleBackfillRequestsLoop(backfillTypes []database.BackfillType) {
+	reCheckChannel := make(chan bool)
+	user.BackfillQueue.reCheckChannels = append(user.BackfillQueue.reCheckChannels, reCheckChannel)
+
+	for {
+		req := user.BackfillQueue.GetNextBackfill(user.MXID, backfillTypes, reCheckChannel)
+		user.log.Infofln("Handling backfill request %s", req)
+
+		conv := user.bridge.DB.HistorySync.GetConversation(user.MXID, req.Portal)
+		if conv == nil {
+			user.log.Debugfln("Could not find history sync conversation data for %s", req.Portal.String())
+			req.MarkDone()
+			continue
+		}
+		portal := user.GetPortalByJID(conv.PortalKey.JID)
+
+		// Update the client store with basic chat settings.
+		if conv.MuteEndTime.After(time.Now()) {
+			user.Client.Store.ChatSettings.PutMutedUntil(conv.PortalKey.JID, conv.MuteEndTime)
+		}
+		if conv.Archived {
+			user.Client.Store.ChatSettings.PutArchived(conv.PortalKey.JID, true)
+		}
+		if conv.Pinned > 0 {
+			user.Client.Store.ChatSettings.PutPinned(conv.PortalKey.JID, true)
+		}
+
+		if conv.EphemeralExpiration != nil && portal.ExpirationTime != *conv.EphemeralExpiration {
+			portal.ExpirationTime = *conv.EphemeralExpiration
+			portal.Update(nil)
+		}
+
+		user.backfillInChunks(req, conv, portal)
+		req.MarkDone()
+	}
+}

+ 5 - 5
commands.go

@@ -223,7 +223,7 @@ func (handler *CommandHandler) CommandSetRelay(ce *CommandEvent) {
 		ce.Reply("Only admins are allowed to enable relay mode on this instance of the bridge")
 	} else {
 		ce.Portal.RelayUserID = ce.User.MXID
-		ce.Portal.Update()
+		ce.Portal.Update(nil)
 		ce.Reply("Messages from non-logged-in users in this room will now be bridged through your WhatsApp account")
 	}
 }
@@ -239,7 +239,7 @@ func (handler *CommandHandler) CommandUnsetRelay(ce *CommandEvent) {
 		ce.Reply("Only admins are allowed to enable relay mode on this instance of the bridge")
 	} else {
 		ce.Portal.RelayUserID = ""
-		ce.Portal.Update()
+		ce.Portal.Update(nil)
 		ce.Reply("Messages from non-logged-in users will no longer be bridged in this room")
 	}
 }
@@ -461,7 +461,7 @@ func (handler *CommandHandler) CommandCreate(ce *CommandEvent) {
 		portal.Encrypted = true
 	}
 
-	portal.Update()
+	portal.Update(nil)
 	portal.UpdateBridgeInfo()
 
 	ce.Reply("Successfully created WhatsApp group %s", portal.Key.JID)
@@ -888,10 +888,10 @@ func (handler *CommandHandler) CommandBackfill(ce *CommandEvent) {
 			return
 		}
 	}
-	backfillMessages := ce.Portal.bridge.DB.Backfill.NewWithValues(ce.User.MXID, database.BackfillImmediate, 0, &ce.Portal.Key, nil, nil, batchSize, -1, batchDelay)
+	backfillMessages := ce.Portal.bridge.DB.Backfill.NewWithValues(ce.User.MXID, database.BackfillImmediate, 0, &ce.Portal.Key, nil, batchSize, -1, batchDelay)
 	backfillMessages.Insert()
 
-	ce.User.BackfillQueue.ReCheckQueue <- true
+	ce.User.BackfillQueue.ReCheck()
 }
 
 const cmdListHelp = `list <contacts|groups> [page] [items per page] - Get a list of all contacts and groups.`

+ 31 - 14
database/backfillqueue.go

@@ -20,6 +20,8 @@ import (
 	"database/sql"
 	"errors"
 	"fmt"
+	"strconv"
+	"strings"
 	"time"
 
 	log "maunium.net/go/maulogger/v2"
@@ -59,7 +61,7 @@ func (bq *BackfillQuery) New() *Backfill {
 	}
 }
 
-func (bq *BackfillQuery) NewWithValues(userID id.UserID, backfillType BackfillType, priority int, portal *PortalKey, timeStart *time.Time, timeEnd *time.Time, maxBatchEvents, maxTotalEvents, batchDelay int) *Backfill {
+func (bq *BackfillQuery) NewWithValues(userID id.UserID, backfillType BackfillType, priority int, portal *PortalKey, timeStart *time.Time, maxBatchEvents, maxTotalEvents, batchDelay int) *Backfill {
 	return &Backfill{
 		db:             bq.db,
 		log:            bq.log,
@@ -68,7 +70,6 @@ func (bq *BackfillQuery) NewWithValues(userID id.UserID, backfillType BackfillTy
 		Priority:       priority,
 		Portal:         portal,
 		TimeStart:      timeStart,
-		TimeEnd:        timeEnd,
 		MaxBatchEvents: maxBatchEvents,
 		MaxTotalEvents: maxTotalEvents,
 		BatchDelay:     batchDelay,
@@ -77,23 +78,28 @@ func (bq *BackfillQuery) NewWithValues(userID id.UserID, backfillType BackfillTy
 
 const (
 	getNextBackfillQuery = `
-		SELECT queue_id, user_mxid, type, priority, portal_jid, portal_receiver, time_start, time_end, max_batch_events, max_total_events, batch_delay
+		SELECT queue_id, user_mxid, type, priority, portal_jid, portal_receiver, time_start, max_batch_events, max_total_events, batch_delay
 		FROM backfill_queue
 		WHERE user_mxid=$1
-			AND completed_at IS NULL
+			AND type IN (%s)
+			AND dispatch_time IS NULL
 		ORDER BY type, priority, queue_id
 		LIMIT 1
 	`
 )
 
 // GetNext returns the next backfill to perform
-func (bq *BackfillQuery) GetNext(userID id.UserID) (backfill *Backfill) {
-	rows, err := bq.db.Query(getNextBackfillQuery, userID)
-	defer rows.Close()
+func (bq *BackfillQuery) GetNext(userID id.UserID, backfillTypes []BackfillType) (backfill *Backfill) {
+	types := []string{}
+	for _, backfillType := range backfillTypes {
+		types = append(types, strconv.Itoa(int(backfillType)))
+	}
+	rows, err := bq.db.Query(fmt.Sprintf(getNextBackfillQuery, strings.Join(types, ",")), userID)
 	if err != nil || rows == nil {
 		bq.log.Error(err)
 		return
 	}
+	defer rows.Close()
 	if rows.Next() {
 		backfill = bq.New().Scan(rows)
 	}
@@ -126,21 +132,21 @@ type Backfill struct {
 	Priority       int
 	Portal         *PortalKey
 	TimeStart      *time.Time
-	TimeEnd        *time.Time
 	MaxBatchEvents int
 	MaxTotalEvents int
 	BatchDelay     int
+	DispatchTime   *time.Time
 	CompletedAt    *time.Time
 }
 
 func (b *Backfill) String() string {
-	return fmt.Sprintf("Backfill{QueueID: %d, UserID: %s, BackfillType: %s, Priority: %d, Portal: %s, TimeStart: %s, TimeEnd: %s, MaxBatchEvents: %d, MaxTotalEvents: %d, BatchDelay: %d, CompletedAt: %s}",
-		b.QueueID, b.UserID, b.BackfillType, b.Priority, b.Portal, b.TimeStart, b.TimeEnd, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.CompletedAt,
+	return fmt.Sprintf("Backfill{QueueID: %d, UserID: %s, BackfillType: %s, Priority: %d, Portal: %s, TimeStart: %s, MaxBatchEvents: %d, MaxTotalEvents: %d, BatchDelay: %d, DispatchTime: %s, CompletedAt: %s}",
+		b.QueueID, b.UserID, b.BackfillType, b.Priority, b.Portal, b.TimeStart, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.CompletedAt, b.DispatchTime,
 	)
 }
 
 func (b *Backfill) Scan(row Scannable) *Backfill {
-	err := row.Scan(&b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart, &b.TimeEnd, &b.MaxBatchEvents, &b.MaxTotalEvents, &b.BatchDelay)
+	err := row.Scan(&b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart, &b.MaxBatchEvents, &b.MaxTotalEvents, &b.BatchDelay)
 	if err != nil {
 		if !errors.Is(err, sql.ErrNoRows) {
 			b.log.Errorln("Database scan failed:", err)
@@ -153,10 +159,10 @@ func (b *Backfill) Scan(row Scannable) *Backfill {
 func (b *Backfill) Insert() {
 	rows, err := b.db.Query(`
 		INSERT INTO backfill_queue
-			(user_mxid, type, priority, portal_jid, portal_receiver, time_start, time_end, max_batch_events, max_total_events, batch_delay, completed_at)
+			(user_mxid, type, priority, portal_jid, portal_receiver, time_start, max_batch_events, max_total_events, batch_delay, dispatch_time, completed_at)
 		VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
 		RETURNING queue_id
-	`, b.UserID, b.BackfillType, b.Priority, b.Portal.JID, b.Portal.Receiver, b.TimeStart, b.TimeEnd, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.CompletedAt)
+	`, b.UserID, b.BackfillType, b.Priority, b.Portal.JID, b.Portal.Receiver, b.TimeStart, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.DispatchTime, b.CompletedAt)
 	defer rows.Close()
 	if err != nil || !rows.Next() {
 		b.log.Warnfln("Failed to insert %v/%s with priority %d: %v", b.BackfillType, b.Portal.JID, b.Priority, err)
@@ -168,9 +174,20 @@ func (b *Backfill) Insert() {
 	}
 }
 
+func (b *Backfill) MarkDispatched() {
+	if b.QueueID == 0 {
+		b.log.Errorf("Cannot mark backfill as dispatched without queue_id. Maybe it wasn't actually inserted in the database?")
+		return
+	}
+	_, err := b.db.Exec("UPDATE backfill_queue SET dispatch_time=$1 WHERE queue_id=$2", time.Now(), b.QueueID)
+	if err != nil {
+		b.log.Warnfln("Failed to mark %s/%s as dispatched: %v", b.BackfillType, b.Priority, err)
+	}
+}
+
 func (b *Backfill) MarkDone() {
 	if b.QueueID == 0 {
-		b.log.Errorf("Cannot delete backfill without queue_id. Maybe it wasn't actually inserted in the database?")
+		b.log.Errorf("Cannot mark backfill done without queue_id. Maybe it wasn't actually inserted in the database?")
 		return
 	}
 	_, err := b.db.Exec("UPDATE backfill_queue SET completed_at=$1 WHERE queue_id=$2", time.Now(), b.QueueID)

+ 3 - 3
database/historysync.go

@@ -249,10 +249,10 @@ func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversation
 
 func (hsm *HistorySyncMessage) Insert() {
 	_, err := hsm.db.Exec(`
-		INSERT INTO history_sync_message (user_mxid, conversation_id, message_id, timestamp, data)
-		VALUES ($1, $2, $3, $4, $5)
+		INSERT INTO history_sync_message (user_mxid, conversation_id, message_id, timestamp, data, inserted_time)
+		VALUES ($1, $2, $3, $4, $5, $6)
 		ON CONFLICT (user_mxid, conversation_id, message_id) DO NOTHING
-	`, hsm.UserID, hsm.ConversationID, hsm.MessageID, hsm.Timestamp, hsm.Data)
+	`, hsm.UserID, hsm.ConversationID, hsm.MessageID, hsm.Timestamp, hsm.Data, time.Now())
 	if err != nil {
 		hsm.log.Warnfln("Failed to insert history sync message %s/%s: %v", hsm.ConversationID, hsm.Timestamp, err)
 	}

+ 23 - 6
database/message.go

@@ -178,16 +178,26 @@ func (msg *Message) Scan(row Scannable) *Message {
 	return msg
 }
 
-func (msg *Message) Insert() {
+func (msg *Message) Insert(txn *sql.Tx) {
 	var sender interface{} = msg.Sender
 	// Slightly hacky hack to allow inserting empty senders (used for post-backfill dummy events)
 	if msg.Sender.IsEmpty() {
 		sender = ""
 	}
-	_, err := msg.db.Exec(`INSERT INTO message
+	query := `
+		INSERT INTO message
 			(chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid)
-			VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)`,
-		msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID)
+		VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
+	`
+	args := []interface{}{
+		msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID,
+	}
+	var err error
+	if txn != nil {
+		_, err = txn.Exec(query, args...)
+	} else {
+		_, err = msg.db.Exec(query, args...)
+	}
 	if err != nil {
 		msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
 	}
@@ -202,11 +212,18 @@ func (msg *Message) MarkSent(ts time.Time) {
 	}
 }
 
-func (msg *Message) UpdateMXID(mxid id.EventID, newType MessageType, newError MessageErrorType) {
+func (msg *Message) UpdateMXID(txn *sql.Tx, mxid id.EventID, newType MessageType, newError MessageErrorType) {
 	msg.MXID = mxid
 	msg.Type = newType
 	msg.Error = newError
-	_, err := msg.db.Exec("UPDATE message SET mxid=$1, type=$2, error=$3 WHERE chat_jid=$4 AND chat_receiver=$5 AND jid=$6", mxid, newType, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
+	query := "UPDATE message SET mxid=$1, type=$2, error=$3 WHERE chat_jid=$4 AND chat_receiver=$5 AND jid=$6"
+	args := []interface{}{mxid, newType, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID}
+	var err error
+	if txn != nil {
+		_, err = txn.Exec(query, args...)
+	} else {
+		_, err = msg.db.Exec(query, args...)
+	}
 	if err != nil {
 		msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
 	}

+ 15 - 3
database/portal.go

@@ -191,9 +191,21 @@ func (portal *Portal) Insert() {
 	}
 }
 
-func (portal *Portal) Update() {
-	_, err := portal.db.Exec("UPDATE portal SET mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, encrypted=$6, first_event_id=$7, next_batch_id=$8, relay_user_id=$9, expiration_time=$10 WHERE jid=$11 AND receiver=$12",
-		portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL.String(), portal.Encrypted, portal.FirstEventID.String(), portal.NextBatchID.String(), portal.relayUserPtr(), portal.ExpirationTime, portal.Key.JID, portal.Key.Receiver)
+func (portal *Portal) Update(txn *sql.Tx) {
+	query := `
+		UPDATE portal
+		SET mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, encrypted=$6, first_event_id=$7, next_batch_id=$8, relay_user_id=$9, expiration_time=$10
+		WHERE jid=$11 AND receiver=$12
+	`
+	args := []interface{}{
+		portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL.String(), portal.Encrypted, portal.FirstEventID.String(), portal.NextBatchID.String(), portal.relayUserPtr(), portal.ExpirationTime, portal.Key.JID, portal.Key.Receiver,
+	}
+	var err error
+	if txn != nil {
+		_, err = txn.Exec(query, args...)
+	} else {
+		_, err = portal.db.Exec(query, args...)
+	}
 	if err != nil {
 		portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
 	}

+ 34 - 0
database/upgrades/2022-05-12-backfillqueue-dispatch-time.go

@@ -0,0 +1,34 @@
+package upgrades
+
+import (
+	"database/sql"
+)
+
+func init() {
+	upgrades[44] = upgrade{"Add dispatch time to backfill queue", func(tx *sql.Tx, ctx context) error {
+		// First, add dispatch_time TIMESTAMP column
+		_, err := tx.Exec(`
+			ALTER TABLE backfill_queue
+			ADD COLUMN dispatch_time TIMESTAMP
+		`)
+		if err != nil {
+			return err
+		}
+
+		// For all previous jobs, set dispatch time to the completed time.
+		_, err = tx.Exec(`
+			UPDATE backfill_queue
+				SET dispatch_time=completed_at
+		`)
+		if err != nil {
+			return err
+		}
+
+		// Remove time_end from the backfill queue
+		_, err = tx.Exec(`
+			ALTER TABLE backfill_queue
+			DROP COLUMN time_end
+		`)
+		return err
+	}}
+}

+ 16 - 0
database/upgrades/2022-05-12-history-sync-message-add-added-timestamp.go

@@ -0,0 +1,16 @@
+package upgrades
+
+import (
+	"database/sql"
+)
+
+func init() {
+	upgrades[45] = upgrade{"Add inserted time to history sync message", func(tx *sql.Tx, ctx context) error {
+		// Add the inserted time TIMESTAMP column to history_sync_message
+		_, err := tx.Exec(`
+			ALTER TABLE history_sync_message
+			ADD COLUMN inserted_time TIMESTAMP
+		`)
+		return err
+	}}
+}

+ 1 - 1
database/upgrades/upgrades.go

@@ -40,7 +40,7 @@ type upgrade struct {
 	fn      upgradeFunc
 }
 
-const NumberOfUpgrades = 44
+const NumberOfUpgrades = 46
 
 var upgrades [NumberOfUpgrades]upgrade
 

+ 52 - 91
historysync.go

@@ -17,6 +17,7 @@
 package main
 
 import (
+	"database/sql"
 	"fmt"
 	"time"
 
@@ -50,27 +51,22 @@ func (user *User) handleHistorySyncsLoop() {
 		return
 	}
 
-	reCheckQueue := make(chan bool, 1)
 	// Start the backfill queue.
 	user.BackfillQueue = &BackfillQueue{
-		BackfillQuery:             user.bridge.DB.Backfill,
-		ImmediateBackfillRequests: make(chan *database.Backfill, 1),
-		DeferredBackfillRequests:  make(chan *database.Backfill, 1),
-		ReCheckQueue:              make(chan bool, 1),
-		log:                       user.log.Sub("BackfillQueue"),
+		BackfillQuery:   user.bridge.DB.Backfill,
+		reCheckChannels: []chan bool{},
+		log:             user.log.Sub("BackfillQueue"),
 	}
-	reCheckQueue = user.BackfillQueue.ReCheckQueue
 
 	// Immediate backfills can be done in parallel
 	for i := 0; i < user.bridge.Config.Bridge.HistorySync.Immediate.WorkerCount; i++ {
-		go user.handleBackfillRequestsLoop(user.BackfillQueue.ImmediateBackfillRequests)
+		go user.HandleBackfillRequestsLoop([]database.BackfillType{database.BackfillImmediate, database.BackfillForward})
 	}
 
 	// Deferred backfills should be handled synchronously so as not to
 	// overload the homeserver. Users can configure their backfill stages
 	// to be more or less aggressive with backfilling at this stage.
-	go user.handleBackfillRequestsLoop(user.BackfillQueue.DeferredBackfillRequests)
-	go user.BackfillQueue.RunLoop(user)
+	go user.HandleBackfillRequestsLoop([]database.BackfillType{database.BackfillDeferred})
 
 	if user.bridge.Config.Bridge.HistorySync.MediaRequests.AutoRequestMedia &&
 		user.bridge.Config.Bridge.HistorySync.MediaRequests.RequestMethod == config.MediaRequestMethodLocalTime {
@@ -80,7 +76,7 @@ func (user *User) handleHistorySyncsLoop() {
 	// Always save the history syncs for the user. If they want to enable
 	// backfilling in the future, we will have it in the database.
 	for evt := range user.historySyncs {
-		user.handleHistorySync(reCheckQueue, evt.Data)
+		user.handleHistorySync(user.BackfillQueue, evt.Data)
 	}
 }
 
@@ -130,36 +126,6 @@ func (user *User) dailyMediaRequestLoop() {
 	}
 }
 
-func (user *User) handleBackfillRequestsLoop(backfillRequests chan *database.Backfill) {
-	for req := range backfillRequests {
-		user.log.Infofln("Handling backfill request %s", req)
-		conv := user.bridge.DB.HistorySync.GetConversation(user.MXID, req.Portal)
-		if conv == nil {
-			user.log.Debugfln("Could not find history sync conversation data for %s", req.Portal.String())
-			continue
-		}
-		portal := user.GetPortalByJID(conv.PortalKey.JID)
-
-		// Update the client store with basic chat settings.
-		if conv.MuteEndTime.After(time.Now()) {
-			user.Client.Store.ChatSettings.PutMutedUntil(conv.PortalKey.JID, conv.MuteEndTime)
-		}
-		if conv.Archived {
-			user.Client.Store.ChatSettings.PutArchived(conv.PortalKey.JID, true)
-		}
-		if conv.Pinned > 0 {
-			user.Client.Store.ChatSettings.PutPinned(conv.PortalKey.JID, true)
-		}
-
-		if conv.EphemeralExpiration != nil && portal.ExpirationTime != *conv.EphemeralExpiration {
-			portal.ExpirationTime = *conv.EphemeralExpiration
-			portal.Update()
-		}
-
-		user.backfillInChunks(req, conv, portal)
-	}
-}
-
 func (user *User) backfillInChunks(req *database.Backfill, conv *database.HistorySyncConversation, portal *Portal) {
 	portal.backfillLock.Lock()
 	defer portal.backfillLock.Unlock()
@@ -169,6 +135,7 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
 	}
 
 	var forwardPrevID id.EventID
+	var timeEnd *time.Time
 	if req.BackfillType == database.BackfillForward {
 		// TODO this overrides the TimeStart set when enqueuing the backfill
 		//      maybe the enqueue should instead include the prev event ID
@@ -178,13 +145,13 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
 		req.TimeStart = &start
 	} else {
 		firstMessage := portal.bridge.DB.Message.GetFirstInChat(portal.Key)
-		if firstMessage != nil && (req.TimeEnd == nil || firstMessage.Timestamp.Before(*req.TimeEnd)) {
+		if firstMessage != nil {
 			end := firstMessage.Timestamp.Add(-1 * time.Second)
-			req.TimeEnd = &end
+			timeEnd = &end
 			user.log.Debugfln("Limiting backfill to end at %v", end)
 		}
 	}
-	allMsgs := user.bridge.DB.HistorySync.GetMessagesBetween(user.MXID, conv.ConversationID, req.TimeStart, req.TimeEnd, req.MaxTotalEvents)
+	allMsgs := user.bridge.DB.HistorySync.GetMessagesBetween(user.MXID, conv.ConversationID, req.TimeStart, timeEnd, req.MaxTotalEvents)
 
 	sendDisappearedNotice := false
 	// If expired messages are on, and a notice has not been sent to this chat
@@ -230,7 +197,7 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
 		msg.Timestamp = conv.LastMessageTimestamp
 		msg.Sent = true
 		msg.Type = database.MsgFake
-		msg.Insert()
+		msg.Insert(nil)
 		return
 	}
 
@@ -290,7 +257,7 @@ func (user *User) shouldCreatePortalForHistorySync(conv *database.HistorySyncCon
 	return false
 }
 
-func (user *User) handleHistorySync(reCheckQueue chan bool, evt *waProto.HistorySync) {
+func (user *User) handleHistorySync(backfillQueue *BackfillQueue, evt *waProto.HistorySync) {
 	if evt == nil || evt.SyncType == nil || evt.GetSyncType() == waProto.HistorySync_INITIAL_STATUS_V3 || evt.GetSyncType() == waProto.HistorySync_PUSH_NAME {
 		return
 	}
@@ -381,7 +348,7 @@ func (user *User) handleHistorySync(reCheckQueue chan bool, evt *waProto.History
 			}
 
 			// Tell the queue to check for new backfill requests.
-			reCheckQueue <- true
+			backfillQueue.ReCheck()
 		}
 	}
 }
@@ -397,7 +364,7 @@ func getConversationTimestamp(conv *waProto.Conversation) uint64 {
 func (user *User) EnqueueImmedateBackfills(portals []*Portal) {
 	for priority, portal := range portals {
 		maxMessages := user.bridge.Config.Bridge.HistorySync.Immediate.MaxEvents
-		initialBackfill := user.bridge.DB.Backfill.NewWithValues(user.MXID, database.BackfillImmediate, priority, &portal.Key, nil, nil, maxMessages, maxMessages, 0)
+		initialBackfill := user.bridge.DB.Backfill.NewWithValues(user.MXID, database.BackfillImmediate, priority, &portal.Key, nil, maxMessages, maxMessages, 0)
 		initialBackfill.Insert()
 	}
 }
@@ -412,7 +379,7 @@ func (user *User) EnqueueDeferredBackfills(portals []*Portal) {
 				startDate = &startDaysAgo
 			}
 			backfillMessages := user.bridge.DB.Backfill.NewWithValues(
-				user.MXID, database.BackfillDeferred, stageIdx*numPortals+portalIdx, &portal.Key, startDate, nil, backfillStage.MaxBatchEvents, -1, backfillStage.BatchDelay)
+				user.MXID, database.BackfillDeferred, stageIdx*numPortals+portalIdx, &portal.Key, startDate, backfillStage.MaxBatchEvents, -1, backfillStage.BatchDelay)
 			backfillMessages.Insert()
 		}
 	}
@@ -425,7 +392,7 @@ func (user *User) EnqueueForwardBackfills(portals []*Portal) {
 			continue
 		}
 		backfill := user.bridge.DB.Backfill.NewWithValues(
-			user.MXID, database.BackfillForward, priority, &portal.Key, &lastMsg.Timestamp, nil, -1, -1, 0)
+			user.MXID, database.BackfillForward, priority, &portal.Key, &lastMsg.Timestamp, -1, -1, 0)
 		backfill.Insert()
 	}
 }
@@ -558,9 +525,24 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo,
 		portal.log.Errorln("Error batch sending messages:", err)
 		return nil
 	} else {
-		portal.finishBatch(resp.EventIDs, infos)
-		portal.NextBatchID = resp.NextBatchID
-		portal.Update()
+		txn, err := portal.bridge.DB.Begin()
+		if err != nil {
+			portal.log.Errorln("Failed to start transaction to save batch messages:", err)
+			return nil
+		}
+
+		// Do the following block in the transaction
+		{
+			portal.finishBatch(txn, resp.EventIDs, infos)
+			portal.NextBatchID = resp.NextBatchID
+			portal.Update(txn)
+		}
+
+		err = txn.Commit()
+		if err != nil {
+			portal.log.Errorln("Failed to commit transaction to save batch messages:", err)
+			return nil
+		}
 		if portal.bridge.Config.Bridge.HistorySync.MediaRequests.AutoRequestMedia {
 			go portal.requestMediaRetries(source, resp.EventIDs, infos)
 		}
@@ -651,48 +633,27 @@ func (portal *Portal) wrapBatchEvent(info *types.MessageInfo, intent *appservice
 	}, nil
 }
 
-func (portal *Portal) finishBatch(eventIDs []id.EventID, infos []*wrappedInfo) {
-	if len(eventIDs) != len(infos) {
-		portal.log.Errorfln("Length of event IDs (%d) and message infos (%d) doesn't match! Using slow path for mapping event IDs", len(eventIDs), len(infos))
-		infoMap := make(map[types.MessageID]*wrappedInfo, len(infos))
-		for _, info := range infos {
-			infoMap[info.ID] = info
-		}
-		for _, eventID := range eventIDs {
-			if evt, err := portal.MainIntent().GetEvent(portal.MXID, eventID); err != nil {
-				portal.log.Warnfln("Failed to get event %s to register it in the database: %v", eventID, err)
-			} else if msgID, ok := evt.Content.Raw[backfillIDField].(string); !ok {
-				portal.log.Warnfln("Event %s doesn't include the WhatsApp message ID", eventID)
-			} else if info, ok := infoMap[types.MessageID(msgID)]; !ok {
-				portal.log.Warnfln("Didn't find info of message %s (event %s) to register it in the database", msgID, eventID)
-			} else {
-				portal.finishBatchEvt(info, eventID)
-			}
-		}
-	} else {
-		for i := 0; i < len(infos); i++ {
-			portal.finishBatchEvt(infos[i], eventIDs[i])
+func (portal *Portal) finishBatch(txn *sql.Tx, eventIDs []id.EventID, infos []*wrappedInfo) {
+	for i, info := range infos {
+		if info == nil {
+			continue
 		}
-		portal.log.Infofln("Successfully sent %d events", len(eventIDs))
-	}
-}
 
-func (portal *Portal) finishBatchEvt(info *wrappedInfo, eventID id.EventID) {
-	if info == nil {
-		return
-	}
+		eventID := eventIDs[i]
+		portal.markHandled(txn, nil, info.MessageInfo, eventID, true, false, info.Type, info.Error)
 
-	portal.markHandled(nil, info.MessageInfo, eventID, true, false, info.Type, info.Error)
-	if info.ExpiresIn > 0 {
-		if info.ExpirationStart > 0 {
-			remainingSeconds := time.Unix(int64(info.ExpirationStart), 0).Add(time.Duration(info.ExpiresIn) * time.Second).Sub(time.Now()).Seconds()
-			portal.log.Debugfln("Disappearing history sync message: expires in %d, started at %d, remaining %d", info.ExpiresIn, info.ExpirationStart, int(remainingSeconds))
-			portal.MarkDisappearing(eventID, uint32(remainingSeconds), true)
-		} else {
-			portal.log.Debugfln("Disappearing history sync message: expires in %d (not started)", info.ExpiresIn)
-			portal.MarkDisappearing(eventID, info.ExpiresIn, false)
+		if info.ExpiresIn > 0 {
+			if info.ExpirationStart > 0 {
+				remainingSeconds := time.Unix(int64(info.ExpirationStart), 0).Add(time.Duration(info.ExpiresIn) * time.Second).Sub(time.Now()).Seconds()
+				portal.log.Debugfln("Disappearing history sync message: expires in %d, started at %d, remaining %d", info.ExpiresIn, info.ExpirationStart, int(remainingSeconds))
+				portal.MarkDisappearing(eventID, uint32(remainingSeconds), true)
+			} else {
+				portal.log.Debugfln("Disappearing history sync message: expires in %d (not started)", info.ExpiresIn)
+				portal.MarkDisappearing(eventID, info.ExpiresIn, false)
+			}
 		}
 	}
+	portal.log.Infofln("Successfully sent %d events", len(eventIDs))
 }
 
 func (portal *Portal) sendPostBackfillDummy(lastTimestamp time.Time, insertionEventId id.EventID) {
@@ -714,7 +675,7 @@ func (portal *Portal) sendPostBackfillDummy(lastTimestamp time.Time, insertionEv
 	msg.Timestamp = lastTimestamp.Add(1 * time.Second)
 	msg.Sent = true
 	msg.Type = database.MsgFake
-	msg.Insert()
+	msg.Insert(nil)
 }
 
 // endregion

+ 2 - 2
matrix.go

@@ -74,7 +74,7 @@ func (mx *MatrixHandler) HandleEncryption(evt *event.Event) {
 	if portal != nil && !portal.Encrypted {
 		mx.log.Debugfln("%s enabled encryption in %s", evt.Sender, evt.RoomID)
 		portal.Encrypted = true
-		portal.Update()
+		portal.Update(nil)
 		if portal.IsPrivateChat() {
 			err := mx.as.BotIntent().EnsureJoined(portal.MXID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client})
 			if err != nil {
@@ -211,7 +211,7 @@ func (mx *MatrixHandler) createPrivatePortalFromInvite(roomID id.RoomID, inviter
 		mx.as.StateStore.SetMembership(roomID, mx.bridge.Bot.UserID, event.MembershipJoin)
 		portal.Encrypted = true
 	}
-	portal.Update()
+	portal.Update(nil)
 	portal.UpdateBridgeInfo()
 	_, _ = intent.SendNotice(roomID, "Private chat portal created")
 }

+ 23 - 22
portal.go

@@ -19,6 +19,7 @@ package main
 import (
 	"bytes"
 	"context"
+	"database/sql"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -482,7 +483,7 @@ func (portal *Portal) convertMessage(intent *appservice.IntentAPI, source *User,
 		return portal.convertGroupInviteMessage(intent, info, waMsg.GetGroupInviteMessage())
 	case waMsg.ProtocolMessage != nil && waMsg.ProtocolMessage.GetType() == waProto.ProtocolMessage_EPHEMERAL_SETTING:
 		portal.ExpirationTime = waMsg.ProtocolMessage.GetEphemeralExpiration()
-		portal.Update()
+		portal.Update(nil)
 		return &ConvertedMessage{
 			Intent: intent,
 			Type:   event.EventMessage,
@@ -501,7 +502,7 @@ func (portal *Portal) UpdateGroupDisappearingMessages(sender *types.JID, timesta
 		return
 	}
 	portal.ExpirationTime = timer
-	portal.Update()
+	portal.Update(nil)
 	intent := portal.MainIntent()
 	if sender != nil {
 		intent = portal.bridge.GetPuppetByJID(sender.ToNonAD()).IntentFor(portal)
@@ -685,7 +686,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
 			_, _ = portal.MainIntent().RedactEvent(portal.MXID, existingMsg.MXID, mautrix.ReqRedact{
 				Reason: "The undecryptable message was actually the deletion of another message",
 			})
-			existingMsg.UpdateMXID("net.maunium.whatsapp.fake::"+existingMsg.MXID, database.MsgFake, database.MsgNoError)
+			existingMsg.UpdateMXID(nil, "net.maunium.whatsapp.fake::"+existingMsg.MXID, database.MsgFake, database.MsgNoError)
 		}
 	} else {
 		portal.log.Warnfln("Unhandled message: %+v (%s)", evt.Info, msgType)
@@ -693,7 +694,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
 			_, _ = portal.MainIntent().RedactEvent(portal.MXID, existingMsg.MXID, mautrix.ReqRedact{
 				Reason: "The undecryptable message contained an unsupported message type",
 			})
-			existingMsg.UpdateMXID("net.maunium.whatsapp.fake::"+existingMsg.MXID, database.MsgFake, database.MsgNoError)
+			existingMsg.UpdateMXID(nil, "net.maunium.whatsapp.fake::"+existingMsg.MXID, database.MsgFake, database.MsgNoError)
 		}
 		return
 	}
@@ -711,7 +712,7 @@ func (portal *Portal) isRecentlyHandled(id types.MessageID, error database.Messa
 	return false
 }
 
-func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo, mxid id.EventID, isSent, recent bool, msgType database.MessageType, error database.MessageErrorType) *database.Message {
+func (portal *Portal) markHandled(txn *sql.Tx, msg *database.Message, info *types.MessageInfo, mxid id.EventID, isSent, recent bool, msgType database.MessageType, errType database.MessageErrorType) *database.Message {
 	if msg == nil {
 		msg = portal.bridge.DB.Message.New()
 		msg.Chat = portal.Key
@@ -721,13 +722,13 @@ func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo
 		msg.Sender = info.Sender
 		msg.Sent = isSent
 		msg.Type = msgType
-		msg.Error = error
+		msg.Error = errType
 		if info.IsIncomingBroadcast() {
 			msg.BroadcastListJID = info.Chat
 		}
-		msg.Insert()
+		msg.Insert(txn)
 	} else {
-		msg.UpdateMXID(mxid, msgType, error)
+		msg.UpdateMXID(txn, mxid, msgType, errType)
 	}
 
 	if recent {
@@ -735,7 +736,7 @@ func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo
 		index := portal.recentlyHandledIndex
 		portal.recentlyHandledIndex = (portal.recentlyHandledIndex + 1) % recentlyHandledLength
 		portal.recentlyHandledLock.Unlock()
-		portal.recentlyHandled[index] = recentlyHandledWrapper{msg.JID, error}
+		portal.recentlyHandled[index] = recentlyHandledWrapper{msg.JID, errType}
 	}
 	return msg
 }
@@ -756,13 +757,13 @@ func (portal *Portal) getMessageIntent(user *User, info *types.MessageInfo) *app
 	return portal.getMessagePuppet(user, info).IntentFor(portal)
 }
 
-func (portal *Portal) finishHandling(existing *database.Message, message *types.MessageInfo, mxid id.EventID, msgType database.MessageType, error database.MessageErrorType) {
-	portal.markHandled(existing, message, mxid, true, true, msgType, error)
+func (portal *Portal) finishHandling(existing *database.Message, message *types.MessageInfo, mxid id.EventID, msgType database.MessageType, errType database.MessageErrorType) {
+	portal.markHandled(nil, existing, message, mxid, true, true, msgType, errType)
 	portal.sendDeliveryReceipt(mxid)
 	var suffix string
-	if error == database.MsgErrDecryptionFailed {
+	if errType == database.MsgErrDecryptionFailed {
 		suffix = "(undecryptable message error notice)"
-	} else if error == database.MsgErrMediaNotFound {
+	} else if errType == database.MsgErrMediaNotFound {
 		suffix = "(media not found notice)"
 	}
 	portal.log.Debugfln("Handled message %s (%s) -> %s %s", message.ID, msgType, mxid, suffix)
@@ -1028,7 +1029,7 @@ func (portal *Portal) UpdateMatrixRoom(user *User, groupInfo *types.GroupInfo) b
 		update = portal.UpdateAvatar(user, types.EmptyJID, false) || update
 	}
 	if update {
-		portal.Update()
+		portal.Update(nil)
 		portal.UpdateBridgeInfo()
 	}
 	return true
@@ -1320,7 +1321,7 @@ func (portal *Portal) CreateMatrixRoom(user *User, groupInfo *types.GroupInfo, i
 		return err
 	}
 	portal.MXID = resp.RoomID
-	portal.Update()
+	portal.Update(nil)
 	portal.bridge.portalsLock.Lock()
 	portal.bridge.portalsByMXID[portal.MXID] = portal
 	portal.bridge.portalsLock.Unlock()
@@ -1338,7 +1339,7 @@ func (portal *Portal) CreateMatrixRoom(user *User, groupInfo *types.GroupInfo, i
 	if groupInfo != nil {
 		if groupInfo.IsEphemeral {
 			portal.ExpirationTime = groupInfo.DisappearingTimer
-			portal.Update()
+			portal.Update(nil)
 		}
 		portal.SyncParticipants(user, groupInfo)
 		if groupInfo.IsAnnounce {
@@ -1369,14 +1370,14 @@ func (portal *Portal) CreateMatrixRoom(user *User, groupInfo *types.GroupInfo, i
 		portal.log.Errorln("Failed to send dummy event to mark portal creation:", err)
 	} else {
 		portal.FirstEventID = firstEventResp.EventID
-		portal.Update()
+		portal.Update(nil)
 	}
 
 	if user.bridge.Config.Bridge.HistorySync.Backfill && backfill {
 		portals := []*Portal{portal}
 		user.EnqueueImmedateBackfills(portals)
 		user.EnqueueDeferredBackfills(portals)
-		user.BackfillQueue.ReCheckQueue <- true
+		user.BackfillQueue.ReCheck()
 	}
 	return nil
 }
@@ -2367,7 +2368,7 @@ func (portal *Portal) handleMediaRetry(retry *events.MediaRetry, source *User) {
 		return
 	}
 	portal.log.Debugfln("Successfully edited %s -> %s after retry notification for %s", msg.MXID, resp.EventID, retry.MessageID)
-	msg.UpdateMXID(resp.EventID, database.MsgNormal, database.MsgNoError)
+	msg.UpdateMXID(nil, resp.EventID, database.MsgNormal, database.MsgNoError)
 }
 
 func (portal *Portal) requestMediaRetry(user *User, eventID id.EventID, mediaKey []byte) (bool, error) {
@@ -2844,7 +2845,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event) {
 	}
 	portal.MarkDisappearing(evt.ID, portal.ExpirationTime, true)
 	info := portal.generateMessageInfo(sender)
-	dbMsg := portal.markHandled(nil, info, evt.ID, false, true, database.MsgNormal, database.MsgNoError)
+	dbMsg := portal.markHandled(nil, nil, info, evt.ID, false, true, database.MsgNormal, database.MsgNoError)
 	portal.log.Debugln("Sending event", evt.ID, "to WhatsApp", info.ID)
 	ts, err := sender.Client.SendMessage(portal.Key.JID, info.ID, msg)
 	if err != nil {
@@ -2888,7 +2889,7 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) error
 		return fmt.Errorf("unknown target event %s", content.RelatesTo.EventID)
 	}
 	info := portal.generateMessageInfo(sender)
-	dbMsg := portal.markHandled(nil, info, evt.ID, false, true, database.MsgReaction, database.MsgNoError)
+	dbMsg := portal.markHandled(nil, nil, info, evt.ID, false, true, database.MsgReaction, database.MsgNoError)
 	portal.upsertReaction(nil, target.JID, sender.JID, evt.ID, info.ID)
 	portal.log.Debugln("Sending reaction", evt.ID, "to WhatsApp", info.ID)
 	ts, err := portal.sendReactionToWhatsApp(sender, info.ID, target, content.RelatesTo.Key, evt.Timestamp)
@@ -3302,6 +3303,6 @@ func (portal *Portal) HandleMatrixMeta(sender *User, evt *event.Event) {
 		portal.Avatar = newID
 		portal.AvatarURL = content.URL
 		portal.UpdateBridgeInfo()
-		portal.Update()
+		portal.Update(nil)
 	}
 }

+ 2 - 2
puppet.go

@@ -280,7 +280,7 @@ func (puppet *Puppet) updatePortalAvatar() {
 		}
 		portal.AvatarURL = puppet.AvatarURL
 		portal.Avatar = puppet.Avatar
-		portal.Update()
+		portal.Update(nil)
 	})
 }
 
@@ -293,7 +293,7 @@ func (puppet *Puppet) updatePortalName() {
 			}
 		}
 		portal.Name = puppet.Displayname
-		portal.Update()
+		portal.Update(nil)
 	})
 }