Browse Source

backfill: perform batch finish in transaction

Sumner Evans 3 năm trước cách đây
mục cha
commit
f3f6d88e55
7 tập tin đã thay đổi với 105 bổ sung80 xóa
  1. 3 3
      commands.go
  2. 23 6
      database/message.go
  3. 15 3
      database/portal.go
  4. 38 43
      historysync.go
  5. 2 2
      matrix.go
  6. 22 21
      portal.go
  7. 2 2
      puppet.go

+ 3 - 3
commands.go

@@ -209,7 +209,7 @@ func (handler *CommandHandler) CommandSetRelay(ce *CommandEvent) {
 		ce.Reply("Only admins are allowed to enable relay mode on this instance of the bridge")
 	} else {
 		ce.Portal.RelayUserID = ce.User.MXID
-		ce.Portal.Update()
+		ce.Portal.Update(nil)
 		ce.Reply("Messages from non-logged-in users in this room will now be bridged through your WhatsApp account")
 	}
 }
@@ -225,7 +225,7 @@ func (handler *CommandHandler) CommandUnsetRelay(ce *CommandEvent) {
 		ce.Reply("Only admins are allowed to enable relay mode on this instance of the bridge")
 	} else {
 		ce.Portal.RelayUserID = ""
-		ce.Portal.Update()
+		ce.Portal.Update(nil)
 		ce.Reply("Messages from non-logged-in users will no longer be bridged in this room")
 	}
 }
@@ -447,7 +447,7 @@ func (handler *CommandHandler) CommandCreate(ce *CommandEvent) {
 		portal.Encrypted = true
 	}
 
-	portal.Update()
+	portal.Update(nil)
 	portal.UpdateBridgeInfo()
 
 	ce.Reply("Successfully created WhatsApp group %s", portal.Key.JID)

+ 23 - 6
database/message.go

@@ -178,16 +178,26 @@ func (msg *Message) Scan(row Scannable) *Message {
 	return msg
 }
 
