123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408 |
- // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
- // Copyright (C) 2020 Tulir Asokan
- //
- // This program is free software: you can redistribute it and/or modify
- // it under the terms of the GNU Affero General Public License as published by
- // the Free Software Foundation, either version 3 of the License, or
- // (at your option) any later version.
- //
- // This program is distributed in the hope that it will be useful,
- // but WITHOUT ANY WARRANTY; without even the implied warranty of
- // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- // GNU Affero General Public License for more details.
- //
- // You should have received a copy of the GNU Affero General Public License
- // along with this program. If not, see <https://www.gnu.org/licenses/>.
- // +build cgo
- package database
- import (
- "database/sql"
- "fmt"
- "strings"
- "sync"
- "github.com/lib/pq"
- "github.com/pkg/errors"
- log "maunium.net/go/maulogger/v2"
- "maunium.net/go/mautrix/crypto"
- "maunium.net/go/mautrix/crypto/olm"
- "maunium.net/go/mautrix/id"
- )
- type SQLCryptoStore struct {
- db *Database
- log log.Logger
- UserID id.UserID
- DeviceID id.DeviceID
- SyncToken string
- PickleKey []byte
- Account *crypto.OlmAccount
- GhostIDFormat string
- OGSLock sync.RWMutex
- OutGroupSessions map[id.RoomID]*crypto.OutboundGroupSession
- }
- var _ crypto.Store = (*SQLCryptoStore)(nil)
- func NewSQLCryptoStore(db *Database, deviceID id.DeviceID) *SQLCryptoStore {
- return &SQLCryptoStore{
- db: db,
- log: db.log.Sub("CryptoStore"),
- PickleKey: []byte("maunium.net/go/mautrix-whatsapp"),
- DeviceID: deviceID,
- OutGroupSessions: make(map[id.RoomID]*crypto.OutboundGroupSession),
- }
- }
- func (db *Database) FindDeviceID() (deviceID id.DeviceID) {
- err := db.QueryRow("SELECT device_id FROM crypto_account LIMIT 1").Scan(&deviceID)
- if err != nil && err != sql.ErrNoRows {
- db.log.Warnln("Failed to scan device ID:", err)
- }
- return
- }
- func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) (members []id.UserID, err error) {
- var rows *sql.Rows
- rows, err = store.db.Query(`
- SELECT user_id FROM mx_user_profile
- WHERE room_id=$1
- AND (membership='join' OR membership='invite')
- AND user_id<>$2
- AND user_id NOT LIKE $3
- `, roomID, store.UserID, store.GhostIDFormat)
- if err != nil {
- return
- }
- for rows.Next() {
- var userID id.UserID
- err := rows.Scan(&userID)
- if err != nil {
- store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
- } else {
- members = append(members, userID)
- }
- }
- return
- }
- func (store *SQLCryptoStore) Flush() error {
- return nil
- }
- func (store *SQLCryptoStore) PutNextBatch(nextBatch string) {
- store.SyncToken = nextBatch
- _, err := store.db.Exec(`UPDATE crypto_account SET sync_token=$1 WHERE device_id=$2`, store.SyncToken, store.DeviceID)
- if err != nil {
- store.log.Warnln("Failed to store sync token:", err)
- }
- }
- func (store *SQLCryptoStore) GetNextBatch() string {
- if store.SyncToken == "" {
- err := store.db.
- QueryRow("SELECT sync_token FROM crypto_account WHERE device_id=$1", store.DeviceID).
- Scan(&store.SyncToken)
- if err != nil && err != sql.ErrNoRows {
- store.log.Warnln("Failed to scan sync token:", err)
- }
- }
- return store.SyncToken
- }
- func (store *SQLCryptoStore) PutAccount(account *crypto.OlmAccount) error {
- store.Account = account
- bytes := account.Internal.Pickle(store.PickleKey)
- var err error
- if store.db.dialect == "postgres" {
- _, err = store.db.Exec(`
- INSERT INTO crypto_account (device_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)
- ON CONFLICT (device_id) DO UPDATE SET shared=$2, sync_token=$3, account=$4`,
- store.DeviceID, account.Shared, store.SyncToken, bytes)
- } else if store.db.dialect == "sqlite3" {
- _, err = store.db.Exec("INSERT OR REPLACE INTO crypto_account (device_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)",
- store.DeviceID, account.Shared, store.SyncToken, bytes)
- } else {
- err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
- }
- if err != nil {
- store.log.Warnln("Failed to store account:", err)
- }
- return nil
- }
- func (store *SQLCryptoStore) GetAccount() (*crypto.OlmAccount, error) {
- if store.Account == nil {
- row := store.db.QueryRow("SELECT shared, sync_token, account FROM crypto_account WHERE device_id=$1", store.DeviceID)
- acc := &crypto.OlmAccount{Internal: *olm.NewBlankAccount()}
- var accountBytes []byte
- err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes)
- if err == sql.ErrNoRows {
- return nil, nil
- } else if err != nil {
- return nil, err
- }
- err = acc.Internal.Unpickle(accountBytes, store.PickleKey)
- if err != nil {
- return nil, err
- }
- store.Account = acc
- }
- return store.Account, nil
- }
- func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
- // TODO this may need to be changed if olm sessions start expiring
- var sessionID id.SessionID
- err := store.db.QueryRow("SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 LIMIT 1", key).Scan(&sessionID)
- if err == sql.ErrNoRows {
- return false
- }
- return len(sessionID) > 0
- }
- func (store *SQLCryptoStore) GetSessions(key id.SenderKey) (crypto.OlmSessionList, error) {
- rows, err := store.db.Query("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 ORDER BY session_id", key)
- if err != nil {
- return nil, err
- }
- list := crypto.OlmSessionList{}
- for rows.Next() {
- sess := crypto.OlmSession{Internal: *olm.NewBlankSession()}
- var sessionBytes []byte
- err := rows.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
- if err != nil {
- return nil, err
- }
- err = sess.Internal.Unpickle(sessionBytes, store.PickleKey)
- if err != nil {
- return nil, err
- }
- list = append(list, &sess)
- }
- return list, nil
- }
- func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*crypto.OlmSession, error) {
- row := store.db.QueryRow("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 ORDER BY session_id DESC LIMIT 1", key)
- sess := crypto.OlmSession{Internal: *olm.NewBlankSession()}
- var sessionBytes []byte
- err := row.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
- if err == sql.ErrNoRows {
- return nil, nil
- } else if err != nil {
- return nil, err
- }
- return &sess, sess.Internal.Unpickle(sessionBytes, store.PickleKey)
- }
- func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *crypto.OlmSession) error {
- sessionBytes := session.Internal.Pickle(store.PickleKey)
- _, err := store.db.Exec("INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_used) VALUES ($1, $2, $3, $4, $5)",
- session.ID(), key, sessionBytes, session.CreationTime, session.UseTime)
- return err
- }
- func (store *SQLCryptoStore) UpdateSession(key id.SenderKey, session *crypto.OlmSession) error {
- sessionBytes := session.Internal.Pickle(store.PickleKey)
- _, err := store.db.Exec("UPDATE crypto_olm_session SET session=$1, last_used=$2 WHERE session_id=$3",
- sessionBytes, session.UseTime, session.ID())
- return err
- }
- func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *crypto.InboundGroupSession) error {
- sessionBytes := session.Internal.Pickle(store.PickleKey)
- forwardingChains := strings.Join(session.ForwardingChains, ",")
- _, err := store.db.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, signing_key, room_id, session, forwarding_chains) VALUES ($1, $2, $3, $4, $5, $6)",
- sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains)
- return err
- }
- func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*crypto.InboundGroupSession, error) {
- var signingKey id.Ed25519
- var sessionBytes []byte
- var forwardingChains string
- err := store.db.QueryRow(`
- SELECT signing_key, session, forwarding_chains
- FROM crypto_megolm_inbound_session
- WHERE room_id=$1 AND sender_key=$2 AND session_id=$3`,
- roomID, senderKey, sessionID,
- ).Scan(&signingKey, &sessionBytes, &forwardingChains)
- if err == sql.ErrNoRows {
- return nil, nil
- } else if err != nil {
- return nil, err
- }
- igs := olm.NewBlankInboundGroupSession()
- err = igs.Unpickle(sessionBytes, store.PickleKey)
- if err != nil {
- return nil, err
- }
- return &crypto.InboundGroupSession{
- Internal: *igs,
- SigningKey: signingKey,
- SenderKey: senderKey,
- RoomID: roomID,
- ForwardingChains: strings.Split(forwardingChains, ","),
- }, nil
- }
- func (store *SQLCryptoStore) PutOutboundGroupSession(roomID id.RoomID, session *crypto.OutboundGroupSession) error {
- store.OGSLock.Lock()
- store.OutGroupSessions[roomID] = session
- store.OGSLock.Unlock()
- return nil
- }
- func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*crypto.OutboundGroupSession, error) {
- store.OGSLock.RLock()
- defer store.OGSLock.RUnlock()
- return store.OutGroupSessions[roomID], nil
- }
- func (store *SQLCryptoStore) PopOutboundGroupSession(roomID id.RoomID) error {
- store.OGSLock.Lock()
- delete(store.OutGroupSessions, roomID)
- store.OGSLock.Unlock()
- return nil
- }
- func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) bool {
- var resultEventID id.EventID
- var resultTimestamp int64
- err := store.db.QueryRow(
- `SELECT event_id, timestamp FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND "index"=$3`,
- senderKey, sessionID, index,
- ).Scan(&resultEventID, &resultTimestamp)
- if err == sql.ErrNoRows {
- _, err := store.db.Exec(`INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp) VALUES ($1, $2, $3, $4, $5)`,
- senderKey, sessionID, index, eventID, timestamp)
- if err != nil {
- store.log.Warnln("Failed to store message index:", err)
- }
- return true
- } else if err != nil {
- store.log.Warnln("Failed to scan message index:", err)
- return true
- }
- if resultEventID != eventID || resultTimestamp != timestamp {
- return false
- }
- return true
- }
- func (store *SQLCryptoStore) GetDevices(userID id.UserID) (map[id.DeviceID]*crypto.DeviceIdentity, error) {
- var ignore id.UserID
- err := store.db.QueryRow("SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore)
- if err == sql.ErrNoRows {
- return nil, nil
- } else if err != nil {
- return nil, err
- }
- rows, err := store.db.Query("SELECT device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1", userID)
- if err != nil {
- return nil, err
- }
- data := make(map[id.DeviceID]*crypto.DeviceIdentity)
- for rows.Next() {
- var identity crypto.DeviceIdentity
- err := rows.Scan(&identity.DeviceID, &identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
- if err != nil {
- return nil, err
- }
- identity.UserID = userID
- data[identity.DeviceID] = &identity
- }
- return data, nil
- }
- func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*crypto.DeviceIdentity) error {
- tx, err := store.db.Begin()
- if err != nil {
- return err
- }
- if store.db.dialect == "postgres" {
- _, err = tx.Exec("INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
- } else if store.db.dialect == "sqlite3" {
- _, err = tx.Exec("INSERT OR IGNORE INTO crypto_tracked_user (user_id) VALUES ($1)", userID)
- } else {
- err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
- }
- if err != nil {
- return errors.Wrap(err, "failed to add user to tracked users list")
- }
- _, err = tx.Exec("DELETE FROM crypto_device WHERE user_id=$1", userID)
- if err != nil {
- _ = tx.Rollback()
- return errors.Wrap(err, "failed to delete old devices")
- }
- if len(devices) == 0 {
- err = tx.Commit()
- if err != nil {
- return errors.Wrap(err, "failed to commit changes (no devices added)")
- }
- return nil
- }
- // TODO do this in batches to avoid too large db queries
- values := make([]interface{}, 1, len(devices)*6+1)
- values[0] = userID
- valueStrings := make([]string, 0, len(devices))
- i := 2
- for deviceID, identity := range devices {
- values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name)
- valueStrings = append(valueStrings, fmt.Sprintf("($1, $%d, $%d, $%d, $%d, $%d, $%d)", i, i+1, i+2, i+3, i+4, i+5))
- i += 6
- }
- valueString := strings.Join(valueStrings, ",")
- _, err = tx.Exec("INSERT INTO crypto_device (user_id, device_id, identity_key, signing_key, trust, deleted, name) VALUES "+valueString, values...)
- if err != nil {
- _ = tx.Rollback()
- return errors.Wrap(err, "failed to insert new devices")
- }
- err = tx.Commit()
- if err != nil {
- return errors.Wrap(err, "failed to commit changes")
- }
- return nil
- }
- func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID {
- var rows *sql.Rows
- var err error
- if store.db.dialect == "postgres" {
- rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", pq.Array(users))
- } else {
- queryString := make([]string, len(users))
- params := make([]interface{}, len(users))
- for i, user := range users {
- queryString[i] = fmt.Sprintf("$%d", i+1)
- params[i] = user
- }
- rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN (" + strings.Join(queryString, ",") + ")", params...)
- }
- if err != nil {
- store.log.Warnln("Failed to filter tracked users:", err)
- return users
- }
- var ptr int
- for rows.Next() {
- err = rows.Scan(&users[ptr])
- if err != nil {
- store.log.Warnln("Failed to tracked user ID:", err)
- } else {
- ptr++
- }
- }
- return users[:ptr]
- }
|