ソースを参照

Initial desegregation of users and automatic config updating

Tulir Asokan 6 年 前
コミット
c7348f29b0

+ 1 - 0
.gitignore

@@ -6,3 +6,4 @@
 *.session
 *.session
 *.json
 *.json
 *.db
 *.db
+*.log

+ 2 - 2
Gopkg.lock

@@ -123,7 +123,7 @@
     ".",
     ".",
     "format"
     "format"
   ]
   ]
-  revision = "ead1f970c8f56d1854cb9eb4a54c03aa6dafd753"
+  revision = "42a3133c4980e4b1ea5fb52329d977f592d67cf0"
 
 
 [[projects]]
 [[projects]]
   branch = "master"
   branch = "master"
@@ -141,7 +141,7 @@
   branch = "master"
   branch = "master"
   name = "maunium.net/go/mautrix-appservice"
   name = "maunium.net/go/mautrix-appservice"
   packages = ["."]
   packages = ["."]
-  revision = "269f2ab602126a2de94bc86a457392426cce1ab2"
+  revision = "37d4449056015cea5d0a4420bba595c61ad32007"
 
 
 [solve-meta]
 [solve-meta]
   analyzer-name = "dep"
   analyzer-name = "dep"

+ 2 - 2
commands.go

@@ -48,7 +48,7 @@ type CommandEvent struct {
 func (ce *CommandEvent) Reply(msg string) {
 func (ce *CommandEvent) Reply(msg string) {
 	_, err := ce.Bot.SendNotice(string(ce.RoomID), msg)
 	_, err := ce.Bot.SendNotice(string(ce.RoomID), msg)
 	if err != nil {
 	if err != nil {
-		ce.Handler.log.Warnfln("Failed to reply to command from %s: %v", ce.User.ID, err)
+		ce.Handler.log.Warnfln("Failed to reply to command from %s: %v", ce.User.MXID, err)
 	}
 	}
 }
 }
 
 
@@ -56,7 +56,7 @@ func (handler *CommandHandler) Handle(roomID types.MatrixRoomID, user *User, mes
 	args := strings.Split(message, " ")
 	args := strings.Split(message, " ")
 	cmd := strings.ToLower(args[0])
 	cmd := strings.ToLower(args[0])
 	ce := &CommandEvent{
 	ce := &CommandEvent{
-		Bot:     handler.bridge.AppService.BotIntent(),
+		Bot:     handler.bridge.AS.BotIntent(),
 		Bridge:  handler.bridge,
 		Bridge:  handler.bridge,
 		Handler: handler,
 		Handler: handler,
 		RoomID:  roomID,
 		RoomID:  roomID,

+ 3 - 13
config/bridge.go

@@ -56,12 +56,7 @@ func (bc *BridgeConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
 	return err
 	return err
 }
 }
 
 
-type DisplaynameTemplateArgs struct {
-	Displayname string
-}
-
 type UsernameTemplateArgs struct {
 type UsernameTemplateArgs struct {
-	Receiver string
 	UserID   string
 	UserID   string
 }
 }
 
 
@@ -74,14 +69,9 @@ func (bc BridgeConfig) FormatDisplayname(contact whatsapp.Contact) string {
 	return buf.String()
 	return buf.String()
 }
 }
 
 
-func (bc BridgeConfig) FormatUsername(receiver types.MatrixUserID, userID types.WhatsAppID) string {
+func (bc BridgeConfig) FormatUsername(userID types.WhatsAppID) string {
 	var buf bytes.Buffer
 	var buf bytes.Buffer
-	receiver = strings.Replace(receiver, "@", "=40", 1)
-	receiver = strings.Replace(receiver, ":", "=3", 1)
-	bc.usernameTemplate.Execute(&buf, UsernameTemplateArgs{
-		Receiver: receiver,
-		UserID:   userID,
-	})
+	bc.usernameTemplate.Execute(&buf, userID)
 	return buf.String()
 	return buf.String()
 }
 }
 
 
@@ -92,7 +82,7 @@ func (bc BridgeConfig) MarshalYAML() (interface{}, error) {
 		Name:   "{{.Name}}",
 		Name:   "{{.Name}}",
 		Short:  "{{.Short}}",
 		Short:  "{{.Short}}",
 	})
 	})
-	bc.UsernameTemplate = bc.FormatUsername("{{.Receiver}}", "{{.UserID}}")
+	bc.UsernameTemplate = bc.FormatUsername("{{.}}")
 	return bc, nil
 	return bc, nil
 }
 }
 
 

+ 0 - 1
config/config.go

