|
@@ -22,7 +22,6 @@ import (
|
|
|
"database/sql"
|
|
|
"fmt"
|
|
|
"strings"
|
|
|
- "sync"
|
|
|
|
|
|
"github.com/lib/pq"
|
|
|
"github.com/pkg/errors"
|
|
@@ -44,9 +43,6 @@ type SQLCryptoStore struct {
|
|
|
Account *crypto.OlmAccount
|
|
|
|
|
|
GhostIDFormat string
|
|
|
-
|
|
|
- OGSLock sync.RWMutex
|
|
|
- OutGroupSessions map[id.RoomID]*crypto.OutboundGroupSession
|
|
|
}
|
|
|
|
|
|
var _ crypto.Store = (*SQLCryptoStore)(nil)
|
|
@@ -57,8 +53,6 @@ func NewSQLCryptoStore(db *Database, deviceID id.DeviceID) *SQLCryptoStore {
|
|
|
log: db.log.Sub("CryptoStore"),
|
|
|
PickleKey: []byte("maunium.net/go/mautrix-whatsapp"),
|
|
|
DeviceID: deviceID,
|
|
|
-
|
|
|
- OutGroupSessions: make(map[id.RoomID]*crypto.OutboundGroupSession),
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -255,24 +249,46 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
-func (store *SQLCryptoStore) PutOutboundGroupSession(roomID id.RoomID, session *crypto.OutboundGroupSession) error {
|
|
|
- store.OGSLock.Lock()
|
|
|
- store.OutGroupSessions[roomID] = session
|
|
|
- store.OGSLock.Unlock()
|
|
|
- return nil
|
|
|
+func (store *SQLCryptoStore) AddOutboundGroupSession(session *crypto.OutboundGroupSession) error {
|
|
|
+ sessionBytes := session.Internal.Pickle(store.PickleKey)
|
|
|
+ _, err := store.db.Exec("INSERT INTO crypto_megolm_outbound_session (room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)",
|
|
|
+ session.RoomID, session.ID(), sessionBytes, session.Shared, session.MaxMessages, session.MessageCount, session.MaxAge, session.CreationTime, session.UseTime)
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *crypto.OutboundGroupSession) error {
|
|
|
+ sessionBytes := session.Internal.Pickle(store.PickleKey)
|
|
|
+ _, err := store.db.Exec("UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5",
|
|
|
+ sessionBytes, session.MessageCount, session.UseTime, session.RoomID, session.ID())
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*crypto.OutboundGroupSession, error) {
|
|
|
- store.OGSLock.RLock()
|
|
|
- defer store.OGSLock.RUnlock()
|
|
|
- return store.OutGroupSessions[roomID], nil
|
|
|
+ var ogs crypto.OutboundGroupSession
|
|
|
+ var sessionBytes []byte
|
|
|
+ err := store.db.QueryRow(`
|
|
|
+ SELECT session, shared, max_messages, message_count, max_age, created_at, last_used
|
|
|
+ FROM crypto_megolm_outbound_session WHERE room_id=$1`,
|
|
|
+ roomID,
|
|
|
+ ).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &ogs.MaxAge, &ogs.CreationTime, &ogs.UseTime)
|
|
|
+ if err == sql.ErrNoRows {
|
|
|
+ return nil, nil
|
|
|
+ } else if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ intOGS := olm.NewBlankOutboundGroupSession()
|
|
|
+ err = intOGS.Unpickle(sessionBytes, store.PickleKey)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ ogs.Internal = *intOGS
|
|
|
+ ogs.RoomID = roomID
|
|
|
+ return &ogs, nil
|
|
|
}
|
|
|
|
|
|
-func (store *SQLCryptoStore) PopOutboundGroupSession(roomID id.RoomID) error {
|
|
|
- store.OGSLock.Lock()
|
|
|
- delete(store.OutGroupSessions, roomID)
|
|
|
- store.OGSLock.Unlock()
|
|
|
- return nil
|
|
|
+func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
|
|
|
+ _, err := store.db.Exec("DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1", roomID)
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) bool {
|
|
@@ -389,7 +405,7 @@ func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID {
|
|
|
queryString[i] = fmt.Sprintf("$%d", i+1)
|
|
|
params[i] = user
|
|
|
}
|
|
|
- rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN (" + strings.Join(queryString, ",") + ")", params...)
|
|
|
+ rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...)
|
|
|
}
|
|
|
if err != nil {
|
|
|
store.log.Warnln("Failed to filter tracked users:", err)
|