Browse Source

Send blank protocol message if phone is offline for too long

Tulir Asokan 3 years ago
parent
commit
b389354bcc
8 changed files with 69 additions and 23 deletions
  1. 10 0
      database/upgrades/2022-02-18-phone-ping-ts.go
  2. 1 1
      database/upgrades/upgrades.go
  3. 27 15
      database/user.go
  4. 1 1
      go.mod
  5. 2 2
      go.sum
  6. 1 1
      main.go
  7. 3 0
      portal.go
  8. 24 3
      user.go

+ 10 - 0
database/upgrades/2022-02-18-phone-ping-ts.go

@@ -0,0 +1,10 @@
+package upgrades
+
+import "database/sql"
+
+func init() {
+	upgrades[37] = upgrade{"Store timestamp for previous phone ping", func(tx *sql.Tx, ctx context) error {
+		_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN phone_last_pinged BIGINT`)
+		return err
+	}}
+}

+ 1 - 1
database/upgrades/upgrades.go

@@ -40,7 +40,7 @@ type upgrade struct {
 	fn      upgradeFunc
 	fn      upgradeFunc
 }
 }
 
 
-const NumberOfUpgrades = 37
+const NumberOfUpgrades = 38
 
 
 var upgrades [NumberOfUpgrades]upgrade
 var upgrades [NumberOfUpgrades]upgrade
 
 

+ 27 - 15
database/user.go

@@ -1,5 +1,5 @@
 // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
 // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2021 Tulir Asokan
+// Copyright (C) 2022 Tulir Asokan
 //
 //
 // This program is free software: you can redistribute it and/or modify
 // This program is free software: you can redistribute it and/or modify
 // it under the terms of the GNU Affero General Public License as published by
 // it under the terms of the GNU Affero General Public License as published by
@@ -44,7 +44,7 @@ func (uq *UserQuery) New() *User {
 }
 }
 
 
 func (uq *UserQuery) GetAll() (users []*User) {
 func (uq *UserQuery) GetAll() (users []*User) {
-	rows, err := uq.db.Query(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen FROM "user"`)
+	rows, err := uq.db.Query(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged FROM "user"`)
 	if err != nil || rows == nil {
 	if err != nil || rows == nil {
 		return nil
 		return nil
 	}
 	}
@@ -56,7 +56,7 @@ func (uq *UserQuery) GetAll() (users []*User) {
 }
 }
 
 
 func (uq *UserQuery) GetByMXID(userID id.UserID) *User {
 func (uq *UserQuery) GetByMXID(userID id.UserID) *User {
-	row := uq.db.QueryRow(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen FROM "user" WHERE mxid=$1`, userID)
+	row := uq.db.QueryRow(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged FROM "user" WHERE mxid=$1`, userID)
 	if row == nil {
 	if row == nil {
 		return nil
 		return nil
 	}
 	}
@@ -64,7 +64,7 @@ func (uq *UserQuery) GetByMXID(userID id.UserID) *User {
 }
 }
 
 
 func (uq *UserQuery) GetByUsername(username string) *User {
 func (uq *UserQuery) GetByUsername(username string) *User {
-	row := uq.db.QueryRow(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen FROM "user" WHERE username=$1`, username)
+	row := uq.db.QueryRow(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged FROM "user" WHERE username=$1`, username)
 	if row == nil {
 	if row == nil {
 		return nil
 		return nil
 	}
 	}
@@ -75,11 +75,12 @@ type User struct {
 	db  *Database
 	db  *Database
 	log log.Logger
 	log log.Logger
 
 
-	MXID           id.UserID
-	JID            types.JID
-	ManagementRoom id.RoomID
-	SpaceRoom      id.RoomID
-	PhoneLastSeen  time.Time
+	MXID            id.UserID
+	JID             types.JID
+	ManagementRoom  id.RoomID
+	SpaceRoom       id.RoomID
+	PhoneLastSeen   time.Time
+	PhoneLastPinged time.Time
 
 
 	lastReadCache     map[PortalKey]time.Time
 	lastReadCache     map[PortalKey]time.Time
 	lastReadCacheLock sync.Mutex
 	lastReadCacheLock sync.Mutex
@@ -90,8 +91,8 @@ type User struct {
 func (user *User) Scan(row Scannable) *User {
 func (user *User) Scan(row Scannable) *User {
 	var username sql.NullString
 	var username sql.NullString
 	var device, agent sql.NullByte
 	var device, agent sql.NullByte
-	var phoneLastSeen sql.NullInt64
-	err := row.Scan(&user.MXID, &username, &agent, &device, &user.ManagementRoom, &user.SpaceRoom, &phoneLastSeen)
+	var phoneLastSeen, phoneLastPinged sql.NullInt64
+	err := row.Scan(&user.MXID, &username, &agent, &device, &user.ManagementRoom, &user.SpaceRoom, &phoneLastSeen, &phoneLastPinged)
 	if err != nil {
 	if err != nil {
 		if err != sql.ErrNoRows {
 		if err != sql.ErrNoRows {
 			user.log.Errorln("Database scan failed:", err)
 			user.log.Errorln("Database scan failed:", err)
@@ -104,6 +105,9 @@ func (user *User) Scan(row Scannable) *User {
 	if phoneLastSeen.Valid {
 	if phoneLastSeen.Valid {
 		user.PhoneLastSeen = time.Unix(phoneLastSeen.Int64, 0)
 		user.PhoneLastSeen = time.Unix(phoneLastSeen.Int64, 0)
 	}
 	}
+	if phoneLastPinged.Valid {
+		user.PhoneLastPinged = time.Unix(phoneLastPinged.Int64, 0)
+	}
 	return user
 	return user
 }
 }
 
 
@@ -136,17 +140,25 @@ func (user *User) phoneLastSeenPtr() *int64 {
 	return &ts
 	return &ts
 }
 }
 
 
+func (user *User) phoneLastPingedPtr() *int64 {
+	if user.PhoneLastPinged.IsZero() {
+		return nil
+	}
+	ts := user.PhoneLastPinged.Unix()
+	return &ts
+}
+
 func (user *User) Insert() {
 func (user *User) Insert() {
-	_, err := user.db.Exec(`INSERT INTO "user" (mxid, username, agent, device, management_room, space_room, phone_last_seen) VALUES ($1, $2, $3, $4, $5, $6, $7)`,
-		user.MXID, user.usernamePtr(), user.agentPtr(), user.devicePtr(), user.ManagementRoom, user.SpaceRoom, user.phoneLastSeenPtr())
+	_, err := user.db.Exec(`INSERT INTO "user" (mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
+		user.MXID, user.usernamePtr(), user.agentPtr(), user.devicePtr(), user.ManagementRoom, user.SpaceRoom, user.phoneLastSeenPtr(), user.phoneLastPingedPtr())
 	if err != nil {
 	if err != nil {
 		user.log.Warnfln("Failed to insert %s: %v", user.MXID, err)
 		user.log.Warnfln("Failed to insert %s: %v", user.MXID, err)
 	}
 	}
 }
 }
 
 
 func (user *User) Update() {
 func (user *User) Update() {
-	_, err := user.db.Exec(`UPDATE "user" SET username=$1, agent=$2, device=$3, management_room=$4, space_room=$5, phone_last_seen=$6 WHERE mxid=$7`,
-		user.usernamePtr(), user.agentPtr(), user.devicePtr(), user.ManagementRoom, user.SpaceRoom, user.phoneLastSeenPtr(), user.MXID)
+	_, err := user.db.Exec(`UPDATE "user" SET username=$1, agent=$2, device=$3, management_room=$4, space_room=$5, phone_last_seen=$6, phone_last_pinged=$7 WHERE mxid=$8`,
+		user.usernamePtr(), user.agentPtr(), user.devicePtr(), user.ManagementRoom, user.SpaceRoom, user.phoneLastSeenPtr(), user.phoneLastPingedPtr(), user.MXID)
 	if err != nil {
 	if err != nil {
 		user.log.Warnfln("Failed to update %s: %v", user.MXID, err)
 		user.log.Warnfln("Failed to update %s: %v", user.MXID, err)
 	}
 	}

+ 1 - 1
go.mod

@@ -10,7 +10,7 @@ require (
 	github.com/prometheus/client_golang v1.11.1
 	github.com/prometheus/client_golang v1.11.1
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
 	github.com/tidwall/gjson v1.14.0
 	github.com/tidwall/gjson v1.14.0
-	go.mau.fi/whatsmeow v0.0.0-20220217133111-7d4c399d0640
+	go.mau.fi/whatsmeow v0.0.0-20220218100006-2613ad3a11a2
 	golang.org/x/image v0.0.0-20211028202545-6944b10bf410
 	golang.org/x/image v0.0.0-20211028202545-6944b10bf410
 	golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd
 	golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd
 	google.golang.org/protobuf v1.27.1
 	google.golang.org/protobuf v1.27.1

+ 2 - 2
go.sum

@@ -120,8 +120,8 @@ github.com/tidwall/sjson v1.2.4 h1:cuiLzLnaMeBhRmEv00Lpk3tkYrcxpmbU81tAY4Dw0tc=
 github.com/tidwall/sjson v1.2.4/go.mod h1:098SZ494YoMWPmMO6ct4dcFnqxwj9r/gF0Etp19pSNM=
 github.com/tidwall/sjson v1.2.4/go.mod h1:098SZ494YoMWPmMO6ct4dcFnqxwj9r/gF0Etp19pSNM=
 go.mau.fi/libsignal v0.0.0-20211109153248-a67163214910 h1:9FFhG0OmkuMau5UEaTgiUQ+7cSbtbOQ7hiWKdN8OI3I=
 go.mau.fi/libsignal v0.0.0-20211109153248-a67163214910 h1:9FFhG0OmkuMau5UEaTgiUQ+7cSbtbOQ7hiWKdN8OI3I=
 go.mau.fi/libsignal v0.0.0-20211109153248-a67163214910/go.mod h1:AufGrvVh+00Nc07Jm4hTquh7yleZyn20tKJI2wCPAKg=
 go.mau.fi/libsignal v0.0.0-20211109153248-a67163214910/go.mod h1:AufGrvVh+00Nc07Jm4hTquh7yleZyn20tKJI2wCPAKg=
-go.mau.fi/whatsmeow v0.0.0-20220217133111-7d4c399d0640 h1:8WEXxj18qt6B8KhCW510qtNZjQUiqV2u3nvhNy8HV30=
-go.mau.fi/whatsmeow v0.0.0-20220217133111-7d4c399d0640/go.mod h1:NNI4Ah/B27mfQNChJMD1iSO8+HS+fQ4WqNuQ8Mh2/XI=
+go.mau.fi/whatsmeow v0.0.0-20220218100006-2613ad3a11a2 h1:KPN+bsDm9EQtHFph1rd4h+0UNK0fJTI4ilWIfytK278=
+go.mau.fi/whatsmeow v0.0.0-20220218100006-2613ad3a11a2/go.mod h1:NNI4Ah/B27mfQNChJMD1iSO8+HS+fQ4WqNuQ8Mh2/XI=
 golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
 golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=

+ 1 - 1
main.go

@@ -348,7 +348,7 @@ func (bridge *Bridge) Loop() {
 func (bridge *Bridge) WarnUsersAboutDisconnection() {
 func (bridge *Bridge) WarnUsersAboutDisconnection() {
 	bridge.usersLock.Lock()
 	bridge.usersLock.Lock()
 	for _, user := range bridge.usersByUsername {
 	for _, user := range bridge.usersByUsername {
-		if user.IsConnected() && !user.PhoneRecentlySeen() {
+		if user.IsConnected() && !user.PhoneRecentlySeen(true) {
 			go user.sendPhoneOfflineWarning()
 			go user.sendPhoneOfflineWarning()
 		}
 		}
 	}
 	}

+ 3 - 0
portal.go

@@ -359,6 +359,9 @@ func getMessageType(waMsg *waProto.Message) string {
 	case waMsg.ProtocolMessage != nil:
 	case waMsg.ProtocolMessage != nil:
 		switch waMsg.GetProtocolMessage().GetType() {
 		switch waMsg.GetProtocolMessage().GetType() {
 		case waProto.ProtocolMessage_REVOKE:
 		case waProto.ProtocolMessage_REVOKE:
+			if waMsg.GetProtocolMessage().GetKey() == nil {
+				return "ignore"
+			}
 			return "revoke"
 			return "revoke"
 		case waProto.ProtocolMessage_EPHEMERAL_SETTING:
 		case waProto.ProtocolMessage_EPHEMERAL_SETTING:
 			return "disappearing timer change"
 			return "disappearing timer change"

+ 24 - 3
user.go

@@ -475,8 +475,29 @@ func (user *User) handleCallStart(sender types.JID, id, callType string, ts time
 }
 }
 
 
 const PhoneDisconnectWarningTime = 12 * 24 * time.Hour // 12 days
 const PhoneDisconnectWarningTime = 12 * 24 * time.Hour // 12 days
+const PhoneDisconnectPingTime = 10 * 24 * time.Hour
+const PhoneMinPingInterval = 24 * time.Hour
+
+func (user *User) sendHackyPhonePing() {
+	msgID := whatsmeow.GenerateMessageID()
+	user.PhoneLastPinged = time.Now()
+	ts, err := user.Client.SendMessage(user.JID.ToNonAD(), msgID, &waProto.Message{
+		ProtocolMessage: &waProto.ProtocolMessage{},
+	})
+	if err != nil {
+		user.log.Warnfln("Failed to send hacky phone ping: %v", err)
+	} else {
+		user.log.Debugfln("Sent hacky phone ping %s/%s because phone has been offline for >10 days", msgID, ts)
+		user.PhoneLastPinged = ts
+		user.Update()
+	}
+}
 
 
-func (user *User) PhoneRecentlySeen() bool {
+func (user *User) PhoneRecentlySeen(doPing bool) bool {
+	if doPing && !user.PhoneLastSeen.IsZero() && user.PhoneLastSeen.Add(PhoneDisconnectPingTime).Before(time.Now()) && user.PhoneLastPinged.Add(PhoneMinPingInterval).Before(time.Now()) {
+		// Over 10 days since the phone was seen and over a day since the last somewhat hacky ping, send a new ping.
+		go user.sendHackyPhonePing()
+	}
 	return user.PhoneLastSeen.IsZero() || user.PhoneLastSeen.Add(PhoneDisconnectWarningTime).After(time.Now())
 	return user.PhoneLastSeen.IsZero() || user.PhoneLastSeen.Add(PhoneDisconnectWarningTime).After(time.Now())
 }
 }
 
 
@@ -487,7 +508,7 @@ func (user *User) phoneSeen(ts time.Time) {
 		// The last seen timestamp isn't going to be perfectly accurate in any case,
 		// The last seen timestamp isn't going to be perfectly accurate in any case,
 		// so don't spam the database with an update every time there's an event.
 		// so don't spam the database with an update every time there's an event.
 		return
 		return
-	} else if !user.PhoneRecentlySeen() && user.GetPrevBridgeState().Error == WAPhoneOffline && user.IsConnected() {
+	} else if !user.PhoneRecentlySeen(false) && user.GetPrevBridgeState().Error == WAPhoneOffline && user.IsConnected() {
 		user.log.Debugfln("Saw phone after current bridge state said it has been offline, switching state back to connected")
 		user.log.Debugfln("Saw phone after current bridge state said it has been offline, switching state back to connected")
 		go user.sendBridgeState(BridgeState{StateEvent: StateConnected})
 		go user.sendBridgeState(BridgeState{StateEvent: StateConnected})
 	}
 	}
@@ -543,7 +564,7 @@ func (user *User) HandleEvent(event interface{}) {
 			Message:    fmt.Sprintf("backfilling %d messages and %d receipts", v.Messages, v.Receipts),
 			Message:    fmt.Sprintf("backfilling %d messages and %d receipts", v.Messages, v.Receipts),
 		})
 		})
 	case *events.OfflineSyncCompleted:
 	case *events.OfflineSyncCompleted:
-		if !user.PhoneRecentlySeen() {
+		if !user.PhoneRecentlySeen(true) {
 			user.log.Infofln("Offline sync completed, but phone last seen date is still %s - sending phone offline bridge status", user.PhoneLastSeen)
 			user.log.Infofln("Offline sync completed, but phone last seen date is still %s - sending phone offline bridge status", user.PhoneLastSeen)
 			go user.sendBridgeState(BridgeState{StateEvent: StateTransientDisconnect, Error: WAPhoneOffline})
 			go user.sendBridgeState(BridgeState{StateEvent: StateTransientDisconnect, Error: WAPhoneOffline})
 		} else if user.GetPrevBridgeState().StateEvent == StateBackfilling {
 		} else if user.GetPrevBridgeState().StateEvent == StateBackfilling {