浏览代码

Update mautrix-go

Tulir Asokan 5 年之前
父节点
当前提交
f40a91594d
共有 6 个文件被更改,包括 53 次插入15 次删除
  1. 20 8
      crypto.go
  2. 5 5
      database/cryptostore.go
  3. 13 0
      database/upgrades/2020-07-10-update-crypto-store.go
  4. 12 1
      database/upgrades/upgrades.go
  5. 1 1
      go.mod
  6. 2 0
      go.sum

+ 20 - 8
crypto.go

@@ -68,6 +68,10 @@ func NewCryptoHelper(bridge *Bridge) Crypto {
 
 func (helper *CryptoHelper) Init() error {
 	helper.log.Debugln("Initializing end-to-bridge encryption...")
+
+	helper.store = database.NewSQLCryptoStore(helper.bridge.DB, helper.bridge.AS.BotMXID(),
+		fmt.Sprintf("@%s:%s", helper.bridge.Config.Bridge.FormatUsername("%"), helper.bridge.AS.HomeserverDomain))
+
 	var err error
 	helper.client, err = helper.loginBot()
 	if err != nil {
@@ -77,8 +81,6 @@ func (helper *CryptoHelper) Init() error {
 	helper.log.Debugln("Logged in as bridge bot with device ID", helper.client.DeviceID)
 	logger := &cryptoLogger{helper.baseLog}
 	stateStore := &cryptoStateStore{helper.bridge}
-	helper.store = database.NewSQLCryptoStore(helper.bridge.DB, helper.client.DeviceID, helper.client.UserID,
-		fmt.Sprintf("@%s:%s", helper.bridge.Config.Bridge.FormatUsername("%"), helper.bridge.AS.HomeserverDomain))
 	helper.mach = crypto.NewOlmMachine(helper.client, logger, helper.store, stateStore)
 
 	helper.client.Logger = logger.int.Sub("Bot")
@@ -89,27 +91,30 @@ func (helper *CryptoHelper) Init() error {
 }
 
 func (helper *CryptoHelper) loginBot() (*mautrix.Client, error) {
-	deviceID := helper.bridge.DB.FindDeviceID()
+	deviceID := helper.store.FindDeviceID()
 	if len(deviceID) > 0 {
 		helper.log.Debugln("Found existing device ID for bot in database:", deviceID)
 	}
 	mac := hmac.New(sha512.New, []byte(helper.bridge.Config.Bridge.LoginSharedSecret))
 	mac.Write([]byte(helper.bridge.AS.BotMXID()))
-	resp, err := helper.bridge.AS.BotClient().Login(&mautrix.ReqLogin{
+	client, err := mautrix.NewClient(helper.bridge.AS.HomeserverURL, "", "")
+	if err != nil {
+		return nil, err
+	}
+	resp, err := client.Login(&mautrix.ReqLogin{
 		Type:                     "m.login.password",
 		Identifier:               mautrix.UserIdentifier{Type: "m.id.user", User: string(helper.bridge.AS.BotMXID())},
 		Password:                 hex.EncodeToString(mac.Sum(nil)),
 		DeviceID:                 deviceID,
 		InitialDeviceDisplayName: "WhatsApp Bridge",
+		StoreCredentials:         true,
 	})
 	if err != nil {
 		return nil, err
 	}
-	client, err := mautrix.NewClient(helper.bridge.AS.HomeserverURL, helper.bridge.AS.BotMXID(), resp.AccessToken)
-	if err != nil {
-		return nil, err
+	if len(deviceID) == 0 {
+		helper.store.DeviceID = resp.DeviceID
 	}
-	client.DeviceID = resp.DeviceID
 	return client, nil
 }
 
@@ -228,6 +233,8 @@ type cryptoStateStore struct {
 	bridge *Bridge
 }
 
+var _ crypto.StateStore = (*cryptoStateStore)(nil)
+
 func (c *cryptoStateStore) IsEncrypted(id id.RoomID) bool {
 	portal := c.bridge.GetPortalByMXID(id)
 	if portal != nil {
@@ -239,3 +246,8 @@ func (c *cryptoStateStore) IsEncrypted(id id.RoomID) bool {
 func (c *cryptoStateStore) FindSharedRooms(id id.UserID) []id.RoomID {
 	return c.bridge.StateStore.FindSharedRooms(id)
 }
+
+func (c *cryptoStateStore) GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent {
+	// TODO implement
+	return nil
+}

+ 5 - 5
database/cryptostore.go

@@ -35,9 +35,9 @@ type SQLCryptoStore struct {
 
 var _ crypto.Store = (*SQLCryptoStore)(nil)
 
-func NewSQLCryptoStore(db *Database, deviceID id.DeviceID, userID id.UserID, ghostIDFormat string) *SQLCryptoStore {
+func NewSQLCryptoStore(db *Database, userID id.UserID, ghostIDFormat string) *SQLCryptoStore {
 	return &SQLCryptoStore{
-		SQLCryptoStore: crypto.NewSQLCryptoStore(db.DB, db.dialect, deviceID,
+		SQLCryptoStore: crypto.NewSQLCryptoStore(db.DB, db.dialect, "", "",
 			[]byte("maunium.net/go/mautrix-whatsapp"),
 			&cryptoLogger{db.log.Sub("CryptoStore")}),
 		UserID:        userID,
@@ -45,10 +45,10 @@ func NewSQLCryptoStore(db *Database, deviceID id.DeviceID, userID id.UserID, gho
 	}
 }
 
-func (db *Database) FindDeviceID() (deviceID id.DeviceID) {
-	err := db.QueryRow("SELECT device_id FROM crypto_account LIMIT 1").Scan(&deviceID)
+func (store *SQLCryptoStore) FindDeviceID() (deviceID id.DeviceID) {
+	err := store.DB.QueryRow("SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID)
 	if err != nil && err != sql.ErrNoRows {
-		db.log.Warnln("Failed to scan device ID:", err)
+		store.Log.Warn("Failed to scan device ID: %v", err)
 	}
 	return
 }

+ 13 - 0
database/upgrades/2020-07-10-update-crypto-store.go

@@ -0,0 +1,13 @@
+package upgrades
+
+import (
+	"database/sql"
+
+	"maunium.net/go/mautrix/crypto"
+)
+
+func init() {
+	upgrades[16] = upgrade{"Add account_id to crypto store", func(tx *sql.Tx, c context) error {
+		return crypto.SQLStoreMigrations[1](tx, c.dialect.String())
+	}}
+}

+ 12 - 1
database/upgrades/upgrades.go

@@ -15,6 +15,17 @@ const (
 	SQLite
 )
 
+func (dialect Dialect) String() string {
+	switch dialect {
+	case Postgres:
+		return "postgres"
+	case SQLite:
+		return "sqlite3"
+	default:
+		return ""
+	}
+}
+
 type upgradeFunc func(*sql.Tx, context) error
 
 type context struct {
@@ -28,7 +39,7 @@ type upgrade struct {
 	fn      upgradeFunc
 }
 
-const NumberOfUpgrades = 16
+const NumberOfUpgrades = 17
 
 var upgrades [NumberOfUpgrades]upgrade
 

+ 1 - 1
go.mod

@@ -16,7 +16,7 @@ require (
 	gopkg.in/yaml.v2 v2.3.0
 	maunium.net/go/mauflag v1.0.0
 	maunium.net/go/maulogger/v2 v2.1.1
-	maunium.net/go/mautrix v0.5.8
+	maunium.net/go/mautrix v0.6.0
 )
 
 replace github.com/Rhymen/go-whatsapp => github.com/tulir/go-whatsapp v0.3.4

+ 2 - 0
go.sum

@@ -204,3 +204,5 @@ maunium.net/go/mautrix v0.5.7 h1:tyRwllz3SZvMfD2YjaJPWopxmUCxZgQ2hl5/3/loHTE=
 maunium.net/go/mautrix v0.5.7/go.mod h1:FLbMANzwqlsX2Fgm7SDe+E4I3wSa4UxJRKqS5wGkCwA=
 maunium.net/go/mautrix v0.5.8 h1:jOE3U8WYSIc4qbYvyVaDhOaQcB3sDPN5A2zQ93YixZ0=
 maunium.net/go/mautrix v0.5.8/go.mod h1:Va/74MijqaS0DQ3aUqxmFO54/PMfr1LVsCOcGRHbYmo=
+maunium.net/go/mautrix v0.6.0 h1:V32l4aygKk2XcH3fi8Yd0pFeSyYZJNRIvr8vdA2GtC8=
+maunium.net/go/mautrix v0.6.0/go.mod h1:Va/74MijqaS0DQ3aUqxmFO54/PMfr1LVsCOcGRHbYmo=