Forráskód Böngészése

Merge pull request #43 from RennerDev/master

Implemented postgres
Tulir Asokan 6 éve
szülő
commit
67a041c06d
7 módosított fájl, 93 hozzáadás és 59 törlés
  1. 8 7
      database/database.go
  2. 32 16
      database/message.go
  3. 8 9
      database/portal.go
  4. 6 6
      database/puppet.go
  5. 35 19
      database/user.go
  6. 1 0
      example-config.yaml
  7. 3 2
      main.go

+ 8 - 7
database/database.go

@@ -19,6 +19,7 @@ package database
 import (
 	"database/sql"
 
+	_ "github.com/lib/pq"
 	_ "github.com/mattn/go-sqlite3"
 
 	log "maunium.net/go/maulogger/v2"
@@ -34,8 +35,8 @@ type Database struct {
 	Message *MessageQuery
 }
 
-func New(file string) (*Database, error) {
-	conn, err := sql.Open("sqlite3", file)
+func New(dbType string, uri string) (*Database, error) {
+	conn, err := sql.Open(dbType, uri)
 	if err != nil {
 		return nil, err
 	}
@@ -63,20 +64,20 @@ func New(file string) (*Database, error) {
 	return db, nil
 }
 
-func (db *Database) CreateTables() error {
-	err := db.User.CreateTable()
+func (db *Database) CreateTables(dbType string) error {
+	err := db.User.CreateTable(dbType)
 	if err != nil {
 		return err
 	}
-	err = db.Portal.CreateTable()
+	err = db.Portal.CreateTable(dbType)
 	if err != nil {
 		return err
 	}
-	err = db.Puppet.CreateTable()
+	err = db.Puppet.CreateTable(dbType)
 	if err != nil {
 		return err
 	}
-	err = db.Message.CreateTable()
+	err = db.Message.CreateTable(dbType)
 	if err != nil {
 		return err
 	}

+ 32 - 16
database/message.go

@@ -18,6 +18,7 @@ package database
 
 import (
 	"bytes"
+	"strings"
 	"database/sql"
 	"encoding/json"
 
@@ -33,19 +34,34 @@ type MessageQuery struct {
 	log log.Logger
 }
 
-func (mq *MessageQuery) CreateTable() error {
-	_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
-		chat_jid      VARCHAR(25),
-		chat_receiver VARCHAR(25),
-		jid           VARCHAR(255),
-		mxid          VARCHAR(255) NOT NULL UNIQUE,
-		sender        VARCHAR(25)  NOT NULL,
-		content       BLOB         NOT NULL,
-
-		PRIMARY KEY (chat_jid, chat_receiver, jid),
-		FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
-	)`)
+func (mq *MessageQuery) CreateTable(dbType string) error {
+	if strings.ToLower(dbType) == "postgres" {
+		_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
+			chat_jid      VARCHAR(255),
+			chat_receiver VARCHAR(255),
+			jid           VARCHAR(255),
+			mxid          VARCHAR(255) NOT NULL UNIQUE,
+			sender        VARCHAR(255)  NOT NULL,
+			content       bytea         NOT NULL,
+
+			PRIMARY KEY (chat_jid, chat_receiver, jid),
+			FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
+		)`)
+		return err
+	} else {
+		_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
+			chat_jid      VARCHAR(255),
+			chat_receiver VARCHAR(255),
+			jid           VARCHAR(255),
+			mxid          VARCHAR(255) NOT NULL UNIQUE,
+			sender        VARCHAR(255)  NOT NULL,
+			content       BLOB          NOT NULL,
+
+			PRIMARY KEY (chat_jid, chat_receiver, jid),
+			FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
+		)`)
 	return err
+	}
 }
 
 func (mq *MessageQuery) New() *Message {
@@ -56,7 +72,7 @@ func (mq *MessageQuery) New() *Message {
 }
 
 func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
-	rows, err := mq.db.Query("SELECT * FROM message WHERE chat_jid=? AND chat_receiver=?", chat.JID, chat.Receiver)
+	rows, err := mq.db.Query("SELECT * FROM message WHERE chat_jid=$1 AND chat_receiver=$2", chat.JID, chat.Receiver)
 	if err != nil || rows == nil {
 		return nil
 	}
@@ -68,11 +84,11 @@ func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
 }
 
 func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.WhatsAppMessageID) *Message {
-	return mq.get("SELECT * FROM message WHERE chat_jid=? AND chat_receiver=? AND jid=?", chat.JID, chat.Receiver, jid)
+	return mq.get("SELECT * FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", chat.JID, chat.Receiver, jid)
 }
 
 func (mq *MessageQuery) GetByMXID(mxid types.MatrixEventID) *Message {
-	return mq.get("SELECT * FROM message WHERE mxid=?", mxid)
+	return mq.get("SELECT * FROM message WHERE mxid=$1", mxid)
 }
 
 func (mq *MessageQuery) get(query string, args ...interface{}) *Message {
@@ -130,7 +146,7 @@ func (msg *Message) encodeBinaryContent() []byte {
 }
 
 func (msg *Message) Insert() {
-	_, err := msg.db.Exec("INSERT INTO message VALUES (?, ?, ?, ?, ?, ?)", msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, msg.Sender, msg.encodeBinaryContent())
+	_, err := msg.db.Exec("INSERT INTO message VALUES ($1, $2, $3, $4, $5, $6)", msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, msg.Sender, msg.encodeBinaryContent())
 	if err != nil {
 		msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
 	}

+ 8 - 9
database/portal.go

@@ -59,18 +59,17 @@ type PortalQuery struct {
 	log log.Logger
 }
 
-func (pq *PortalQuery) CreateTable() error {
+func (pq *PortalQuery) CreateTable(dbType string) error {
 	_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS portal (
-		jid      VARCHAR(25),
-		receiver VARCHAR(25),
+		jid      VARCHAR(255),
+		receiver VARCHAR(255),
 		mxid     VARCHAR(255) UNIQUE,
 
 		name   VARCHAR(255) NOT NULL,
 		topic  VARCHAR(255) NOT NULL,
 		avatar VARCHAR(255) NOT NULL,
 
-		PRIMARY KEY (jid, receiver),
-		FOREIGN KEY (receiver) REFERENCES user(mxid)
+		PRIMARY KEY (jid, receiver)
 	)`)
 	return err
 }
@@ -95,11 +94,11 @@ func (pq *PortalQuery) GetAll() (portals []*Portal) {
 }
 
 func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
-	return pq.get("SELECT * FROM portal WHERE jid=? AND receiver=?", key.JID, key.Receiver)
+	return pq.get("SELECT * FROM portal WHERE jid=$1 AND receiver=$2", key.JID, key.Receiver)
 }
 
 func (pq *PortalQuery) GetByMXID(mxid types.MatrixRoomID) *Portal {
-	return pq.get("SELECT * FROM portal WHERE mxid=?", mxid)
+	return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid)
 }
 
 func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
