Эх сурвалжийг харах

Move crypto store to main database

Tulir Asokan 5 жил өмнө
parent
commit
dfc5722a80

+ 73 - 51
crypto.go

@@ -28,6 +28,7 @@ import (
 	"maunium.net/go/maulogger/v2"
 
 	"maunium.net/go/mautrix"
+	"maunium.net/go/mautrix-whatsapp/database"
 	"maunium.net/go/mautrix/crypto"
 	"maunium.net/go/mautrix/event"
 	"maunium.net/go/mautrix/id"
@@ -40,13 +41,15 @@ var levelTrace = maulogger.Level{
 }
 
 type CryptoHelper struct {
-	bridge *Bridge
-	client *mautrix.Client
-	mach   *crypto.OlmMachine
-	log    maulogger.Logger
+	bridge  *Bridge
+	client  *mautrix.Client
+	mach    *crypto.OlmMachine
+	store   *database.SQLCryptoStore
+	log     maulogger.Logger
+	baseLog maulogger.Logger
 }
 
-func (bridge *Bridge) initCrypto() error {
+func NewCryptoHelper(bridge *Bridge) *CryptoHelper {
 	if !bridge.Config.Bridge.Encryption.Allow {
 		bridge.Log.Debugln("Bridge built with end-to-bridge encryption, but disabled in config")
 		return nil
@@ -54,74 +57,74 @@ func (bridge *Bridge) initCrypto() error {
 		bridge.Log.Warnln("End-to-bridge encryption enabled, but login_shared_secret not set")
 		return nil
 	}
-	bridge.Log.Debugln("Initializing end-to-bridge encryption...")
-	client, err := bridge.loginBot()
-	if err != nil {
-		return err
-	}
-	// TODO put this in the database
-	cryptoStore, err := crypto.NewGobStore("crypto.gob")
-	if err != nil {
-		return err
+	baseLog := bridge.Log.Sub("Crypto")
+	return &CryptoHelper{
+		bridge:  bridge,
+		log:     baseLog.Sub("Helper"),
+		baseLog: baseLog,
 	}
+}
 
-	log := bridge.Log.Sub("Crypto")
-	logger := &cryptoLogger{log}
-	stateStore := &cryptoStateStore{bridge}
-	helper := &CryptoHelper{
-		bridge: bridge,
-		client: client,
-		log: log.Sub("Helper"),
-		mach: crypto.NewOlmMachine(client, logger, cryptoStore, stateStore),
-	}
-
-	client.Logger = logger.int.Sub("Bot")
-	client.Syncer = &cryptoSyncer{helper.mach}
-	// TODO put this in the database too
-	client.Store = mautrix.NewInMemoryStore()
-
-	err = helper.mach.Load()
+func (helper *CryptoHelper) Init() error {
+	helper.log.Debugln("Initializing end-to-bridge encryption...")
+	var err error
+	helper.client, err = helper.loginBot()
 	if err != nil {
 		return err
 	}
 
-	bridge.Crypto = helper
-	return nil
-}
+	helper.log.Debugln("Logged in as bridge bot with device ID", helper.client.DeviceID)
+	logger := &cryptoLogger{helper.baseLog}
+	stateStore := &cryptoStateStore{helper.bridge}
+	helper.store = database.NewSQLCryptoStore(helper.bridge.DB, helper.client.DeviceID)
+	helper.store.UserID = helper.client.UserID
+	helper.store.GhostIDFormat = helper.bridge.Config.Bridge.FormatUsername("%")
+	helper.mach = crypto.NewOlmMachine(helper.client, logger, helper.store, stateStore)
 
-func (helper *CryptoHelper) Start() {
-	helper.log.Debugln("Starting syncer for receiving to-device messages")
-	err := helper.client.Sync()
-	if err != nil {
-		helper.log.Errorln("Fatal error syncing:", err)
-	}
-}
+	helper.client.Logger = logger.int.Sub("Bot")
+	helper.client.Syncer = &cryptoSyncer{helper.mach}
+	helper.client.Store = &cryptoClientStore{helper.store}
 
-func (helper *CryptoHelper) Stop() {
-	helper.client.StopSync()
+	return helper.mach.Load()
 }
 
-func (bridge *Bridge) loginBot() (*mautrix.Client, error) {
-	mac := hmac.New(sha512.New, []byte(bridge.Config.Bridge.LoginSharedSecret))
-	mac.Write([]byte(bridge.AS.BotMXID()))
-	resp, err := bridge.AS.BotClient().Login(&mautrix.ReqLogin{
+func (helper *CryptoHelper) loginBot() (*mautrix.Client, error) {
+	deviceID := helper.bridge.DB.FindDeviceID()
+	if len(deviceID) > 0 {
+		helper.log.Debugln("Found existing device ID for bot in database:", deviceID)
+	}
+	mac := hmac.New(sha512.New, []byte(helper.bridge.Config.Bridge.LoginSharedSecret))
+	mac.Write([]byte(helper.bridge.AS.BotMXID()))
+	resp, err := helper.bridge.AS.BotClient().Login(&mautrix.ReqLogin{
 		Type:                     "m.login.password",
-		Identifier:               mautrix.UserIdentifier{Type: "m.id.user", User: string(bridge.AS.BotMXID())},
+		Identifier:               mautrix.UserIdentifier{Type: "m.id.user", User: string(helper.bridge.AS.BotMXID())},
 		Password:                 hex.EncodeToString(mac.Sum(nil)),
-		DeviceID:                 "WhatsApp Bridge",
+		DeviceID:                 deviceID,
 		InitialDeviceDisplayName: "WhatsApp Bridge",
 	})
 	if err != nil {
 		return nil, err
 	}
-	client, err := mautrix.NewClient(bridge.AS.HomeserverURL, bridge.AS.BotMXID(), resp.AccessToken)
+	client, err := mautrix.NewClient(helper.bridge.AS.HomeserverURL, helper.bridge.AS.BotMXID(), resp.AccessToken)
 	if err != nil {
 		return nil, err
 	}
-	client.DeviceID = "WhatsApp Bridge"
+	client.DeviceID = resp.DeviceID
 	return client, nil
 }
 
+func (helper *CryptoHelper) Start() {
+	helper.log.Debugln("Starting syncer for receiving to-device messages")
+	err := helper.client.Sync()
+	if err != nil {
+		helper.log.Errorln("Fatal error syncing:", err)
+	}
+}
+
+func (helper *CryptoHelper) Stop() {
+	helper.client.StopSync()
+}
+
 func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
 	return helper.mach.DecryptMegolmEvent(evt)
 }
@@ -133,7 +136,7 @@ func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, conten
 			return nil, err
 		}
 		helper.log.Debugfln("Got %v while encrypting event for %s, sharing group session and trying again...", err, roomID)
-		users, err := helper.bridge.StateStore.GetRoomMemberList(roomID)
+		users, err := helper.store.GetRoomMembers(roomID)
 		if err != nil {
 			return nil, errors.Wrap(err, "failed to get room member list")
 		}
@@ -202,6 +205,25 @@ func (c *cryptoLogger) Trace(message string, args ...interface{}) {
 	c.int.Logfln(levelTrace, message, args...)
 }
 
+type cryptoClientStore struct {
+	int *database.SQLCryptoStore
+}
+
+func (c cryptoClientStore) SaveFilterID(_ id.UserID, _ string) {}
+func (c cryptoClientStore) LoadFilterID(_ id.UserID) string    { return "" }
+func (c cryptoClientStore) SaveRoom(_ *mautrix.Room)           {}
+func (c cryptoClientStore) LoadRoom(_ id.RoomID) *mautrix.Room { return nil }
+
+func (c cryptoClientStore) SaveNextBatch(_ id.UserID, nextBatchToken string) {
+	c.int.PutNextBatch(nextBatchToken)
+}
+
+func (c cryptoClientStore) LoadNextBatch(_ id.UserID) string {
+	return c.int.GetNextBatch()
+}
+
+var _ mautrix.Storer = (*cryptoClientStore)(nil)
+
 type cryptoStateStore struct {
 	bridge *Bridge
 }

+ 393 - 0
database/cryptostore.go

@@ -0,0 +1,393 @@
+// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
+// Copyright (C) 2020 Tulir Asokan
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+package database
+
+import (
+	"database/sql"
+	"fmt"
+	"strings"
+	"sync"
+
+	"github.com/lib/pq"
+	"github.com/pkg/errors"
+	log "maunium.net/go/maulogger/v2"
+
+	"maunium.net/go/mautrix/crypto"
+	"maunium.net/go/mautrix/crypto/olm"
+	"maunium.net/go/mautrix/id"
+)
+
+type SQLCryptoStore struct {
+	db  *Database
+	log log.Logger
+
+	UserID    id.UserID
+	DeviceID  id.DeviceID
+	SyncToken string
+	PickleKey []byte
+	Account   *crypto.OlmAccount
+
+	GhostIDFormat string
+
+	OGSLock          sync.RWMutex
+	OutGroupSessions map[id.RoomID]*crypto.OutboundGroupSession
+}
+
+var _ crypto.Store = (*SQLCryptoStore)(nil)
+
+func NewSQLCryptoStore(db *Database, deviceID id.DeviceID) *SQLCryptoStore {
+	return &SQLCryptoStore{
+		db:        db,
+		log:       db.log.Sub("CryptoStore"),
+		PickleKey: []byte("maunium.net/go/mautrix-whatsapp"),
+		DeviceID:  deviceID,
+
+		OutGroupSessions: make(map[id.RoomID]*crypto.OutboundGroupSession),
+	}
+}
+
+func (db *Database) FindDeviceID() (deviceID id.DeviceID) {
+	err := db.QueryRow("SELECT device_id FROM crypto_account LIMIT 1").Scan(&deviceID)
+	if err != nil && err != sql.ErrNoRows {
+		db.log.Warnln("Failed to scan device ID:", err)
+	}
+	return
+}
+
+func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) (members []id.UserID, err error) {
+	var rows *sql.Rows
+	rows, err = store.db.Query(`
+		SELECT user_id FROM mx_user_profile
+		WHERE room_id=$1
+			AND (membership='join' OR membership='invite')
+			AND user_id<>$2
+			AND user_id NOT LIKE $3
+	`, roomID, store.UserID, store.GhostIDFormat)
+	if err != nil {
+		return
+	}
+	for rows.Next() {
+		var userID id.UserID
+		err := rows.Scan(&userID)
+		if err != nil {
+			store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
+		} else {
+			members = append(members, userID)
+		}
+	}
+	return
+}
+
+func (store *SQLCryptoStore) Flush() error {
+	return nil
+}
+
+func (store *SQLCryptoStore) PutNextBatch(nextBatch string) {
+	store.SyncToken = nextBatch
+	_, err := store.db.Exec(`UPDATE crypto_account SET sync_token=$1 WHERE device_id=$2`, store.SyncToken, store.DeviceID)
+	if err != nil {
+		store.log.Warnln("Failed to store sync token:", err)
+	}
+}
+
+func (store *SQLCryptoStore) GetNextBatch() string {
+	if store.SyncToken == "" {
+		err := store.db.
+			QueryRow("SELECT sync_token FROM crypto_account WHERE device_id=$1", store.DeviceID).
+			Scan(&store.SyncToken)
+		if err != nil && err != sql.ErrNoRows {
+			store.log.Warnln("Failed to scan sync token:", err)
+		}
+	}
+	return store.SyncToken
+}
+
+func (store *SQLCryptoStore) PutAccount(account *crypto.OlmAccount) error {
+	store.Account = account
+	bytes := account.Internal.Pickle(store.PickleKey)
+	var err error
+	if store.db.dialect == "postgres" {
+		_, err = store.db.Exec(`
+			INSERT INTO crypto_account (device_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)
+			ON CONFLICT (device_id) DO UPDATE SET shared=$2, sync_token=$3, account=$4`,
+			store.DeviceID, account.Shared, store.SyncToken, bytes)
+	} else if store.db.dialect == "sqlite3" {
+		_, err = store.db.Exec("INSERT OR REPLACE INTO crypto_account (deivce_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)",
+			store.DeviceID, account.Shared, store.SyncToken, bytes)
+	} else {
+		err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
+	}
+	if err != nil {
+		store.log.Warnln("Failed to store account:", err)
+	}
+	return nil
+}
+
+func (store *SQLCryptoStore) GetAccount() (*crypto.OlmAccount, error) {
+	if store.Account == nil {
+		row := store.db.QueryRow("SELECT shared, sync_token, account FROM crypto_account WHERE device_id=$1", store.DeviceID)
+		acc := &crypto.OlmAccount{Internal: *olm.NewBlankAccount()}
+		var accountBytes []byte
+		err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes)
+		if err == sql.ErrNoRows {
+			return nil, nil
+		} else if err != nil {
+			return nil, err
+		}
+		err = acc.Internal.Unpickle(accountBytes, store.PickleKey)
+		if err != nil {
+			return nil, err
+		}
+		store.Account = acc
+	}
+	return store.Account, nil
+}
+
+func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
+	// TODO this may need to be changed if olm sessions start expiring
+	var sessionID id.SessionID
+	err := store.db.QueryRow("SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 LIMIT 1", key).Scan(&sessionID)
+	if err == sql.ErrNoRows {
+		return false
+	}
+	return len(sessionID) > 0
+}
+
+func (store *SQLCryptoStore) GetSessions(key id.SenderKey) (crypto.OlmSessionList, error) {
+	rows, err := store.db.Query("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 ORDER BY session_id", key)
+	if err != nil {
+		return nil, err
+	}
+	list := crypto.OlmSessionList{}
+	for rows.Next() {
+		sess := crypto.OlmSession{Internal: *olm.NewBlankSession()}
+		var sessionBytes []byte
+		err := rows.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
+		if err != nil {
+			return nil, err
+		}
+		err = sess.Internal.Unpickle(sessionBytes, store.PickleKey)
+		if err != nil {
+			return nil, err
+		}
+		list = append(list, &sess)
+	}
+	return list, nil
+}
+
+func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*crypto.OlmSession, error) {
+	row := store.db.QueryRow("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 ORDER BY session_id DESC LIMIT 1", key)
+	sess := crypto.OlmSession{Internal: *olm.NewBlankSession()}
+	var sessionBytes []byte
+	err := row.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
+	if err == sql.ErrNoRows {
+		return nil, nil
+	} else if err != nil {
+		return nil, err
+	}
+	return &sess, sess.Internal.Unpickle(sessionBytes, store.PickleKey)
+}
+
+func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *crypto.OlmSession) error {
+	sessionBytes := session.Internal.Pickle(store.PickleKey)
+	_, err := store.db.Exec("INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_used) VALUES ($1, $2, $3, $4, $5)",
+		session.ID(), key, sessionBytes, session.CreationTime, session.UseTime)
+	return err
+}
+
+func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *crypto.InboundGroupSession) error {
+	sessionBytes := session.Internal.Pickle(store.PickleKey)
+	forwardingChains := strings.Join(session.ForwardingChains, ",")
+	_, err := store.db.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, signing_key, room_id, session, forwarding_chains) VALUES ($1, $2, $3, $4, $5, $6)",
+		sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains)
+	return err
+}
+
+func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*crypto.InboundGroupSession, error) {
+	var signingKey id.Ed25519
+	var sessionBytes []byte
+	var forwardingChains string
+	err := store.db.QueryRow(`
+		SELECT signing_key, session, forwarding_chains
+		FROM crypto_megolm_inbound_session
+		WHERE room_id=$1 AND sender_key=$2 AND session_id=$3`,
+		roomID, senderKey, sessionID,
+	).Scan(&signingKey, &sessionBytes, &forwardingChains)
+	if err == sql.ErrNoRows {
+		return nil, nil
+	} else if err != nil {
+		return nil, err
+	}
+	igs := olm.NewBlankInboundGroupSession()
+	err = igs.Unpickle(sessionBytes, store.PickleKey)
+	if err != nil {
+		return nil, err
+	}
+	return &crypto.InboundGroupSession{
+		Internal:         *igs,
+		SigningKey:       signingKey,
+		SenderKey:        senderKey,
+		RoomID:           roomID,
+		ForwardingChains: strings.Split(forwardingChains, ","),
+	}, nil
+}
+
+func (store *SQLCryptoStore) PutOutboundGroupSession(roomID id.RoomID, session *crypto.OutboundGroupSession) error {
+	store.OGSLock.Lock()
+	store.OutGroupSessions[roomID] = session
+	store.OGSLock.Unlock()
+	return nil
+}
+
+func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*crypto.OutboundGroupSession, error) {
+	store.OGSLock.RLock()
+	defer store.OGSLock.RUnlock()
+	return store.OutGroupSessions[roomID], nil
+}
+
+func (store *SQLCryptoStore) PopOutboundGroupSession(roomID id.RoomID) error {
+	store.OGSLock.Lock()
+	delete(store.OutGroupSessions, roomID)
+	store.OGSLock.Unlock()
+	return nil
+}
+
+func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) bool {
+	var resultEventID id.EventID
+	var resultTimestamp int64
+	err := store.db.QueryRow(
+		"SELECT event_id, timestamp FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND index=$3",
+		senderKey, sessionID, index,
+	).Scan(&resultEventID, &resultTimestamp)
+	if err == sql.ErrNoRows {
+		_, err := store.db.Exec("INSERT INTO crypto_message_index (sender_key, session_id, index, event_id, timestamp) VALUES ($1, $2, $3, $4, $5)",
+			senderKey, sessionID, index, eventID, timestamp)
+		if err != nil {
+			store.log.Warnln("Failed to store message index:", err)
+		}
+		return true
+	} else if err != nil {
+		store.log.Warnln("Failed to scan message index:", err)
+		return true
+	}
+	if resultEventID != eventID || resultTimestamp != timestamp {
+		return false
+	}
+	return true
+}
+
+func (store *SQLCryptoStore) GetDevices(userID id.UserID) (map[id.DeviceID]*crypto.DeviceIdentity, error) {
+	var ignore id.UserID
+	err := store.db.QueryRow("SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore)
+	if err == sql.ErrNoRows {
+		return nil, nil
+	} else if err != nil {
+		return nil, err
+	}
+
+	rows, err := store.db.Query("SELECT device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1", userID)
+	if err != nil {
+		return nil, err
+	}
+	data := make(map[id.DeviceID]*crypto.DeviceIdentity)
+	for rows.Next() {
+		var identity crypto.DeviceIdentity
+		err := rows.Scan(&identity.DeviceID, &identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
+		if err != nil {
+			return nil, err
+		}
+		identity.UserID = userID
+		data[identity.DeviceID] = &identity
+	}
+	return data, nil
+}
+
+func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*crypto.DeviceIdentity) error {
+	tx, err := store.db.Begin()
+	if err != nil {
+		return err
+	}
+
+	if store.db.dialect == "postgres" {
+		_, err = tx.Exec("INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
+	} else if store.db.dialect == "sqlite3" {
+		_, err = tx.Exec("INSERT OR IGNORE INTO crypto_tracked_users (user_id) VALUES ($1)", userID)
+	} else {
+		err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
+	}
+	if err != nil {
+		return errors.Wrap(err, "failed to add user to tracked users list")
+	}
+
+	_, err = tx.Exec("DELETE FROM crypto_device WHERE user_id=$1", userID)
+	if err != nil {
+		_ = tx.Rollback()
+		return errors.Wrap(err, "failed to delete old devices")
+	}
+	if len(devices) == 0 {
+		err = tx.Commit()
+		if err != nil {
+			return errors.Wrap(err, "failed to commit changes (no devices added)")
+		}
+		return nil
+	}
+	// TODO do this in batches to avoid too large db queries
+	values := make([]interface{}, 1, len(devices)*6+1)
+	values[0] = userID
+	valueStrings := make([]string, 0, len(devices))
+	i := 2
+	for deviceID, identity := range devices {
+		values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name)
+		valueStrings = append(valueStrings, fmt.Sprintf("($1, $%d, $%d, $%d, $%d, $%d, $%d)", i, i+1, i+2, i+3, i+4, i+5))
+		i += 6
+	}
+	valueString := strings.Join(valueStrings, ",")
+	_, err = tx.Exec("INSERT INTO crypto_device (user_id, device_id, identity_key, signing_key, trust, deleted, name) VALUES "+valueString, values...)
+	if err != nil {
+		_ = tx.Rollback()
+		return errors.Wrap(err, "failed to insert new devices")
+	}
+	err = tx.Commit()
+	if err != nil {
+		return errors.Wrap(err, "failed to commit changes")
+	}
+	return nil
+}
+
+func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID {
+	var rows *sql.Rows
+	var err error
+	if store.db.dialect == "postgres" {
+		rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", pq.Array(users))
+	} else {
+		rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ($1)", users)
+	}
+	if err != nil {
+		store.log.Warnln("Failed to filter tracked users:", err)
+		return users
+	}
+	var ptr int
+	for rows.Next() {
+		err = rows.Scan(&users[ptr])
+		if err != nil {
+			store.log.Warnln("Failed to tracked user ID:", err)
+		} else {
+			ptr++
+		}
+	}
+	return users[:ptr]
+}

+ 7 - 21
database/statestore.go

@@ -39,11 +39,13 @@ type SQLStateStore struct {
 	typingLock sync.RWMutex
 }
 
+var _ appservice.StateStore = (*SQLStateStore)(nil)
+
 func NewSQLStateStore(db *Database) *SQLStateStore {
 	return &SQLStateStore{
 		TypingStateStore: appservice.NewTypingStateStore(),
 		db:               db,
-		log:              log.Sub("StateStore"),
+		log:              db.log.Sub("StateStore"),
 	}
 }
 
@@ -90,24 +92,6 @@ func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*even
 	return members
 }
 
