Эх сурвалжийг харах

Add database migrations and handle leaving private chat portals. Fixes #7

Tulir Asokan 6 жил өмнө
parent
commit
b0d774a5a5

+ 4 - 18
database/database.go

@@ -23,6 +23,8 @@ import (
 	_ "github.com/mattn/go-sqlite3"
 
 	log "maunium.net/go/maulogger/v2"
+
+	"maunium.net/go/mautrix-whatsapp/database/upgrades"
 )
 
 type Database struct {
@@ -64,24 +66,8 @@ func New(dbType string, uri string) (*Database, error) {
 	return db, nil
 }
 
-func (db *Database) CreateTables(dbType string) error {
-	err := db.User.CreateTable(dbType)
-	if err != nil {
-		return err
-	}
-	err = db.Portal.CreateTable(dbType)
-	if err != nil {
-		return err
-	}
-	err = db.Puppet.CreateTable(dbType)
-	if err != nil {
-		return err
-	}
-	err = db.Message.CreateTable(dbType)
-	if err != nil {
-		return err
-	}
-	return nil
+func (db *Database) Init(dialectName string) error {
+	return upgrades.Run(db.log.Sub("Upgrade"), dialectName, db.DB)
 }
 
 type Scannable interface {

+ 0 - 31
database/message.go

@@ -18,7 +18,6 @@ package database
 
 import (
 	"bytes"
-	"strings"
 	"database/sql"
 	"encoding/json"
 
@@ -34,36 +33,6 @@ type MessageQuery struct {
 	log log.Logger
 }
 
-func (mq *MessageQuery) CreateTable(dbType string) error {
-	if strings.ToLower(dbType) == "postgres" {
-		_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
-			chat_jid      VARCHAR(255),
-			chat_receiver VARCHAR(255),
-			jid           VARCHAR(255),
-			mxid          VARCHAR(255) NOT NULL UNIQUE,
-			sender        VARCHAR(255)  NOT NULL,
-			content       bytea         NOT NULL,
-
-			PRIMARY KEY (chat_jid, chat_receiver, jid),
-			FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
-		)`)
-		return err
-	} else {
-		_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
-			chat_jid      VARCHAR(255),
-			chat_receiver VARCHAR(255),
-			jid           VARCHAR(255),
-			mxid          VARCHAR(255) NOT NULL UNIQUE,
-			sender        VARCHAR(255)  NOT NULL,
-			content       BLOB          NOT NULL,
-
-			PRIMARY KEY (chat_jid, chat_receiver, jid),
-			FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
-		)`)
-	return err
-	}
-}
-
 func (mq *MessageQuery) New() *Message {
 	return &Message{
 		db:  mq.db,

+ 7 - 15
database/portal.go

@@ -59,21 +59,6 @@ type PortalQuery struct {
 	log log.Logger
 }
 
-func (pq *PortalQuery) CreateTable(dbType string) error {
-	_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS portal (
-		jid      VARCHAR(255),
-		receiver VARCHAR(255),
-		mxid     VARCHAR(255) UNIQUE,
-
-		name   VARCHAR(255) NOT NULL,
-		topic  VARCHAR(255) NOT NULL,
-		avatar VARCHAR(255) NOT NULL,
-
-		PRIMARY KEY (jid, receiver)
-	)`)
-	return err
-}
-
 func (pq *PortalQuery) New() *Portal {
 	return &Portal{
 		db:  pq.db,
@@ -160,3 +145,10 @@ func (portal *Portal) Update() {
 		portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
 	}
 }
+
+func (portal *Portal) Delete() {
+	_, err := portal.db.Exec("DELETE FROM portal WHERE jid=$1 AND receiver=$2", portal.Key.JID, portal.Key.Receiver)
+	if err != nil {
+		portal.log.Warnfln("Failed to delete %s: %v", portal.Key, err)
+	}
+}

+ 0 - 10
database/puppet.go

@@ -29,16 +29,6 @@ type PuppetQuery struct {
 	log log.Logger
 }
 
-func (pq *PuppetQuery) CreateTable(dbType string) error {
-	_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS puppet (
-		jid          VARCHAR(255) PRIMARY KEY,
-		avatar       VARCHAR(255),
-		displayname  VARCHAR(255),
-		name_quality SMALLINT
-	)`)
-	return err
-}
-
 func (pq *PuppetQuery) New() *Puppet {
 	return &Puppet{
 		db:  pq.db,

+ 74 - 0
database/upgrades/2018-09-01-initial-schema.go

@@ -0,0 +1,74 @@
+package upgrades
+
+import (
+	"database/sql"
+	"fmt"
+)
+
+func init() {
+	upgrades[0] = upgrade{"Initial schema", func(dialect Dialect, tx *sql.Tx) error {
+		var byteType string
+		if dialect == SQLite {
+			byteType = "BLOB"
+		} else {
+			byteType = "bytea"
+		}
+		_, err := tx.Exec(`CREATE TABLE IF NOT EXISTS portal (
+			jid      VARCHAR(255),
+			receiver VARCHAR(255),
+			mxid     VARCHAR(255) UNIQUE,
+	
+			name   VARCHAR(255) NOT NULL,
+			topic  VARCHAR(255) NOT NULL,
+			avatar VARCHAR(255) NOT NULL,
+	
+			PRIMARY KEY (jid, receiver)
+		)`)
+		if err != nil {
+			return err
+		}
+
+		_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS puppet (
+			jid          VARCHAR(255) PRIMARY KEY,
+			avatar       VARCHAR(255),
+			displayname  VARCHAR(255),
+			name_quality SMALLINT
+		)`)
+		if err != nil {
+			return err
+		}
+
+		_, err = tx.Exec(fmt.Sprintf(`CREATE TABLE IF NOT EXISTS "user" (
+			mxid VARCHAR(255) PRIMARY KEY,
+			jid  VARCHAR(255)  UNIQUE,
+
+			management_room VARCHAR(255),
+
+			client_id    VARCHAR(255),
+			client_token VARCHAR(255),
+			server_token VARCHAR(255),
+			enc_key      %[1]s,
+			mac_key      %[1]s
+		)`, byteType))
+		if err != nil {
+			return err
+		}
+
+		_, err = tx.Exec(fmt.Sprintf(`CREATE TABLE IF NOT EXISTS message (
+			chat_jid      VARCHAR(255),
+			chat_receiver VARCHAR(255),
+			jid           VARCHAR(255),
+			mxid          VARCHAR(255) NOT NULL UNIQUE,
+			sender        VARCHAR(255) NOT NULL,
+			content       %[1]s        NOT NULL,
+
+			PRIMARY KEY (chat_jid, chat_receiver, jid),
+			FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
+		)`, byteType))
+		if err != nil {
+			return err
+		}
+
+		return nil
+	}}
+}

+ 25 - 0
database/upgrades/2019-05-16-message-delete-cascade.go

@@ -0,0 +1,25 @@
+package upgrades
+
+import (
+	"database/sql"
+)
+
+func init() {
+	upgrades[1] = upgrade{"Add ON DELETE CASCADE to message table", func(dialect Dialect, tx *sql.Tx) error {
+		if dialect == SQLite {
+			// SQLite doesn't support constraint updates, but it isn't that careful with constraints anyway.
+			return nil
+		}
+		_, err := tx.Exec("ALTER TABLE message DROP CONSTRAINT message_chat_jid_fkey")
+		if err != nil {
+			return err
+		}
+		_, err = tx.Exec(`ALTER TABLE message ADD CONSTRAINT message_chat_jid_fkey
+				FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
+				ON DELETE CASCADE`)
+		if err != nil {
+			return err
+		}
+		return nil
+	}}
+}

+ 87 - 0
database/upgrades/upgrades.go

@@ -0,0 +1,87 @@
+package upgrades
+
+import (
+	"database/sql"
+	"fmt"
+	"strings"
+
+	log "maunium.net/go/maulogger/v2"
+)
+
+type Dialect int
+
+const (
+	Postgres Dialect = iota
+	SQLite
+)
+
+type upgradeFunc func(Dialect, *sql.Tx) error
+
+type upgrade struct {
+	message string
+	fn upgradeFunc
+}
+
+var upgrades [2]upgrade
+
+func getVersion(dialect Dialect, db *sql.DB) (int, error) {
+	_, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
+	if err != nil {
+		return -1, err
+	}
+
+	version := 0
+	row := db.QueryRow("SELECT version FROM version LIMIT 1")
+	if row != nil {
+		_ = row.Scan(&version)
+	}
+	return version, nil
+}
+
+func setVersion(dialect Dialect, tx *sql.Tx, version int) error {
+	_, err := tx.Exec("DELETE FROM version")
+	if err != nil {
+		return err
+	}
+	_, err = tx.Exec("INSERT INTO version (version) VALUES ($1)", version)
+	return err
+}
+
+func Run(log log.Logger, dialectName string, db *sql.DB) error {
+	var dialect Dialect
+	switch strings.ToLower(dialectName) {
+	case "postgres":
+		dialect = Postgres
+	case "sqlite3":
+		dialect = SQLite
+	default:
+		return fmt.Errorf("unknown dialect %s", dialectName)
+	}
+
+	version, err := getVersion(dialect, db)
+	if err != nil {
+		return err
+	}
+
+	log.Infofln("Database currently on v%d, latest: v%d", version, len(upgrades))
+	for i, upgrade := range upgrades[version:] {
+		log.Infofln("Upgrading database to v%d: %s", i+1, upgrade.message)
+		tx, err := db.Begin()
+		if err != nil {
+			return err
+		}
+		err = upgrade.fn(dialect, tx)
+		if err != nil {
+			return err
+		}
+		err = setVersion(dialect, tx, i+1)
+		if err != nil {
+			return err
+		}
+		err = tx.Commit()
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}

+ 0 - 32
database/user.go

@@ -33,38 +33,6 @@ type UserQuery struct {
 	log log.Logger
 }
 
-func (uq *UserQuery) CreateTable(dbType string) error {
-	if strings.ToLower(dbType) == "postgres" {
-		_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS "user" (
-			mxid VARCHAR(255) PRIMARY KEY,
-			jid  VARCHAR(255)  UNIQUE,
-
-			management_room VARCHAR(255),
-
-			client_id    VARCHAR(255),
-			client_token VARCHAR(255),
-			server_token VARCHAR(255),
-			enc_key      bytea,
-			mac_key      bytea
-		)`)
-		return err
-	} else {
-		_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS "user" (
-			mxid VARCHAR(255) PRIMARY KEY,
-			jid  VARCHAR(255)  UNIQUE,
-
-			management_room VARCHAR(255),
-
-			client_id    VARCHAR(255),
-			client_token VARCHAR(255),
-			server_token VARCHAR(255),
-			enc_key      BLOB,
-			mac_key      BLOB
-		)`)
-		return err
-	}
-}
-
 func (uq *UserQuery) New() *User {
 	return &User{
 		db:  uq.db,

+ 6 - 0
go.mod

@@ -19,3 +19,9 @@ require (
 )
 
 replace gopkg.in/russross/blackfriday.v2 => github.com/russross/blackfriday/v2 v2.0.1
+
+replace maunium.net/go/mautrix-appservice => ../mautrix-appservice-go
+
+replace maunium.net/go/mautrix => ../mautrix-go
+
+replace github.com/Rhymen/go-whatsapp => ../../Go/go-whatsapp

+ 2 - 2
main.go

@@ -147,9 +147,9 @@ func (bridge *Bridge) Init() {
 }
 
 func (bridge *Bridge) Start() {
-	err := bridge.DB.CreateTables(bridge.Config.AppService.Database.Type)
+	err := bridge.DB.Init(bridge.Config.AppService.Database.Type)
 	if err != nil {
-		bridge.Log.Fatalln("Failed to create database tables:", err)
+		bridge.Log.Fatalln("Failed to initialize database:", err)
 		os.Exit(15)
 	}
 	bridge.Log.Debugln("Starting application service HTTP server")

+ 20 - 0
matrix.go

@@ -111,6 +111,26 @@ func (mx *MatrixHandler) HandleMembership(evt *mautrix.Event) {
 	if evt.Content.Membership == "invite" && evt.GetStateKey() == mx.as.BotMXID() {
 		mx.HandleBotInvite(evt)
 	}
+
+	portal := mx.bridge.GetPortalByMXID(evt.RoomID)
+	if portal == nil {
+		return
+	}
+
+	user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender))
+	if user == nil || !user.Whitelisted || !user.IsLoggedIn() {
+		return
+	}
+
+	if evt.Content.Membership == "leave" {
+		if evt.GetStateKey() == evt.Sender {
+			if portal.IsPrivateChat() || evt.Unsigned.PrevContent.Membership == "join" {
+				portal.HandleMatrixLeave(user)
+			}
+		} else {
+			portal.HandleMatrixKick(user, evt)
+		}
+	}
 }
 
 func (mx *MatrixHandler) HandleRoomMetadata(evt *mautrix.Event) {

+ 50 - 0
portal.go

@@ -991,3 +991,53 @@ func (portal *Portal) HandleMatrixRedaction(sender *User, evt *mautrix.Event) {
 		portal.log.Debugln("Handled Matrix redaction:", evt)
 	}
 }
+
+func (portal *Portal) Delete() {
+	portal.Portal.Delete()
+	delete(portal.bridge.portalsByJID, portal.Key)
+	if len(portal.MXID) > 0 {
+		delete(portal.bridge.portalsByMXID, portal.MXID)
+	}
+}
+
+func (portal *Portal) Cleanup(puppetsOnly bool) {
+	if len(portal.MXID) == 0 {
+		return
+	}
+	if portal.IsPrivateChat() {
+		_, err := portal.MainIntent().LeaveRoom(portal.MXID)
+		if err != nil {
+			portal.log.Warnln("Failed to leave private chat portal with main intent:", err)
+		}
+		return
+	}
+	intent := portal.MainIntent()
+	members, err := intent.JoinedMembers(portal.MXID)
+	if err != nil {
+		portal.log.Errorln("Failed to get portal members for cleanup:", err)
+		return
+	}
+	for member, _ := range members.Joined {
+		puppet := portal.bridge.GetPuppetByMXID(member)
+		if puppet != nil {
+			_, err = puppet.Intent().LeaveRoom(portal.MXID)
+			portal.log.Errorln("Error leaving as puppet while cleaning up portal:", err)
+		} else if !puppetsOnly {
+			_, err = intent.KickUser(portal.MXID, &mautrix.ReqKickUser{UserID: member, Reason: "Deleting portal"})
+			portal.log.Errorln("Error kicking user while cleaning up portal:", err)
+		}
+	}
+}
+
+func (portal *Portal) HandleMatrixLeave(sender *User) {
+	if portal.IsPrivateChat() {
+		portal.log.Debugln("User left private chat portal, cleaning up and deleting...")
+		portal.Delete()
+		portal.Cleanup(false)
+		return
+	}
+}
+
+func (portal *Portal) HandleMatrixKick(sender *User, event *mautrix.Event) {
+	// TODO
+}

+ 4 - 0
user.go

@@ -47,6 +47,10 @@ type User struct {
 }
 
 func (bridge *Bridge) GetUserByMXID(userID types.MatrixUserID) *User {
+	_, isPuppet := bridge.ParsePuppetMXID(userID)
+	if isPuppet || userID == bridge.Bot.UserID {
+		return nil
+	}
 	bridge.usersLock.Lock()
 	defer bridge.usersLock.Unlock()
 	user, ok := bridge.usersByMXID[userID]