Răsfoiți Sursa

Check database owner and foreign tables before starting

Tulir Asokan 3 ani în urmă
părinte
comite
caacac15c7
4 a modificat fișierele cu 68 adăugiri și 6 ștergeri
  1. 1 0
      CHANGELOG.md
  2. 1 2
      database/database.go
  3. 55 2
      database/upgrades/upgrades.go
  4. 11 2
      main.go

+ 1 - 0
CHANGELOG.md

@@ -15,6 +15,7 @@
   [@abmantis] in [#452]). This can be used to enable incoming typing
   notifications without enabling Matrix presence (WhatsApp only sends typing
   notifications if you're online).
+* Added checks to prevent sharing the database with unrelated software.
 * Exposed maximum database connection idle time and lifetime options.
 * Fixed syncing group topics. To get topics into existing portals on Matrix,
   you can use `!wa sync groups`.

+ 1 - 2
database/database.go

@@ -23,10 +23,9 @@ import (
 
 	"github.com/lib/pq"
 	_ "github.com/mattn/go-sqlite3"
-
+	"go.mau.fi/whatsmeow/store/sqlstore"
 	log "maunium.net/go/maulogger/v2"
 
-	"go.mau.fi/whatsmeow/store/sqlstore"
 	"maunium.net/go/mautrix-whatsapp/config"
 	"maunium.net/go/mautrix-whatsapp/database/upgrades"
 )

+ 55 - 2
database/upgrades/upgrades.go

@@ -44,7 +44,12 @@ const NumberOfUpgrades = 39
 
 var upgrades [NumberOfUpgrades]upgrade
 
-var UnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version")
+var ErrUnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version")
+var ErrForeignTables = fmt.Errorf("the database contains foreign tables")
+var ErrNotOwned = fmt.Errorf("the database is owned by")
+var IgnoreForeignTables = false
+
+const databaseOwner = "mautrix-whatsapp"
 
 func GetVersion(db *sql.DB) (int, error) {
 	_, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
@@ -60,6 +65,49 @@ func GetVersion(db *sql.DB) (int, error) {
 	return version, nil
 }
 
+const tableExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)"
+const tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND table_name=$1)"
+
+func tableExists(dialect Dialect, db *sql.DB, table string) (exists bool) {
+	if dialect == SQLite {
+		_ = db.QueryRow(tableExistsSQLite, table).Scan(&exists)
+	} else if dialect == Postgres {
+		_ = db.QueryRow(tableExistsPostgres, table).Scan(&exists)
+	}
+	return
+}
+
+const createOwnerTable = `
+CREATE TABLE IF NOT EXISTS database_owner (
+	key   INTEGER PRIMARY KEY DEFAULT 0,
+	owner TEXT NOT NULL
+)
+`
+
+func CheckDatabaseOwner(dialect Dialect, db *sql.DB) error {
+	var owner string
+	if !IgnoreForeignTables {
+		if tableExists(dialect, db, "state_groups_state") {
+			return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables)
+		} else if tableExists(dialect, db, "goose_db_version") {
+			return fmt.Errorf("%w (found goose_db_version, possibly belonging to Dendrite)", ErrForeignTables)
+		}
+	}
+	if _, err := db.Exec(createOwnerTable); err != nil {
+		return fmt.Errorf("failed to ensure database owner table exists: %w", err)
+	} else if err = db.QueryRow("SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
+		_, err = db.Exec("INSERT INTO database_owner (owner) VALUES ($1)", databaseOwner)
+		if err != nil {
+			return fmt.Errorf("failed to insert database owner: %w", err)
+		}
+	} else if err != nil {
+		return fmt.Errorf("failed to check database owner: %w", err)
+	} else if owner != databaseOwner {
+		return fmt.Errorf("%w %s", ErrNotOwned, owner)
+	}
+	return nil
+}
+
 func SetVersion(tx *sql.Tx, version int) error {
 	_, err := tx.Exec("DELETE FROM version")
 	if err != nil {
@@ -90,13 +138,18 @@ func Run(log log.Logger, dialectName string, db *sql.DB) error {
 		return fmt.Errorf("unknown dialect %s", dialectName)
 	}
 
+	err := CheckDatabaseOwner(dialect, db)
+	if err != nil {
+		return err
+	}
+
 	version, err := GetVersion(db)
 	if err != nil {
 		return err
 	}
 
 	if version > NumberOfUpgrades {
-		return fmt.Errorf("%w: currently on v%d, latest known: v%d", UnsupportedDatabaseVersion, version, NumberOfUpgrades)
+		return fmt.Errorf("%w: currently on v%d, latest known: v%d", ErrUnsupportedDatabaseVersion, version, NumberOfUpgrades)
 	}
 
 	log.Infofln("Database currently on v%d, latest: v%d", version, NumberOfUpgrades)

+ 11 - 2
main.go

@@ -103,7 +103,8 @@ var dontSaveConfig = flag.MakeFull("n", "no-update", "Don't save updated config
 var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String()
 var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool()
 var version = flag.MakeFull("v", "version", "View bridge version and quit.", "false").Bool()
-var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if database is too new").Default("false").Bool()
+var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if the database schema is too new").Default("false").Bool()
+var ignoreForeignTables = flag.Make().LongKey("ignore-foreign-tables").Usage("Run even if the database contains tables from other programs (like Synapse)").Default("false").Bool()
 var migrateFrom = flag.Make().LongKey("migrate-db").Usage("Source database type and URI to migrate from.").Bool()
 var wantHelp, _ = flag.MakeHelpFlag()
 
@@ -299,8 +300,15 @@ func (bridge *Bridge) Init() {
 func (bridge *Bridge) Start() {
 	bridge.Log.Debugln("Running database upgrades")
 	err := bridge.DB.Init()
-	if err != nil && (err != upgrades.UnsupportedDatabaseVersion || !*ignoreUnsupportedDatabase) {
+	if err != nil && (!errors.Is(err, upgrades.ErrUnsupportedDatabaseVersion) || !*ignoreUnsupportedDatabase) {
 		bridge.Log.Fatalln("Failed to initialize database:", err)
+		if errors.Is(err, upgrades.ErrForeignTables) {
+			bridge.Log.Infoln("You can use --ignore-foreign-tables to ignore this error")
+		} else if errors.Is(err, upgrades.ErrNotOwned) {
+			bridge.Log.Infoln("Sharing the same database with different programs is not supported")
+		} else if errors.Is(err, upgrades.ErrUnsupportedDatabaseVersion) {
+			bridge.Log.Infoln("Downgrading the bridge is not supported")
+		}
 		os.Exit(15)
 	}
 	bridge.Log.Debugln("Checking connection to homeserver")
@@ -517,6 +525,7 @@ func main() {
 		fmt.Println(VersionString)
 		return
 	}
+	upgrades.IgnoreForeignTables = *ignoreForeignTables
 
 	(&Bridge{
 		usersByMXID:         make(map[id.UserID]*User),