@@ -78,7 +78,6 @@ func (config *Config) Save(path string) error {
 
 
 func (config *Config) MakeAppService() (*appservice.AppService, error) {
 func (config *Config) MakeAppService() (*appservice.AppService, error) {
 	as := appservice.Create()
 	as := appservice.Create()
-	as.LogConfig = config.Logging
 	as.HomeserverDomain = config.Homeserver.Domain
 	as.HomeserverDomain = config.Homeserver.Domain
 	as.HomeserverURL = config.Homeserver.Address
 	as.HomeserverURL = config.Homeserver.Address
 	as.Host.Hostname = config.AppService.Hostname
 	as.Host.Hostname = config.AppService.Hostname

+ 115 - 0
config/recursivemap.go

@@ -0,0 +1,115 @@
+// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
+// Copyright (C) 2018 Tulir Asokan
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+package config
+
+import (
+	"strings"
+)
+
+type RecursiveMap map[interface{}]interface{}
+
+func (rm RecursiveMap) GetDefault(path string, defVal interface{}) interface{} {
+	val, ok := rm.Get(path)
+	if !ok {
+		return defVal
+	}
+	return val
+}
+
+func (rm RecursiveMap) GetMap(path string) RecursiveMap {
+	val := rm.GetDefault(path, nil)
+	if val == nil {
+		return nil
+	}
+
+	newRM, ok := val.(map[interface{}]interface{})
+	if ok {
+		return RecursiveMap(newRM)
+	}
+	return nil
+}
+
+func (rm RecursiveMap) Get(path string) (interface{}, bool) {
+	if index := strings.IndexRune(path, '.'); index >= 0 {
+		key := path[:index]
+		path = path[index+1:]
+
+		submap := rm.GetMap(key)
+		if submap == nil {
+			return nil, false
+		}
+		return submap.Get(path)
+	}
+	val, ok := rm[path]
+	return val, ok
+}
+
+func (rm RecursiveMap) GetIntDefault(path string, defVal int) int {
+	val, ok := rm.GetInt(path)
+	if !ok {
+		return defVal
+	}
+	return val
+}
+
+func (rm RecursiveMap) GetInt(path string) (int, bool) {
+	val, ok := rm.Get(path)
+	if !ok {
+		return 0, ok
+	}
+	intVal, ok := val.(int)
+	return intVal, ok
+}
+
+func (rm RecursiveMap) GetStringDefault(path string, defVal string) string {
+	val, ok := rm.GetString(path)
+	if !ok {
+		return defVal
+	}
+	return val
+}
+
+func (rm RecursiveMap) GetString(path string) (string, bool) {
+	val, ok := rm.Get(path)
+	if !ok {
+		return "", ok
+	}
+	strVal, ok := val.(string)
+	return strVal, ok
+}
+
+func (rm RecursiveMap) Set(path string, value interface{}) {
+	if index := strings.IndexRune(path, '.'); index >= 0 {
+		key := path[:index]
+		path = path[index+1:]
+		nextRM := rm.GetMap(key)
+		if nextRM == nil {
+			nextRM = make(RecursiveMap)
+			rm[key] = nextRM
+		}
+		nextRM.Set(path, value)
+		return
+	}
+	rm[path] = value
+}
+
+func (rm RecursiveMap) CopyFrom(otherRM RecursiveMap, path string) {
+	val, ok := otherRM.Get(path)
+	if ok {
+		rm.Set(path, val)
+	}
+}

+ 1 - 1
config/registration.go

@@ -56,7 +56,7 @@ func (config *Config) copyToRegistration(registration *appservice.Registration)
 	registration.SenderLocalpart = config.AppService.Bot.Username
 	registration.SenderLocalpart = config.AppService.Bot.Username
 
 
 	userIDRegex, err := regexp.Compile(fmt.Sprintf("^@%s:%s$",
 	userIDRegex, err := regexp.Compile(fmt.Sprintf("^@%s:%s$",
-		config.Bridge.FormatUsername(".+", "[0-9]+"),
+		config.Bridge.FormatUsername("[0-9]+"),
 		config.Homeserver.Domain))
 		config.Homeserver.Domain))
 	if err != nil {
 	if err != nil {
 		return err
 		return err

+ 100 - 0
config/update.go

@@ -0,0 +1,100 @@
+// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
+// Copyright (C) 2018 Tulir Asokan
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+package config
+
+import (
+	"io/ioutil"
+
+	"gopkg.in/yaml.v2"
+)
+
+func Update(path, basePath string) error {
+	oldCfgData, err := ioutil.ReadFile(path)
+	if err != nil {
+		return err
+	}
+
+	oldCfg := make(RecursiveMap)
+	err = yaml.Unmarshal(oldCfgData, &oldCfg)
+	if err != nil {
+		return err
+	}
+
+	baseCfgData, err := ioutil.ReadFile(basePath)
+	if err != nil {
+		return err
+	}
+
+	baseCfg := make(RecursiveMap)
+	err = yaml.Unmarshal(baseCfgData, &baseCfg)
+	if err != nil {
+		return err
+	}
+
+	err = runUpdate(oldCfg, baseCfg)
+	if err != nil {
+		return err
+	}
+
+	newCfgData, err := yaml.Marshal(&baseCfg)
+	if err != nil {
+		return err
+	}
+
+	return ioutil.WriteFile(path, newCfgData, 0600)
+}
+
+func runUpdate(oldCfg, newCfg RecursiveMap) error {
+	cp := func(path string) {
+		newCfg.CopyFrom(oldCfg, path)
+	}
+
+	cp("homeserver.address")
+	cp("homeserver.domain")
+
+	cp("appservice.address")
+	cp("appservice.hostname")
+	cp("appservice.port")
+
+	cp("appservice.database.type")
+	cp("appservice.database.uri")
+	cp("appservice.state_store_path")
+
+	cp("appservice.id")
+	cp("appservice.bot.username")
+	cp("appservice.bot.displayname")
+	cp("appservice.bot.avatar")
+
+	cp("appservice.bot.as_token")
+	cp("appservice.bot.hs_token")
+
+	cp("bridge.username_template")
+	cp("bridge.displayname_template")
+
+	cp("bridge.command_prefix")
+
+	cp("bridge.permissions")
+
+	cp("logging.directory")
+	cp("logging.file_name_format")
+	cp("logging.file_date_format")
+	cp("logging.file_mode")
+	cp("logging.timestamp_format")
+	cp("logging.print_level")
+
+	return nil
+}

+ 18 - 17
database/message.go

@@ -30,12 +30,13 @@ type MessageQuery struct {
 
 
 func (mq *MessageQuery) CreateTable() error {
 func (mq *MessageQuery) CreateTable() error {
 	_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
 	_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
-		owner VARCHAR(255),
-		jid   VARCHAR(255),
-		mxid  VARCHAR(255) NOT NULL UNIQUE,
+		chat_jid      VARCHAR(25) NOT NULL,
+		chat_receiver VARCHAR(25) NOT NULL,
+		jid  VARCHAR(255) NOT NULL,
+		mxid VARCHAR(255) NOT NULL UNIQUE,
 
 
-		PRIMARY KEY (owner, jid),
-		FOREIGN KEY (owner) REFERENCES user(mxid)
+		PRIMARY KEY (chat_jid, jid),
+		FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
 	)`)
 	)`)
 	return err
 	return err
 }
 }
@@ -47,8 +48,8 @@ func (mq *MessageQuery) New() *Message {
 	}
 	}
 }
 }
 
 
-func (mq *MessageQuery) GetAll(owner types.MatrixUserID) (messages []*Message) {
-	rows, err := mq.db.Query("SELECT * FROM message WHERE owner=?", owner)
+func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
+	rows, err := mq.db.Query("SELECT * FROM message WHERE chat_jid=? AND chat_receiver=?", chat.JID, chat.Receiver)
 	if err != nil || rows == nil {
 	if err != nil || rows == nil {
 		return nil
 		return nil
 	}
 	}
@@ -59,8 +60,8 @@ func (mq *MessageQuery) GetAll(owner types.MatrixUserID) (messages []*Message) {
 	return
 	return
 }
 }
 
 
-func (mq *MessageQuery) GetByJID(owner types.MatrixUserID, jid types.WhatsAppMessageID) *Message {
-	return mq.get("SELECT * FROM message WHERE owner=? AND jid=?", owner, jid)
+func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.WhatsAppMessageID) *Message {
+	return mq.get("SELECT * FROM message WHERE chat_jid=? AND chat_receiver=? AND jid=?", chat.JID, chat.Receiver, jid)
 }
 }
 
 
 func (mq *MessageQuery) GetByMXID(mxid types.MatrixEventID) *Message {
 func (mq *MessageQuery) GetByMXID(mxid types.MatrixEventID) *Message {
@@ -79,13 +80,13 @@ type Message struct {
 	db  *Database
 	db  *Database
 	log log.Logger
 	log log.Logger
 
 
-	Owner types.MatrixUserID
-	JID   types.WhatsAppMessageID
-	MXID  types.MatrixEventID
+	Chat PortalKey
+	JID     types.WhatsAppMessageID
+	MXID    types.MatrixEventID
 }
 }
 
 
 func (msg *Message) Scan(row Scannable) *Message {
 func (msg *Message) Scan(row Scannable) *Message {
-	err := row.Scan(&msg.Owner, &msg.JID, &msg.MXID)
+	err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID)
 	if err != nil {
 	if err != nil {
 		if err != sql.ErrNoRows {
 		if err != sql.ErrNoRows {
 			msg.log.Errorln("Database scan failed:", err)
 			msg.log.Errorln("Database scan failed:", err)
@@ -96,17 +97,17 @@ func (msg *Message) Scan(row Scannable) *Message {
 }
 }
 
 
 func (msg *Message) Insert() error {
 func (msg *Message) Insert() error {
-	_, err := msg.db.Exec("INSERT INTO message VALUES (?, ?, ?)", msg.Owner, msg.JID, msg.MXID)
+	_, err := msg.db.Exec("INSERT INTO message VALUES (?, ?, ?)", msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID)
 	if err != nil {
 	if err != nil {
-		msg.log.Warnfln("Failed to update %s->%s: %v", msg.Owner, msg.JID, err)
+		msg.log.Warnfln("Failed to update %s: %v", msg.Chat, msg.JID, err)
 	}
 	}
 	return err
 	return err
 }
 }
 
 
 func (msg *Message) Update() error {
 func (msg *Message) Update() error {
-	_, err := msg.db.Exec("UPDATE portal SET mxid=? WHERE owner=? AND jid=?", msg.MXID, msg.Owner, msg.JID)
+	_, err := msg.db.Exec("UPDATE portal SET mxid=? WHERE chat_jid=? AND chat_receiver=? AND jid=?", msg.MXID, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
 	if err != nil {
 	if err != nil {
-		msg.log.Warnfln("Failed to update %s->%s: %v", msg.Owner, msg.JID, err)
+		msg.log.Warnfln("Failed to update %s: %v", msg.Chat, msg.JID, err)
 	}
 	}
 	return err
 	return err
 }
 }

+ 53 - 21
database/portal.go

@@ -18,11 +18,41 @@ package database
 
 
 import (
 import (
 	"database/sql"
 	"database/sql"
+	"strings"
 
 
 	log "maunium.net/go/maulogger"
 	log "maunium.net/go/maulogger"
 	"maunium.net/go/mautrix-whatsapp/types"
 	"maunium.net/go/mautrix-whatsapp/types"
 )
 )
 
 
+type PortalKey struct {
+	JID      types.WhatsAppID
+	Receiver types.WhatsAppID
+}
+
+func GroupPortalKey(jid types.WhatsAppID) PortalKey {
+	return PortalKey{
+		JID:      jid,
+		Receiver: jid,
+	}
+}
+
+func NewPortalKey(jid, receiver types.WhatsAppID) PortalKey {
+	if strings.HasSuffix(jid, "@g.us") {
+		receiver = jid
+	}
+	return PortalKey{
+		JID: jid,
+		Receiver: receiver,
+	}
+}
+
+func (key PortalKey) String() string {
+	if key.Receiver == key.JID {
+		return key.JID
+	}
+	return key.JID + "-" + key.Receiver
+}
+
 type PortalQuery struct {
 type PortalQuery struct {
 	db  *Database
 	db  *Database
 	log log.Logger
 	log log.Logger
@@ -30,16 +60,16 @@ type PortalQuery struct {
 
 
 func (pq *PortalQuery) CreateTable() error {
 func (pq *PortalQuery) CreateTable() error {
 	_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS portal (
 	_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS portal (
-		jid   VARCHAR(255),
-		owner VARCHAR(255),
-		mxid  VARCHAR(255) UNIQUE,
+		jid      VARCHAR(25),
+		receiver VARCHAR(25),
+		mxid     VARCHAR(255) UNIQUE,
 
 
 		name   VARCHAR(255) NOT NULL,
 		name   VARCHAR(255) NOT NULL,
 		topic  VARCHAR(255) NOT NULL,
 		topic  VARCHAR(255) NOT NULL,
 		avatar VARCHAR(255) NOT NULL,
 		avatar VARCHAR(255) NOT NULL,
 
 
-		PRIMARY KEY (jid, owner),
-		FOREIGN KEY (owner) REFERENCES user(mxid)
+		PRIMARY KEY (jid, receiver),
+		FOREIGN KEY (receiver) REFERENCES user(mxid)
 	)`)
 	)`)
 	return err
 	return err
 }
 }
@@ -51,8 +81,8 @@ func (pq *PortalQuery) New() *Portal {
 	}
 	}
 }
 }
 
 
-func (pq *PortalQuery) GetAll(owner types.MatrixUserID) (portals []*Portal) {
-	rows, err := pq.db.Query("SELECT * FROM portal WHERE owner=?", owner)
+func (pq *PortalQuery) GetAll() (portals []*Portal) {
+	rows, err := pq.db.Query("SELECT * FROM portal")
 	if err != nil || rows == nil {
 	if err != nil || rows == nil {
 		return nil
 		return nil
 	}
 	}
@@ -63,8 +93,8 @@ func (pq *PortalQuery) GetAll(owner types.MatrixUserID) (portals []*Portal) {
 	return
 	return
 }
 }
 
 
-func (pq *PortalQuery) GetByJID(owner types.MatrixUserID, jid types.WhatsAppID) *Portal {
-	return pq.get("SELECT * FROM portal WHERE jid=? AND owner=?", jid, owner)
+func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
+	return pq.get("SELECT * FROM portal WHERE jid=? AND receiver=?", key.JID, key.Receiver)
 }
 }
 
 
 func (pq *PortalQuery) GetByMXID(mxid types.MatrixRoomID) *Portal {
 func (pq *PortalQuery) GetByMXID(mxid types.MatrixRoomID) *Portal {
@@ -83,9 +113,8 @@ type Portal struct {
 	db  *Database
 	db  *Database
 	log log.Logger
 	log log.Logger
 
 
-	JID   types.WhatsAppID
-	MXID  types.MatrixRoomID
-	Owner types.MatrixUserID
+	Key  PortalKey
+	MXID types.MatrixRoomID
 
 
 	Name   string
 	Name   string
 	Topic  string
 	Topic  string
@@ -93,7 +122,7 @@ type Portal struct {
 }
 }
 
 
 func (portal *Portal) Scan(row Scannable) *Portal {
 func (portal *Portal) Scan(row Scannable) *Portal {
-	err := row.Scan(&portal.JID, &portal.Owner, &portal.MXID, &portal.Name, &portal.Topic, &portal.Avatar)
+	err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &portal.MXID, &portal.Name, &portal.Topic, &portal.Avatar)
 	if err != nil {
 	if err != nil {
 		if err != sql.ErrNoRows {
 		if err != sql.ErrNoRows {
 			portal.log.Errorln("Database scan failed:", err)
 			portal.log.Errorln("Database scan failed:", err)
@@ -103,15 +132,18 @@ func (portal *Portal) Scan(row Scannable) *Portal {
 	return portal
 	return portal
 }
 }
 
 
-func (portal *Portal) Insert() error {
-	var mxid *string
+func (portal *Portal) mxidPtr() *string {
 	if len(portal.MXID) > 0 {
 	if len(portal.MXID) > 0 {
-		mxid = &portal.MXID
+		return &portal.MXID
 	}
 	}
+	return nil
+}
+
+func (portal *Portal) Insert() error {
 	_, err := portal.db.Exec("INSERT INTO portal VALUES (?, ?, ?, ?, ?, ?)",
 	_, err := portal.db.Exec("INSERT INTO portal VALUES (?, ?, ?, ?, ?, ?)",
-		portal.JID, portal.Owner, mxid, portal.Name, portal.Topic, portal.Avatar)
+		portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar)
 	if err != nil {
 	if err != nil {
-		portal.log.Warnfln("Failed to insert %s->%s: %v", portal.JID, portal.Owner, err)
+		portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
 	}
 	}
 	return err
 	return err
 }
 }
@@ -121,10 +153,10 @@ func (portal *Portal) Update() error {
 	if len(portal.MXID) > 0 {
 	if len(portal.MXID) > 0 {
 		mxid = &portal.MXID
 		mxid = &portal.MXID
 	}
 	}
-	_, err := portal.db.Exec("UPDATE portal SET mxid=?, name=?, topic=?, avatar=? WHERE jid=? AND owner=?",
-		mxid, portal.Name, portal.Topic, portal.Avatar, portal.JID, portal.Owner)
+	_, err := portal.db.Exec("UPDATE portal SET mxid=?, name=?, topic=?, avatar=? WHERE jid=? AND receiver=?",
+		mxid, portal.Name, portal.Topic, portal.Avatar, portal.Key.JID, portal.Key.Receiver)
 	if err != nil {
 	if err != nil {
-		portal.log.Warnfln("Failed to update %s->%s: %v", portal.JID, portal.Owner, err)
+		portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
 	}
 	}
 	return err
 	return err
 }
 }

+ 14 - 21
database/puppet.go

@@ -30,13 +30,9 @@ type PuppetQuery struct {
 
 
 func (pq *PuppetQuery) CreateTable() error {
 func (pq *PuppetQuery) CreateTable() error {
 	_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS puppet (
 	_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS puppet (
-		jid      VARCHAR(255),
-		receiver VARCHAR(255),
-
+		jid         VARCHAR(25) PRIMARY KEY,
 		displayname VARCHAR(255),
 		displayname VARCHAR(255),
-		avatar      VARCHAR(255),
-
-		PRIMARY KEY(jid, receiver)
+		avatar      VARCHAR(255)
 	)`)
 	)`)
 	return err
 	return err
 }
 }
@@ -48,8 +44,8 @@ func (pq *PuppetQuery) New() *Puppet {
 	}
 	}
 }
 }
 
 
-func (pq *PuppetQuery) GetAll(receiver types.MatrixUserID) (puppets []*Puppet) {
-	rows, err := pq.db.Query("SELECT * FROM puppet WHERE receiver=%s")
+func (pq *PuppetQuery) GetAll() (puppets []*Puppet) {
+	rows, err := pq.db.Query("SELECT * FROM puppet")
 	if err != nil || rows == nil {
 	if err != nil || rows == nil {
 		return nil
 		return nil
 	}
 	}
@@ -60,8 +56,8 @@ func (pq *PuppetQuery) GetAll(receiver types.MatrixUserID) (puppets []*Puppet) {
 	return
 	return
 }
 }
 
 
-func (pq *PuppetQuery) Get(jid types.WhatsAppID, receiver types.MatrixUserID) *Puppet {
-	row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=? AND receiver=?", jid, receiver)
+func (pq *PuppetQuery) Get(jid types.WhatsAppID) *Puppet {
+	row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=?", jid)
 	if row == nil {
 	if row == nil {
 		return nil
 		return nil
 	}
 	}
@@ -72,15 +68,13 @@ type Puppet struct {
 	db  *Database
 	db  *Database
 	log log.Logger
 	log log.Logger
 
 
-	JID      types.WhatsAppID
-	Receiver types.MatrixUserID
-
+	JID         types.WhatsAppID
 	Displayname string
 	Displayname string
 	Avatar      string
 	Avatar      string
 }
 }
 
 
 func (puppet *Puppet) Scan(row Scannable) *Puppet {
 func (puppet *Puppet) Scan(row Scannable) *Puppet {
-	err := row.Scan(&puppet.JID, &puppet.Receiver, &puppet.Displayname, &puppet.Avatar)
+	err := row.Scan(&puppet.JID, &puppet.Displayname, &puppet.Avatar)
 	if err != nil {
 	if err != nil {
 		if err != sql.ErrNoRows {
 		if err != sql.ErrNoRows {
 			puppet.log.Errorln("Database scan failed:", err)
 			puppet.log.Errorln("Database scan failed:", err)
@@ -91,20 +85,19 @@ func (puppet *Puppet) Scan(row Scannable) *Puppet {
 }
 }
 
 
 func (puppet *Puppet) Insert() error {
 func (puppet *Puppet) Insert() error {
-	_, err := puppet.db.Exec("INSERT INTO puppet VALUES (?, ?, ?, ?)",
-		puppet.JID, puppet.Receiver, puppet.Displayname, puppet.Avatar)
+	_, err := puppet.db.Exec("INSERT INTO puppet VALUES (?, ?, ?)",
+		puppet.JID, puppet.Displayname, puppet.Avatar)
 	if err != nil {
 	if err != nil {
-		puppet.log.Errorfln("Failed to insert %s->%s: %v", puppet.JID, puppet.Receiver, err)
+		puppet.log.Errorfln("Failed to insert %s: %v", puppet.JID, err)
 	}
 	}
 	return err
 	return err
 }
 }
 
 
 func (puppet *Puppet) Update() error {
 func (puppet *Puppet) Update() error {
-	_, err := puppet.db.Exec("UPDATE puppet SET displayname=?, avatar=? WHERE jid=? AND receiver=?",
-		puppet.Displayname, puppet.Avatar,
-		puppet.JID, puppet.Receiver)
+	_, err := puppet.db.Exec("UPDATE puppet SET displayname=?, avatar=? WHERE jid=?",
+		puppet.Displayname, puppet.Avatar, puppet.JID)
 	if err != nil {
 	if err != nil {
-		puppet.log.Errorfln("Failed to update %s->%s: %v", puppet.JID, puppet.Receiver, err)
+		puppet.log.Errorfln("Failed to update %s->%s: %v", puppet.JID, err)
 	}
 	}
 	return err
 	return err
 }
 }

+ 32 - 13
database/user.go

@@ -32,6 +32,7 @@ type UserQuery struct {
 func (uq *UserQuery) CreateTable() error {
 func (uq *UserQuery) CreateTable() error {
 	_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS user (
 	_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS user (
 		mxid VARCHAR(255) PRIMARY KEY,
 		mxid VARCHAR(255) PRIMARY KEY,
+		jid  VARCHAR(25)  UNIQUE,
 
 
 		management_room VARCHAR(255),
 		management_room VARCHAR(255),
 
 
@@ -64,7 +65,7 @@ func (uq *UserQuery) GetAll() (users []*User) {
 	return
 	return
 }
 }
 
 
-func (uq *UserQuery) Get(userID types.MatrixUserID) *User {
+func (uq *UserQuery) GetByMXID(userID types.MatrixUserID) *User {
 	row := uq.db.QueryRow("SELECT * FROM user WHERE mxid=?", userID)
 	row := uq.db.QueryRow("SELECT * FROM user WHERE mxid=?", userID)
 	if row == nil {
 	if row == nil {
 		return nil
 		return nil
@@ -72,18 +73,27 @@ func (uq *UserQuery) Get(userID types.MatrixUserID) *User {
 	return uq.New().Scan(row)
 	return uq.New().Scan(row)
 }
 }
 
 
+func (uq *UserQuery) GetByJID(userID types.WhatsAppID) *User {
+	row := uq.db.QueryRow("SELECT * FROM user WHERE jid=?", userID)
+	if row == nil {
+		return nil
+	}
+	return uq.New().Scan(row)
+}
+
 type User struct {
 type User struct {
 	db  *Database
 	db  *Database
 	log log.Logger
 	log log.Logger
 
 
-	ID             types.MatrixUserID
+	MXID           types.MatrixUserID
+	JID            types.WhatsAppID
 	ManagementRoom types.MatrixRoomID
 	ManagementRoom types.MatrixRoomID
 	Session        *whatsapp.Session
 	Session        *whatsapp.Session
 }
 }
 
 
 func (user *User) Scan(row Scannable) *User {
 func (user *User) Scan(row Scannable) *User {
 	sess := whatsapp.Session{}
 	sess := whatsapp.Session{}
-	err := row.Scan(&user.ID, &user.ManagementRoom, &sess.ClientId, &sess.ClientToken, &sess.ServerToken,
+	err := row.Scan(&user.MXID, &user.JID, &user.ManagementRoom, &sess.ClientId, &sess.ClientToken, &sess.ServerToken,
 		&sess.EncKey, &sess.MacKey, &sess.Wid)
 		&sess.EncKey, &sess.MacKey, &sess.Wid)
 	if err != nil {
 	if err != nil {
 		if err != sql.ErrNoRows {
 		if err != sql.ErrNoRows {
@@ -99,23 +109,32 @@ func (user *User) Scan(row Scannable) *User {
 	return user
 	return user
 }
 }
 
 
-func (user *User) Insert() error {
-	var sess whatsapp.Session
+func (user *User) jidPtr() *string {
+	if len(user.JID) > 0 {
+		return &user.JID
+	}
+	return nil
+}
+
+func (user *User) sessionUnptr() (sess whatsapp.Session) {
 	if user.Session != nil {
 	if user.Session != nil {
 		sess = *user.Session
 		sess = *user.Session
 	}
 	}
-	_, err := user.db.Exec("INSERT INTO user VALUES (?, ?, ?, ?, ?, ?, ?, ?)", user.ID, user.ManagementRoom,
+	return
+}
+
+func (user *User) Insert() error {
+	sess := user.sessionUnptr()
+	_, err := user.db.Exec("INSERT INTO user VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", user.MXID, user.jidPtr(), user.ManagementRoom,
 		sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey, sess.Wid)
 		sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey, sess.Wid)
 	return err
 	return err
 }
 }
 
 
 func (user *User) Update() error {
 func (user *User) Update() error {
-	var sess whatsapp.Session
-	if user.Session != nil {
-		sess = *user.Session
-	}
-	_, err := user.db.Exec("UPDATE user SET management_room=?, client_id=?, client_token=?, server_token=?, enc_key=?, mac_key=?, wid=? WHERE mxid=?",
-		user.ManagementRoom,
-		sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey, sess.Wid, user.ID)
+	sess := user.sessionUnptr()
+	_, err := user.db.Exec("UPDATE user SET jid=?, management_room=?, client_id=?, client_token=?, server_token=?, enc_key=?, mac_key=?, wid=? WHERE mxid=?",
+		user.jidPtr(), user.ManagementRoom,
+		sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey, sess.Wid,
+		user.MXID)
 	return err
 	return err
 }
 }

+ 9 - 10
example-config.yaml

@@ -21,7 +21,6 @@ appservice:
     type: sqlite3
     type: sqlite3
     # The database URI. Usually file name. https://github.com/mattn/go-sqlite3#connection-string
     # The database URI. Usually file name. https://github.com/mattn/go-sqlite3#connection-string
     uri: mautrix-whatsapp.db
     uri: mautrix-whatsapp.db
-
   # Path to the Matrix room state store.
   # Path to the Matrix room state store.
   state_store_path: ./mx-state.json
   state_store_path: ./mx-state.json
 
 
@@ -43,15 +42,15 @@ appservice:
 # Bridge config. Currently unused.
 # Bridge config. Currently unused.
 bridge:
 bridge:
   # Localpart template of MXIDs for WhatsApp users.
   # Localpart template of MXIDs for WhatsApp users.
-  # {{.Receiver}} is replaced with the WhatsApp user ID of the Matrix user receiving messages.
-  # {{.UserID}} is replaced with the user ID of the WhatsApp user.
-  username_template: "whatsapp_{{.Receiver}}_{{.UserID}}"
+  # {{.}} is replaced with the phone number of the WhatsApp user.
+  username_template: whatsapp_{{.}}
   # Displayname template for WhatsApp users.
   # Displayname template for WhatsApp users.
-  # {{.Name}}   - display name
-  # {{.Short}}  - short display name (usually first name)
-  # {{.Notify}} - nickname (maybe set by the target WhatsApp user)
+  # {{.Notify}} - nickname set by the WhatsApp user
   # {{.Jid}}    - phone number (international format)
   # {{.Jid}}    - phone number (international format)
-  displayname_template: "{{if .Name}}{{.Name}}{{else if .Notify}}{{.Notify}}{{else if .Short}}{{.Short}}{{else}}{{.Jid}}{{end}}"
+  # The following variables are also available, but will cause problems on multi-user instances:
+  # {{.Name}}   - display name from contact list
+  # {{.Short}}  - short display name from contact list
+  displayname_template: "{{if .Notify}}{{.Notify}}{{else}}{{.Jid}}{{end}} (WA)"
 
 
   # The prefix for commands. Only required in non-management rooms.
   # The prefix for commands. Only required in non-management rooms.
   command_prefix: "!wa"
   command_prefix: "!wa"
@@ -72,8 +71,8 @@ bridge:
 logging:
 logging:
   # The directory for log files. Will be created if not found.
   # The directory for log files. Will be created if not found.
   directory: ./logs
   directory: ./logs
-  # Available variables: .date for the file date and .index for different log files on the same day.
-  file_name_format: "{{.date}}-{{.index}.log"
+  # Available variables: .Date for the file date and .Index for different log files on the same day.
+  file_name_format: "{{.Date}}-{{.Index}}.log"
   # Date format for file names in the Go time format: https://golang.org/pkg/time/#pkg-constants
   # Date format for file names in the Go time format: https://golang.org/pkg/time/#pkg-constants
   file_date_format: 2006-01-02
   file_date_format: 2006-01-02
   # Log file permissions.
   # Log file permissions.

+ 91 - 49
formatting.go

@@ -18,58 +18,71 @@ package main
 
 
 import (
 import (
 	"fmt"
 	"fmt"
+	"html"
 	"regexp"
 	"regexp"
 	"strings"
 	"strings"
 
 
+	"maunium.net/go/gomatrix"
 	"maunium.net/go/gomatrix/format"
 	"maunium.net/go/gomatrix/format"
 	"maunium.net/go/mautrix-whatsapp/whatsapp-ext"
 	"maunium.net/go/mautrix-whatsapp/whatsapp-ext"
 )
 )
 
 
-func (user *User) newHTMLParser() *format.HTMLParser {
-	return &format.HTMLParser{
-		TabsToSpaces: 4,
-		Newline:      "\n",
-
-		PillConverter: func(mxid, eventID string) string {
-			if mxid[0] == '@' {
-				puppet := user.GetPuppetByMXID(mxid)
-				fmt.Println(mxid, puppet)
-				if puppet != nil {
-					return "@" + puppet.PhoneNumber()
-				}
-			}
-			return mxid
-		},
-		BoldConverter: func(text string) string {
-			return fmt.Sprintf("*%s*", text)
-		},
-		ItalicConverter: func(text string) string {
-			return fmt.Sprintf("_%s_", text)
-		},
-		StrikethroughConverter: func(text string) string {
-			return fmt.Sprintf("~%s~", text)
-		},
-		MonospaceConverter: func(text string) string {
-			return fmt.Sprintf("```%s```", text)
-		},
-		MonospaceBlockConverter: func(text string) string {
-			return fmt.Sprintf("```%s```", text)
-		},
-	}
-}
-
 var italicRegex = regexp.MustCompile("([\\s>~*]|^)_(.+?)_([^a-zA-Z\\d]|$)")
 var italicRegex = regexp.MustCompile("([\\s>~*]|^)_(.+?)_([^a-zA-Z\\d]|$)")
 var boldRegex = regexp.MustCompile("([\\s>_~]|^)\\*(.+?)\\*([^a-zA-Z\\d]|$)")
 var boldRegex = regexp.MustCompile("([\\s>_~]|^)\\*(.+?)\\*([^a-zA-Z\\d]|$)")
 var strikethroughRegex = regexp.MustCompile("([\\s>_*]|^)~(.+?)~([^a-zA-Z\\d]|$)")
 var strikethroughRegex = regexp.MustCompile("([\\s>_*]|^)~(.+?)~([^a-zA-Z\\d]|$)")
 var codeBlockRegex = regexp.MustCompile("```(?:.|\n)+?```")
 var codeBlockRegex = regexp.MustCompile("```(?:.|\n)+?```")
 var mentionRegex = regexp.MustCompile("@[0-9]+")
 var mentionRegex = regexp.MustCompile("@[0-9]+")
 
 
-func (user *User) newWhatsAppFormatMaps() (map[*regexp.Regexp]string, map[*regexp.Regexp]func(string) string, map[*regexp.Regexp]func(string) string) {
-	return map[*regexp.Regexp]string{
-		italicRegex:        "$1<em>$2</em>$3",
-		boldRegex:          "$1<strong>$2</strong>$3",
-		strikethroughRegex: "$1<del>$2</del>$3",
-	}, map[*regexp.Regexp]func(string) string{
+type Formatter struct {
+	bridge *Bridge
+
+	matrixHTMLParser *format.HTMLParser
+
+	waReplString   map[*regexp.Regexp]string
+	waReplFunc     map[*regexp.Regexp]func(string) string
+	waReplFuncText map[*regexp.Regexp]func(string) string
+}
+
+func NewFormatter(bridge *Bridge) *Formatter {
+	formatter := &Formatter{
+		bridge: bridge,
+		matrixHTMLParser: &format.HTMLParser{
+			TabsToSpaces: 4,
+			Newline:      "\n",
+
+			PillConverter: func(mxid, eventID string) string {
+				if mxid[0] == '@' {
+					puppet := bridge.GetPuppetByMXID(mxid)
+					fmt.Println(mxid, puppet)
+					if puppet != nil {
+						return "@" + puppet.PhoneNumber()
+					}
+				}
+				return mxid
+			},
+			BoldConverter: func(text string) string {
+				return fmt.Sprintf("*%s*", text)
+			},
+			ItalicConverter: func(text string) string {
+				return fmt.Sprintf("_%s_", text)
+			},
+			StrikethroughConverter: func(text string) string {
+				return fmt.Sprintf("~%s~", text)
+			},
+			MonospaceConverter: func(text string) string {
+				return fmt.Sprintf("```%s```", text)
+			},
+			MonospaceBlockConverter: func(text string) string {
+				return fmt.Sprintf("```%s```", text)
+			},
+		},
+		waReplString: map[*regexp.Regexp]string{
+			italicRegex:        "$1<em>$2</em>$3",
+			boldRegex:          "$1<strong>$2</strong>$3",
+			strikethroughRegex: "$1<del>$2</del>$3",
+		},
+	}
+	formatter.waReplFunc = map[*regexp.Regexp]func(string) string{
 		codeBlockRegex: func(str string) string {
 		codeBlockRegex: func(str string) string {
 			str = str[3 : len(str)-3]
 			str = str[3 : len(str)-3]
 			if strings.ContainsRune(str, '\n') {
 			if strings.ContainsRune(str, '\n') {
@@ -78,18 +91,47 @@ func (user *User) newWhatsAppFormatMaps() (map[*regexp.Regexp]string, map[*regex
 			return fmt.Sprintf("<code>%s</code>", str)
 			return fmt.Sprintf("<code>%s</code>", str)
 		},
 		},
 		mentionRegex: func(str string) string {
 		mentionRegex: func(str string) string {
-			jid := str[1:] + whatsappExt.NewUserSuffix
-			puppet := user.GetPuppetByJID(jid)
-			mxid := puppet.MXID
-			if jid == user.JID() {
-				mxid = user.ID
-			}
-			return fmt.Sprintf(`<a href="https://matrix.to/#/%s">%s</a>`, mxid, puppet.Displayname)
+			mxid, displayname := formatter.getMatrixInfoByJID(str[1:] + whatsappExt.NewUserSuffix)
+			return fmt.Sprintf(`<a href="https://matrix.to/#/%s">%s</a>`, mxid, displayname)
 		},
 		},
-	}, map[*regexp.Regexp]func(string)string {
+	}
+	formatter.waReplFuncText = map[*regexp.Regexp]func(string) string{
 		mentionRegex: func(str string) string {
 		mentionRegex: func(str string) string {
-			puppet := user.GetPuppetByJID(str[1:] + whatsappExt.NewUserSuffix)
-			return puppet.Displayname
+			_, displayname := formatter.getMatrixInfoByJID(str[1:] + whatsappExt.NewUserSuffix)
+			return displayname
 		},
 		},
 	}
 	}
+	return formatter
+}
+
+func (formatter *Formatter) getMatrixInfoByJID(jid string) (mxid, displayname string) {
+	if user := formatter.bridge.GetUserByJID(jid); user != nil {
+		mxid = user.MXID
+		displayname = user.MXID
+	} else if puppet := formatter.bridge.GetPuppetByJID(jid); puppet != nil {
+		mxid = puppet.MXID
+		displayname = puppet.Displayname
+	}
+	return
+}
+
+func (formatter *Formatter) ParseWhatsApp(content *gomatrix.Content) {
+	output := html.EscapeString(content.Body)
+	for regex, replacement := range formatter.waReplString {
+		output = regex.ReplaceAllString(output, replacement)
+	}
+	for regex, replacer := range formatter.waReplFunc {
+		output = regex.ReplaceAllStringFunc(output, replacer)
+	}
+	if output != content.Body {
+		content.FormattedBody = output
+		content.Format = gomatrix.FormatHTML
+		for regex, replacer := range formatter.waReplFuncText {
+			content.Body = regex.ReplaceAllStringFunc(content.Body, replacer)
+		}
+	}
+}
+
+func (formatter *Formatter) ParseMatrix(html string) string {
+	return formatter.matrixHTMLParser.Parse(html)
 }
 }

+ 54 - 25
main.go

@@ -20,6 +20,7 @@ import (
 	"fmt"
 	"fmt"
 	"os"
 	"os"
 	"os/signal"
 	"os/signal"
+	"sync"
 	"syscall"
 	"syscall"
 
 
 	flag "maunium.net/go/mauflag"
 	flag "maunium.net/go/mauflag"
@@ -31,6 +32,7 @@ import (
 )
 )
 
 
 var configPath = flag.MakeFull("c", "config", "The path to your config file.", "config.yaml").String()
 var configPath = flag.MakeFull("c", "config", "The path to your config file.", "config.yaml").String()
+var baseConfigPath = flag.MakeFull("b", "base-config", "The path to the example config file.", "example-config.yaml").String()
 var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String()
 var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String()
 var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool()
 var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool()
 var wantHelp, _ = flag.MakeHelpFlag()
 var wantHelp, _ = flag.MakeHelpFlag()
@@ -58,29 +60,47 @@ func (bridge *Bridge) GenerateRegistration() {
 }
 }
 
 
 type Bridge struct {
 type Bridge struct {
-	AppService     *appservice.AppService
+	AS             *appservice.AppService
 	EventProcessor *appservice.EventProcessor
 	EventProcessor *appservice.EventProcessor
 	MatrixHandler  *MatrixHandler
 	MatrixHandler  *MatrixHandler
 	Config         *config.Config
 	Config         *config.Config
 	DB             *database.Database
 	DB             *database.Database
 	Log            log.Logger
 	Log            log.Logger
-
-	StateStore *AutosavingStateStore
-
-	users           map[types.MatrixUserID]*User
-	managementRooms map[types.MatrixRoomID]*User
+	StateStore     *AutosavingStateStore
+	Bot            *appservice.IntentAPI
+	Formatter      *Formatter
+
+	usersByMXID         map[types.MatrixUserID]*User
+	usersByJID          map[types.WhatsAppID]*User
+	usersLock           sync.Mutex
+	managementRooms     map[types.MatrixRoomID]*User
+	managementRoomsLock sync.Mutex
+	portalsByMXID       map[types.MatrixRoomID]*Portal
+	portalsByJID        map[database.PortalKey]*Portal
+	portalsLock         sync.Mutex
+	puppets             map[types.WhatsAppID]*Puppet
+	puppetsLock         sync.Mutex
 }
 }
 
 
 func NewBridge() *Bridge {
 func NewBridge() *Bridge {
 	bridge := &Bridge{
 	bridge := &Bridge{
-		users:           make(map[types.MatrixUserID]*User),
+		usersByMXID:     make(map[types.MatrixUserID]*User),
+		usersByJID:      make(map[types.WhatsAppID]*User),
 		managementRooms: make(map[types.MatrixRoomID]*User),
 		managementRooms: make(map[types.MatrixRoomID]*User),
+		portalsByMXID:   make(map[types.MatrixRoomID]*Portal),
+		portalsByJID:    make(map[database.PortalKey]*Portal),
+		puppets:         make(map[types.WhatsAppID]*Puppet),
 	}
 	}
-	var err error
+	err := config.Update(*configPath, *baseConfigPath)
+	if err != nil {
+		fmt.Fprintln(os.Stderr, "Failed to update config:", err)
+		os.Exit(10)
+	}
+
 	bridge.Config, err = config.Load(*configPath)
 	bridge.Config, err = config.Load(*configPath)
 	if err != nil {
 	if err != nil {
 		fmt.Fprintln(os.Stderr, "Failed to load config:", err)
 		fmt.Fprintln(os.Stderr, "Failed to load config:", err)
-		os.Exit(10)
+		os.Exit(11)
 	}
 	}
 	return bridge
 	return bridge
 }
 }
@@ -88,46 +108,55 @@ func NewBridge() *Bridge {
 func (bridge *Bridge) Init() {
 func (bridge *Bridge) Init() {
 	var err error
 	var err error
 
 
-	bridge.AppService, err = bridge.Config.MakeAppService()
+	bridge.AS, err = bridge.Config.MakeAppService()
 	if err != nil {
 	if err != nil {
 		fmt.Fprintln(os.Stderr, "Failed to initialize AppService:", err)
 		fmt.Fprintln(os.Stderr, "Failed to initialize AppService:", err)
-		os.Exit(11)
+		os.Exit(12)
 	}
 	}
-	bridge.AppService.Init()
-	bridge.Log = bridge.AppService.Log
+	bridge.AS.Init()
+	bridge.Bot = bridge.AS.BotIntent()
+
+	bridge.Log = log.Create()
+	bridge.Config.Logging.Configure(bridge.Log)
 	log.DefaultLogger = bridge.Log.(*log.BasicLogger)
 	log.DefaultLogger = bridge.Log.(*log.BasicLogger)
-	bridge.AppService.Log = log.Sub("Matrix")
+	err = log.OpenFile()
+	if err != nil {
+		fmt.Fprintln(os.Stderr, "Failed to open log file:", err)
+		os.Exit(13)
+	}
+	bridge.AS.Log = log.Sub("Matrix")
 
 
 	bridge.Log.Debugln("Initializing state store")
 	bridge.Log.Debugln("Initializing state store")
 	bridge.StateStore = NewAutosavingStateStore(bridge.Config.AppService.StateStore)
 	bridge.StateStore = NewAutosavingStateStore(bridge.Config.AppService.StateStore)
 	err = bridge.StateStore.Load()
 	err = bridge.StateStore.Load()
 	if err != nil {
 	if err != nil {
 		bridge.Log.Fatalln("Failed to load state store:", err)
 		bridge.Log.Fatalln("Failed to load state store:", err)
-		os.Exit(12)
+		os.Exit(14)
 	}
 	}
-	bridge.AppService.StateStore = bridge.StateStore
+	bridge.AS.StateStore = bridge.StateStore
 
 
 	bridge.Log.Debugln("Initializing database")
 	bridge.Log.Debugln("Initializing database")
 	bridge.DB, err = database.New(bridge.Config.AppService.Database.URI)
 	bridge.DB, err = database.New(bridge.Config.AppService.Database.URI)
 	if err != nil {
 	if err != nil {
 		bridge.Log.Fatalln("Failed to initialize database:", err)
 		bridge.Log.Fatalln("Failed to initialize database:", err)
-		os.Exit(13)
+		os.Exit(15)
 	}
 	}
 
 
 	bridge.Log.Debugln("Initializing Matrix event processor")
 	bridge.Log.Debugln("Initializing Matrix event processor")
-	bridge.EventProcessor = appservice.NewEventProcessor(bridge.AppService)
+	bridge.EventProcessor = appservice.NewEventProcessor(bridge.AS)
 	bridge.Log.Debugln("Initializing Matrix event handler")
 	bridge.Log.Debugln("Initializing Matrix event handler")
 	bridge.MatrixHandler = NewMatrixHandler(bridge)
 	bridge.MatrixHandler = NewMatrixHandler(bridge)
+	bridge.Formatter = NewFormatter(bridge)
 }
 }
 
 
 func (bridge *Bridge) Start() {
 func (bridge *Bridge) Start() {
 	err := bridge.DB.CreateTables()
 	err := bridge.DB.CreateTables()
 	if err != nil {
 	if err != nil {
 		bridge.Log.Fatalln("Failed to create database tables:", err)
 		bridge.Log.Fatalln("Failed to create database tables:", err)
-		os.Exit(14)
+		os.Exit(16)
 	}
 	}
 	bridge.Log.Debugln("Starting application service HTTP server")
 	bridge.Log.Debugln("Starting application service HTTP server")
-	go bridge.AppService.Start()
+	go bridge.AS.Start()
 	bridge.Log.Debugln("Starting event processor")
 	bridge.Log.Debugln("Starting event processor")
 	go bridge.EventProcessor.Start()
 	go bridge.EventProcessor.Start()
 	go bridge.UpdateBotProfile()
 	go bridge.UpdateBotProfile()
@@ -140,18 +169,18 @@ func (bridge *Bridge) UpdateBotProfile() {
 
 
 	var err error
 	var err error
 	if botConfig.Avatar == "remove" {
 	if botConfig.Avatar == "remove" {
-		err = bridge.AppService.BotIntent().SetAvatarURL("")
+		err = bridge.AS.BotIntent().SetAvatarURL("")
 	} else if len(botConfig.Avatar) > 0 {
 	} else if len(botConfig.Avatar) > 0 {
-		err = bridge.AppService.BotIntent().SetAvatarURL(botConfig.Avatar)
+		err = bridge.AS.BotIntent().SetAvatarURL(botConfig.Avatar)
 	}
 	}
 	if err != nil {
 	if err != nil {
 		bridge.Log.Warnln("Failed to update bot avatar:", err)
 		bridge.Log.Warnln("Failed to update bot avatar:", err)
 	}
 	}
 
 
 	if botConfig.Displayname == "remove" {
 	if botConfig.Displayname == "remove" {
-		err = bridge.AppService.BotIntent().SetDisplayName("")
+		err = bridge.AS.BotIntent().SetDisplayName("")
 	} else if len(botConfig.Avatar) > 0 {
 	} else if len(botConfig.Avatar) > 0 {
-		err = bridge.AppService.BotIntent().SetDisplayName(botConfig.Displayname)
+		err = bridge.AS.BotIntent().SetDisplayName(botConfig.Displayname)
 	}
 	}
 	if err != nil {
 	if err != nil {
 		bridge.Log.Warnln("Failed to update bot displayname:", err)
 		bridge.Log.Warnln("Failed to update bot displayname:", err)
@@ -165,7 +194,7 @@ func (bridge *Bridge) StartUsers() {
 }
 }
 
 
 func (bridge *Bridge) Stop() {
 func (bridge *Bridge) Stop() {
-	bridge.AppService.Stop()
+	bridge.AS.Stop()
 	bridge.EventProcessor.Stop()
 	bridge.EventProcessor.Stop()
 	err := bridge.StateStore.Save()
 	err := bridge.StateStore.Save()
 	if err != nil {
 	if err != nil {

+ 15 - 11
matrix.go

@@ -35,7 +35,7 @@ type MatrixHandler struct {
 func NewMatrixHandler(bridge *Bridge) *MatrixHandler {
 func NewMatrixHandler(bridge *Bridge) *MatrixHandler {
 	handler := &MatrixHandler{
 	handler := &MatrixHandler{
 		bridge: bridge,
 		bridge: bridge,
-		as:     bridge.AppService,
+		as:     bridge.AS,
 		log:    bridge.Log.Sub("Matrix"),
 		log:    bridge.Log.Sub("Matrix"),
 		cmd:    NewCommandHandler(bridge),
 		cmd:    NewCommandHandler(bridge),
 	}
 	}
@@ -50,7 +50,7 @@ func NewMatrixHandler(bridge *Bridge) *MatrixHandler {
 func (mx *MatrixHandler) HandleBotInvite(evt *gomatrix.Event) {
 func (mx *MatrixHandler) HandleBotInvite(evt *gomatrix.Event) {
 	intent := mx.as.BotIntent()
 	intent := mx.as.BotIntent()
 
 
-	user := mx.bridge.GetUser(evt.Sender)
+	user := mx.bridge.GetUserByMXID(evt.Sender)
 	if user == nil {
 	if user == nil {
 		return
 		return
 	}
 	}
@@ -85,7 +85,7 @@ func (mx *MatrixHandler) HandleBotInvite(evt *gomatrix.Event) {
 	for mxid, _ := range members.Joined {
 	for mxid, _ := range members.Joined {
 		if mxid == intent.UserID || mxid == evt.Sender {
 		if mxid == intent.UserID || mxid == evt.Sender {
 			continue
 			continue
-		} else if _, _, ok := mx.bridge.ParsePuppetMXID(types.MatrixUserID(mxid)); ok {
+		} else if _, ok := mx.bridge.ParsePuppetMXID(types.MatrixUserID(mxid)); ok {
 			hasPuppets = true
 			hasPuppets = true
 			continue
 			continue
 		}
 		}
@@ -96,7 +96,7 @@ func (mx *MatrixHandler) HandleBotInvite(evt *gomatrix.Event) {
 	}
 	}
 
 
 	if !hasPuppets {
 	if !hasPuppets {
-		user := mx.bridge.GetUser(types.MatrixUserID(evt.Sender))
+		user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender))
 		user.SetManagementRoom(types.MatrixRoomID(resp.RoomID))
 		user.SetManagementRoom(types.MatrixRoomID(resp.RoomID))
 		intent.SendNotice(string(user.ManagementRoom), "This room has been registered as your bridge management/status room.")
 		intent.SendNotice(string(user.ManagementRoom), "This room has been registered as your bridge management/status room.")
 		mx.log.Debugln(resp.RoomID, "registered as a management room with", evt.Sender)
 		mx.log.Debugln(resp.RoomID, "registered as a management room with", evt.Sender)
@@ -110,12 +110,12 @@ func (mx *MatrixHandler) HandleMembership(evt *gomatrix.Event) {
 }
 }
 
 
 func (mx *MatrixHandler) HandleRoomMetadata(evt *gomatrix.Event) {
 func (mx *MatrixHandler) HandleRoomMetadata(evt *gomatrix.Event) {
-	user := mx.bridge.GetUser(types.MatrixUserID(evt.Sender))
-	if user == nil || !user.Whitelisted {
+	user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender))
+	if user == nil || !user.Whitelisted || !user.IsLoggedIn() {
 		return
 		return
 	}
 	}
 
 
-	portal := user.GetPortalByMXID(evt.RoomID)
+	portal := mx.bridge.GetPortalByMXID(evt.RoomID)
 	if portal == nil || portal.IsPrivateChat() {
 	if portal == nil || portal.IsPrivateChat() {
 		return
 		return
 	}
 	}
@@ -124,7 +124,7 @@ func (mx *MatrixHandler) HandleRoomMetadata(evt *gomatrix.Event) {
 	var err error
 	var err error
 	switch evt.Type {
 	switch evt.Type {
 	case gomatrix.StateRoomName:
 	case gomatrix.StateRoomName:
-		resp, err = user.Conn.UpdateGroupSubject(evt.Content.Name, portal.JID)
+		resp, err = user.Conn.UpdateGroupSubject(evt.Content.Name, portal.Key.JID)
 	case gomatrix.StateRoomAvatar:
 	case gomatrix.StateRoomAvatar:
 		return
 		return
 	case gomatrix.StateTopic:
 	case gomatrix.StateTopic:
@@ -140,7 +140,7 @@ func (mx *MatrixHandler) HandleRoomMetadata(evt *gomatrix.Event) {
 
 
 func (mx *MatrixHandler) HandleMessage(evt *gomatrix.Event) {
 func (mx *MatrixHandler) HandleMessage(evt *gomatrix.Event) {
 	roomID := types.MatrixRoomID(evt.RoomID)
 	roomID := types.MatrixRoomID(evt.RoomID)
-	user := mx.bridge.GetUser(types.MatrixUserID(evt.Sender))
+	user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender))
 
 
 	if !user.Whitelisted {
 	if !user.Whitelisted {
 		return
 		return
@@ -158,8 +158,12 @@ func (mx *MatrixHandler) HandleMessage(evt *gomatrix.Event) {
 		}
 		}
 	}
 	}
 
 
-	portal := user.GetPortalByMXID(roomID)
+	if !user.IsLoggedIn() {
+		return
+	}
+
+	portal := mx.bridge.GetPortalByMXID(roomID)
 	if portal != nil {
 	if portal != nil {
-		portal.HandleMatrixMessage(evt)
+		portal.HandleMatrixMessage(user, evt)
 	}
 	}
 }
 }

+ 96 - 109
portal.go

@@ -20,7 +20,6 @@ import (
 	"bytes"
 	"bytes"
 	"encoding/hex"
 	"encoding/hex"
 	"fmt"
 	"fmt"
-	"html"
 	"image"
 	"image"
 	"image/gif"
 	"image/gif"
 	"image/jpeg"
 	"image/jpeg"
@@ -41,57 +40,56 @@ import (
 	"maunium.net/go/mautrix-whatsapp/whatsapp-ext"
 	"maunium.net/go/mautrix-whatsapp/whatsapp-ext"
 )
 )
 
 
-func (user *User) GetPortalByMXID(mxid types.MatrixRoomID) *Portal {
-	user.portalsLock.Lock()
-	defer user.portalsLock.Unlock()
-	portal, ok := user.portalsByMXID[mxid]
+func (bridge *Bridge) GetPortalByMXID(mxid types.MatrixRoomID) *Portal {
+	bridge.portalsLock.Lock()
+	defer bridge.portalsLock.Unlock()
+	portal, ok := bridge.portalsByMXID[mxid]
 	if !ok {
 	if !ok {
-		dbPortal := user.bridge.DB.Portal.GetByMXID(mxid)
-		if dbPortal == nil || dbPortal.Owner != user.ID {
+		dbPortal := bridge.DB.Portal.GetByMXID(mxid)
+		if dbPortal == nil {
 			return nil
 			return nil
 		}
 		}
-		portal = user.NewPortal(dbPortal)
-		user.portalsByJID[portal.JID] = portal
+		portal = bridge.NewPortal(dbPortal)
+		bridge.portalsByJID[portal.Key] = portal
 		if len(portal.MXID) > 0 {
 		if len(portal.MXID) > 0 {
-			user.portalsByMXID[portal.MXID] = portal
+			bridge.portalsByMXID[portal.MXID] = portal
 		}
 		}
 	}
 	}
 	return portal
 	return portal
 }
 }
 
 
-func (user *User) GetPortalByJID(jid types.WhatsAppID) *Portal {
-	user.portalsLock.Lock()
-	defer user.portalsLock.Unlock()
-	portal, ok := user.portalsByJID[jid]
+func (bridge *Bridge) GetPortalByJID(key database.PortalKey) *Portal {
+	bridge.portalsLock.Lock()
+	defer bridge.portalsLock.Unlock()
+	portal, ok := bridge.portalsByJID[key]
 	if !ok {
 	if !ok {
-		dbPortal := user.bridge.DB.Portal.GetByJID(user.ID, jid)
+		dbPortal := bridge.DB.Portal.GetByJID(key)
 		if dbPortal == nil {
 		if dbPortal == nil {
-			dbPortal = user.bridge.DB.Portal.New()
-			dbPortal.JID = jid
-			dbPortal.Owner = user.ID
+			dbPortal = bridge.DB.Portal.New()
+			dbPortal.Key = key
 			dbPortal.Insert()
 			dbPortal.Insert()
 		}
 		}
-		portal = user.NewPortal(dbPortal)
-		user.portalsByJID[portal.JID] = portal
+		portal = bridge.NewPortal(dbPortal)
+		bridge.portalsByJID[portal.Key] = portal
 		if len(portal.MXID) > 0 {
 		if len(portal.MXID) > 0 {
-			user.portalsByMXID[portal.MXID] = portal
+			bridge.portalsByMXID[portal.MXID] = portal
 		}
 		}
 	}
 	}
 	return portal
 	return portal
 }
 }
 
 
-func (user *User) GetAllPortals() []*Portal {
-	user.portalsLock.Lock()
-	defer user.portalsLock.Unlock()
-	dbPortals := user.bridge.DB.Portal.GetAll(user.ID)
+func (bridge *Bridge) GetAllPortals() []*Portal {
+	bridge.portalsLock.Lock()
+	defer bridge.portalsLock.Unlock()
+	dbPortals := bridge.DB.Portal.GetAll()
 	output := make([]*Portal, len(dbPortals))
 	output := make([]*Portal, len(dbPortals))
 	for index, dbPortal := range dbPortals {
 	for index, dbPortal := range dbPortals {
-		portal, ok := user.portalsByJID[dbPortal.JID]
+		portal, ok := bridge.portalsByJID[dbPortal.Key]
 		if !ok {
 		if !ok {
-			portal = user.NewPortal(dbPortal)
-			user.portalsByJID[dbPortal.JID] = portal
+			portal = bridge.NewPortal(dbPortal)
+			bridge.portalsByJID[portal.Key] = portal
 			if len(dbPortal.MXID) > 0 {
 			if len(dbPortal.MXID) > 0 {
-				user.portalsByMXID[dbPortal.MXID] = portal
+				bridge.portalsByMXID[dbPortal.MXID] = portal
 			}
 			}
 		}
 		}
 		output[index] = portal
 		output[index] = portal
@@ -99,19 +97,17 @@ func (user *User) GetAllPortals() []*Portal {
 	return output
 	return output
 }
 }
 
 
-func (user *User) NewPortal(dbPortal *database.Portal) *Portal {
+func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal {
 	return &Portal{
 	return &Portal{
 		Portal: dbPortal,
 		Portal: dbPortal,
-		user:   user,
-		bridge: user.bridge,
-		log:    user.log.Sub(fmt.Sprintf("Portal/%s", dbPortal.JID)),
+		bridge: bridge,
+		log:    bridge.Log.Sub(fmt.Sprintf("Portal/%s", dbPortal.Key)),
 	}
 	}
 }
 }
 
 
 type Portal struct {
 type Portal struct {
 	*database.Portal
 	*database.Portal
 
 
-	user   *User
 	bridge *Bridge
 	bridge *Bridge
 	log    log.Logger
 	log    log.Logger
 
 
@@ -126,9 +122,16 @@ func (portal *Portal) SyncParticipants(metadata *whatsappExt.GroupInfo) {
 		changed = true
 		changed = true
 	}
 	}
 	for _, participant := range metadata.Participants {
 	for _, participant := range metadata.Participants {
-		puppet := portal.user.GetPuppetByJID(participant.JID)
+		puppet := portal.bridge.GetPuppetByJID(participant.JID)
 		puppet.Intent().EnsureJoined(portal.MXID)
 		puppet.Intent().EnsureJoined(portal.MXID)
 
 
+		user := portal.bridge.GetUserByJID(participant.JID)
+		if !portal.bridge.AS.StateStore.IsInvited(portal.MXID, user.MXID) {
+			portal.MainIntent().InviteUser(portal.MXID, &gomatrix.ReqInviteUser{
+				UserID: user.MXID,
+			})
+		}
+
 		expectedLevel := 0
 		expectedLevel := 0
 		if participant.IsSuperAdmin {
 		if participant.IsSuperAdmin {
 			expectedLevel = 95
 			expectedLevel = 95
@@ -136,9 +139,8 @@ func (portal *Portal) SyncParticipants(metadata *whatsappExt.GroupInfo) {
 			expectedLevel = 50
 			expectedLevel = 50
 		}
 		}
 		changed = levels.EnsureUserLevel(puppet.MXID, expectedLevel) || changed
 		changed = levels.EnsureUserLevel(puppet.MXID, expectedLevel) || changed
-
-		if participant.JID == portal.user.JID() {
-			changed = levels.EnsureUserLevel(portal.user.ID, expectedLevel) || changed
+		if user != nil {
+			changed = levels.EnsureUserLevel(user.MXID, expectedLevel) || changed
 		}
 		}
 	}
 	}
 	if changed {
 	if changed {
@@ -146,10 +148,10 @@ func (portal *Portal) SyncParticipants(metadata *whatsappExt.GroupInfo) {
 	}
 	}
 }
 }
 
 
-func (portal *Portal) UpdateAvatar(avatar *whatsappExt.ProfilePicInfo) bool {
+func (portal *Portal) UpdateAvatar(user *User, avatar *whatsappExt.ProfilePicInfo) bool {
 	if avatar == nil {
 	if avatar == nil {
 		var err error
 		var err error
-		avatar, err = portal.user.Conn.GetProfilePicThumb(portal.JID)
+		avatar, err = user.Conn.GetProfilePicThumb(portal.Key.JID)
 		if err != nil {
 		if err != nil {
 			portal.log.Errorln(err)
 			portal.log.Errorln(err)
 			return false
 			return false
@@ -184,7 +186,7 @@ func (portal *Portal) UpdateAvatar(avatar *whatsappExt.ProfilePicInfo) bool {
 
 
 func (portal *Portal) UpdateName(name string, setBy types.WhatsAppID) bool {
 func (portal *Portal) UpdateName(name string, setBy types.WhatsAppID) bool {
 	if portal.Name != name {
 	if portal.Name != name {
-		intent := portal.user.GetPuppetByJID(setBy).Intent()
+		intent := portal.bridge.GetPuppetByJID(setBy).Intent()
 		_, err := intent.SetRoomName(portal.MXID, name)
 		_, err := intent.SetRoomName(portal.MXID, name)
 		if err == nil {
 		if err == nil {
 			portal.Name = name
 			portal.Name = name
@@ -197,7 +199,7 @@ func (portal *Portal) UpdateName(name string, setBy types.WhatsAppID) bool {
 
 
 func (portal *Portal) UpdateTopic(topic string, setBy types.WhatsAppID) bool {
 func (portal *Portal) UpdateTopic(topic string, setBy types.WhatsAppID) bool {
 	if portal.Topic != topic {
 	if portal.Topic != topic {
-		intent := portal.user.GetPuppetByJID(setBy).Intent()
+		intent := portal.bridge.GetPuppetByJID(setBy).Intent()
 		_, err := intent.SetRoomTopic(portal.MXID, topic)
 		_, err := intent.SetRoomTopic(portal.MXID, topic)
 		if err == nil {
 		if err == nil {
 			portal.Topic = topic
 			portal.Topic = topic
@@ -208,8 +210,8 @@ func (portal *Portal) UpdateTopic(topic string, setBy types.WhatsAppID) bool {
 	return false
 	return false
 }
 }
 
 
-func (portal *Portal) UpdateMetadata() bool {
-	metadata, err := portal.user.Conn.GetGroupMetaData(portal.JID)
+func (portal *Portal) UpdateMetadata(user *User) bool {
+	metadata, err := user.Conn.GetGroupMetaData(portal.Key.JID)
 	if err != nil {
 	if err != nil {
 		portal.log.Errorln(err)
 		portal.log.Errorln(err)
 		return false
 		return false
@@ -221,25 +223,23 @@ func (portal *Portal) UpdateMetadata() bool {
 	return update
 	return update
 }
 }
 
 
-func (portal *Portal) Sync(contact whatsapp.Contact) {
+func (portal *Portal) Sync(user *User, contact whatsapp.Contact) {
+	if portal.IsPrivateChat() {
+		return
+	}
+
 	if len(portal.MXID) == 0 {
 	if len(portal.MXID) == 0 {
-		if !portal.IsPrivateChat() {
-			portal.Name = contact.Name
-		}
-		err := portal.CreateMatrixRoom()
+		portal.Name = contact.Name
+		err := portal.CreateMatrixRoom([]string{user.MXID})
 		if err != nil {
 		if err != nil {
 			portal.log.Errorln("Failed to create portal room:", err)
 			portal.log.Errorln("Failed to create portal room:", err)
 			return
 			return
 		}
 		}
 	}
 	}
 
 
-	if portal.IsPrivateChat() {
-		return
-	}
-
 	update := false
 	update := false
-	update = portal.UpdateMetadata() || update
-	update = portal.UpdateAvatar(nil) || update
+	update = portal.UpdateMetadata(user) || update
+	update = portal.UpdateAvatar(user, nil) || update
 	if update {
 	if update {
 		portal.Update()
 		portal.Update()
 	}
 	}
@@ -277,11 +277,12 @@ func (portal *Portal) ChangeAdminStatus(jids []string, setAdmin bool) {
 	}
 	}
 	changed := false
 	changed := false
 	for _, jid := range jids {
 	for _, jid := range jids {
-		puppet := portal.user.GetPuppetByJID(jid)
+		puppet := portal.bridge.GetPuppetByJID(jid)
 		changed = levels.EnsureUserLevel(puppet.MXID, newLevel) || changed
 		changed = levels.EnsureUserLevel(puppet.MXID, newLevel) || changed
 
 
-		if jid == portal.user.JID() {
-			changed = levels.EnsureUserLevel(portal.user.ID, newLevel) || changed
+		user := portal.bridge.GetUserByJID(jid)
+		if user != nil {
+			changed = levels.EnsureUserLevel(user.MXID, newLevel) || changed
 		}
 		}
 	}
 	}
 	if changed {
 	if changed {
@@ -312,15 +313,15 @@ func (portal *Portal) RestrictMetadataChanges(restrict bool) {
 		newLevel = 50
 		newLevel = 50
 	}
 	}
 	changed := false
 	changed := false
-	changed = levels.EnsureEventLevel(gomatrix.StateRoomName, true, newLevel) || changed
-	changed = levels.EnsureEventLevel(gomatrix.StateRoomAvatar, true, newLevel) || changed
-	changed = levels.EnsureEventLevel(gomatrix.StateTopic, true, newLevel) || changed
+	changed = levels.EnsureEventLevel(gomatrix.StateRoomName, newLevel) || changed
+	changed = levels.EnsureEventLevel(gomatrix.StateRoomAvatar, newLevel) || changed
+	changed = levels.EnsureEventLevel(gomatrix.StateTopic, newLevel) || changed
 	if changed {
 	if changed {
 		portal.MainIntent().SetPowerLevels(portal.MXID, levels)
 		portal.MainIntent().SetPowerLevels(portal.MXID, levels)
 	}
 	}
 }
 }
 
 
-func (portal *Portal) CreateMatrixRoom() error {
+func (portal *Portal) CreateMatrixRoom(invite []string) error {
 	portal.roomCreateLock.Lock()
 	portal.roomCreateLock.Lock()
 	defer portal.roomCreateLock.Unlock()
 	defer portal.roomCreateLock.Unlock()
 	if len(portal.MXID) > 0 {
 	if len(portal.MXID) > 0 {
@@ -330,7 +331,6 @@ func (portal *Portal) CreateMatrixRoom() error {
 	name := portal.Name
 	name := portal.Name
 	topic := portal.Topic
 	topic := portal.Topic
 	isPrivateChat := false
 	isPrivateChat := false
-	invite := []string{portal.user.ID}
 	if portal.IsPrivateChat() {
 	if portal.IsPrivateChat() {
 		name = ""
 		name = ""
 		topic = "WhatsApp private chat"
 		topic = "WhatsApp private chat"
@@ -360,18 +360,18 @@ func (portal *Portal) CreateMatrixRoom() error {
 }
 }
 
 
 func (portal *Portal) IsPrivateChat() bool {
 func (portal *Portal) IsPrivateChat() bool {
-	return strings.HasSuffix(portal.JID, whatsappExt.NewUserSuffix)
+	return strings.HasSuffix(portal.Key.JID, whatsappExt.NewUserSuffix)
 }
 }
 
 
 func (portal *Portal) MainIntent() *appservice.IntentAPI {
 func (portal *Portal) MainIntent() *appservice.IntentAPI {
 	if portal.IsPrivateChat() {
 	if portal.IsPrivateChat() {
-		return portal.user.GetPuppetByJID(portal.JID).Intent()
+		return portal.bridge.GetPuppetByJID(portal.Key.JID).Intent()
 	}
 	}
-	return portal.bridge.AppService.BotIntent()
+	return portal.bridge.AS.BotIntent()
 }
 }
 
 
 func (portal *Portal) IsDuplicate(id types.WhatsAppMessageID) bool {
 func (portal *Portal) IsDuplicate(id types.WhatsAppMessageID) bool {
-	msg := portal.bridge.DB.Message.GetByJID(portal.Owner, id)
+	msg := portal.bridge.DB.Message.GetByJID(portal.Key, id)
 	if msg != nil {
 	if msg != nil {
 		return true
 		return true
 	}
 	}
@@ -380,7 +380,7 @@ func (portal *Portal) IsDuplicate(id types.WhatsAppMessageID) bool {
 
 
 func (portal *Portal) MarkHandled(jid types.WhatsAppMessageID, mxid types.MatrixEventID) {
 func (portal *Portal) MarkHandled(jid types.WhatsAppMessageID, mxid types.MatrixEventID) {
 	msg := portal.bridge.DB.Message.New()
 	msg := portal.bridge.DB.Message.New()
-	msg.Owner = portal.Owner
+	msg.Chat = portal.Key
 	msg.JID = jid
 	msg.JID = jid
 	msg.MXID = mxid
 	msg.MXID = mxid
 	msg.Insert()
 	msg.Insert()
@@ -392,7 +392,7 @@ func (portal *Portal) GetMessageIntent(info whatsapp.MessageInfo) *appservice.In
 			// TODO handle own messages in private chats properly
 			// TODO handle own messages in private chats properly
 			return nil
 			return nil
 		}
 		}
-		return portal.user.GetPuppetByJID(portal.user.JID()).Intent()
+		return portal.bridge.GetPuppetByJID(portal.Key.Receiver).Intent()
 	} else if portal.IsPrivateChat() {
 	} else if portal.IsPrivateChat() {
 		return portal.MainIntent()
 		return portal.MainIntent()
 	} else if len(info.SenderJid) == 0 {
 	} else if len(info.SenderJid) == 0 {
@@ -402,14 +402,14 @@ func (portal *Portal) GetMessageIntent(info whatsapp.MessageInfo) *appservice.In
 			return nil
 			return nil
 		}
 		}
 	}
 	}
-	return portal.user.GetPuppetByJID(info.SenderJid).Intent()
+	return portal.bridge.GetPuppetByJID(info.SenderJid).Intent()
 }
 }
 
 
 func (portal *Portal) SetReply(content *gomatrix.Content, info whatsapp.MessageInfo) {
 func (portal *Portal) SetReply(content *gomatrix.Content, info whatsapp.MessageInfo) {
 	if len(info.QuotedMessageID) == 0 {
 	if len(info.QuotedMessageID) == 0 {
 		return
 		return
 	}
 	}
-	message := portal.bridge.DB.Message.GetByJID(portal.Owner, info.QuotedMessageID)
+	message := portal.bridge.DB.Message.GetByJID(portal.Key, info.QuotedMessageID)
 	if message != nil {
 	if message != nil {
 		event, err := portal.MainIntent().GetEvent(portal.MXID, message.MXID)
 		event, err := portal.MainIntent().GetEvent(portal.MXID, message.MXID)
 		if err != nil {
 		if err != nil {
@@ -421,29 +421,12 @@ func (portal *Portal) SetReply(content *gomatrix.Content, info whatsapp.MessageI
 	return
 	return
 }
 }
 
 
-func (portal *Portal) FormatWhatsAppMessage(content *gomatrix.Content) {
-	output := html.EscapeString(content.Body)
-	for regex, replacement := range portal.user.waReplString {
-		output = regex.ReplaceAllString(output, replacement)
-	}
-	for regex, replacer := range portal.user.waReplFunc {
-		output = regex.ReplaceAllStringFunc(output, replacer)
-	}
-	if output != content.Body {
-		content.FormattedBody = output
-		content.Format = gomatrix.FormatHTML
-		for regex, replacer := range portal.user.waReplFuncText {
-			content.Body = regex.ReplaceAllStringFunc(content.Body, replacer)
-		}
-	}
-}
-
-func (portal *Portal) HandleTextMessage(message whatsapp.TextMessage) {
+func (portal *Portal) HandleTextMessage(source *User, message whatsapp.TextMessage) {
 	if portal.IsDuplicate(message.Info.Id) {
 	if portal.IsDuplicate(message.Info.Id) {
 		return
 		return
 	}
 	}
 
 
-	err := portal.CreateMatrixRoom()
+	err := portal.CreateMatrixRoom([]string{source.MXID})
 	if err != nil {
 	if err != nil {
 		portal.log.Errorln("Failed to create portal room:", err)
 		portal.log.Errorln("Failed to create portal room:", err)
 		return
 		return
@@ -459,7 +442,7 @@ func (portal *Portal) HandleTextMessage(message whatsapp.TextMessage) {
 		MsgType: gomatrix.MsgText,
 		MsgType: gomatrix.MsgText,
 	}
 	}
 
 
-	portal.FormatWhatsAppMessage(content)
+	portal.bridge.Formatter.ParseWhatsApp(content)
 	portal.SetReply(content, message.Info)
 	portal.SetReply(content, message.Info)
 
 
 	intent.UserTyping(portal.MXID, false, 0)
 	intent.UserTyping(portal.MXID, false, 0)
@@ -472,12 +455,12 @@ func (portal *Portal) HandleTextMessage(message whatsapp.TextMessage) {
 	portal.log.Debugln("Handled message", message.Info.Id, "->", resp.EventID)
 	portal.log.Debugln("Handled message", message.Info.Id, "->", resp.EventID)
 }
 }
 
 
-func (portal *Portal) HandleMediaMessage(download func() ([]byte, error), thumbnail []byte, info whatsapp.MessageInfo, mimeType, caption string) {
+func (portal *Portal) HandleMediaMessage(source *User, download func() ([]byte, error), thumbnail []byte, info whatsapp.MessageInfo, mimeType, caption string) {
 	if portal.IsDuplicate(info.Id) {
 	if portal.IsDuplicate(info.Id) {
 		return
 		return
 	}
 	}
 
 
-	err := portal.CreateMatrixRoom()
+	err := portal.CreateMatrixRoom([]string{source.MXID})
 	if err != nil {
 	if err != nil {
 		portal.log.Errorln("Failed to create portal room:", err)
 		portal.log.Errorln("Failed to create portal room:", err)
 		return
 		return
@@ -559,7 +542,7 @@ func (portal *Portal) HandleMediaMessage(download func() ([]byte, error), thumbn
 			MsgType: gomatrix.MsgNotice,
 			MsgType: gomatrix.MsgNotice,
 		}
 		}
 
 
-		portal.FormatWhatsAppMessage(captionContent)
+		portal.bridge.Formatter.ParseWhatsApp(captionContent)
 
 
 		_, err := intent.SendMassagedMessageEvent(portal.MXID, gomatrix.EventMessage, captionContent, ts)
 		_, err := intent.SendMassagedMessageEvent(portal.MXID, gomatrix.EventMessage, captionContent, ts)
 		if err != nil {
 		if err != nil {
@@ -612,7 +595,7 @@ func (portal *Portal) downloadThumbnail(evt *gomatrix.Event) []byte {
 	return buf.Bytes()
 	return buf.Bytes()
 }
 }
 
 
-func (portal *Portal) preprocessMatrixMedia(evt *gomatrix.Event, mediaType whatsapp.MediaType) *MediaUpload {
+func (portal *Portal) preprocessMatrixMedia(sender *User, evt *gomatrix.Event, mediaType whatsapp.MediaType) *MediaUpload {
 	if evt.Content.Info == nil {
 	if evt.Content.Info == nil {
 		evt.Content.Info = &gomatrix.FileInfo{}
 		evt.Content.Info = &gomatrix.FileInfo{}
 	}
 	}
@@ -630,7 +613,7 @@ func (portal *Portal) preprocessMatrixMedia(evt *gomatrix.Event, mediaType whats
 		return nil
 		return nil
 	}
 	}
 
 
-	url, mediaKey, fileEncSHA256, fileSHA256, fileLength, err := portal.user.Conn.Upload(bytes.NewReader(content), mediaType)
+	url, mediaKey, fileEncSHA256, fileSHA256, fileLength, err := sender.Conn.Upload(bytes.NewReader(content), mediaType)
 	if err != nil {
 	if err != nil {
 		portal.log.Errorfln("Failed to upload media in %s: %v", evt.ID, err)
 		portal.log.Errorfln("Failed to upload media in %s: %v", evt.ID, err)
 		return nil
 		return nil
@@ -657,8 +640,8 @@ type MediaUpload struct {
 	Thumbnail     []byte
 	Thumbnail     []byte
 }
 }
 
 
-func (portal *Portal) GetMessage(jid types.WhatsAppMessageID) *waProto.WebMessageInfo {
-	node, err := portal.user.Conn.LoadMessagesBefore(portal.JID, jid, 1)
+func (portal *Portal) GetMessage(user *User, jid types.WhatsAppMessageID) *waProto.WebMessageInfo {
+	node, err := user.Conn.LoadMessagesBefore(portal.Key.JID, jid, 1)
 	if err != nil {
 	if err != nil {
 		return nil
 		return nil
 	}
 	}
@@ -670,7 +653,7 @@ func (portal *Portal) GetMessage(jid types.WhatsAppMessageID) *waProto.WebMessag
 	if !ok {
 	if !ok {
 		return nil
 		return nil
 	}
 	}
-	node, err = portal.user.Conn.LoadMessagesAfter(portal.JID, msg.GetKey().GetId(), 1)
+	node, err = user.Conn.LoadMessagesAfter(portal.Key.JID, msg.GetKey().GetId(), 1)
 	if err != nil {
 	if err != nil {
 		return nil
 		return nil
 	}
 	}
@@ -682,7 +665,11 @@ func (portal *Portal) GetMessage(jid types.WhatsAppMessageID) *waProto.WebMessag
 	return msg
 	return msg
 }
 }
 
 
-func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
+func (portal *Portal) HandleMatrixMessage(sender *User, evt *gomatrix.Event) {
+	if portal.IsPrivateChat() && sender.JID != portal.Key.Receiver {
+		return
+	}
+
 	ts := uint64(evt.Timestamp / 1000)
 	ts := uint64(evt.Timestamp / 1000)
 	status := waProto.WebMessageInfo_ERROR
 	status := waProto.WebMessageInfo_ERROR
 	fromMe := true
 	fromMe := true
@@ -690,7 +677,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
 		Key: &waProto.MessageKey{
 		Key: &waProto.MessageKey{
 			FromMe:    &fromMe,
 			FromMe:    &fromMe,
 			Id:        makeMessageID(),
 			Id:        makeMessageID(),
-			RemoteJid: &portal.JID,
+			RemoteJid: &portal.Key.JID,
 		},
 		},
 		MessageTimestamp: &ts,
 		MessageTimestamp: &ts,
 		Message:          &waProto.Message{},
 		Message:          &waProto.Message{},
@@ -702,12 +689,12 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
 		evt.Content.RemoveReplyFallback()
 		evt.Content.RemoveReplyFallback()
 		msg := portal.bridge.DB.Message.GetByMXID(replyToID)
 		msg := portal.bridge.DB.Message.GetByMXID(replyToID)
 		if msg != nil {
 		if msg != nil {
-			origMsg := portal.GetMessage(msg.JID)
+			origMsg := portal.GetMessage(sender, msg.JID)
 			if origMsg != nil {
 			if origMsg != nil {
 				ctxInfo.StanzaId = &msg.JID
 				ctxInfo.StanzaId = &msg.JID
 				replyMsgSender := origMsg.GetParticipant()
 				replyMsgSender := origMsg.GetParticipant()
 				if origMsg.GetKey().GetFromMe() {
 				if origMsg.GetKey().GetFromMe() {
-					replyMsgSender = portal.user.JID()
+					replyMsgSender = sender.JID
 				}
 				}
 				ctxInfo.Participant = &replyMsgSender
 				ctxInfo.Participant = &replyMsgSender
 				ctxInfo.QuotedMessage = []*waProto.Message{origMsg.Message}
 				ctxInfo.QuotedMessage = []*waProto.Message{origMsg.Message}
@@ -719,7 +706,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
 	case gomatrix.MsgText, gomatrix.MsgEmote:
 	case gomatrix.MsgText, gomatrix.MsgEmote:
 		text := evt.Content.Body
 		text := evt.Content.Body
 		if evt.Content.Format == gomatrix.FormatHTML {
 		if evt.Content.Format == gomatrix.FormatHTML {
-			text = portal.user.htmlParser.Parse(evt.Content.FormattedBody)
+			text = portal.bridge.Formatter.ParseMatrix(evt.Content.FormattedBody)
 		}
 		}
 		if evt.Content.MsgType == gomatrix.MsgEmote {
 		if evt.Content.MsgType == gomatrix.MsgEmote {
 			text = "/me " + text
 			text = "/me " + text
@@ -737,7 +724,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
 			info.Message.Conversation = &text
 			info.Message.Conversation = &text
 		}
 		}
 	case gomatrix.MsgImage:
 	case gomatrix.MsgImage:
-		media := portal.preprocessMatrixMedia(evt, whatsapp.MediaImage)
+		media := portal.preprocessMatrixMedia(sender, evt, whatsapp.MediaImage)
 		if media == nil {
 		if media == nil {
 			return
 			return
 		}
 		}
@@ -752,7 +739,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
 			FileLength:    &media.FileLength,
 			FileLength:    &media.FileLength,
 		}
 		}
 	case gomatrix.MsgVideo:
 	case gomatrix.MsgVideo:
-		media := portal.preprocessMatrixMedia(evt, whatsapp.MediaVideo)
+		media := portal.preprocessMatrixMedia(sender, evt, whatsapp.MediaVideo)
 		if media == nil {
 		if media == nil {
 			return
 			return
 		}
 		}
@@ -769,7 +756,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
 			FileLength:    &media.FileLength,
 			FileLength:    &media.FileLength,
 		}
 		}
 	case gomatrix.MsgAudio:
 	case gomatrix.MsgAudio:
-		media := portal.preprocessMatrixMedia(evt, whatsapp.MediaAudio)
+		media := portal.preprocessMatrixMedia(sender, evt, whatsapp.MediaAudio)
 		if media == nil {
 		if media == nil {
 			return
 			return
 		}
 		}
@@ -784,7 +771,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
 			FileLength:    &media.FileLength,
 			FileLength:    &media.FileLength,
 		}
 		}
 	case gomatrix.MsgFile:
 	case gomatrix.MsgFile:
-		media := portal.preprocessMatrixMedia(evt, whatsapp.MediaDocument)
+		media := portal.preprocessMatrixMedia(sender, evt, whatsapp.MediaDocument)
 		if media == nil {
 		if media == nil {
 			return
 			return
 		}
 		}
@@ -800,7 +787,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
 		portal.log.Debugln("Unhandled Matrix event:", evt)
 		portal.log.Debugln("Unhandled Matrix event:", evt)
 		return
 		return
 	}
 	}
-	err = portal.user.Conn.Send(info)
+	err = sender.Conn.Send(info)
 	portal.MarkHandled(info.GetKey().GetId(), evt.ID)
 	portal.MarkHandled(info.GetKey().GetId(), evt.ID)
 	if err != nil {
 	if err != nil {
 		portal.log.Errorfln("Error handling Matrix event %s: %v", evt.ID, err)
 		portal.log.Errorfln("Error handling Matrix event %s: %v", evt.ID, err)

+ 35 - 57
puppet.go

@@ -30,105 +30,83 @@ import (
 	"maunium.net/go/mautrix-whatsapp/whatsapp-ext"
 	"maunium.net/go/mautrix-whatsapp/whatsapp-ext"
 )
 )
 
 
-func (bridge *Bridge) ParsePuppetMXID(mxid types.MatrixUserID) (types.MatrixUserID, types.WhatsAppID, bool) {
+func (bridge *Bridge) ParsePuppetMXID(mxid types.MatrixUserID) (types.WhatsAppID, bool) {
 	userIDRegex, err := regexp.Compile(fmt.Sprintf("^@%s:%s$",
 	userIDRegex, err := regexp.Compile(fmt.Sprintf("^@%s:%s$",
-		bridge.Config.Bridge.FormatUsername("(.+)", "([0-9]+)"),
+		bridge.Config.Bridge.FormatUsername("([0-9]+)"),
 		bridge.Config.Homeserver.Domain))
 		bridge.Config.Homeserver.Domain))
 	if err != nil {
 	if err != nil {
 		bridge.Log.Warnln("Failed to compile puppet user ID regex:", err)
 		bridge.Log.Warnln("Failed to compile puppet user ID regex:", err)
-		return "", "", false
+		return "", false
 	}
 	}
 	match := userIDRegex.FindStringSubmatch(string(mxid))
 	match := userIDRegex.FindStringSubmatch(string(mxid))
-	if match == nil || len(match) != 3 {
-		return "", "", false
+	if match == nil || len(match) != 2 {
+		return "", false
 	}
 	}
 
 
-	receiver := types.MatrixUserID(match[1])
-	receiver = strings.Replace(receiver, "=40", "@", 1)
-	colonIndex := strings.LastIndex(receiver, "=3")
-	receiver = receiver[:colonIndex] + ":" + receiver[colonIndex+len("=3"):]
 	jid := types.WhatsAppID(match[2] + whatsappExt.NewUserSuffix)
 	jid := types.WhatsAppID(match[2] + whatsappExt.NewUserSuffix)
-	return receiver, jid, true
+	return jid, true
 }
 }
 
 
 func (bridge *Bridge) GetPuppetByMXID(mxid types.MatrixUserID) *Puppet {
 func (bridge *Bridge) GetPuppetByMXID(mxid types.MatrixUserID) *Puppet {
-	receiver, jid, ok := bridge.ParsePuppetMXID(mxid)
+	jid, ok := bridge.ParsePuppetMXID(mxid)
 	if !ok {
 	if !ok {
 		return nil
 		return nil
 	}
 	}
 
 
-	user := bridge.GetUser(receiver)
-	if user == nil {
-		return nil
-	}
-
-	return user.GetPuppetByJID(jid)
-}
-
-func (user *User) GetPuppetByMXID(mxid types.MatrixUserID) *Puppet {
-	receiver, jid, ok := user.bridge.ParsePuppetMXID(mxid)
-	if !ok || receiver != user.ID {
-		return nil
-	}
-
-	return user.GetPuppetByJID(jid)
+	return bridge.GetPuppetByJID(jid)
 }
 }
 
 
-func (user *User) GetPuppetByJID(jid types.WhatsAppID) *Puppet {
-	user.puppetsLock.Lock()
-	defer user.puppetsLock.Unlock()
-	puppet, ok := user.puppets[jid]
+func (bridge *Bridge) GetPuppetByJID(jid types.WhatsAppID) *Puppet {
+	bridge.puppetsLock.Lock()
+	defer bridge.puppetsLock.Unlock()
+	puppet, ok := bridge.puppets[jid]
 	if !ok {
 	if !ok {
-		dbPuppet := user.bridge.DB.Puppet.Get(jid, user.ID)
+		dbPuppet := bridge.DB.Puppet.Get(jid)
 		if dbPuppet == nil {
 		if dbPuppet == nil {
-			dbPuppet = user.bridge.DB.Puppet.New()
+			dbPuppet = bridge.DB.Puppet.New()
 			dbPuppet.JID = jid
 			dbPuppet.JID = jid
-			dbPuppet.Receiver = user.ID
 			dbPuppet.Insert()
 			dbPuppet.Insert()
 		}
 		}
-		puppet = user.NewPuppet(dbPuppet)
-		user.puppets[puppet.JID] = puppet
+		puppet = bridge.NewPuppet(dbPuppet)
+		bridge.puppets[puppet.JID] = puppet
 	}
 	}
 	return puppet
 	return puppet
 }
 }
 
 
-func (user *User) GetAllPuppets() []*Puppet {
-	user.puppetsLock.Lock()
-	defer user.puppetsLock.Unlock()
-	dbPuppets := user.bridge.DB.Puppet.GetAll(user.ID)
+func (bridge *Bridge) GetAllPuppets() []*Puppet {
+	bridge.puppetsLock.Lock()
+	defer bridge.puppetsLock.Unlock()
+	dbPuppets := bridge.DB.Puppet.GetAll()
 	output := make([]*Puppet, len(dbPuppets))
 	output := make([]*Puppet, len(dbPuppets))
 	for index, dbPuppet := range dbPuppets {
 	for index, dbPuppet := range dbPuppets {
-		puppet, ok := user.puppets[dbPuppet.JID]
+		puppet, ok := bridge.puppets[dbPuppet.JID]
 		if !ok {
 		if !ok {
-			puppet = user.NewPuppet(dbPuppet)
-			user.puppets[dbPuppet.JID] = puppet
+			puppet = bridge.NewPuppet(dbPuppet)
+			bridge.puppets[dbPuppet.JID] = puppet
 		}
 		}
 		output[index] = puppet
 		output[index] = puppet
 	}
 	}
 	return output
 	return output
 }
 }
 
 
-func (user *User) NewPuppet(dbPuppet *database.Puppet) *Puppet {
+func (bridge *Bridge) NewPuppet(dbPuppet *database.Puppet) *Puppet {
 	return &Puppet{
 	return &Puppet{
 		Puppet: dbPuppet,
 		Puppet: dbPuppet,
-		user:   user,
-		bridge: user.bridge,
-		log:    user.log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)),
+		bridge: bridge,
+		log:    bridge.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)),
 
 
 		MXID: fmt.Sprintf("@%s:%s",
 		MXID: fmt.Sprintf("@%s:%s",
-			user.bridge.Config.Bridge.FormatUsername(
-				dbPuppet.Receiver,
+			bridge.Config.Bridge.FormatUsername(
 				strings.Replace(
 				strings.Replace(
 					dbPuppet.JID,
 					dbPuppet.JID,
 					whatsappExt.NewUserSuffix, "", 1)),
 					whatsappExt.NewUserSuffix, "", 1)),
-			user.bridge.Config.Homeserver.Domain),
+			bridge.Config.Homeserver.Domain),
 	}
 	}
 }
 }
 
 
 type Puppet struct {
 type Puppet struct {
 	*database.Puppet
 	*database.Puppet
 
 
-	user   *User
 	bridge *Bridge
 	bridge *Bridge
 	log    log.Logger
 	log    log.Logger
 
 
@@ -143,13 +121,13 @@ func (puppet *Puppet) PhoneNumber() string {
 }
 }
 
 
 func (puppet *Puppet) Intent() *appservice.IntentAPI {
 func (puppet *Puppet) Intent() *appservice.IntentAPI {
-	return puppet.bridge.AppService.Intent(puppet.MXID)
+	return puppet.bridge.AS.Intent(puppet.MXID)
 }
 }
 
 
-func (puppet *Puppet) UpdateAvatar(avatar *whatsappExt.ProfilePicInfo) bool {
+func (puppet *Puppet) UpdateAvatar(source *User, avatar *whatsappExt.ProfilePicInfo) bool {
 	if avatar == nil {
 	if avatar == nil {
 		var err error
 		var err error
-		avatar, err = puppet.user.Conn.GetProfilePicThumb(puppet.JID)
+		avatar, err = source.Conn.GetProfilePicThumb(puppet.JID)
 		if err != nil {
 		if err != nil {
 			puppet.log.Errorln(err)
 			puppet.log.Errorln(err)
 			return false
 			return false
@@ -184,11 +162,11 @@ func (puppet *Puppet) UpdateAvatar(avatar *whatsappExt.ProfilePicInfo) bool {
 	return true
 	return true
 }
 }
 
 
-func (puppet *Puppet) Sync(contact whatsapp.Contact) {
+func (puppet *Puppet) Sync(source *User, contact whatsapp.Contact) {
 	puppet.Intent().EnsureRegistered()
 	puppet.Intent().EnsureRegistered()
 
 
-	if contact.Jid == puppet.user.JID() {
-		contact.Notify = puppet.user.Conn.Info.Pushname
+	if contact.Jid == source.JID {
+		contact.Notify = source.Conn.Info.Pushname
 	}
 	}
 	newName := puppet.bridge.Config.Bridge.FormatDisplayname(contact)
 	newName := puppet.bridge.Config.Bridge.FormatDisplayname(contact)
 	if puppet.Displayname != newName {
 	if puppet.Displayname != newName {
@@ -201,7 +179,7 @@ func (puppet *Puppet) Sync(contact whatsapp.Contact) {
 		}
 		}
 	}
 	}
 
 
-	if puppet.UpdateAvatar(nil) {
+	if puppet.UpdateAvatar(source, nil) {
 		puppet.Update()
 		puppet.Update()
 	}
 	}
 }
 }

+ 70 - 62
user.go

@@ -17,14 +17,11 @@
 package main
 package main
 
 
 import (
 import (
-	"regexp"
 	"strings"
 	"strings"
-	"sync"
 	"time"
 	"time"
 
 
 	"github.com/Rhymen/go-whatsapp"
 	"github.com/Rhymen/go-whatsapp"
 	"github.com/skip2/go-qrcode"
 	"github.com/skip2/go-qrcode"
-	"maunium.net/go/gomatrix/format"
 	log "maunium.net/go/maulogger"
 	log "maunium.net/go/maulogger"
 	"maunium.net/go/mautrix-whatsapp/database"
 	"maunium.net/go/mautrix-whatsapp/database"
 	"maunium.net/go/mautrix-whatsapp/types"
 	"maunium.net/go/mautrix-whatsapp/types"
@@ -41,31 +38,42 @@ type User struct {
 	Admin       bool
 	Admin       bool
 	Whitelisted bool
 	Whitelisted bool
 	jid         string
 	jid         string
+}
 
 
-	portalsByMXID map[types.MatrixRoomID]*Portal
-	portalsByJID  map[types.WhatsAppID]*Portal
-	portalsLock   sync.Mutex
-	puppets       map[types.WhatsAppID]*Puppet
-	puppetsLock   sync.Mutex
-
-	htmlParser *format.HTMLParser
-
-	waReplString   map[*regexp.Regexp]string
-	waReplFunc     map[*regexp.Regexp]func(string) string
-	waReplFuncText map[*regexp.Regexp]func(string) string
+func (bridge *Bridge) GetUserByMXID(userID types.MatrixUserID) *User {
+	bridge.usersLock.Lock()
+	defer bridge.usersLock.Unlock()
+	user, ok := bridge.usersByMXID[userID]
+	if !ok {
+		dbUser := bridge.DB.User.GetByMXID(userID)
+		if dbUser == nil {
+			dbUser = bridge.DB.User.New()
+			dbUser.MXID = userID
+			dbUser.Insert()
+		}
+		user = bridge.NewUser(dbUser)
+		bridge.usersByMXID[user.MXID] = user
+		if len(user.ManagementRoom) > 0 {
+			bridge.managementRooms[user.ManagementRoom] = user
+		}
+	}
+	return user
 }
 }
 
 
-func (bridge *Bridge) GetUser(userID types.MatrixUserID) *User {
-	user, ok := bridge.users[userID]
+
+func (bridge *Bridge) GetUserByJID(userID types.WhatsAppID) *User {
+	bridge.usersLock.Lock()
+	defer bridge.usersLock.Unlock()
+	user, ok := bridge.usersByJID[userID]
 	if !ok {
 	if !ok {
-		dbUser := bridge.DB.User.Get(userID)
+		dbUser := bridge.DB.User.GetByMXID(userID)
 		if dbUser == nil {
 		if dbUser == nil {
 			dbUser = bridge.DB.User.New()
 			dbUser = bridge.DB.User.New()
-			dbUser.ID = userID
+			dbUser.MXID = userID
 			dbUser.Insert()
 			dbUser.Insert()
 		}
 		}
 		user = bridge.NewUser(dbUser)
 		user = bridge.NewUser(dbUser)
-		bridge.users[user.ID] = user
+		bridge.usersByJID[user.JID] = user
 		if len(user.ManagementRoom) > 0 {
 		if len(user.ManagementRoom) > 0 {
 			bridge.managementRooms[user.ManagementRoom] = user
 			bridge.managementRooms[user.ManagementRoom] = user
 		}
 		}
@@ -74,13 +82,15 @@ func (bridge *Bridge) GetUser(userID types.MatrixUserID) *User {
 }
 }
 
 
 func (bridge *Bridge) GetAllUsers() []*User {
 func (bridge *Bridge) GetAllUsers() []*User {
+	bridge.usersLock.Lock()
+	defer bridge.usersLock.Unlock()
 	dbUsers := bridge.DB.User.GetAll()
 	dbUsers := bridge.DB.User.GetAll()
 	output := make([]*User, len(dbUsers))
 	output := make([]*User, len(dbUsers))
 	for index, dbUser := range dbUsers {
 	for index, dbUser := range dbUsers {
-		user, ok := bridge.users[dbUser.ID]
+		user, ok := bridge.usersByMXID[dbUser.MXID]
 		if !ok {
 		if !ok {
 			user = bridge.NewUser(dbUser)
 			user = bridge.NewUser(dbUser)
-			bridge.users[user.ID] = user
+			bridge.usersByMXID[user.MXID] = user
 			if len(user.ManagementRoom) > 0 {
 			if len(user.ManagementRoom) > 0 {
 				bridge.managementRooms[user.ManagementRoom] = user
 				bridge.managementRooms[user.ManagementRoom] = user
 			}
 			}
@@ -94,15 +104,10 @@ func (bridge *Bridge) NewUser(dbUser *database.User) *User {
 	user := &User{
 	user := &User{
 		User:          dbUser,
 		User:          dbUser,
 		bridge:        bridge,
 		bridge:        bridge,
-		log:           bridge.Log.Sub("User").Sub(string(dbUser.ID)),
-		portalsByMXID: make(map[types.MatrixRoomID]*Portal),
-		portalsByJID:  make(map[types.WhatsAppID]*Portal),
-		puppets:       make(map[types.WhatsAppID]*Puppet),
+		log:           bridge.Log.Sub("User").Sub(string(dbUser.MXID)),
 	}
 	}
-	user.Whitelisted = user.bridge.Config.Bridge.Permissions.IsWhitelisted(user.ID)
-	user.Admin = user.bridge.Config.Bridge.Permissions.IsAdmin(user.ID)
-	user.htmlParser = user.newHTMLParser()
-	user.waReplString, user.waReplFunc, user.waReplFuncText = user.newWhatsAppFormatMaps()
+	user.Whitelisted = user.bridge.Config.Bridge.Permissions.IsWhitelisted(user.MXID)
+	user.Admin = user.bridge.Config.Bridge.Permissions.IsAdmin(user.MXID)
 	return user
 	return user
 }
 }
 
 
@@ -152,7 +157,6 @@ func (user *User) RestoreSession() bool {
 		sess, err := user.Conn.RestoreSession(*user.Session)
 		sess, err := user.Conn.RestoreSession(*user.Session)
 		if err != nil {
 		if err != nil {
 			user.log.Errorln("Failed to restore session:", err)
 			user.log.Errorln("Failed to restore session:", err)
-			//user.SetSession(nil)
 			return false
 			return false
 		}
 		}
 		user.SetSession(&sess)
 		user.SetSession(&sess)
@@ -162,8 +166,12 @@ func (user *User) RestoreSession() bool {
 	return false
 	return false
 }
 }
 
 
+func (user *User) IsLoggedIn() bool {
+	return user.Conn != nil
+}
+
 func (user *User) Login(roomID types.MatrixRoomID) {
 func (user *User) Login(roomID types.MatrixRoomID) {
-	bot := user.bridge.AppService.BotClient()
+	bot := user.bridge.AS.BotClient()
 
 
 	qrChan := make(chan string, 2)
 	qrChan := make(chan string, 2)
 	go func() {
 	go func() {
@@ -194,38 +202,24 @@ func (user *User) Login(roomID types.MatrixRoomID) {
 		qrChan <- "error"
 		qrChan <- "error"
 		return
 		return
 	}
 	}
+	user.JID = strings.Replace(user.Conn.Info.Wid, whatsappExt.OldUserSuffix, whatsappExt.NewUserSuffix, 1)
 	user.Session = &session
 	user.Session = &session
 	user.Update()
 	user.Update()
 	bot.SendNotice(roomID, "Successfully logged in. Synchronizing chats...")
 	bot.SendNotice(roomID, "Successfully logged in. Synchronizing chats...")
 	go user.Sync()
 	go user.Sync()
 }
 }
 
 
-func (user *User) JID() string {
-	if user.Conn == nil {
-		return ""
-	}
-	if len(user.jid) == 0 {
-		user.jid = strings.Replace(user.Conn.Info.Wid, whatsappExt.OldUserSuffix, whatsappExt.NewUserSuffix, 1)
-	}
-	return user.jid
-}
-
 func (user *User) Sync() {
 func (user *User) Sync() {
 	user.log.Debugln("Syncing...")
 	user.log.Debugln("Syncing...")
 	user.Conn.Contacts()
 	user.Conn.Contacts()
 	for jid, contact := range user.Conn.Store.Contacts {
 	for jid, contact := range user.Conn.Store.Contacts {
 		if strings.HasSuffix(jid, whatsappExt.NewUserSuffix) {
 		if strings.HasSuffix(jid, whatsappExt.NewUserSuffix) {
-			puppet := user.GetPuppetByJID(contact.Jid)
-			puppet.Sync(contact)
-		}
-
-		if len(contact.Notify) == 0 && !strings.HasSuffix(jid, "@g.us") {
-			// No messages sent -> don't bridge
-			continue
+			puppet := user.bridge.GetPuppetByJID(contact.Jid)
+			puppet.Sync(user, contact)
+		} else {
+			portal := user.bridge.GetPortalByJID(database.GroupPortalKey(contact.Jid))
+			portal.Sync(user, contact)
 		}
 		}
-
-		portal := user.GetPortalByJID(contact.Jid)
-		portal.Sync(contact)
 	}
 	}
 }
 }
 
 
@@ -237,33 +231,41 @@ func (user *User) HandleJSONParseError(err error) {
 	user.log.Errorln("WhatsApp JSON parse error:", err)
 	user.log.Errorln("WhatsApp JSON parse error:", err)
 }
 }
 
 
+func (user *User) PortalKey(jid types.WhatsAppID) database.PortalKey {
+	return database.NewPortalKey(jid, user.JID)
+}
+
+func (user *User) GetPortalByJID(jid types.WhatsAppID) *Portal {
+	return user.bridge.GetPortalByJID(user.PortalKey(jid))
+}
+
 func (user *User) HandleTextMessage(message whatsapp.TextMessage) {
 func (user *User) HandleTextMessage(message whatsapp.TextMessage) {
 	portal := user.GetPortalByJID(message.Info.RemoteJid)
 	portal := user.GetPortalByJID(message.Info.RemoteJid)
-	portal.HandleTextMessage(message)
+	portal.HandleTextMessage(user, message)
 }
 }
 
 
 func (user *User) HandleImageMessage(message whatsapp.ImageMessage) {
 func (user *User) HandleImageMessage(message whatsapp.ImageMessage) {
 	portal := user.GetPortalByJID(message.Info.RemoteJid)
 	portal := user.GetPortalByJID(message.Info.RemoteJid)
-	portal.HandleMediaMessage(message.Download, message.Thumbnail, message.Info, message.Type, message.Caption)
+	portal.HandleMediaMessage(user, message.Download, message.Thumbnail, message.Info, message.Type, message.Caption)
 }
 }
 
 
 func (user *User) HandleVideoMessage(message whatsapp.VideoMessage) {
 func (user *User) HandleVideoMessage(message whatsapp.VideoMessage) {
 	portal := user.GetPortalByJID(message.Info.RemoteJid)
 	portal := user.GetPortalByJID(message.Info.RemoteJid)
-	portal.HandleMediaMessage(message.Download, message.Thumbnail, message.Info, message.Type, message.Caption)
+	portal.HandleMediaMessage(user, message.Download, message.Thumbnail, message.Info, message.Type, message.Caption)
 }
 }
 
 
 func (user *User) HandleAudioMessage(message whatsapp.AudioMessage) {
 func (user *User) HandleAudioMessage(message whatsapp.AudioMessage) {
 	portal := user.GetPortalByJID(message.Info.RemoteJid)
 	portal := user.GetPortalByJID(message.Info.RemoteJid)
-	portal.HandleMediaMessage(message.Download, nil, message.Info, message.Type, "")
+	portal.HandleMediaMessage(user, message.Download, nil, message.Info, message.Type, "")
 }
 }
 
 
 func (user *User) HandleDocumentMessage(message whatsapp.DocumentMessage) {
 func (user *User) HandleDocumentMessage(message whatsapp.DocumentMessage) {
 	portal := user.GetPortalByJID(message.Info.RemoteJid)
 	portal := user.GetPortalByJID(message.Info.RemoteJid)
-	portal.HandleMediaMessage(message.Download, message.Thumbnail, message.Info, message.Type, message.Title)
+	portal.HandleMediaMessage(user, message.Download, message.Thumbnail, message.Info, message.Type, message.Title)
 }
 }
 
 
 func (user *User) HandlePresence(info whatsappExt.Presence) {
 func (user *User) HandlePresence(info whatsappExt.Presence) {
-	puppet := user.GetPuppetByJID(info.SenderJID)
+	puppet := user.bridge.GetPuppetByJID(info.SenderJID)
 	switch info.Status {
 	switch info.Status {
 	case whatsappExt.PresenceUnavailable:
 	case whatsappExt.PresenceUnavailable:
 		puppet.Intent().SetPresence("offline")
 		puppet.Intent().SetPresence("offline")
@@ -277,6 +279,12 @@ func (user *User) HandlePresence(info whatsappExt.Presence) {
 		}
 		}
 	case whatsappExt.PresenceComposing:
 	case whatsappExt.PresenceComposing:
 		portal := user.GetPortalByJID(info.JID)
 		portal := user.GetPortalByJID(info.JID)
+		if len(puppet.typingIn) > 0 && puppet.typingAt+15 > time.Now().Unix() {
+			if puppet.typingIn == portal.MXID {
+				return
+			}
+			puppet.Intent().UserTyping(puppet.typingIn, false, 0)
+		}
 		puppet.typingIn = portal.MXID
 		puppet.typingIn = portal.MXID
 		puppet.typingAt = time.Now().Unix()
 		puppet.typingAt = time.Now().Unix()
 		puppet.Intent().UserTyping(portal.MXID, true, 15*1000)
 		puppet.Intent().UserTyping(portal.MXID, true, 15*1000)
@@ -290,9 +298,9 @@ func (user *User) HandleMsgInfo(info whatsappExt.MsgInfo) {
 			return
 			return
 		}
 		}
 
 
-		intent := user.GetPuppetByJID(info.SenderJID).Intent()
+		intent := user.bridge.GetPuppetByJID(info.SenderJID).Intent()
 		for _, id := range info.IDs {
 		for _, id := range info.IDs {
-			msg := user.bridge.DB.Message.GetByJID(user.ID, id)
+			msg := user.bridge.DB.Message.GetByJID(portal.Key, id)
 			if msg == nil {
 			if msg == nil {
 				continue
 				continue
 			}
 			}
@@ -308,11 +316,11 @@ func (user *User) HandleCommand(cmd whatsappExt.Command) {
 	switch cmd.Type {
 	switch cmd.Type {
 	case whatsappExt.CommandPicture:
 	case whatsappExt.CommandPicture:
 		if strings.HasSuffix(cmd.JID, whatsappExt.NewUserSuffix) {
 		if strings.HasSuffix(cmd.JID, whatsappExt.NewUserSuffix) {
-			puppet := user.GetPuppetByJID(cmd.JID)
-			puppet.UpdateAvatar(cmd.ProfilePicInfo)
+			puppet := user.bridge.GetPuppetByJID(cmd.JID)
+			puppet.UpdateAvatar(user, cmd.ProfilePicInfo)
 		} else {
 		} else {
 			portal := user.GetPortalByJID(cmd.JID)
 			portal := user.GetPortalByJID(cmd.JID)
-			portal.UpdateAvatar(cmd.ProfilePicInfo)
+			portal.UpdateAvatar(user, cmd.ProfilePicInfo)
 		}
 		}
 	}
 	}
 }
 }

+ 9 - 9
vendor/maunium.net/go/gomatrix/client.go

@@ -463,7 +463,7 @@ func (cli *Client) SetAvatarURL(url string) (err error) {
 // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
 // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
 func (cli *Client) SendMessageEvent(roomID string, eventType EventType, contentJSON interface{}) (resp *RespSendEvent, err error) {
 func (cli *Client) SendMessageEvent(roomID string, eventType EventType, contentJSON interface{}) (resp *RespSendEvent, err error) {
 	txnID := txnID()
 	txnID := txnID()
-	urlPath := cli.BuildURL("rooms", roomID, "send", string(eventType), txnID)
+	urlPath := cli.BuildURL("rooms", roomID, "send", eventType.String(), txnID)
 	_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
 	_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
 	return
 	return
 }
 }
@@ -472,7 +472,7 @@ func (cli *Client) SendMessageEvent(roomID string, eventType EventType, contentJ
 // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
 // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
 func (cli *Client) SendMassagedMessageEvent(roomID string, eventType EventType, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) {
 func (cli *Client) SendMassagedMessageEvent(roomID string, eventType EventType, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) {
 	txnID := txnID()
 	txnID := txnID()
-	urlPath := cli.BuildURLWithQuery([]string{"rooms", roomID, "send", string(eventType), txnID}, map[string]string{
+	urlPath := cli.BuildURLWithQuery([]string{"rooms", roomID, "send", eventType.String(), txnID}, map[string]string{
 		"ts": strconv.FormatInt(ts, 10),
 		"ts": strconv.FormatInt(ts, 10),
 	})
 	})
 	_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
 	_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
@@ -482,7 +482,7 @@ func (cli *Client) SendMassagedMessageEvent(roomID string, eventType EventType,
 // SendStateEvent sends a state event into a room. See http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-state-eventtype-statekey
 // SendStateEvent sends a state event into a room. See http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-state-eventtype-statekey
 // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
 // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
 func (cli *Client) SendStateEvent(roomID string, eventType EventType, stateKey string, contentJSON interface{}) (resp *RespSendEvent, err error) {
 func (cli *Client) SendStateEvent(roomID string, eventType EventType, stateKey string, contentJSON interface{}) (resp *RespSendEvent, err error) {
-	urlPath := cli.BuildURL("rooms", roomID, "state", string(eventType), stateKey)
+	urlPath := cli.BuildURL("rooms", roomID, "state", eventType.String(), stateKey)
 	_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
 	_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
 	return
 	return
 }
 }
@@ -490,7 +490,7 @@ func (cli *Client) SendStateEvent(roomID string, eventType EventType, stateKey s
 // SendStateEvent sends a state event into a room. See http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-state-eventtype-statekey
 // SendStateEvent sends a state event into a room. See http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-state-eventtype-statekey
 // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
 // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
 func (cli *Client) SendMassagedStateEvent(roomID string, eventType EventType, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) {
 func (cli *Client) SendMassagedStateEvent(roomID string, eventType EventType, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) {
-	urlPath := cli.BuildURLWithQuery([]string{"rooms", roomID, "state", string(eventType), stateKey}, map[string]string{
+	urlPath := cli.BuildURLWithQuery([]string{"rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{
 		"ts": strconv.FormatInt(ts, 10),
 		"ts": strconv.FormatInt(ts, 10),
 	})
 	})
 	_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
 	_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
@@ -500,7 +500,7 @@ func (cli *Client) SendMassagedStateEvent(roomID string, eventType EventType, st
 // SendText sends an m.room.message event into the given room with a msgtype of m.text
 // SendText sends an m.room.message event into the given room with a msgtype of m.text
 // See http://matrix.org/docs/spec/client_server/r0.2.0.html#m-text
 // See http://matrix.org/docs/spec/client_server/r0.2.0.html#m-text
 func (cli *Client) SendText(roomID, text string) (*RespSendEvent, error) {
 func (cli *Client) SendText(roomID, text string) (*RespSendEvent, error) {
-	return cli.SendMessageEvent(roomID, "m.room.message", Content{
+	return cli.SendMessageEvent(roomID, EventMessage, Content{
 		MsgType: MsgText,
 		MsgType: MsgText,
 		Body:    text,
 		Body:    text,
 	})
 	})
@@ -509,7 +509,7 @@ func (cli *Client) SendText(roomID, text string) (*RespSendEvent, error) {
 // SendImage sends an m.room.message event into the given room with a msgtype of m.image
 // SendImage sends an m.room.message event into the given room with a msgtype of m.image
 // See https://matrix.org/docs/spec/client_server/r0.2.0.html#m-image
 // See https://matrix.org/docs/spec/client_server/r0.2.0.html#m-image
 func (cli *Client) SendImage(roomID, body, url string) (*RespSendEvent, error) {
 func (cli *Client) SendImage(roomID, body, url string) (*RespSendEvent, error) {
-	return cli.SendMessageEvent(roomID, "m.room.message", Content{
+	return cli.SendMessageEvent(roomID, EventMessage, Content{
 		MsgType: MsgImage,
 		MsgType: MsgImage,
 		Body:    body,
 		Body:    body,
 		URL:     url,
 		URL:     url,
@@ -519,7 +519,7 @@ func (cli *Client) SendImage(roomID, body, url string) (*RespSendEvent, error) {
 // SendVideo sends an m.room.message event into the given room with a msgtype of m.video
 // SendVideo sends an m.room.message event into the given room with a msgtype of m.video
 // See https://matrix.org/docs/spec/client_server/r0.2.0.html#m-video
 // See https://matrix.org/docs/spec/client_server/r0.2.0.html#m-video
 func (cli *Client) SendVideo(roomID, body, url string) (*RespSendEvent, error) {
 func (cli *Client) SendVideo(roomID, body, url string) (*RespSendEvent, error) {
-	return cli.SendMessageEvent(roomID, "m.room.message", Content{
+	return cli.SendMessageEvent(roomID, EventMessage, Content{
 		MsgType: MsgVideo,
 		MsgType: MsgVideo,
 		Body:    body,
 		Body:    body,
 		URL:     url,
 		URL:     url,
@@ -529,7 +529,7 @@ func (cli *Client) SendVideo(roomID, body, url string) (*RespSendEvent, error) {
 // SendNotice sends an m.room.message event into the given room with a msgtype of m.notice
 // SendNotice sends an m.room.message event into the given room with a msgtype of m.notice
 // See http://matrix.org/docs/spec/client_server/r0.2.0.html#m-notice
 // See http://matrix.org/docs/spec/client_server/r0.2.0.html#m-notice
 func (cli *Client) SendNotice(roomID, text string) (*RespSendEvent, error) {
 func (cli *Client) SendNotice(roomID, text string) (*RespSendEvent, error) {
-	return cli.SendMessageEvent(roomID, "m.room.message", Content{
+	return cli.SendMessageEvent(roomID, EventMessage, Content{
 		MsgType: MsgNotice,
 		MsgType: MsgNotice,
 		Body:    text,
 		Body:    text,
 	})
 	})
@@ -622,7 +622,7 @@ func (cli *Client) SetPresence(status string) (err error) {
 // the HTTP response body, or return an error.
 // the HTTP response body, or return an error.
 // See http://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-client-r0-rooms-roomid-state-eventtype-statekey
 // See http://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-client-r0-rooms-roomid-state-eventtype-statekey
 func (cli *Client) StateEvent(roomID string, eventType EventType, stateKey string, outContent interface{}) (err error) {
 func (cli *Client) StateEvent(roomID string, eventType EventType, stateKey string, outContent interface{}) (err error) {
-	u := cli.BuildURL("rooms", roomID, "state", string(eventType), stateKey)
+	u := cli.BuildURL("rooms", roomID, "state", eventType.String(), stateKey)
 	_, err = cli.MakeRequest("GET", u, nil, outContent)
 	_, err = cli.MakeRequest("GET", u, nil, outContent)
 	return
 	return
 }
 }

+ 39 - 23
vendor/maunium.net/go/gomatrix/events.go

@@ -5,28 +5,44 @@ import (
 	"sync"
 	"sync"
 )
 )
 
 
-type EventType string
+type EventType struct {
+	Type    string
+	IsState bool
+}
+
+func (et *EventType) UnmarshalJSON(data []byte) error {
+	return json.Unmarshal(data, &et.Type)
+}
+
+func (et *EventType) MarshalJSON() ([]byte, error) {
+	return json.Marshal(&et.Type)
+}
+
+func (et *EventType) String() string {
+	return et.Type
+}
+
 type MessageType string
 type MessageType string
 
 
 // State events
 // State events
-const (
-	StateAliases        EventType = "m.room.aliases"
-	StateCanonicalAlias           = "m.room.canonical_alias"
-	StateCreate                   = "m.room.create"
-	StateJoinRules                = "m.room.join_rules"
-	StateMember                   = "m.room.member"
-	StatePowerLevels              = "m.room.power_levels"
-	StateRoomName                 = "m.room.name"
-	StateTopic                    = "m.room.topic"
-	StateRoomAvatar               = "m.room.avatar"
-	StatePinnedEvents             = "m.room.pinned_events"
+var (
+	StateAliases        = EventType{"m.room.aliases", true}
+	StateCanonicalAlias = EventType{"m.room.canonical_alias", true}
+	StateCreate         = EventType{"m.room.create", true}
+	StateJoinRules      = EventType{"m.room.join_rules", true}
+	StateMember         = EventType{"m.room.member", true}
+	StatePowerLevels    = EventType{"m.room.power_levels", true}
+	StateRoomName       = EventType{"m.room.name", true}
+	StateTopic          = EventType{"m.room.topic", true}
+	StateRoomAvatar     = EventType{"m.room.avatar", true}
+	StatePinnedEvents   = EventType{"m.room.pinned_events", true}
 )
 )
 
 
 // Message events
 // Message events
-const (
-	EventRedaction EventType = "m.room.redaction"
-	EventMessage             = "m.room.message"
-	EventSticker             = "m.sticker"
+var (
+	EventRedaction = EventType{"m.room.redaction", false}
+	EventMessage   = EventType{"m.room.message", false}
+	EventSticker   = EventType{"m.sticker", false}
 )
 )
 
 
 // Msgtypes
 // Msgtypes
@@ -258,12 +274,12 @@ func (pl *PowerLevels) EnsureUserLevel(userID string, level int) bool {
 	return false
 	return false
 }
 }
 
 
-func (pl *PowerLevels) GetEventLevel(eventType EventType, isState bool) int {
+func (pl *PowerLevels) GetEventLevel(eventType EventType) int {
 	pl.eventsLock.RLock()
 	pl.eventsLock.RLock()
 	defer pl.eventsLock.RUnlock()
 	defer pl.eventsLock.RUnlock()
 	level, ok := pl.Events[eventType]
 	level, ok := pl.Events[eventType]
 	if !ok {
 	if !ok {
-		if isState {
+		if eventType.IsState {
 			return pl.StateDefault()
 			return pl.StateDefault()
 		}
 		}
 		return pl.EventsDefault
 		return pl.EventsDefault
@@ -271,20 +287,20 @@ func (pl *PowerLevels) GetEventLevel(eventType EventType, isState bool) int {
 	return level
 	return level
 }
 }
 
 
-func (pl *PowerLevels) SetEventLevel(eventType EventType, isState bool, level int) {
+func (pl *PowerLevels) SetEventLevel(eventType EventType, level int) {
 	pl.eventsLock.Lock()
 	pl.eventsLock.Lock()
 	defer pl.eventsLock.Unlock()
 	defer pl.eventsLock.Unlock()
-	if (isState && level == pl.StateDefault()) || (!isState && level == pl.EventsDefault) {
+	if (eventType.IsState && level == pl.StateDefault()) || (!eventType.IsState && level == pl.EventsDefault) {
 		delete(pl.Events, eventType)
 		delete(pl.Events, eventType)
 	} else {
 	} else {
 		pl.Events[eventType] = level
 		pl.Events[eventType] = level
 	}
 	}
 }
 }
 
 
-func (pl *PowerLevels) EnsureEventLevel(eventType EventType, isState bool, level int) bool {
-	existingLevel := pl.GetEventLevel(eventType, isState)
+func (pl *PowerLevels) EnsureEventLevel(eventType EventType, level int) bool {
+	existingLevel := pl.GetEventLevel(eventType)
 	if existingLevel != level {
 	if existingLevel != level {
-		pl.SetEventLevel(eventType, isState, level)
+		pl.SetEventLevel(eventType, level)
 		return true
 		return true
 	}
 	}
 	return false
 	return false

+ 19 - 8
vendor/maunium.net/go/mautrix-appservice/appservice.go

@@ -2,17 +2,19 @@ package appservice
 
 
 import (
 import (
 	"fmt"
 	"fmt"
+	"html/template"
 	"io/ioutil"
 	"io/ioutil"
 	"os"
 	"os"
+	"path/filepath"
 
 
 	"gopkg.in/yaml.v2"
 	"gopkg.in/yaml.v2"
 
 
-	"maunium.net/go/maulogger"
-	"strings"
-	"net/http"
 	"errors"
 	"errors"
 	"maunium.net/go/gomatrix"
 	"maunium.net/go/gomatrix"
+	"maunium.net/go/maulogger"
+	"net/http"
 	"regexp"
 	"regexp"
+	"strings"
 )
 )
 
 
 // EventChannelSize is the size for the Events channel in Appservice instances.
 // EventChannelSize is the size for the Events channel in Appservice instances.
@@ -263,15 +265,24 @@ func CreateLogConfig() LogConfig {
 	}
 	}
 }
 }
 
 
+type FileFormatData struct {
+	Date string
+	Index int
+}
+
 // GetFileFormat returns a mauLogger-compatible logger file format based on the data in the struct.
 // GetFileFormat returns a mauLogger-compatible logger file format based on the data in the struct.
 func (lc LogConfig) GetFileFormat() maulogger.LoggerFileFormat {
 func (lc LogConfig) GetFileFormat() maulogger.LoggerFileFormat {
-	path := lc.FileNameFormat
-	if len(lc.Directory) > 0 {
-		path = lc.Directory + "/" + path
-	}
+	os.MkdirAll(lc.Directory, 0700)
+	path := filepath.Join(lc.Directory, lc.FileNameFormat)
+	tpl, _ := template.New("fileformat").Parse(path)
 
 
 	return func(now string, i int) string {
 	return func(now string, i int) string {
-		return fmt.Sprintf(path, now, i)
+		var buf strings.Builder
+		tpl.Execute(&buf, FileFormatData{
+			Date: now,
+			Index: i,
+		})
+		return buf.String()
 	}
 	}
 }
 }
 
 

+ 3 - 3
vendor/maunium.net/go/mautrix-appservice/intent.go

@@ -201,19 +201,19 @@ func (intent *IntentAPI) RedactEvent(roomID, eventID string, req *gomatrix.ReqRe
 }
 }
 
 
 func (intent *IntentAPI) SetRoomName(roomID, roomName string) (*gomatrix.RespSendEvent, error) {
 func (intent *IntentAPI) SetRoomName(roomID, roomName string) (*gomatrix.RespSendEvent, error) {
-	return intent.SendStateEvent(roomID, "m.room.name", "", map[string]interface{}{
+	return intent.SendStateEvent(roomID, gomatrix.StateRoomName, "", map[string]interface{}{
 		"name": roomName,
 		"name": roomName,
 	})
 	})
 }
 }
 
 
 func (intent *IntentAPI) SetRoomAvatar(roomID, avatarURL string) (*gomatrix.RespSendEvent, error) {
 func (intent *IntentAPI) SetRoomAvatar(roomID, avatarURL string) (*gomatrix.RespSendEvent, error) {
-	return intent.SendStateEvent(roomID, "m.room.avatar", "", map[string]interface{}{
+	return intent.SendStateEvent(roomID, gomatrix.StateRoomAvatar, "", map[string]interface{}{
 		"url": avatarURL,
 		"url": avatarURL,
 	})
 	})
 }
 }
 
 
 func (intent *IntentAPI) SetRoomTopic(roomID, topic string) (*gomatrix.RespSendEvent, error) {
 func (intent *IntentAPI) SetRoomTopic(roomID, topic string) (*gomatrix.RespSendEvent, error) {
-	return intent.SendStateEvent(roomID, "m.room.topic", "", map[string]interface{}{
+	return intent.SendStateEvent(roomID, gomatrix.StateTopic, "", map[string]interface{}{
 		"topic": topic,
 		"topic": topic,
 	})
 	})
 }
 }

+ 23 - 16
vendor/maunium.net/go/mautrix-appservice/statestore.go

@@ -15,13 +15,15 @@ type StateStore interface {
 	SetTyping(roomID, userID string, timeout int64)
 	SetTyping(roomID, userID string, timeout int64)
 
 
 	IsInRoom(roomID, userID string) bool
 	IsInRoom(roomID, userID string) bool
+	IsInvited(roomID, userID string) bool
+	IsMembership(roomID, userID string, allowedMemberships ...string) bool
 	SetMembership(roomID, userID, membership string)
 	SetMembership(roomID, userID, membership string)
 
 
 	SetPowerLevels(roomID string, levels *gomatrix.PowerLevels)
 	SetPowerLevels(roomID string, levels *gomatrix.PowerLevels)
 	GetPowerLevels(roomID string) *gomatrix.PowerLevels
 	GetPowerLevels(roomID string) *gomatrix.PowerLevels
 	GetPowerLevel(roomID, userID string) int
 	GetPowerLevel(roomID, userID string) int
-	GetPowerLevelRequirement(roomID string, eventType gomatrix.EventType, isState bool) int
-	HasPowerLevel(roomID, userID string, eventType gomatrix.EventType, isState bool) bool
+	GetPowerLevelRequirement(roomID string, eventType gomatrix.EventType) int
+	HasPowerLevel(roomID, userID string, eventType gomatrix.EventType) bool
 }
 }
 
 
 func (as *AppService) UpdateState(evt *gomatrix.Event) {
 func (as *AppService) UpdateState(evt *gomatrix.Event) {
@@ -126,7 +128,21 @@ func (store *BasicStateStore) GetMembership(roomID, userID string) string {
 }
 }
 
 
 func (store *BasicStateStore) IsInRoom(roomID, userID string) bool {
 func (store *BasicStateStore) IsInRoom(roomID, userID string) bool {
-	return store.GetMembership(roomID, userID) == "join"
+	return store.IsMembership(roomID, userID, "join")
+}
+
+func (store *BasicStateStore) IsInvited(roomID, userID string) bool {
+	return store.IsMembership(roomID, userID, "join", "invite")
+}
+
+func (store *BasicStateStore) IsMembership(roomID, userID string, allowedMemberships ...string) bool {
+	membership := store.GetMembership(roomID, userID)
+	for _, allowedMembership := range allowedMemberships {
+		if allowedMembership == membership {
+			return true
+		}
+	}
+	return false
 }
 }
 
 
 func (store *BasicStateStore) SetMembership(roomID, userID, membership string) {
 func (store *BasicStateStore) SetMembership(roomID, userID, membership string) {
@@ -160,19 +176,10 @@ func (store *BasicStateStore) GetPowerLevel(roomID, userID string) int {
 	return store.GetPowerLevels(roomID).GetUserLevel(userID)
 	return store.GetPowerLevels(roomID).GetUserLevel(userID)
 }
 }
 
 
-func (store *BasicStateStore) GetPowerLevelRequirement(roomID string, eventType gomatrix.EventType, isState bool) int {
-	levels := store.GetPowerLevels(roomID)
-	switch eventType {
-	case "kick":
-		return levels.Kick()
-	case "invite":
-		return levels.Invite()
-	case "redact":
-		return levels.Redact()
-	}
-	return levels.GetEventLevel(eventType, isState)
+func (store *BasicStateStore) GetPowerLevelRequirement(roomID string, eventType gomatrix.EventType) int {
+	return store.GetPowerLevels(roomID).GetEventLevel(eventType)
 }
 }
 
 
-func (store *BasicStateStore) HasPowerLevel(roomID, userID string, eventType gomatrix.EventType, isState bool) bool {
-	return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType, isState)
+func (store *BasicStateStore) HasPowerLevel(roomID, userID string, eventType gomatrix.EventType) bool {
+	return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
 }
 }