-func (msg *Message) Insert() {
+func (msg *Message) Insert(txn *sql.Tx) {
 	var sender interface{} = msg.Sender
 	// Slightly hacky hack to allow inserting empty senders (used for post-backfill dummy events)
 	if msg.Sender.IsEmpty() {
 		sender = ""
 	}
-	_, err := msg.db.Exec(`INSERT INTO message
+	query := `
+		INSERT INTO message
 			(chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid)
-			VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)`,
-		msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID)
+		VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
+	`
+	args := []interface{}{
+		msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID,
+	}
+	var err error
+	if txn != nil {
+		_, err = txn.Exec(query, args...)
+	} else {
+		_, err = msg.db.Exec(query, args...)
+	}
 	if err != nil {
 		msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
 	}
@@ -202,11 +212,18 @@ func (msg *Message) MarkSent(ts time.Time) {
 	}
 }
 
-func (msg *Message) UpdateMXID(mxid id.EventID, newType MessageType, newError MessageErrorType) {
+func (msg *Message) UpdateMXID(txn *sql.Tx, mxid id.EventID, newType MessageType, newError MessageErrorType) {
 	msg.MXID = mxid
 	msg.Type = newType
 	msg.Error = newError
-	_, err := msg.db.Exec("UPDATE message SET mxid=$1, type=$2, error=$3 WHERE chat_jid=$4 AND chat_receiver=$5 AND jid=$6", mxid, newType, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
+	query := "UPDATE message SET mxid=$1, type=$2, error=$3 WHERE chat_jid=$4 AND chat_receiver=$5 AND jid=$6"
+	args := []interface{}{mxid, newType, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID}
+	var err error
+	if txn != nil {
+		_, err = txn.Exec(query, args...)
+	} else {
+		_, err = msg.db.Exec(query, args...)
+	}
 	if err != nil {
 		msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
 	}

+ 15 - 3
database/portal.go

@@ -191,9 +191,21 @@ func (portal *Portal) Insert() {
 	}
 }
 
-func (portal *Portal) Update() {
-	_, err := portal.db.Exec("UPDATE portal SET mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, encrypted=$6, first_event_id=$7, next_batch_id=$8, relay_user_id=$9, expiration_time=$10 WHERE jid=$11 AND receiver=$12",
-		portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL.String(), portal.Encrypted, portal.FirstEventID.String(), portal.NextBatchID.String(), portal.relayUserPtr(), portal.ExpirationTime, portal.Key.JID, portal.Key.Receiver)
+func (portal *Portal) Update(txn *sql.Tx) {
+	query := `
+		UPDATE portal
+		SET mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, encrypted=$6, first_event_id=$7, next_batch_id=$8, relay_user_id=$9, expiration_time=$10
+		WHERE jid=$11 AND receiver=$12
+	`
+	args := []interface{}{
+		portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL.String(), portal.Encrypted, portal.FirstEventID.String(), portal.NextBatchID.String(), portal.relayUserPtr(), portal.ExpirationTime, portal.Key.JID, portal.Key.Receiver,
+	}
+	var err error
+	if txn != nil {
+		_, err = txn.Exec(query, args...)
+	} else {
+		_, err = portal.db.Exec(query, args...)
+	}
 	if err != nil {
 		portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
 	}

+ 38 - 43
historysync.go

@@ -17,6 +17,7 @@
 package main
 
 import (
+	"database/sql"
 	"fmt"
 	"time"
 
@@ -154,7 +155,7 @@ func (user *User) handleBackfillRequestsLoop(backfillRequests chan *database.Bac
 
 		if conv.EphemeralExpiration != nil && portal.ExpirationTime != *conv.EphemeralExpiration {
 			portal.ExpirationTime = *conv.EphemeralExpiration
-			portal.Update()
+			portal.Update(nil)
 		}
 
 		user.backfillInChunks(req, conv, portal)
@@ -233,7 +234,7 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
 		msg.Timestamp = conv.LastMessageTimestamp
 		msg.Sent = true
 		msg.Type = database.MsgFake
-		msg.Insert()
+		msg.Insert(nil)
 		return
 	}
 
@@ -561,9 +562,24 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo,
 		portal.log.Errorln("Error batch sending messages:", err)
 		return nil
 	} else {
-		portal.finishBatch(resp.EventIDs, infos)
-		portal.NextBatchID = resp.NextBatchID
-		portal.Update()
+		txn, err := portal.bridge.DB.Begin()
+		if err != nil {
+			portal.log.Errorln("Failed to start transaction to save batch messages:", err)
+			return nil
+		}
+
+		// Do the following block in the transaction
+		{
+			portal.finishBatch(txn, resp.EventIDs, infos)
+			portal.NextBatchID = resp.NextBatchID
+			portal.Update(txn)
+		}
+
+		err = txn.Commit()
+		if err != nil {
+			portal.log.Errorln("Failed to commit transaction to save batch messages:", err)
+			return nil
+		}
 		if portal.bridge.Config.Bridge.HistorySync.MediaRequests.AutoRequestMedia {
 			go portal.requestMediaRetries(source, resp.EventIDs, infos)
 		}
@@ -654,48 +670,27 @@ func (portal *Portal) wrapBatchEvent(info *types.MessageInfo, intent *appservice
 	}, nil
 }
 
-func (portal *Portal) finishBatch(eventIDs []id.EventID, infos []*wrappedInfo) {
-	if len(eventIDs) != len(infos) {
-		portal.log.Errorfln("Length of event IDs (%d) and message infos (%d) doesn't match! Using slow path for mapping event IDs", len(eventIDs), len(infos))
-		infoMap := make(map[types.MessageID]*wrappedInfo, len(infos))
-		for _, info := range infos {
-			infoMap[info.ID] = info
-		}
-		for _, eventID := range eventIDs {
-			if evt, err := portal.MainIntent().GetEvent(portal.MXID, eventID); err != nil {
-				portal.log.Warnfln("Failed to get event %s to register it in the database: %v", eventID, err)
-			} else if msgID, ok := evt.Content.Raw[backfillIDField].(string); !ok {
-				portal.log.Warnfln("Event %s doesn't include the WhatsApp message ID", eventID)
-			} else if info, ok := infoMap[types.MessageID(msgID)]; !ok {
-				portal.log.Warnfln("Didn't find info of message %s (event %s) to register it in the database", msgID, eventID)
-			} else {
-				portal.finishBatchEvt(info, eventID)
-			}
-		}
-	} else {
-		for i := 0; i < len(infos); i++ {
-			portal.finishBatchEvt(infos[i], eventIDs[i])
+func (portal *Portal) finishBatch(txn *sql.Tx, eventIDs []id.EventID, infos []*wrappedInfo) {
+	for i, info := range infos {
+		if info == nil {
+			continue
 		}
-		portal.log.Infofln("Successfully sent %d events", len(eventIDs))
-	}
-}
 
-func (portal *Portal) finishBatchEvt(info *wrappedInfo, eventID id.EventID) {
-	if info == nil {
-		return
-	}
+		eventID := eventIDs[i]
+		portal.markHandled(txn, nil, info.MessageInfo, eventID, true, false, info.Type, info.Error)
 
-	portal.markHandled(nil, info.MessageInfo, eventID, true, false, info.Type, info.Error)
-	if info.ExpiresIn > 0 {
-		if info.ExpirationStart > 0 {
-			remainingSeconds := time.Unix(int64(info.ExpirationStart), 0).Add(time.Duration(info.ExpiresIn) * time.Second).Sub(time.Now()).Seconds()
-			portal.log.Debugfln("Disappearing history sync message: expires in %d, started at %d, remaining %d", info.ExpiresIn, info.ExpirationStart, int(remainingSeconds))
-			portal.MarkDisappearing(eventID, uint32(remainingSeconds), true)
-		} else {
-			portal.log.Debugfln("Disappearing history sync message: expires in %d (not started)", info.ExpiresIn)
-			portal.MarkDisappearing(eventID, info.ExpiresIn, false)
+		if info.ExpiresIn > 0 {
+			if info.ExpirationStart > 0 {
+				remainingSeconds := time.Unix(int64(info.ExpirationStart), 0).Add(time.Duration(info.ExpiresIn) * time.Second).Sub(time.Now()).Seconds()
+				portal.log.Debugfln("Disappearing history sync message: expires in %d, started at %d, remaining %d", info.ExpiresIn, info.ExpirationStart, int(remainingSeconds))
+				portal.MarkDisappearing(eventID, uint32(remainingSeconds), true)
+			} else {
+				portal.log.Debugfln("Disappearing history sync message: expires in %d (not started)", info.ExpiresIn)
+				portal.MarkDisappearing(eventID, info.ExpiresIn, false)
+			}
 		}
 	}
+	portal.log.Infofln("Successfully sent %d events", len(eventIDs))
 }
 
 func (portal *Portal) sendPostBackfillDummy(lastTimestamp time.Time, insertionEventId id.EventID) {
@@ -717,7 +712,7 @@ func (portal *Portal) sendPostBackfillDummy(lastTimestamp time.Time, insertionEv
 	msg.Timestamp = lastTimestamp.Add(1 * time.Second)
 	msg.Sent = true
 	msg.Type = database.MsgFake
-	msg.Insert()
+	msg.Insert(nil)
 }
 
 // endregion

+ 2 - 2
matrix.go

@@ -74,7 +74,7 @@ func (mx *MatrixHandler) HandleEncryption(evt *event.Event) {
 	if portal != nil && !portal.Encrypted {
 		mx.log.Debugfln("%s enabled encryption in %s", evt.Sender, evt.RoomID)
 		portal.Encrypted = true
-		portal.Update()
+		portal.Update(nil)
 		if portal.IsPrivateChat() {
 			err := mx.as.BotIntent().EnsureJoined(portal.MXID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client})
 			if err != nil {
@@ -211,7 +211,7 @@ func (mx *MatrixHandler) createPrivatePortalFromInvite(roomID id.RoomID, inviter
 		mx.as.StateStore.SetMembership(roomID, mx.bridge.Bot.UserID, event.MembershipJoin)
 		portal.Encrypted = true
 	}
-	portal.Update()
+	portal.Update(nil)
 	portal.UpdateBridgeInfo()
 	_, _ = intent.SendNotice(roomID, "Private chat portal created")
 }

+ 22 - 21
portal.go

@@ -19,6 +19,7 @@ package main
 import (
 	"bytes"
 	"context"
+	"database/sql"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -482,7 +483,7 @@ func (portal *Portal) convertMessage(intent *appservice.IntentAPI, source *User,
 		return portal.convertGroupInviteMessage(intent, info, waMsg.GetGroupInviteMessage())
 	case waMsg.ProtocolMessage != nil && waMsg.ProtocolMessage.GetType() == waProto.ProtocolMessage_EPHEMERAL_SETTING:
 		portal.ExpirationTime = waMsg.ProtocolMessage.GetEphemeralExpiration()
-		portal.Update()
+		portal.Update(nil)
 		return &ConvertedMessage{
 			Intent: intent,
 			Type:   event.EventMessage,
@@ -498,7 +499,7 @@ func (portal *Portal) convertMessage(intent *appservice.IntentAPI, source *User,
 
 func (portal *Portal) UpdateGroupDisappearingMessages(sender *types.JID, timestamp time.Time, timer uint32) {
 	portal.ExpirationTime = timer
-	portal.Update()
+	portal.Update(nil)
 	intent := portal.MainIntent()
 	if sender != nil {
 		intent = portal.bridge.GetPuppetByJID(sender.ToNonAD()).IntentFor(portal)
@@ -676,7 +677,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
 			_, _ = portal.MainIntent().RedactEvent(portal.MXID, existingMsg.MXID, mautrix.ReqRedact{
 				Reason: "The undecryptable message was actually the deletion of another message",
 			})
-			existingMsg.UpdateMXID("net.maunium.whatsapp.fake::"+existingMsg.MXID, database.MsgFake, database.MsgNoError)
+			existingMsg.UpdateMXID(nil, "net.maunium.whatsapp.fake::"+existingMsg.MXID, database.MsgFake, database.MsgNoError)
 		}
 	} else {
 		portal.log.Warnfln("Unhandled message: %+v (%s)", evt.Info, msgType)
@@ -684,7 +685,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
 			_, _ = portal.MainIntent().RedactEvent(portal.MXID, existingMsg.MXID, mautrix.ReqRedact{
 				Reason: "The undecryptable message contained an unsupported message type",
 			})
-			existingMsg.UpdateMXID("net.maunium.whatsapp.fake::"+existingMsg.MXID, database.MsgFake, database.MsgNoError)
+			existingMsg.UpdateMXID(nil, "net.maunium.whatsapp.fake::"+existingMsg.MXID, database.MsgFake, database.MsgNoError)
 		}
 		return
 	}
@@ -702,7 +703,7 @@ func (portal *Portal) isRecentlyHandled(id types.MessageID, error database.Messa
 	return false
 }
 
-func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo, mxid id.EventID, isSent, recent bool, msgType database.MessageType, error database.MessageErrorType) *database.Message {
+func (portal *Portal) markHandled(txn *sql.Tx, msg *database.Message, info *types.MessageInfo, mxid id.EventID, isSent, recent bool, msgType database.MessageType, errType database.MessageErrorType) *database.Message {
 	if msg == nil {
 		msg = portal.bridge.DB.Message.New()
 		msg.Chat = portal.Key
@@ -712,13 +713,13 @@ func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo
 		msg.Sender = info.Sender
 		msg.Sent = isSent
 		msg.Type = msgType
-		msg.Error = error
+		msg.Error = errType
 		if info.IsIncomingBroadcast() {
 			msg.BroadcastListJID = info.Chat
 		}
-		msg.Insert()
+		msg.Insert(txn)
 	} else {
-		msg.UpdateMXID(mxid, msgType, error)
+		msg.UpdateMXID(txn, mxid, msgType, errType)
 	}
 
 	if recent {
@@ -726,7 +727,7 @@ func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo
 		index := portal.recentlyHandledIndex
 		portal.recentlyHandledIndex = (portal.recentlyHandledIndex + 1) % recentlyHandledLength
 		portal.recentlyHandledLock.Unlock()
-		portal.recentlyHandled[index] = recentlyHandledWrapper{msg.JID, error}
+		portal.recentlyHandled[index] = recentlyHandledWrapper{msg.JID, errType}
 	}
 	return msg
 }
@@ -747,13 +748,13 @@ func (portal *Portal) getMessageIntent(user *User, info *types.MessageInfo) *app
 	return portal.getMessagePuppet(user, info).IntentFor(portal)
 }
 
-func (portal *Portal) finishHandling(existing *database.Message, message *types.MessageInfo, mxid id.EventID, msgType database.MessageType, error database.MessageErrorType) {
-	portal.markHandled(existing, message, mxid, true, true, msgType, error)
+func (portal *Portal) finishHandling(existing *database.Message, message *types.MessageInfo, mxid id.EventID, msgType database.MessageType, errType database.MessageErrorType) {
+	portal.markHandled(nil, existing, message, mxid, true, true, msgType, errType)
 	portal.sendDeliveryReceipt(mxid)
 	var suffix string
-	if error == database.MsgErrDecryptionFailed {
+	if errType == database.MsgErrDecryptionFailed {
 		suffix = "(undecryptable message error notice)"
-	} else if error == database.MsgErrMediaNotFound {
+	} else if errType == database.MsgErrMediaNotFound {
 		suffix = "(media not found notice)"
 	}
 	portal.log.Debugfln("Handled message %s (%s) -> %s %s", message.ID, msgType, mxid, suffix)
@@ -1019,7 +1020,7 @@ func (portal *Portal) UpdateMatrixRoom(user *User, groupInfo *types.GroupInfo) b
 		update = portal.UpdateAvatar(user, types.EmptyJID, false) || update
 	}
 	if update {
-		portal.Update()
+		portal.Update(nil)
 		portal.UpdateBridgeInfo()
 	}
 	return true
@@ -1311,7 +1312,7 @@ func (portal *Portal) CreateMatrixRoom(user *User, groupInfo *types.GroupInfo, i
 		return err
 	}
 	portal.MXID = resp.RoomID
-	portal.Update()
+	portal.Update(nil)
 	portal.bridge.portalsLock.Lock()
 	portal.bridge.portalsByMXID[portal.MXID] = portal
 	portal.bridge.portalsLock.Unlock()
@@ -1329,7 +1330,7 @@ func (portal *Portal) CreateMatrixRoom(user *User, groupInfo *types.GroupInfo, i
 	if groupInfo != nil {
 		if groupInfo.IsEphemeral {
 			portal.ExpirationTime = groupInfo.DisappearingTimer
-			portal.Update()
+			portal.Update(nil)
 		}
 		portal.SyncParticipants(user, groupInfo)
 		if groupInfo.IsAnnounce {
@@ -1360,7 +1361,7 @@ func (portal *Portal) CreateMatrixRoom(user *User, groupInfo *types.GroupInfo, i
 		portal.log.Errorln("Failed to send dummy event to mark portal creation:", err)
 	} else {
 		portal.FirstEventID = firstEventResp.EventID
-		portal.Update()
+		portal.Update(nil)
 	}
 
 	if user.bridge.Config.Bridge.HistorySync.Backfill && backfill {
@@ -2358,7 +2359,7 @@ func (portal *Portal) handleMediaRetry(retry *events.MediaRetry, source *User) {
 		return
 	}
 	portal.log.Debugfln("Successfully edited %s -> %s after retry notification for %s", msg.MXID, resp.EventID, retry.MessageID)
-	msg.UpdateMXID(resp.EventID, database.MsgNormal, database.MsgNoError)
+	msg.UpdateMXID(nil, resp.EventID, database.MsgNormal, database.MsgNoError)
 }
 
 func (portal *Portal) requestMediaRetry(user *User, eventID id.EventID, mediaKey []byte) (bool, error) {
@@ -2835,7 +2836,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event) {
 	}
 	portal.MarkDisappearing(evt.ID, portal.ExpirationTime, true)
 	info := portal.generateMessageInfo(sender)
-	dbMsg := portal.markHandled(nil, info, evt.ID, false, true, database.MsgNormal, database.MsgNoError)
+	dbMsg := portal.markHandled(nil, nil, info, evt.ID, false, true, database.MsgNormal, database.MsgNoError)
 	portal.log.Debugln("Sending event", evt.ID, "to WhatsApp", info.ID)
 	ts, err := sender.Client.SendMessage(portal.Key.JID, info.ID, msg)
 	if err != nil {
@@ -2879,7 +2880,7 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) error
 		return fmt.Errorf("unknown target event %s", content.RelatesTo.EventID)
 	}
 	info := portal.generateMessageInfo(sender)
-	dbMsg := portal.markHandled(nil, info, evt.ID, false, true, database.MsgReaction, database.MsgNoError)
+	dbMsg := portal.markHandled(nil, nil, info, evt.ID, false, true, database.MsgReaction, database.MsgNoError)
 	portal.upsertReaction(nil, target.JID, sender.JID, evt.ID, info.ID)
 	portal.log.Debugln("Sending reaction", evt.ID, "to WhatsApp", info.ID)
 	ts, err := portal.sendReactionToWhatsApp(sender, info.ID, target, content.RelatesTo.Key, evt.Timestamp)
@@ -3293,6 +3294,6 @@ func (portal *Portal) HandleMatrixMeta(sender *User, evt *event.Event) {
 		portal.Avatar = newID
 		portal.AvatarURL = content.URL
 		portal.UpdateBridgeInfo()
-		portal.Update()
+		portal.Update(nil)
 	}
 }

+ 2 - 2
puppet.go

@@ -280,7 +280,7 @@ func (puppet *Puppet) updatePortalAvatar() {
 		}
 		portal.AvatarURL = puppet.AvatarURL
 		portal.Avatar = puppet.Avatar
-		portal.Update()
+		portal.Update(nil)
 	})
 }
 
@@ -293,7 +293,7 @@ func (puppet *Puppet) updatePortalName() {
 			}
 		}
 		portal.Name = puppet.Displayname
-		portal.Update()
+		portal.Update(nil)
 	})
 }