Selaa lähdekoodia

Initial desegregation of users and automatic config updating

Tulir Asokan 6 vuotta sitten
vanhempi
sitoutus
c7348f29b0

+ 1 - 0
.gitignore

@@ -6,3 +6,4 @@
 *.session
 *.json
 *.db
+*.log

+ 2 - 2
Gopkg.lock

@@ -123,7 +123,7 @@
     ".",
     "format"
   ]
-  revision = "ead1f970c8f56d1854cb9eb4a54c03aa6dafd753"
+  revision = "42a3133c4980e4b1ea5fb52329d977f592d67cf0"
 
 [[projects]]
   branch = "master"
@@ -141,7 +141,7 @@
   branch = "master"
   name = "maunium.net/go/mautrix-appservice"
   packages = ["."]
-  revision = "269f2ab602126a2de94bc86a457392426cce1ab2"
+  revision = "37d4449056015cea5d0a4420bba595c61ad32007"
 
 [solve-meta]
   analyzer-name = "dep"

+ 2 - 2
commands.go

@@ -48,7 +48,7 @@ type CommandEvent struct {
 func (ce *CommandEvent) Reply(msg string) {
 	_, err := ce.Bot.SendNotice(string(ce.RoomID), msg)
 	if err != nil {
-		ce.Handler.log.Warnfln("Failed to reply to command from %s: %v", ce.User.ID, err)
+		ce.Handler.log.Warnfln("Failed to reply to command from %s: %v", ce.User.MXID, err)
 	}
 }
 
@@ -56,7 +56,7 @@ func (handler *CommandHandler) Handle(roomID types.MatrixRoomID, user *User, mes
 	args := strings.Split(message, " ")
 	cmd := strings.ToLower(args[0])
 	ce := &CommandEvent{
-		Bot:     handler.bridge.AppService.BotIntent(),
+		Bot:     handler.bridge.AS.BotIntent(),
 		Bridge:  handler.bridge,
 		Handler: handler,
 		RoomID:  roomID,

+ 3 - 13
config/bridge.go

@@ -56,12 +56,7 @@ func (bc *BridgeConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
 	return err
 }
 
-type DisplaynameTemplateArgs struct {
-	Displayname string
-}
-
 type UsernameTemplateArgs struct {
-	Receiver string
 	UserID   string
 }
 
@@ -74,14 +69,9 @@ func (bc BridgeConfig) FormatDisplayname(contact whatsapp.Contact) string {
 	return buf.String()
 }
 
-func (bc BridgeConfig) FormatUsername(receiver types.MatrixUserID, userID types.WhatsAppID) string {
+func (bc BridgeConfig) FormatUsername(userID types.WhatsAppID) string {
 	var buf bytes.Buffer
-	receiver = strings.Replace(receiver, "@", "=40", 1)
-	receiver = strings.Replace(receiver, ":", "=3", 1)
-	bc.usernameTemplate.Execute(&buf, UsernameTemplateArgs{
-		Receiver: receiver,
-		UserID:   userID,
-	})
+	bc.usernameTemplate.Execute(&buf, userID)
 	return buf.String()
 }
 
@@ -92,7 +82,7 @@ func (bc BridgeConfig) MarshalYAML() (interface{}, error) {
 		Name:   "{{.Name}}",
 		Short:  "{{.Short}}",
 	})
-	bc.UsernameTemplate = bc.FormatUsername("{{.Receiver}}", "{{.UserID}}")
+	bc.UsernameTemplate = bc.FormatUsername("{{.}}")
 	return bc, nil
 }
 

+ 0 - 1
config/config.go

@@ -78,7 +78,6 @@ func (config *Config) Save(path string) error {
 
 func (config *Config) MakeAppService() (*appservice.AppService, error) {
 	as := appservice.Create()
-	as.LogConfig = config.Logging
 	as.HomeserverDomain = config.Homeserver.Domain
 	as.HomeserverURL = config.Homeserver.Address
 	as.Host.Hostname = config.AppService.Hostname

+ 115 - 0
config/recursivemap.go

@@ -0,0 +1,115 @@
+// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
+// Copyright (C) 2018 Tulir Asokan
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+package config
+
+import (
+	"strings"
+)
+
+type RecursiveMap map[interface{}]interface{}
+
+func (rm RecursiveMap) GetDefault(path string, defVal interface{}) interface{} {
+	val, ok := rm.Get(path)
+	if !ok {
+		return defVal
+	}
+	return val
+}
+
+func (rm RecursiveMap) GetMap(path string) RecursiveMap {
+	val := rm.GetDefault(path, nil)
+	if val == nil {
+		return nil
+	}
+
+	newRM, ok := val.(map[interface{}]interface{})
+	if ok {
+		return RecursiveMap(newRM)
+	}
+	return nil
+}
+
+func (rm RecursiveMap) Get(path string) (interface{}, bool) {
+	if index := strings.IndexRune(path, '.'); index >= 0 {
+		key := path[:index]
+		path = path[index+1:]
+
+		submap := rm.GetMap(key)
+		if submap == nil {
+			return nil, false
+		}
+		return submap.Get(path)
+	}
+	val, ok := rm[path]
+	return val, ok
+}
+
+func (rm RecursiveMap) GetIntDefault(path string, defVal int) int {
+	val, ok := rm.GetInt(path)
+	if !ok {
+		return defVal
+	}
+	return val
+}
+
+func (rm RecursiveMap) GetInt(path string) (int, bool) {
+	val, ok := rm.Get(path)
+	if !ok {
+		return 0, ok
+	}
+	intVal, ok := val.(int)
+	return intVal, ok
+}
+
+func (rm RecursiveMap) GetStringDefault(path string, defVal string) string {
+	val, ok := rm.GetString(path)
+	if !ok {
+		return defVal
+	}
+	return val
+}
+
+func (rm RecursiveMap) GetString(path string) (string, bool) {
+	val, ok := rm.Get(path)
+	if !ok {
+		return "", ok
+	}
+	strVal, ok := val.(string)
+	return strVal, ok
+}
+
+func (rm RecursiveMap) Set(path string, value interface{}) {
+	if index := strings.IndexRune(path, '.'); index >= 0 {
+		key := path[:index]
+		path = path[index+1:]
+		nextRM := rm.GetMap(key)
+		if nextRM == nil {
+			nextRM = make(RecursiveMap)
+			rm[key] = nextRM
+		}
+		nextRM.Set(path, value)
+		return
+	}
+	rm[path] = value
+}
+
+func (rm RecursiveMap) CopyFrom(otherRM RecursiveMap, path string) {
+	val, ok := otherRM.Get(path)
+	if ok {
+		rm.Set(path, val)
+	}
+}

+ 1 - 1
config/registration.go

@@ -56,7 +56,7 @@ func (config *Config) copyToRegistration(registration *appservice.Registration)
 	registration.SenderLocalpart = config.AppService.Bot.Username
 
 	userIDRegex, err := regexp.Compile(fmt.Sprintf("^@%s:%s$",
-		config.Bridge.FormatUsername(".+", "[0-9]+"),
+		config.Bridge.FormatUsername("[0-9]+"),
 		config.Homeserver.Domain))
 	if err != nil {
 		return err

+ 100 - 0
config/update.go

@@ -0,0 +1,100 @@
+// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
+// Copyright (C) 2018 Tulir Asokan
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+package config
+
+import (
+	"io/ioutil"
+
+	"gopkg.in/yaml.v2"
+)
+
+func Update(path, basePath string) error {
+	oldCfgData, err := ioutil.ReadFile(path)
+	if err != nil {
+		return err
+	}
+
+	oldCfg := make(RecursiveMap)
+	err = yaml.Unmarshal(oldCfgData, &oldCfg)
+	if err != nil {
+		return err
+	}
+
+	baseCfgData, err := ioutil.ReadFile(basePath)
+	if err != nil {
+		return err
+	}
+
+	baseCfg := make(RecursiveMap)
+	err = yaml.Unmarshal(baseCfgData, &baseCfg)
+	if err != nil {
+		return err
+	}
+
+	err = runUpdate(oldCfg, baseCfg)
+	if err != nil {
+		return err
+	}
+
+	newCfgData, err := yaml.Marshal(&baseCfg)
+	if err != nil {
+		return err
+	}
+
+	return ioutil.WriteFile(path, newCfgData, 0600)
+}
+
+func runUpdate(oldCfg, newCfg RecursiveMap) error {
+	cp := func(path string) {
+		newCfg.CopyFrom(oldCfg, path)
+	}
+
+	cp("homeserver.address")
+	cp("homeserver.domain")
+
+	cp("appservice.address")
+	cp("appservice.hostname")
+	cp("appservice.port")
+
+	cp("appservice.database.type")
+	cp("appservice.database.uri")
+	cp("appservice.state_store_path")
+
+	cp("appservice.id")
+	cp("appservice.bot.username")
+	cp("appservice.bot.displayname")
+	cp("appservice.bot.avatar")
+
+	cp("appservice.bot.as_token")
+	cp("appservice.bot.hs_token")
+
+	cp("bridge.username_template")
+	cp("bridge.displayname_template")
+
+	cp("bridge.command_prefix")
+
+	cp("bridge.permissions")
+
+	cp("logging.directory")
+	cp("logging.file_name_format")
+	cp("logging.file_date_format")
+	cp("logging.file_mode")
+	cp("logging.timestamp_format")
+	cp("logging.print_level")
+
+	return nil
+}

+ 18 - 17
database/message.go

@@ -30,12 +30,13 @@ type MessageQuery struct {
 
 func (mq *MessageQuery) CreateTable() error {
 	_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
-		owner VARCHAR(255),
-		jid   VARCHAR(255),
-		mxid  VARCHAR(255) NOT NULL UNIQUE,
+		chat_jid      VARCHAR(25) NOT NULL,
+		chat_receiver VARCHAR(25) NOT NULL,
+		jid  VARCHAR(255) NOT NULL,
+		mxid VARCHAR(255) NOT NULL UNIQUE,
 
-		PRIMARY KEY (owner, jid),
-		FOREIGN KEY (owner) REFERENCES user(mxid)
+		PRIMARY KEY (chat_jid, jid),
+		FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
 	)`)
 	return err
 }
@@ -47,8 +48,8 @@ func (mq *MessageQuery) New() *Message {
 	}
 }
 
-func (mq *MessageQuery) GetAll(owner types.MatrixUserID) (messages []*Message) {
-	rows, err := mq.db.Query("SELECT * FROM message WHERE owner=?", owner)
+func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
+	rows, err := mq.db.Query("SELECT * FROM message WHERE chat_jid=? AND chat_receiver=?", chat.JID, chat.Receiver)
 	if err != nil || rows == nil {
 		return nil
 	}
@@ -59,8 +60,8 @@ func (mq *MessageQuery) GetAll(owner types.MatrixUserID) (messages []*Message) {
 	return
 }
 
-func (mq *MessageQuery) GetByJID(owner types.MatrixUserID, jid types.WhatsAppMessageID) *Message {
-	return mq.get("SELECT * FROM message WHERE owner=? AND jid=?", owner, jid)
+func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.WhatsAppMessageID) *Message {
+	return mq.get("SELECT * FROM message WHERE chat_jid=? AND chat_receiver=? AND jid=?", chat.JID, chat.Receiver, jid)
 }
 
 func (mq *MessageQuery) GetByMXID(mxid types.MatrixEventID) *Message {
@@ -79,13 +80,13 @@ type Message struct {
 	db  *Database
 	log log.Logger
 
-	Owner types.MatrixUserID
-	JID   types.WhatsAppMessageID
-	MXID  types.MatrixEventID
+	Chat PortalKey
+	JID     types.WhatsAppMessageID
+	MXID    types.MatrixEventID
 }
 
 func (msg *Message) Scan(row Scannable) *Message {
-	err := row.Scan(&msg.Owner, &msg.JID, &msg.MXID)
+	err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID)
 	if err != nil {
 		if err != sql.ErrNoRows {
 			msg.log.Errorln("Database scan failed:", err)
@@ -96,17 +97,17 @@ func (msg *Message) Scan(row Scannable) *Message {
 }
 
 func (msg *Message) Insert() error {
-	_, err := msg.db.Exec("INSERT INTO message VALUES (?, ?, ?)", msg.Owner, msg.JID, msg.MXID)
+	_, err := msg.db.Exec("INSERT INTO message VALUES (?, ?, ?)", msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID)
 	if err != nil {
-		msg.log.Warnfln("Failed to update %s->%s: %v", msg.Owner, msg.JID, err)
+		msg.log.Warnfln("Failed to update %s: %v", msg.Chat, msg.JID, err)
 	}
 	return err
 }
 
 func (msg *Message) Update() error {
-	_, err := msg.db.Exec("UPDATE portal SET mxid=? WHERE owner=? AND jid=?", msg.MXID, msg.Owner, msg.JID)
+	_, err := msg.db.Exec("UPDATE portal SET mxid=? WHERE chat_jid=? AND chat_receiver=? AND jid=?", msg.MXID, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
 	if err != nil {
-		msg.log.Warnfln("Failed to update %s->%s: %v", msg.Owner, msg.JID, err)
+		msg.log.Warnfln("Failed to update %s: %v", msg.Chat, msg.JID, err)
 	}
 	return err
 }

+ 53 - 21
database/portal.go

@@ -18,11 +18,41 @@ package database
 
 import (
 	"database/sql"
+	"strings"
 
 	log "maunium.net/go/maulogger"
 	"maunium.net/go/mautrix-whatsapp/types"
 )
 
+type PortalKey struct {
+	JID      types.WhatsAppID
+	Receiver types.WhatsAppID
+}
+
+func GroupPortalKey(jid types.WhatsAppID) PortalKey {
+	return PortalKey{
+		JID:      jid,
+		Receiver: jid,
+	}
+}
+
+func NewPortalKey(jid, receiver types.WhatsAppID) PortalKey {
+	if strings.HasSuffix(jid, "@g.us") {
+		receiver = jid
+	}
+	return PortalKey{
+		JID: jid,
+		Receiver: receiver,
+	}
+}
+
+func (key PortalKey) String() string {
+	if key.Receiver == key.JID {
+		return key.JID
+	}
+	return key.JID + "-" + key.Receiver
+}
+
 type PortalQuery struct {
 	db  *Database
 	log log.Logger
@@ -30,16 +60,16 @@ type PortalQuery struct {
 
 func (pq *PortalQuery) CreateTable() error {
 	_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS portal (
-		jid   VARCHAR(255),
-		owner VARCHAR(255),
-		mxid  VARCHAR(255) UNIQUE,
+		jid      VARCHAR(25),
+		receiver VARCHAR(25),
+		mxid     VARCHAR(255) UNIQUE,
 
 		name   VARCHAR(255) NOT NULL,
 		topic  VARCHAR(255) NOT NULL,
 		avatar VARCHAR(255) NOT NULL,
 
-		PRIMARY KEY (jid, owner),
-		FOREIGN KEY (owner) REFERENCES user(mxid)
+		PRIMARY KEY (jid, receiver),
+		FOREIGN KEY (receiver) REFERENCES user(mxid)
 	)`)
 	return err
 }
@@ -51,8 +81,8 @@ func (pq *PortalQuery) New() *Portal {
 	}
 }
 
-func (pq *PortalQuery) GetAll(owner types.MatrixUserID) (portals []*Portal) {
-	rows, err := pq.db.Query("SELECT * FROM portal WHERE owner=?", owner)
+func (pq *PortalQuery) GetAll() (portals []*Portal) {
+	rows, err := pq.db.Query("SELECT * FROM portal")
 	if err != nil || rows == nil {
 		return nil
 	}
@@ -63,8 +93,8 @@ func (pq *PortalQuery) GetAll(owner types.MatrixUserID) (portals []*Portal) {
 	return
 }
 
-func (pq *PortalQuery) GetByJID(owner types.MatrixUserID, jid types.WhatsAppID) *Portal {
-	return pq.get("SELECT * FROM portal WHERE jid=? AND owner=?", jid, owner)
+func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
+	return pq.get("SELECT * FROM portal WHERE jid=? AND receiver=?", key.JID, key.Receiver)
 }
 
 func (pq *PortalQuery) GetByMXID(mxid types.MatrixRoomID) *Portal {
@@ -83,9 +113,8 @@ type Portal struct {
 	db  *Database
 	log log.Logger
 
-	JID   types.WhatsAppID
-	MXID  types.MatrixRoomID
-	Owner types.MatrixUserID
+	Key  PortalKey
+	MXID types.MatrixRoomID
 
 	Name   string
 	Topic  string
@@ -93,7 +122,7 @@ type Portal struct {
 }
 
 func (portal *Portal) Scan(row Scannable) *Portal {
-	err := row.Scan(&portal.JID, &portal.Owner, &portal.MXID, &portal.Name, &portal.Topic, &portal.Avatar)
+	err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &portal.MXID, &portal.Name, &portal.Topic, &portal.Avatar)
 	if err != nil {
 		if err != sql.ErrNoRows {
 			portal.log.Errorln("Database scan failed:", err)
@@ -103,15 +132,18 @@ func (portal *Portal) Scan(row Scannable) *Portal {
 	return portal
 }
 
-func (portal *Portal) Insert() error {
-	var mxid *string
+func (portal *Portal) mxidPtr() *string {
 	if len(portal.MXID) > 0 {
-		mxid = &portal.MXID
+		return &portal.MXID
 	}
+	return nil
+}
+
+func (portal *Portal) Insert() error {
 	_, err := portal.db.Exec("INSERT INTO portal VALUES (?, ?, ?, ?, ?, ?)",
-		portal.JID, portal.Owner, mxid, portal.Name, portal.Topic, portal.Avatar)
+		portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar)
 	if err != nil {
-		portal.log.Warnfln("Failed to insert %s->%s: %v", portal.JID, portal.Owner, err)
+		portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
 	}
 	return err
 }
@@ -121,10 +153,10 @@ func (portal *Portal) Update() error {
 	if len(portal.MXID) > 0 {
 		mxid = &portal.MXID
 	}
-	_, err := portal.db.Exec("UPDATE portal SET mxid=?, name=?, topic=?, avatar=? WHERE jid=? AND owner=?",
-		mxid, portal.Name, portal.Topic, portal.Avatar, portal.JID, portal.Owner)
+	_, err := portal.db.Exec("UPDATE portal SET mxid=?, name=?, topic=?, avatar=? WHERE jid=? AND receiver=?",
+		mxid, portal.Name, portal.Topic, portal.Avatar, portal.Key.JID, portal.Key.Receiver)
 	if err != nil {
-		portal.log.Warnfln("Failed to update %s->%s: %v", portal.JID, portal.Owner, err)
+		portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
 	}
 	return err
 }

+ 14 - 21
database/puppet.go

@@ -30,13 +30,9 @@ type PuppetQuery struct {
 
 func (pq *PuppetQuery) CreateTable() error {
 	_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS puppet (
-		jid      VARCHAR(255),
-		receiver VARCHAR(255),
-
+		jid         VARCHAR(25) PRIMARY KEY,
 		displayname VARCHAR(255),
-		avatar      VARCHAR(255),
-
-		PRIMARY KEY(jid, receiver)
+		avatar      VARCHAR(255)
 	)`)
 	return err
 }
@@ -48,8 +44,8 @@ func (pq *PuppetQuery) New() *Puppet {
 	}
 }
 
-func (pq *PuppetQuery) GetAll(receiver types.MatrixUserID) (puppets []*Puppet) {
-	rows, err := pq.db.Query("SELECT * FROM puppet WHERE receiver=%s")
+func (pq *PuppetQuery) GetAll() (puppets []*Puppet) {
+	rows, err := pq.db.Query("SELECT * FROM puppet")
 	if err != nil || rows == nil {
 		return nil
 	}
@@ -60,8 +56,8 @@ func (pq *PuppetQuery) GetAll(receiver types.MatrixUserID) (puppets []*Puppet) {
 	return
 }
 
-func (pq *PuppetQuery) Get(jid types.WhatsAppID, receiver types.MatrixUserID) *Puppet {
-	row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=? AND receiver=?", jid, receiver)
+func (pq *PuppetQuery) Get(jid types.WhatsAppID) *Puppet {
+	row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=?", jid)
 	if row == nil {
 		return nil
 	}
@@ -72,15 +68,13 @@ type Puppet struct {
 	db  *Database
 	log log.Logger
 
-	JID      types.WhatsAppID
-	Receiver types.MatrixUserID
-
+	JID         types.WhatsAppID
 	Displayname string
 	Avatar      string
 }
 
 func (puppet *Puppet) Scan(row Scannable) *Puppet {
-	err := row.Scan(&puppet.JID, &puppet.Receiver, &puppet.Displayname, &puppet.Avatar)
+	err := row.Scan(&puppet.JID, &puppet.Displayname, &puppet.Avatar)
 	if err != nil {
 		if err != sql.ErrNoRows {
 			puppet.log.Errorln("Database scan failed:", err)
@@ -91,20 +85,19 @@ func (puppet *Puppet) Scan(row Scannable) *Puppet {
 }
 
 func (puppet *Puppet) Insert() error {
-	_, err := puppet.db.Exec("INSERT INTO puppet VALUES (?, ?, ?, ?)",
-		puppet.JID, puppet.Receiver, puppet.Displayname, puppet.Avatar)
+	_, err := puppet.db.Exec("INSERT INTO puppet VALUES (?, ?, ?)",
+		puppet.JID, puppet.Displayname, puppet.Avatar)
 	if err != nil {
-		puppet.log.Errorfln("Failed to insert %s->%s: %v", puppet.JID, puppet.Receiver, err)
+		puppet.log.Errorfln("Failed to insert %s: %v", puppet.JID, err)
 	}
 	return err
 }
 
 func (puppet *Puppet) Update() error {
-	_, err := puppet.db.Exec("UPDATE puppet SET displayname=?, avatar=? WHERE jid=? AND receiver=?",
-		puppet.Displayname, puppet.Avatar,
-		puppet.JID, puppet.Receiver)
+	_, err := puppet.db.Exec("UPDATE puppet SET displayname=?, avatar=? WHERE jid=?",
+		puppet.Displayname, puppet.Avatar, puppet.JID)
 	if err != nil {
-		puppet.log.Errorfln("Failed to update %s->%s: %v", puppet.JID, puppet.Receiver, err)
+		puppet.log.Errorfln("Failed to update %s->%s: %v", puppet.JID, err)
 	}
 	return err
 }

+ 32 - 13
database/user.go

@@ -32,6 +32,7 @@ type UserQuery struct {
 func (uq *UserQuery) CreateTable() error {
 	_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS user (
 		mxid VARCHAR(255) PRIMARY KEY,
+		jid  VARCHAR(25)  UNIQUE,
 
 		management_room VARCHAR(255),
 
@@ -64,7 +65,7 @@ func (uq *UserQuery) GetAll() (users []*User) {
 	return
 }
 
-func (uq *UserQuery) Get(userID types.MatrixUserID) *User {
+func (uq *UserQuery) GetByMXID(userID types.MatrixUserID) *User {
 	row := uq.db.QueryRow("SELECT * FROM user WHERE mxid=?", userID)
 	if row == nil {
 		return nil
@@ -72,18 +73,27 @@ func (uq *UserQuery) Get(userID types.MatrixUserID) *User {
 	return uq.New().Scan(row)
 }
 
+func (uq *UserQuery) GetByJID(userID types.WhatsAppID) *User {
+	row := uq.db.QueryRow("SELECT * FROM user WHERE jid=?", userID)
+	if row == nil {
+		return nil
+	}
+	return uq.New().Scan(row)
+}
+
 type User struct {
 	db  *Database
 	log log.Logger
 
-	ID             types.MatrixUserID
+	MXID           types.MatrixUserID
+	JID            types.WhatsAppID
 	ManagementRoom types.MatrixRoomID
 	Session        *whatsapp.Session
 }
 
 func (user *User) Scan(row Scannable) *User {
 	sess := whatsapp.Session{}
-	err := row.Scan(&user.ID, &user.ManagementRoom, &sess.ClientId, &sess.ClientToken, &sess.ServerToken,
+	err := row.Scan(&user.MXID, &user.JID, &user.ManagementRoom, &sess.ClientId, &sess.ClientToken, &sess.ServerToken,
 		&sess.EncKey, &sess.MacKey, &sess.Wid)
 	if err != nil {
 		if err != sql.ErrNoRows {
@@ -99,23 +109,32 @@ func (user *User) Scan(row Scannable) *User {
 	return user
 }
 
-func (user *User) Insert() error {
-	var sess whatsapp.Session
+func (user *User) jidPtr() *string {
+	if len(user.JID) > 0 {
+		return &user.JID
+	}
+	return nil
+}
+
+func (user *User) sessionUnptr() (sess whatsapp.Session) {
 	if user.Session != nil {
 		sess = *user.Session
 	}
-	_, err := user.db.Exec("INSERT INTO user VALUES (?, ?, ?, ?, ?, ?, ?, ?)", user.ID, user.ManagementRoom,
+	return
+}
+
+func (user *User) Insert() error {
+	sess := user.sessionUnptr()
+	_, err := user.db.Exec("INSERT INTO user VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", user.MXID, user.jidPtr(), user.ManagementRoom,
 		sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey, sess.Wid)
 	return err
 }
 
 func (user *User) Update() error {
-	var sess whatsapp.Session
-	if user.Session != nil {
-		sess = *user.Session
-	}
-	_, err := user.db.Exec("UPDATE user SET management_room=?, client_id=?, client_token=?, server_token=?, enc_key=?, mac_key=?, wid=? WHERE mxid=?",
-		user.ManagementRoom,
-		sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey, sess.Wid, user.ID)
+	sess := user.sessionUnptr()
+	_, err := user.db.Exec("UPDATE user SET jid=?, management_room=?, client_id=?, client_token=?, server_token=?, enc_key=?, mac_key=?, wid=? WHERE mxid=?",
+		user.jidPtr(), user.ManagementRoom,
+		sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey, sess.Wid,
+		user.MXID)
 	return err
 }

+ 9 - 10
example-config.yaml

@@ -21,7 +21,6 @@ appservice:
     type: sqlite3
     # The database URI. Usually file name. https://github.com/mattn/go-sqlite3#connection-string
     uri: mautrix-whatsapp.db
-
   # Path to the Matrix room state store.
   state_store_path: ./mx-state.json
 
@@ -43,15 +42,15 @@ appservice:
 # Bridge config. Currently unused.
 bridge:
   # Localpart template of MXIDs for WhatsApp users.
-  # {{.Receiver}} is replaced with the WhatsApp user ID of the Matrix user receiving messages.
-  # {{.UserID}} is replaced with the user ID of the WhatsApp user.
-  username_template: "whatsapp_{{.Receiver}}_{{.UserID}}"
+  # {{.}} is replaced with the phone number of the WhatsApp user.
+  username_template: whatsapp_{{.}}
   # Displayname template for WhatsApp users.
-  # {{.Name}}   - display name
-  # {{.Short}}  - short display name (usually first name)
-  # {{.Notify}} - nickname (maybe set by the target WhatsApp user)
+  # {{.Notify}} - nickname set by the WhatsApp user
   # {{.Jid}}    - phone number (international format)
-  displayname_template: "{{if .Name}}{{.Name}}{{else if .Notify}}{{.Notify}}{{else if .Short}}{{.Short}}{{else}}{{.Jid}}{{end}}"
+  # The following variables are also available, but will cause problems on multi-user instances:
+  # {{.Name}}   - display name from contact list
+  # {{.Short}}  - short display name from contact list
+  displayname_template: "{{if .Notify}}{{.Notify}}{{else}}{{.Jid}}{{end}} (WA)"
 
   # The prefix for commands. Only required in non-management rooms.
   command_prefix: "!wa"
@@ -72,8 +71,8 @@ bridge:
 logging:
   # The directory for log files. Will be created if not found.
   directory: ./logs
-  # Available variables: .date for the file date and .index for different log files on the same day.
-  file_name_format: "{{.date}}-{{.index}.log"
+  # Available variables: .Date for the file date and .Index for different log files on the same day.
+  file_name_format: "{{.Date}}-{{.Index}}.log"
   # Date format for file names in the Go time format: https://golang.org/pkg/time/#pkg-constants
   file_date_format: 2006-01-02
   # Log file permissions.

+ 91 - 49
formatting.go

@@ -18,58 +18,71 @@ package main
 
 import (
 	"fmt"
+	"html"
 	"regexp"
 	"strings"
 
+	"maunium.net/go/gomatrix"
 	"maunium.net/go/gomatrix/format"
 	"maunium.net/go/mautrix-whatsapp/whatsapp-ext"
 )
 
-func (user *User) newHTMLParser() *format.HTMLParser {
-	return &format.HTMLParser{
-		TabsToSpaces: 4,
-		Newline:      "\n",
-
-		PillConverter: func(mxid, eventID string) string {
-			if mxid[0] == '@' {
-				puppet := user.GetPuppetByMXID(mxid)
-				fmt.Println(mxid, puppet)
-				if puppet != nil {
-					return "@" + puppet.PhoneNumber()
-				}
-			}
-			return mxid
-		},
-		BoldConverter: func(text string) string {
-			return fmt.Sprintf("*%s*", text)
-		},
-		ItalicConverter: func(text string) string {
-			return fmt.Sprintf("_%s_", text)
-		},
-		StrikethroughConverter: func(text string) string {
-			return fmt.Sprintf("~%s~", text)
-		},
-		MonospaceConverter: func(text string) string {
-			return fmt.Sprintf("```%s```", text)
-		},
-		MonospaceBlockConverter: func(text string) string {
-			return fmt.Sprintf("```%s```", text)
-		},
-	}
-}
-
 var italicRegex = regexp.MustCompile("([\\s>~*]|^)_(.+?)_([^a-zA-Z\\d]|$)")
 var boldRegex = regexp.MustCompile("([\\s>_~]|^)\\*(.+?)\\*([^a-zA-Z\\d]|$)")
 var strikethroughRegex = regexp.MustCompile("([\\s>_*]|^)~(.+?)~([^a-zA-Z\\d]|$)")
 var codeBlockRegex = regexp.MustCompile("```(?:.|\n)+?```")
 var mentionRegex = regexp.MustCompile("@[0-9]+")
 
-func (user *User) newWhatsAppFormatMaps() (map[*regexp.Regexp]string, map[*regexp.Regexp]func(string) string, map[*regexp.Regexp]func(string) string) {
-	return map[*regexp.Regexp]string{
-		italicRegex:        "$1<em>$2</em>$3",
-		boldRegex:          "$1<strong>$2</strong>$3",
-		strikethroughRegex: "$1<del>$2</del>$3",
-	}, map[*regexp.Regexp]func(string) string{
+type Formatter struct {
+	bridge *Bridge
+
+	matrixHTMLParser *format.HTMLParser
+
+	waReplString   map[*regexp.Regexp]string
+	waReplFunc     map[*regexp.Regexp]func(string) string
+	waReplFuncText map[*regexp.Regexp]func(string) string
+}
+
+func NewFormatter(bridge *Bridge) *Formatter {
+	formatter := &Formatter{
+		bridge: bridge,
+		matrixHTMLParser: &format.HTMLParser{
+			TabsToSpaces: 4,
+			Newline:      "\n",
+
+			PillConverter: func(mxid, eventID string) string {
+				if mxid[0] == '@' {
+					puppet := bridge.GetPuppetByMXID(mxid)
+					fmt.Println(mxid, puppet)
+					if puppet != nil {
+						return "@" + puppet.PhoneNumber()
+					}
+				}
+				return mxid
+			},
+			BoldConverter: func(text string) string {
+				return fmt.Sprintf("*%s*", text)
+			},
+			ItalicConverter: func(text string) string {
+				return fmt.Sprintf("_%s_", text)
+			},
+			StrikethroughConverter: func(text string) string {
+				return fmt.Sprintf("~%s~", text)
+			},
+			MonospaceConverter: func(text string) string {
+				return fmt.Sprintf("```%s```", text)
+			},
+			MonospaceBlockConverter: func(text string) string {
+				return fmt.Sprintf("```%s```", text)
+			},
+		},
+		waReplString: map[*regexp.Regexp]string{
+			italicRegex:        "$1<em>$2</em>$3",
+			boldRegex:          "$1<strong>$2</strong>$3",
+			strikethroughRegex: "$1<del>$2</del>$3",
+		},
+	}
+	formatter.waReplFunc = map[*regexp.Regexp]func(string) string{
 		codeBlockRegex: func(str string) string {
 			str = str[3 : len(str)-3]
 			if strings.ContainsRune(str, '\n') {
@@ -78,18 +91,47 @@ func (user *User) newWhatsAppFormatMaps() (map[*regexp.Regexp]string, map[*regex
 			return fmt.Sprintf("<code>%s</code>", str)
 		},
 		mentionRegex: func(str string) string {
-			jid := str[1:] + whatsappExt.NewUserSuffix
-			puppet := user.GetPuppetByJID(jid)
-			mxid := puppet.MXID
-			if jid == user.JID() {
-				mxid = user.ID
-			}
-			return fmt.Sprintf(`<a href="https://matrix.to/#/%s">%s</a>`, mxid, puppet.Displayname)
+			mxid, displayname := formatter.getMatrixInfoByJID(str[1:] + whatsappExt.NewUserSuffix)
+			return fmt.Sprintf(`<a href="https://matrix.to/#/%s">%s</a>`, mxid, displayname)
 		},
-	}, map[*regexp.Regexp]func(string)string {
+	}
+	formatter.waReplFuncText = map[*regexp.Regexp]func(string) string{
 		mentionRegex: func(str string) string {
-			puppet := user.GetPuppetByJID(str[1:] + whatsappExt.NewUserSuffix)
-			return puppet.Displayname
+			_, displayname := formatter.getMatrixInfoByJID(str[1:] + whatsappExt.NewUserSuffix)
+			return displayname
 		},
 	}
+	return formatter
+}
+
+func (formatter *Formatter) getMatrixInfoByJID(jid string) (mxid, displayname string) {
+	if user := formatter.bridge.GetUserByJID(jid); user != nil {
+		mxid = user.MXID
+		displayname = user.MXID
+	} else if puppet := formatter.bridge.GetPuppetByJID(jid); puppet != nil {
+		mxid = puppet.MXID
+		displayname = puppet.Displayname
+	}
+	return
+}
+
+func (formatter *Formatter) ParseWhatsApp(content *gomatrix.Content) {
+	output := html.EscapeString(content.Body)
+	for regex, replacement := range formatter.waReplString {
+		output = regex.ReplaceAllString(output, replacement)
+	}
+	for regex, replacer := range formatter.waReplFunc {
+		output = regex.ReplaceAllStringFunc(output, replacer)
+	}
+	if output != content.Body {
+		content.FormattedBody = output
+		content.Format = gomatrix.FormatHTML
+		for regex, replacer := range formatter.waReplFuncText {
+			content.Body = regex.ReplaceAllStringFunc(content.Body, replacer)
+		}
+	}
+}
+
+func (formatter *Formatter) ParseMatrix(html string) string {
+	return formatter.matrixHTMLParser.Parse(html)
 }

+ 54 - 25
main.go

@@ -20,6 +20,7 @@ import (
 	"fmt"
 	"os"
 	"os/signal"
+	"sync"
 	"syscall"
 
 	flag "maunium.net/go/mauflag"
@@ -31,6 +32,7 @@ import (
 )
 
 var configPath = flag.MakeFull("c", "config", "The path to your config file.", "config.yaml").String()
+var baseConfigPath = flag.MakeFull("b", "base-config", "The path to the example config file.", "example-config.yaml").String()
 var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String()
 var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool()
 var wantHelp, _ = flag.MakeHelpFlag()
@@ -58,29 +60,47 @@ func (bridge *Bridge) GenerateRegistration() {
 }
 
 type Bridge struct {
-	AppService     *appservice.AppService
+	AS             *appservice.AppService
 	EventProcessor *appservice.EventProcessor
 	MatrixHandler  *MatrixHandler
 	Config         *config.Config
 	DB             *database.Database
 	Log            log.Logger
-
-	StateStore *AutosavingStateStore
-
-	users           map[types.MatrixUserID]*User
-	managementRooms map[types.MatrixRoomID]*User
+	StateStore     *AutosavingStateStore
+	Bot            *appservice.IntentAPI
+	Formatter      *Formatter
+
+	usersByMXID         map[types.MatrixUserID]*User
+	usersByJID          map[types.WhatsAppID]*User
+	usersLock           sync.Mutex
+	managementRooms     map[types.MatrixRoomID]*User
+	managementRoomsLock sync.Mutex
+	portalsByMXID       map[types.MatrixRoomID]*Portal
+	portalsByJID        map[database.PortalKey]*Portal
+	portalsLock         sync.Mutex
+	puppets             map[types.WhatsAppID]*Puppet
+	puppetsLock         sync.Mutex
 }
 
 func NewBridge() *Bridge {
 	bridge := &Bridge{
-		users:           make(map[types.MatrixUserID]*User),
+		usersByMXID:     make(map[types.MatrixUserID]*User),
+		usersByJID:      make(map[types.WhatsAppID]*User),
 		managementRooms: make(map[types.MatrixRoomID]*User),
+		portalsByMXID:   make(map[types.MatrixRoomID]*Portal),
+		portalsByJID:    make(map[database.PortalKey]*Portal),
+		puppets:         make(map[types.WhatsAppID]*Puppet),
 	}
-	var err error
+	err := config.Update(*configPath, *baseConfigPath)
+	if err != nil {
+		fmt.Fprintln(os.Stderr, "Failed to update config:", err)
+		os.Exit(10)
+	}
+
 	bridge.Config, err = config.Load(*configPath)
 	if err != nil {
 		fmt.Fprintln(os.Stderr, "Failed to load config:", err)
-		os.Exit(10)
+		os.Exit(11)
 	}
 	return bridge
 }
@@ -88,46 +108,55 @@ func NewBridge() *Bridge {
 func (bridge *Bridge) Init() {
 	var err error
 
-	bridge.AppService, err = bridge.Config.MakeAppService()
+	bridge.AS, err = bridge.Config.MakeAppService()
 	if err != nil {
 		fmt.Fprintln(os.Stderr, "Failed to initialize AppService:", err)
-		os.Exit(11)
+		os.Exit(12)
 	}
-	bridge.AppService.Init()
-	bridge.Log = bridge.AppService.Log
+	bridge.AS.Init()
+	bridge.Bot = bridge.AS.BotIntent()
+
+	bridge.Log = log.Create()
+	bridge.Config.Logging.Configure(bridge.Log)
 	log.DefaultLogger = bridge.Log.(*log.BasicLogger)
-	bridge.AppService.Log = log.Sub("Matrix")
+	err = log.OpenFile()
+	if err != nil {
+		fmt.Fprintln(os.Stderr, "Failed to open log file:", err)
+		os.Exit(13)
+	}
+	bridge.AS.Log = log.Sub("Matrix")
 
 	bridge.Log.Debugln("Initializing state store")
 	bridge.StateStore = NewAutosavingStateStore(bridge.Config.AppService.StateStore)
 	err = bridge.StateStore.Load()
 	if err != nil {
 		bridge.Log.Fatalln("Failed to load state store:", err)
-		os.Exit(12)
+		os.Exit(14)
 	}
-	bridge.AppService.StateStore = bridge.StateStore
+	bridge.AS.StateStore = bridge.StateStore
 
 	bridge.Log.Debugln("Initializing database")
 	bridge.DB, err = database.New(bridge.Config.AppService.Database.URI)
 	if err != nil {
 		bridge.Log.Fatalln("Failed to initialize database:", err)
-		os.Exit(13)
+		os.Exit(15)
 	}
 
 	bridge.Log.Debugln("Initializing Matrix event processor")
-	bridge.EventProcessor = appservice.NewEventProcessor(bridge.AppService)
+	bridge.EventProcessor = appservice.NewEventProcessor(bridge.AS)
 	bridge.Log.Debugln("Initializing Matrix event handler")
 	bridge.MatrixHandler = NewMatrixHandler(bridge)
+	bridge.Formatter = NewFormatter(bridge)
 }
 
 func (bridge *Bridge) Start() {
 	err := bridge.DB.CreateTables()
 	if err != nil {
 		bridge.Log.Fatalln("Failed to create database tables:", err)
-		os.Exit(14)
+		os.Exit(16)
 	}
 	bridge.Log.Debugln("Starting application service HTTP server")
-	go bridge.AppService.Start()
+	go bridge.AS.Start()
 	bridge.Log.Debugln("Starting event processor")
 	go bridge.EventProcessor.Start()
 	go bridge.UpdateBotProfile()
@@ -140,18 +169,18 @@ func (bridge *Bridge) UpdateBotProfile() {
 
 	var err error
 	if botConfig.Avatar == "remove" {
-		err = bridge.AppService.BotIntent().SetAvatarURL("")
+		err = bridge.AS.BotIntent().SetAvatarURL("")
 	} else if len(botConfig.Avatar) > 0 {
-		err = bridge.AppService.BotIntent().SetAvatarURL(botConfig.Avatar)
+		err = bridge.AS.BotIntent().SetAvatarURL(botConfig.Avatar)
 	}
 	if err != nil {
 		bridge.Log.Warnln("Failed to update bot avatar:", err)
 	}
 
 	if botConfig.Displayname == "remove" {
-		err = bridge.AppService.BotIntent().SetDisplayName("")
+		err = bridge.AS.BotIntent().SetDisplayName("")
 	} else if len(botConfig.Avatar) > 0 {
-		err = bridge.AppService.BotIntent().SetDisplayName(botConfig.Displayname)
+		err = bridge.AS.BotIntent().SetDisplayName(botConfig.Displayname)
 	}
 	if err != nil {
 		bridge.Log.Warnln("Failed to update bot displayname:", err)
@@ -165,7 +194,7 @@ func (bridge *Bridge) StartUsers() {
 }
 
 func (bridge *Bridge) Stop() {
-	bridge.AppService.Stop()
+	bridge.AS.Stop()
 	bridge.EventProcessor.Stop()
 	err := bridge.StateStore.Save()
 	if err != nil {

+ 15 - 11
matrix.go

@@ -35,7 +35,7 @@ type MatrixHandler struct {
 func NewMatrixHandler(bridge *Bridge) *MatrixHandler {
 	handler := &MatrixHandler{
 		bridge: bridge,
-		as:     bridge.AppService,
+		as:     bridge.AS,
 		log:    bridge.Log.Sub("Matrix"),
 		cmd:    NewCommandHandler(bridge),
 	}
@@ -50,7 +50,7 @@ func NewMatrixHandler(bridge *Bridge) *MatrixHandler {
 func (mx *MatrixHandler) HandleBotInvite(evt *gomatrix.Event) {
 	intent := mx.as.BotIntent()
 
-	user := mx.bridge.GetUser(evt.Sender)
+	user := mx.bridge.GetUserByMXID(evt.Sender)
 	if user == nil {
 		return
 	}
@@ -85,7 +85,7 @@ func (mx *MatrixHandler) HandleBotInvite(evt *gomatrix.Event) {
 	for mxid, _ := range members.Joined {
 		if mxid == intent.UserID || mxid == evt.Sender {
 			continue
-		} else if _, _, ok := mx.bridge.ParsePuppetMXID(types.MatrixUserID(mxid)); ok {
+		} else if _, ok := mx.bridge.ParsePuppetMXID(types.MatrixUserID(mxid)); ok {
 			hasPuppets = true
 			continue
 		}
@@ -96,7 +96,7 @@ func (mx *MatrixHandler) HandleBotInvite(evt *gomatrix.Event) {
 	}
 
 	if !hasPuppets {
-		user := mx.bridge.GetUser(types.MatrixUserID(evt.Sender))
+		user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender))
 		user.SetManagementRoom(types.MatrixRoomID(resp.RoomID))
 		intent.SendNotice(string(user.ManagementRoom), "This room has been registered as your bridge management/status room.")
 		mx.log.Debugln(resp.RoomID, "registered as a management room with", evt.Sender)
@@ -110,12 +110,12 @@ func (mx *MatrixHandler) HandleMembership(evt *gomatrix.Event) {
 }
 
 func (mx *MatrixHandler) HandleRoomMetadata(evt *gomatrix.Event) {
-	user := mx.bridge.GetUser(types.MatrixUserID(evt.Sender))
-	if user == nil || !user.Whitelisted {
+	user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender))
+	if user == nil || !user.Whitelisted || !user.IsLoggedIn() {
 		return
 	}
 
-	portal := user.GetPortalByMXID(evt.RoomID)
+	portal := mx.bridge.GetPortalByMXID(evt.RoomID)
 	if portal == nil || portal.IsPrivateChat() {
 		return
 	}
@@ -124,7 +124,7 @@ func (mx *MatrixHandler) HandleRoomMetadata(evt *gomatrix.Event) {
 	var err error
 	switch evt.Type {
 	case gomatrix.StateRoomName:
-		resp, err = user.Conn.UpdateGroupSubject(evt.Content.Name, portal.JID)
+		resp, err = user.Conn.UpdateGroupSubject(evt.Content.Name, portal.Key.JID)
 	case gomatrix.StateRoomAvatar:
 		return
 	case gomatrix.StateTopic:
@@ -140,7 +140,7 @@ func (mx *MatrixHandler) HandleRoomMetadata(evt *gomatrix.Event) {
 
 func (mx *MatrixHandler) HandleMessage(evt *gomatrix.Event) {
 	roomID := types.MatrixRoomID(evt.RoomID)
-	user := mx.bridge.GetUser(types.MatrixUserID(evt.Sender))
+	user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender))
 
 	if !user.Whitelisted {
 		return
@@ -158,8 +158,12 @@ func (mx *MatrixHandler) HandleMessage(evt *gomatrix.Event) {
 		}
 	}
 
-	portal := user.GetPortalByMXID(roomID)
+	if !user.IsLoggedIn() {
+		return
+	}
+
+	portal := mx.bridge.GetPortalByMXID(roomID)
 	if portal != nil {
-		portal.HandleMatrixMessage(evt)
+		portal.HandleMatrixMessage(user, evt)
 	}
 }

+ 96 - 109
portal.go

@@ -20,7 +20,6 @@ import (
 	"bytes"
 	"encoding/hex"
 	"fmt"
-	"html"
 	"image"
 	"image/gif"
 	"image/jpeg"
@@ -41,57 +40,56 @@ import (
 	"maunium.net/go/mautrix-whatsapp/whatsapp-ext"
 )
 
-func (user *User) GetPortalByMXID(mxid types.MatrixRoomID) *Portal {
-	user.portalsLock.Lock()
-	defer user.portalsLock.Unlock()
-	portal, ok := user.portalsByMXID[mxid]
+func (bridge *Bridge) GetPortalByMXID(mxid types.MatrixRoomID) *Portal {
+	bridge.portalsLock.Lock()
+	defer bridge.portalsLock.Unlock()
+	portal, ok := bridge.portalsByMXID[mxid]
 	if !ok {
-		dbPortal := user.bridge.DB.Portal.GetByMXID(mxid)
-		if dbPortal == nil || dbPortal.Owner != user.ID {
+		dbPortal := bridge.DB.Portal.GetByMXID(mxid)
+		if dbPortal == nil {
 			return nil
 		}
-		portal = user.NewPortal(dbPortal)
-		user.portalsByJID[portal.JID] = portal
+		portal = bridge.NewPortal(dbPortal)
+		bridge.portalsByJID[portal.Key] = portal
 		if len(portal.MXID) > 0 {
-			user.portalsByMXID[portal.MXID] = portal
+			bridge.portalsByMXID[portal.MXID] = portal
 		}
 	}
 	return portal
 }
 
-func (user *User) GetPortalByJID(jid types.WhatsAppID) *Portal {
-	user.portalsLock.Lock()
-	defer user.portalsLock.Unlock()
-	portal, ok := user.portalsByJID[jid]
+func (bridge *Bridge) GetPortalByJID(key database.PortalKey) *Portal {
+	bridge.portalsLock.Lock()
+	defer bridge.portalsLock.Unlock()
+	portal, ok := bridge.portalsByJID[key]
 	if !ok {
-		dbPortal := user.bridge.DB.Portal.GetByJID(user.ID, jid)
+		dbPortal := bridge.DB.Portal.GetByJID(key)
 		if dbPortal == nil {
-			dbPortal = user.bridge.DB.Portal.New()
-			dbPortal.JID = jid
-			dbPortal.Owner = user.ID
+			dbPortal = bridge.DB.Portal.New()
+			dbPortal.Key = key
 			dbPortal.Insert()
 		}
-		portal = user.NewPortal(dbPortal)
-		user.portalsByJID[portal.JID] = portal
+		portal = bridge.NewPortal(dbPortal)
+		bridge.portalsByJID[portal.Key] = portal
 		if len(portal.MXID) > 0 {
-			user.portalsByMXID[portal.MXID] = portal
+			bridge.portalsByMXID[portal.MXID] = portal
 		}
 	}
 	return portal
 }
 
-func (user *User) GetAllPortals() []*Portal {
-	user.portalsLock.Lock()
-	defer user.portalsLock.Unlock()
-	dbPortals := user.bridge.DB.Portal.GetAll(user.ID)
+func (bridge *Bridge) GetAllPortals() []*Portal {
+	bridge.portalsLock.Lock()
+	defer bridge.portalsLock.Unlock()
+	dbPortals := bridge.DB.Portal.GetAll()
 	output := make([]*Portal, len(dbPortals))
 	for index, dbPortal := range dbPortals {
-		portal, ok := user.portalsByJID[dbPortal.JID]
+		portal, ok := bridge.portalsByJID[dbPortal.Key]
 		if !ok {
-			portal = user.NewPortal(dbPortal)
-			user.portalsByJID[dbPortal.JID] = portal
+			portal = bridge.NewPortal(dbPortal)
+			bridge.portalsByJID[portal.Key] = portal
 			if len(dbPortal.MXID) > 0 {
-				user.portalsByMXID[dbPortal.MXID] = portal
+				bridge.portalsByMXID[dbPortal.MXID] = portal
 			}
 		}
 		output[index] = portal
@@ -99,19 +97,17 @@ func (user *User) GetAllPortals() []*Portal {
 	return output
 }
 
-func (user *User) NewPortal(dbPortal *database.Portal) *Portal {
+func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal {
 	return &Portal{
 		Portal: dbPortal,
-		user:   user,
-		bridge: user.bridge,
-		log:    user.log.Sub(fmt.Sprintf("Portal/%s", dbPortal.JID)),
+		bridge: bridge,
+		log:    bridge.Log.Sub(fmt.Sprintf("Portal/%s", dbPortal.Key)),
 	}
 }
 
 type Portal struct {
 	*database.Portal
 
-	user   *User
 	bridge *Bridge
 	log    log.Logger
 
@@ -126,9 +122,16 @@ func (portal *Portal) SyncParticipants(metadata *whatsappExt.GroupInfo) {
 		changed = true
 	}
 	for _, participant := range metadata.Participants {
-		puppet := portal.user.GetPuppetByJID(participant.JID)
+		puppet := portal.bridge.GetPuppetByJID(participant.JID)
 		puppet.Intent().EnsureJoined(portal.MXID)
 
+		user := portal.bridge.GetUserByJID(participant.JID)
+		if !portal.bridge.AS.StateStore.IsInvited(portal.MXID, user.MXID) {
+			portal.MainIntent().InviteUser(portal.MXID, &gomatrix.ReqInviteUser{
+				UserID: user.MXID,
+			})
+		}
+
 		expectedLevel := 0
 		if participant.IsSuperAdmin {
 			expectedLevel = 95
@@ -136,9 +139,8 @@ func (portal *Portal) SyncParticipants(metadata *whatsappExt.GroupInfo) {
 			expectedLevel = 50
 		}
 		changed = levels.EnsureUserLevel(puppet.MXID, expectedLevel) || changed
-
-		if participant.JID == portal.user.JID() {
-			changed = levels.EnsureUserLevel(portal.user.ID, expectedLevel) || changed
+		if user != nil {
+			changed = levels.EnsureUserLevel(user.MXID, expectedLevel) || changed
 		}
 	}
 	if changed {
@@ -146,10 +148,10 @@ func (portal *Portal) SyncParticipants(metadata *whatsappExt.GroupInfo) {
 	}
 }
 
-func (portal *Portal) UpdateAvatar(avatar *whatsappExt.ProfilePicInfo) bool {
+func (portal *Portal) UpdateAvatar(user *User, avatar *whatsappExt.ProfilePicInfo) bool {
 	if avatar == nil {
 		var err error
-		avatar, err = portal.user.Conn.GetProfilePicThumb(portal.JID)
+		avatar, err = user.Conn.GetProfilePicThumb(portal.Key.JID)
 		if err != nil {
 			portal.log.Errorln(err)
 			return false
@@ -184,7 +186,7 @@ func (portal *Portal) UpdateAvatar(avatar *whatsappExt.ProfilePicInfo) bool {
 
 func (portal *Portal) UpdateName(name string, setBy types.WhatsAppID) bool {
 	if portal.Name != name {
-		intent := portal.user.GetPuppetByJID(setBy).Intent()
+		intent := portal.bridge.GetPuppetByJID(setBy).Intent()
 		_, err := intent.SetRoomName(portal.MXID, name)
 		if err == nil {
 			portal.Name = name
@@ -197,7 +199,7 @@ func (portal *Portal) UpdateName(name string, setBy types.WhatsAppID) bool {
 
 func (portal *Portal) UpdateTopic(topic string, setBy types.WhatsAppID) bool {
 	if portal.Topic != topic {
-		intent := portal.user.GetPuppetByJID(setBy).Intent()
+		intent := portal.bridge.GetPuppetByJID(setBy).Intent()
 		_, err := intent.SetRoomTopic(portal.MXID, topic)
 		if err == nil {
 			portal.Topic = topic
@@ -208,8 +210,8 @@ func (portal *Portal) UpdateTopic(topic string, setBy types.WhatsAppID) bool {
 	return false
 }
 
-func (portal *Portal) UpdateMetadata() bool {
-	metadata, err := portal.user.Conn.GetGroupMetaData(portal.JID)
+func (portal *Portal) UpdateMetadata(user *User) bool {
+	metadata, err := user.Conn.GetGroupMetaData(portal.Key.JID)
 	if err != nil {
 		portal.log.Errorln(err)
 		return false
@@ -221,25 +223,23 @@ func (portal *Portal) UpdateMetadata() bool {
 	return update
 }
 
-func (portal *Portal) Sync(contact whatsapp.Contact) {
+func (portal *Portal) Sync(user *User, contact whatsapp.Contact) {
+	if portal.IsPrivateChat() {
+		return
+	}
+
 	if len(portal.MXID) == 0 {
-		if !portal.IsPrivateChat() {
-			portal.Name = contact.Name
-		}
-		err := portal.CreateMatrixRoom()
+		portal.Name = contact.Name
+		err := portal.CreateMatrixRoom([]string{user.MXID})
 		if err != nil {
 			portal.log.Errorln("Failed to create portal room:", err)
 			return
 		}
 	}
 
-	if portal.IsPrivateChat() {
-		return
-	}
-
 	update := false
-	update = portal.UpdateMetadata() || update
-	update = portal.UpdateAvatar(nil) || update
+	update = portal.UpdateMetadata(user) || update
+	update = portal.UpdateAvatar(user, nil) || update
 	if update {
 		portal.Update()
 	}
@@ -277,11 +277,12 @@ func (portal *Portal) ChangeAdminStatus(jids []string, setAdmin bool) {
 	}
 	changed := false
 	for _, jid := range jids {
-		puppet := portal.user.GetPuppetByJID(jid)
+		puppet := portal.bridge.GetPuppetByJID(jid)
 		changed = levels.EnsureUserLevel(puppet.MXID, newLevel) || changed
 
-		if jid == portal.user.JID() {
-			changed = levels.EnsureUserLevel(portal.user.ID, newLevel) || changed
+		user := portal.bridge.GetUserByJID(jid)
+		if user != nil {
+			changed = levels.EnsureUserLevel(user.MXID, newLevel) || changed
 		}
 	}
 	if changed {
@@ -312,15 +313,15 @@ func (portal *Portal) RestrictMetadataChanges(restrict bool) {
 		newLevel = 50
 	}
 	changed := false
-	changed = levels.EnsureEventLevel(gomatrix.StateRoomName, true, newLevel) || changed
-	changed = levels.EnsureEventLevel(gomatrix.StateRoomAvatar, true, newLevel) || changed
-	changed = levels.EnsureEventLevel(gomatrix.StateTopic, true, newLevel) || changed
+	changed = levels.EnsureEventLevel(gomatrix.StateRoomName, newLevel) || changed
+	changed = levels.EnsureEventLevel(gomatrix.StateRoomAvatar, newLevel) || changed
+	changed = levels.EnsureEventLevel(gomatrix.StateTopic, newLevel) || changed
 	if changed {
 		portal.MainIntent().SetPowerLevels(portal.MXID, levels)
 	}
 }
 
-func (portal *Portal) CreateMatrixRoom() error {
+func (portal *Portal) CreateMatrixRoom(invite []string) error {
 	portal.roomCreateLock.Lock()
 	defer portal.roomCreateLock.Unlock()
 	if len(portal.MXID) > 0 {
@@ -330,7 +331,6 @@ func (portal *Portal) CreateMatrixRoom() error {
 	name := portal.Name
 	topic := portal.Topic
 	isPrivateChat := false
-	invite := []string{portal.user.ID}
 	if portal.IsPrivateChat() {
 		name = ""
 		topic = "WhatsApp private chat"
@@ -360,18 +360,18 @@ func (portal *Portal) CreateMatrixRoom() error {
 }
 
 func (portal *Portal) IsPrivateChat() bool {
-	return strings.HasSuffix(portal.JID, whatsappExt.NewUserSuffix)
+	return strings.HasSuffix(portal.Key.JID, whatsappExt.NewUserSuffix)
 }
 
 func (portal *Portal) MainIntent() *appservice.IntentAPI {
 	if portal.IsPrivateChat() {
-		return portal.user.GetPuppetByJID(portal.JID).Intent()
+		return portal.bridge.GetPuppetByJID(portal.Key.JID).Intent()
 	}
-	return portal.bridge.AppService.BotIntent()
+	return portal.bridge.AS.BotIntent()
 }
 
 func (portal *Portal) IsDuplicate(id types.WhatsAppMessageID) bool {
-	msg := portal.bridge.DB.Message.GetByJID(portal.Owner, id)
+	msg := portal.bridge.DB.Message.GetByJID(portal.Key, id)
 	if msg != nil {
 		return true
 	}
@@ -380,7 +380,7 @@ func (portal *Portal) IsDuplicate(id types.WhatsAppMessageID) bool {
 
 func (portal *Portal) MarkHandled(jid types.WhatsAppMessageID, mxid types.MatrixEventID) {
 	msg := portal.bridge.DB.Message.New()
-	msg.Owner = portal.Owner
+	msg.Chat = portal.Key
 	msg.JID = jid
 	msg.MXID = mxid
 	msg.Insert()
@@ -392,7 +392,7 @@ func (portal *Portal) GetMessageIntent(info whatsapp.MessageInfo) *appservice.In
 			// TODO handle own messages in private chats properly
 			return nil
 		}
-		return portal.user.GetPuppetByJID(portal.user.JID()).Intent()
+		return portal.bridge.GetPuppetByJID(portal.Key.Receiver).Intent()
 	} else if portal.IsPrivateChat() {
 		return portal.MainIntent()
 	} else if len(info.SenderJid) == 0 {
@@ -402,14 +402,14 @@ func (portal *Portal) GetMessageIntent(info whatsapp.MessageInfo) *appservice.In
 			return nil
 		}
 	}
-	return portal.user.GetPuppetByJID(info.SenderJid).Intent()
+	return portal.bridge.GetPuppetByJID(info.SenderJid).Intent()
 }
 
 func (portal *Portal) SetReply(content *gomatrix.Content, info whatsapp.MessageInfo) {
 	if len(info.QuotedMessageID) == 0 {
 		return
 	}
-	message := portal.bridge.DB.Message.GetByJID(portal.Owner, info.QuotedMessageID)
+	message := portal.bridge.DB.Message.GetByJID(portal.Key, info.QuotedMessageID)
 	if message != nil {
 		event, err := portal.MainIntent().GetEvent(portal.MXID, message.MXID)
 		if err != nil {
@@ -421,29 +421,12 @@ func (portal *Portal) SetReply(content *gomatrix.Content, info whatsapp.MessageI
 	return
 }
 
-func (portal *Portal) FormatWhatsAppMessage(content *gomatrix.Content) {
-	output := html.EscapeString(content.Body)
-	for regex, replacement := range portal.user.waReplString {
-		output = regex.ReplaceAllString(output, replacement)
-	}
-	for regex, replacer := range portal.user.waReplFunc {
-		output = regex.ReplaceAllStringFunc(output, replacer)
-	}
-	if output != content.Body {
-		content.FormattedBody = output
-		content.Format = gomatrix.FormatHTML
-		for regex, replacer := range portal.user.waReplFuncText {
-			content.Body = regex.ReplaceAllStringFunc(content.Body, replacer)
-		}
-	}
-}
-
-func (portal *Portal) HandleTextMessage(message whatsapp.TextMessage) {
+func (portal *Portal) HandleTextMessage(source *User, message whatsapp.TextMessage) {
 	if portal.IsDuplicate(message.Info.Id) {
 		return
 	}
 
-	err := portal.CreateMatrixRoom()
+	err := portal.CreateMatrixRoom([]string{source.MXID})
 	if err != nil {
 		portal.log.Errorln("Failed to create portal room:", err)
 		return
@@ -459,7 +442,7 @@ func (portal *Portal) HandleTextMessage(message whatsapp.TextMessage) {
 		MsgType: gomatrix.MsgText,
 	}
 
-	portal.FormatWhatsAppMessage(content)
+	portal.bridge.Formatter.ParseWhatsApp(content)
 	portal.SetReply(content, message.Info)
 
 	intent.UserTyping(portal.MXID, false, 0)
@@ -472,12 +455,12 @@ func (portal *Portal) HandleTextMessage(message whatsapp.TextMessage) {
 	portal.log.Debugln("Handled message", message.Info.Id, "->", resp.EventID)
 }
 
-func (portal *Portal) HandleMediaMessage(download func() ([]byte, error), thumbnail []byte, info whatsapp.MessageInfo, mimeType, caption string) {
+func (portal *Portal) HandleMediaMessage(source *User, download func() ([]byte, error), thumbnail []byte, info whatsapp.MessageInfo, mimeType, caption string) {
 	if portal.IsDuplicate(info.Id) {
 		return
 	}
 
-	err := portal.CreateMatrixRoom()
+	err := portal.CreateMatrixRoom([]string{source.MXID})
 	if err != nil {
 		portal.log.Errorln("Failed to create portal room:", err)
 		return
@@ -559,7 +542,7 @@ func (portal *Portal) HandleMediaMessage(download func() ([]byte, error), thumbn
 			MsgType: gomatrix.MsgNotice,
 		}
 
-		portal.FormatWhatsAppMessage(captionContent)
+		portal.bridge.Formatter.ParseWhatsApp(captionContent)
 
 		_, err := intent.SendMassagedMessageEvent(portal.MXID, gomatrix.EventMessage, captionContent, ts)
 		if err != nil {
@@ -612,7 +595,7 @@ func (portal *Portal) downloadThumbnail(evt *gomatrix.Event) []byte {
 	return buf.Bytes()
 }
 
-func (portal *Portal) preprocessMatrixMedia(evt *gomatrix.Event, mediaType whatsapp.MediaType) *MediaUpload {
+func (portal *Portal) preprocessMatrixMedia(sender *User, evt *gomatrix.Event, mediaType whatsapp.MediaType) *MediaUpload {
 	if evt.Content.Info == nil {
 		evt.Content.Info = &gomatrix.FileInfo{}
 	}
@@ -630,7 +613,7 @@ func (portal *Portal) preprocessMatrixMedia(evt *gomatrix.Event, mediaType whats
 		return nil
 	}
 
-	url, mediaKey, fileEncSHA256, fileSHA256, fileLength, err := portal.user.Conn.Upload(bytes.NewReader(content), mediaType)
+	url, mediaKey, fileEncSHA256, fileSHA256, fileLength, err := sender.Conn.Upload(bytes.NewReader(content), mediaType)
 	if err != nil {
 		portal.log.Errorfln("Failed to upload media in %s: %v", evt.ID, err)
 		return nil
@@ -657,8 +640,8 @@ type MediaUpload struct {
 	Thumbnail     []byte
 }
 
-func (portal *Portal) GetMessage(jid types.WhatsAppMessageID) *waProto.WebMessageInfo {
-	node, err := portal.user.Conn.LoadMessagesBefore(portal.JID, jid, 1)
+func (portal *Portal) GetMessage(user *User, jid types.WhatsAppMessageID) *waProto.WebMessageInfo {
+	node, err := user.Conn.LoadMessagesBefore(portal.Key.JID, jid, 1)
 	if err != nil {
 		return nil
 	}
@@ -670,7 +653,7 @@ func (portal *Portal) GetMessage(jid types.WhatsAppMessageID) *waProto.WebMessag
 	if !ok {
 		return nil
 	}
-	node, err = portal.user.Conn.LoadMessagesAfter(portal.JID, msg.GetKey().GetId(), 1)
+	node, err = user.Conn.LoadMessagesAfter(portal.Key.JID, msg.GetKey().GetId(), 1)
 	if err != nil {
 		return nil
 	}
@@ -682,7 +665,11 @@ func (portal *Portal) GetMessage(jid types.WhatsAppMessageID) *waProto.WebMessag
 	return msg
 }
 
-func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
+func (portal *Portal) HandleMatrixMessage(sender *User, evt *gomatrix.Event) {
+	if portal.IsPrivateChat() && sender.JID != portal.Key.Receiver {
+		return
+	}
+
 	ts := uint64(evt.Timestamp / 1000)
 	status := waProto.WebMessageInfo_ERROR
 	fromMe := true
@@ -690,7 +677,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
 		Key: &waProto.MessageKey{
 			FromMe:    &fromMe,
 			Id:        makeMessageID(),
-			RemoteJid: &portal.JID,
+			RemoteJid: &portal.Key.JID,
 		},
 		MessageTimestamp: &ts,
 		Message:          &waProto.Message{},
@@ -702,12 +689,12 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
 		evt.Content.RemoveReplyFallback()
 		msg := portal.bridge.DB.Message.GetByMXID(replyToID)
 		if msg != nil {
-			origMsg := portal.GetMessage(msg.JID)
+			origMsg := portal.GetMessage(sender, msg.JID)
 			if origMsg != nil {
 				ctxInfo.StanzaId = &msg.JID
 				replyMsgSender := origMsg.GetParticipant()
 				if origMsg.GetKey().GetFromMe() {
-					replyMsgSender = portal.user.JID()
+					replyMsgSender = sender.JID
 				}
 				ctxInfo.Participant = &replyMsgSender
 				ctxInfo.QuotedMessage = []*waProto.Message{origMsg.Message}
@@ -719,7 +706,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
 	case gomatrix.MsgText, gomatrix.MsgEmote:
 		text := evt.Content.Body
 		if evt.Content.Format == gomatrix.FormatHTML {
-			text = portal.user.htmlParser.Parse(evt.Content.FormattedBody)
+			text = portal.bridge.Formatter.ParseMatrix(evt.Content.FormattedBody)
 		}
 		if evt.Content.MsgType == gomatrix.MsgEmote {
 			text = "/me " + text
@@ -737,7 +724,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
 			info.Message.Conversation = &text
 		}
 	case gomatrix.MsgImage:
-		media := portal.preprocessMatrixMedia(evt, whatsapp.MediaImage)
+		media := portal.preprocessMatrixMedia(sender, evt, whatsapp.MediaImage)
 		if media == nil {
 			return
 		}
@@ -752,7 +739,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
 			FileLength:    &media.FileLength,
 		}
 	case gomatrix.MsgVideo:
-		media := portal.preprocessMatrixMedia(evt, whatsapp.MediaVideo)
+		media := portal.preprocessMatrixMedia(sender, evt, whatsapp.MediaVideo)
 		if media == nil {
 			return
 		}
@@ -769,7 +756,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
 			FileLength:    &media.FileLength,
 		}
 	case gomatrix.MsgAudio:
-		media := portal.preprocessMatrixMedia(evt, whatsapp.MediaAudio)
+		media := portal.preprocessMatrixMedia(sender, evt, whatsapp.MediaAudio)
 		if media == nil {
 			return
 		}
@@ -784,7 +771,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
 			FileLength:    &media.FileLength,
 		}
 	case gomatrix.MsgFile:
-		media := portal.preprocessMatrixMedia(evt, whatsapp.MediaDocument)
+		media := portal.preprocessMatrixMedia(sender, evt, whatsapp.MediaDocument)
 		if media == nil {
 			return
 		}
@@ -800,7 +787,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
 		portal.log.Debugln("Unhandled Matrix event:", evt)
 		return
 	}
-	err = portal.user.Conn.Send(info)
+	err = sender.Conn.Send(info)
 	portal.MarkHandled(info.GetKey().GetId(), evt.ID)
 	if err != nil {
 		portal.log.Errorfln("Error handling Matrix event %s: %v", evt.ID, err)

+ 35 - 57
puppet.go

@@ -30,105 +30,83 @@ import (
 	"maunium.net/go/mautrix-whatsapp/whatsapp-ext"
 )
 
-func (bridge *Bridge) ParsePuppetMXID(mxid types.MatrixUserID) (types.MatrixUserID, types.WhatsAppID, bool) {
+func (bridge *Bridge) ParsePuppetMXID(mxid types.MatrixUserID) (types.WhatsAppID, bool) {
 	userIDRegex, err := regexp.Compile(fmt.Sprintf("^@%s:%s$",
-		bridge.Config.Bridge.FormatUsername("(.+)", "([0-9]+)"),
+		bridge.Config.Bridge.FormatUsername("([0-9]+)"),
 		bridge.Config.Homeserver.Domain))
 	if err != nil {
 		bridge.Log.Warnln("Failed to compile puppet user ID regex:", err)
-		return "", "", false
+		return "", false
 	}
 	match := userIDRegex.FindStringSubmatch(string(mxid))
-	if match == nil || len(match) != 3 {
-		return "", "", false
+	if match == nil || len(match) != 2 {
+		return "", false
 	}
 
-	receiver := types.MatrixUserID(match[1])
-	receiver = strings.Replace(receiver, "=40", "@", 1)
-	colonIndex := strings.LastIndex(receiver, "=3")
-	receiver = receiver[:colonIndex] + ":" + receiver[colonIndex+len("=3"):]
 	jid := types.WhatsAppID(match[2] + whatsappExt.NewUserSuffix)
-	return receiver, jid, true
+	return jid, true
 }
 
 func (bridge *Bridge) GetPuppetByMXID(mxid types.MatrixUserID) *Puppet {
-	receiver, jid, ok := bridge.ParsePuppetMXID(mxid)
+	jid, ok := bridge.ParsePuppetMXID(mxid)
 	if !ok {
 		return nil
 	}
 
-	user := bridge.GetUser(receiver)
-	if user == nil {
-		return nil
-	}
-
-	return user.GetPuppetByJID(jid)
-}
-
-func (user *User) GetPuppetByMXID(mxid types.MatrixUserID) *Puppet {
-	receiver, jid, ok := user.bridge.ParsePuppetMXID(mxid)
-	if !ok || receiver != user.ID {
-		return nil
-	}
-
-	return user.GetPuppetByJID(jid)
+	return bridge.GetPuppetByJID(jid)
 }
 
-func (user *User) GetPuppetByJID(jid types.WhatsAppID) *Puppet {
-	user.puppetsLock.Lock()
-	defer user.puppetsLock.Unlock()
-	puppet, ok := user.puppets[jid]
+func (bridge *Bridge) GetPuppetByJID(jid types.WhatsAppID) *Puppet {
+	bridge.puppetsLock.Lock()
+	defer bridge.puppetsLock.Unlock()
+	puppet, ok := bridge.puppets[jid]
 	if !ok {
-		dbPuppet := user.bridge.DB.Puppet.Get(jid, user.ID)
+		dbPuppet := bridge.DB.Puppet.Get(jid)
 		if dbPuppet == nil {
-			dbPuppet = user.bridge.DB.Puppet.New()
+			dbPuppet = bridge.DB.Puppet.New()
 			dbPuppet.JID = jid
-			dbPuppet.Receiver = user.ID
 			dbPuppet.Insert()
 		}
-		puppet = user.NewPuppet(dbPuppet)
-		user.puppets[puppet.JID] = puppet
+		puppet = bridge.NewPuppet(dbPuppet)
+		bridge.puppets[puppet.JID] = puppet
 	}
 	return puppet
 }
 
-func (user *User) GetAllPuppets() []*Puppet {
-	user.puppetsLock.Lock()
-	defer user.puppetsLock.Unlock()
-	dbPuppets := user.bridge.DB.Puppet.GetAll(user.ID)
+func (bridge *Bridge) GetAllPuppets() []*Puppet {
+	bridge.puppetsLock.Lock()
+	defer bridge.puppetsLock.Unlock()
+	dbPuppets := bridge.DB.Puppet.GetAll()
 	output := make([]*Puppet, len(dbPuppets))
 	for index, dbPuppet := range dbPuppets {
-		puppet, ok := user.puppets[dbPuppet.JID]
+		puppet, ok := bridge.puppets[dbPuppet.JID]
 		if !ok {
-			puppet = user.NewPuppet(dbPuppet)
-			user.puppets[dbPuppet.JID] = puppet
+			puppet = bridge.NewPuppet(dbPuppet)
+			bridge.puppets[dbPuppet.JID] = puppet
 		}
 		output[index] = puppet
 	}
 	return output
 }
 
-func (user *User) NewPuppet(dbPuppet *database.Puppet) *Puppet {
+func (bridge *Bridge) NewPuppet(dbPuppet *database.Puppet) *Puppet {
 	return &Puppet{
 		Puppet: dbPuppet,
-		user:   user,
-		bridge: user.bridge,
-		log:    user.log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)),
+		bridge: bridge,
+		log:    bridge.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)),
 
 		MXID: fmt.Sprintf("@%s:%s",
-			user.bridge.Config.Bridge.FormatUsername(
-				dbPuppet.Receiver,
+			bridge.Config.Bridge.FormatUsername(
 				strings.Replace(
 					dbPuppet.JID,
 					whatsappExt.NewUserSuffix, "", 1)),
-			user.bridge.Config.Homeserver.Domain),
+			bridge.Config.Homeserver.Domain),
 	}
 }
 
 type Puppet struct {
 	*database.Puppet
 
-	user   *User
 	bridge *Bridge
 	log    log.Logger
 
@@ -143,13 +121,13 @@ func (puppet *Puppet) PhoneNumber() string {
 }
 
 func (puppet *Puppet) Intent() *appservice.IntentAPI {
-	return puppet.bridge.AppService.Intent(puppet.MXID)
+	return puppet.bridge.AS.Intent(puppet.MXID)
 }
 
-func (puppet *Puppet) UpdateAvatar(avatar *whatsappExt.ProfilePicInfo) bool {
+func (puppet *Puppet) UpdateAvatar(source *User, avatar *whatsappExt.ProfilePicInfo) bool {
 	if avatar == nil {
 		var err error
-		avatar, err = puppet.user.Conn.GetProfilePicThumb(puppet.JID)
+		avatar, err = source.Conn.GetProfilePicThumb(puppet.JID)
 		if err != nil {
 			puppet.log.Errorln(err)
 			return false
@@ -184,11 +162,11 @@ func (puppet *Puppet) UpdateAvatar(avatar *whatsappExt.ProfilePicInfo) bool {
 	return true
 }
 
-func (puppet *Puppet) Sync(contact whatsapp.Contact) {
+func (puppet *Puppet) Sync(source *User, contact whatsapp.Contact) {
 	puppet.Intent().EnsureRegistered()
 
-	if contact.Jid == puppet.user.JID() {
-		contact.Notify = puppet.user.Conn.Info.Pushname
+	if contact.Jid == source.JID {
+		contact.Notify = source.Conn.Info.Pushname
 	}
 	newName := puppet.bridge.Config.Bridge.FormatDisplayname(contact)
 	if puppet.Displayname != newName {
@@ -201,7 +179,7 @@ func (puppet *Puppet) Sync(contact whatsapp.Contact) {
 		}
 	}
 
-	if puppet.UpdateAvatar(nil) {
+	if puppet.UpdateAvatar(source, nil) {
 		puppet.Update()
 	}
 }

+ 70 - 62
user.go

@@ -17,14 +17,11 @@
 package main
 
 import (
-	"regexp"
 	"strings"
-	"sync"
 	"time"
 
 	"github.com/Rhymen/go-whatsapp"
 	"github.com/skip2/go-qrcode"
-	"maunium.net/go/gomatrix/format"
 	log "maunium.net/go/maulogger"
 	"maunium.net/go/mautrix-whatsapp/database"
 	"maunium.net/go/mautrix-whatsapp/types"
@@ -41,31 +38,42 @@ type User struct {
 	Admin       bool
 	Whitelisted bool
 	jid         string
+}
 
-	portalsByMXID map[types.MatrixRoomID]*Portal
-	portalsByJID  map[types.WhatsAppID]*Portal
-	portalsLock   sync.Mutex
-	puppets       map[types.WhatsAppID]*Puppet
-	puppetsLock   sync.Mutex
-
-	htmlParser *format.HTMLParser
-
-	waReplString   map[*regexp.Regexp]string
-	waReplFunc     map[*regexp.Regexp]func(string) string
-	waReplFuncText map[*regexp.Regexp]func(string) string
+func (bridge *Bridge) GetUserByMXID(userID types.MatrixUserID) *User {
+	bridge.usersLock.Lock()
+	defer bridge.usersLock.Unlock()
+	user, ok := bridge.usersByMXID[userID]
+	if !ok {
+		dbUser := bridge.DB.User.GetByMXID(userID)
+		if dbUser == nil {
+			dbUser = bridge.DB.User.New()
+			dbUser.MXID = userID
+			dbUser.Insert()
+		}
+		user = bridge.NewUser(dbUser)
+		bridge.usersByMXID[user.MXID] = user
+		if len(user.ManagementRoom) > 0 {
+			bridge.managementRooms[user.ManagementRoom] = user
+		}
+	}
+	return user
 }
 
-func (bridge *Bridge) GetUser(userID types.MatrixUserID) *User {
-	user, ok := bridge.users[userID]
+
+func (bridge *Bridge) GetUserByJID(userID types.WhatsAppID) *User {
+	bridge.usersLock.Lock()
+	defer bridge.usersLock.Unlock()
+	user, ok := bridge.usersByJID[userID]
 	if !ok {
-		dbUser := bridge.DB.User.Get(userID)
+		dbUser := bridge.DB.User.GetByMXID(userID)
 		if dbUser == nil {
 			dbUser = bridge.DB.User.New()
-			dbUser.ID = userID
+			dbUser.MXID = userID
 			dbUser.Insert()
 		}
 		user = bridge.NewUser(dbUser)
-		bridge.users[user.ID] = user
+		bridge.usersByJID[user.JID] = user
 		if len(user.ManagementRoom) > 0 {
 			bridge.managementRooms[user.ManagementRoom] = user
 		}
@@ -74,13 +82,15 @@ func (bridge *Bridge) GetUser(userID types.MatrixUserID) *User {
 }
 
 func (bridge *Bridge) GetAllUsers() []*User {
+	bridge.usersLock.Lock()
+	defer bridge.usersLock.Unlock()
 	dbUsers := bridge.DB.User.GetAll()
 	output := make([]*User, len(dbUsers))
 	for index, dbUser := range dbUsers {
-		user, ok := bridge.users[dbUser.ID]
+		user, ok := bridge.usersByMXID[dbUser.MXID]
 		if !ok {
 			user = bridge.NewUser(dbUser)
-			bridge.users[user.ID] = user
+			bridge.usersByMXID[user.MXID] = user
 			if len(user.ManagementRoom) > 0 {
 				bridge.managementRooms[user.ManagementRoom] = user
 			}
@@ -94,15 +104,10 @@ func (bridge *Bridge) NewUser(dbUser *database.User) *User {
 	user := &User{
 		User:          dbUser,
 		bridge:        bridge,
-		log:           bridge.Log.Sub("User").Sub(string(dbUser.ID)),
-		portalsByMXID: make(map[types.MatrixRoomID]*Portal),
-		portalsByJID:  make(map[types.WhatsAppID]*Portal),
-		puppets:       make(map[types.WhatsAppID]*Puppet),
+		log:           bridge.Log.Sub("User").Sub(string(dbUser.MXID)),
 	}
-	user.Whitelisted = user.bridge.Config.Bridge.Permissions.IsWhitelisted(user.ID)
-	user.Admin = user.bridge.Config.Bridge.Permissions.IsAdmin(user.ID)
-	user.htmlParser = user.newHTMLParser()
-	user.waReplString, user.waReplFunc, user.waReplFuncText = user.newWhatsAppFormatMaps()
+	user.Whitelisted = user.bridge.Config.Bridge.Permissions.IsWhitelisted(user.MXID)
+	user.Admin = user.bridge.Config.Bridge.Permissions.IsAdmin(user.MXID)
 	return user
 }
 
@@ -152,7 +157,6 @@ func (user *User) RestoreSession() bool {
 		sess, err := user.Conn.RestoreSession(*user.Session)
 		if err != nil {
 			user.log.Errorln("Failed to restore session:", err)
-			//user.SetSession(nil)
 			return false
 		}
 		user.SetSession(&sess)
@@ -162,8 +166,12 @@ func (user *User) RestoreSession() bool {
 	return false
 }
 
+func (user *User) IsLoggedIn() bool {
+	return user.Conn != nil
+}
+
 func (user *User) Login(roomID types.MatrixRoomID) {
-	bot := user.bridge.AppService.BotClient()
+	bot := user.bridge.AS.BotClient()
 
 	qrChan := make(chan string, 2)
 	go func() {
@@ -194,38 +202,24 @@ func (user *User) Login(roomID types.MatrixRoomID) {
 		qrChan <- "error"
 		return
 	}
+	user.JID = strings.Replace(user.Conn.Info.Wid, whatsappExt.OldUserSuffix, whatsappExt.NewUserSuffix, 1)
 	user.Session = &session
 	user.Update()
 	bot.SendNotice(roomID, "Successfully logged in. Synchronizing chats...")
 	go user.Sync()
 }
 
-func (user *User) JID() string {
-	if user.Conn == nil {
-		return ""
-	}
-	if len(user.jid) == 0 {
-		user.jid = strings.Replace(user.Conn.Info.Wid, whatsappExt.OldUserSuffix, whatsappExt.NewUserSuffix, 1)
-	}
-	return user.jid
-}
-
 func (user *User) Sync() {
 	user.log.Debugln("Syncing...")
 	user.Conn.Contacts()
 	for jid, contact := range user.Conn.Store.Contacts {
 		if strings.HasSuffix(jid, whatsappExt.NewUserSuffix) {
-			puppet := user.GetPuppetByJID(contact.Jid)
-			puppet.Sync(contact)
-		}
-
-		if len(contact.Notify) == 0 && !strings.HasSuffix(jid, "@g.us") {
-			// No messages sent -> don't bridge
-			continue
+			puppet := user.bridge.GetPuppetByJID(contact.Jid)
+			puppet.Sync(user, contact)
+		} else {
+			portal := user.bridge.GetPortalByJID(database.GroupPortalKey(contact.Jid))
+			portal.Sync(user, contact)
 		}
-
-		portal := user.GetPortalByJID(contact.Jid)
-		portal.Sync(contact)
 	}
 }
 
@@ -237,33 +231,41 @@ func (user *User) HandleJSONParseError(err error) {
 	user.log.Errorln("WhatsApp JSON parse error:", err)
 }
 
+func (user *User) PortalKey(jid types.WhatsAppID) database.PortalKey {
+	return database.NewPortalKey(jid, user.JID)
+}
+
+func (user *User) GetPortalByJID(jid types.WhatsAppID) *Portal {
+	return user.bridge.GetPortalByJID(user.PortalKey(jid))
+}
+
 func (user *User) HandleTextMessage(message whatsapp.TextMessage) {
 	portal := user.GetPortalByJID(message.Info.RemoteJid)
-	portal.HandleTextMessage(message)
+	portal.HandleTextMessage(user, message)
 }
 
 func (user *User) HandleImageMessage(message whatsapp.ImageMessage) {
 	portal := user.GetPortalByJID(message.Info.RemoteJid)
-	portal.HandleMediaMessage(message.Download, message.Thumbnail, message.Info, message.Type, message.Caption)
+	portal.HandleMediaMessage(user, message.Download, message.Thumbnail, message.Info, message.Type, message.Caption)
 }
 
 func (user *User) HandleVideoMessage(message whatsapp.VideoMessage) {
 	portal := user.GetPortalByJID(message.Info.RemoteJid)
-	portal.HandleMediaMessage(message.Download, message.Thumbnail, message.Info, message.Type, message.Caption)
+	portal.HandleMediaMessage(user, message.Download, message.Thumbnail, message.Info, message.Type, message.Caption)
 }
 
 func (user *User) HandleAudioMessage(message whatsapp.AudioMessage) {
 	portal := user.GetPortalByJID(message.Info.RemoteJid)
-	portal.HandleMediaMessage(message.Download, nil, message.Info, message.Type, "")
+	portal.HandleMediaMessage(user, message.Download, nil, message.Info, message.Type, "")
 }
 
 func (user *User) HandleDocumentMessage(message whatsapp.DocumentMessage) {
 	portal := user.GetPortalByJID(message.Info.RemoteJid)
-	portal.HandleMediaMessage(message.Download, message.Thumbnail, message.Info, message.Type, message.Title)
+	portal.HandleMediaMessage(user, message.Download, message.Thumbnail, message.Info, message.Type, message.Title)
 }
 
 func (user *User) HandlePresence(info whatsappExt.Presence) {
-	puppet := user.GetPuppetByJID(info.SenderJID)
+	puppet := user.bridge.GetPuppetByJID(info.SenderJID)
 	switch info.Status {
 	case whatsappExt.PresenceUnavailable:
 		puppet.Intent().SetPresence("offline")
@@ -277,6 +279,12 @@ func (user *User) HandlePresence(info whatsappExt.Presence) {
 		}
 	case whatsappExt.PresenceComposing:
 		portal := user.GetPortalByJID(info.JID)
+		if len(puppet.typingIn) > 0 && puppet.typingAt+15 > time.Now().Unix() {
+			if puppet.typingIn == portal.MXID {
+				return
+			}
+			puppet.Intent().UserTyping(puppet.typingIn, false, 0)
+		}
 		puppet.typingIn = portal.MXID
 		puppet.typingAt = time.Now().Unix()
 		puppet.Intent().UserTyping(portal.MXID, true, 15*1000)
@@ -290,9 +298,9 @@ func (user *User) HandleMsgInfo(info whatsappExt.MsgInfo) {
 			return
 		}
 
-		intent := user.GetPuppetByJID(info.SenderJID).Intent()
+		intent := user.bridge.GetPuppetByJID(info.SenderJID).Intent()
 		for _, id := range info.IDs {
-			msg := user.bridge.DB.Message.GetByJID(user.ID, id)
+			msg := user.bridge.DB.Message.GetByJID(portal.Key, id)
 			if msg == nil {
 				continue
 			}
@@ -308,11 +316,11 @@ func (user *User) HandleCommand(cmd whatsappExt.Command) {
 	switch cmd.Type {
 	case whatsappExt.CommandPicture:
 		if strings.HasSuffix(cmd.JID, whatsappExt.NewUserSuffix) {
-			puppet := user.GetPuppetByJID(cmd.JID)
-			puppet.UpdateAvatar(cmd.ProfilePicInfo)
+			puppet := user.bridge.GetPuppetByJID(cmd.JID)
+			puppet.UpdateAvatar(user, cmd.ProfilePicInfo)
 		} else {
 			portal := user.GetPortalByJID(cmd.JID)
-			portal.UpdateAvatar(cmd.ProfilePicInfo)
+			portal.UpdateAvatar(user, cmd.ProfilePicInfo)
 		}
 	}
 }

+ 9 - 9
vendor/maunium.net/go/gomatrix/client.go

@@ -463,7 +463,7 @@ func (cli *Client) SetAvatarURL(url string) (err error) {
 // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
 func (cli *Client) SendMessageEvent(roomID string, eventType EventType, contentJSON interface{}) (resp *RespSendEvent, err error) {
 	txnID := txnID()
-	urlPath := cli.BuildURL("rooms", roomID, "send", string(eventType), txnID)
+	urlPath := cli.BuildURL("rooms", roomID, "send", eventType.String(), txnID)
 	_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
 	return
 }
@@ -472,7 +472,7 @@ func (cli *Client) SendMessageEvent(roomID string, eventType EventType, contentJ
 // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
 func (cli *Client) SendMassagedMessageEvent(roomID string, eventType EventType, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) {
 	txnID := txnID()
-	urlPath := cli.BuildURLWithQuery([]string{"rooms", roomID, "send", string(eventType), txnID}, map[string]string{
+	urlPath := cli.BuildURLWithQuery([]string{"rooms", roomID, "send", eventType.String(), txnID}, map[string]string{
 		"ts": strconv.FormatInt(ts, 10),
 	})
 	_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
@@ -482,7 +482,7 @@ func (cli *Client) SendMassagedMessageEvent(roomID string, eventType EventType,
 // SendStateEvent sends a state event into a room. See http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-state-eventtype-statekey
 // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
 func (cli *Client) SendStateEvent(roomID string, eventType EventType, stateKey string, contentJSON interface{}) (resp *RespSendEvent, err error) {
-	urlPath := cli.BuildURL("rooms", roomID, "state", string(eventType), stateKey)
+	urlPath := cli.BuildURL("rooms", roomID, "state", eventType.String(), stateKey)
 	_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
 	return
 }
@@ -490,7 +490,7 @@ func (cli *Client) SendStateEvent(roomID string, eventType EventType, stateKey s
 // SendStateEvent sends a state event into a room. See http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-state-eventtype-statekey
 // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
 func (cli *Client) SendMassagedStateEvent(roomID string, eventType EventType, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) {
-	urlPath := cli.BuildURLWithQuery([]string{"rooms", roomID, "state", string(eventType), stateKey}, map[string]string{
+	urlPath := cli.BuildURLWithQuery([]string{"rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{
 		"ts": strconv.FormatInt(ts, 10),
 	})
 	_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
@@ -500,7 +500,7 @@ func (cli *Client) SendMassagedStateEvent(roomID string, eventType EventType, st
 // SendText sends an m.room.message event into the given room with a msgtype of m.text
 // See http://matrix.org/docs/spec/client_server/r0.2.0.html#m-text
 func (cli *Client) SendText(roomID, text string) (*RespSendEvent, error) {
-	return cli.SendMessageEvent(roomID, "m.room.message", Content{
+	return cli.SendMessageEvent(roomID, EventMessage, Content{
 		MsgType: MsgText,
 		Body:    text,
 	})
@@ -509,7 +509,7 @@ func (cli *Client) SendText(roomID, text string) (*RespSendEvent, error) {
 // SendImage sends an m.room.message event into the given room with a msgtype of m.image
 // See https://matrix.org/docs/spec/client_server/r0.2.0.html#m-image
 func (cli *Client) SendImage(roomID, body, url string) (*RespSendEvent, error) {
-	return cli.SendMessageEvent(roomID, "m.room.message", Content{
+	return cli.SendMessageEvent(roomID, EventMessage, Content{
 		MsgType: MsgImage,
 		Body:    body,
 		URL:     url,
@@ -519,7 +519,7 @@ func (cli *Client) SendImage(roomID, body, url string) (*RespSendEvent, error) {
 // SendVideo sends an m.room.message event into the given room with a msgtype of m.video
 // See https://matrix.org/docs/spec/client_server/r0.2.0.html#m-video
 func (cli *Client) SendVideo(roomID, body, url string) (*RespSendEvent, error) {
-	return cli.SendMessageEvent(roomID, "m.room.message", Content{
+	return cli.SendMessageEvent(roomID, EventMessage, Content{
 		MsgType: MsgVideo,
 		Body:    body,
 		URL:     url,
@@ -529,7 +529,7 @@ func (cli *Client) SendVideo(roomID, body, url string) (*RespSendEvent, error) {
 // SendNotice sends an m.room.message event into the given room with a msgtype of m.notice
 // See http://matrix.org/docs/spec/client_server/r0.2.0.html#m-notice
 func (cli *Client) SendNotice(roomID, text string) (*RespSendEvent, error) {
-	return cli.SendMessageEvent(roomID, "m.room.message", Content{
+	return cli.SendMessageEvent(roomID, EventMessage, Content{
 		MsgType: MsgNotice,
 		Body:    text,
 	})
@@ -622,7 +622,7 @@ func (cli *Client) SetPresence(status string) (err error) {
 // the HTTP response body, or return an error.
 // See http://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-client-r0-rooms-roomid-state-eventtype-statekey
 func (cli *Client) StateEvent(roomID string, eventType EventType, stateKey string, outContent interface{}) (err error) {
-	u := cli.BuildURL("rooms", roomID, "state", string(eventType), stateKey)
+	u := cli.BuildURL("rooms", roomID, "state", eventType.String(), stateKey)
 	_, err = cli.MakeRequest("GET", u, nil, outContent)
 	return
 }

+ 39 - 23
vendor/maunium.net/go/gomatrix/events.go

@@ -5,28 +5,44 @@ import (
 	"sync"
 )
 
-type EventType string
+type EventType struct {
+	Type    string
+	IsState bool
+}
+
+func (et *EventType) UnmarshalJSON(data []byte) error {
+	return json.Unmarshal(data, &et.Type)
+}
+
+func (et *EventType) MarshalJSON() ([]byte, error) {
+	return json.Marshal(&et.Type)
+}
+
+func (et *EventType) String() string {
+	return et.Type
+}
+
 type MessageType string
 
 // State events
-const (
-	StateAliases        EventType = "m.room.aliases"
-	StateCanonicalAlias           = "m.room.canonical_alias"
-	StateCreate                   = "m.room.create"
-	StateJoinRules                = "m.room.join_rules"
-	StateMember                   = "m.room.member"
-	StatePowerLevels              = "m.room.power_levels"
-	StateRoomName                 = "m.room.name"
-	StateTopic                    = "m.room.topic"
-	StateRoomAvatar               = "m.room.avatar"
-	StatePinnedEvents             = "m.room.pinned_events"
+var (
+	StateAliases        = EventType{"m.room.aliases", true}
+	StateCanonicalAlias = EventType{"m.room.canonical_alias", true}
+	StateCreate         = EventType{"m.room.create", true}
+	StateJoinRules      = EventType{"m.room.join_rules", true}
+	StateMember         = EventType{"m.room.member", true}
+	StatePowerLevels    = EventType{"m.room.power_levels", true}
+	StateRoomName       = EventType{"m.room.name", true}
+	StateTopic          = EventType{"m.room.topic", true}
+	StateRoomAvatar     = EventType{"m.room.avatar", true}
+	StatePinnedEvents   = EventType{"m.room.pinned_events", true}
 )
 
 // Message events
-const (
-	EventRedaction EventType = "m.room.redaction"
-	EventMessage             = "m.room.message"
-	EventSticker             = "m.sticker"
+var (
+	EventRedaction = EventType{"m.room.redaction", false}
+	EventMessage   = EventType{"m.room.message", false}
+	EventSticker   = EventType{"m.sticker", false}
 )
 
 // Msgtypes
@@ -258,12 +274,12 @@ func (pl *PowerLevels) EnsureUserLevel(userID string, level int) bool {
 	return false
 }
 
-func (pl *PowerLevels) GetEventLevel(eventType EventType, isState bool) int {
+func (pl *PowerLevels) GetEventLevel(eventType EventType) int {
 	pl.eventsLock.RLock()
 	defer pl.eventsLock.RUnlock()
 	level, ok := pl.Events[eventType]
 	if !ok {
-		if isState {
+		if eventType.IsState {
 			return pl.StateDefault()
 		}
 		return pl.EventsDefault
@@ -271,20 +287,20 @@ func (pl *PowerLevels) GetEventLevel(eventType EventType, isState bool) int {
 	return level
 }
 
-func (pl *PowerLevels) SetEventLevel(eventType EventType, isState bool, level int) {
+func (pl *PowerLevels) SetEventLevel(eventType EventType, level int) {
 	pl.eventsLock.Lock()
 	defer pl.eventsLock.Unlock()
-	if (isState && level == pl.StateDefault()) || (!isState && level == pl.EventsDefault) {
+	if (eventType.IsState && level == pl.StateDefault()) || (!eventType.IsState && level == pl.EventsDefault) {
 		delete(pl.Events, eventType)
 	} else {
 		pl.Events[eventType] = level
 	}
 }
 
-func (pl *PowerLevels) EnsureEventLevel(eventType EventType, isState bool, level int) bool {
-	existingLevel := pl.GetEventLevel(eventType, isState)
+func (pl *PowerLevels) EnsureEventLevel(eventType EventType, level int) bool {
+	existingLevel := pl.GetEventLevel(eventType)
 	if existingLevel != level {
-		pl.SetEventLevel(eventType, isState, level)
+		pl.SetEventLevel(eventType, level)
 		return true
 	}
 	return false

+ 19 - 8
vendor/maunium.net/go/mautrix-appservice/appservice.go

@@ -2,17 +2,19 @@ package appservice
 
 import (
 	"fmt"
+	"html/template"
 	"io/ioutil"
 	"os"
+	"path/filepath"
 
 	"gopkg.in/yaml.v2"
 
-	"maunium.net/go/maulogger"
-	"strings"
-	"net/http"
 	"errors"
 	"maunium.net/go/gomatrix"
+	"maunium.net/go/maulogger"
+	"net/http"
 	"regexp"
+	"strings"
 )
 
 // EventChannelSize is the size for the Events channel in Appservice instances.
@@ -263,15 +265,24 @@ func CreateLogConfig() LogConfig {
 	}
 }
 
+type FileFormatData struct {
+	Date string
+	Index int
+}
+
 // GetFileFormat returns a mauLogger-compatible logger file format based on the data in the struct.
 func (lc LogConfig) GetFileFormat() maulogger.LoggerFileFormat {
-	path := lc.FileNameFormat
-	if len(lc.Directory) > 0 {
-		path = lc.Directory + "/" + path
-	}
+	os.MkdirAll(lc.Directory, 0700)
+	path := filepath.Join(lc.Directory, lc.FileNameFormat)
+	tpl, _ := template.New("fileformat").Parse(path)
 
 	return func(now string, i int) string {
-		return fmt.Sprintf(path, now, i)
+		var buf strings.Builder
+		tpl.Execute(&buf, FileFormatData{
+			Date: now,
+			Index: i,
+		})
+		return buf.String()
 	}
 }
 

+ 3 - 3
vendor/maunium.net/go/mautrix-appservice/intent.go

@@ -201,19 +201,19 @@ func (intent *IntentAPI) RedactEvent(roomID, eventID string, req *gomatrix.ReqRe
 }
 
 func (intent *IntentAPI) SetRoomName(roomID, roomName string) (*gomatrix.RespSendEvent, error) {
-	return intent.SendStateEvent(roomID, "m.room.name", "", map[string]interface{}{
+	return intent.SendStateEvent(roomID, gomatrix.StateRoomName, "", map[string]interface{}{
 		"name": roomName,
 	})
 }
 
 func (intent *IntentAPI) SetRoomAvatar(roomID, avatarURL string) (*gomatrix.RespSendEvent, error) {
-	return intent.SendStateEvent(roomID, "m.room.avatar", "", map[string]interface{}{
+	return intent.SendStateEvent(roomID, gomatrix.StateRoomAvatar, "", map[string]interface{}{
 		"url": avatarURL,
 	})
 }
 
 func (intent *IntentAPI) SetRoomTopic(roomID, topic string) (*gomatrix.RespSendEvent, error) {
-	return intent.SendStateEvent(roomID, "m.room.topic", "", map[string]interface{}{
+	return intent.SendStateEvent(roomID, gomatrix.StateTopic, "", map[string]interface{}{
 		"topic": topic,
 	})
 }

+ 23 - 16
vendor/maunium.net/go/mautrix-appservice/statestore.go

@@ -15,13 +15,15 @@ type StateStore interface {
 	SetTyping(roomID, userID string, timeout int64)
 
 	IsInRoom(roomID, userID string) bool
+	IsInvited(roomID, userID string) bool
+	IsMembership(roomID, userID string, allowedMemberships ...string) bool
 	SetMembership(roomID, userID, membership string)
 
 	SetPowerLevels(roomID string, levels *gomatrix.PowerLevels)
 	GetPowerLevels(roomID string) *gomatrix.PowerLevels
 	GetPowerLevel(roomID, userID string) int
-	GetPowerLevelRequirement(roomID string, eventType gomatrix.EventType, isState bool) int
-	HasPowerLevel(roomID, userID string, eventType gomatrix.EventType, isState bool) bool
+	GetPowerLevelRequirement(roomID string, eventType gomatrix.EventType) int
+	HasPowerLevel(roomID, userID string, eventType gomatrix.EventType) bool
 }
 
 func (as *AppService) UpdateState(evt *gomatrix.Event) {
@@ -126,7 +128,21 @@ func (store *BasicStateStore) GetMembership(roomID, userID string) string {
 }
 
 func (store *BasicStateStore) IsInRoom(roomID, userID string) bool {
-	return store.GetMembership(roomID, userID) == "join"
+	return store.IsMembership(roomID, userID, "join")
+}
+
+func (store *BasicStateStore) IsInvited(roomID, userID string) bool {
+	return store.IsMembership(roomID, userID, "join", "invite")
+}
+
+func (store *BasicStateStore) IsMembership(roomID, userID string, allowedMemberships ...string) bool {
+	membership := store.GetMembership(roomID, userID)
+	for _, allowedMembership := range allowedMemberships {
+		if allowedMembership == membership {
+			return true
+		}
+	}
+	return false
 }
 
 func (store *BasicStateStore) SetMembership(roomID, userID, membership string) {
@@ -160,19 +176,10 @@ func (store *BasicStateStore) GetPowerLevel(roomID, userID string) int {
 	return store.GetPowerLevels(roomID).GetUserLevel(userID)
 }
 
-func (store *BasicStateStore) GetPowerLevelRequirement(roomID string, eventType gomatrix.EventType, isState bool) int {
-	levels := store.GetPowerLevels(roomID)
-	switch eventType {
-	case "kick":
-		return levels.Kick()
-	case "invite":
-		return levels.Invite()
-	case "redact":
-		return levels.Redact()
-	}
-	return levels.GetEventLevel(eventType, isState)
+func (store *BasicStateStore) GetPowerLevelRequirement(roomID string, eventType gomatrix.EventType) int {
+	return store.GetPowerLevels(roomID).GetEventLevel(eventType)
 }
 
-func (store *BasicStateStore) HasPowerLevel(roomID, userID string, eventType gomatrix.EventType, isState bool) bool {
-	return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType, isState)
+func (store *BasicStateStore) HasPowerLevel(roomID, userID string, eventType gomatrix.EventType) bool {
+	return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
 }