Просмотр исходного кода

Fix marking messages as disappearing while backfilling on SQLite

Tulir Asokan 2 лет назад
Родитель
Сommit
4bfd3bd644
4 измененных файлов с 15 добавлено и 11 удалено
  1. 5 2
      database/disappearingmessage.go
  2. 3 2
      disappear.go
  3. 2 2
      historysync.go
  4. 5 5
      portal.go

+ 5 - 2
database/disappearingmessage.go

@@ -112,13 +112,16 @@ func (msg *DisappearingMessage) Scan(row dbutil.Scannable) *DisappearingMessage
 	return msg
 }
 
-func (msg *DisappearingMessage) Insert() {
+func (msg *DisappearingMessage) Insert(txn dbutil.Execable) {
+	if txn == nil {
+		txn = msg.db
+	}
 	var expireAt sql.NullInt64
 	if !msg.ExpireAt.IsZero() {
 		expireAt.Valid = true
 		expireAt.Int64 = msg.ExpireAt.UnixMilli()
 	}
-	_, err := msg.db.Exec(`INSERT INTO disappearing_message (room_id, event_id, expire_in, expire_at) VALUES ($1, $2, $3, $4)`,
+	_, err := txn.Exec(`INSERT INTO disappearing_message (room_id, event_id, expire_in, expire_at) VALUES ($1, $2, $3, $4)`,
 		msg.RoomID, msg.EventID, msg.ExpireIn.Milliseconds(), expireAt)
 	if err != nil {
 		msg.log.Warnfln("Failed to insert %s/%s: %v", msg.RoomID, msg.EventID, err)

+ 3 - 2
disappear.go

@@ -22,17 +22,18 @@ import (
 
 	"maunium.net/go/mautrix"
 	"maunium.net/go/mautrix/id"
+	"maunium.net/go/mautrix/util/dbutil"
 
 	"maunium.net/go/mautrix-whatsapp/database"
 )
 
-func (portal *Portal) MarkDisappearing(eventID id.EventID, expiresIn uint32, startNow bool) {
+func (portal *Portal) MarkDisappearing(txn dbutil.Execable, eventID id.EventID, expiresIn uint32, startNow bool) {
 	if expiresIn == 0 || (!portal.bridge.Config.Bridge.DisappearingMessagesInGroups && portal.IsGroupChat()) {
 		return
 	}
 
 	msg := portal.bridge.DB.DisappearingMessage.NewWithValues(portal.MXID, eventID, time.Duration(expiresIn)*time.Second, startNow)
-	msg.Insert()
+	msg.Insert(txn)
 	if startNow {
 		go portal.sleepAndDelete(msg)
 	}

+ 2 - 2
historysync.go

@@ -788,10 +788,10 @@ func (portal *Portal) finishBatch(txn dbutil.Transaction, eventIDs []id.EventID,
 			if info.ExpirationStart > 0 {
 				remainingSeconds := time.Unix(int64(info.ExpirationStart), 0).Add(time.Duration(info.ExpiresIn) * time.Second).Sub(time.Now()).Seconds()
 				portal.log.Debugfln("Disappearing history sync message: expires in %d, started at %d, remaining %d", info.ExpiresIn, info.ExpirationStart, int(remainingSeconds))
-				portal.MarkDisappearing(eventID, uint32(remainingSeconds), true)
+				portal.MarkDisappearing(txn, eventID, uint32(remainingSeconds), true)
 			} else {
 				portal.log.Debugfln("Disappearing history sync message: expires in %d (not started)", info.ExpiresIn)
-				portal.MarkDisappearing(eventID, info.ExpiresIn, false)
+				portal.MarkDisappearing(txn, eventID, info.ExpiresIn, false)
 			}
 		}
 	}

+ 5 - 5
portal.go

@@ -743,7 +743,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
 		var eventID id.EventID
 		var lastEventID id.EventID
 		if existingMsg != nil {
-			portal.MarkDisappearing(existingMsg.MXID, converted.ExpiresIn, false)
+			portal.MarkDisappearing(nil, existingMsg.MXID, converted.ExpiresIn, false)
 			converted.Content.SetEdit(existingMsg.MXID)
 		} else if converted.ReplyTo != nil {
 			portal.SetReply(converted.Content, converted.ReplyTo, false)
@@ -758,7 +758,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
 			portal.log.Errorfln("Failed to send %s to Matrix: %v", msgID, err)
 		} else {
 			if editTargetMsg == nil {
-				portal.MarkDisappearing(resp.EventID, converted.ExpiresIn, false)
+				portal.MarkDisappearing(nil, resp.EventID, converted.ExpiresIn, false)
 			}
 			eventID = resp.EventID
 			lastEventID = eventID
@@ -769,7 +769,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
 			if err != nil {
 				portal.log.Errorfln("Failed to send caption of %s to Matrix: %v", msgID, err)
 			} else {
-				portal.MarkDisappearing(resp.EventID, converted.ExpiresIn, false)
+				portal.MarkDisappearing(nil, resp.EventID, converted.ExpiresIn, false)
 				lastEventID = resp.EventID
 			}
 		}
@@ -779,7 +779,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
 				if err != nil {
 					portal.log.Errorfln("Failed to send sub-event %d of %s to Matrix: %v", index+1, msgID, err)
 				} else {
-					portal.MarkDisappearing(resp.EventID, converted.ExpiresIn, false)
+					portal.MarkDisappearing(nil, resp.EventID, converted.ExpiresIn, false)
 					lastEventID = resp.EventID
 				}
 			}
@@ -3502,7 +3502,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing
 	}
 	dbMsgType := database.MsgNormal
 	if msg.EditedMessage == nil {
-		portal.MarkDisappearing(origEvtID, portal.ExpirationTime, true)
+		portal.MarkDisappearing(nil, origEvtID, portal.ExpirationTime, true)
 	} else {
 		dbMsgType = database.MsgEdit
 	}