Tulir Asokan 5 жил өмнө
parent
commit
8bb5407f98

+ 3 - 0
.editorconfig

@@ -10,3 +10,6 @@ insert_final_newline = true
 
 [*.{yaml,yml}]
 indent_style = space
+
+[.gitlab-ci.yml]
+indent_size = 2

+ 49 - 19
.gitlab-ci.yml

@@ -3,16 +3,15 @@ stages:
 - build docker
 - manifest
 
-build:
+.build: &build
   image: golang:1-alpine
   stage: build
-  tags:
-  - amd64
   cache:
     paths:
     - .cache
   before_script:
-  - apk add git build-base
+  - echo "@edge_community http://dl-cdn.alpinelinux.org/alpine/edge/community" >> /etc/apk/repositories
+  - apk add build-base olm-dev@edge_community
   - mkdir -p .cache
   - export GOPATH="$CI_PROJECT_DIR/.cache"
   script:
@@ -22,31 +21,62 @@ build:
     - mautrix-whatsapp
     - example-config.yaml
 
-build docker amd64:
+.build-docker: &build-docker
   image: docker:stable
   stage: build docker
-  tags:
-  - amd64
   before_script:
   - docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY
   script:
   - docker pull $CI_REGISTRY_IMAGE:latest || true
-  - docker build --pull --cache-from $CI_REGISTRY_IMAGE:latest --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 . --file Dockerfile.ci
-  - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64
-  - docker rmi $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64
+  - docker build --pull --cache-from $CI_REGISTRY_IMAGE:latest --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-$DOCKER_ARCH . --file Dockerfile.ci
+  - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-$DOCKER_ARCH
+  - docker rmi $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-$DOCKER_ARCH
 
-build docker arm64:
-  image: docker:stable
-  stage: build docker
+build static amd64:
+  image: golang:1-alpine
+  stage: build
   tags:
-  - arm64
+  - amd64
+  cache:
+    paths:
+    - .cache
   before_script:
-  - docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY
+  - mkdir -p .cache
+  - export GOPATH="$CI_PROJECT_DIR/.cache"
   script:
-  - docker pull $CI_REGISTRY_IMAGE:latest || true
-  - docker build --pull --cache-from $CI_REGISTRY_IMAGE:latest --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 .
-  - docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64
-  - docker rmi $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64
+  - CGO_ENABLED=0 go build -o mautrix-whatsapp
+  artifacts:
+    paths:
+    - mautrix-whatsapp
+    - example-config.yaml
+
+build amd64:
+  <<: *build
+  tags:
+  - amd64
+
+build arm64:
+  <<: *build
+  tags:
+  - arm64
+
+build docker amd64:
+  <<: *build-docker
+  tags:
+  - amd64
+  dependencies:
+  - build amd64
+  variables:
+    DOCKER_ARCH: amd64
+
+build docker arm64:
+  <<: *build-docker
+  tags:
+  - arm64
+  dependencies:
+  - build arm64
+  variables:
+    DOCKER_ARCH: arm64
 
 manifest:
   stage: manifest

+ 5 - 3
Dockerfile

@@ -1,6 +1,7 @@
-FROM golang:1.12-alpine AS builder
+FROM golang:1-alpine AS builder
 
-RUN apk add --no-cache git ca-certificates build-base su-exec
+RUN echo "@edge_community http://dl-cdn.alpinelinux.org/alpine/edge/community" >> /etc/apk/repositories
+RUN apk add --no-cache git ca-certificates build-base su-exec olm-dev@edge_community
 
 WORKDIR /build
 COPY go.mod go.sum /build/
@@ -14,7 +15,8 @@ FROM alpine:latest
 ENV UID=1337 \
     GID=1337
 
-RUN apk add --no-cache su-exec ca-certificates
+RUN echo "@edge_community http://dl-cdn.alpinelinux.org/alpine/edge/community" >> /etc/apk/repositories
+RUN apk add --no-cache su-exec ca-certificates olm@edge_community
 
 COPY --from=builder /usr/bin/mautrix-whatsapp /usr/bin/mautrix-whatsapp
 COPY --from=builder /build/example-config.yaml /opt/mautrix-whatsapp/example-config.yaml

+ 2 - 1
Dockerfile.ci

@@ -3,7 +3,8 @@ FROM alpine:latest
 ENV UID=1337 \
     GID=1337
 
-RUN apk add --no-cache su-exec ca-certificates
+RUN echo "@edge_community http://dl-cdn.alpinelinux.org/alpine/edge/community" >> /etc/apk/repositories
+RUN apk add --no-cache su-exec ca-certificates olm@edge_community
 
 ARG EXECUTABLE=./mautrix-whatsapp
 COPY $EXECUTABLE /usr/bin/mautrix-whatsapp

+ 68 - 11
commands.go

@@ -1,5 +1,5 @@
 // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2019 Tulir Asokan
+// Copyright (C) 2020 Tulir Asokan
 //
 // This program is free software: you can redistribute it and/or modify
 // it under the terms of the GNU Affero General Public License as published by
