ソースを参照

Store outbound group sessions in database

Tulir Asokan 5 年 前
コミット
c9adb3aba3

+ 36 - 20
database/cryptostore.go

@@ -22,7 +22,6 @@ import (
 	"database/sql"
 	"fmt"
 	"strings"
-	"sync"
 
 	"github.com/lib/pq"
 	"github.com/pkg/errors"
@@ -44,9 +43,6 @@ type SQLCryptoStore struct {
 	Account   *crypto.OlmAccount
 
 	GhostIDFormat string
-
-	OGSLock          sync.RWMutex
-	OutGroupSessions map[id.RoomID]*crypto.OutboundGroupSession
 }
 
 var _ crypto.Store = (*SQLCryptoStore)(nil)
@@ -57,8 +53,6 @@ func NewSQLCryptoStore(db *Database, deviceID id.DeviceID) *SQLCryptoStore {
 		log:       db.log.Sub("CryptoStore"),
 		PickleKey: []byte("maunium.net/go/mautrix-whatsapp"),
 		DeviceID:  deviceID,
-
-		OutGroupSessions: make(map[id.RoomID]*crypto.OutboundGroupSession),
 	}
 }
 
@@ -255,24 +249,46 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
 	}, nil
 }
 
-func (store *SQLCryptoStore) PutOutboundGroupSession(roomID id.RoomID, session *crypto.OutboundGroupSession) error {
-	store.OGSLock.Lock()
-	store.OutGroupSessions[roomID] = session
-	store.OGSLock.Unlock()
-	return nil
+func (store *SQLCryptoStore) AddOutboundGroupSession(session *crypto.OutboundGroupSession) error {
+	sessionBytes := session.Internal.Pickle(store.PickleKey)
+	_, err := store.db.Exec("INSERT INTO crypto_megolm_outbound_session (room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)",
+		session.RoomID, session.ID(), sessionBytes, session.Shared, session.MaxMessages, session.MessageCount, session.MaxAge, session.CreationTime, session.UseTime)
+	return err
+}
+
+func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *crypto.OutboundGroupSession) error {
+	sessionBytes := session.Internal.Pickle(store.PickleKey)
+	_, err := store.db.Exec("UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5",
+		sessionBytes, session.MessageCount, session.UseTime, session.RoomID, session.ID())
+	return err
 }
 
 func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*crypto.OutboundGroupSession, error) {
-	store.OGSLock.RLock()
-	defer store.OGSLock.RUnlock()
-	return store.OutGroupSessions[roomID], nil
+	var ogs crypto.OutboundGroupSession
+	var sessionBytes []byte
+	err := store.db.QueryRow(`
+		SELECT session, shared, max_messages, message_count, max_age, created_at, last_used
+		FROM crypto_megolm_outbound_session WHERE room_id=$1`,
+		roomID,
+	).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &ogs.MaxAge, &ogs.CreationTime, &ogs.UseTime)
+	if err == sql.ErrNoRows {
+		return nil, nil
+	} else if err != nil {
+		return nil, err
+	}
+	intOGS := olm.NewBlankOutboundGroupSession()
+	err = intOGS.Unpickle(sessionBytes, store.PickleKey)
+	if err != nil {
+		return nil, err
+	}
+	ogs.Internal = *intOGS
+	ogs.RoomID = roomID
+	return &ogs, nil
 }
 
-func (store *SQLCryptoStore) PopOutboundGroupSession(roomID id.RoomID) error {
-	store.OGSLock.Lock()
-	delete(store.OutGroupSessions, roomID)
-	store.OGSLock.Unlock()
-	return nil
+func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
+	_, err := store.db.Exec("DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1", roomID)
+	return err
 }
 
 func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) bool {
@@ -389,7 +405,7 @@ func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID {
 			queryString[i] = fmt.Sprintf("$%d", i+1)
 			params[i] = user
 		}
-		rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN (" + strings.Join(queryString, ",") + ")", params...)
+		rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...)
 	}
 	if err != nil {
 		store.log.Warnln("Failed to filter tracked users:", err)

+ 26 - 0
database/upgrades/2020-05-12-outbound-group-session-store.go

@@ -0,0 +1,26 @@
+package upgrades
+
+import (
+	"database/sql"
+)
+
+func init() {
+	upgrades[14] = upgrade{"Add outbound group sessions to database", func(tx *sql.Tx, ctx context) error {
+		// TODO use DATETIME instead of timestamp and BLOB instead of bytea for sqlite
+		_, err := tx.Exec(`CREATE TABLE crypto_megolm_outbound_session (
+			room_id       VARCHAR(255) PRIMARY KEY,
+			session_id    CHAR(43)     NOT NULL UNIQUE,
+			session       bytea        NOT NULL,
+			shared        BOOLEAN      NOT NULL,
+			max_messages  INTEGER      NOT NULL,
+			message_count INTEGER      NOT NULL,
+			max_age       BIGINT       NOT NULL,
+			created_at    timestamp    NOT NULL,
+			last_used     timestamp    NOT NULL
+		)`)
+		if err != nil {
+			return err
+		}
+		return nil
+	}}
+}

+ 1 - 1
database/upgrades/upgrades.go

@@ -28,7 +28,7 @@ type upgrade struct {
 	fn      upgradeFunc
 }
 
-const NumberOfUpgrades = 14
+const NumberOfUpgrades = 15
 
 var upgrades [NumberOfUpgrades]upgrade
 

+ 1 - 1
go.mod

@@ -15,7 +15,7 @@ require (
 	gopkg.in/yaml.v2 v2.2.8
 	maunium.net/go/mauflag v1.0.0
 	maunium.net/go/maulogger/v2 v2.1.1
-	maunium.net/go/mautrix v0.4.3
+	maunium.net/go/mautrix v0.4.4
 )
 
 replace github.com/Rhymen/go-whatsapp => github.com/tulir/go-whatsapp v0.2.6

+ 2 - 0
go.sum

@@ -92,3 +92,5 @@ maunium.net/go/mautrix v0.4.2 h1:GBU++Z7o/fLPcEsNMkNOUsnDknwV/MGPQ0BN4ikK6tw=
 maunium.net/go/mautrix v0.4.2/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
 maunium.net/go/mautrix v0.4.3 h1:fVoJy992TjBEvuK5NeO9fpBh+9JuSFsxaEdGjFp/7h4=
 maunium.net/go/mautrix v0.4.3/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
+maunium.net/go/mautrix v0.4.4 h1:C5yYDzUdRtJj/9Vot5YBPQUsWmn19sTySew7f4ACLhM=
+maunium.net/go/mautrix v0.4.4/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=