浏览代码

Check database owner and foreign tables before starting

Tulir Asokan 3 年之前
父节点
当前提交
caacac15c7
共有 4 个文件被更改,包括 68 次插入6 次删除
  1. 1 0
      CHANGELOG.md
  2. 1 2
      database/database.go
  3. 55 2
      database/upgrades/upgrades.go
  4. 11 2
      main.go

+ 1 - 0
CHANGELOG.md

@@ -15,6 +15,7 @@
   [@abmantis] in [#452]). This can be used to enable incoming typing
   [@abmantis] in [#452]). This can be used to enable incoming typing
   notifications without enabling Matrix presence (WhatsApp only sends typing
   notifications without enabling Matrix presence (WhatsApp only sends typing
   notifications if you're online).
   notifications if you're online).
+* Added checks to prevent sharing the database with unrelated software.
 * Exposed maximum database connection idle time and lifetime options.
 * Exposed maximum database connection idle time and lifetime options.
 * Fixed syncing group topics. To get topics into existing portals on Matrix,
 * Fixed syncing group topics. To get topics into existing portals on Matrix,
   you can use `!wa sync groups`.
   you can use `!wa sync groups`.

+ 1 - 2
database/database.go

@@ -23,10 +23,9 @@ import (
 
 
 	"github.com/lib/pq"
 	"github.com/lib/pq"
 	_ "github.com/mattn/go-sqlite3"
 	_ "github.com/mattn/go-sqlite3"
-
+	"go.mau.fi/whatsmeow/store/sqlstore"
 	log "maunium.net/go/maulogger/v2"
 	log "maunium.net/go/maulogger/v2"
 
 
-	"go.mau.fi/whatsmeow/store/sqlstore"
 	"maunium.net/go/mautrix-whatsapp/config"
 	"maunium.net/go/mautrix-whatsapp/config"
 	"maunium.net/go/mautrix-whatsapp/database/upgrades"
 	"maunium.net/go/mautrix-whatsapp/database/upgrades"
 )
 )

+ 55 - 2
database/upgrades/upgrades.go

@@ -44,7 +44,12 @@ const NumberOfUpgrades = 39
 
 
 var upgrades [NumberOfUpgrades]upgrade
 var upgrades [NumberOfUpgrades]upgrade
 
 
-var UnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version")
+var ErrUnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version")
+var ErrForeignTables = fmt.Errorf("the database contains foreign tables")
+var ErrNotOwned = fmt.Errorf("the database is owned by")
+var IgnoreForeignTables = false
+
+const databaseOwner = "mautrix-whatsapp"
 
 
 func GetVersion(db *sql.DB) (int, error) {
 func GetVersion(db *sql.DB) (int, error) {
 	_, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
 	_, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
@@ -60,6 +65,49 @@ func GetVersion(db *sql.DB) (int, error) {
 	return version, nil
 	return version, nil
 }
 }
 
 
+const tableExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)"
+const tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND table_name=$1)"
+
+func tableExists(dialect Dialect, db *sql.DB, table string) (exists bool) {
+	if dialect == SQLite {
+		_ = db.QueryRow(tableExistsSQLite, table).Scan(&exists)
+	} else if dialect == Postgres {
+		_ = db.QueryRow(tableExistsPostgres, table).Scan(&exists)
+	}
+	return
+}
+
+const createOwnerTable = `
+CREATE TABLE IF NOT EXISTS database_owner (
+	key   INTEGER PRIMARY KEY DEFAULT 0,
+	owner TEXT NOT NULL
+)
+`
+
+func CheckDatabaseOwner(dialect Dialect, db *sql.DB) error {
+	var owner string
+	if !IgnoreForeignTables {
+		if tableExists(dialect, db, "state_groups_state") {
+			return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables)
+		} else if tableExists(dialect, db, "goose_db_version") {
+			return fmt.Errorf("%w (found goose_db_version, possibly belonging to Dendrite)", ErrForeignTables)
+		}
+	}
+	if _, err := db.Exec(createOwnerTable); err != nil {
+		return fmt.Errorf("failed to ensure database owner table exists: %w", err)
+	} else if err = db.QueryRow("SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
+		_, err = db.Exec("INSERT INTO database_owner (owner) VALUES ($1)", databaseOwner)
+		if err != nil {
+			return fmt.Errorf("failed to insert database owner: %w", err)
+		}
+	} else if err != nil {
+		return fmt.Errorf("failed to check database owner: %w", err)
+	} else if owner != databaseOwner {
+		return fmt.Errorf("%w %s", ErrNotOwned, owner)
+	}
+	return nil
+}
+
 func SetVersion(tx *sql.Tx, version int) error {
 func SetVersion(tx *sql.Tx, version int) error {
 	_, err := tx.Exec("DELETE FROM version")
 	_, err := tx.Exec("DELETE FROM version")
 	if err != nil {
 	if err != nil {
@@ -90,13 +138,18 @@ func Run(log log.Logger, dialectName string, db *sql.DB) error {
 		return fmt.Errorf("unknown dialect %s", dialectName)
 		return fmt.Errorf("unknown dialect %s", dialectName)
 	}
 	}
 
 
+	err := CheckDatabaseOwner(dialect, db)
+	if err != nil {
+		return err
+	}
+
 	version, err := GetVersion(db)
 	version, err := GetVersion(db)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
 	if version > NumberOfUpgrades {
 	if version > NumberOfUpgrades {
-		return fmt.Errorf("%w: currently on v%d, latest known: v%d", UnsupportedDatabaseVersion, version, NumberOfUpgrades)
+		return fmt.Errorf("%w: currently on v%d, latest known: v%d", ErrUnsupportedDatabaseVersion, version, NumberOfUpgrades)
 	}
 	}
 
 
 	log.Infofln("Database currently on v%d, latest: v%d", version, NumberOfUpgrades)
 	log.Infofln("Database currently on v%d, latest: v%d", version, NumberOfUpgrades)

+ 11 - 2
main.go

@@ -103,7 +103,8 @@ var dontSaveConfig = flag.MakeFull("n", "no-update", "Don't save updated config
 var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String()
 var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String()
 var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool()
 var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool()
 var version = flag.MakeFull("v", "version", "View bridge version and quit.", "false").Bool()
 var version = flag.MakeFull("v", "version", "View bridge version and quit.", "false").Bool()
-var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if database is too new").Default("false").Bool()
+var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if the database schema is too new").Default("false").Bool()
+var ignoreForeignTables = flag.Make().LongKey("ignore-foreign-tables").Usage("Run even if the database contains tables from other programs (like Synapse)").Default("false").Bool()
 var migrateFrom = flag.Make().LongKey("migrate-db").Usage("Source database type and URI to migrate from.").Bool()
 var migrateFrom = flag.Make().LongKey("migrate-db").Usage("Source database type and URI to migrate from.").Bool()
 var wantHelp, _ = flag.MakeHelpFlag()
 var wantHelp, _ = flag.MakeHelpFlag()
 
 
@@ -299,8 +300,15 @@ func (bridge *Bridge) Init() {
 func (bridge *Bridge) Start() {
 func (bridge *Bridge) Start() {
 	bridge.Log.Debugln("Running database upgrades")
 	bridge.Log.Debugln("Running database upgrades")
 	err := bridge.DB.Init()
 	err := bridge.DB.Init()
-	if err != nil && (err != upgrades.UnsupportedDatabaseVersion || !*ignoreUnsupportedDatabase) {
+	if err != nil && (!errors.Is(err, upgrades.ErrUnsupportedDatabaseVersion) || !*ignoreUnsupportedDatabase) {
 		bridge.Log.Fatalln("Failed to initialize database:", err)
 		bridge.Log.Fatalln("Failed to initialize database:", err)
+		if errors.Is(err, upgrades.ErrForeignTables) {
+			bridge.Log.Infoln("You can use --ignore-foreign-tables to ignore this error")
+		} else if errors.Is(err, upgrades.ErrNotOwned) {
+			bridge.Log.Infoln("Sharing the same database with different programs is not supported")
+		} else if errors.Is(err, upgrades.ErrUnsupportedDatabaseVersion) {
+			bridge.Log.Infoln("Downgrading the bridge is not supported")
+		}
 		os.Exit(15)
 		os.Exit(15)
 	}
 	}
 	bridge.Log.Debugln("Checking connection to homeserver")
 	bridge.Log.Debugln("Checking connection to homeserver")
@@ -517,6 +525,7 @@ func main() {
 		fmt.Println(VersionString)
 		fmt.Println(VersionString)
 		return
 		return
 	}
 	}
+	upgrades.IgnoreForeignTables = *ignoreForeignTables
 
 
 	(&Bridge{
 	(&Bridge{
 		usersByMXID:         make(map[id.UserID]*User),
 		usersByMXID:         make(map[id.UserID]*User),