@@ -143,7 +142,7 @@ func (portal *Portal) mxidPtr() *string {
 }
 
 func (portal *Portal) Insert() {
-	_, err := portal.db.Exec("INSERT INTO portal VALUES (?, ?, ?, ?, ?, ?)",
+	_, err := portal.db.Exec("INSERT INTO portal VALUES ($1, $2, $3, $4, $5, $6)",
 		portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar)
 	if err != nil {
 		portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
@@ -155,7 +154,7 @@ func (portal *Portal) Update() {
 	if len(portal.MXID) > 0 {
 		mxid = &portal.MXID
 	}
-	_, err := portal.db.Exec("UPDATE portal SET mxid=?, name=?, topic=?, avatar=? WHERE jid=? AND receiver=?",
+	_, err := portal.db.Exec("UPDATE portal SET mxid=$1, name=$2, topic=$3, avatar=$4 WHERE jid=$5 AND receiver=$6",
 		mxid, portal.Name, portal.Topic, portal.Avatar, portal.Key.JID, portal.Key.Receiver)
 	if err != nil {
 		portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)

+ 6 - 6
database/puppet.go

@@ -29,12 +29,12 @@ type PuppetQuery struct {
 	log log.Logger
 }
 
-func (pq *PuppetQuery) CreateTable() error {
+func (pq *PuppetQuery) CreateTable(dbType string) error {
 	_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS puppet (
-		jid          VARCHAR(25) PRIMARY KEY,
+		jid          VARCHAR(255) PRIMARY KEY,
 		avatar       VARCHAR(255),
 		displayname  VARCHAR(255),
-		name_quality TINYINT
+		name_quality SMALLINT
 	)`)
 	return err
 }
@@ -59,7 +59,7 @@ func (pq *PuppetQuery) GetAll() (puppets []*Puppet) {
 }
 
 func (pq *PuppetQuery) Get(jid types.WhatsAppID) *Puppet {
-	row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=?", jid)
+	row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=$1", jid)
 	if row == nil {
 		return nil
 	}
@@ -93,7 +93,7 @@ func (puppet *Puppet) Scan(row Scannable) *Puppet {
 }
 
 func (puppet *Puppet) Insert() {
-	_, err := puppet.db.Exec("INSERT INTO puppet VALUES (?, ?, ?, ?)",
+	_, err := puppet.db.Exec("INSERT INTO puppet VALUES ($1, $2, $3, $4)",
 		puppet.JID, puppet.Avatar, puppet.Displayname, puppet.NameQuality)
 	if err != nil {
 		puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err)
@@ -101,7 +101,7 @@ func (puppet *Puppet) Insert() {
 }
 
 func (puppet *Puppet) Update() {
-	_, err := puppet.db.Exec("UPDATE puppet SET displayname=?, name_quality=?, avatar=? WHERE jid=?",
+	_, err := puppet.db.Exec("UPDATE puppet SET displayname=$1, name_quality=$2, avatar=$3 WHERE jid=$4",
 		puppet.Displayname, puppet.NameQuality, puppet.Avatar, puppet.JID)
 	if err != nil {
 		puppet.log.Warnfln("Failed to update %s->%s: %v", puppet.JID, err)

+ 35 - 19
database/user.go

@@ -33,20 +33,36 @@ type UserQuery struct {
 	log log.Logger
 }
 
-func (uq *UserQuery) CreateTable() error {
-	_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS user (
-		mxid VARCHAR(255) PRIMARY KEY,
-		jid  VARCHAR(25)  UNIQUE,
-
-		management_room VARCHAR(255),
-
-		client_id    VARCHAR(255),
-		client_token VARCHAR(255),
-		server_token VARCHAR(255),
-		enc_key      BLOB,
-		mac_key      BLOB
-	)`)
-	return err
+func (uq *UserQuery) CreateTable(dbType string) error {
+	if strings.ToLower(dbType) == "postgres" {
+		_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS "user" (
+			mxid VARCHAR(255) PRIMARY KEY,
+			jid  VARCHAR(255)  UNIQUE,
+
+			management_room VARCHAR(255),
+
+			client_id    VARCHAR(255),
+			client_token VARCHAR(255),
+			server_token VARCHAR(255),
+			enc_key      bytea,
+			mac_key      bytea
+		)`)
+		return err
+	} else {
+		_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS "user" (
+			mxid VARCHAR(255) PRIMARY KEY,
+			jid  VARCHAR(255)  UNIQUE,
+
+			management_room VARCHAR(255),
+
+			client_id    VARCHAR(255),
+			client_token VARCHAR(255),
+			server_token VARCHAR(255),
+			enc_key      BLOB,
+			mac_key      BLOB
+		)`)
+		return err
+	}
 }
 
 func (uq *UserQuery) New() *User {
@@ -57,7 +73,7 @@ func (uq *UserQuery) New() *User {
 }
 
 func (uq *UserQuery) GetAll() (users []*User) {
-	rows, err := uq.db.Query("SELECT * FROM user")
+	rows, err := uq.db.Query(`SELECT * FROM "user"`)
 	if err != nil || rows == nil {
 		return nil
 	}
@@ -69,7 +85,7 @@ func (uq *UserQuery) GetAll() (users []*User) {
 }
 
 func (uq *UserQuery) GetByMXID(userID types.MatrixUserID) *User {
-	row := uq.db.QueryRow("SELECT * FROM user WHERE mxid=?", userID)
+	row := uq.db.QueryRow(`SELECT * FROM "user" WHERE mxid=$1`, userID)
 	if row == nil {
 		return nil
 	}
@@ -77,7 +93,7 @@ func (uq *UserQuery) GetByMXID(userID types.MatrixUserID) *User {
 }
 
 func (uq *UserQuery) GetByJID(userID types.WhatsAppID) *User {
-	row := uq.db.QueryRow("SELECT * FROM user WHERE jid=?", stripSuffix(userID))
+	row := uq.db.QueryRow(`SELECT * FROM "user" WHERE jid=$1`, stripSuffix(userID))
 	if row == nil {
 		return nil
 	}
@@ -150,7 +166,7 @@ func (user *User) sessionUnptr() (sess whatsapp.Session) {
 
 func (user *User) Insert() {
 	sess := user.sessionUnptr()
-	_, err := user.db.Exec("INSERT INTO user VALUES (?, ?, ?, ?, ?, ?, ?, ?)", user.MXID, user.jidPtr(),
+	_, err := user.db.Exec(`INSERT INTO "user" VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`, user.MXID, user.jidPtr(),
 		user.ManagementRoom,
 		sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey)
 	if err != nil {
@@ -160,7 +176,7 @@ func (user *User) Insert() {
 
 func (user *User) Update() {
 	sess := user.sessionUnptr()
-	_, err := user.db.Exec("UPDATE user SET jid=?, management_room=?, client_id=?, client_token=?, server_token=?, enc_key=?, mac_key=? WHERE mxid=?",
+	_, err := user.db.Exec(`UPDATE "user" SET jid=$1, management_room=$2, client_id=$3, client_token=$4, server_token=$5, enc_key=$6, mac_key=$7 WHERE mxid=$8`,
 		user.jidPtr(), user.ManagementRoom,
 		sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey,
 		user.MXID)

+ 1 - 0
example-config.yaml

@@ -20,6 +20,7 @@ appservice:
     # The database type. Only "sqlite3" is supported.
     type: sqlite3
     # The database URI. Usually file name. https://github.com/mattn/go-sqlite3#connection-string
+    # postres example: postgres://synapse:changeme@db/whatsapp?sslmode=disable
     uri: mautrix-whatsapp.db
   # Path to the Matrix room state store.
   state_store_path: ./mx-state.json

+ 3 - 2
main.go

@@ -133,7 +133,7 @@ func (bridge *Bridge) Init() {
 	bridge.AS.StateStore = bridge.StateStore
 
 	bridge.Log.Debugln("Initializing database")
-	bridge.DB, err = database.New(bridge.Config.AppService.Database.URI)
+	bridge.DB, err = database.New(bridge.Config.AppService.Database.Type, bridge.Config.AppService.Database.URI)
 	if err != nil {
 		bridge.Log.Fatalln("Failed to initialize database:", err)
 		os.Exit(14)
@@ -147,7 +147,7 @@ func (bridge *Bridge) Init() {
 }
 
 func (bridge *Bridge) Start() {
-	err := bridge.DB.CreateTables()
+	err := bridge.DB.CreateTables(bridge.Config.AppService.Database.Type)
 	if err != nil {
 		bridge.Log.Fatalln("Failed to create database tables:", err)
 		os.Exit(15)
@@ -185,6 +185,7 @@ func (bridge *Bridge) UpdateBotProfile() {
 }
 
 func (bridge *Bridge) StartUsers() {
+	bridge.Log.Debugln("Starting users")
 	for _, user := range bridge.GetAllUsers() {
 		go user.Connect(false)
 	}