Browse Source

Merge pull request #43 from RennerDev/master

Implemented postgres
Tulir Asokan 6 years ago
parent
commit
67a041c06d
7 changed files with 93 additions and 59 deletions
  1. 8 7
      database/database.go
  2. 32 16
      database/message.go
  3. 8 9
      database/portal.go
  4. 6 6
      database/puppet.go
  5. 35 19
      database/user.go
  6. 1 0
      example-config.yaml
  7. 3 2
      main.go

+ 8 - 7
database/database.go

@@ -19,6 +19,7 @@ package database
 import (
 import (
 	"database/sql"
 	"database/sql"
 
 
+	_ "github.com/lib/pq"
 	_ "github.com/mattn/go-sqlite3"
 	_ "github.com/mattn/go-sqlite3"
 
 
 	log "maunium.net/go/maulogger/v2"
 	log "maunium.net/go/maulogger/v2"
@@ -34,8 +35,8 @@ type Database struct {
 	Message *MessageQuery
 	Message *MessageQuery
 }
 }
 
 
-func New(file string) (*Database, error) {
-	conn, err := sql.Open("sqlite3", file)
+func New(dbType string, uri string) (*Database, error) {
+	conn, err := sql.Open(dbType, uri)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -63,20 +64,20 @@ func New(file string) (*Database, error) {
 	return db, nil
 	return db, nil
 }
 }
 
 
-func (db *Database) CreateTables() error {
-	err := db.User.CreateTable()
+func (db *Database) CreateTables(dbType string) error {
+	err := db.User.CreateTable(dbType)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	err = db.Portal.CreateTable()
+	err = db.Portal.CreateTable(dbType)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	err = db.Puppet.CreateTable()
+	err = db.Puppet.CreateTable(dbType)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	err = db.Message.CreateTable()
+	err = db.Message.CreateTable(dbType)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}

+ 32 - 16
database/message.go

@@ -18,6 +18,7 @@ package database
 
 
 import (
 import (
 	"bytes"
 	"bytes"
+	"strings"
 	"database/sql"
 	"database/sql"
 	"encoding/json"
 	"encoding/json"
 
 
@@ -33,19 +34,34 @@ type MessageQuery struct {
 	log log.Logger
 	log log.Logger
 }
 }
 
 
-func (mq *MessageQuery) CreateTable() error {
-	_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
-		chat_jid      VARCHAR(25),
-		chat_receiver VARCHAR(25),
-		jid           VARCHAR(255),
-		mxid          VARCHAR(255) NOT NULL UNIQUE,
-		sender        VARCHAR(25)  NOT NULL,
-		content       BLOB         NOT NULL,
-
-		PRIMARY KEY (chat_jid, chat_receiver, jid),
-		FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
-	)`)
+func (mq *MessageQuery) CreateTable(dbType string) error {
+	if strings.ToLower(dbType) == "postgres" {
+		_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
+			chat_jid      VARCHAR(255),
+			chat_receiver VARCHAR(255),
+			jid           VARCHAR(255),
+			mxid          VARCHAR(255) NOT NULL UNIQUE,
+			sender        VARCHAR(255)  NOT NULL,
+			content       bytea         NOT NULL,
+
+			PRIMARY KEY (chat_jid, chat_receiver, jid),
+			FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
+		)`)
+		return err
+	} else {
+		_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
+			chat_jid      VARCHAR(255),
+			chat_receiver VARCHAR(255),
+			jid           VARCHAR(255),
+			mxid          VARCHAR(255) NOT NULL UNIQUE,
+			sender        VARCHAR(255)  NOT NULL,
+			content       BLOB          NOT NULL,
+
+			PRIMARY KEY (chat_jid, chat_receiver, jid),
+			FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
+		)`)
 	return err
 	return err
+	}
 }
 }
 
 
 func (mq *MessageQuery) New() *Message {
 func (mq *MessageQuery) New() *Message {
@@ -56,7 +72,7 @@ func (mq *MessageQuery) New() *Message {
 }
 }
 
 
 func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
 func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
-	rows, err := mq.db.Query("SELECT * FROM message WHERE chat_jid=? AND chat_receiver=?", chat.JID, chat.Receiver)
+	rows, err := mq.db.Query("SELECT * FROM message WHERE chat_jid=$1 AND chat_receiver=$2", chat.JID, chat.Receiver)
 	if err != nil || rows == nil {
 	if err != nil || rows == nil {
 		return nil
 		return nil
 	}
 	}
@@ -68,11 +84,11 @@ func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
 }
 }
 
 
 func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.WhatsAppMessageID) *Message {
 func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.WhatsAppMessageID) *Message {
-	return mq.get("SELECT * FROM message WHERE chat_jid=? AND chat_receiver=? AND jid=?", chat.JID, chat.Receiver, jid)
+	return mq.get("SELECT * FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", chat.JID, chat.Receiver, jid)
 }
 }
 
 
 func (mq *MessageQuery) GetByMXID(mxid types.MatrixEventID) *Message {
 func (mq *MessageQuery) GetByMXID(mxid types.MatrixEventID) *Message {
-	return mq.get("SELECT * FROM message WHERE mxid=?", mxid)
+	return mq.get("SELECT * FROM message WHERE mxid=$1", mxid)
 }
 }
 
 
 func (mq *MessageQuery) get(query string, args ...interface{}) *Message {
 func (mq *MessageQuery) get(query string, args ...interface{}) *Message {
@@ -130,7 +146,7 @@ func (msg *Message) encodeBinaryContent() []byte {
 }
 }
 
 
 func (msg *Message) Insert() {
 func (msg *Message) Insert() {
-	_, err := msg.db.Exec("INSERT INTO message VALUES (?, ?, ?, ?, ?, ?)", msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, msg.Sender, msg.encodeBinaryContent())
+	_, err := msg.db.Exec("INSERT INTO message VALUES ($1, $2, $3, $4, $5, $6)", msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, msg.Sender, msg.encodeBinaryContent())
 	if err != nil {
 	if err != nil {
 		msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
 		msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
 	}
 	}

+ 8 - 9
database/portal.go

@@ -59,18 +59,17 @@ type PortalQuery struct {
 	log log.Logger
 	log log.Logger
 }
 }
 
 
-func (pq *PortalQuery) CreateTable() error {
+func (pq *PortalQuery) CreateTable(dbType string) error {
 	_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS portal (
 	_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS portal (
-		jid      VARCHAR(25),
-		receiver VARCHAR(25),
+		jid      VARCHAR(255),
+		receiver VARCHAR(255),
 		mxid     VARCHAR(255) UNIQUE,
 		mxid     VARCHAR(255) UNIQUE,
 
 
 		name   VARCHAR(255) NOT NULL,
 		name   VARCHAR(255) NOT NULL,
 		topic  VARCHAR(255) NOT NULL,
 		topic  VARCHAR(255) NOT NULL,
 		avatar VARCHAR(255) NOT NULL,
 		avatar VARCHAR(255) NOT NULL,
 
 
-		PRIMARY KEY (jid, receiver),
-		FOREIGN KEY (receiver) REFERENCES user(mxid)
+		PRIMARY KEY (jid, receiver)
 	)`)
 	)`)
 	return err
 	return err
 }
 }
@@ -95,11 +94,11 @@ func (pq *PortalQuery) GetAll() (portals []*Portal) {
 }
 }
 
 
 func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
 func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
-	return pq.get("SELECT * FROM portal WHERE jid=? AND receiver=?", key.JID, key.Receiver)
+	return pq.get("SELECT * FROM portal WHERE jid=$1 AND receiver=$2", key.JID, key.Receiver)
 }
 }
 
 
 func (pq *PortalQuery) GetByMXID(mxid types.MatrixRoomID) *Portal {
 func (pq *PortalQuery) GetByMXID(mxid types.MatrixRoomID) *Portal {
-	return pq.get("SELECT * FROM portal WHERE mxid=?", mxid)
+	return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid)
 }
 }
 
 
 func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
 func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
@@ -143,7 +142,7 @@ func (portal *Portal) mxidPtr() *string {
 }
 }
 
 
 func (portal *Portal) Insert() {
 func (portal *Portal) Insert() {
-	_, err := portal.db.Exec("INSERT INTO portal VALUES (?, ?, ?, ?, ?, ?)",
+	_, err := portal.db.Exec("INSERT INTO portal VALUES ($1, $2, $3, $4, $5, $6)",
 		portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar)
 		portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar)
 	if err != nil {
 	if err != nil {
 		portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
 		portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
@@ -155,7 +154,7 @@ func (portal *Portal) Update() {
 	if len(portal.MXID) > 0 {
 	if len(portal.MXID) > 0 {
 		mxid = &portal.MXID
 		mxid = &portal.MXID
 	}
 	}
-	_, err := portal.db.Exec("UPDATE portal SET mxid=?, name=?, topic=?, avatar=? WHERE jid=? AND receiver=?",
+	_, err := portal.db.Exec("UPDATE portal SET mxid=$1, name=$2, topic=$3, avatar=$4 WHERE jid=$5 AND receiver=$6",
 		mxid, portal.Name, portal.Topic, portal.Avatar, portal.Key.JID, portal.Key.Receiver)
 		mxid, portal.Name, portal.Topic, portal.Avatar, portal.Key.JID, portal.Key.Receiver)
 	if err != nil {
 	if err != nil {
 		portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
 		portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)

+ 6 - 6
database/puppet.go

@@ -29,12 +29,12 @@ type PuppetQuery struct {
 	log log.Logger
 	log log.Logger
 }
 }
 
 
-func (pq *PuppetQuery) CreateTable() error {
+func (pq *PuppetQuery) CreateTable(dbType string) error {
 	_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS puppet (
 	_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS puppet (
-		jid          VARCHAR(25) PRIMARY KEY,
+		jid          VARCHAR(255) PRIMARY KEY,
 		avatar       VARCHAR(255),
 		avatar       VARCHAR(255),
 		displayname  VARCHAR(255),
 		displayname  VARCHAR(255),
-		name_quality TINYINT
+		name_quality SMALLINT
 	)`)
 	)`)
 	return err
 	return err
 }
 }
@@ -59,7 +59,7 @@ func (pq *PuppetQuery) GetAll() (puppets []*Puppet) {
 }
 }
 
 
 func (pq *PuppetQuery) Get(jid types.WhatsAppID) *Puppet {
 func (pq *PuppetQuery) Get(jid types.WhatsAppID) *Puppet {
-	row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=?", jid)
+	row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=$1", jid)
 	if row == nil {
 	if row == nil {
 		return nil
 		return nil
 	}
 	}
@@ -93,7 +93,7 @@ func (puppet *Puppet) Scan(row Scannable) *Puppet {
 }
 }
 
 
 func (puppet *Puppet) Insert() {
 func (puppet *Puppet) Insert() {
-	_, err := puppet.db.Exec("INSERT INTO puppet VALUES (?, ?, ?, ?)",
+	_, err := puppet.db.Exec("INSERT INTO puppet VALUES ($1, $2, $3, $4)",
 		puppet.JID, puppet.Avatar, puppet.Displayname, puppet.NameQuality)
 		puppet.JID, puppet.Avatar, puppet.Displayname, puppet.NameQuality)
 	if err != nil {
 	if err != nil {
 		puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err)
 		puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err)
@@ -101,7 +101,7 @@ func (puppet *Puppet) Insert() {
 }
 }
 
 
 func (puppet *Puppet) Update() {
 func (puppet *Puppet) Update() {
-	_, err := puppet.db.Exec("UPDATE puppet SET displayname=?, name_quality=?, avatar=? WHERE jid=?",
+	_, err := puppet.db.Exec("UPDATE puppet SET displayname=$1, name_quality=$2, avatar=$3 WHERE jid=$4",
 		puppet.Displayname, puppet.NameQuality, puppet.Avatar, puppet.JID)
 		puppet.Displayname, puppet.NameQuality, puppet.Avatar, puppet.JID)
 	if err != nil {
 	if err != nil {
 		puppet.log.Warnfln("Failed to update %s->%s: %v", puppet.JID, err)
 		puppet.log.Warnfln("Failed to update %s->%s: %v", puppet.JID, err)

+ 35 - 19
database/user.go

@@ -33,20 +33,36 @@ type UserQuery struct {
 	log log.Logger
 	log log.Logger
 }
 }
 
 
-func (uq *UserQuery) CreateTable() error {
-	_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS user (
-		mxid VARCHAR(255) PRIMARY KEY,
-		jid  VARCHAR(25)  UNIQUE,
-
-		management_room VARCHAR(255),
-
-		client_id    VARCHAR(255),
-		client_token VARCHAR(255),
-		server_token VARCHAR(255),
-		enc_key      BLOB,
-		mac_key      BLOB
-	)`)
-	return err
+func (uq *UserQuery) CreateTable(dbType string) error {
+	if strings.ToLower(dbType) == "postgres" {
+		_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS "user" (
+			mxid VARCHAR(255) PRIMARY KEY,
+			jid  VARCHAR(255)  UNIQUE,
+
+			management_room VARCHAR(255),
+
+			client_id    VARCHAR(255),
+			client_token VARCHAR(255),
+			server_token VARCHAR(255),
+			enc_key      bytea,
+			mac_key      bytea
+		)`)
+		return err
+	} else {
+		_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS "user" (
+			mxid VARCHAR(255) PRIMARY KEY,
+			jid  VARCHAR(255)  UNIQUE,
+
+			management_room VARCHAR(255),
+
+			client_id    VARCHAR(255),
+			client_token VARCHAR(255),
+			server_token VARCHAR(255),
+			enc_key      BLOB,
+			mac_key      BLOB
+		)`)
+		return err
+	}
 }
 }
 
 
 func (uq *UserQuery) New() *User {
 func (uq *UserQuery) New() *User {
@@ -57,7 +73,7 @@ func (uq *UserQuery) New() *User {
 }
 }
 
 
 func (uq *UserQuery) GetAll() (users []*User) {
 func (uq *UserQuery) GetAll() (users []*User) {
-	rows, err := uq.db.Query("SELECT * FROM user")
+	rows, err := uq.db.Query(`SELECT * FROM "user"`)
 	if err != nil || rows == nil {
 	if err != nil || rows == nil {
 		return nil
 		return nil
 	}
 	}
@@ -69,7 +85,7 @@ func (uq *UserQuery) GetAll() (users []*User) {
 }
 }
 
 
 func (uq *UserQuery) GetByMXID(userID types.MatrixUserID) *User {
 func (uq *UserQuery) GetByMXID(userID types.MatrixUserID) *User {
-	row := uq.db.QueryRow("SELECT * FROM user WHERE mxid=?", userID)
+	row := uq.db.QueryRow(`SELECT * FROM "user" WHERE mxid=$1`, userID)
 	if row == nil {
 	if row == nil {
 		return nil
 		return nil
 	}
 	}
@@ -77,7 +93,7 @@ func (uq *UserQuery) GetByMXID(userID types.MatrixUserID) *User {
 }
 }
 
 
 func (uq *UserQuery) GetByJID(userID types.WhatsAppID) *User {
 func (uq *UserQuery) GetByJID(userID types.WhatsAppID) *User {
-	row := uq.db.QueryRow("SELECT * FROM user WHERE jid=?", stripSuffix(userID))
+	row := uq.db.QueryRow(`SELECT * FROM "user" WHERE jid=$1`, stripSuffix(userID))
 	if row == nil {
 	if row == nil {
 		return nil
 		return nil
 	}
 	}
@@ -150,7 +166,7 @@ func (user *User) sessionUnptr() (sess whatsapp.Session) {
 
 
 func (user *User) Insert() {
 func (user *User) Insert() {
 	sess := user.sessionUnptr()
 	sess := user.sessionUnptr()
-	_, err := user.db.Exec("INSERT INTO user VALUES (?, ?, ?, ?, ?, ?, ?, ?)", user.MXID, user.jidPtr(),
+	_, err := user.db.Exec(`INSERT INTO "user" VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`, user.MXID, user.jidPtr(),
 		user.ManagementRoom,
 		user.ManagementRoom,
 		sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey)
 		sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey)
 	if err != nil {
 	if err != nil {
@@ -160,7 +176,7 @@ func (user *User) Insert() {
 
 
 func (user *User) Update() {
 func (user *User) Update() {
 	sess := user.sessionUnptr()
 	sess := user.sessionUnptr()
-	_, err := user.db.Exec("UPDATE user SET jid=?, management_room=?, client_id=?, client_token=?, server_token=?, enc_key=?, mac_key=? WHERE mxid=?",
+	_, err := user.db.Exec(`UPDATE "user" SET jid=$1, management_room=$2, client_id=$3, client_token=$4, server_token=$5, enc_key=$6, mac_key=$7 WHERE mxid=$8`,
 		user.jidPtr(), user.ManagementRoom,
 		user.jidPtr(), user.ManagementRoom,
 		sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey,
 		sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey,
 		user.MXID)
 		user.MXID)

+ 1 - 0
example-config.yaml

@@ -20,6 +20,7 @@ appservice:
     # The database type. Only "sqlite3" is supported.
     # The database type. Only "sqlite3" is supported.
     type: sqlite3
     type: sqlite3
     # The database URI. Usually file name. https://github.com/mattn/go-sqlite3#connection-string
     # The database URI. Usually file name. https://github.com/mattn/go-sqlite3#connection-string
+    # postres example: postgres://synapse:changeme@db/whatsapp?sslmode=disable
     uri: mautrix-whatsapp.db
     uri: mautrix-whatsapp.db
   # Path to the Matrix room state store.
   # Path to the Matrix room state store.
   state_store_path: ./mx-state.json
   state_store_path: ./mx-state.json

+ 3 - 2
main.go

@@ -133,7 +133,7 @@ func (bridge *Bridge) Init() {
 	bridge.AS.StateStore = bridge.StateStore
 	bridge.AS.StateStore = bridge.StateStore
 
 
 	bridge.Log.Debugln("Initializing database")
 	bridge.Log.Debugln("Initializing database")
-	bridge.DB, err = database.New(bridge.Config.AppService.Database.URI)
+	bridge.DB, err = database.New(bridge.Config.AppService.Database.Type, bridge.Config.AppService.Database.URI)
 	if err != nil {
 	if err != nil {
 		bridge.Log.Fatalln("Failed to initialize database:", err)
 		bridge.Log.Fatalln("Failed to initialize database:", err)
 		os.Exit(14)
 		os.Exit(14)
@@ -147,7 +147,7 @@ func (bridge *Bridge) Init() {
 }
 }
 
 
 func (bridge *Bridge) Start() {
 func (bridge *Bridge) Start() {
-	err := bridge.DB.CreateTables()
+	err := bridge.DB.CreateTables(bridge.Config.AppService.Database.Type)
 	if err != nil {
 	if err != nil {
 		bridge.Log.Fatalln("Failed to create database tables:", err)
 		bridge.Log.Fatalln("Failed to create database tables:", err)
 		os.Exit(15)
 		os.Exit(15)
@@ -185,6 +185,7 @@ func (bridge *Bridge) UpdateBotProfile() {
 }
 }
 
 
 func (bridge *Bridge) StartUsers() {
 func (bridge *Bridge) StartUsers() {
+	bridge.Log.Debugln("Starting users")
 	for _, user := range bridge.GetAllUsers() {
 	for _, user := range bridge.GetAllUsers() {
 		go user.Connect(false)
 		go user.Connect(false)
 	}
 	}