Browse Source

media backfill: send retry requests at the configured time

Only does the batch send of requests if the request method is 'local_time'
Sumner Evans 3 years ago
parent
commit
08e77fab29

+ 6 - 4
database/mediabackfillrequest.go

@@ -29,8 +29,8 @@ type MediaBackfillRequestStatus int
 
 const (
 	MediaBackfillRequestStatusNotRequested MediaBackfillRequestStatus = iota
-	MediaBackfillRequestStatusSuccess
-	MediaBackfillRequestStatusFailed
+	MediaBackfillRequestStatusRequested
+	MediaBackfillRequestStatusRequestFailed
 )
 
 type MediaBackfillRequestQuery struct {
@@ -64,14 +64,16 @@ func (mbrq *MediaBackfillRequestQuery) NewMediaBackfillRequestWithValues(userID
 		UserID:    userID,
 		PortalKey: portalKey,
 		EventID:   eventID,
+		Status:    MediaBackfillRequestStatusNotRequested,
 	}
 }
 
 const (
 	getMediaBackfillRequestsForUser = `
 		SELECT user_mxid, portal_jid, portal_receiver, event_id, status, error
-		  FROM media_backfill_requests
-		 WHERE user_mxid=$1
+		FROM media_backfill_requests
+		WHERE user_mxid=$1
+			AND status=0
 	`
 )
 

+ 12 - 0
database/upgrades/2022-05-11-add-user-timestamp.go