-func (store *SQLStateStore) GetRoomMemberList(roomID id.RoomID) (members []id.UserID, err error) {
-	var rows *sql.Rows
-	rows, err = store.db.Query("SELECT user_id FROM mx_user_profile WHERE room_id=$1", roomID)
-	if err != nil {
-		return
-	}
-	for rows.Next() {
-		var userID id.UserID
-		err := rows.Scan(&userID)
-		if err != nil {
-			store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
-		} else {
-			members = append(members, userID)
-		}
-	}
-	return
-}
-
 func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
 	row := store.db.QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID)
 	membership := event.MembershipLeave
@@ -138,8 +122,10 @@ func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*e
 
 func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) {
 	rows, err := store.db.Query(`
-			SELECT room_id FROM mx_user_profile WHERE user_id=$2 AND portal.encrypted=true
-			LEFT JOIN portal WHEN portal.mxid=mx_user_profile.room_id`, userID)
+			SELECT room_id FROM mx_user_profile
+			LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id
+			WHERE user_id=$1 AND portal.encrypted=true
+	`, userID)
 	if err != nil {
 		store.log.Warnfln("Failed to query shared rooms with %s: %v", userID, err)
 		return

+ 74 - 0
database/upgrades/2020-05-09-crypto-store.go

@@ -0,0 +1,74 @@
+package upgrades
+
+import (
+	"database/sql"
+)
+
+func init() {
+	upgrades[13] = upgrade{"Add crypto store to database", func(tx *sql.Tx, ctx context) error {
+		// TODO use DATETIME instead of timestamp and BLOB instead of bytea for sqlite
+		_, err := tx.Exec(`CREATE TABLE crypto_account (
+			device_id  VARCHAR(255) PRIMARY KEY,
+			shared     BOOLEAN      NOT NULL,
+			sync_token TEXT         NOT NULL,
+			account    bytea        NOT NULL
+		)`)
+		if err != nil {
+			return err
+		}
+		_, err = tx.Exec(`CREATE TABLE crypto_message_index (
+			sender_key CHAR(43),
+			session_id VARCHAR(255),
+			index      INTEGER,
+			event_id   VARCHAR(255) NOT NULL,
+			timestamp  BIGINT       NOT NULL,
+
+			PRIMARY KEY (sender_key, session_id, index)
+		)`)
+		if err != nil {
+			return err
+		}
+		_, err = tx.Exec(`CREATE TABLE crypto_tracked_user (
+			user_id VARCHAR(255) PRIMARY KEY
+		)`)
+		if err != nil {
+			return err
+		}
+		_, err = tx.Exec(`CREATE TABLE crypto_device (
+			user_id      VARCHAR(255),
+			device_id    VARCHAR(255),
+			identity_key CHAR(43)      NOT NULL,
+			signing_key  CHAR(43)      NOT NULL,
+			trust        SMALLINT      NOT NULL,
+			deleted      BOOLEAN       NOT NULL,
+			name         VARCHAR(255)  NOT NULL,
+
+			PRIMARY KEY (user_id, device_id)
+		)`)
+		if err != nil {
+			return err
+		}
+		_, err = tx.Exec(`CREATE TABLE crypto_olm_session (
+			session_id   CHAR(43)     PRIMARY KEY,
+			sender_key   VARCHAR(255) NOT NULL,
+			session      bytea        NOT NULL,
+			created_at   timestamp    NOT NULL,
+			last_used    timestamp    NOT NULL
+		)`)
+		if err != nil {
+			return err
+		}
+		_, err = tx.Exec(`CREATE TABLE crypto_megolm_inbound_session (
+			session_id   CHAR(43)     PRIMARY KEY,
+			sender_key   CHAR(43)     NOT NULL,
+			signing_key  CHAR(43)     NOT NULL,
+			room_id      VARCHAR(255) NOT NULL,
+			session      bytea        NOT NULL,
+			forwarding_chains bytea   NOT NULL
+		)`)
+		if err != nil {
+			return err
+		}
+		return nil
+	}}
+}

+ 1 - 1
database/upgrades/upgrades.go

@@ -28,7 +28,7 @@ type upgrade struct {
 	fn      upgradeFunc
 }
 
-const NumberOfUpgrades = 13
+const NumberOfUpgrades = 14
 
 var upgrades [NumberOfUpgrades]upgrade
 

+ 1 - 1
go.mod

@@ -15,7 +15,7 @@ require (
 	gopkg.in/yaml.v2 v2.2.8
 	maunium.net/go/mauflag v1.0.0
 	maunium.net/go/maulogger/v2 v2.1.1
-	maunium.net/go/mautrix v0.4.0
+	maunium.net/go/mautrix v0.4.1
 )
 
 replace github.com/Rhymen/go-whatsapp => github.com/tulir/go-whatsapp v0.2.6

+ 2 - 0
go.sum

@@ -86,3 +86,5 @@ maunium.net/go/mautrix v0.3.7 h1:N0czrZeAwjvBrw2a/B2G6U3EwIYaWpt7OuSslGp8DRc=
 maunium.net/go/mautrix v0.3.7/go.mod h1:SkGZzch8CvU2qKtNpYxtzZ0sQxfVEJ3IsVVLSUBUx9Y=
 maunium.net/go/mautrix v0.4.0 h1:IYfmxCoxR/6UMi92IncsSZeKQbZm8Xa35XIRX814KJ4=
 maunium.net/go/mautrix v0.4.0/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
+maunium.net/go/mautrix v0.4.1 h1:i2lJNT+TE4AAL3cVKUN4jKVRkujCE/oS8aIsj8+7iNE=
+maunium.net/go/mautrix v0.4.1/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=

+ 9 - 5
main.go

@@ -126,6 +126,7 @@ type Crypto interface {
 	HandleMemberEvent(*event.Event)
 	Decrypt(*event.Event) (*event.Event, error)
 	Encrypt(id.RoomID, event.Type, event.Content) (*event.EncryptedEventContent, error)
+	Init() error
 	Start()
 	Stop()
 }
@@ -225,11 +226,7 @@ func (bridge *Bridge) Init() {
 	bridge.Log.Debugln("Initializing Matrix event handler")
 	bridge.MatrixHandler = NewMatrixHandler(bridge)
 	bridge.Formatter = NewFormatter(bridge)
-	err = bridge.initCrypto()
-	if err != nil {
-		bridge.Log.Fatalln("Error initializing end-to-bridge encryption:", err)
-		os.Exit(19)
-	}
+	bridge.Crypto = NewCryptoHelper(bridge)
 }
 
 func (bridge *Bridge) Start() {
@@ -238,6 +235,13 @@ func (bridge *Bridge) Start() {
 		bridge.Log.Fatalln("Failed to initialize database:", err)
 		os.Exit(15)
 	}
+	if bridge.Crypto != nil {
+		err := bridge.Crypto.Init()
+		if err != nil {
+			bridge.Log.Fatalln("Error initializing end-to-bridge encryption:", err)
+			os.Exit(19)
+		}
+	}
 	if bridge.Provisioning != nil {
 		bridge.Log.Debugln("Initializing provisioning API")
 		bridge.Provisioning.Init()