@@ -18,6 +18,7 @@ package main
 
 import (
 	"fmt"
+	"strconv"
 	"strings"
 
 	"github.com/Rhymen/go-whatsapp"
@@ -25,11 +26,12 @@ import (
 	"maunium.net/go/maulogger/v2"
 
 	"maunium.net/go/mautrix"
-	"maunium.net/go/mautrix-appservice"
+	"maunium.net/go/mautrix/appservice"
+	"maunium.net/go/mautrix/event"
 	"maunium.net/go/mautrix/format"
+	"maunium.net/go/mautrix/id"
 
 	"maunium.net/go/mautrix-whatsapp/database"
-	"maunium.net/go/mautrix-whatsapp/types"
 	"maunium.net/go/mautrix-whatsapp/whatsapp-ext"
 )
 
@@ -51,7 +53,7 @@ type CommandEvent struct {
 	Bot     *appservice.IntentAPI
 	Bridge  *Bridge
 	Handler *CommandHandler
-	RoomID  types.MatrixRoomID
+	RoomID  id.RoomID
 	User    *User
 	Command string
 	Args    []string
@@ -59,20 +61,20 @@ type CommandEvent struct {
 
 // Reply sends a reply to command as notice
 func (ce *CommandEvent) Reply(msg string, args ...interface{}) {
-	content := format.RenderMarkdown(fmt.Sprintf(msg, args...))
-	content.MsgType = mautrix.MsgNotice
+	content := format.RenderMarkdown(fmt.Sprintf(msg, args...), true, false)
+	content.MsgType = event.MsgNotice
 	room := ce.User.ManagementRoom
 	if len(room) == 0 {
 		room = ce.RoomID
 	}
-	_, err := ce.Bot.SendMessageEvent(room, mautrix.EventMessage, content)
+	_, err := ce.Bot.SendMessageEvent(room, event.EventMessage, content)
 	if err != nil {
 		ce.Handler.log.Warnfln("Failed to reply to command from %s: %v", ce.User.MXID, err)
 	}
 }
 
 // Handle handles messages to the bridge
-func (handler *CommandHandler) Handle(roomID types.MatrixRoomID, user *User, message string) {
+func (handler *CommandHandler) Handle(roomID id.RoomID, user *User, message string) {
 	args := strings.Split(message, " ")
 	ce := &CommandEvent{
 		Bot:     handler.bridge.Bot,
@@ -117,7 +119,11 @@ func (handler *CommandHandler) CommandMux(ce *CommandEvent) {
 		handler.CommandDeleteAllPortals(ce)
 	case "dev-test":
 		handler.CommandDevTest(ce)
-	case "login-matrix", "logout", "sync", "list", "open", "pm":
+	case "set-pl":
+		handler.CommandSetPowerLevel(ce)
+	case "logout":
+		handler.CommandLogout(ce)
+	case "login-matrix", "sync", "list", "open", "pm":
 		if !ce.User.HasSession() {
 			ce.Reply("You are not logged in. Use the `login` command to log into WhatsApp.")
 			return
@@ -129,8 +135,6 @@ func (handler *CommandHandler) CommandMux(ce *CommandEvent) {
 		switch ce.Command {
 		case "login-matrix":
 			handler.CommandLoginMatrix(ce)
-		case "logout":
-			handler.CommandLogout(ce)
 		case "sync":
 			handler.CommandSync(ce)
 		case "list":
@@ -168,6 +172,45 @@ func (handler *CommandHandler) CommandDevTest(ce *CommandEvent) {
 
 }
 
+func (handler *CommandHandler) CommandSetPowerLevel(ce *CommandEvent) {
+	portal := ce.Bridge.GetPortalByMXID(ce.RoomID)
+	if portal == nil {
+		ce.Reply("Not a portal room")
+		return
+	}
+	var level int
+	var userID id.UserID
+	var err error
+	if len(ce.Args) == 1 {
+		level, err = strconv.Atoi(ce.Args[0])
+		if err != nil {
+			ce.Reply("Invalid power level \"%s\"", ce.Args[0])
+			return
+		}
+		userID = ce.User.MXID
+	} else if len(ce.Args) == 2 {
+		userID = id.UserID(ce.Args[0])
+		_, _, err := userID.Parse()
+		if err != nil {
+			ce.Reply("Invalid user ID \"%s\"", ce.Args[0])
+			return
+		}
+		level, err = strconv.Atoi(ce.Args[1])
+		if err != nil {
+			ce.Reply("Invalid power level \"%s\"", ce.Args[1])
+			return
+		}
+	} else {
+		ce.Reply("**Usage:** `set-pl [user] <level>`")
+		return
+	}
+	intent := portal.MainIntent()
+	_, err = intent.SetPowerLevel(ce.RoomID, userID, level)
+	if err != nil {
+		ce.Reply("Failed to set power levels: %v", err)
+	}
+}
+
 const cmdLoginHelp = `login - Authenticate this Bridge as WhatsApp Web Client`
 
 // CommandLogin handles login command
@@ -186,6 +229,16 @@ func (handler *CommandHandler) CommandLogout(ce *CommandEvent) {
 	if ce.User.Session == nil {
 		ce.Reply("You're not logged in.")
 		return
+	} else if !ce.User.IsConnected() {
+		ce.Reply("You are not connected to WhatsApp. Use the `reconnect` command to reconnect, or `delete-session` to forget all login information.")
+		return
+	}
+	puppet := handler.bridge.GetPuppetByJID(ce.User.JID)
+	if puppet.CustomMXID != "" {
+		err := puppet.SwitchCustomMXID("", "")
+		if err != nil {
+			ce.User.log.Warnln("Failed to logout-matrix while logging out of WhatsApp:", err)
+		}
 	}
 	err := ce.User.Conn.Logout()
 	if err != nil {
@@ -199,6 +252,9 @@ func (handler *CommandHandler) CommandLogout(ce *CommandEvent) {
 	}
 	ce.User.Conn.RemoveHandlers()
 	ce.User.Conn = nil
+	ce.User.removeFromJIDMap()
+	// TODO this causes a foreign key violation, which should be fixed
+	//ce.User.JID = ""
 	ce.User.SetSession(nil)
 	ce.Reply("Logged out successfully.")
 }
@@ -516,6 +572,7 @@ func (handler *CommandHandler) CommandOpen(ce *CommandEvent) {
 		portal.Sync(user, contact)
 		ce.Reply("Portal room created.")
 	}
+	_, _ = portal.MainIntent().InviteUser(portal.MXID, &mautrix.ReqInviteUser{UserID: user.MXID})
 }
 
 const cmdPMHelp = `pm [--force] <_international phone number_> - Open a private chat with the given phone number.`

+ 4 - 5
community.go

@@ -1,5 +1,5 @@
 // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2019 Tulir Asokan
+// Copyright (C) 2020 Tulir Asokan
 //
 // This program is free software: you can redistribute it and/or modify
 // it under the terms of the GNU Affero General Public License as published by
@@ -21,7 +21,6 @@ import (
 	"net/http"
 
 	"maunium.net/go/mautrix"
-	appservice "maunium.net/go/mautrix-appservice"
 )
 
 func (user *User) inviteToCommunity() {
@@ -51,7 +50,7 @@ func (user *User) createCommunity() {
 		return
 	}
 
-	localpart, server := appservice.ParseUserID(user.MXID)
+	localpart, server, _ := user.MXID.Parse()
 	community := user.bridge.Config.Bridge.FormatCommunity(localpart, server)
 	user.log.Debugln("Creating personal filtering community", community)
 	bot := user.bridge.Bot
@@ -100,8 +99,8 @@ func (user *User) addPuppetToCommunity(puppet *Puppet) bool {
 			"type": "private",
 		},
 	}
-	url = bot.BuildURLWithQuery([]string{"groups", user.CommunityID, "self", "accept_invite"}, map[string]string{
-		"user_id": puppet.MXID,
+	url = bot.BuildURLWithQuery(mautrix.URLPath{"groups", user.CommunityID, "self", "accept_invite"}, map[string]string{
+		"user_id": puppet.MXID.String(),
 	})
 	_, err = bot.MakeRequest(http.MethodPut, url, &reqBody, nil)
 	if err != nil {

+ 31 - 26
config/bridge.go

@@ -1,5 +1,5 @@
 // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2019 Tulir Asokan
+// Copyright (C) 2020 Tulir Asokan
 //
 // This program is free software: you can redistribute it and/or modify
 // it under the terms of the GNU Affero General Public License as published by
@@ -24,8 +24,8 @@ import (
 
 	"github.com/Rhymen/go-whatsapp"
 
-	"maunium.net/go/mautrix"
-	"maunium.net/go/mautrix-appservice"
+	"maunium.net/go/mautrix/event"
+	"maunium.net/go/mautrix/id"
 
 	"maunium.net/go/mautrix-whatsapp/types"
 )
@@ -54,8 +54,8 @@ type BridgeConfig struct {
 	RecoverHistory     bool   `yaml:"recovery_history_backfill"`
 	SyncChatMaxAge     uint64 `yaml:"sync_max_chat_age"`
 
-	SyncWithCustomPuppets bool `yaml:"sync_with_custom_puppets"`
-	LoginSharedSecret string `yaml:"login_shared_secret"`
+	SyncWithCustomPuppets bool   `yaml:"sync_with_custom_puppets"`
+	LoginSharedSecret     string `yaml:"login_shared_secret"`
 
 	InviteOwnPuppetForBackfilling bool `yaml:"invite_own_puppet_for_backfilling"`
 	PrivateChatPortalMeta         bool `yaml:"private_chat_portal_meta"`
@@ -64,6 +64,11 @@ type BridgeConfig struct {
 
 	CommandPrefix string `yaml:"command_prefix"`
 
+	Encryption struct {
+		Allow   bool `yaml:"allow"`
+		Default bool `yaml:"default"`
+	} `yaml:"encryption"`
+
 	Permissions PermissionConfig `yaml:"permissions"`
 
 	Relaybot RelaybotConfig `yaml:"relaybot"`
@@ -127,7 +132,7 @@ func (bc *BridgeConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
 }
 
 type UsernameTemplateArgs struct {
-	UserID string
+	UserID id.UserID
 }
 
 func (bc BridgeConfig) FormatDisplayname(contact whatsapp.Contact) (string, int8) {
@@ -232,25 +237,25 @@ func (pc *PermissionConfig) MarshalYAML() (interface{}, error) {
 	return rawPC, nil
 }
 
-func (pc PermissionConfig) IsRelaybotWhitelisted(userID string) bool {
+func (pc PermissionConfig) IsRelaybotWhitelisted(userID id.UserID) bool {
 	return pc.GetPermissionLevel(userID) >= PermissionLevelRelaybot
 }
 
-func (pc PermissionConfig) IsWhitelisted(userID string) bool {
+func (pc PermissionConfig) IsWhitelisted(userID id.UserID) bool {
 	return pc.GetPermissionLevel(userID) >= PermissionLevelUser
 }
 
-func (pc PermissionConfig) IsAdmin(userID string) bool {
+func (pc PermissionConfig) IsAdmin(userID id.UserID) bool {
 	return pc.GetPermissionLevel(userID) >= PermissionLevelAdmin
 }
 
-func (pc PermissionConfig) GetPermissionLevel(userID string) PermissionLevel {
-	permissions, ok := pc[userID]
+func (pc PermissionConfig) GetPermissionLevel(userID id.UserID) PermissionLevel {
+	permissions, ok := pc[string(userID)]
 	if ok {
 		return permissions
 	}
 
-	_, homeserver := appservice.ParseUserID(userID)
+	_, homeserver, _ := userID.Parse()
 	permissions, ok = pc[homeserver]
 	if len(homeserver) > 0 && ok {
 		return permissions
@@ -265,12 +270,12 @@ func (pc PermissionConfig) GetPermissionLevel(userID string) PermissionLevel {
 }
 
 type RelaybotConfig struct {
-	Enabled        bool                 `yaml:"enabled"`
-	ManagementRoom string               `yaml:"management"`
-	InviteUsers    []types.MatrixUserID `yaml:"invites"`
+	Enabled        bool        `yaml:"enabled"`
+	ManagementRoom id.RoomID   `yaml:"management"`
+	InviteUsers    []id.UserID `yaml:"invites"`
 
-	MessageFormats   map[mautrix.MessageType]string `yaml:"message_formats"`
-	messageTemplates *template.Template             `yaml:"-"`
+	MessageFormats   map[event.MessageType]string `yaml:"message_formats"`
+	messageTemplates *template.Template           `yaml:"-"`
 }
 
 type umRelaybotConfig RelaybotConfig
@@ -293,25 +298,25 @@ func (rc *RelaybotConfig) UnmarshalYAML(unmarshal func(interface{}) error) error
 }
 
 type Sender struct {
-	UserID types.MatrixUserID
-	mautrix.Member
+	UserID id.UserID
+	*event.MemberEventContent
 }
 
 type formatData struct {
 	Sender  Sender
 	Message string
-	Content mautrix.Content
+	Content *event.MessageEventContent
 }
 
-func (rc *RelaybotConfig) FormatMessage(evt *mautrix.Event, member mautrix.Member) (string, error) {
+func (rc *RelaybotConfig) FormatMessage(content *event.MessageEventContent, sender id.UserID, member *event.MemberEventContent) (string, error) {
 	var output strings.Builder
-	err := rc.messageTemplates.ExecuteTemplate(&output, string(evt.Content.MsgType), formatData{
+	err := rc.messageTemplates.ExecuteTemplate(&output, string(content.MsgType), formatData{
 		Sender: Sender{
-			UserID: evt.Sender,
-			Member: member,
+			UserID:             sender,
+			MemberEventContent: member,
 		},
-		Content: evt.Content,
-		Message: evt.Content.FormattedBody,
+		Content: content,
+		Message: content.FormattedBody,
 	})
 	return output.String(), err
 }

+ 1 - 1
config/config.go

@@ -21,7 +21,7 @@ import (
 
 	"gopkg.in/yaml.v2"
 
-	"maunium.net/go/mautrix-appservice"
+	"maunium.net/go/mautrix/appservice"
 )
 
 type Config struct {

+ 1 - 1
config/registration.go

@@ -20,7 +20,7 @@ import (
 	"fmt"
 	"regexp"
 
-	"maunium.net/go/mautrix-appservice"
+	"maunium.net/go/mautrix/appservice"
 )
 
 func (config *Config) NewRegistration() (*appservice.Registration, error) {

+ 242 - 0
crypto.go

@@ -0,0 +1,242 @@
+// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
+// Copyright (C) 2020 Tulir Asokan
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+// +build cgo
+
+package main
+
+import (
+	"crypto/hmac"
+	"crypto/sha512"
+	"encoding/hex"
+	"fmt"
+	"time"
+
+	"github.com/pkg/errors"
+	"maunium.net/go/maulogger/v2"
+
+	"maunium.net/go/mautrix"
+	"maunium.net/go/mautrix-whatsapp/database"
+	"maunium.net/go/mautrix/crypto"
+	"maunium.net/go/mautrix/event"
+	"maunium.net/go/mautrix/id"
+)
+
+var levelTrace = maulogger.Level{
+	Name:     "Trace",
+	Severity: -10,
+	Color:    -1,
+}
+
+type CryptoHelper struct {
+	bridge  *Bridge
+	client  *mautrix.Client
+	mach    *crypto.OlmMachine
+	store   *database.SQLCryptoStore
+	log     maulogger.Logger
+	baseLog maulogger.Logger
+}
+
+func NewCryptoHelper(bridge *Bridge) Crypto {
+	if !bridge.Config.Bridge.Encryption.Allow {
+		bridge.Log.Debugln("Bridge built with end-to-bridge encryption, but disabled in config")
+		return nil
+	} else if bridge.Config.Bridge.LoginSharedSecret == "" {
+		bridge.Log.Warnln("End-to-bridge encryption enabled, but login_shared_secret not set")
+		return nil
+	}
+	baseLog := bridge.Log.Sub("Crypto")
+	return &CryptoHelper{
+		bridge:  bridge,
+		log:     baseLog.Sub("Helper"),
+		baseLog: baseLog,
+	}
+}
+
+func (helper *CryptoHelper) Init() error {
+	helper.log.Debugln("Initializing end-to-bridge encryption...")
+	var err error
+	helper.client, err = helper.loginBot()
+	if err != nil {
+		return err
+	}
+
+	helper.log.Debugln("Logged in as bridge bot with device ID", helper.client.DeviceID)
+	logger := &cryptoLogger{helper.baseLog}
+	stateStore := &cryptoStateStore{helper.bridge}
+	helper.store = database.NewSQLCryptoStore(helper.bridge.DB, helper.client.DeviceID)
+	helper.store.UserID = helper.client.UserID
+	helper.store.GhostIDFormat = fmt.Sprintf("@%s:%s", helper.bridge.Config.Bridge.FormatUsername("%"), helper.bridge.AS.HomeserverDomain)
+	helper.mach = crypto.NewOlmMachine(helper.client, logger, helper.store, stateStore)
+
+	helper.client.Logger = logger.int.Sub("Bot")
+	helper.client.Syncer = &cryptoSyncer{helper.mach}
+	helper.client.Store = &cryptoClientStore{helper.store}
+
+	return helper.mach.Load()
+}
+
+func (helper *CryptoHelper) loginBot() (*mautrix.Client, error) {
+	deviceID := helper.bridge.DB.FindDeviceID()
+	if len(deviceID) > 0 {
+		helper.log.Debugln("Found existing device ID for bot in database:", deviceID)
+	}
+	mac := hmac.New(sha512.New, []byte(helper.bridge.Config.Bridge.LoginSharedSecret))
+	mac.Write([]byte(helper.bridge.AS.BotMXID()))
+	resp, err := helper.bridge.AS.BotClient().Login(&mautrix.ReqLogin{
+		Type:                     "m.login.password",
+		Identifier:               mautrix.UserIdentifier{Type: "m.id.user", User: string(helper.bridge.AS.BotMXID())},
+		Password:                 hex.EncodeToString(mac.Sum(nil)),
+		DeviceID:                 deviceID,
+		InitialDeviceDisplayName: "WhatsApp Bridge",
+	})
+	if err != nil {
+		return nil, err
+	}
+	client, err := mautrix.NewClient(helper.bridge.AS.HomeserverURL, helper.bridge.AS.BotMXID(), resp.AccessToken)
+	if err != nil {
+		return nil, err
+	}
+	client.DeviceID = resp.DeviceID
+	return client, nil
+}
+
+func (helper *CryptoHelper) Start() {
+	helper.log.Debugln("Starting syncer for receiving to-device messages")
+	err := helper.client.Sync()
+	if err != nil {
+		helper.log.Errorln("Fatal error syncing:", err)
+	}
+}
+
+func (helper *CryptoHelper) Stop() {
+	helper.client.StopSync()
+}
+
+func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
+	return helper.mach.DecryptMegolmEvent(evt)
+}
+
+func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content event.Content) (*event.EncryptedEventContent, error) {
+	encrypted, err := helper.mach.EncryptMegolmEvent(roomID, evtType, content)
+	if err != nil {
+		if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession {
+			return nil, err
+		}
+		helper.log.Debugfln("Got %v while encrypting event for %s, sharing group session and trying again...", err, roomID)
+		users, err := helper.store.GetRoomMembers(roomID)
+		if err != nil {
+			return nil, errors.Wrap(err, "failed to get room member list")
+		}
+		err = helper.mach.ShareGroupSession(roomID, users)
+		if err != nil {
+			return nil, errors.Wrap(err, "failed to share group session")
+		}
+		encrypted, err = helper.mach.EncryptMegolmEvent(roomID, evtType, content)
+		if err != nil {
+			return nil, errors.Wrap(err, "failed to encrypt event after re-sharing group session")
+		}
+	}
+	return encrypted, nil
+}
+
+func (helper *CryptoHelper) HandleMemberEvent(evt *event.Event) {
+	helper.mach.HandleMemberEvent(evt)
+}
+
+type cryptoSyncer struct {
+	*crypto.OlmMachine
+}
+
+func (syncer *cryptoSyncer) ProcessResponse(resp *mautrix.RespSync, since string) error {
+	syncer.ProcessSyncResponse(resp, since)
+	return nil
+}
+
+func (syncer *cryptoSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) {
+	syncer.Log.Error("Error /syncing, waiting 10 seconds: %v", err)
+	return 10 * time.Second, nil
+}
+
+func (syncer *cryptoSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter {
+	everything := []event.Type{{Type: "*"}}
+	return &mautrix.Filter{
+		Presence:    mautrix.FilterPart{NotTypes: everything},
+		AccountData: mautrix.FilterPart{NotTypes: everything},
+		Room: mautrix.RoomFilter{
+			IncludeLeave: false,
+			Ephemeral:    mautrix.FilterPart{NotTypes: everything},
+			AccountData:  mautrix.FilterPart{NotTypes: everything},
+			State:        mautrix.FilterPart{NotTypes: everything},
+			Timeline:     mautrix.FilterPart{NotTypes: everything},
+		},
+	}
+}
+
+type cryptoLogger struct {
+	int maulogger.Logger
+}
+
+func (c *cryptoLogger) Error(message string, args ...interface{}) {
+	c.int.Errorfln(message, args...)
+}
+
+func (c *cryptoLogger) Warn(message string, args ...interface{}) {
+	c.int.Warnfln(message, args...)
+}
+
+func (c *cryptoLogger) Debug(message string, args ...interface{}) {
+	c.int.Debugfln(message, args...)
+}
+
+func (c *cryptoLogger) Trace(message string, args ...interface{}) {
+	c.int.Logfln(levelTrace, message, args...)
+}
+
+type cryptoClientStore struct {
+	int *database.SQLCryptoStore
+}
+
+func (c cryptoClientStore) SaveFilterID(_ id.UserID, _ string) {}
+func (c cryptoClientStore) LoadFilterID(_ id.UserID) string    { return "" }
+func (c cryptoClientStore) SaveRoom(_ *mautrix.Room)           {}
+func (c cryptoClientStore) LoadRoom(_ id.RoomID) *mautrix.Room { return nil }
+
+func (c cryptoClientStore) SaveNextBatch(_ id.UserID, nextBatchToken string) {
+	c.int.PutNextBatch(nextBatchToken)
+}
+
+func (c cryptoClientStore) LoadNextBatch(_ id.UserID) string {
+	return c.int.GetNextBatch()
+}
+
+var _ mautrix.Storer = (*cryptoClientStore)(nil)
+
+type cryptoStateStore struct {
+	bridge *Bridge
+}
+
+func (c *cryptoStateStore) IsEncrypted(id id.RoomID) bool {
+	portal := c.bridge.GetPortalByMXID(id)
+	if portal != nil {
+		return portal.Encrypted
+	}
+	return false
+}
+
+func (c *cryptoStateStore) FindSharedRooms(id id.UserID) []id.RoomID {
+	return c.bridge.StateStore.FindSharedRooms(id)
+}

+ 60 - 98
custompuppet.go

@@ -1,5 +1,5 @@
 // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2019 Tulir Asokan
+// Copyright (C) 2020 Tulir Asokan
 //
 // This program is free software: you can redistribute it and/or modify
 // it under the terms of the GNU Affero General Public License as published by
@@ -20,17 +20,16 @@ import (
 	"crypto/hmac"
 	"crypto/sha512"
 	"encoding/hex"
-	"encoding/json"
-	"fmt"
-	"os"
-	"strings"
 	"time"
 
 	"github.com/pkg/errors"
 
 	"github.com/Rhymen/go-whatsapp"
+
 	"maunium.net/go/mautrix"
-	appservice "maunium.net/go/mautrix-appservice"
+	"maunium.net/go/mautrix/appservice"
+	"maunium.net/go/mautrix/event"
+	"maunium.net/go/mautrix/id"
 )
 
 var (
@@ -38,7 +37,7 @@ var (
 	ErrMismatchingMXID = errors.New("whoami result does not match custom mxid")
 )
 
-func (puppet *Puppet) SwitchCustomMXID(accessToken string, mxid string) error {
+func (puppet *Puppet) SwitchCustomMXID(accessToken string, mxid id.UserID) error {
 	prevCustomMXID := puppet.CustomMXID
 	if puppet.customIntent != nil {
 		puppet.stopSyncing()
@@ -63,12 +62,12 @@ func (puppet *Puppet) SwitchCustomMXID(accessToken string, mxid string) error {
 	return nil
 }
 
-func (puppet *Puppet) loginWithSharedSecret(mxid string) (string, error) {
+func (puppet *Puppet) loginWithSharedSecret(mxid id.UserID) (string, error) {
 	mac := hmac.New(sha512.New, []byte(puppet.bridge.Config.Bridge.LoginSharedSecret))
 	mac.Write([]byte(mxid))
 	resp, err := puppet.bridge.AS.BotClient().Login(&mautrix.ReqLogin{
 		Type:                     "m.login.password",
-		Identifier:               mautrix.UserIdentifier{Type: "m.id.user", User: mxid},
+		Identifier:               mautrix.UserIdentifier{Type: "m.id.user", User: string(mxid)},
 		Password:                 hex.EncodeToString(mac.Sum(nil)),
 		DeviceID:                 "WhatsApp Bridge",
 		InitialDeviceDisplayName: "WhatsApp Bridge",
@@ -87,13 +86,13 @@ func (puppet *Puppet) newCustomIntent() (*appservice.IntentAPI, error) {
 	if err != nil {
 		return nil, err
 	}
-	client.Logger = puppet.bridge.AS.Log.Sub(puppet.CustomMXID)
+	client.Logger = puppet.bridge.AS.Log.Sub(string(puppet.CustomMXID))
 	client.Syncer = puppet
 	client.Store = puppet
 
 	ia := puppet.bridge.AS.NewIntentAPI("custom")
 	ia.Client = client
-	ia.Localpart = puppet.CustomMXID[1:strings.IndexRune(puppet.CustomMXID, ':')]
+	ia.Localpart, _, _ = puppet.CustomMXID.Parse()
 	ia.UserID = puppet.CustomMXID
 	ia.IsCustomPuppet = true
 	return ia, nil
@@ -117,11 +116,7 @@ func (puppet *Puppet) StartCustomMXID() error {
 		puppet.clearCustomMXID()
 		return err
 	}
-	urlPath := intent.BuildURL("account", "whoami")
-	var resp struct {
-		UserID string `json:"user_id"`
-	}
-	_, err = intent.MakeRequest("GET", urlPath, nil, &resp)
+	resp, err := intent.Whoami()
 	if err != nil {
 		puppet.clearCustomMXID()
 		return err
@@ -131,7 +126,7 @@ func (puppet *Puppet) StartCustomMXID() error {
 		return ErrMismatchingMXID
 	}
 	puppet.customIntent = intent
-	puppet.customTypingIn = make(map[string]bool)
+	puppet.customTypingIn = make(map[id.RoomID]bool)
 	puppet.customUser = puppet.bridge.GetUserByMXID(puppet.CustomMXID)
 	puppet.startSyncing()
 	return nil
@@ -158,28 +153,6 @@ func (puppet *Puppet) stopSyncing() {
 	puppet.customIntent.StopSync()
 }
 
-func parseEvent(roomID string, data json.RawMessage) *mautrix.Event {
-	event := &mautrix.Event{}
-	err := json.Unmarshal(data, event)
-	if err != nil {
-		// TODO add separate handler for these
-		_, _ = fmt.Fprintf(os.Stderr, "Failed to unmarshal event: %v\n%s\n", err, string(data))
-		return nil
-	}
-	return event
-}
-
-func parsePresenceEvent(data json.RawMessage) *mautrix.Event {
-	event := &mautrix.Event{}
-	err := json.Unmarshal(data, event)
-	if err != nil {
-		// TODO add separate handler for these
-		_, _ = fmt.Fprintf(os.Stderr, "Failed to unmarshal event: %v\n%s\n", err, string(data))
-		return nil
-	}
-	return event
-}
-
 func (puppet *Puppet) ProcessResponse(resp *mautrix.RespSync, since string) error {
 	if !puppet.customUser.IsConnected() {
 		puppet.log.Debugln("Skipping sync processing: custom user not connected to whatsapp")
@@ -190,31 +163,33 @@ func (puppet *Puppet) ProcessResponse(resp *mautrix.RespSync, since string) erro
 		if portal == nil {
 			continue
 		}
-		for _, data := range events.Ephemeral.Events {
-			event := parseEvent(roomID, data)
-			if event != nil {
-				switch event.Type {
-				case mautrix.EphemeralEventReceipt:
-					go puppet.handleReceiptEvent(portal, event)
-				case mautrix.EphemeralEventTyping:
-					go puppet.handleTypingEvent(portal, event)
-				}
+		for _, evt := range events.Ephemeral.Events {
+			err := evt.Content.ParseRaw(evt.Type)
+			if err != nil {
+				continue
+			}
+			switch evt.Type {
+			case event.EphemeralEventReceipt:
+				go puppet.handleReceiptEvent(portal, evt)
+			case event.EphemeralEventTyping:
+				go puppet.handleTypingEvent(portal, evt)
 			}
 		}
 	}
-	for _, data := range resp.Presence.Events {
-		event := parsePresenceEvent(data)
-		if event != nil {
-			if event.Sender != puppet.CustomMXID {
-				continue
-			}
-			go puppet.handlePresenceEvent(event)
+	for _, evt := range resp.Presence.Events {
+		if evt.Sender != puppet.CustomMXID {
+			continue
 		}
+		err := evt.Content.ParseRaw(evt.Type)
+		if err != nil {
+			continue
+		}
+		go puppet.handlePresenceEvent(evt)
 	}
 	return nil
 }
 
-func (puppet *Puppet) handlePresenceEvent(event *mautrix.Event) {
+func (puppet *Puppet) handlePresenceEvent(event *event.Event) {
 	presence := whatsapp.PresenceAvailable
 	if event.Content.Raw["presence"].(string) != "online" {
 		presence = whatsapp.PresenceUnavailable
@@ -228,13 +203,9 @@ func (puppet *Puppet) handlePresenceEvent(event *mautrix.Event) {
 	}
 }
 
-func (puppet *Puppet) handleReceiptEvent(portal *Portal, event *mautrix.Event) {
-	for eventID, rawReceipts := range event.Content.Raw {
-		if receipts, ok := rawReceipts.(map[string]interface{}); !ok {
-			continue
-		} else if readReceipt, ok := receipts["m.read"].(map[string]interface{}); !ok {
-			continue
-		} else if _, ok = readReceipt[puppet.CustomMXID].(map[string]interface{}); !ok {
+func (puppet *Puppet) handleReceiptEvent(portal *Portal, event *event.Event) {
+	for eventID, receipts := range *event.Content.AsReceipt() {
+		if _, ok := receipts.Read[puppet.CustomMXID]; !ok {
 			continue
 		}
 		message := puppet.bridge.DB.Message.GetByMXID(eventID)
@@ -249,16 +220,16 @@ func (puppet *Puppet) handleReceiptEvent(portal *Portal, event *mautrix.Event) {
 	}
 }
 
-func (puppet *Puppet) handleTypingEvent(portal *Portal, event *mautrix.Event) {
+func (puppet *Puppet) handleTypingEvent(portal *Portal, evt *event.Event) {
 	isTyping := false
-	for _, userID := range event.Content.TypingUserIDs {
+	for _, userID := range evt.Content.AsTyping().UserIDs {
 		if userID == puppet.CustomMXID {
 			isTyping = true
 			break
 		}
 	}
-	if puppet.customTypingIn[event.RoomID] != isTyping {
-		puppet.customTypingIn[event.RoomID] = isTyping
+	if puppet.customTypingIn[evt.RoomID] != isTyping {
+		puppet.customTypingIn[evt.RoomID] = isTyping
 		presence := whatsapp.PresenceComposing
 		if !isTyping {
 			puppet.customUser.log.Infofln("Marking not typing in %s/%s", portal.Key.JID, portal.MXID)
@@ -278,36 +249,27 @@ func (puppet *Puppet) OnFailedSync(res *mautrix.RespSync, err error) (time.Durat
 	return 10 * time.Second, nil
 }
 
-func (puppet *Puppet) GetFilterJSON(_ string) json.RawMessage {
-	mxid, _ := json.Marshal(puppet.CustomMXID)
-	return json.RawMessage(fmt.Sprintf(`{
-    "account_data": { "types": [] },
-    "presence": {
-        "senders": [
-            %s
-        ],
-        "types": [
-            "m.presence"
-        ]
-    },
-    "room": {
-        "ephemeral": {
-            "types": [
-                "m.typing",
-                "m.receipt"
-            ]
-        },
-        "include_leave": false,
-        "account_data": { "types": [] },
-        "state": { "types": [] },
-        "timeline": { "types": [] }
-    }
-}`, mxid))
+func (puppet *Puppet) GetFilterJSON(_ id.UserID) *mautrix.Filter {
+	everything := []event.Type{{Type: "*"}}
+	return &mautrix.Filter{
+		Presence: mautrix.FilterPart{
+			Senders: []id.UserID{puppet.CustomMXID},
+			Types:   []event.Type{event.EphemeralEventPresence},
+		},
+		AccountData: mautrix.FilterPart{NotTypes: everything},
+		Room: mautrix.RoomFilter{
+			Ephemeral:    mautrix.FilterPart{Types: []event.Type{event.EphemeralEventTyping, event.EphemeralEventReceipt}},
+			IncludeLeave: false,
+			AccountData:  mautrix.FilterPart{NotTypes: everything},
+			State:        mautrix.FilterPart{NotTypes: everything},
+			Timeline:     mautrix.FilterPart{NotTypes: everything},
+		},
+	}
 }
 
-func (puppet *Puppet) SaveFilterID(_, _ string)             {}
-func (puppet *Puppet) SaveNextBatch(_, nbt string)          { puppet.NextBatch = nbt; puppet.Update() }
-func (puppet *Puppet) SaveRoom(room *mautrix.Room)          {}
-func (puppet *Puppet) LoadFilterID(_ string) string         { return "" }
-func (puppet *Puppet) LoadNextBatch(_ string) string        { return puppet.NextBatch }
-func (puppet *Puppet) LoadRoom(roomID string) *mautrix.Room { return nil }
+func (puppet *Puppet) SaveFilterID(_ id.UserID, _ string)      {}
+func (puppet *Puppet) SaveNextBatch(_ id.UserID, nbt string)   { puppet.NextBatch = nbt; puppet.Update() }
+func (puppet *Puppet) SaveRoom(room *mautrix.Room)             {}
+func (puppet *Puppet) LoadFilterID(_ id.UserID) string         { return "" }
+func (puppet *Puppet) LoadNextBatch(_ id.UserID) string        { return puppet.NextBatch }
+func (puppet *Puppet) LoadRoom(roomID id.RoomID) *mautrix.Room { return nil }

+ 438 - 0
database/cryptostore.go

@@ -0,0 +1,438 @@
+// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
+// Copyright (C) 2020 Tulir Asokan
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+// +build cgo
+
+package database
+
+import (
+	"database/sql"
+	"fmt"
+	"strings"
+
+	"github.com/lib/pq"
+	"github.com/pkg/errors"
+	log "maunium.net/go/maulogger/v2"
+
+	"maunium.net/go/mautrix/crypto"
+	"maunium.net/go/mautrix/crypto/olm"
+	"maunium.net/go/mautrix/id"
+)
+
+type SQLCryptoStore struct {
+	db  *Database
+	log log.Logger
+
+	UserID    id.UserID
+	DeviceID  id.DeviceID
+	SyncToken string
+	PickleKey []byte
+	Account   *crypto.OlmAccount
+
+	GhostIDFormat string
+}
+
+var _ crypto.Store = (*SQLCryptoStore)(nil)
+
+func NewSQLCryptoStore(db *Database, deviceID id.DeviceID) *SQLCryptoStore {
+	return &SQLCryptoStore{
+		db:        db,
+		log:       db.log.Sub("CryptoStore"),
+		PickleKey: []byte("maunium.net/go/mautrix-whatsapp"),
+		DeviceID:  deviceID,
+	}
+}
+
+func (db *Database) FindDeviceID() (deviceID id.DeviceID) {
+	err := db.QueryRow("SELECT device_id FROM crypto_account LIMIT 1").Scan(&deviceID)
+	if err != nil && err != sql.ErrNoRows {
+		db.log.Warnln("Failed to scan device ID:", err)
+	}
+	return
+}
+
+func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) (members []id.UserID, err error) {
+	var rows *sql.Rows
+	rows, err = store.db.Query(`
+		SELECT user_id FROM mx_user_profile
+		WHERE room_id=$1
+			AND (membership='join' OR membership='invite')
+			AND user_id<>$2
+			AND user_id NOT LIKE $3
+	`, roomID, store.UserID, store.GhostIDFormat)
+	if err != nil {
+		return
+	}
+	for rows.Next() {
+		var userID id.UserID
+		err := rows.Scan(&userID)
+		if err != nil {
+			store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
+		} else {
+			members = append(members, userID)
+		}
+	}
+	return
+}
+
+func (store *SQLCryptoStore) Flush() error {
+	return nil
+}
+
+func (store *SQLCryptoStore) PutNextBatch(nextBatch string) {
+	store.SyncToken = nextBatch
+	_, err := store.db.Exec(`UPDATE crypto_account SET sync_token=$1 WHERE device_id=$2`, store.SyncToken, store.DeviceID)
+	if err != nil {
+		store.log.Warnln("Failed to store sync token:", err)
+	}
+}
+
+func (store *SQLCryptoStore) GetNextBatch() string {
+	if store.SyncToken == "" {
+		err := store.db.
+			QueryRow("SELECT sync_token FROM crypto_account WHERE device_id=$1", store.DeviceID).
+			Scan(&store.SyncToken)
+		if err != nil && err != sql.ErrNoRows {
+			store.log.Warnln("Failed to scan sync token:", err)
+		}
+	}
+	return store.SyncToken
+}
+
+func (store *SQLCryptoStore) PutAccount(account *crypto.OlmAccount) error {
+	store.Account = account
+	bytes := account.Internal.Pickle(store.PickleKey)
+	var err error
+	if store.db.dialect == "postgres" {
+		_, err = store.db.Exec(`
+			INSERT INTO crypto_account (device_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)
+			ON CONFLICT (device_id) DO UPDATE SET shared=$2, sync_token=$3, account=$4`,
+			store.DeviceID, account.Shared, store.SyncToken, bytes)
+	} else if store.db.dialect == "sqlite3" {
+		_, err = store.db.Exec("INSERT OR REPLACE INTO crypto_account (device_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)",
+			store.DeviceID, account.Shared, store.SyncToken, bytes)
+	} else {
+		err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
+	}
+	if err != nil {
+		store.log.Warnln("Failed to store account:", err)
+	}
+	return nil
+}
+
+func (store *SQLCryptoStore) GetAccount() (*crypto.OlmAccount, error) {
+	if store.Account == nil {
+		row := store.db.QueryRow("SELECT shared, sync_token, account FROM crypto_account WHERE device_id=$1", store.DeviceID)
+		acc := &crypto.OlmAccount{Internal: *olm.NewBlankAccount()}
+		var accountBytes []byte
+		err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes)
+		if err == sql.ErrNoRows {
+			return nil, nil
+		} else if err != nil {
+			return nil, err
+		}
+		err = acc.Internal.Unpickle(accountBytes, store.PickleKey)
+		if err != nil {
+			return nil, err
+		}
+		store.Account = acc
+	}
+	return store.Account, nil
+}
+
+func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
+	// TODO this may need to be changed if olm sessions start expiring
+	var sessionID id.SessionID
+	err := store.db.QueryRow("SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 LIMIT 1", key).Scan(&sessionID)
+	if err == sql.ErrNoRows {
+		return false
+	}
+	return len(sessionID) > 0
+}
+
+func (store *SQLCryptoStore) GetSessions(key id.SenderKey) (crypto.OlmSessionList, error) {
+	rows, err := store.db.Query("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 ORDER BY session_id", key)
+	if err != nil {
+		return nil, err
+	}
+	list := crypto.OlmSessionList{}
+	for rows.Next() {
+		sess := crypto.OlmSession{Internal: *olm.NewBlankSession()}
+		var sessionBytes []byte
+		err := rows.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
+		if err != nil {
+			return nil, err
+		}
+		err = sess.Internal.Unpickle(sessionBytes, store.PickleKey)
+		if err != nil {
+			return nil, err
+		}
+		list = append(list, &sess)
+	}
+	return list, nil
+}
+
+func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*crypto.OlmSession, error) {
+	row := store.db.QueryRow("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 ORDER BY session_id DESC LIMIT 1", key)
+	sess := crypto.OlmSession{Internal: *olm.NewBlankSession()}
+	var sessionBytes []byte
+	err := row.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
+	if err == sql.ErrNoRows {
+		return nil, nil
+	} else if err != nil {
+		return nil, err
+	}
+	return &sess, sess.Internal.Unpickle(sessionBytes, store.PickleKey)
+}
+
+func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *crypto.OlmSession) error {
+	sessionBytes := session.Internal.Pickle(store.PickleKey)
+	_, err := store.db.Exec("INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_used) VALUES ($1, $2, $3, $4, $5)",
+		session.ID(), key, sessionBytes, session.CreationTime, session.UseTime)
+	return err
+}
+
+func (store *SQLCryptoStore) UpdateSession(key id.SenderKey, session *crypto.OlmSession) error {
+	sessionBytes := session.Internal.Pickle(store.PickleKey)
+	_, err := store.db.Exec("UPDATE crypto_olm_session SET session=$1, last_used=$2 WHERE session_id=$3",
+		sessionBytes, session.UseTime, session.ID())
+	return err
+}
+
+func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *crypto.InboundGroupSession) error {
+	sessionBytes := session.Internal.Pickle(store.PickleKey)
+	forwardingChains := strings.Join(session.ForwardingChains, ",")
+	_, err := store.db.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, signing_key, room_id, session, forwarding_chains) VALUES ($1, $2, $3, $4, $5, $6)",
+		sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains)
+	return err
+}
+
+func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*crypto.InboundGroupSession, error) {
+	var signingKey id.Ed25519
+	var sessionBytes []byte
+	var forwardingChains string
+	err := store.db.QueryRow(`
+		SELECT signing_key, session, forwarding_chains
+		FROM crypto_megolm_inbound_session
+		WHERE room_id=$1 AND sender_key=$2 AND session_id=$3`,
+		roomID, senderKey, sessionID,
+	).Scan(&signingKey, &sessionBytes, &forwardingChains)
+	if err == sql.ErrNoRows {
+		return nil, nil
+	} else if err != nil {
+		return nil, err
+	}
+	igs := olm.NewBlankInboundGroupSession()
+	err = igs.Unpickle(sessionBytes, store.PickleKey)
+	if err != nil {
+		return nil, err
+	}
+	return &crypto.InboundGroupSession{
+		Internal:         *igs,
+		SigningKey:       signingKey,
+		SenderKey:        senderKey,
+		RoomID:           roomID,
+		ForwardingChains: strings.Split(forwardingChains, ","),
+	}, nil
+}
+
+func (store *SQLCryptoStore) AddOutboundGroupSession(session *crypto.OutboundGroupSession) (err error) {
+	sessionBytes := session.Internal.Pickle(store.PickleKey)
+	if store.db.dialect == "postgres" {
+		_, err = store.db.Exec(`
+			INSERT INTO crypto_megolm_outbound_session (
+				room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used
+			) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
+			ON CONFLICT (room_id) DO UPDATE SET session_id=$2, session=$3, shared=$4, max_messages=$5, message_count=$6, max_age=$7, created_at=$8, last_used=$9`,
+			session.RoomID, session.ID(), sessionBytes, session.Shared, session.MaxMessages, session.MessageCount, session.MaxAge, session.CreationTime, session.UseTime)
+	} else if store.db.dialect == "sqlite3" {
+		_, err = store.db.Exec(`
+			INSERT OR REPLACE INTO crypto_megolm_outbound_session (
+				room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used
+			) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
+			session.RoomID, session.ID(), sessionBytes, session.Shared, session.MaxMessages, session.MessageCount, session.MaxAge, session.CreationTime, session.UseTime)
+	}  else {
+		err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
+	}
+	return
+}
+
+func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *crypto.OutboundGroupSession) error {
+	sessionBytes := session.Internal.Pickle(store.PickleKey)
+	_, err := store.db.Exec("UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5",
+		sessionBytes, session.MessageCount, session.UseTime, session.RoomID, session.ID())
+	return err
+}
+
+func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*crypto.OutboundGroupSession, error) {
+	var ogs crypto.OutboundGroupSession
+	var sessionBytes []byte
+	err := store.db.QueryRow(`
+		SELECT session, shared, max_messages, message_count, max_age, created_at, last_used
+		FROM crypto_megolm_outbound_session WHERE room_id=$1`,
+		roomID,
+	).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &ogs.MaxAge, &ogs.CreationTime, &ogs.UseTime)
+	if err == sql.ErrNoRows {
+		return nil, nil
+	} else if err != nil {
+		return nil, err
+	}
+	intOGS := olm.NewBlankOutboundGroupSession()
+	err = intOGS.Unpickle(sessionBytes, store.PickleKey)
+	if err != nil {
+		return nil, err
+	}
+	ogs.Internal = *intOGS
+	ogs.RoomID = roomID
+	return &ogs, nil
+}
+
+func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
+	_, err := store.db.Exec("DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1", roomID)
+	return err
+}
+
+func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) bool {
+	var resultEventID id.EventID
+	var resultTimestamp int64
+	err := store.db.QueryRow(
+		`SELECT event_id, timestamp FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND "index"=$3`,
+		senderKey, sessionID, index,
+	).Scan(&resultEventID, &resultTimestamp)
+	if err == sql.ErrNoRows {
+		_, err := store.db.Exec(`INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp) VALUES ($1, $2, $3, $4, $5)`,
+			senderKey, sessionID, index, eventID, timestamp)
+		if err != nil {
+			store.log.Warnln("Failed to store message index:", err)
+		}
+		return true
+	} else if err != nil {
+		store.log.Warnln("Failed to scan message index:", err)
+		return true
+	}
+	if resultEventID != eventID || resultTimestamp != timestamp {
+		return false
+	}
+	return true
+}
+
+func (store *SQLCryptoStore) GetDevices(userID id.UserID) (map[id.DeviceID]*crypto.DeviceIdentity, error) {
+	var ignore id.UserID
+	err := store.db.QueryRow("SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore)
+	if err == sql.ErrNoRows {
+		return nil, nil
+	} else if err != nil {
+		return nil, err
+	}
+
+	rows, err := store.db.Query("SELECT device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1", userID)
+	if err != nil {
+		return nil, err
+	}
+	data := make(map[id.DeviceID]*crypto.DeviceIdentity)
+	for rows.Next() {
+		var identity crypto.DeviceIdentity
+		err := rows.Scan(&identity.DeviceID, &identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
+		if err != nil {
+			return nil, err
+		}
+		identity.UserID = userID
+		data[identity.DeviceID] = &identity
+	}
+	return data, nil
+}
+
+func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*crypto.DeviceIdentity) error {
+	tx, err := store.db.Begin()
+	if err != nil {
+		return err
+	}
+
+	if store.db.dialect == "postgres" {
+		_, err = tx.Exec("INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
+	} else if store.db.dialect == "sqlite3" {
+		_, err = tx.Exec("INSERT OR IGNORE INTO crypto_tracked_user (user_id) VALUES ($1)", userID)
+	} else {
+		err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
+	}
+	if err != nil {
+		return errors.Wrap(err, "failed to add user to tracked users list")
+	}
+
+	_, err = tx.Exec("DELETE FROM crypto_device WHERE user_id=$1", userID)
+	if err != nil {
+		_ = tx.Rollback()
+		return errors.Wrap(err, "failed to delete old devices")
+	}
+	if len(devices) == 0 {
+		err = tx.Commit()
+		if err != nil {
+			return errors.Wrap(err, "failed to commit changes (no devices added)")
+		}
+		return nil
+	}
+	// TODO do this in batches to avoid too large db queries
+	values := make([]interface{}, 1, len(devices)*6+1)
+	values[0] = userID
+	valueStrings := make([]string, 0, len(devices))
+	i := 2
+	for deviceID, identity := range devices {
+		values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name)
+		valueStrings = append(valueStrings, fmt.Sprintf("($1, $%d, $%d, $%d, $%d, $%d, $%d)", i, i+1, i+2, i+3, i+4, i+5))
+		i += 6
+	}
+	valueString := strings.Join(valueStrings, ",")
+	_, err = tx.Exec("INSERT INTO crypto_device (user_id, device_id, identity_key, signing_key, trust, deleted, name) VALUES "+valueString, values...)
+	if err != nil {
+		_ = tx.Rollback()
+		return errors.Wrap(err, "failed to insert new devices")
+	}
+	err = tx.Commit()
+	if err != nil {
+		return errors.Wrap(err, "failed to commit changes")
+	}
+	return nil
+}
+
+func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID {
+	var rows *sql.Rows
+	var err error
+	if store.db.dialect == "postgres" {
+		rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", pq.Array(users))
+	} else {
+		queryString := make([]string, len(users))
+		params := make([]interface{}, len(users))
+		for i, user := range users {
+			queryString[i] = fmt.Sprintf("$%d", i+1)
+			params[i] = user
+		}
+		rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...)
+	}
+	if err != nil {
+		store.log.Warnln("Failed to filter tracked users:", err)
+		return users
+	}
+	var ptr int
+	for rows.Next() {
+		err = rows.Scan(&users[ptr])
+		if err != nil {
+			store.log.Warnln("Failed to tracked user ID:", err)
+		} else {
+			ptr++
+		}
+	}
+	return users[:ptr]
+}

+ 4 - 3
database/message.go

@@ -1,5 +1,5 @@
 // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2019 Tulir Asokan
+// Copyright (C) 2020 Tulir Asokan
 //
 // This program is free software: you can redistribute it and/or modify
 // it under the terms of the GNU Affero General Public License as published by
@@ -26,6 +26,7 @@ import (
 	log "maunium.net/go/maulogger/v2"
 
 	"maunium.net/go/mautrix-whatsapp/types"
+	"maunium.net/go/mautrix/id"
 )
 
 type MessageQuery struct {
@@ -57,7 +58,7 @@ func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.WhatsAppMessageID) *M
 		"FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", chat.JID, chat.Receiver, jid)
 }
 
-func (mq *MessageQuery) GetByMXID(mxid types.MatrixEventID) *Message {
+func (mq *MessageQuery) GetByMXID(mxid id.EventID) *Message {
 	return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content " +
 		"FROM message WHERE mxid=$1", mxid)
 }
@@ -86,7 +87,7 @@ type Message struct {
 
 	Chat      PortalKey
 	JID       types.WhatsAppMessageID
-	MXID      types.MatrixEventID
+	MXID      id.EventID
 	Sender    types.WhatsAppID
 	Timestamp uint64
 	Content   *waProto.Message

+ 29 - 1
database/migrate.go

@@ -89,7 +89,7 @@ func migrateTable(old *Database, new *Database, table string, columns ...string)
 }
 
 func Migrate(old *Database, new *Database) {
-	err := migrateTable(old, new, "portal", "jid", "receiver", "mxid", "name", "topic", "avatar", "avatar_url")
+	err := migrateTable(old, new, "portal", "jid", "receiver", "mxid", "name", "topic", "avatar", "avatar_url", "encrypted")
 	if err != nil {
 		panic(err)
 	}
@@ -121,4 +121,32 @@ func Migrate(old *Database, new *Database) {
 	if err != nil {
 		panic(err)
 	}
+	err = migrateTable(old, new, "crypto_account", "device_id", "shared", "sync_token", "account")
+	if err != nil {
+		panic(err)
+	}
+	err = migrateTable(old, new, "crypto_message_index", "sender_key", "session_id", `"index"`, "event_id", "timestamp")
+	if err != nil {
+		panic(err)
+	}
+	err = migrateTable(old, new, "crypto_tracked_user", "user_id")
+	if err != nil {
+		panic(err)
+	}
+	err = migrateTable(old, new, "crypto_device", "user_id", "device_id", "identity_key", "signing_key", "trust", "deleted", "name")
+	if err != nil {
+		panic(err)
+	}
+	err = migrateTable(old, new, "crypto_olm_session", "session_id", "sender_key", "session", "created_at", "last_used")
+	if err != nil {
+		panic(err)
+	}
+	err = migrateTable(old, new, "crypto_megolm_inbound_session", "session_id", "sender_key", "signing_key", "room_id", "session", "forwarding_chains")
+	if err != nil {
+		panic(err)
+	}
+	err = migrateTable(old, new, "crypto_megolm_outbound_session", "room_id", "session_id", "session", "shared", "max_messages", "message_count", "max_age", "created_at", "last_used")
+	if err != nil {
+		panic(err)
+	}
 }

+ 19 - 16
database/portal.go

@@ -1,5 +1,5 @@
 // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2019 Tulir Asokan
+// Copyright (C) 2020 Tulir Asokan
 //
 // This program is free software: you can redistribute it and/or modify
 // it under the terms of the GNU Affero General Public License as published by
@@ -22,6 +22,8 @@ import (
 
 	log "maunium.net/go/maulogger/v2"
 
+	"maunium.net/go/mautrix/id"
+
 	"maunium.net/go/mautrix-whatsapp/types"
 )
 
@@ -74,7 +76,7 @@ func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
 	return pq.get("SELECT * FROM portal WHERE jid=$1 AND receiver=$2", key.JID, key.Receiver)
 }
 
-func (pq *PortalQuery) GetByMXID(mxid types.MatrixRoomID) *Portal {
+func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
 	return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid)
 }
 
@@ -107,29 +109,30 @@ type Portal struct {
 	log log.Logger
 
 	Key  PortalKey
-	MXID types.MatrixRoomID
+	MXID id.RoomID
 
 	Name      string
 	Topic     string
 	Avatar    string
-	AvatarURL string
+	AvatarURL id.ContentURI
+	Encrypted bool
 }
 
 func (portal *Portal) Scan(row Scannable) *Portal {
 	var mxid, avatarURL sql.NullString
-	err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.Topic, &portal.Avatar, &avatarURL)
+	err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.Topic, &portal.Avatar, &avatarURL, &portal.Encrypted)
 	if err != nil {
 		if err != sql.ErrNoRows {
 			portal.log.Errorln("Database scan failed:", err)
 		}
 		return nil
 	}
-	portal.MXID = mxid.String
-	portal.AvatarURL = avatarURL.String
+	portal.MXID = id.RoomID(mxid.String)
+	portal.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
 	return portal
 }
 
-func (portal *Portal) mxidPtr() *string {
+func (portal *Portal) mxidPtr() *id.RoomID {
 	if len(portal.MXID) > 0 {
 		return &portal.MXID
 	}
@@ -137,20 +140,20 @@ func (portal *Portal) mxidPtr() *string {
 }
 
 func (portal *Portal) Insert() {
-	_, err := portal.db.Exec("INSERT INTO portal VALUES ($1, $2, $3, $4, $5, $6, $7)",
-		portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL)
+	_, err := portal.db.Exec("INSERT INTO portal (jid, receiver, mxid, name, topic, avatar, avatar_url, encrypted) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)",
+		portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL.String(), portal.Encrypted)
 	if err != nil {
 		portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
 	}
 }
 
 func (portal *Portal) Update() {
-	var mxid *string
+	var mxid *id.RoomID
 	if len(portal.MXID) > 0 {
 		mxid = &portal.MXID
 	}
-	_, err := portal.db.Exec("UPDATE portal SET mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5 WHERE jid=$6 AND receiver=$7",
-		mxid, portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL, portal.Key.JID, portal.Key.Receiver)
+	_, err := portal.db.Exec("UPDATE portal SET mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, encrypted=$6 WHERE jid=$7 AND receiver=$8",
+		mxid, portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL.String(), portal.Encrypted, portal.Key.JID, portal.Key.Receiver)
 	if err != nil {
 		portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
 	}
@@ -163,7 +166,7 @@ func (portal *Portal) Delete() {
 	}
 }
 
-func (portal *Portal) GetUserIDs() []types.MatrixUserID {
+func (portal *Portal) GetUserIDs() []id.UserID {
 	rows, err := portal.db.Query(`SELECT "user".mxid FROM "user", user_portal
 		WHERE "user".jid=user_portal.user_jid
 			AND user_portal.portal_jid=$1
@@ -173,9 +176,9 @@ func (portal *Portal) GetUserIDs() []types.MatrixUserID {
 		portal.log.Debugln("Failed to get portal user ids:", err)
 		return nil
 	}
-	var userIDs []types.MatrixUserID
+	var userIDs []id.UserID
 	for rows.Next() {
-		var userID types.MatrixUserID
+		var userID id.UserID
 		err = rows.Scan(&userID)
 		if err != nil {
 			portal.log.Warnln("Failed to scan row:", err)

+ 9 - 8
database/puppet.go

@@ -1,5 +1,5 @@
 // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2019 Tulir Asokan
+// Copyright (C) 2020 Tulir Asokan
 //
 // This program is free software: you can redistribute it and/or modify
 // it under the terms of the GNU Affero General Public License as published by
@@ -22,6 +22,7 @@ import (
 	log "maunium.net/go/maulogger/v2"
 
 	"maunium.net/go/mautrix-whatsapp/types"
+	"maunium.net/go/mautrix/id"
 )
 
 type PuppetQuery struct {
@@ -56,7 +57,7 @@ func (pq *PuppetQuery) Get(jid types.WhatsAppID) *Puppet {
 	return pq.New().Scan(row)
 }
 
-func (pq *PuppetQuery) GetByCustomMXID(mxid types.MatrixUserID) *Puppet {
+func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet {
 	row := pq.db.QueryRow("SELECT jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch FROM puppet WHERE custom_mxid=$1", mxid)
 	if row == nil {
 		return nil
@@ -82,11 +83,11 @@ type Puppet struct {
 
 	JID         types.WhatsAppID
 	Avatar      string
-	AvatarURL   string
+	AvatarURL   id.ContentURI
 	Displayname string
 	NameQuality int8
 
-	CustomMXID  string
+	CustomMXID  id.UserID
 	AccessToken string
 	NextBatch   string
 }
@@ -103,9 +104,9 @@ func (puppet *Puppet) Scan(row Scannable) *Puppet {
 	}
 	puppet.Displayname = displayname.String
 	puppet.Avatar = avatar.String
-	puppet.AvatarURL = avatarURL.String
+	puppet.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
 	puppet.NameQuality = int8(quality.Int64)
-	puppet.CustomMXID = customMXID.String
+	puppet.CustomMXID = id.UserID(customMXID.String)
 	puppet.AccessToken = accessToken.String
 	puppet.NextBatch = nextBatch.String
 	return puppet
@@ -113,7 +114,7 @@ func (puppet *Puppet) Scan(row Scannable) *Puppet {
 
 func (puppet *Puppet) Insert() {
 	_, err := puppet.db.Exec("INSERT INTO puppet (jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)",
-		puppet.JID, puppet.Avatar, puppet.AvatarURL, puppet.Displayname, puppet.NameQuality, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch)
+		puppet.JID, puppet.Avatar, puppet.AvatarURL.String(), puppet.Displayname, puppet.NameQuality, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch)
 	if err != nil {
 		puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err)
 	}
@@ -121,7 +122,7 @@ func (puppet *Puppet) Insert() {
 
 func (puppet *Puppet) Update() {
 	_, err := puppet.db.Exec("UPDATE puppet SET displayname=$1, name_quality=$2, avatar=$3, avatar_url=$4, custom_mxid=$5, access_token=$6, next_batch=$7 WHERE jid=$8",
-		puppet.Displayname, puppet.NameQuality, puppet.Avatar, puppet.AvatarURL, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch, puppet.JID)
+		puppet.Displayname, puppet.NameQuality, puppet.Avatar, puppet.AvatarURL.String(), puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch, puppet.JID)
 	if err != nil {
 		puppet.log.Warnfln("Failed to update %s->%s: %v", puppet.JID, err)
 	}

+ 55 - 30
database/statestore.go

@@ -1,5 +1,5 @@
 // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2019 Tulir Asokan
+// Copyright (C) 2020 Tulir Asokan
 //
 // This program is free software: you can redistribute it and/or modify
 // it under the terms of the GNU Affero General Public License as published by
@@ -24,8 +24,9 @@ import (
 
 	log "maunium.net/go/maulogger/v2"
 
-	"maunium.net/go/mautrix"
-	"maunium.net/go/mautrix-appservice"
+	"maunium.net/go/mautrix/appservice"
+	"maunium.net/go/mautrix/event"
+	"maunium.net/go/mautrix/id"
 )
 
 type SQLStateStore struct {
@@ -34,19 +35,21 @@ type SQLStateStore struct {
 	db  *Database
 	log log.Logger
 
-	Typing     map[string]map[string]int64
+	Typing     map[id.RoomID]map[id.UserID]int64
 	typingLock sync.RWMutex
 }
 
+var _ appservice.StateStore = (*SQLStateStore)(nil)
+
 func NewSQLStateStore(db *Database) *SQLStateStore {
 	return &SQLStateStore{
 		TypingStateStore: appservice.NewTypingStateStore(),
 		db:               db,
-		log:              log.Sub("StateStore"),
+		log:              db.log.Sub("StateStore"),
 	}
 }
 
-func (store *SQLStateStore) IsRegistered(userID string) bool {
+func (store *SQLStateStore) IsRegistered(userID id.UserID) bool {
 	row := store.db.QueryRow("SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID)
 	var isRegistered bool
 	err := row.Scan(&isRegistered)
@@ -56,7 +59,7 @@ func (store *SQLStateStore) IsRegistered(userID string) bool {
 	return isRegistered
 }
 
-func (store *SQLStateStore) MarkRegistered(userID string) {
+func (store *SQLStateStore) MarkRegistered(userID id.UserID) {
 	var err error
 	if store.db.dialect == "postgres" {
 		_, err = store.db.Exec("INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
@@ -70,28 +73,28 @@ func (store *SQLStateStore) MarkRegistered(userID string) {
 	}
 }
 
-func (store *SQLStateStore) GetRoomMembers(roomID string) map[string]mautrix.Member {
-	members := make(map[string]mautrix.Member)
+func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent {
+	members := make(map[id.UserID]*event.MemberEventContent)
 	rows, err := store.db.Query("SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1", roomID)
 	if err != nil {
 		return members
 	}
-	var userID string
-	var member mautrix.Member
+	var userID id.UserID
+	var member event.MemberEventContent
 	for rows.Next() {
 		err := rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL)
 		if err != nil {
 			store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
 		} else {
-			members[userID] = member
+			members[userID] = &member
 		}
 	}
 	return members
 }
 
-func (store *SQLStateStore) GetMembership(roomID, userID string) mautrix.Membership {
+func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
 	row := store.db.QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID)
-	membership := mautrix.MembershipLeave
+	membership := event.MembershipLeave
 	err := row.Scan(&membership)
 	if err != nil && err != sql.ErrNoRows {
 		store.log.Warnfln("Failed to scan membership of %s in %s: %v", userID, roomID, err)
@@ -99,33 +102,55 @@ func (store *SQLStateStore) GetMembership(roomID, userID string) mautrix.Members
 	return membership
 }
 
-func (store *SQLStateStore) GetMember(roomID, userID string) mautrix.Member {
+func (store *SQLStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
 	member, ok := store.TryGetMember(roomID, userID)
 	if !ok {
-		member.Membership = mautrix.MembershipLeave
+		member.Membership = event.MembershipLeave
 	}
 	return member
 }
 
-func (store *SQLStateStore) TryGetMember(roomID, userID string) (mautrix.Member, bool) {
+func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) {
 	row := store.db.QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID)
-	var member mautrix.Member
+	var member event.MemberEventContent
 	err := row.Scan(&member.Membership, &member.Displayname, &member.AvatarURL)
 	if err != nil && err != sql.ErrNoRows {
 		store.log.Warnfln("Failed to scan member info of %s in %s: %v", userID, roomID, err)
 	}
-	return member, err == nil
+	return &member, err == nil
+}
+
+func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) {
+	rows, err := store.db.Query(`
+			SELECT room_id FROM mx_user_profile
+			LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id
+			WHERE user_id=$1 AND portal.encrypted=true
+	`, userID)
+	if err != nil {
+		store.log.Warnfln("Failed to query shared rooms with %s: %v", userID, err)
+		return
+	}
+	for rows.Next() {
+		var roomID id.RoomID
+		err := rows.Scan(&roomID)
+		if err != nil {
+			store.log.Warnfln("Failed to scan room ID: %v", err)
+		} else {
+			rooms = append(rooms, roomID)
+		}
+	}
+	return
 }
 
-func (store *SQLStateStore) IsInRoom(roomID, userID string) bool {
+func (store *SQLStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
 	return store.IsMembership(roomID, userID, "join")
 }
 
-func (store *SQLStateStore) IsInvited(roomID, userID string) bool {
+func (store *SQLStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
 	return store.IsMembership(roomID, userID, "join", "invite")
 }
 
-func (store *SQLStateStore) IsMembership(roomID, userID string, allowedMemberships ...mautrix.Membership) bool {
+func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
 	membership := store.GetMembership(roomID, userID)
 	for _, allowedMembership := range allowedMemberships {
 		if allowedMembership == membership {
@@ -135,7 +160,7 @@ func (store *SQLStateStore) IsMembership(roomID, userID string, allowedMembershi
 	return false
 }
 
-func (store *SQLStateStore) SetMembership(roomID, userID string, membership mautrix.Membership) {
+func (store *SQLStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
 	var err error
 	if store.db.dialect == "postgres" {
 		_, err = store.db.Exec(`INSERT INTO mx_user_profile (room_id, user_id, membership) VALUES ($1, $2, $3)
@@ -150,7 +175,7 @@ func (store *SQLStateStore) SetMembership(roomID, userID string, membership maut
 	}
 }
 
-func (store *SQLStateStore) SetMember(roomID, userID string, member mautrix.Member) {
+func (store *SQLStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
 	var err error
 	if store.db.dialect == "postgres" {
 		_, err = store.db.Exec(`INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)
@@ -166,7 +191,7 @@ func (store *SQLStateStore) SetMember(roomID, userID string, member mautrix.Memb
 	}
 }
 
-func (store *SQLStateStore) SetPowerLevels(roomID string, levels *mautrix.PowerLevels) {
+func (store *SQLStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
 	levelsBytes, err := json.Marshal(levels)
 	if err != nil {
 		store.log.Errorfln("Failed to marshal power levels of %s: %v", roomID, err)
@@ -185,7 +210,7 @@ func (store *SQLStateStore) SetPowerLevels(roomID string, levels *mautrix.PowerL
 	}
 }
 
-func (store *SQLStateStore) GetPowerLevels(roomID string) (levels *mautrix.PowerLevels) {
+func (store *SQLStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
 	row := store.db.QueryRow("SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID)
 	if row == nil {
 		return
@@ -196,7 +221,7 @@ func (store *SQLStateStore) GetPowerLevels(roomID string) (levels *mautrix.Power
 		store.log.Errorln("Failed to scan power levels of %s: %v", roomID, err)
 		return
 	}
-	levels = &mautrix.PowerLevels{}
+	levels = &event.PowerLevelsEventContent{}
 	err = json.Unmarshal(data, levels)
 	if err != nil {
 		store.log.Errorln("Failed to parse power levels of %s: %v", roomID, err)
@@ -205,7 +230,7 @@ func (store *SQLStateStore) GetPowerLevels(roomID string) (levels *mautrix.Power
 	return
 }
 
-func (store *SQLStateStore) GetPowerLevel(roomID, userID string) int {
+func (store *SQLStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
 	if store.db.dialect == "postgres" {
 		row := store.db.QueryRow(`SELECT
 			COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
@@ -224,7 +249,7 @@ func (store *SQLStateStore) GetPowerLevel(roomID, userID string) int {
 	return store.GetPowerLevels(roomID).GetUserLevel(userID)
 }
 
-func (store *SQLStateStore) GetPowerLevelRequirement(roomID string, eventType mautrix.EventType) int {
+func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
 	if store.db.dialect == "postgres" {
 		defaultType := "events_default"
 		defaultValue := 0
@@ -249,7 +274,7 @@ func (store *SQLStateStore) GetPowerLevelRequirement(roomID string, eventType ma
 	return store.GetPowerLevels(roomID).GetEventLevel(eventType)
 }
 
-func (store *SQLStateStore) HasPowerLevel(roomID, userID string, eventType mautrix.EventType) bool {
+func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
 	if store.db.dialect == "postgres" {
 		defaultType := "events_default"
 		defaultValue := 0

+ 9 - 16
database/upgrades/2018-09-01-initial-schema.go

@@ -2,26 +2,19 @@ package upgrades
 
 import (
 	"database/sql"
-	"fmt"
 )
 
 func init() {
 	upgrades[0] = upgrade{"Initial schema", func(tx *sql.Tx, ctx context) error {
-		var byteType string
-		if ctx.dialect == SQLite {
-			byteType = "BLOB"
-		} else {
-			byteType = "bytea"
-		}
 		_, err := tx.Exec(`CREATE TABLE IF NOT EXISTS portal (
 			jid      VARCHAR(255),
 			receiver VARCHAR(255),
 			mxid     VARCHAR(255) UNIQUE,
-	
+
 			name   VARCHAR(255) NOT NULL,
 			topic  VARCHAR(255) NOT NULL,
 			avatar VARCHAR(255) NOT NULL,
-	
+
 			PRIMARY KEY (jid, receiver)
 		)`)
 		if err != nil {
@@ -38,7 +31,7 @@ func init() {
 			return err
 		}
 
-		_, err = tx.Exec(fmt.Sprintf(`CREATE TABLE IF NOT EXISTS "user" (
+		_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS "user" (
 			mxid VARCHAR(255) PRIMARY KEY,
 			jid  VARCHAR(255) UNIQUE,
 
@@ -47,24 +40,24 @@ func init() {
 			client_id    VARCHAR(255),
 			client_token VARCHAR(255),
 			server_token VARCHAR(255),
-			enc_key      %[1]s,
-			mac_key      %[1]s
-		)`, byteType))
+			enc_key      bytea,
+			mac_key      bytea
+		)`)
 		if err != nil {
 			return err
 		}
 
-		_, err = tx.Exec(fmt.Sprintf(`CREATE TABLE IF NOT EXISTS message (
+		_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS message (
 			chat_jid      VARCHAR(255),
 			chat_receiver VARCHAR(255),
 			jid           VARCHAR(255),
 			mxid          VARCHAR(255) NOT NULL UNIQUE,
 			sender        VARCHAR(255) NOT NULL,
-			content       %[1]s        NOT NULL,
+			content       bytea        NOT NULL,
 
 			PRIMARY KEY (chat_jid, chat_receiver, jid),
 			FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
-		)`, byteType))
+		)`)
 		if err != nil {
 			return err
 		}

+ 6 - 6
database/upgrades/2019-08-25-move-state-store-to-db.go

@@ -8,7 +8,7 @@ import (
 	"os"
 	"strings"
 
-	"maunium.net/go/mautrix"
+	"maunium.net/go/mautrix/event"
 )
 
 func init() {
@@ -46,7 +46,7 @@ func init() {
 		return executeBatch(tx, valueStrings, values...)
 	}
 
-	migrateMemberships := func(tx *sql.Tx, rooms map[string]map[string]mautrix.Membership) error {
+	migrateMemberships := func(tx *sql.Tx, rooms map[string]map[string]event.Membership) error {
 		for roomID, members := range rooms {
 			if len(members) == 0 {
 				continue
@@ -68,7 +68,7 @@ func init() {
 		return nil
 	}
 
-	migratePowerLevels := func(tx *sql.Tx, rooms map[string]*mautrix.PowerLevels) error {
+	migratePowerLevels := func(tx *sql.Tx, rooms map[string]*event.PowerLevelsEventContent) error {
 		if len(rooms) == 0 {
 			return nil
 		}
@@ -106,9 +106,9 @@ func init() {
 	)`
 
 	type TempStateStore struct {
-		Registrations map[string]bool                          `json:"registrations"`
-		Members       map[string]map[string]mautrix.Membership `json:"memberships"`
-		PowerLevels   map[string]*mautrix.PowerLevels          `json:"power_levels"`
+		Registrations map[string]bool                           `json:"registrations"`
+		Members       map[string]map[string]event.Membership    `json:"memberships"`
+		PowerLevels   map[string]*event.PowerLevelsEventContent `json:"power_levels"`
 	}
 
 	upgrades[9] = upgrade{"Move state store to main DB", func(tx *sql.Tx, ctx context) error {

+ 12 - 0
database/upgrades/2020-05-09-add-portal-encrypted-field.go

@@ -0,0 +1,12 @@
+package upgrades
+
+import (
+	"database/sql"
+)
+
+func init() {
+	upgrades[12] = upgrade{"Add encryption status to portal table", func(tx *sql.Tx, ctx context) error {
+		_, err := tx.Exec(`ALTER TABLE portal ADD COLUMN encrypted BOOLEAN NOT NULL DEFAULT false`)
+		return err
+	}}
+}

+ 73 - 0
database/upgrades/2020-05-09-crypto-store.go

@@ -0,0 +1,73 @@
+package upgrades
+
+import (
+	"database/sql"
+)
+
+func init() {
+	upgrades[13] = upgrade{"Add crypto store to database", func(tx *sql.Tx, ctx context) error {
+		_, err := tx.Exec(`CREATE TABLE crypto_account (
+			device_id  VARCHAR(255) PRIMARY KEY,
+			shared     BOOLEAN      NOT NULL,
+			sync_token TEXT         NOT NULL,
+			account    bytea        NOT NULL
+		)`)
+		if err != nil {
+			return err
+		}
+		_, err = tx.Exec(`CREATE TABLE crypto_message_index (
+			sender_key CHAR(43),
+			session_id CHAR(43),
+			"index"    INTEGER,
+			event_id   VARCHAR(255) NOT NULL,
+			timestamp  BIGINT       NOT NULL,
+
+			PRIMARY KEY (sender_key, session_id, "index")
+		)`)
+		if err != nil {
+			return err
+		}
+		_, err = tx.Exec(`CREATE TABLE crypto_tracked_user (
+			user_id VARCHAR(255) PRIMARY KEY
+		)`)
+		if err != nil {
+			return err
+		}
+		_, err = tx.Exec(`CREATE TABLE crypto_device (
+			user_id      VARCHAR(255),
+			device_id    VARCHAR(255),
+			identity_key CHAR(43)      NOT NULL,
+			signing_key  CHAR(43)      NOT NULL,
+			trust        SMALLINT      NOT NULL,
+			deleted      BOOLEAN       NOT NULL,
+			name         VARCHAR(255)  NOT NULL,
+
+			PRIMARY KEY (user_id, device_id)
+		)`)
+		if err != nil {
+			return err
+		}
+		_, err = tx.Exec(`CREATE TABLE crypto_olm_session (
+			session_id   CHAR(43)  PRIMARY KEY,
+			sender_key   CHAR(43)  NOT NULL,
+			session      bytea     NOT NULL,
+			created_at   timestamp NOT NULL,
+			last_used    timestamp NOT NULL
+		)`)
+		if err != nil {
+			return err
+		}
+		_, err = tx.Exec(`CREATE TABLE crypto_megolm_inbound_session (
+			session_id   CHAR(43)     PRIMARY KEY,
+			sender_key   CHAR(43)     NOT NULL,
+			signing_key  CHAR(43)     NOT NULL,
+			room_id      VARCHAR(255) NOT NULL,
+			session      bytea        NOT NULL,
+			forwarding_chains bytea   NOT NULL
+		)`)
+		if err != nil {
+			return err
+		}
+		return nil
+	}}
+}

+ 25 - 0
database/upgrades/2020-05-12-outbound-group-session-store.go

@@ -0,0 +1,25 @@
+package upgrades
+
+import (
+	"database/sql"
+)
+
+func init() {
+	upgrades[14] = upgrade{"Add outbound group sessions to database", func(tx *sql.Tx, ctx context) error {
+		_, err := tx.Exec(`CREATE TABLE crypto_megolm_outbound_session (
+			room_id       VARCHAR(255) PRIMARY KEY,
+			session_id    CHAR(43)     NOT NULL UNIQUE,
+			session       bytea        NOT NULL,
+			shared        BOOLEAN      NOT NULL,
+			max_messages  INTEGER      NOT NULL,
+			message_count INTEGER      NOT NULL,
+			max_age       BIGINT       NOT NULL,
+			created_at    timestamp    NOT NULL,
+			last_used     timestamp    NOT NULL
+		)`)
+		if err != nil {
+			return err
+		}
+		return nil
+	}}
+}

+ 1 - 1
database/upgrades/upgrades.go

@@ -28,7 +28,7 @@ type upgrade struct {
 	fn      upgradeFunc
 }
 
-const NumberOfUpgrades = 12
+const NumberOfUpgrades = 15
 
 var upgrades [NumberOfUpgrades]upgrade
 

+ 5 - 4
database/user.go

@@ -1,5 +1,5 @@
 // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2019 Tulir Asokan
+// Copyright (C) 2020 Tulir Asokan
 //
 // This program is free software: you can redistribute it and/or modify
 // it under the terms of the GNU Affero General Public License as published by
@@ -28,6 +28,7 @@ import (
 
 	"maunium.net/go/mautrix-whatsapp/types"
 	"maunium.net/go/mautrix-whatsapp/whatsapp-ext"
+	"maunium.net/go/mautrix/id"
 )
 
 type UserQuery struct {
@@ -54,7 +55,7 @@ func (uq *UserQuery) GetAll() (users []*User) {
 	return
 }
 
-func (uq *UserQuery) GetByMXID(userID types.MatrixUserID) *User {
+func (uq *UserQuery) GetByMXID(userID id.UserID) *User {
 	row := uq.db.QueryRow(`SELECT mxid, jid, management_room, last_connection, client_id, client_token, server_token, enc_key, mac_key FROM "user" WHERE mxid=$1`, userID)
 	if row == nil {
 		return nil
@@ -74,9 +75,9 @@ type User struct {
 	db  *Database
 	log log.Logger
 
-	MXID           types.MatrixUserID
+	MXID           id.UserID
 	JID            types.WhatsAppID
-	ManagementRoom types.MatrixRoomID
+	ManagementRoom id.RoomID
 	Session        *whatsapp.Session
 	LastConnection uint64
 }

+ 13 - 0
example-config.yaml

@@ -138,6 +138,19 @@ bridge:
     # The prefix for commands. Only required in non-management rooms.
     command_prefix: "!wa"
 
+    # End-to-bridge encryption support options. This requires login_shared_secret to be configured
+    # in order to get a device for the bridge bot.
+    #
+    # Additionally, https://github.com/matrix-org/synapse/pull/5758 is required if using a normal
+    # application service.
+    encryption:
+        # Allow encryption, work in group chat rooms with e2ee enabled
+        allow: false
+        # Default to encryption, force-enable encryption in all portals the bridge creates
+        # This will cause the bridge bot to be in private chats for the encryption to work properly.
+        # It is recommended to also set private_chat_portal_meta to true when using this.
+        default: false
+
     # Permissions for using the bridge.
     # Permitted values:
     # relaybot - Talk through the relaybot (if enabled), no access otherwise

+ 8 - 8
formatting.go

@@ -1,5 +1,5 @@
 // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2019 Tulir Asokan
+// Copyright (C) 2020 Tulir Asokan
 //
 // This program is free software: you can redistribute it and/or modify
 // it under the terms of the GNU Affero General Public License as published by
@@ -22,8 +22,9 @@ import (
 	"regexp"
 	"strings"
 
-	"maunium.net/go/mautrix"
+	"maunium.net/go/mautrix/event"
 	"maunium.net/go/mautrix/format"
+	"maunium.net/go/mautrix/id"
 
 	"maunium.net/go/mautrix-whatsapp/types"
 	"maunium.net/go/mautrix-whatsapp/whatsapp-ext"
@@ -54,8 +55,7 @@ func NewFormatter(bridge *Bridge) *Formatter {
 
 			PillConverter: func(mxid, eventID string) string {
 				if mxid[0] == '@' {
-					puppet := bridge.GetPuppetByMXID(mxid)
-					fmt.Println(mxid, puppet)
+					puppet := bridge.GetPuppetByMXID(id.UserID(mxid))
 					if puppet != nil {
 						return "@" + puppet.PhoneNumber()
 					}
@@ -106,10 +106,10 @@ func NewFormatter(bridge *Bridge) *Formatter {
 	return formatter
 }
 
-func (formatter *Formatter) getMatrixInfoByJID(jid types.WhatsAppID) (mxid, displayname string) {
+func (formatter *Formatter) getMatrixInfoByJID(jid types.WhatsAppID) (mxid id.UserID, displayname string) {
 	if user := formatter.bridge.GetUserByJID(jid); user != nil {
 		mxid = user.MXID
-		displayname = user.MXID
+		displayname = string(user.MXID)
 	} else if puppet := formatter.bridge.GetPuppetByJID(jid); puppet != nil {
 		mxid = puppet.MXID
 		displayname = puppet.Displayname
@@ -117,7 +117,7 @@ func (formatter *Formatter) getMatrixInfoByJID(jid types.WhatsAppID) (mxid, disp
 	return
 }
 
-func (formatter *Formatter) ParseWhatsApp(content *mautrix.Content) {
+func (formatter *Formatter) ParseWhatsApp(content *event.MessageEventContent) {
 	output := html.EscapeString(content.Body)
 	for regex, replacement := range formatter.waReplString {
 		output = regex.ReplaceAllString(output, replacement)
@@ -128,7 +128,7 @@ func (formatter *Formatter) ParseWhatsApp(content *mautrix.Content) {
 	if output != content.Body {
 		output = strings.Replace(output, "\n", "<br/>", -1)
 		content.FormattedBody = output
-		content.Format = mautrix.FormatHTML
+		content.Format = event.FormatHTML
 		for regex, replacer := range formatter.waReplFuncText {
 			content.Body = regex.ReplaceAllStringFunc(content.Body, replacer)
 		}

+ 4 - 4
go.mod

@@ -5,17 +5,17 @@ go 1.14
 require (
 	github.com/Rhymen/go-whatsapp v0.1.0
 	github.com/chai2010/webp v1.1.0
-	github.com/gorilla/websocket v1.4.1
-	github.com/lib/pq v1.3.0
+	github.com/gorilla/websocket v1.4.2
+	github.com/lib/pq v1.5.2
 	github.com/mattn/go-sqlite3 v2.0.3+incompatible
 	github.com/pkg/errors v0.9.1
 	github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
 	github.com/skip2/go-qrcode v0.0.0-20191027152451-9434209cb086
+	golang.org/x/image v0.0.0-20200430140353-33d19683fad8
 	gopkg.in/yaml.v2 v2.2.8
 	maunium.net/go/mauflag v1.0.0
 	maunium.net/go/maulogger/v2 v2.1.1
-	maunium.net/go/mautrix v0.1.0-beta.2
-	maunium.net/go/mautrix-appservice v0.1.0-alpha.6
+	maunium.net/go/mautrix v0.4.5
 )
 
 replace github.com/Rhymen/go-whatsapp => github.com/tulir/go-whatsapp v0.2.6

+ 36 - 2
go.sum

@@ -1,7 +1,10 @@
 github.com/chai2010/webp v1.1.0 h1:4Ei0/BRroMF9FaXDG2e4OxwFcuW2vcXd+A6tyqTJUQQ=
 github.com/chai2010/webp v1.1.0/go.mod h1:LP12PG5IFmLGHUU26tBiCBKnghxx3toZFwDjOYvd3Ow=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s=
 github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU=
+github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo=
+github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M=
 github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
 github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
 github.com/gorilla/mux v1.7.4 h1:VuZ8uybHlWmqV03+zRzdwKL4tUnIp1MAQtp1mIFE1bc=
@@ -12,6 +15,8 @@ github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0U
 github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
 github.com/lib/pq v1.3.0 h1:/qkRGz8zljWiDcFvgpwUpwIAPu3r07TDvs3Rws+o/pU=
 github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
+github.com/lib/pq v1.5.2 h1:yTSXVswvWUOQ3k1sd7vJfDrbSl8lKuscqFJRqjC0ifw=
+github.com/lib/pq v1.5.2/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
 github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA=
 github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
 github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
@@ -22,12 +27,24 @@ github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW
 github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
 github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
 github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q=
 github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
 github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo=
 github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
 github.com/skip2/go-qrcode v0.0.0-20191027152451-9434209cb086 h1:RYiqpb2ii2Z6J4x0wxK46kvPBbFuZcdhS+CIztmYgZs=
 github.com/skip2/go-qrcode v0.0.0-20191027152451-9434209cb086/go.mod h1:PLPIyL7ikehBD1OAjmKKiOEhbvWyHGaNDjquXMcYABo=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
+github.com/tidwall/gjson v1.6.0 h1:9VEQWz6LLMUsUl6PueE49ir4Ka6CzLymOAZDxpFsTDc=
+github.com/tidwall/gjson v1.6.0/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls=
+github.com/tidwall/match v1.0.1 h1:PnKP62LPNxHKTwvHHZZzdOAOCtsJTjo6dZLCwpKm5xc=
+github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E=
+github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
+github.com/tidwall/pretty v1.0.1 h1:WE4RBSZ1x6McVVC8S/Md+Qse8YUv6HRObAx6ke00NY8=
+github.com/tidwall/pretty v1.0.1/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
+github.com/tidwall/sjson v1.1.1 h1:7h1vk049Jnd5EH9NyzNiEuwYW4b5qgreBbqRC19AS3U=
+github.com/tidwall/sjson v1.1.1/go.mod h1:yvVuSnpEQv5cYIrO+AT6kw4QVfd5SDZoGIS7/5+fZFs=
 github.com/tulir/go-whatsapp v0.2.0 h1:JWK/Xxrc1qsZsVz6gYVX5AtvzYmqaHNjt34Ipnrgz88=
 github.com/tulir/go-whatsapp v0.2.0/go.mod h1:gyw9zGup1/Y3ZQUueZaqz3iR/WX9a2Lth4aqEbXjkok=
 github.com/tulir/go-whatsapp v0.2.1 h1:Owoss2AbvZMgt3nxoFlsG+bqLHDnO+PhXNhhoCmb/3M=
@@ -42,6 +59,8 @@ github.com/tulir/go-whatsapp v0.2.6 h1:d58cqz/iqcCDeT+uFjLso8oSgMTYqoxGhGhGOyyHB
 github.com/tulir/go-whatsapp v0.2.6/go.mod h1:gyw9zGup1/Y3ZQUueZaqz3iR/WX9a2Lth4aqEbXjkok=
 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=
 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
+golang.org/x/image v0.0.0-20200430140353-33d19683fad8 h1:6WW6V3x1P/jokJBpRQYUJnMHRP6isStQwCozxnU7XQw=
+golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
 golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/net v0.0.0-20200301022130-244492dfa37a h1:GuSPYbZzB5/dcLNCwLQLsg3obCJtX9IJhpXkvY7kzk0=
 golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@@ -52,6 +71,7 @@ golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 h1:YyJpGZS1sBuBCzLAR1VEpK193
 golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
 gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
 gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
 maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
@@ -60,5 +80,19 @@ maunium.net/go/maulogger/v2 v2.1.1 h1:NAZNc6XUFJzgzfewCzVoGkxNAsblLCSSEdtDuIjP0X
 maunium.net/go/maulogger/v2 v2.1.1/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A=
 maunium.net/go/mautrix v0.1.0-beta.2 h1:RxYTqTzW6iXu83gf8ucqGwYx8JLa+a17LWjiPkVV/fU=
 maunium.net/go/mautrix v0.1.0-beta.2/go.mod h1:YFMU9DBeXH7cqx7sJLg0DkVxwNPbih8QbpUTYf/IjMM=
-maunium.net/go/mautrix-appservice v0.1.0-alpha.6 h1:dNE+RykOC0UhSyRNbMHXEk3BzSOp3dj8aQwKuNMELWM=
-maunium.net/go/mautrix-appservice v0.1.0-alpha.6/go.mod h1:Dfiwiuicvn8s2VKrBDrZ9eCjlKUMbuCi91TE6xeEHRM=
+maunium.net/go/mautrix v0.3.6 h1:bXUo8WFdv7sUpvr7jgJ6TVMEQgVHtw1z1T3eUcLpPCA=
+maunium.net/go/mautrix v0.3.6/go.mod h1:SkGZzch8CvU2qKtNpYxtzZ0sQxfVEJ3IsVVLSUBUx9Y=
+maunium.net/go/mautrix v0.3.7 h1:N0czrZeAwjvBrw2a/B2G6U3EwIYaWpt7OuSslGp8DRc=
+maunium.net/go/mautrix v0.3.7/go.mod h1:SkGZzch8CvU2qKtNpYxtzZ0sQxfVEJ3IsVVLSUBUx9Y=
+maunium.net/go/mautrix v0.4.0 h1:IYfmxCoxR/6UMi92IncsSZeKQbZm8Xa35XIRX814KJ4=
+maunium.net/go/mautrix v0.4.0/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
+maunium.net/go/mautrix v0.4.1 h1:i2lJNT+TE4AAL3cVKUN4jKVRkujCE/oS8aIsj8+7iNE=
+maunium.net/go/mautrix v0.4.1/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
+maunium.net/go/mautrix v0.4.2 h1:GBU++Z7o/fLPcEsNMkNOUsnDknwV/MGPQ0BN4ikK6tw=
+maunium.net/go/mautrix v0.4.2/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
+maunium.net/go/mautrix v0.4.3 h1:fVoJy992TjBEvuK5NeO9fpBh+9JuSFsxaEdGjFp/7h4=
+maunium.net/go/mautrix v0.4.3/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
+maunium.net/go/mautrix v0.4.4 h1:C5yYDzUdRtJj/9Vot5YBPQUsWmn19sTySew7f4ACLhM=
+maunium.net/go/mautrix v0.4.4/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
+maunium.net/go/mautrix v0.4.5 h1:cQhlPURW0TGjlqEoac+4+J/aS5/Rg8x1b+fiFZZz6LI=
+maunium.net/go/mautrix v0.4.5/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=

+ 39 - 18
main.go

@@ -1,5 +1,5 @@
 // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2019 Tulir Asokan
+// Copyright (C) 2020 Tulir Asokan
 //
 // This program is free software: you can redistribute it and/or modify
 // it under the terms of the GNU Affero General Public License as published by
@@ -18,7 +18,6 @@ package main
 
 import (
 	"fmt"
-	"net/http"
 	"os"
 	"os/signal"
 	"sync"
@@ -29,7 +28,9 @@ import (
 	log "maunium.net/go/maulogger/v2"
 
 	"maunium.net/go/mautrix"
-	"maunium.net/go/mautrix-appservice"
+	"maunium.net/go/mautrix/appservice"
+	"maunium.net/go/mautrix/event"
+	"maunium.net/go/mautrix/id"
 
 	"maunium.net/go/mautrix-whatsapp/config"
 	"maunium.net/go/mautrix-whatsapp/database"
@@ -106,29 +107,39 @@ type Bridge struct {
 	Bot            *appservice.IntentAPI
 	Formatter      *Formatter
 	Relaybot       *User
+	Crypto         Crypto
 
-	usersByMXID         map[types.MatrixUserID]*User
+	usersByMXID         map[id.UserID]*User
 	usersByJID          map[types.WhatsAppID]*User
 	usersLock           sync.Mutex
-	managementRooms     map[types.MatrixRoomID]*User
+	managementRooms     map[id.RoomID]*User
 	managementRoomsLock sync.Mutex
-	portalsByMXID       map[types.MatrixRoomID]*Portal
+	portalsByMXID       map[id.RoomID]*Portal
 	portalsByJID        map[database.PortalKey]*Portal
 	portalsLock         sync.Mutex
 	puppets             map[types.WhatsAppID]*Puppet
-	puppetsByCustomMXID map[types.MatrixUserID]*Puppet
+	puppetsByCustomMXID map[id.UserID]*Puppet
 	puppetsLock         sync.Mutex
 }
 
+type Crypto interface {
+	HandleMemberEvent(*event.Event)
+	Decrypt(*event.Event) (*event.Event, error)
+	Encrypt(id.RoomID, event.Type, event.Content) (*event.EncryptedEventContent, error)
+	Init() error
+	Start()
+	Stop()
+}
+
 func NewBridge() *Bridge {
 	bridge := &Bridge{
-		usersByMXID:         make(map[types.MatrixUserID]*User),
+		usersByMXID:         make(map[id.UserID]*User),
 		usersByJID:          make(map[types.WhatsAppID]*User),
-		managementRooms:     make(map[types.MatrixRoomID]*User),
-		portalsByMXID:       make(map[types.MatrixRoomID]*Portal),
+		managementRooms:     make(map[id.RoomID]*User),
+		portalsByMXID:       make(map[id.RoomID]*Portal),
 		portalsByJID:        make(map[database.PortalKey]*Portal),
 		puppets:             make(map[types.WhatsAppID]*Puppet),
-		puppetsByCustomMXID: make(map[types.MatrixUserID]*Puppet),
+		puppetsByCustomMXID: make(map[id.UserID]*Puppet),
 	}
 
 	var err error
@@ -141,12 +152,8 @@ func NewBridge() *Bridge {
 }
 
 func (bridge *Bridge) ensureConnection() {
-	url := bridge.Bot.BuildURL("account", "whoami")
-	resp := struct {
-		UserID string `json:"user_id"`
-	}{}
 	for {
-		_, err := bridge.Bot.MakeRequest(http.MethodGet, url, nil, &resp)
+		resp, err := bridge.Bot.Whoami()
 		if err != nil {
 			if httpErr, ok := err.(mautrix.HTTPError); ok && httpErr.RespError != nil && httpErr.RespError.ErrCode == "M_UNKNOWN_ACCESS_TOKEN" {
 				bridge.Log.Fatalln("Access token invalid. Is the registration installed in your homeserver correctly?")
@@ -219,6 +226,7 @@ func (bridge *Bridge) Init() {
 	bridge.Log.Debugln("Initializing Matrix event handler")
 	bridge.MatrixHandler = NewMatrixHandler(bridge)
 	bridge.Formatter = NewFormatter(bridge)
+	bridge.Crypto = NewCryptoHelper(bridge)
 }
 
 func (bridge *Bridge) Start() {
@@ -227,6 +235,13 @@ func (bridge *Bridge) Start() {
 		bridge.Log.Fatalln("Failed to initialize database:", err)
 		os.Exit(15)
 	}
+	if bridge.Crypto != nil {
+		err := bridge.Crypto.Init()
+		if err != nil {
+			bridge.Log.Fatalln("Error initializing end-to-bridge encryption:", err)
+			os.Exit(19)
+		}
+	}
 	if bridge.Provisioning != nil {
 		bridge.Log.Debugln("Initializing provisioning API")
 		bridge.Provisioning.Init()
@@ -239,6 +254,7 @@ func (bridge *Bridge) Start() {
 	bridge.Log.Debugln("Starting event processor")
 	go bridge.EventProcessor.Start()
 	go bridge.UpdateBotProfile()
+	go bridge.Crypto.Start()
 	go bridge.StartUsers()
 }
 
@@ -262,10 +278,14 @@ func (bridge *Bridge) UpdateBotProfile() {
 	botConfig := bridge.Config.AppService.Bot
 
 	var err error
+	var mxc id.ContentURI
 	if botConfig.Avatar == "remove" {
-		err = bridge.Bot.SetAvatarURL("")
+		err = bridge.Bot.SetAvatarURL(mxc)
 	} else if len(botConfig.Avatar) > 0 {
-		err = bridge.Bot.SetAvatarURL(botConfig.Avatar)
+		mxc, err = id.ParseContentURI(botConfig.Avatar)
+		if err == nil {
+			err = bridge.Bot.SetAvatarURL(mxc)
+		}
 	}
 	if err != nil {
 		bridge.Log.Warnln("Failed to update bot avatar:", err)
@@ -299,6 +319,7 @@ func (bridge *Bridge) StartUsers() {
 }
 
 func (bridge *Bridge) Stop() {
+	bridge.Crypto.Stop()
 	bridge.AS.Stop()
 	bridge.EventProcessor.Stop()
 	for _, user := range bridge.usersByJID {

+ 98 - 51
matrix.go

@@ -1,5 +1,5 @@
 // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2019 Tulir Asokan
+// Copyright (C) 2020 Tulir Asokan
 //
 // This program is free software: you can redistribute it and/or modify
 // it under the terms of the GNU Affero General Public License as published by
@@ -22,11 +22,10 @@ import (
 
 	"maunium.net/go/maulogger/v2"
 
-	"maunium.net/go/mautrix"
-	"maunium.net/go/mautrix-appservice"
+	"maunium.net/go/mautrix/appservice"
+	"maunium.net/go/mautrix/event"
 	"maunium.net/go/mautrix/format"
-
-	"maunium.net/go/mautrix-whatsapp/types"
+	"maunium.net/go/mautrix/id"
 )
 
 type MatrixHandler struct {
@@ -43,17 +42,32 @@ func NewMatrixHandler(bridge *Bridge) *MatrixHandler {
 		log:    bridge.Log.Sub("Matrix"),
 		cmd:    NewCommandHandler(bridge),
 	}
-	bridge.EventProcessor.On(mautrix.EventMessage, handler.HandleMessage)
-	bridge.EventProcessor.On(mautrix.EventSticker, handler.HandleMessage)
-	bridge.EventProcessor.On(mautrix.EventRedaction, handler.HandleRedaction)
-	bridge.EventProcessor.On(mautrix.StateMember, handler.HandleMembership)
-	bridge.EventProcessor.On(mautrix.StateRoomName, handler.HandleRoomMetadata)
-	bridge.EventProcessor.On(mautrix.StateRoomAvatar, handler.HandleRoomMetadata)
-	bridge.EventProcessor.On(mautrix.StateTopic, handler.HandleRoomMetadata)
+	bridge.EventProcessor.On(event.EventMessage, handler.HandleMessage)
+	bridge.EventProcessor.On(event.EventEncrypted, handler.HandleEncrypted)
+	bridge.EventProcessor.On(event.EventSticker, handler.HandleMessage)
+	bridge.EventProcessor.On(event.EventRedaction, handler.HandleRedaction)
+	bridge.EventProcessor.On(event.StateMember, handler.HandleMembership)
+	bridge.EventProcessor.On(event.StateRoomName, handler.HandleRoomMetadata)
+	bridge.EventProcessor.On(event.StateRoomAvatar, handler.HandleRoomMetadata)
+	bridge.EventProcessor.On(event.StateTopic, handler.HandleRoomMetadata)
+	bridge.EventProcessor.On(event.StateEncryption, handler.HandleEncryption)
 	return handler
 }
 
-func (mx *MatrixHandler) HandleBotInvite(evt *mautrix.Event) {
+func (mx *MatrixHandler) HandleEncryption(evt *event.Event) {
+	if evt.Content.AsEncryption().Algorithm != id.AlgorithmMegolmV1 {
+		return
+	}
+	portal := mx.bridge.GetPortalByMXID(evt.RoomID)
+	mx.log.Debugln(portal)
+	if portal != nil && !portal.Encrypted {
+		mx.log.Debugfln("%s enabled encryption in %s", evt.Sender, evt.RoomID)
+		portal.Encrypted = true
+		portal.Update()
+	}
+}
+
+func (mx *MatrixHandler) HandleBotInvite(evt *event.Event) {
 	intent := mx.as.BotIntent()
 
 	user := mx.bridge.GetUserByMXID(evt.Sender)
@@ -61,7 +75,7 @@ func (mx *MatrixHandler) HandleBotInvite(evt *mautrix.Event) {
 		return
 	}
 
-	resp, err := intent.JoinRoom(evt.RoomID, "", nil)
+	resp, err := intent.JoinRoomByID(evt.RoomID)
 	if err != nil {
 		mx.log.Debugln("Failed to join room", evt.RoomID, "with invite from", evt.Sender)
 		return
@@ -97,7 +111,7 @@ func (mx *MatrixHandler) HandleBotInvite(evt *mautrix.Event) {
 	for mxid, _ := range members.Joined {
 		if mxid == intent.UserID || mxid == evt.Sender {
 			continue
-		} else if _, ok := mx.bridge.ParsePuppetMXID(types.MatrixUserID(mxid)); ok {
+		} else if _, ok := mx.bridge.ParsePuppetMXID(mxid); ok {
 			hasPuppets = true
 			continue
 		}
@@ -108,15 +122,24 @@ func (mx *MatrixHandler) HandleBotInvite(evt *mautrix.Event) {
 	}
 
 	if !hasPuppets {
-		user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender))
-		user.SetManagementRoom(types.MatrixRoomID(resp.RoomID))
-		intent.SendNotice(string(user.ManagementRoom), "This room has been registered as your bridge management/status room. Send `help` to get a list of commands.")
+		user := mx.bridge.GetUserByMXID(evt.Sender)
+		user.SetManagementRoom(resp.RoomID)
+		intent.SendNotice(user.ManagementRoom, "This room has been registered as your bridge management/status room. Send `help` to get a list of commands.")
 		mx.log.Debugln(resp.RoomID, "registered as a management room with", evt.Sender)
 	}
 }
 
-func (mx *MatrixHandler) HandleMembership(evt *mautrix.Event) {
-	if evt.Content.Membership == "invite" && evt.GetStateKey() == mx.as.BotMXID() {
+func (mx *MatrixHandler) HandleMembership(evt *event.Event) {
+	if _, isPuppet := mx.bridge.ParsePuppetMXID(evt.Sender); evt.Sender == mx.bridge.Bot.UserID || isPuppet {
+		return
+	}
+
+	if mx.bridge.Crypto != nil {
+		mx.bridge.Crypto.HandleMemberEvent(evt)
+	}
+
+	content := evt.Content.AsMember()
+	if content.Membership == event.MembershipInvite && id.UserID(evt.GetStateKey()) == mx.as.BotMXID() {
 		mx.HandleBotInvite(evt)
 	}
 
@@ -125,15 +148,21 @@ func (mx *MatrixHandler) HandleMembership(evt *mautrix.Event) {
 		return
 	}
 
-	user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender))
+	user := mx.bridge.GetUserByMXID(evt.Sender)
 	if user == nil || !user.Whitelisted || !user.IsConnected() {
 		return
 	}
 
-	if evt.Content.Membership == "leave" {
-		if evt.GetStateKey() == evt.Sender {
-			if portal.IsPrivateChat() || evt.Unsigned.PrevContent.Membership == "join" {
-				portal.HandleMatrixLeave(user)
+	if content.Membership == event.MembershipLeave {
+		if id.UserID(evt.GetStateKey()) == evt.Sender {
+			if evt.Unsigned.PrevContent != nil {
+				_ = evt.Unsigned.PrevContent.ParseRaw(evt.Type)
+				prevContent, ok := evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent)
+				if ok {
+					if portal.IsPrivateChat() || prevContent.Membership == "join" {
+						portal.HandleMatrixLeave(user)
+					}
+				}
 			}
 		} else {
 			portal.HandleMatrixKick(user, evt)
@@ -141,8 +170,8 @@ func (mx *MatrixHandler) HandleMembership(evt *mautrix.Event) {
 	}
 }
 
-func (mx *MatrixHandler) HandleRoomMetadata(evt *mautrix.Event) {
-	user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender))
+func (mx *MatrixHandler) HandleRoomMetadata(evt *event.Event) {
+	user := mx.bridge.GetUserByMXID(evt.Sender)
 	if user == nil || !user.Whitelisted || !user.IsConnected() {
 		return
 	}
@@ -154,12 +183,12 @@ func (mx *MatrixHandler) HandleRoomMetadata(evt *mautrix.Event) {
 
 	var resp <-chan string
 	var err error
-	switch evt.Type {
-	case mautrix.StateRoomName:
-		resp, err = user.Conn.UpdateGroupSubject(evt.Content.Name, portal.Key.JID)
-	case mautrix.StateTopic:
-		resp, err = user.Conn.UpdateGroupDescription(portal.Key.JID, evt.Content.Topic)
-	case mautrix.StateRoomAvatar:
+	switch content := evt.Content.Parsed.(type) {
+	case *event.RoomNameEventContent:
+		resp, err = user.Conn.UpdateGroupSubject(content.Name, portal.Key.JID)
+	case *event.TopicEventContent:
+		resp, err = user.Conn.UpdateGroupDescription(portal.Key.JID, content.Topic)
+	case *event.RoomAvatarEventContent:
 		return
 	}
 	if err != nil {
@@ -170,47 +199,65 @@ func (mx *MatrixHandler) HandleRoomMetadata(evt *mautrix.Event) {
 	}
 }
 
-func (mx *MatrixHandler) HandleMessage(evt *mautrix.Event) {
+func (mx *MatrixHandler) shouldIgnoreEvent(evt *event.Event) bool {
 	if _, isPuppet := mx.bridge.ParsePuppetMXID(evt.Sender); evt.Sender == mx.bridge.Bot.UserID || isPuppet {
-		return
+		return true
 	}
 	isCustomPuppet, ok := evt.Content.Raw["net.maunium.whatsapp.puppet"].(bool)
 	if ok && isCustomPuppet && mx.bridge.GetPuppetByCustomMXID(evt.Sender) != nil {
+		return true
+	}
+	user := mx.bridge.GetUserByMXID(evt.Sender)
+	if !user.RelaybotWhitelisted {
+		return true
+	}
+	return false
+}
+
+func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) {
+	if mx.shouldIgnoreEvent(evt) || mx.bridge.Crypto == nil {
 		return
 	}
 
-	roomID := types.MatrixRoomID(evt.RoomID)
-	user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender))
+	decrypted, err := mx.bridge.Crypto.Decrypt(evt)
+	if err != nil {
+		mx.log.Warnfln("Failed to decrypt %s: %v", evt.ID, err)
+		return
+	}
+	mx.bridge.EventProcessor.Dispatch(decrypted)
+}
 
-	if !user.RelaybotWhitelisted {
+func (mx *MatrixHandler) HandleMessage(evt *event.Event) {
+	if mx.shouldIgnoreEvent(evt) {
 		return
 	}
 
-	if user.Whitelisted && evt.Content.MsgType == mautrix.MsgText {
+	user := mx.bridge.GetUserByMXID(evt.Sender)
+	content := evt.Content.AsMessage()
+	if user.Whitelisted && content.MsgType == event.MsgText {
 		commandPrefix := mx.bridge.Config.Bridge.CommandPrefix
-		hasCommandPrefix := strings.HasPrefix(evt.Content.Body, commandPrefix)
+		hasCommandPrefix := strings.HasPrefix(content.Body, commandPrefix)
 		if hasCommandPrefix {
-			evt.Content.Body = strings.TrimLeft(evt.Content.Body[len(commandPrefix):], " ")
+			content.Body = strings.TrimLeft(content.Body[len(commandPrefix):], " ")
 		}
-		if hasCommandPrefix || roomID == user.ManagementRoom {
-			mx.cmd.Handle(roomID, user, evt.Content.Body)
+		if hasCommandPrefix || evt.RoomID == user.ManagementRoom {
+			mx.cmd.Handle(evt.RoomID, user, content.Body)
 			return
 		}
 	}
 
-	portal := mx.bridge.GetPortalByMXID(roomID)
+	portal := mx.bridge.GetPortalByMXID(evt.RoomID)
 	if portal != nil && (user.Whitelisted || portal.HasRelaybot()) {
 		portal.HandleMatrixMessage(user, evt)
 	}
 }
 
-func (mx *MatrixHandler) HandleRedaction(evt *mautrix.Event) {
+func (mx *MatrixHandler) HandleRedaction(evt *event.Event) {
 	if _, isPuppet := mx.bridge.ParsePuppetMXID(evt.Sender); evt.Sender == mx.bridge.Bot.UserID || isPuppet {
 		return
 	}
 
-	roomID := types.MatrixRoomID(evt.RoomID)
-	user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender))
+	user := mx.bridge.GetUserByMXID(evt.Sender)
 
 	if !user.Whitelisted {
 		return
@@ -221,13 +268,13 @@ func (mx *MatrixHandler) HandleRedaction(evt *mautrix.Event) {
 	} else if !user.IsConnected() {
 		msg := format.RenderMarkdown(fmt.Sprintf("[%[1]s](https://matrix.to/#/%[1]s): \u26a0 "+
 			"You are not connected to WhatsApp, so your redaction was not bridged. "+
-			"Use `%[2]s reconnect` to reconnect.", user.MXID, mx.bridge.Config.Bridge.CommandPrefix))
-		msg.MsgType = mautrix.MsgNotice
-		_, _ = mx.bridge.Bot.SendMessageEvent(roomID, mautrix.EventMessage, msg)
+			"Use `%[2]s reconnect` to reconnect.", user.MXID, mx.bridge.Config.Bridge.CommandPrefix), true, false)
+		msg.MsgType = event.MsgNotice
+		_, _ = mx.bridge.Bot.SendMessageEvent(evt.RoomID, event.EventMessage, msg)
 		return
 	}
 
-	portal := mx.bridge.GetPortalByMXID(roomID)
+	portal := mx.bridge.GetPortalByMXID(evt.RoomID)
 	if portal != nil {
 		portal.HandleMatrixRedaction(user, evt)
 	}

+ 38 - 0
no-cgo.go

@@ -0,0 +1,38 @@
+// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
+// Copyright (C) 2020 Tulir Asokan
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+// +build !cgo
+
+package main
+
+import (
+	"image"
+	"io"
+
+	"golang.org/x/image/webp"
+)
+
+func NewCryptoHelper(bridge *Bridge) Crypto {
+	if !bridge.Config.Bridge.Encryption.Allow {
+		bridge.Log.Warnln("Bridge built without end-to-bridge encryption, but encryption is enabled in config")
+	}
+	bridge.Log.Debugln("Bridge built without end-to-bridge encryption")
+	return nil
+}
+
+func decodeWebp(r io.Reader) (image.Image, error) {
+	return webp.Decode(r)
+}

+ 191 - 136
portal.go

@@ -1,5 +1,5 @@
 // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2019 Tulir Asokan
+// Copyright (C) 2020 Tulir Asokan
 //
 // This program is free software: you can redistribute it and/or modify
 // it under the terms of the GNU Affero General Public License as published by
@@ -20,7 +20,6 @@ import (
 	"bytes"
 	"encoding/gob"
 	"encoding/hex"
-	"encoding/json"
 	"fmt"
 	"html"
 	"image"
@@ -35,22 +34,24 @@ import (
 	"sync"
 	"time"
 
-	"github.com/chai2010/webp"
+	"github.com/pkg/errors"
 	log "maunium.net/go/maulogger/v2"
 
 	"github.com/Rhymen/go-whatsapp"
 	waProto "github.com/Rhymen/go-whatsapp/binary/proto"
 
 	"maunium.net/go/mautrix"
-	"maunium.net/go/mautrix-appservice"
+	"maunium.net/go/mautrix/appservice"
+	"maunium.net/go/mautrix/event"
 	"maunium.net/go/mautrix/format"
+	"maunium.net/go/mautrix/id"
 
 	"maunium.net/go/mautrix-whatsapp/database"
 	"maunium.net/go/mautrix-whatsapp/types"
 	"maunium.net/go/mautrix-whatsapp/whatsapp-ext"
 )
 
-func (bridge *Bridge) GetPortalByMXID(mxid types.MatrixRoomID) *Portal {
+func (bridge *Bridge) GetPortalByMXID(mxid id.RoomID) *Portal {
 	bridge.portalsLock.Lock()
 	defer bridge.portalsLock.Unlock()
 	portal, ok := bridge.portalsByMXID[mxid]
@@ -233,7 +234,7 @@ func init() {
 	gob.Register(&waProto.Message{})
 }
 
-func (portal *Portal) markHandled(source *User, message *waProto.WebMessageInfo, mxid types.MatrixEventID) {
+func (portal *Portal) markHandled(source *User, message *waProto.WebMessageInfo, mxid id.EventID) {
 	msg := portal.bridge.DB.Message.New()
 	msg.Chat = portal.Key
 	msg.JID = message.GetKey().GetId()
@@ -269,7 +270,7 @@ func (portal *Portal) startHandling(info whatsapp.MessageInfo) bool {
 	return true
 }
 
-func (portal *Portal) finishHandling(source *User, message *waProto.WebMessageInfo, mxid types.MatrixEventID) {
+func (portal *Portal) finishHandling(source *User, message *waProto.WebMessageInfo, mxid id.EventID) {
 	portal.markHandled(source, message, mxid)
 	portal.log.Debugln("Handled message", message.GetKey().GetId(), "->", mxid)
 }
@@ -416,7 +417,7 @@ func (portal *Portal) UpdateMetadata(user *User) bool {
 	return update
 }
 
-func (portal *Portal) userMXIDAction(user *User, fn func(mxid types.MatrixUserID)) {
+func (portal *Portal) userMXIDAction(user *User, fn func(mxid id.UserID)) {
 	if user == nil {
 		return
 	}
@@ -430,7 +431,7 @@ func (portal *Portal) userMXIDAction(user *User, fn func(mxid types.MatrixUserID
 	}
 }
 
-func (portal *Portal) ensureMXIDInvited(mxid types.MatrixUserID) {
+func (portal *Portal) ensureMXIDInvited(mxid id.UserID) {
 	err := portal.MainIntent().EnsureInvited(portal.MXID, mxid)
 	if err != nil {
 		portal.log.Warnfln("Failed to ensure %s is invited to %s: %v", mxid, portal.MXID, err)
@@ -481,27 +482,27 @@ func (portal *Portal) Sync(user *User, contact whatsapp.Contact) {
 	}
 }
 
-func (portal *Portal) GetBasePowerLevels() *mautrix.PowerLevels {
+func (portal *Portal) GetBasePowerLevels() *event.PowerLevelsEventContent {
 	anyone := 0
 	nope := 99
 	invite := 99
 	if portal.bridge.Config.Bridge.AllowUserInvite {
 		invite = 0
 	}
-	return &mautrix.PowerLevels{
+	return &event.PowerLevelsEventContent{
 		UsersDefault:    anyone,
 		EventsDefault:   anyone,
 		RedactPtr:       &anyone,
 		StateDefaultPtr: &nope,
 		BanPtr:          &nope,
 		InvitePtr:       &invite,
-		Users: map[string]int{
+		Users: map[id.UserID]int{
 			portal.MainIntent().UserID: 100,
 		},
 		Events: map[string]int{
-			mautrix.StateRoomName.Type:   anyone,
-			mautrix.StateRoomAvatar.Type: anyone,
-			mautrix.StateTopic.Type:      anyone,
+			event.StateRoomName.Type:   anyone,
+			event.StateRoomAvatar.Type: anyone,
+			event.StateTopic.Type:      anyone,
 		},
 	}
 }
@@ -559,9 +560,9 @@ func (portal *Portal) RestrictMetadataChanges(restrict bool) {
 		newLevel = 50
 	}
 	changed := false
-	changed = levels.EnsureEventLevel(mautrix.StateRoomName, newLevel) || changed
-	changed = levels.EnsureEventLevel(mautrix.StateRoomAvatar, newLevel) || changed
-	changed = levels.EnsureEventLevel(mautrix.StateTopic, newLevel) || changed
+	changed = levels.EnsureEventLevel(event.StateRoomName, newLevel) || changed
+	changed = levels.EnsureEventLevel(event.StateRoomAvatar, newLevel) || changed
+	changed = levels.EnsureEventLevel(event.StateTopic, newLevel) || changed
 	if changed {
 		_, err = portal.MainIntent().SetPowerLevels(portal.MXID, levels)
 		if err != nil {
@@ -724,7 +725,6 @@ func (portal *Portal) CreateMatrixRoom(user *User) error {
 	portal.log.Infoln("Creating Matrix room. Info source:", user.MXID)
 
 	var metadata *whatsappExt.GroupInfo
-	isPrivateChat := false
 	if portal.IsPrivateChat() {
 		puppet := portal.bridge.GetPuppetByJID(portal.Key.JID)
 		if portal.bridge.Config.Bridge.PrivateChatPortalMeta {
@@ -735,7 +735,6 @@ func (portal *Portal) CreateMatrixRoom(user *User) error {
 			portal.Name = ""
 		}
 		portal.Topic = "WhatsApp private chat"
-		isPrivateChat = true
 	} else if portal.IsStatusBroadcastRoom() {
 		portal.Name = "WhatsApp Status Broadcast"
 		portal.Topic = "WhatsApp status updates from your contacts"
@@ -749,33 +748,46 @@ func (portal *Portal) CreateMatrixRoom(user *User) error {
 		portal.UpdateAvatar(user, nil)
 	}
 
-	initialState := []*mautrix.Event{{
-		Type: mautrix.StatePowerLevels,
-		Content: mautrix.Content{
-			PowerLevels: portal.GetBasePowerLevels(),
+	initialState := []*event.Event{{
+		Type: event.StatePowerLevels,
+		Content: event.Content{
+			Parsed: portal.GetBasePowerLevels(),
 		},
 	}}
-	if len(portal.AvatarURL) > 0 {
-		initialState = append(initialState, &mautrix.Event{
-			Type: mautrix.StateRoomAvatar,
-			Content: mautrix.Content{
-				URL: portal.AvatarURL,
+	if !portal.AvatarURL.IsEmpty() {
+		initialState = append(initialState, &event.Event{
+			Type: event.StateRoomAvatar,
+			Content: event.Content{
+				Parsed: event.RoomAvatarEventContent{URL: portal.AvatarURL},
 			},
 		})
 	}
 
-	invite := []string{user.MXID}
+	invite := []id.UserID{user.MXID}
 	if user.IsRelaybot {
 		invite = portal.bridge.Config.Bridge.Relaybot.InviteUsers
 	}
 
+	if portal.bridge.Config.Bridge.Encryption.Default {
+		initialState = append(initialState, &event.Event{
+			Type: event.StateEncryption,
+			Content: event.Content{
+				Parsed: event.EncryptionEventContent{Algorithm: id.AlgorithmMegolmV1},
+			},
+		})
+		portal.Encrypted = true
+		if portal.IsPrivateChat() {
+			invite = append(invite, portal.bridge.Bot.UserID)
+		}
+	}
+
 	resp, err := intent.CreateRoom(&mautrix.ReqCreateRoom{
 		Visibility:   "private",
 		Name:         portal.Name,
 		Topic:        portal.Topic,
 		Invite:       invite,
 		Preset:       "private_chat",
-		IsDirect:     isPrivateChat,
+		IsDirect:     portal.IsPrivateChat(),
 		InitialState: initialState,
 	})
 	if err != nil {
@@ -783,6 +795,12 @@ func (portal *Portal) CreateMatrixRoom(user *User) error {
 	}
 	portal.MXID = resp.RoomID
 	portal.Update()
+
+	// We set the memberships beforehand to make sure the encryption key exchange in initial backfill knows the users are here.
+	for _, user := range invite {
+		portal.bridge.StateStore.SetMembership(portal.MXID, user, event.MembershipInvite)
+	}
+
 	if metadata != nil {
 		portal.SyncParticipants(metadata)
 	} else {
@@ -795,6 +813,13 @@ func (portal *Portal) CreateMatrixRoom(user *User) error {
 	if portal.IsPrivateChat() {
 		puppet := user.bridge.GetPuppetByJID(portal.Key.JID)
 		user.addPuppetToCommunity(puppet)
+
+		if portal.bridge.Config.Bridge.Encryption.Default {
+			err = portal.bridge.Bot.EnsureJoined(portal.MXID)
+			if err != nil {
+				portal.log.Errorln("Failed to join created portal with bridge bot for e2be:", err)
+			}
+		}
 	}
 	err = portal.FillInitialHistory(user)
 	if err != nil {
@@ -847,19 +872,18 @@ func (portal *Portal) GetMessageIntent(user *User, info whatsapp.MessageInfo) *a
 	return portal.bridge.GetPuppetByJID(info.SenderJid).IntentFor(portal)
 }
 
-func (portal *Portal) SetReply(content *mautrix.Content, info whatsapp.ContextInfo) {
+func (portal *Portal) SetReply(content *event.MessageEventContent, info whatsapp.ContextInfo) {
 	if len(info.QuotedMessageID) == 0 {
 		return
 	}
 	message := portal.bridge.DB.Message.GetByJID(portal.Key, info.QuotedMessageID)
 	if message != nil {
-		event, err := portal.MainIntent().GetEvent(portal.MXID, message.MXID)
+		evt, err := portal.MainIntent().GetEvent(portal.MXID, message.MXID)
 		if err != nil {
 			portal.log.Warnln("Failed to get reply target:", err)
 			return
 		}
-		event.Content.RemoveReplyFallback()
-		content.SetReply(event)
+		content.SetReply(evt)
 	}
 	return
 }
@@ -895,7 +919,14 @@ func (portal *Portal) HandleFakeMessage(source *User, message FakeMessage) {
 		return
 	}
 
-	_, err := portal.MainIntent().SendNotice(portal.MXID, message.Text)
+	content := event.MessageEventContent{
+		MsgType: event.MsgNotice,
+		Body:    message.Text,
+	}
+	if message.Alert {
+		content.MsgType = event.MsgText
+	}
+	_, err := portal.sendMainIntentMessage(content)
 	if err != nil {
 		portal.log.Errorfln("Failed to handle fake message %s: %v", message.ID, err)
 		return
@@ -908,30 +939,30 @@ func (portal *Portal) HandleFakeMessage(source *User, message FakeMessage) {
 	portal.recentlyHandled[index] = message.ID
 }
 
-type MessageContent struct {
-	*mautrix.Content
-	IsCustomPuppet bool `json:"net.maunium.whatsapp.puppet,omitempty"`
+func (portal *Portal) sendMainIntentMessage(content interface{}) (*mautrix.RespSendEvent, error) {
+	return portal.sendMessage(portal.MainIntent(), event.EventMessage, content, 0)
 }
 
-type serializableContent mautrix.Content
-
-type serializableMessageContent struct {
-	*serializableContent
-	IsCustomPuppet bool `json:"net.maunium.whatsapp.puppet,omitempty"`
-}
-
-// Hacky bypass for mautrix.Content's MarshalSJSON
-func (content *MessageContent) MarshalJSON() ([]byte, error) {
-	if mautrix.DisableFancyEventParsing {
-		if content.IsCustomPuppet {
-			content.Raw["net.maunium.whatsapp.puppet"] = content.IsCustomPuppet
+func (portal *Portal) sendMessage(intent *appservice.IntentAPI, eventType event.Type, content interface{}, timestamp int64) (*mautrix.RespSendEvent, error) {
+	wrappedContent := event.Content{Parsed: content}
+	if timestamp != 0 && intent.IsCustomPuppet {
+		wrappedContent.Raw = map[string]interface{}{
+			"net.maunium.whatsapp.puppet": intent.IsCustomPuppet,
 		}
-		return json.Marshal(content.Raw)
 	}
-	return json.Marshal(&serializableMessageContent{
-		serializableContent: (*serializableContent)(content.Content),
-		IsCustomPuppet:      content.IsCustomPuppet,
-	})
+	if portal.Encrypted && portal.bridge.Crypto != nil {
+		encrypted, err := portal.bridge.Crypto.Encrypt(portal.MXID, eventType, wrappedContent)
+		if err != nil {
+			return nil, errors.Wrap(err, "failed to encrypt event")
+		}
+		eventType = event.EventEncrypted
+		wrappedContent.Parsed = encrypted
+	}
+	if timestamp == 0 {
+		return intent.SendMessageEvent(portal.MXID, eventType, &wrappedContent)
+	} else {
+		return intent.SendMassagedMessageEvent(portal.MXID, eventType, &wrappedContent, timestamp)
+	}
 }
 
 func (portal *Portal) HandleTextMessage(source *User, message whatsapp.TextMessage) {
@@ -944,16 +975,16 @@ func (portal *Portal) HandleTextMessage(source *User, message whatsapp.TextMessa
 		return
 	}
 
-	content := &mautrix.Content{
+	content := &event.MessageEventContent{
 		Body:    message.Text,
-		MsgType: mautrix.MsgText,
+		MsgType: event.MsgText,
 	}
 
 	portal.bridge.Formatter.ParseWhatsApp(content)
 	portal.SetReply(content, message.ContextInfo)
 
 	_, _ = intent.UserTyping(portal.MXID, false, 0)
-	resp, err := intent.SendMassagedMessageEvent(portal.MXID, mautrix.EventMessage, &MessageContent{content, intent.IsCustomPuppet}, int64(message.Info.Timestamp*1000))
+	resp, err := portal.sendMessage(intent, event.EventMessage, content, int64(message.Info.Timestamp*1000))
 	if err != nil {
 		portal.log.Errorfln("Failed to handle message %s: %v", message.Info.Id, err)
 		return
@@ -977,7 +1008,10 @@ func (portal *Portal) HandleMediaMessage(source *User, download func() ([]byte,
 		return
 	} else if err != nil {
 		portal.log.Errorfln("Failed to download media for %s: %v", info.Id, err)
-		resp, err := portal.MainIntent().SendNotice(portal.MXID, "Failed to bridge media")
+		resp, err := portal.sendMainIntentMessage(event.MessageEventContent{
+			MsgType: event.MsgNotice,
+			Body:    "Failed to bridge media",
+		})
 		if err != nil {
 			portal.log.Errorfln("Failed to send media download error message for %s: %v", info.Id, err)
 		} else {
@@ -988,7 +1022,7 @@ func (portal *Portal) HandleMediaMessage(source *User, download func() ([]byte,
 
 	// synapse doesn't handle webp well, so we convert it. This can be dropped once https://github.com/matrix-org/synapse/issues/4382 is fixed
 	if mimeType == "image/webp" {
-		img, err := webp.Decode(bytes.NewReader(data))
+		img, err := decodeWebp(bytes.NewReader(data))
 		if err != nil {
 			portal.log.Errorfln("Failed to decode media for %s: %v", err)
 			return
@@ -1016,10 +1050,10 @@ func (portal *Portal) HandleMediaMessage(source *User, download func() ([]byte,
 		fileName += exts[0]
 	}
 
-	content := &mautrix.Content{
+	content := &event.MessageEventContent{
 		Body: fileName,
-		URL:  uploaded.ContentURI,
-		Info: &mautrix.FileInfo{
+		URL:  uploaded.ContentURI.CUString(),
+		Info: &event.FileInfo{
 			Size:     len(data),
 			MimeType: mimeType,
 		},
@@ -1030,9 +1064,9 @@ func (portal *Portal) HandleMediaMessage(source *User, download func() ([]byte,
 		thumbnailMime := http.DetectContentType(thumbnail)
 		uploadedThumbnail, _ := intent.UploadBytes(thumbnail, thumbnailMime)
 		if uploadedThumbnail != nil {
-			content.Info.ThumbnailURL = uploadedThumbnail.ContentURI
+			content.Info.ThumbnailURL = uploadedThumbnail.ContentURI.CUString()
 			cfg, _, _ := image.DecodeConfig(bytes.NewReader(data))
-			content.Info.ThumbnailInfo = &mautrix.FileInfo{
+			content.Info.ThumbnailInfo = &event.FileInfo{
 				Size:     len(thumbnail),
 				Width:    cfg.Width,
 				Height:   cfg.Height,
@@ -1044,40 +1078,40 @@ func (portal *Portal) HandleMediaMessage(source *User, download func() ([]byte,
 	switch strings.ToLower(strings.Split(mimeType, "/")[0]) {
 	case "image":
 		if !sendAsSticker {
-			content.MsgType = mautrix.MsgImage
+			content.MsgType = event.MsgImage
 		}
 		cfg, _, _ := image.DecodeConfig(bytes.NewReader(data))
 		content.Info.Width = cfg.Width
 		content.Info.Height = cfg.Height
 	case "video":
-		content.MsgType = mautrix.MsgVideo
+		content.MsgType = event.MsgVideo
 	case "audio":
-		content.MsgType = mautrix.MsgAudio
+		content.MsgType = event.MsgAudio
 	default:
-		content.MsgType = mautrix.MsgFile
+		content.MsgType = event.MsgFile
 	}
 
 	_, _ = intent.UserTyping(portal.MXID, false, 0)
 	ts := int64(info.Timestamp * 1000)
-	eventType := mautrix.EventMessage
+	eventType := event.EventMessage
 	if sendAsSticker {
-		eventType = mautrix.EventSticker
+		eventType = event.EventSticker
 	}
-	resp, err := intent.SendMassagedMessageEvent(portal.MXID, eventType, &MessageContent{content, intent.IsCustomPuppet}, ts)
+	resp, err := portal.sendMessage(intent, eventType, content, ts)
 	if err != nil {
 		portal.log.Errorfln("Failed to handle message %s: %v", info.Id, err)
 		return
 	}
 
 	if len(caption) > 0 {
-		captionContent := &mautrix.Content{
+		captionContent := &event.MessageEventContent{
 			Body:    caption,
-			MsgType: mautrix.MsgNotice,
+			MsgType: event.MsgNotice,
 		}
 
 		portal.bridge.Formatter.ParseWhatsApp(captionContent)
 
-		_, err := intent.SendMassagedMessageEvent(portal.MXID, mautrix.EventMessage, &MessageContent{captionContent, intent.IsCustomPuppet}, ts)
+		_, err := portal.sendMessage(intent, event.EventMessage, content, ts)
 		if err != nil {
 			portal.log.Warnfln("Failed to handle caption of message %s: %v", info.Id, err)
 		}
@@ -1094,14 +1128,17 @@ func makeMessageID() *string {
 	return &str
 }
 
-func (portal *Portal) downloadThumbnail(evt *mautrix.Event) []byte {
-	if evt.Content.Info == nil || len(evt.Content.Info.ThumbnailURL) == 0 {
+func (portal *Portal) downloadThumbnail(content *event.MessageEventContent, id id.EventID) []byte {
+	if len(content.GetInfo().ThumbnailURL) == 0 {
 		return nil
 	}
-
-	thumbnail, err := portal.MainIntent().DownloadBytes(evt.Content.Info.ThumbnailURL)
+	mxc, err := content.GetInfo().ThumbnailURL.Parse()
 	if err != nil {
-		portal.log.Errorln("Failed to download thumbnail in %s: %v", evt.ID, err)
+		portal.log.Errorln("Malformed thumbnail URL in %s: %v", id, err)
+	}
+	thumbnail, err := portal.MainIntent().DownloadBytes(mxc)
+	if err != nil {
+		portal.log.Errorln("Failed to download thumbnail in %s: %v", id, err)
 		return nil
 	}
 	thumbnailType := http.DetectContentType(thumbnail)
@@ -1121,30 +1158,44 @@ func (portal *Portal) downloadThumbnail(evt *mautrix.Event) []byte {
 		Quality: jpeg.DefaultQuality,
 	})
 	if err != nil {
-		portal.log.Errorln("Failed to re-encode thumbnail in %s: %v", evt.ID, err)
+		portal.log.Errorln("Failed to re-encode thumbnail in %s: %v", id, err)
 		return nil
 	}
 	return buf.Bytes()
 }
 
-func (portal *Portal) preprocessMatrixMedia(sender *User, relaybotFormatted bool, evt *mautrix.Event, mediaType whatsapp.MediaType) *MediaUpload {
-	if evt.Content.Info == nil {
-		evt.Content.Info = &mautrix.FileInfo{}
-	}
+func (portal *Portal) preprocessMatrixMedia(sender *User, relaybotFormatted bool, content *event.MessageEventContent, eventID id.EventID, mediaType whatsapp.MediaType) *MediaUpload {
 	var caption string
 	if relaybotFormatted {
-		caption = portal.bridge.Formatter.ParseMatrix(evt.Content.FormattedBody)
+		caption = portal.bridge.Formatter.ParseMatrix(content.FormattedBody)
 	}
 
-	content, err := portal.MainIntent().DownloadBytes(evt.Content.URL)
+	var file *event.EncryptedFileInfo
+	rawMXC := content.URL
+	if content.File != nil {
+		file = content.File
+		rawMXC = file.URL
+	}
+	mxc, err := rawMXC.Parse()
+	if err != nil {
+		portal.log.Errorln("Malformed content URL in %s: %v", eventID, err)
+	}
+	data, err := portal.MainIntent().DownloadBytes(mxc)
 	if err != nil {
-		portal.log.Errorfln("Failed to download media in %s: %v", evt.ID, err)
+		portal.log.Errorfln("Failed to download media in %s: %v", eventID, err)
 		return nil
 	}
+	if file != nil {
+		data, err = file.Decrypt(data)
+		if err != nil {
+			portal.log.Errorfln("Failed to decrypt media in %s: %v", eventID, err)
+			return nil
+		}
+	}
 
-	url, mediaKey, fileEncSHA256, fileSHA256, fileLength, err := sender.Conn.Upload(bytes.NewReader(content), mediaType)
+	url, mediaKey, fileEncSHA256, fileSHA256, fileLength, err := sender.Conn.Upload(bytes.NewReader(data), mediaType)
 	if err != nil {
-		portal.log.Errorfln("Failed to upload media in %s: %v", evt.ID, err)
+		portal.log.Errorfln("Failed to upload media in %s: %v", eventID, err)
 		return nil
 	}
 
@@ -1155,7 +1206,7 @@ func (portal *Portal) preprocessMatrixMedia(sender *User, relaybotFormatted bool
 		FileEncSHA256: fileEncSHA256,
 		FileSHA256:    fileSHA256,
 		FileLength:    fileLength,
-		Thumbnail:     portal.downloadThumbnail(evt),
+		Thumbnail:     portal.downloadThumbnail(content, eventID),
 	}
 }
 
@@ -1169,7 +1220,7 @@ type MediaUpload struct {
 	Thumbnail     []byte
 }
 
-func (portal *Portal) sendMatrixConnectionError(sender *User, eventID string) bool {
+func (portal *Portal) sendMatrixConnectionError(sender *User, eventID id.EventID) bool {
 	if !sender.HasSession() {
 		portal.log.Debugln("Ignoring event", eventID, "from", sender.MXID, "as user has no session")
 		return true
@@ -1183,9 +1234,9 @@ func (portal *Portal) sendMatrixConnectionError(sender *User, eventID string) bo
 		if sender.IsLoginInProgress() {
 			reconnect = "You have a login attempt in progress, please wait."
 		}
-		msg := format.RenderMarkdown("\u26a0 You are not connected to WhatsApp, so your message was not bridged. " + reconnect)
-		msg.MsgType = mautrix.MsgNotice
-		_, err := portal.MainIntent().SendMessageEvent(portal.MXID, mautrix.EventMessage, msg)
+		msg := format.RenderMarkdown("\u26a0 You are not connected to WhatsApp, so your message was not bridged. "+reconnect, true, false)
+		msg.MsgType = event.MsgNotice
+		_, err := portal.sendMainIntentMessage(msg)
 		if err != nil {
 			portal.log.Errorln("Failed to send bridging failure message:", err)
 		}
@@ -1194,30 +1245,34 @@ func (portal *Portal) sendMatrixConnectionError(sender *User, eventID string) bo
 	return false
 }
 
-func (portal *Portal) addRelaybotFormat(user *User, evt *mautrix.Event) bool {
-	member := portal.MainIntent().Member(portal.MXID, evt.Sender)
+func (portal *Portal) addRelaybotFormat(sender *User, content *event.MessageEventContent) bool {
+	member := portal.MainIntent().Member(portal.MXID, sender.MXID)
 	if len(member.Displayname) == 0 {
-		member.Displayname = evt.Sender
+		member.Displayname = string(sender.MXID)
 	}
 
-	if evt.Content.Format != mautrix.FormatHTML {
-		evt.Content.FormattedBody = strings.Replace(html.EscapeString(evt.Content.Body), "\n", "<br/>", -1)
-		evt.Content.Format = mautrix.FormatHTML
+	if content.Format != event.FormatHTML {
+		content.FormattedBody = strings.Replace(html.EscapeString(content.Body), "\n", "<br/>", -1)
+		content.Format = event.FormatHTML
 	}
-	data, err := portal.bridge.Config.Bridge.Relaybot.FormatMessage(evt, member)
+	data, err := portal.bridge.Config.Bridge.Relaybot.FormatMessage(content, sender.MXID, member)
 	if err != nil {
 		portal.log.Errorln("Failed to apply relaybot format:", err)
 	}
-	evt.Content.FormattedBody = data
+	content.FormattedBody = data
 	return true
 }
 
-func (portal *Portal) HandleMatrixMessage(sender *User, evt *mautrix.Event) {
+func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event) {
 	if !portal.HasRelaybot() && (
 		(portal.IsPrivateChat() && sender.JID != portal.Key.Receiver) ||
 			portal.sendMatrixConnectionError(sender, evt.ID)) {
 		return
 	}
+	content := evt.Content.AsMessage()
+	if content == nil {
+		return
+	}
 	portal.log.Debugfln("Received event %s", evt.ID)
 
 	ts := uint64(evt.Timestamp / 1000)
@@ -1234,9 +1289,9 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *mautrix.Event) {
 		Status:           &status,
 	}
 	ctxInfo := &waProto.ContextInfo{}
-	replyToID := evt.Content.GetReplyTo()
+	replyToID := content.GetReplyTo()
 	if len(replyToID) > 0 {
-		evt.Content.RemoveReplyFallback()
+		content.RemoveReplyFallback()
 		msg := portal.bridge.DB.Message.GetByMXID(replyToID)
 		if msg != nil && msg.Content != nil {
 			ctxInfo.StanzaId = &msg.JID
@@ -1254,21 +1309,21 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *mautrix.Event) {
 				return
 			}
 		} else {
-			relaybotFormatted = portal.addRelaybotFormat(sender, evt)
+			relaybotFormatted = portal.addRelaybotFormat(sender, content)
 			sender = portal.bridge.Relaybot
 		}
 	}
-	if evt.Type == mautrix.EventSticker {
-		evt.Content.MsgType = mautrix.MsgImage
+	if evt.Type == event.EventSticker {
+		content.MsgType = event.MsgImage
 	}
 	var err error
-	switch evt.Content.MsgType {
-	case mautrix.MsgText, mautrix.MsgEmote, mautrix.MsgNotice:
-		text := evt.Content.Body
-		if evt.Content.Format == mautrix.FormatHTML {
-			text = portal.bridge.Formatter.ParseMatrix(evt.Content.FormattedBody)
+	switch content.MsgType {
+	case event.MsgText, event.MsgEmote, event.MsgNotice:
+		text := content.Body
+		if content.Format == event.FormatHTML {
+			text = portal.bridge.Formatter.ParseMatrix(content.FormattedBody)
 		}
-		if evt.Content.MsgType == mautrix.MsgEmote && !relaybotFormatted {
+		if content.MsgType == event.MsgEmote && !relaybotFormatted {
 			text = "/me " + text
 		}
 		ctxInfo.MentionedJid = mentionRegex.FindAllString(text, -1)
@@ -1283,8 +1338,8 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *mautrix.Event) {
 		} else {
 			info.Message.Conversation = &text
 		}
-	case mautrix.MsgImage:
-		media := portal.preprocessMatrixMedia(sender, relaybotFormatted, evt, whatsapp.MediaImage)
+	case event.MsgImage:
+		media := portal.preprocessMatrixMedia(sender, relaybotFormatted, content, evt.ID, whatsapp.MediaImage)
 		if media == nil {
 			return
 		}
@@ -1293,53 +1348,53 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *mautrix.Event) {
 			JpegThumbnail: media.Thumbnail,
 			Url:           &media.URL,
 			MediaKey:      media.MediaKey,
-			Mimetype:      &evt.Content.GetInfo().MimeType,
+			Mimetype:      &content.GetInfo().MimeType,
 			FileEncSha256: media.FileEncSHA256,
 			FileSha256:    media.FileSHA256,
 			FileLength:    &media.FileLength,
 		}
-	case mautrix.MsgVideo:
-		media := portal.preprocessMatrixMedia(sender, relaybotFormatted, evt, whatsapp.MediaVideo)
+	case event.MsgVideo:
+		media := portal.preprocessMatrixMedia(sender, relaybotFormatted, content, evt.ID, whatsapp.MediaVideo)
 		if media == nil {
 			return
 		}
-		duration := uint32(evt.Content.GetInfo().Duration)
+		duration := uint32(content.GetInfo().Duration)
 		info.Message.VideoMessage = &waProto.VideoMessage{
 			Caption:       &media.Caption,
 			JpegThumbnail: media.Thumbnail,
 			Url:           &media.URL,
 			MediaKey:      media.MediaKey,
-			Mimetype:      &evt.Content.GetInfo().MimeType,
+			Mimetype:      &content.GetInfo().MimeType,
 			Seconds:       &duration,
 			FileEncSha256: media.FileEncSHA256,
 			FileSha256:    media.FileSHA256,
 			FileLength:    &media.FileLength,
 		}
-	case mautrix.MsgAudio:
-		media := portal.preprocessMatrixMedia(sender, relaybotFormatted, evt, whatsapp.MediaAudio)
+	case event.MsgAudio:
+		media := portal.preprocessMatrixMedia(sender, relaybotFormatted, content, evt.ID, whatsapp.MediaAudio)
 		if media == nil {
 			return
 		}
-		duration := uint32(evt.Content.GetInfo().Duration)
+		duration := uint32(content.GetInfo().Duration)
 		info.Message.AudioMessage = &waProto.AudioMessage{
 			Url:           &media.URL,
 			MediaKey:      media.MediaKey,
-			Mimetype:      &evt.Content.GetInfo().MimeType,
+			Mimetype:      &content.GetInfo().MimeType,
 			Seconds:       &duration,
 			FileEncSha256: media.FileEncSHA256,
 			FileSha256:    media.FileSHA256,
 			FileLength:    &media.FileLength,
 		}
-	case mautrix.MsgFile:
-		media := portal.preprocessMatrixMedia(sender, relaybotFormatted, evt, whatsapp.MediaDocument)
+	case event.MsgFile:
+		media := portal.preprocessMatrixMedia(sender, relaybotFormatted, content, evt.ID, whatsapp.MediaDocument)
 		if media == nil {
 			return
 		}
 		info.Message.DocumentMessage = &waProto.DocumentMessage{
 			Url:           &media.URL,
-			FileName:      &evt.Content.Body,
+			FileName:      &content.Body,
 			MediaKey:      media.MediaKey,
-			Mimetype:      &evt.Content.GetInfo().MimeType,
+			Mimetype:      &content.GetInfo().MimeType,
 			FileEncSha256: media.FileEncSHA256,
 			FileSha256:    media.FileSHA256,
 			FileLength:    &media.FileLength,
@@ -1353,9 +1408,9 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *mautrix.Event) {
 	_, err = sender.Conn.Send(info)
 	if err != nil {
 		portal.log.Errorfln("Error handling Matrix event %s: %v", evt.ID, err)
-		msg := format.RenderMarkdown(fmt.Sprintf("\u26a0 Your message may not have been bridged: %v", err))
-		msg.MsgType = mautrix.MsgNotice
-		_, err := portal.MainIntent().SendMessageEvent(portal.MXID, mautrix.EventMessage, msg)
+		msg := format.RenderMarkdown(fmt.Sprintf("\u26a0 Your message may not have been bridged: %v", err), false, false)
+		msg.MsgType = event.MsgNotice
+		_, err := portal.sendMainIntentMessage(msg)
 		if err != nil {
 			portal.log.Errorln("Failed to send bridging failure message:", err)
 		}
@@ -1364,7 +1419,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *mautrix.Event) {
 	}
 }
 
-func (portal *Portal) HandleMatrixRedaction(sender *User, evt *mautrix.Event) {
+func (portal *Portal) HandleMatrixRedaction(sender *User, evt *event.Event) {
 	if portal.IsPrivateChat() && sender.JID != portal.Key.Receiver {
 		return
 	}
@@ -1462,6 +1517,6 @@ func (portal *Portal) HandleMatrixLeave(sender *User) {
 	}
 }
 
-func (portal *Portal) HandleMatrixKick(sender *User, event *mautrix.Event) {
+func (portal *Portal) HandleMatrixKick(sender *User, event *event.Event) {
 	// TODO
 }

+ 7 - 3
provisioning.go

@@ -26,8 +26,8 @@ import (
 	"github.com/gorilla/websocket"
 	log "maunium.net/go/maulogger/v2"
 
-	"maunium.net/go/mautrix-whatsapp/types"
 	whatsappExt "maunium.net/go/mautrix-whatsapp/whatsapp-ext"
+	"maunium.net/go/mautrix/id"
 )
 
 type ProvisioningAPI struct {
@@ -61,7 +61,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
 			return
 		}
 		userID := r.URL.Query().Get("user_id")
-		user := prov.bridge.GetUserByMXID(types.MatrixUserID(userID))
+		user := prov.bridge.GetUserByMXID(id.UserID(userID))
 		h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), "user", user)))
 	})
 }
@@ -292,6 +292,9 @@ func (prov *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
 	}
 	user.Conn.RemoveHandlers()
 	user.Conn = nil
+	user.removeFromJIDMap()
+	// TODO this causes a foreign key violation, which should be fixed
+	//ce.User.JID = ""
 	user.SetSession(nil)
 	jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."})
 }
@@ -300,7 +303,7 @@ var upgrader = websocket.Upgrader{}
 
 func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
 	userID := r.URL.Query().Get("user_id")
-	user := prov.bridge.GetUserByMXID(types.MatrixUserID(userID))
+	user := prov.bridge.GetUserByMXID(id.UserID(userID))
 
 	c, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
@@ -351,6 +354,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
 	}
 	user.ConnectionErrors = 0
 	user.JID = strings.Replace(user.Conn.Info.Wid, whatsappExt.OldUserSuffix, whatsappExt.NewUserSuffix, 1)
+	user.addToJIDMap()
 	user.SetSession(&session)
 	_ = c.WriteJSON(map[string]interface{}{
 		"success": true,

+ 16 - 12
puppet.go

@@ -1,5 +1,5 @@
 // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2019 Tulir Asokan
+// Copyright (C) 2020 Tulir Asokan
 //
 // This program is free software: you can redistribute it and/or modify
 // it under the terms of the GNU Affero General Public License as published by
@@ -25,14 +25,16 @@ import (
 	"github.com/Rhymen/go-whatsapp"
 
 	log "maunium.net/go/maulogger/v2"
-	"maunium.net/go/mautrix-appservice"
+
+	"maunium.net/go/mautrix/appservice"
+	"maunium.net/go/mautrix/id"
 
 	"maunium.net/go/mautrix-whatsapp/database"
 	"maunium.net/go/mautrix-whatsapp/types"
 	"maunium.net/go/mautrix-whatsapp/whatsapp-ext"
 )
 
-func (bridge *Bridge) ParsePuppetMXID(mxid types.MatrixUserID) (types.WhatsAppID, bool) {
+func (bridge *Bridge) ParsePuppetMXID(mxid id.UserID) (types.WhatsAppID, bool) {
 	userIDRegex, err := regexp.Compile(fmt.Sprintf("^@%s:%s$",
 		bridge.Config.Bridge.FormatUsername("([0-9]+)"),
 		bridge.Config.Homeserver.Domain))
@@ -49,7 +51,7 @@ func (bridge *Bridge) ParsePuppetMXID(mxid types.MatrixUserID) (types.WhatsAppID
 	return jid, true
 }
 
-func (bridge *Bridge) GetPuppetByMXID(mxid types.MatrixUserID) *Puppet {
+func (bridge *Bridge) GetPuppetByMXID(mxid id.UserID) *Puppet {
 	jid, ok := bridge.ParsePuppetMXID(mxid)
 	if !ok {
 		return nil
@@ -78,7 +80,7 @@ func (bridge *Bridge) GetPuppetByJID(jid types.WhatsAppID) *Puppet {
 	return puppet
 }
 
-func (bridge *Bridge) GetPuppetByCustomMXID(mxid types.MatrixUserID) *Puppet {
+func (bridge *Bridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet {
 	bridge.puppetsLock.Lock()
 	defer bridge.puppetsLock.Unlock()
 	puppet, ok := bridge.puppetsByCustomMXID[mxid]
@@ -129,7 +131,7 @@ func (bridge *Bridge) NewPuppet(dbPuppet *database.Puppet) *Puppet {
 		bridge: bridge,
 		log:    bridge.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)),
 
-		MXID: fmt.Sprintf("@%s:%s",
+		MXID: id.NewUserID(
 			bridge.Config.Bridge.FormatUsername(
 				strings.Replace(
 					dbPuppet.JID,
@@ -144,13 +146,13 @@ type Puppet struct {
 	bridge *Bridge
 	log    log.Logger
 
-	typingIn types.MatrixRoomID
+	typingIn id.RoomID
 	typingAt int64
 
-	MXID types.MatrixUserID
+	MXID id.UserID
 
 	customIntent   *appservice.IntentAPI
-	customTypingIn map[string]bool
+	customTypingIn map[id.RoomID]bool
 	customUser     *User
 }
 
@@ -159,7 +161,9 @@ func (puppet *Puppet) PhoneNumber() string {
 }
 
 func (puppet *Puppet) IntentFor(portal *Portal) *appservice.IntentAPI {
-	if (!portal.IsPrivateChat() && puppet.customIntent == nil) || portal.backfilling || portal.Key.JID == puppet.JID {
+	if (!portal.IsPrivateChat() && puppet.customIntent == nil) ||
+		(portal.backfilling && portal.bridge.Config.Bridge.InviteOwnPuppetForBackfilling) ||
+		portal.Key.JID == puppet.JID {
 		return puppet.DefaultIntent()
 	}
 	return puppet.customIntent
@@ -192,11 +196,11 @@ func (puppet *Puppet) UpdateAvatar(source *User, avatar *whatsappExt.ProfilePicI
 	}
 
 	if len(avatar.URL) == 0 {
-		err := puppet.DefaultIntent().SetAvatarURL("")
+		err := puppet.DefaultIntent().SetAvatarURL(id.ContentURI{})
 		if err != nil {
 			puppet.log.Warnln("Failed to remove avatar:", err)
 		}
-		puppet.AvatarURL = ""
+		puppet.AvatarURL = id.ContentURI{}
 		puppet.Avatar = avatar.Tag
 		go puppet.updatePortalAvatar()
 		return true

+ 0 - 9
types/types.go

@@ -21,12 +21,3 @@ type WhatsAppID = string
 
 // WhatsAppMessageID is the internal ID of a WhatsApp message.
 type WhatsAppMessageID = string
-
-// MatrixUserID is the ID of a Matrix user.
-type MatrixUserID = string
-
-// MatrixRoomID is the internal room ID of a Matrix room.
-type MatrixRoomID = string
-
-// MatrixEventID is the internal ID of a Matrix event.
-type MatrixEventID = string

+ 59 - 37
user.go

@@ -1,5 +1,5 @@
 // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2019 Tulir Asokan
+// Copyright (C) 2020 Tulir Asokan
 //
 // This program is free software: you can redistribute it and/or modify
 // it under the terms of the GNU Affero General Public License as published by
@@ -32,8 +32,9 @@ import (
 	"github.com/Rhymen/go-whatsapp"
 	waProto "github.com/Rhymen/go-whatsapp/binary/proto"
 
-	"maunium.net/go/mautrix"
+	"maunium.net/go/mautrix/event"
 	"maunium.net/go/mautrix/format"
+	"maunium.net/go/mautrix/id"
 
 	"maunium.net/go/mautrix-whatsapp/database"
 	"maunium.net/go/mautrix-whatsapp/types"
@@ -65,7 +66,7 @@ type User struct {
 	syncLock sync.Mutex
 }
 
-func (bridge *Bridge) GetUserByMXID(userID types.MatrixUserID) *User {
+func (bridge *Bridge) GetUserByMXID(userID id.UserID) *User {
 	_, isPuppet := bridge.ParsePuppetMXID(userID)
 	if isPuppet || userID == bridge.Bot.UserID {
 		return nil
@@ -89,6 +90,18 @@ func (bridge *Bridge) GetUserByJID(userID types.WhatsAppID) *User {
 	return user
 }
 
+func (user *User) addToJIDMap() {
+	user.bridge.usersLock.Lock()
+	user.bridge.usersByJID[user.JID] = user
+	user.bridge.usersLock.Unlock()
+}
+
+func (user *User) removeFromJIDMap() {
+	user.bridge.usersLock.Lock()
+	delete(user.bridge.usersByJID, user.JID)
+	user.bridge.usersLock.Unlock()
+}
+
 func (bridge *Bridge) GetAllUsers() []*User {
 	bridge.usersLock.Lock()
 	defer bridge.usersLock.Unlock()
@@ -104,7 +117,7 @@ func (bridge *Bridge) GetAllUsers() []*User {
 	return output
 }
 
-func (bridge *Bridge) loadDBUser(dbUser *database.User, mxid *types.MatrixUserID) *User {
+func (bridge *Bridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User {
 	if dbUser == nil {
 		if mxid == nil {
 			return nil
@@ -160,7 +173,7 @@ func (bridge *Bridge) NewUser(dbUser *database.User) *User {
 	return user
 }
 
-func (user *User) SetManagementRoom(roomID types.MatrixRoomID) {
+func (user *User) SetManagementRoom(roomID id.RoomID) {
 	existingUser, ok := user.bridge.managementRooms[roomID]
 	if ok {
 		existingUser.ManagementRoom = ""
@@ -194,9 +207,9 @@ func (user *User) Connect(evenIfNoSession bool) bool {
 	conn, err := whatsapp.NewConn(timeout * time.Second)
 	if err != nil {
 		user.log.Errorln("Failed to connect to WhatsApp:", err)
-		msg := format.RenderMarkdown("\u26a0 Failed to connect to WhatsApp server. " +
-			"This indicates a network problem on the bridge server. See bridge logs for more info.")
-		_, _ = user.bridge.Bot.SendMessageEvent(user.ManagementRoom, mautrix.EventMessage, msg)
+		msg := format.RenderMarkdown("\u26a0 Failed to connect to WhatsApp server. "+
+			"This indicates a network problem on the bridge server. See bridge logs for more info.", true, false)
+		_, _ = user.bridge.Bot.SendMessageEvent(user.ManagementRoom, event.EventMessage, msg)
 		return false
 	}
 	user.Conn = whatsappExt.ExtendConn(conn)
@@ -213,9 +226,9 @@ func (user *User) RestoreSession() bool {
 			return true
 		} else if err != nil {
 			user.log.Errorln("Failed to restore session:", err)
-			msg := format.RenderMarkdown("\u26a0 Failed to connect to WhatsApp. Make sure WhatsApp " +
-				"on your phone is reachable and use `reconnect` to try connecting again.")
-			_, _ = user.bridge.Bot.SendMessageEvent(user.ManagementRoom, mautrix.EventMessage, msg)
+			msg := format.RenderMarkdown("\u26a0 Failed to connect to WhatsApp. Make sure WhatsApp "+
+				"on your phone is reachable and use `reconnect` to try connecting again.", true, false)
+			_, _ = user.bridge.Bot.SendMessageEvent(user.ManagementRoom, event.EventMessage, msg)
 			user.log.Debugln("Disconnecting due to failed session restore...")
 			_, err := user.Conn.Disconnect()
 			if err != nil {
@@ -243,8 +256,8 @@ func (user *User) IsLoginInProgress() bool {
 	return user.Conn != nil && user.Conn.IsLoginInProgress()
 }
 
-func (user *User) loginQrChannel(ce *CommandEvent, qrChan <-chan string, eventIDChan chan<- string) {
-	var qrEventID string
+func (user *User) loginQrChannel(ce *CommandEvent, qrChan <-chan string, eventIDChan chan<- id.EventID) {
+	var qrEventID id.EventID
 	for code := range qrChan {
 		if code == "stop" {
 			return
@@ -274,17 +287,17 @@ func (user *User) loginQrChannel(ce *CommandEvent, qrChan <-chan string, eventID
 			qrEventID = sendResp.EventID
 			eventIDChan <- qrEventID
 		} else {
-			_, err = bot.SendMessageEvent(ce.RoomID, mautrix.EventMessage, &mautrix.Content{
-				MsgType: mautrix.MsgImage,
+			_, err = bot.SendMessageEvent(ce.RoomID, event.EventMessage, &event.MessageEventContent{
+				MsgType: event.MsgImage,
 				Body:    code,
-				URL:     resp.ContentURI,
-				NewContent: &mautrix.Content{
-					MsgType: mautrix.MsgImage,
+				URL:     resp.ContentURI.CUString(),
+				NewContent: &event.MessageEventContent{
+					MsgType: event.MsgImage,
 					Body:    code,
-					URL:     resp.ContentURI,
+					URL:     resp.ContentURI.CUString(),
 				},
-				RelatesTo: &mautrix.RelatesTo{
-					Type:    mautrix.RelReplace,
+				RelatesTo: &event.RelatesTo{
+					Type:    event.RelReplace,
 					EventID: qrEventID,
 				},
 			})
@@ -297,18 +310,18 @@ func (user *User) loginQrChannel(ce *CommandEvent, qrChan <-chan string, eventID
 
 func (user *User) Login(ce *CommandEvent) {
 	qrChan := make(chan string, 3)
-	eventIDChan := make(chan string, 1)
+	eventIDChan := make(chan id.EventID, 1)
 	go user.loginQrChannel(ce, qrChan, eventIDChan)
 	session, err := user.Conn.LoginWithRetry(qrChan, user.bridge.Config.Bridge.LoginQRRegenCount)
 	qrChan <- "stop"
 	if err != nil {
-		var eventID string
+		var eventID id.EventID
 		select {
 		case eventID = <-eventIDChan:
 		default:
 		}
-		reply := mautrix.Content{
-			MsgType: mautrix.MsgText,
+		reply := event.MessageEventContent{
+			MsgType: event.MsgText,
 		}
 		if err == whatsapp.ErrAlreadyLoggedIn {
 			reply.Body = "You're already logged in"
@@ -323,16 +336,19 @@ func (user *User) Login(ce *CommandEvent) {
 		msg := reply
 		if eventID != "" {
 			msg.NewContent = &reply
-			msg.RelatesTo = &mautrix.RelatesTo{
-				Type:    mautrix.RelReplace,
+			msg.RelatesTo = &event.RelatesTo{
+				Type:    event.RelReplace,
 				EventID: eventID,
 			}
 		}
-		_, _ = ce.Bot.SendMessageEvent(ce.RoomID, mautrix.EventMessage, &msg)
+		_, _ = ce.Bot.SendMessageEvent(ce.RoomID, event.EventMessage, &msg)
 		return
 	}
+	// TODO there's a bit of duplication between this and the provisioning API login method
+	//      Also between the two logout methods (commands.go and provisioning.go)
 	user.ConnectionErrors = 0
 	user.JID = strings.Replace(user.Conn.Info.Wid, whatsappExt.OldUserSuffix, whatsappExt.NewUserSuffix, 1)
+	user.addToJIDMap()
 	user.SetSession(&session)
 	ce.Reply("Successfully logged in, synchronizing chats...")
 	user.PostLogin()
@@ -365,8 +381,11 @@ func (user *User) PostLogin() {
 }
 
 func (user *User) tryAutomaticDoublePuppeting() {
-	if len(user.bridge.Config.Bridge.LoginSharedSecret) == 0 || !strings.HasSuffix(user.MXID, user.bridge.Config.Homeserver.Domain) {
-		// Automatic login not enabled or user is on another homeserver
+	if len(user.bridge.Config.Bridge.LoginSharedSecret) == 0 {
+		// Automatic login not enabled
+		return
+	} else if _, homeserver, _ := user.MXID.Parse(); homeserver != user.bridge.Config.Homeserver.Domain {
+		// user is on another homeserver
 		return
 	}
 
@@ -535,8 +554,8 @@ func (user *User) HandleError(err error) {
 
 func (user *User) tryReconnect(msg string) {
 	if user.ConnectionErrors > user.bridge.Config.Bridge.MaxConnectionAttempts {
-		content := format.RenderMarkdown(fmt.Sprintf("%s. Use the `reconnect` command to reconnect.", msg))
-		_, _ = user.bridge.Bot.SendMessageEvent(user.ManagementRoom, mautrix.EventMessage, content)
+		content := format.RenderMarkdown(fmt.Sprintf("%s. Use the `reconnect` command to reconnect.", msg), true, false)
+		_, _ = user.bridge.Bot.SendMessageEvent(user.ManagementRoom, event.EventMessage, content)
 		return
 	}
 	if user.bridge.Config.Bridge.ReportConnectionRetry {
@@ -591,8 +610,8 @@ func (user *User) tryReconnect(msg string) {
 			"Use the `reconnect` command to try to reconnect.", msg, tries)
 	}
 
-	content := format.RenderMarkdown(msg)
-	_, _ = user.bridge.Bot.SendMessageEvent(user.ManagementRoom, mautrix.EventMessage, content)
+	content := format.RenderMarkdown(msg, true, false)
+	_, _ = user.bridge.Bot.SendMessageEvent(user.ManagementRoom, event.EventMessage, content)
 }
 
 func (user *User) ShouldCallSynchronously() bool {
@@ -656,8 +675,9 @@ func (user *User) HandleMessageRevoke(message whatsappExt.MessageRevocation) {
 }
 
 type FakeMessage struct {
-	Text string
-	ID   string
+	Text  string
+	ID    string
+	Alert bool
 }
 
 func (user *User) HandleCallInfo(info whatsappExt.CallInfo) {
@@ -673,11 +693,13 @@ func (user *User) HandleCallInfo(info whatsappExt.CallInfo) {
 			return
 		}
 		data.Text = "Incoming call"
+		data.Alert = true
 	case whatsappExt.CallOfferVideo:
 		if !user.bridge.Config.Bridge.CallNotices.Start {
 			return
 		}
 		data.Text = "Incoming video call"
+		data.Alert = true
 	case whatsappExt.CallTerminate:
 		if !user.bridge.Config.Bridge.CallNotices.End {
 			return
@@ -766,7 +788,7 @@ func (user *User) HandleCommand(cmd whatsappExt.Command) {
 				"Use the `reconnect` command to reconnect.", cmd.Kind)
 		}
 		user.cleanDisconnection = true
-		go user.bridge.Bot.SendMessageEvent(user.ManagementRoom, mautrix.EventMessage, format.RenderMarkdown(msg))
+		go user.bridge.Bot.SendMessageEvent(user.ManagementRoom, event.EventMessage, format.RenderMarkdown(msg, true, false))
 	}
 }
 

+ 14 - 0
webp.go

@@ -0,0 +1,14 @@
+// +build cgo
+
+package main
+
+import (
+	"image"
+	"io"
+
+	"github.com/chai2010/webp"
+)
+
+func decodeWebp(r io.Reader) (image.Image, error) {
+	return webp.Decode(r)
+}