@@ -0,0 +1,12 @@
+package upgrades
+
+import (
+	"database/sql"
+)
+
+func init() {
+	upgrades[43] = upgrade{"Add timezone column to user table", func(tx *sql.Tx, ctx context) error {
+		_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN timezone TEXT`)
+		return err
+	}}
+}

+ 1 - 1
database/upgrades/upgrades.go

@@ -40,7 +40,7 @@ type upgrade struct {
 	fn      upgradeFunc
 }
 
-const NumberOfUpgrades = 43
+const NumberOfUpgrades = 44
 
 var upgrades [NumberOfUpgrades]upgrade
 

+ 9 - 8
database/user.go

@@ -44,7 +44,7 @@ func (uq *UserQuery) New() *User {
 }
 
 func (uq *UserQuery) GetAll() (users []*User) {
-	rows, err := uq.db.Query(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged FROM "user"`)
+	rows, err := uq.db.Query(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user"`)
 	if err != nil || rows == nil {
 		return nil
 	}
@@ -56,7 +56,7 @@ func (uq *UserQuery) GetAll() (users []*User) {
 }
 
 func (uq *UserQuery) GetByMXID(userID id.UserID) *User {
-	row := uq.db.QueryRow(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged FROM "user" WHERE mxid=$1`, userID)
+	row := uq.db.QueryRow(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user" WHERE mxid=$1`, userID)
 	if row == nil {
 		return nil
 	}
@@ -64,7 +64,7 @@ func (uq *UserQuery) GetByMXID(userID id.UserID) *User {
 }
 
 func (uq *UserQuery) GetByUsername(username string) *User {
-	row := uq.db.QueryRow(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged FROM "user" WHERE username=$1`, username)
+	row := uq.db.QueryRow(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user" WHERE username=$1`, username)
 	if row == nil {
 		return nil
 	}
@@ -81,6 +81,7 @@ type User struct {
 	SpaceRoom       id.RoomID
 	PhoneLastSeen   time.Time
 	PhoneLastPinged time.Time
+	Timezone        string
 
 	lastReadCache     map[PortalKey]time.Time
 	lastReadCacheLock sync.Mutex
@@ -92,7 +93,7 @@ func (user *User) Scan(row Scannable) *User {
 	var username sql.NullString
 	var device, agent sql.NullByte
 	var phoneLastSeen, phoneLastPinged sql.NullInt64
-	err := row.Scan(&user.MXID, &username, &agent, &device, &user.ManagementRoom, &user.SpaceRoom, &phoneLastSeen, &phoneLastPinged)
+	err := row.Scan(&user.MXID, &username, &agent, &device, &user.ManagementRoom, &user.SpaceRoom, &phoneLastSeen, &phoneLastPinged, &user.Timezone)
 	if err != nil {
 		if err != sql.ErrNoRows {
 			user.log.Errorln("Database scan failed:", err)
@@ -149,16 +150,16 @@ func (user *User) phoneLastPingedPtr() *int64 {
 }
 
 func (user *User) Insert() {
-	_, err := user.db.Exec(`INSERT INTO "user" (mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
-		user.MXID, user.usernamePtr(), user.agentPtr(), user.devicePtr(), user.ManagementRoom, user.SpaceRoom, user.phoneLastSeenPtr(), user.phoneLastPingedPtr())
+	_, err := user.db.Exec(`INSERT INTO "user" (mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
+		user.MXID, user.usernamePtr(), user.agentPtr(), user.devicePtr(), user.ManagementRoom, user.SpaceRoom, user.phoneLastSeenPtr(), user.phoneLastPingedPtr(), user.Timezone)
 	if err != nil {
 		user.log.Warnfln("Failed to insert %s: %v", user.MXID, err)
 	}
 }
 
 func (user *User) Update() {
-	_, err := user.db.Exec(`UPDATE "user" SET username=$1, agent=$2, device=$3, management_room=$4, space_room=$5, phone_last_seen=$6, phone_last_pinged=$7 WHERE mxid=$8`,
-		user.usernamePtr(), user.agentPtr(), user.devicePtr(), user.ManagementRoom, user.SpaceRoom, user.phoneLastSeenPtr(), user.phoneLastPingedPtr(), user.MXID)
+	_, err := user.db.Exec(`UPDATE "user" SET username=$1, agent=$2, device=$3, management_room=$4, space_room=$5, phone_last_seen=$6, phone_last_pinged=$7, timezone=$8 WHERE mxid=$9`,
+		user.usernamePtr(), user.agentPtr(), user.devicePtr(), user.ManagementRoom, user.SpaceRoom, user.phoneLastSeenPtr(), user.phoneLastPingedPtr(), user.Timezone, user.MXID)
 	if err != nil {
 		user.log.Warnfln("Failed to update %s: %v", user.MXID, err)
 	}

+ 50 - 0
historysync.go

@@ -72,6 +72,11 @@ func (user *User) handleHistorySyncsLoop() {
 	go user.handleBackfillRequestsLoop(user.BackfillQueue.DeferredBackfillRequests)
 	go user.BackfillQueue.RunLoop(user)
 
+	if user.bridge.Config.Bridge.HistorySync.MediaRequests.AutoRequestMedia &&
+		user.bridge.Config.Bridge.HistorySync.MediaRequests.RequestMethod == config.MediaRequestMethodLocalTime {
+		go user.dailyMediaRequestLoop()
+	}
+
 	// Always save the history syncs for the user. If they want to enable
 	// backfilling in the future, we will have it in the database.
 	for evt := range user.historySyncs {
@@ -79,6 +84,51 @@ func (user *User) handleHistorySyncsLoop() {
 	}
 }
 
+func (user *User) dailyMediaRequestLoop() {
+	// Calculate when to do the first set of media retry requests
+	now := time.Now()
+	userTz, err := time.LoadLocation(user.Timezone)
+	if err != nil {
+		userTz = now.Local().Location()
+	}
+	tonightMidnight := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, userTz)
+	midnightOffset := time.Duration(user.bridge.Config.Bridge.HistorySync.MediaRequests.RequestLocalTime) * time.Minute
+	requestStartTime := tonightMidnight.Add(midnightOffset)
+
+	// If the request time for today has already happened, we need to start the
+	// request loop tomorrow instead.
+	if requestStartTime.Before(now) {
+		requestStartTime = requestStartTime.AddDate(0, 0, 1)
+	}
+
+	// Wait to start the loop
+	user.log.Infof("Waiting until %s to do media retry requests", requestStartTime)
+	time.Sleep(time.Until(requestStartTime))
+
+	for {
+		mediaBackfillRequests := user.bridge.DB.MediaBackfillRequest.GetMediaBackfillRequestsForUser(user.MXID)
+		user.log.Infof("Sending %d media retry requests", len(mediaBackfillRequests))
+
+		// Send all of the media backfill requests for the user at once
+		for _, req := range mediaBackfillRequests {
+			portal := user.GetPortalByJID(req.PortalKey.JID)
+			_, err := portal.requestMediaRetry(user, req.EventID)
+			if err != nil {
+				user.log.Warnf("Failed to send media retry request for %s / %s", req.PortalKey.String(), req.EventID)
+				req.Status = database.MediaBackfillRequestStatusRequestFailed
+				req.Error = err.Error()
+			} else {
+				user.log.Debugfln("Sent media retry request for %s / %s", req.PortalKey.String(), req.EventID)
+				req.Status = database.MediaBackfillRequestStatusRequested
+			}
+			req.Upsert()
+		}
+
+		// Wait for 24 hours before making requests again
+		time.Sleep(24 * time.Hour)
+	}
+}
+
 func (user *User) handleBackfillRequestsLoop(backfillRequests chan *database.Backfill) {
 	for req := range backfillRequests {
 		user.log.Infofln("Handling backfill request %s", req)

+ 1 - 1
matrix.go

@@ -491,7 +491,7 @@ func (mx *MatrixHandler) HandleReaction(evt *event.Event) {
 
 	content := evt.Content.AsReaction()
 	if strings.Contains(content.RelatesTo.Key, "retry") || strings.HasPrefix(content.RelatesTo.Key, "\u267b") { // ♻️
-		if portal.requestMediaRetry(user, content.RelatesTo.EventID) {
+		if retryRequested, _ := portal.requestMediaRetry(user, content.RelatesTo.EventID); retryRequested {
 			_, _ = portal.MainIntent().RedactEvent(portal.MXID, evt.ID, mautrix.ReqRedact{
 				Reason: "requested media from phone",
 			})

+ 9 - 7
portal.go

@@ -2346,20 +2346,22 @@ func (portal *Portal) handleMediaRetry(retry *events.MediaRetry, source *User) {
 	msg.UpdateMXID(resp.EventID, database.MsgNormal, database.MsgNoError)
 }
 
-func (portal *Portal) requestMediaRetry(user *User, eventID id.EventID) bool {
+func (portal *Portal) requestMediaRetry(user *User, eventID id.EventID) (bool, error) {
 	msg := portal.bridge.DB.Message.GetByMXID(eventID)
 	if msg == nil {
-		portal.log.Debugfln("%s requested a media retry for unknown event %s", user.MXID, eventID)
-		return false
+		err := errors.New(fmt.Sprintf("%s requested a media retry for unknown event %s", user.MXID, eventID))
+		portal.log.Debugfln(err.Error())
+		return false, err
 	} else if msg.Error != database.MsgErrMediaNotFound {
-		portal.log.Debugfln("%s requested a media retry for non-errored event %s", user.MXID, eventID)
-		return false
+		err := errors.New(fmt.Sprintf("%s requested a media retry for non-errored event %s", user.MXID, eventID))
+		portal.log.Debugfln(err.Error())
+		return false, err
 	}
 
 	evt, err := portal.fetchMediaRetryEvent(msg)
 	if err != nil {
 		portal.log.Warnfln("Can't send media retry request for %s: %v", msg.JID, err)
-		return true
+		return true, nil
 	}
 
 	err = user.Client.SendMediaRetryReceipt(&types.MessageInfo{
@@ -2376,7 +2378,7 @@ func (portal *Portal) requestMediaRetry(user *User, eventID id.EventID) bool {
 	} else {
 		portal.log.Debugfln("Sent media retry request for %s", msg.JID)
 	}
-	return true
+	return true, err
 }
 
 const thumbnailMaxSize = 72

+ 5 - 0
provisioning.go

@@ -513,6 +513,11 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
 	userID := r.URL.Query().Get("user_id")
 	user := prov.bridge.GetUserByMXID(id.UserID(userID))
 
+	if userTimezone := r.URL.Query().Get("tz"); userTimezone != "" {
+		user.Timezone = userTimezone
+		user.Update()
+	}
+
 	c, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
 		prov.log.Errorln("Failed to upgrade connection to websocket:", err)

+ 1 - 0
user.go

@@ -431,6 +431,7 @@ func (user *User) DeleteSession() {
 	user.bridge.DB.Backfill.DeleteAll(user.MXID)
 	user.bridge.DB.HistorySync.DeleteAllConversations(user.MXID)
 	user.bridge.DB.HistorySync.DeleteAllMessages(user.MXID)
+	user.bridge.DB.MediaBackfillRequest.DeleteAllMediaBackfillRequests(user.MXID)
 }
 
 func (user *User) IsConnected() bool {