فهرست منبع

Move a bunch of stuff to mautrix-go

See https://github.com/mautrix/go/commit/d578d1a610d57a90277aa85e910c48a6627f115d

Database upgrades from before v0.4.0 were squashed, users must update
to at least v0.4.0 before updating beyond this commit.
Tulir Asokan 3 سال پیش
والد
کامیت
a948ea0146
83فایلهای تغییر یافته به همراه626 افزوده شده و 2837 حذف شده
  1. 8 8
      bridgestate.go
  2. 7 12
      commands.go
  3. 14 13
      config/bridge.go
  4. 2 85
      config/config.go
  5. 0 82
      config/registration.go
  6. 34 77
      config/upgrade.go
  7. 0 327
      crypto.go
  8. 11 11
      custompuppet.go
  9. 4 2
      database/backfill.go
  10. 7 48
      database/cryptostore.go
  11. 16 55
      database/database.go
  12. 2 1
      database/disappearingmessage.go
  13. 4 2
      database/historysync.go
  14. 3 1
      database/mediabackfillrequest.go
  15. 2 1
      database/message.go
  16. 2 1
      database/portal.go
  17. 3 1
      database/puppet.go
  18. 2 1
      database/reaction.go
  19. 0 282
      database/statestore.go
  20. 181 0
      database/upgrades/00-latest-revision.sql
  21. 0 67
      database/upgrades/2018-09-01-initial-schema.go
  22. 0 15
      database/upgrades/2019-05-21-message-timestamp-column.go
  23. 0 15
      database/upgrades/2019-05-22-user-last-connection-column.go
  24. 0 23
      database/upgrades/2019-05-23-puppet-custom-mxid-columns.go
  25. 0 19
      database/upgrades/2019-05-28-user-portal-table.go
  26. 0 19
      database/upgrades/2019-06-01-avatar-url-fields.go
  27. 0 12
      database/upgrades/2019-08-10-portal-in-community-field.go
  28. 0 39
      database/upgrades/2019-08-25-move-state-store-to-db.go
  29. 0 16
      database/upgrades/2019-11-10-full-member-state-store.go
  30. 0 16
      database/upgrades/2019-11-12-fix-room-topic-length.go
  31. 0 12
      database/upgrades/2020-05-09-add-portal-encrypted-field.go
  32. 0 73
      database/upgrades/2020-05-09-crypto-store.go
  33. 0 25
      database/upgrades/2020-05-12-outbound-group-session-store.go
  34. 0 12
      database/upgrades/2020-07-10-custom-puppet-presence-toggle.go
  35. 0 13
      database/upgrades/2020-07-10-update-crypto-store.go
  36. 0 12
      database/upgrades/2020-07-10-x-custom-puppet-receipts-toggle.go
  37. 0 13
      database/upgrades/2020-08-03-update-crypto-store.go
  38. 0 13
      database/upgrades/2020-10-28-crypto-store-cross-signing.go
  39. 0 12
      database/upgrades/2021-02-17-message-sent-status.go
  40. 0 44
      database/upgrades/2021-08-19-remove-message-content.go
  41. 0 13
      database/upgrades/2021-08-19-varchar-to-text-crypto.go
  42. 0 48
      database/upgrades/2021-08-19-varchar-to-text.go
  43. 0 13
      database/upgrades/2021-10-21-add-whatsmeow-store.go
  44. 0 93
      database/upgrades/2021-10-21-multidevice-updates.go
  45. 0 19
      database/upgrades/2021-10-26-portal-origin-event-id.go
  46. 0 12
      database/upgrades/2021-10-27-message-decryption-errors.go
  47. 0 12
      database/upgrades/2021-10-28-portal-relay-user.go
  48. 0 22
      database/upgrades/2021-10-30-varchar-to-text-state-store.go
  49. 0 22
      database/upgrades/2021-11-30-store-last-read-state.go
  50. 0 13
      database/upgrades/2021-12-22-crypto-store-last-used.go
  51. 0 12
      database/upgrades/2021-12-25-broadcast-list-message-source.go
  52. 0 16
      database/upgrades/2021-12-29-personal-filtering-spaces.go
  53. 0 20
      database/upgrades/2022-01-07-disappearing-messages.go
  54. 0 10
      database/upgrades/2022-01-24-phone-last-seen-ts.go
  55. 0 30
      database/upgrades/2022-02-10-message-error-string.go
  56. 0 10
      database/upgrades/2022-02-18-phone-ping-ts.go
  57. 0 39
      database/upgrades/2022-03-05-reactions.go
  58. 0 45
      database/upgrades/2022-03-15-prioritized-backfill.go
  59. 0 52
      database/upgrades/2022-03-18-historysync-store.go
  60. 0 20
      database/upgrades/2022-04-29-backfillqueue-type-order.go
  61. 0 26
      database/upgrades/2022-05-09-media-backfill-requests-queue-table.go
  62. 0 12
      database/upgrades/2022-05-11-add-user-timezone.go
  63. 0 34
      database/upgrades/2022-05-12-backfillqueue-dispatch-time.go
  64. 0 16
      database/upgrades/2022-05-12-history-sync-message-add-added-timestamp.go
  65. 0 25
      database/upgrades/2022-05-16-room-backfill-state.go
  66. 5 0
      database/upgrades/45-backfillqueue-dispatch-time.sql
  67. 3 0
      database/upgrades/46-history-sync-message-added-timestamp.sql
  68. 13 0
      database/upgrades/47-room-backfill-state.sql
  69. 7 0
      database/upgrades/48-crypto-store-handling-split.sql
  70. 16 169
      database/upgrades/upgrades.go
  71. 2 1
      database/user.go
  72. 3 3
      disappear.go
  73. 8 8
      example-config.yaml
  74. 2 2
      formatting.go
  75. 3 4
      go.mod
  76. 4 5
      go.sum
  77. 119 394
      main.go
  78. 4 3
      matrix.go
  79. 0 17
      no-crypto.go
  80. 47 38
      portal.go
  81. 4 4
      provisioning.go
  82. 42 42
      puppet.go
  83. 42 33
      user.go

+ 8 - 8
bridgestate.go

@@ -114,18 +114,18 @@ func (pong *BridgeState) shouldDeduplicate(newPong *BridgeState) bool {
 	return pong.Timestamp+int64(pong.TTL/5) > time.Now().Unix()
 	return pong.Timestamp+int64(pong.TTL/5) > time.Now().Unix()
 }
 }
 
 
-func (bridge *Bridge) sendBridgeState(ctx context.Context, state *BridgeState) error {
+func (br *WABridge) sendBridgeState(ctx context.Context, state *BridgeState) error {
 	var body bytes.Buffer
 	var body bytes.Buffer
 	if err := json.NewEncoder(&body).Encode(&state); err != nil {
 	if err := json.NewEncoder(&body).Encode(&state); err != nil {
 		return fmt.Errorf("failed to encode bridge state JSON: %w", err)
 		return fmt.Errorf("failed to encode bridge state JSON: %w", err)
 	}
 	}
 
 
-	req, err := http.NewRequestWithContext(ctx, http.MethodPost, bridge.Config.Homeserver.StatusEndpoint, &body)
+	req, err := http.NewRequestWithContext(ctx, http.MethodPost, br.Config.Homeserver.StatusEndpoint, &body)
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("failed to prepare request: %w", err)
 		return fmt.Errorf("failed to prepare request: %w", err)
 	}
 	}
 
 
-	req.Header.Set("Authorization", "Bearer "+bridge.Config.AppService.ASToken)
+	req.Header.Set("Authorization", "Bearer "+br.Config.AppService.ASToken)
 	req.Header.Set("Content-Type", "application/json")
 	req.Header.Set("Content-Type", "application/json")
 
 
 	resp, err := http.DefaultClient.Do(req)
 	resp, err := http.DefaultClient.Do(req)
@@ -143,17 +143,17 @@ func (bridge *Bridge) sendBridgeState(ctx context.Context, state *BridgeState) e
 	return nil
 	return nil
 }
 }
 
 
-func (bridge *Bridge) sendGlobalBridgeState(state BridgeState) {
-	if len(bridge.Config.Homeserver.StatusEndpoint) == 0 {
+func (br *WABridge) sendGlobalBridgeState(state BridgeState) {
+	if len(br.Config.Homeserver.StatusEndpoint) == 0 {
 		return
 		return
 	}
 	}
 
 
 	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
 	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
 	defer cancel()
 	defer cancel()
-	if err := bridge.sendBridgeState(ctx, &state); err != nil {
-		bridge.Log.Warnln("Failed to update global bridge state:", err)
+	if err := br.sendBridgeState(ctx, &state); err != nil {
+		br.Log.Warnln("Failed to update global bridge state:", err)
 	} else {
 	} else {
-		bridge.Log.Debugfln("Sent new global bridge state %+v", state)
+		br.Log.Debugfln("Sent new global bridge state %+v", state)
 	}
 	}
 }
 }
 
 

+ 7 - 12
commands.go

@@ -39,6 +39,7 @@ import (
 
 
 	"maunium.net/go/mautrix"
 	"maunium.net/go/mautrix"
 	"maunium.net/go/mautrix/appservice"
 	"maunium.net/go/mautrix/appservice"
+	"maunium.net/go/mautrix/bridge"
 	"maunium.net/go/mautrix/event"
 	"maunium.net/go/mautrix/event"
 	"maunium.net/go/mautrix/format"
 	"maunium.net/go/mautrix/format"
 	"maunium.net/go/mautrix/id"
 	"maunium.net/go/mautrix/id"
@@ -47,12 +48,12 @@ import (
 )
 )
 
 
 type CommandHandler struct {
 type CommandHandler struct {
-	bridge *Bridge
+	bridge *WABridge
 	log    maulogger.Logger
 	log    maulogger.Logger
 }
 }
 
 
 // NewCommandHandler creates a CommandHandler
 // NewCommandHandler creates a CommandHandler
-func NewCommandHandler(bridge *Bridge) *CommandHandler {
+func NewCommandHandler(bridge *WABridge) *CommandHandler {
 	return &CommandHandler{
 	return &CommandHandler{
 		bridge: bridge,
 		bridge: bridge,
 		log:    bridge.Log.Sub("Command handler"),
 		log:    bridge.Log.Sub("Command handler"),
@@ -62,7 +63,7 @@ func NewCommandHandler(bridge *Bridge) *CommandHandler {
 // CommandEvent stores all data which might be used to handle commands
 // CommandEvent stores all data which might be used to handle commands
 type CommandEvent struct {
 type CommandEvent struct {
 	Bot     *appservice.IntentAPI
 	Bot     *appservice.IntentAPI
-	Bridge  *Bridge
+	Bridge  *WABridge
 	Portal  *Portal
 	Portal  *Portal
 	Handler *CommandHandler
 	Handler *CommandHandler
 	RoomID  id.RoomID
 	RoomID  id.RoomID
@@ -251,13 +252,7 @@ func (handler *CommandHandler) CommandDevTest(_ *CommandEvent) {
 const cmdVersionHelp = `version - View the bridge version`
 const cmdVersionHelp = `version - View the bridge version`
 
 
 func (handler *CommandHandler) CommandVersion(ce *CommandEvent) {
 func (handler *CommandHandler) CommandVersion(ce *CommandEvent) {
-	linkifiedVersion := fmt.Sprintf("v%s", Version)
-	if Tag == Version {
-		linkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", Version, URL, Tag)
-	} else if len(Commit) > 8 {
-		linkifiedVersion = strings.Replace(linkifiedVersion, Commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", Commit[:8], URL, Commit), 1)
-	}
-	ce.Reply(fmt.Sprintf("[%s](%s) %s (%s)", Name, URL, linkifiedVersion, BuildTime))
+	ce.Reply(fmt.Sprintf("[%s](%s) %s (%s)", ce.Bridge.Name, ce.Bridge.URL, ce.Bridge.LinkifiedVersion, BuildTime))
 }
 }
 
 
 const cmdInviteLinkHelp = `invite-link [--reset] - Get an invite link to the current group chat, optionally regenerating the link and revoking the old link.`
 const cmdInviteLinkHelp = `invite-link [--reset] - Get an invite link to the current group chat, optionally regenerating the link and revoking the old link.`
@@ -331,7 +326,7 @@ func (handler *CommandHandler) CommandJoin(ce *CommandEvent) {
 	ce.Reply("Successfully joined group `%s`, the portal should be created momentarily", jid)
 	ce.Reply("Successfully joined group `%s`, the portal should be created momentarily", jid)
 }
 }
 
 
-func tryDecryptEvent(crypto Crypto, evt *event.Event) (json.RawMessage, error) {
+func tryDecryptEvent(crypto bridge.Crypto, evt *event.Event) (json.RawMessage, error) {
 	var data json.RawMessage
 	var data json.RawMessage
 	if evt.Type != event.EventEncrypted {
 	if evt.Type != event.EventEncrypted {
 		data = evt.Content.VeryRaw
 		data = evt.Content.VeryRaw
@@ -903,7 +898,7 @@ func matchesQuery(str string, query string) bool {
 	return strings.Contains(strings.ToLower(str), query)
 	return strings.Contains(strings.ToLower(str), query)
 }
 }
 
 
-func formatContacts(bridge *Bridge, input map[types.JID]types.ContactInfo, query string) (result []string) {
+func formatContacts(bridge *WABridge, input map[types.JID]types.ContactInfo, query string) (result []string) {
 	hasQuery := len(query) > 0
 	hasQuery := len(query) > 0
 	for jid, contact := range input {
 	for jid, contact := range input {
 		if len(contact.FullName) == 0 {
 		if len(contact.FullName) == 0 {

+ 14 - 13
config/bridge.go

@@ -24,6 +24,7 @@ import (
 
 
 	"go.mau.fi/whatsmeow/types"
 	"go.mau.fi/whatsmeow/types"
 
 
+	"maunium.net/go/mautrix/bridge/bridgeconfig"
 	"maunium.net/go/mautrix/event"
 	"maunium.net/go/mautrix/event"
 	"maunium.net/go/mautrix/id"
 	"maunium.net/go/mautrix/id"
 )
 )
@@ -118,23 +119,23 @@ type BridgeConfig struct {
 		AdditionalHelp     string `yaml:"additional_help"`
 		AdditionalHelp     string `yaml:"additional_help"`
 	} `yaml:"management_room_text"`
 	} `yaml:"management_room_text"`
 
 
-	Encryption struct {
-		Allow   bool `yaml:"allow"`
-		Default bool `yaml:"default"`
+	Encryption bridgeconfig.EncryptionConfig `yaml:"encryption"`
 
 
-		KeySharing struct {
-			Allow               bool `yaml:"allow"`
-			RequireCrossSigning bool `yaml:"require_cross_signing"`
-			RequireVerification bool `yaml:"require_verification"`
-		} `yaml:"key_sharing"`
-	} `yaml:"encryption"`
+	Provisioning struct {
+		Prefix       string `yaml:"prefix"`
+		SharedSecret string `yaml:"shared_secret"`
+	} `yaml:"provisioning"`
 
 
 	Permissions PermissionConfig `yaml:"permissions"`
 	Permissions PermissionConfig `yaml:"permissions"`
 
 
 	Relay RelaybotConfig `yaml:"relay"`
 	Relay RelaybotConfig `yaml:"relay"`
 
 
-	usernameTemplate    *template.Template `yaml:"-"`
-	displaynameTemplate *template.Template `yaml:"-"`
+	ParsedUsernameTemplate *template.Template `yaml:"-"`
+	displaynameTemplate    *template.Template `yaml:"-"`
+}
+
+func (bc BridgeConfig) GetEncryptionConfig() bridgeconfig.EncryptionConfig {
+	return bc.Encryption
 }
 }
 
 
 type umBridgeConfig BridgeConfig
 type umBridgeConfig BridgeConfig
@@ -145,7 +146,7 @@ func (bc *BridgeConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
 		return err
 		return err
 	}
 	}
 
 
-	bc.usernameTemplate, err = template.New("username").Parse(bc.UsernameTemplate)
+	bc.ParsedUsernameTemplate, err = template.New("username").Parse(bc.UsernameTemplate)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	} else if !strings.Contains(bc.FormatUsername("1234567890"), "1234567890") {
 	} else if !strings.Contains(bc.FormatUsername("1234567890"), "1234567890") {
@@ -206,7 +207,7 @@ func (bc BridgeConfig) FormatDisplayname(jid types.JID, contact types.ContactInf
 
 
 func (bc BridgeConfig) FormatUsername(username string) string {
 func (bc BridgeConfig) FormatUsername(username string) string {
 	var buf strings.Builder
 	var buf strings.Builder
-	_ = bc.usernameTemplate.Execute(&buf, username)
+	_ = bc.ParsedUsernameTemplate.Execute(&buf, username)
 	return buf.String()
 	return buf.String()
 }
 }
 
 

+ 2 - 85
config/config.go

@@ -17,52 +17,12 @@
 package config
 package config
 
 
 import (
 import (
-	"fmt"
-
-	"gopkg.in/yaml.v3"
-
-	"maunium.net/go/mautrix/appservice"
+	"maunium.net/go/mautrix/bridge/bridgeconfig"
 	"maunium.net/go/mautrix/id"
 	"maunium.net/go/mautrix/id"
 )
 )
 
 
-var ExampleConfig string
-
 type Config struct {
 type Config struct {
-	Homeserver struct {
-		Address                       string `yaml:"address"`
-		Domain                        string `yaml:"domain"`
-		Asmux                         bool   `yaml:"asmux"`
-		StatusEndpoint                string `yaml:"status_endpoint"`
-		MessageSendCheckpointEndpoint string `yaml:"message_send_checkpoint_endpoint"`
-		AsyncMedia                    bool   `yaml:"async_media"`
-	} `yaml:"homeserver"`
-
-	AppService struct {
-		Address  string `yaml:"address"`
-		Hostname string `yaml:"hostname"`
-		Port     uint16 `yaml:"port"`
-
-		Database DatabaseConfig `yaml:"database"`
-
-		Provisioning struct {
-			Prefix       string `yaml:"prefix"`
-			SharedSecret string `yaml:"shared_secret"`
-		} `yaml:"provisioning"`
-
-		ID  string `yaml:"id"`
-		Bot struct {
-			Username    string `yaml:"username"`
-			Displayname string `yaml:"displayname"`
-			Avatar      string `yaml:"avatar"`
-
-			ParsedAvatar id.ContentURI `yaml:"-"`
-		} `yaml:"bot"`
-
-		EphemeralEvents bool `yaml:"ephemeral_events"`
-
-		ASToken string `yaml:"as_token"`
-		HSToken string `yaml:"hs_token"`
-	} `yaml:"appservice"`
+	*bridgeconfig.BaseConfig `yaml:",inline"`
 
 
 	SegmentKey string `yaml:"segment_key"`
 	SegmentKey string `yaml:"segment_key"`
 
 
@@ -77,8 +37,6 @@ type Config struct {
 	} `yaml:"whatsapp"`
 	} `yaml:"whatsapp"`
 
 
 	Bridge BridgeConfig `yaml:"bridge"`
 	Bridge BridgeConfig `yaml:"bridge"`
-
-	Logging appservice.LogConfig `yaml:"logging"`
 }
 }
 
 
 func (config *Config) CanAutoDoublePuppet(userID id.UserID) bool {
 func (config *Config) CanAutoDoublePuppet(userID id.UserID) bool {
@@ -98,44 +56,3 @@ func (config *Config) CanDoublePuppetBackfill(userID id.UserID) bool {
 	}
 	}
 	return true
 	return true
 }
 }
-
-func Load(data []byte, upgraded bool) (*Config, error) {
-	var config = &Config{}
-	if !upgraded {
-		// Fallback: if config upgrading failed, load example config for base values
-		err := yaml.Unmarshal([]byte(ExampleConfig), config)
-		if err != nil {
-			return nil, fmt.Errorf("failed to unmarshal example config: %w", err)
-		}
-	}
-	err := yaml.Unmarshal(data, config)
-	if err != nil {
-		return nil, err
-	}
-
-	return config, err
-}
-
-func (config *Config) MakeAppService() (*appservice.AppService, error) {
-	as := appservice.Create()
-	as.HomeserverDomain = config.Homeserver.Domain
-	as.HomeserverURL = config.Homeserver.Address
-	as.Host.Hostname = config.AppService.Hostname
-	as.Host.Port = config.AppService.Port
-	as.MessageSendCheckpointEndpoint = config.Homeserver.MessageSendCheckpointEndpoint
-	as.DefaultHTTPRetries = 4
-	var err error
-	as.Registration, err = config.GetRegistration()
-	return as, err
-}
-
-type DatabaseConfig struct {
-	Type string `yaml:"type"`
-	URI  string `yaml:"uri"`
-
-	MaxOpenConns int `yaml:"max_open_conns"`
-	MaxIdleConns int `yaml:"max_idle_conns"`
-
-	ConnMaxIdleTime string `yaml:"conn_max_idle_time"`
-	ConnMaxLifetime string `yaml:"conn_max_lifetime"`
-}

+ 0 - 82
config/registration.go

@@ -1,82 +0,0 @@
-// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2019 Tulir Asokan
-//
-// This program is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Affero General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// This program is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-// GNU Affero General Public License for more details.
-//
-// You should have received a copy of the GNU Affero General Public License
-// along with this program.  If not, see <https://www.gnu.org/licenses/>.
-
-package config
-
-import (
-	"fmt"
-	"regexp"
-	"strings"
-
-	"maunium.net/go/mautrix/appservice"
-)
-
-func (config *Config) NewRegistration() (*appservice.Registration, error) {
-	registration := appservice.CreateRegistration()
-
-	err := config.copyToRegistration(registration)
-	if err != nil {
-		return nil, err
-	}
-
-	config.AppService.ASToken = registration.AppToken
-	config.AppService.HSToken = registration.ServerToken
-
-	// Workaround for https://github.com/matrix-org/synapse/pull/5758
-	registration.SenderLocalpart = appservice.RandomString(32)
-	botRegex := regexp.MustCompile(fmt.Sprintf("^@%s:%s$",
-		regexp.QuoteMeta(config.AppService.Bot.Username),
-		regexp.QuoteMeta(config.Homeserver.Domain)))
-	registration.Namespaces.RegisterUserIDs(botRegex, true)
-
-	return registration, nil
-}
-
-func (config *Config) GetRegistration() (*appservice.Registration, error) {
-	registration := appservice.CreateRegistration()
-
-	err := config.copyToRegistration(registration)
-	if err != nil {
-		return nil, err
-	}
-
-	registration.AppToken = config.AppService.ASToken
-	registration.ServerToken = config.AppService.HSToken
-	return registration, nil
-}
-
-func (config *Config) copyToRegistration(registration *appservice.Registration) error {
-	registration.ID = config.AppService.ID
-	registration.URL = config.AppService.Address
-	falseVal := false
-	registration.RateLimited = &falseVal
-	registration.SenderLocalpart = config.AppService.Bot.Username
-	registration.EphemeralEvents = config.AppService.EphemeralEvents
-
-	usernamePlaceholder := appservice.RandomString(16)
-	usernameTemplate := fmt.Sprintf("@%s:%s",
-		config.Bridge.FormatUsername(usernamePlaceholder),
-		config.Homeserver.Domain)
-	usernameTemplate = regexp.QuoteMeta(usernameTemplate)
-	usernameTemplate = strings.Replace(usernameTemplate, usernamePlaceholder, "[0-9]+", 1)
-	usernameTemplate = fmt.Sprintf("^%s$", usernameTemplate)
-	userIDRegex, err := regexp.Compile(usernameTemplate)
-	if err != nil {
-		return err
-	}
-	registration.Namespaces.RegisterUserIDs(userIDRegex, true)
-	return nil
-}

+ 34 - 77
config/upgrade.go

@@ -20,50 +20,12 @@ import (
 	"strings"
 	"strings"
 
 
 	"maunium.net/go/mautrix/appservice"
 	"maunium.net/go/mautrix/appservice"
+	"maunium.net/go/mautrix/bridge/bridgeconfig"
 	up "maunium.net/go/mautrix/util/configupgrade"
 	up "maunium.net/go/mautrix/util/configupgrade"
 )
 )
 
 
-type waUpgrader struct{}
-
-func (wau waUpgrader) GetBase() string {
-	return ExampleConfig
-}
-
-func (wau waUpgrader) DoUpgrade(helper *up.Helper) {
-	helper.Copy(up.Str, "homeserver", "address")
-	helper.Copy(up.Str, "homeserver", "domain")
-	helper.Copy(up.Bool, "homeserver", "asmux")
-	helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint")
-	helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint")
-	helper.Copy(up.Bool, "homeserver", "async_media")
-
-	helper.Copy(up.Str, "appservice", "address")
-	helper.Copy(up.Str, "appservice", "hostname")
-	helper.Copy(up.Int, "appservice", "port")
-	helper.Copy(up.Str, "appservice", "database", "type")
-	helper.Copy(up.Str, "appservice", "database", "uri")
-	helper.Copy(up.Int, "appservice", "database", "max_open_conns")
-	helper.Copy(up.Int, "appservice", "database", "max_idle_conns")
-	helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_idle_time")
-	helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_lifetime")
-	if prefix, ok := helper.Get(up.Str, "appservice", "provisioning", "prefix"); ok && strings.HasSuffix(prefix, "/v1") {
-		helper.Set(up.Str, strings.TrimSuffix(prefix, "/v1"), "appservice", "provisioning", "prefix")
-	} else {
-		helper.Copy(up.Str, "appservice", "provisioning", "prefix")
-	}
-	if secret, ok := helper.Get(up.Str, "appservice", "provisioning", "shared_secret"); !ok || secret == "generate" {
-		sharedSecret := appservice.RandomString(64)
-		helper.Set(up.Str, sharedSecret, "appservice", "provisioning", "shared_secret")
-	} else {
-		helper.Copy(up.Str, "appservice", "provisioning", "shared_secret")
-	}
-	helper.Copy(up.Str, "appservice", "id")
-	helper.Copy(up.Str, "appservice", "bot", "username")
-	helper.Copy(up.Str, "appservice", "bot", "displayname")
-	helper.Copy(up.Str, "appservice", "bot", "avatar")
-	helper.Copy(up.Bool, "appservice", "ephemeral_events")
-	helper.Copy(up.Str, "appservice", "as_token")
-	helper.Copy(up.Str, "appservice", "hs_token")
+func DoUpgrade(helper *up.Helper) {
+	bridgeconfig.Upgrader.DoUpgrade(helper)
 
 
 	helper.Copy(up.Str|up.Null, "segment_key")
 	helper.Copy(up.Str|up.Null, "segment_key")
 
 
@@ -134,46 +96,41 @@ func (wau waUpgrader) DoUpgrade(helper *up.Helper) {
 	helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "allow")
 	helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "allow")
 	helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "require_cross_signing")
 	helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "require_cross_signing")
 	helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "require_verification")
 	helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "require_verification")
+	if prefix, ok := helper.Get(up.Str, "appservice", "provisioning", "prefix"); ok {
+		helper.Set(up.Str, strings.TrimSuffix(prefix, "/v1"), "bridge", "provisioning", "prefix")
+	} else {
+		helper.Copy(up.Str, "bridge", "provisioning", "prefix")
+	}
+	if secret, ok := helper.Get(up.Str, "appservice", "provisioning", "shared_secret"); ok && secret != "generate" {
+		helper.Set(up.Str, secret, "bridge", "provisioning", "shared_secret")
+	} else if secret, ok = helper.Get(up.Str, "bridge", "provisioning", "shared_secret"); !ok || secret == "generate" {
+		sharedSecret := appservice.RandomString(64)
+		helper.Set(up.Str, sharedSecret, "bridge", "provisioning", "shared_secret")
+	} else {
+		helper.Copy(up.Str, "bridge", "provisioning", "shared_secret")
+	}
 	helper.Copy(up.Map, "bridge", "permissions")
 	helper.Copy(up.Map, "bridge", "permissions")
 	helper.Copy(up.Bool, "bridge", "relay", "enabled")
 	helper.Copy(up.Bool, "bridge", "relay", "enabled")
 	helper.Copy(up.Bool, "bridge", "relay", "admin_only")
 	helper.Copy(up.Bool, "bridge", "relay", "admin_only")
 	helper.Copy(up.Map, "bridge", "relay", "message_formats")
 	helper.Copy(up.Map, "bridge", "relay", "message_formats")
-
-	helper.Copy(up.Str, "logging", "directory")
-	helper.Copy(up.Str|up.Null, "logging", "file_name_format")
-	helper.Copy(up.Str|up.Timestamp, "logging", "file_date_format")
-	helper.Copy(up.Int, "logging", "file_mode")
-	helper.Copy(up.Str|up.Timestamp, "logging", "timestamp_format")
-	helper.Copy(up.Str, "logging", "print_level")
-}
-
-func (wau waUpgrader) SpacedBlocks() [][]string {
-	return [][]string{
-		{"homeserver", "asmux"},
-		{"appservice"},
-		{"appservice", "hostname"},
-		{"appservice", "database"},
-		{"appservice", "provisioning"},
-		{"appservice", "id"},
-		{"appservice", "as_token"},
-		{"segment_key"},
-		{"metrics"},
-		{"whatsapp"},
-		{"bridge"},
-		{"bridge", "command_prefix"},
-		{"bridge", "management_room_text"},
-		{"bridge", "encryption"},
-		{"bridge", "permissions"},
-		{"bridge", "relay"},
-		{"logging"},
-	}
-}
-
-func Mutate(path string, mutate func(helper *up.Helper)) error {
-	_, _, err := up.Do(path, true, waUpgrader{}, up.SimpleUpgrader(mutate))
-	return err
 }
 }
 
 
-func Upgrade(path string, save bool) ([]byte, bool, error) {
-	return up.Do(path, save, waUpgrader{})
+var SpacedBlocks = [][]string{
+	{"homeserver", "asmux"},
+	{"appservice"},
+	{"appservice", "hostname"},
+	{"appservice", "database"},
+	{"appservice", "id"},
+	{"appservice", "as_token"},
+	{"segment_key"},
+	{"metrics"},
+	{"whatsapp"},
+	{"bridge"},
+	{"bridge", "command_prefix"},
+	{"bridge", "management_room_text"},
+	{"bridge", "encryption"},
+	{"bridge", "provisioning"},
+	{"bridge", "permissions"},
+	{"bridge", "relay"},
+	{"logging"},
 }
 }

+ 0 - 327
crypto.go

@@ -1,327 +0,0 @@
-// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2020 Tulir Asokan
-//
-// This program is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Affero General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// This program is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-// GNU Affero General Public License for more details.
-//
-// You should have received a copy of the GNU Affero General Public License
-// along with this program.  If not, see <https://www.gnu.org/licenses/>.
-
-//go:build cgo && !nocrypto
-
-package main
-
-import (
-	"fmt"
-	"runtime/debug"
-	"time"
-
-	"github.com/lib/pq"
-
-	"maunium.net/go/maulogger/v2"
-
-	"maunium.net/go/mautrix"
-	"maunium.net/go/mautrix/crypto"
-	"maunium.net/go/mautrix/event"
-	"maunium.net/go/mautrix/id"
-
-	"maunium.net/go/mautrix-whatsapp/database"
-)
-
-var NoSessionFound = crypto.NoSessionFound
-
-var levelTrace = maulogger.Level{
-	Name:     "TRACE",
-	Severity: -10,
-	Color:    -1,
-}
-
-type CryptoHelper struct {
-	bridge  *Bridge
-	client  *mautrix.Client
-	mach    *crypto.OlmMachine
-	store   *database.SQLCryptoStore
-	log     maulogger.Logger
-	baseLog maulogger.Logger
-}
-
-func init() {
-	crypto.PostgresArrayWrapper = pq.Array
-}
-
-func NewCryptoHelper(bridge *Bridge) Crypto {
-	if !bridge.Config.Bridge.Encryption.Allow {
-		bridge.Log.Debugln("Bridge built with end-to-bridge encryption, but disabled in config")
-		return nil
-	}
-	baseLog := bridge.Log.Sub("Crypto")
-	return &CryptoHelper{
-		bridge:  bridge,
-		log:     baseLog.Sub("Helper"),
-		baseLog: baseLog,
-	}
-}
-
-func (helper *CryptoHelper) Init() error {
-	helper.log.Debugln("Initializing end-to-bridge encryption...")
-
-	helper.store = database.NewSQLCryptoStore(helper.bridge.DB, helper.bridge.AS.BotMXID(),
-		fmt.Sprintf("@%s:%s", helper.bridge.Config.Bridge.FormatUsername("%"), helper.bridge.AS.HomeserverDomain))
-
-	var err error
-	helper.client, err = helper.loginBot()
-	if err != nil {
-		return err
-	}
-
-	helper.log.Debugln("Logged in as bridge bot with device ID", helper.client.DeviceID)
-	logger := &cryptoLogger{helper.baseLog}
-	stateStore := &cryptoStateStore{helper.bridge}
-	helper.mach = crypto.NewOlmMachine(helper.client, logger, helper.store, stateStore)
-	helper.mach.AllowKeyShare = helper.allowKeyShare
-
-	helper.client.Syncer = &cryptoSyncer{helper.mach}
-	helper.client.Store = &cryptoClientStore{helper.store}
-
-	return helper.mach.Load()
-}
-
-func (helper *CryptoHelper) allowKeyShare(device *crypto.DeviceIdentity, info event.RequestedKeyInfo) *crypto.KeyShareRejection {
-	cfg := helper.bridge.Config.Bridge.Encryption.KeySharing
-	if !cfg.Allow {
-		return &crypto.KeyShareRejectNoResponse
-	} else if device.Trust == crypto.TrustStateBlacklisted {
-		return &crypto.KeyShareRejectBlacklisted
-	} else if device.Trust == crypto.TrustStateVerified || !cfg.RequireVerification {
-		portal := helper.bridge.GetPortalByMXID(info.RoomID)
-		if portal == nil {
-			helper.log.Debugfln("Rejecting key request for %s from %s/%s: room is not a portal", info.SessionID, device.UserID, device.DeviceID)
-			return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnavailable, Reason: "Requested room is not a portal room"}
-		}
-		user := helper.bridge.GetUserByMXID(device.UserID)
-		// FIXME reimplement IsInPortal
-		if !user.Admin /*&& !user.IsInPortal(portal.Key)*/ {
-			helper.log.Debugfln("Rejecting key request for %s from %s/%s: user is not in portal", info.SessionID, device.UserID, device.DeviceID)
-			return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnauthorized, Reason: "You're not in that portal"}
-		}
-		helper.log.Debugfln("Accepting key request for %s from %s/%s", info.SessionID, device.UserID, device.DeviceID)
-		return nil
-	} else {
-		return &crypto.KeyShareRejectUnverified
-	}
-}
-
-func (helper *CryptoHelper) loginBot() (*mautrix.Client, error) {
-	deviceID := helper.store.FindDeviceID()
-	if len(deviceID) > 0 {
-		helper.log.Debugln("Found existing device ID for bot in database:", deviceID)
-	}
-	client, err := mautrix.NewClient(helper.bridge.AS.HomeserverURL, "", "")
-	if err != nil {
-		return nil, fmt.Errorf("failed to initialize client: %w", err)
-	}
-	client.Logger = helper.baseLog.Sub("Bot")
-	client.Client = helper.bridge.AS.HTTPClient
-	client.DefaultHTTPRetries = helper.bridge.AS.DefaultHTTPRetries
-	flows, err := client.GetLoginFlows()
-	if err != nil {
-		return nil, fmt.Errorf("failed to get supported login flows: %w", err)
-	} else if !flows.HasFlow(mautrix.AuthTypeAppservice) {
-		return nil, fmt.Errorf("homeserver does not support appservice login")
-	}
-	// We set the API token to the AS token here to authenticate the appservice login
-	// It'll get overridden after the login
-	client.AccessToken = helper.bridge.AS.Registration.AppToken
-	resp, err := client.Login(&mautrix.ReqLogin{
-		Type:                     mautrix.AuthTypeAppservice,
-		Identifier:               mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(helper.bridge.AS.BotMXID())},
-		DeviceID:                 deviceID,
-		InitialDeviceDisplayName: "WhatsApp Bridge",
-		StoreCredentials:         true,
-	})
-	if err != nil {
-		return nil, fmt.Errorf("failed to log in as bridge bot: %w", err)
-	}
-	helper.store.DeviceID = resp.DeviceID
-	return client, nil
-}
-
-func (helper *CryptoHelper) Start() {
-	helper.log.Debugln("Starting syncer for receiving to-device messages")
-	err := helper.client.Sync()
-	if err != nil {
-		helper.log.Errorln("Fatal error syncing:", err)
-	} else {
-		helper.log.Infoln("Bridge bot to-device syncer stopped without error")
-	}
-}
-
-func (helper *CryptoHelper) Stop() {
-	helper.log.Debugln("CryptoHelper.Stop() called, stopping bridge bot sync")
-	helper.client.StopSync()
-}
-
-func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
-	return helper.mach.DecryptMegolmEvent(evt)
-}
-
-func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content event.Content) (*event.EncryptedEventContent, error) {
-	encrypted, err := helper.mach.EncryptMegolmEvent(roomID, evtType, &content)
-	if err != nil {
-		if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession {
-			return nil, err
-		}
-		helper.log.Debugfln("Got %v while encrypting event for %s, sharing group session and trying again...", err, roomID)
-		users, err := helper.store.GetRoomMembers(roomID)
-		if err != nil {
-			return nil, fmt.Errorf("failed to get room member list: %w", err)
-		}
-		err = helper.mach.ShareGroupSession(roomID, users)
-		if err != nil {
-			return nil, fmt.Errorf("failed to share group session: %w", err)
-		}
-		encrypted, err = helper.mach.EncryptMegolmEvent(roomID, evtType, &content)
-		if err != nil {
-			return nil, fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err)
-		}
-	}
-	return encrypted, nil
-}
-
-func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
-	return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout)
-}
-
-func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
-	err := helper.mach.SendRoomKeyRequest(roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}})
-	if err != nil {
-		helper.log.Warnfln("Failed to send key request to %s/%s for %s in %s: %v", userID, deviceID, sessionID, roomID, err)
-	} else {
-		helper.log.Debugfln("Sent key request to %s/%s for %s in %s", userID, deviceID, sessionID, roomID)
-	}
-}
-
-func (helper *CryptoHelper) ResetSession(roomID id.RoomID) {
-	err := helper.mach.CryptoStore.RemoveOutboundGroupSession(roomID)
-	if err != nil {
-		helper.log.Debugfln("Error manually removing outbound group session in %s: %v", roomID, err)
-	}
-}
-
-func (helper *CryptoHelper) HandleMemberEvent(evt *event.Event) {
-	helper.mach.HandleMemberEvent(evt)
-}
-
-type cryptoSyncer struct {
-	*crypto.OlmMachine
-}
-
-func (syncer *cryptoSyncer) ProcessResponse(resp *mautrix.RespSync, since string) error {
-	done := make(chan struct{})
-	go func() {
-		defer func() {
-			if err := recover(); err != nil {
-				syncer.Log.Error("Processing sync response (%s) panicked: %v\n%s", since, err, debug.Stack())
-			}
-			done <- struct{}{}
-		}()
-		syncer.Log.Trace("Starting sync response handling (%s)", since)
-		syncer.ProcessSyncResponse(resp, since)
-		syncer.Log.Trace("Successfully handled sync response (%s)", since)
-	}()
-	select {
-	case <-done:
-	case <-time.After(30 * time.Second):
-		syncer.Log.Warn("Handling sync response (%s) is taking unusually long", since)
-	}
-	return nil
-}
-
-func (syncer *cryptoSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) {
-	syncer.Log.Error("Error /syncing, waiting 10 seconds: %v", err)
-	return 10 * time.Second, nil
-}
-
-func (syncer *cryptoSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter {
-	everything := []event.Type{{Type: "*"}}
-	return &mautrix.Filter{
-		Presence:    mautrix.FilterPart{NotTypes: everything},
-		AccountData: mautrix.FilterPart{NotTypes: everything},
-		Room: mautrix.RoomFilter{
-			IncludeLeave: false,
-			Ephemeral:    mautrix.FilterPart{NotTypes: everything},
-			AccountData:  mautrix.FilterPart{NotTypes: everything},
-			State:        mautrix.FilterPart{NotTypes: everything},
-			Timeline:     mautrix.FilterPart{NotTypes: everything},
-		},
-	}
-}
-
-type cryptoLogger struct {
-	int maulogger.Logger
-}
-
-func (c *cryptoLogger) Error(message string, args ...interface{}) {
-	c.int.Errorfln(message, args...)
-}
-
-func (c *cryptoLogger) Warn(message string, args ...interface{}) {
-	c.int.Warnfln(message, args...)
-}
-
-func (c *cryptoLogger) Debug(message string, args ...interface{}) {
-	c.int.Debugfln(message, args...)
-}
-
-func (c *cryptoLogger) Trace(message string, args ...interface{}) {
-	c.int.Logfln(levelTrace, message, args...)
-}
-
-type cryptoClientStore struct {
-	int *database.SQLCryptoStore
-}
-
-func (c cryptoClientStore) SaveFilterID(_ id.UserID, _ string) {}
-func (c cryptoClientStore) LoadFilterID(_ id.UserID) string    { return "" }
-func (c cryptoClientStore) SaveRoom(_ *mautrix.Room)           {}
-func (c cryptoClientStore) LoadRoom(_ id.RoomID) *mautrix.Room { return nil }
-
-func (c cryptoClientStore) SaveNextBatch(_ id.UserID, nextBatchToken string) {
-	c.int.PutNextBatch(nextBatchToken)
-}
-
-func (c cryptoClientStore) LoadNextBatch(_ id.UserID) string {
-	return c.int.GetNextBatch()
-}
-
-var _ mautrix.Storer = (*cryptoClientStore)(nil)
-
-type cryptoStateStore struct {
-	bridge *Bridge
-}
-
-var _ crypto.StateStore = (*cryptoStateStore)(nil)
-
-func (c *cryptoStateStore) IsEncrypted(id id.RoomID) bool {
-	portal := c.bridge.GetPortalByMXID(id)
-	if portal != nil {
-		return portal.Encrypted
-	}
-	return false
-}
-
-func (c *cryptoStateStore) FindSharedRooms(id id.UserID) []id.RoomID {
-	return c.bridge.StateStore.FindSharedRooms(id)
-}
-
-func (c *cryptoStateStore) GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent {
-	// TODO implement
-	return nil
-}

+ 11 - 11
custompuppet.go

@@ -75,8 +75,8 @@ func (puppet *Puppet) loginWithSharedSecret(mxid id.UserID) (string, error) {
 		Type:                     mautrix.AuthTypePassword,
 		Type:                     mautrix.AuthTypePassword,
 		Identifier:               mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(mxid)},
 		Identifier:               mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(mxid)},
 		Password:                 hex.EncodeToString(mac.Sum(nil)),
 		Password:                 hex.EncodeToString(mac.Sum(nil)),
-		DeviceID:                 "WhatsApp Bridge",
-		InitialDeviceDisplayName: "WhatsApp Bridge",
+		DeviceID:                 "WhatsApp bridge",
+		InitialDeviceDisplayName: "WhatsApp bridge",
 	})
 	})
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
@@ -84,22 +84,22 @@ func (puppet *Puppet) loginWithSharedSecret(mxid id.UserID) (string, error) {
 	return resp.AccessToken, nil
 	return resp.AccessToken, nil
 }
 }
 
 
-func (bridge *Bridge) newDoublePuppetClient(mxid id.UserID, accessToken string) (*mautrix.Client, error) {
+func (br *WABridge) newDoublePuppetClient(mxid id.UserID, accessToken string) (*mautrix.Client, error) {
 	_, homeserver, err := mxid.Parse()
 	_, homeserver, err := mxid.Parse()
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	homeserverURL, found := bridge.Config.Bridge.DoublePuppetServerMap[homeserver]
+	homeserverURL, found := br.Config.Bridge.DoublePuppetServerMap[homeserver]
 	if !found {
 	if !found {
-		if homeserver == bridge.AS.HomeserverDomain {
-			homeserverURL = bridge.AS.HomeserverURL
-		} else if bridge.Config.Bridge.DoublePuppetAllowDiscovery {
+		if homeserver == br.AS.HomeserverDomain {
+			homeserverURL = br.AS.HomeserverURL
+		} else if br.Config.Bridge.DoublePuppetAllowDiscovery {
 			resp, err := mautrix.DiscoverClientAPI(homeserver)
 			resp, err := mautrix.DiscoverClientAPI(homeserver)
 			if err != nil {
 			if err != nil {
 				return nil, fmt.Errorf("failed to find homeserver URL for %s: %v", homeserver, err)
 				return nil, fmt.Errorf("failed to find homeserver URL for %s: %v", homeserver, err)
 			}
 			}
 			homeserverURL = resp.Homeserver.BaseURL
 			homeserverURL = resp.Homeserver.BaseURL
-			bridge.Log.Debugfln("Discovered URL %s for %s to enable double puppeting for %s", homeserverURL, homeserver, mxid)
+			br.Log.Debugfln("Discovered URL %s for %s to enable double puppeting for %s", homeserverURL, homeserver, mxid)
 		} else {
 		} else {
 			return nil, fmt.Errorf("double puppeting from %s is not allowed", homeserver)
 			return nil, fmt.Errorf("double puppeting from %s is not allowed", homeserver)
 		}
 		}
@@ -108,9 +108,9 @@ func (bridge *Bridge) newDoublePuppetClient(mxid id.UserID, accessToken string)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	client.Logger = bridge.AS.Log.Sub(mxid.String())
-	client.Client = bridge.AS.HTTPClient
-	client.DefaultHTTPRetries = bridge.AS.DefaultHTTPRetries
+	client.Logger = br.AS.Log.Sub(mxid.String())
+	client.Client = br.AS.HTTPClient
+	client.DefaultHTTPRetries = br.AS.DefaultHTTPRetries
 	return client, nil
 	return client, nil
 }
 }
 
 

+ 4 - 2
database/backfill.go

@@ -26,7 +26,9 @@ import (
 	"time"
 	"time"
 
 
 	log "maunium.net/go/maulogger/v2"
 	log "maunium.net/go/maulogger/v2"
+
 	"maunium.net/go/mautrix/id"
 	"maunium.net/go/mautrix/id"
+	"maunium.net/go/mautrix/util/dbutil"
 )
 )
 
 
 type BackfillType int
 type BackfillType int
@@ -165,7 +167,7 @@ func (b *Backfill) String() string {
 	)
 	)
 }
 }
 
 
-func (b *Backfill) Scan(row Scannable) *Backfill {
+func (b *Backfill) Scan(row dbutil.Scannable) *Backfill {
 	err := row.Scan(&b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart, &b.MaxBatchEvents, &b.MaxTotalEvents, &b.BatchDelay)
 	err := row.Scan(&b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart, &b.MaxBatchEvents, &b.MaxTotalEvents, &b.BatchDelay)
 	if err != nil {
 	if err != nil {
 		if !errors.Is(err, sql.ErrNoRows) {
 		if !errors.Is(err, sql.ErrNoRows) {
@@ -256,7 +258,7 @@ type BackfillState struct {
 	FirstExpectedTimestamp uint64
 	FirstExpectedTimestamp uint64
 }
 }
 
 
-func (b *BackfillState) Scan(row Scannable) *BackfillState {
+func (b *BackfillState) Scan(row dbutil.Scannable) *BackfillState {
 	err := row.Scan(&b.UserID, &b.Portal.JID, &b.Portal.Receiver, &b.ProcessingBatch, &b.BackfillComplete, &b.FirstExpectedTimestamp)
 	err := row.Scan(&b.UserID, &b.Portal.JID, &b.Portal.Receiver, &b.ProcessingBatch, &b.BackfillComplete, &b.FirstExpectedTimestamp)
 	if err != nil {
 	if err != nil {
 		if !errors.Is(err, sql.ErrNoRows) {
 		if !errors.Is(err, sql.ErrNoRows) {

+ 7 - 48
database/cryptostore.go

@@ -1,18 +1,8 @@
-// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2020 Tulir Asokan
+// Copyright (c) 2022 Tulir Asokan
 //
 //
-// This program is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Affero General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// This program is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-// GNU Affero General Public License for more details.
-//
-// You should have received a copy of the GNU Affero General Public License
-// along with this program.  If not, see <https://www.gnu.org/licenses/>.
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
 
 
 //go:build cgo && !nocrypto
 //go:build cgo && !nocrypto
 
 
@@ -21,8 +11,6 @@ package database
 import (
 import (
 	"database/sql"
 	"database/sql"
 
 
-	log "maunium.net/go/maulogger/v2"
-
 	"maunium.net/go/mautrix/crypto"
 	"maunium.net/go/mautrix/crypto"
 	"maunium.net/go/mautrix/id"
 	"maunium.net/go/mautrix/id"
 )
 )
@@ -37,11 +25,9 @@ var _ crypto.Store = (*SQLCryptoStore)(nil)
 
 
 func NewSQLCryptoStore(db *Database, userID id.UserID, ghostIDFormat string) *SQLCryptoStore {
 func NewSQLCryptoStore(db *Database, userID id.UserID, ghostIDFormat string) *SQLCryptoStore {
 	return &SQLCryptoStore{
 	return &SQLCryptoStore{
-		SQLCryptoStore: crypto.NewSQLCryptoStore(db.DB, db.dialect, "", "",
-			[]byte("maunium.net/go/mautrix-whatsapp"),
-			&cryptoLogger{db.log.Sub("CryptoStore")}),
-		UserID:        userID,
-		GhostIDFormat: ghostIDFormat,
+		SQLCryptoStore: crypto.NewSQLCryptoStore(db.Database, "", "", []byte("maunium.net/go/mautrix-whatsapp")),
+		UserID:         userID,
+		GhostIDFormat:  ghostIDFormat,
 	}
 	}
 }
 }
 
 
@@ -76,30 +62,3 @@ func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) (members []id.User
 	}
 	}
 	return
 	return
 }
 }
-
-// TODO merge this with the one in the parent package
-type cryptoLogger struct {
-	int log.Logger
-}
-
-var levelTrace = log.Level{
-	Name:     "TRACE",
-	Severity: -10,
-	Color:    -1,
-}
-
-func (c *cryptoLogger) Error(message string, args ...interface{}) {
-	c.int.Errorfln(message, args...)
-}
-
-func (c *cryptoLogger) Warn(message string, args ...interface{}) {
-	c.int.Warnfln(message, args...)
-}
-
-func (c *cryptoLogger) Debug(message string, args ...interface{}) {
-	c.int.Debugfln(message, args...)
-}
-
-func (c *cryptoLogger) Trace(message string, args ...interface{}) {
-	c.int.Logfln(levelTrace, message, args...)
-}

+ 16 - 55
database/database.go

@@ -17,21 +17,17 @@
 package database
 package database
 
 
 import (
 import (
-	"database/sql"
 	"errors"
 	"errors"
-	"fmt"
 	"net"
 	"net"
 	"time"
 	"time"
 
 
 	"github.com/lib/pq"
 	"github.com/lib/pq"
 	_ "github.com/mattn/go-sqlite3"
 	_ "github.com/mattn/go-sqlite3"
-	log "maunium.net/go/maulogger/v2"
-
 	"go.mau.fi/whatsmeow/store"
 	"go.mau.fi/whatsmeow/store"
 	"go.mau.fi/whatsmeow/store/sqlstore"
 	"go.mau.fi/whatsmeow/store/sqlstore"
 
 
-	"maunium.net/go/mautrix-whatsapp/config"
 	"maunium.net/go/mautrix-whatsapp/database/upgrades"
 	"maunium.net/go/mautrix-whatsapp/database/upgrades"
+	"maunium.net/go/mautrix/util/dbutil"
 )
 )
 
 
 func init() {
 func init() {
@@ -39,9 +35,7 @@ func init() {
 }
 }
 
 
 type Database struct {
 type Database struct {
-	*sql.DB
-	log     log.Logger
-	dialect string
+	*dbutil.Database
 
 
 	User     *UserQuery
 	User     *UserQuery
 	Portal   *PortalQuery
 	Portal   *PortalQuery
@@ -55,79 +49,46 @@ type Database struct {
 	MediaBackfillRequest *MediaBackfillRequestQuery
 	MediaBackfillRequest *MediaBackfillRequestQuery
 }
 }
 
 
-func New(cfg config.DatabaseConfig, baseLog log.Logger) (*Database, error) {
-	conn, err := sql.Open(cfg.Type, cfg.URI)
-	if err != nil {
-		return nil, err
-	}
-
-	db := &Database{
-		DB:      conn,
-		log:     baseLog.Sub("Database"),
-		dialect: cfg.Type,
-	}
+func New(baseDB *dbutil.Database) *Database {
+	db := &Database{Database: baseDB}
+	db.UpgradeTable = upgrades.Table
 	db.User = &UserQuery{
 	db.User = &UserQuery{
 		db:  db,
 		db:  db,
-		log: db.log.Sub("User"),
+		log: db.Log.Sub("User"),
 	}
 	}
 	db.Portal = &PortalQuery{
 	db.Portal = &PortalQuery{
 		db:  db,
 		db:  db,
-		log: db.log.Sub("Portal"),
+		log: db.Log.Sub("Portal"),
 	}
 	}
 	db.Puppet = &PuppetQuery{
 	db.Puppet = &PuppetQuery{
 		db:  db,
 		db:  db,
-		log: db.log.Sub("Puppet"),
+		log: db.Log.Sub("Puppet"),
 	}
 	}
 	db.Message = &MessageQuery{
 	db.Message = &MessageQuery{
 		db:  db,
 		db:  db,
-		log: db.log.Sub("Message"),
+		log: db.Log.Sub("Message"),
 	}
 	}
 	db.Reaction = &ReactionQuery{
 	db.Reaction = &ReactionQuery{
 		db:  db,
 		db:  db,
-		log: db.log.Sub("Reaction"),
+		log: db.Log.Sub("Reaction"),
 	}
 	}
 	db.DisappearingMessage = &DisappearingMessageQuery{
 	db.DisappearingMessage = &DisappearingMessageQuery{
 		db:  db,
 		db:  db,
-		log: db.log.Sub("DisappearingMessage"),
+		log: db.Log.Sub("DisappearingMessage"),
 	}
 	}
 	db.Backfill = &BackfillQuery{
 	db.Backfill = &BackfillQuery{
 		db:  db,
 		db:  db,
-		log: db.log.Sub("Backfill"),
+		log: db.Log.Sub("Backfill"),
 	}
 	}
 	db.HistorySync = &HistorySyncQuery{
 	db.HistorySync = &HistorySyncQuery{
 		db:  db,
 		db:  db,
-		log: db.log.Sub("HistorySync"),
+		log: db.Log.Sub("HistorySync"),
 	}
 	}
 	db.MediaBackfillRequest = &MediaBackfillRequestQuery{
 	db.MediaBackfillRequest = &MediaBackfillRequestQuery{
 		db:  db,
 		db:  db,
-		log: db.log.Sub("MediaBackfillRequest"),
-	}
-
-	db.SetMaxOpenConns(cfg.MaxOpenConns)
-	db.SetMaxIdleConns(cfg.MaxIdleConns)
-	if len(cfg.ConnMaxIdleTime) > 0 {
-		maxIdleTimeDuration, err := time.ParseDuration(cfg.ConnMaxIdleTime)
-		if err != nil {
-			return nil, fmt.Errorf("failed to parse max_conn_idle_time: %w", err)
-		}
-		db.SetConnMaxIdleTime(maxIdleTimeDuration)
+		log: db.Log.Sub("MediaBackfillRequest"),
 	}
 	}
-	if len(cfg.ConnMaxLifetime) > 0 {
-		maxLifetimeDuration, err := time.ParseDuration(cfg.ConnMaxLifetime)
-		if err != nil {
-			return nil, fmt.Errorf("failed to parse max_conn_idle_time: %w", err)
-		}
-		db.SetConnMaxLifetime(maxLifetimeDuration)
-	}
-	return db, nil
-}
-
-func (db *Database) Init() error {
-	return upgrades.Run(db.log.Sub("Upgrade"), db.dialect, db.DB)
-}
-
-type Scannable interface {
-	Scan(...interface{}) error
+	return db
 }
 }
 
 
 func isRetryableError(err error) bool {
 func isRetryableError(err error) bool {
@@ -145,7 +106,7 @@ func isRetryableError(err error) bool {
 }
 }
 
 
 func (db *Database) HandleSignalStoreError(device *store.Device, action string, attemptIndex int, err error) (retry bool) {
 func (db *Database) HandleSignalStoreError(device *store.Device, action string, attemptIndex int, err error) (retry bool) {
-	if db.dialect != "sqlite" && isRetryableError(err) {
+	if db.Dialect != dbutil.SQLite && isRetryableError(err) {
 		sleepTime := time.Duration(attemptIndex*2) * time.Second
 		sleepTime := time.Duration(attemptIndex*2) * time.Second
 		device.Log.Warnf("Failed to %s (attempt #%d): %v - retrying in %v", action, attemptIndex+1, err, sleepTime)
 		device.Log.Warnf("Failed to %s (attempt #%d): %v - retrying in %v", action, attemptIndex+1, err, sleepTime)
 		time.Sleep(sleepTime)
 		time.Sleep(sleepTime)

+ 2 - 1
database/disappearingmessage.go

@@ -24,6 +24,7 @@ import (
 	log "maunium.net/go/maulogger/v2"
 	log "maunium.net/go/maulogger/v2"
 
 
 	"maunium.net/go/mautrix/id"
 	"maunium.net/go/mautrix/id"
+	"maunium.net/go/mautrix/util/dbutil"
 )
 )
 
 
 type DisappearingMessageQuery struct {
 type DisappearingMessageQuery struct {
@@ -94,7 +95,7 @@ type DisappearingMessage struct {
 	ExpireAt time.Time
 	ExpireAt time.Time
 }
 }
 
 
-func (msg *DisappearingMessage) Scan(row Scannable) *DisappearingMessage {
+func (msg *DisappearingMessage) Scan(row dbutil.Scannable) *DisappearingMessage {
 	var expireIn int64
 	var expireIn int64
 	var expireAt sql.NullInt64
 	var expireAt sql.NullInt64
 	err := row.Scan(&msg.RoomID, &msg.EventID, &expireIn, &expireAt)
 	err := row.Scan(&msg.RoomID, &msg.EventID, &expireIn, &expireAt)

+ 4 - 2
database/historysync.go

@@ -27,7 +27,9 @@ import (
 
 
 	_ "github.com/mattn/go-sqlite3"
 	_ "github.com/mattn/go-sqlite3"
 	log "maunium.net/go/maulogger/v2"
 	log "maunium.net/go/maulogger/v2"
+
 	"maunium.net/go/mautrix/id"
 	"maunium.net/go/mautrix/id"
+	"maunium.net/go/mautrix/util/dbutil"
 )
 )
 
 
 type HistorySyncQuery struct {
 type HistorySyncQuery struct {
@@ -139,7 +141,7 @@ func (hsc *HistorySyncConversation) Upsert() {
 	}
 	}
 }
 }
 
 
-func (hsc *HistorySyncConversation) Scan(row Scannable) *HistorySyncConversation {
+func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) *HistorySyncConversation {
 	err := row.Scan(
 	err := row.Scan(
 		&hsc.UserID,
 		&hsc.UserID,
 		&hsc.ConversationID,
 		&hsc.ConversationID,
@@ -166,7 +168,7 @@ func (hsc *HistorySyncConversation) Scan(row Scannable) *HistorySyncConversation
 func (hsq *HistorySyncQuery) GetNMostRecentConversations(userID id.UserID, n int) (conversations []*HistorySyncConversation) {
 func (hsq *HistorySyncQuery) GetNMostRecentConversations(userID id.UserID, n int) (conversations []*HistorySyncConversation) {
 	nPtr := &n
 	nPtr := &n
 	// Negative limit on SQLite means unlimited, but Postgres prefers a NULL limit.
 	// Negative limit on SQLite means unlimited, but Postgres prefers a NULL limit.
-	if n < 0 && hsq.db.dialect == "postgres" {
+	if n < 0 && hsq.db.Dialect == dbutil.Postgres {
 		nPtr = nil
 		nPtr = nil
 	}
 	}
 	rows, err := hsq.db.Query(getNMostRecentConversations, userID, nPtr)
 	rows, err := hsq.db.Query(getNMostRecentConversations, userID, nPtr)

+ 3 - 1
database/mediabackfillrequest.go

@@ -22,7 +22,9 @@ import (
 
 
 	_ "github.com/mattn/go-sqlite3"
 	_ "github.com/mattn/go-sqlite3"
 	log "maunium.net/go/maulogger/v2"
 	log "maunium.net/go/maulogger/v2"
+
 	"maunium.net/go/mautrix/id"
 	"maunium.net/go/mautrix/id"
+	"maunium.net/go/mautrix/util/dbutil"
 )
 )
 
 
 type MediaBackfillRequestStatus int
 type MediaBackfillRequestStatus int
@@ -100,7 +102,7 @@ func (mbr *MediaBackfillRequest) Upsert() {
 	}
 	}
 }
 }
 
 
-func (mbr *MediaBackfillRequest) Scan(row Scannable) *MediaBackfillRequest {
+func (mbr *MediaBackfillRequest) Scan(row dbutil.Scannable) *MediaBackfillRequest {
 	err := row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.MediaKey, &mbr.Status, &mbr.Error)
 	err := row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.MediaKey, &mbr.Status, &mbr.Error)
 	if err != nil {
 	if err != nil {
 		if !errors.Is(err, sql.ErrNoRows) {
 		if !errors.Is(err, sql.ErrNoRows) {

+ 2 - 1
database/message.go

@@ -25,6 +25,7 @@ import (
 	log "maunium.net/go/maulogger/v2"
 	log "maunium.net/go/maulogger/v2"
 
 
 	"maunium.net/go/mautrix/id"
 	"maunium.net/go/mautrix/id"
+	"maunium.net/go/mautrix/util/dbutil"
 
 
 	"go.mau.fi/whatsmeow/types"
 	"go.mau.fi/whatsmeow/types"
 )
 )
@@ -163,7 +164,7 @@ func (msg *Message) IsFakeJID() bool {
 	return strings.HasPrefix(msg.JID, "FAKE::") || msg.JID == string(msg.MXID)
 	return strings.HasPrefix(msg.JID, "FAKE::") || msg.JID == string(msg.MXID)
 }
 }
 
 
-func (msg *Message) Scan(row Scannable) *Message {
+func (msg *Message) Scan(row dbutil.Scannable) *Message {
 	var ts int64
 	var ts int64
 	err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID)
 	err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID)
 	if err != nil {
 	if err != nil {

+ 2 - 1
database/portal.go

@@ -22,6 +22,7 @@ import (
 	log "maunium.net/go/maulogger/v2"
 	log "maunium.net/go/maulogger/v2"
 
 
 	"maunium.net/go/mautrix/id"
 	"maunium.net/go/mautrix/id"
+	"maunium.net/go/mautrix/util/dbutil"
 
 
 	"go.mau.fi/whatsmeow/types"
 	"go.mau.fi/whatsmeow/types"
 )
 )
@@ -152,7 +153,7 @@ type Portal struct {
 	ExpirationTime uint32
 	ExpirationTime uint32
 }
 }
 
 
-func (portal *Portal) Scan(row Scannable) *Portal {
+func (portal *Portal) Scan(row dbutil.Scannable) *Portal {
 	var mxid, avatarURL, firstEventID, nextBatchID, relayUserID sql.NullString
 	var mxid, avatarURL, firstEventID, nextBatchID, relayUserID sql.NullString
 	err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.Topic, &portal.Avatar, &avatarURL, &portal.Encrypted, &firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime)
 	err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.Topic, &portal.Avatar, &avatarURL, &portal.Encrypted, &firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime)
 	if err != nil {
 	if err != nil {

+ 3 - 1
database/puppet.go

@@ -20,7 +20,9 @@ import (
 	"database/sql"
 	"database/sql"
 
 
 	log "maunium.net/go/maulogger/v2"
 	log "maunium.net/go/maulogger/v2"
+
 	"maunium.net/go/mautrix/id"
 	"maunium.net/go/mautrix/id"
+	"maunium.net/go/mautrix/util/dbutil"
 
 
 	"go.mau.fi/whatsmeow/types"
 	"go.mau.fi/whatsmeow/types"
 )
 )
@@ -97,7 +99,7 @@ type Puppet struct {
 	EnableReceipts bool
 	EnableReceipts bool
 }
 }
 
 
-func (puppet *Puppet) Scan(row Scannable) *Puppet {
+func (puppet *Puppet) Scan(row dbutil.Scannable) *Puppet {
 	var displayname, avatar, avatarURL, customMXID, accessToken, nextBatch sql.NullString
 	var displayname, avatar, avatarURL, customMXID, accessToken, nextBatch sql.NullString
 	var quality sql.NullInt64
 	var quality sql.NullInt64
 	var enablePresence, enableReceipts sql.NullBool
 	var enablePresence, enableReceipts sql.NullBool

+ 2 - 1
database/reaction.go

@@ -23,6 +23,7 @@ import (
 	log "maunium.net/go/maulogger/v2"
 	log "maunium.net/go/maulogger/v2"
 
 
 	"maunium.net/go/mautrix/id"
 	"maunium.net/go/mautrix/id"
+	"maunium.net/go/mautrix/util/dbutil"
 
 
 	"go.mau.fi/whatsmeow/types"
 	"go.mau.fi/whatsmeow/types"
 )
 )
@@ -85,7 +86,7 @@ type Reaction struct {
 	JID       types.MessageID
 	JID       types.MessageID
 }
 }
 
 
-func (reaction *Reaction) Scan(row Scannable) *Reaction {
+func (reaction *Reaction) Scan(row dbutil.Scannable) *Reaction {
 	err := row.Scan(&reaction.Chat.JID, &reaction.Chat.Receiver, &reaction.TargetJID, &reaction.Sender, &reaction.MXID, &reaction.JID)
 	err := row.Scan(&reaction.Chat.JID, &reaction.Chat.Receiver, &reaction.TargetJID, &reaction.Sender, &reaction.MXID, &reaction.JID)
 	if err != nil {
 	if err != nil {
 		if !errors.Is(err, sql.ErrNoRows) {
 		if !errors.Is(err, sql.ErrNoRows) {

+ 0 - 282
database/statestore.go

@@ -1,282 +0,0 @@
-// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
-// Copyright (C) 2022 Tulir Asokan
-//
-// This program is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Affero General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// This program is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-// GNU Affero General Public License for more details.
-//
-// You should have received a copy of the GNU Affero General Public License
-// along with this program.  If not, see <https://www.gnu.org/licenses/>.
-
-package database
-
-import (
-	"database/sql"
-	"encoding/json"
-	"errors"
-	"sync"
-
-	log "maunium.net/go/maulogger/v2"
-
-	"maunium.net/go/mautrix/appservice"
-	"maunium.net/go/mautrix/event"
-	"maunium.net/go/mautrix/id"
-)
-
-type SQLStateStore struct {
-	*appservice.TypingStateStore
-
-	db  *Database
-	log log.Logger
-
-	Typing     map[id.RoomID]map[id.UserID]int64
-	typingLock sync.RWMutex
-}
-
-var _ appservice.StateStore = (*SQLStateStore)(nil)
-
-func NewSQLStateStore(db *Database) *SQLStateStore {
-	return &SQLStateStore{
-		TypingStateStore: appservice.NewTypingStateStore(),
-		db:               db,
-		log:              db.log.Sub("StateStore"),
-	}
-}
-
-func (store *SQLStateStore) IsRegistered(userID id.UserID) bool {
-	var isRegistered bool
-	err := store.db.
-		QueryRow("SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID).
-		Scan(&isRegistered)
-	if err != nil {
-		store.log.Warnfln("Failed to scan registration existence for %s: %v", userID, err)
-	}
-	return isRegistered
-}
-
-func (store *SQLStateStore) MarkRegistered(userID id.UserID) {
-	_, err := store.db.Exec("INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
-	if err != nil {
-		store.log.Warnfln("Failed to mark %s as registered: %v", userID, err)
-	}
-}
-
-func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent {
-	members := make(map[id.UserID]*event.MemberEventContent)
-	rows, err := store.db.Query("SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1", roomID)
-	if err != nil {
-		return members
-	}
-	var userID id.UserID
-	var member event.MemberEventContent
-	for rows.Next() {
-		err = rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL)
-		if err != nil {
-			store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
-		} else {
-			members[userID] = &member
-		}
-	}
-	return members
-}
-
-func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
-	membership := event.MembershipLeave
-	err := store.db.
-		QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
-		Scan(&membership)
-	if err != nil && err != sql.ErrNoRows {
-		store.log.Warnfln("Failed to scan membership of %s in %s: %v", userID, roomID, err)
-	}
-	return membership
-}
-
-func (store *SQLStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
-	member, ok := store.TryGetMember(roomID, userID)
-	if !ok {
-		member.Membership = event.MembershipLeave
-	}
-	return member
-}
-
-func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) {
-	var member event.MemberEventContent
-	err := store.db.
-		QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
-		Scan(&member.Membership, &member.Displayname, &member.AvatarURL)
-	if err != nil && err != sql.ErrNoRows {
-		store.log.Warnfln("Failed to scan member info of %s in %s: %v", userID, roomID, err)
-	}
-	return &member, err == nil
-}
-
-func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) {
-	rows, err := store.db.Query(`
-			SELECT room_id FROM mx_user_profile
-			LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id
-			WHERE user_id=$1 AND portal.encrypted=true
-	`, userID)
-	if err != nil {
-		store.log.Warnfln("Failed to query shared rooms with %s: %v", userID, err)
-		return
-	}
-	for rows.Next() {
-		var roomID id.RoomID
-		err = rows.Scan(&roomID)
-		if err != nil {
-			store.log.Warnfln("Failed to scan room ID: %v", err)
-		} else {
-			rooms = append(rooms, roomID)
-		}
-	}
-	return
-}
-
-func (store *SQLStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
-	return store.IsMembership(roomID, userID, "join")
-}
-
-func (store *SQLStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
-	return store.IsMembership(roomID, userID, "join", "invite")
-}
-
-func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
-	membership := store.GetMembership(roomID, userID)
-	for _, allowedMembership := range allowedMemberships {
-		if allowedMembership == membership {
-			return true
-		}
-	}
-	return false
-}
-
-func (store *SQLStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
-	_, err := store.db.Exec(`
-		INSERT INTO mx_user_profile (room_id, user_id, membership) VALUES ($1, $2, $3)
-		ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership
-	`, roomID, userID, membership)
-	if err != nil {
-		store.log.Warnfln("Failed to set membership of %s in %s to %s: %v", userID, roomID, membership, err)
-	}
-}
-
-func (store *SQLStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
-	_, err := store.db.Exec(`
-		INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)
-		ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership, displayname=excluded.displayname, avatar_url=excluded.avatar_url
-	`, roomID, userID, member.Membership, member.Displayname, member.AvatarURL)
-	if err != nil {
-		store.log.Warnfln("Failed to set membership of %s in %s to %s: %v", userID, roomID, member, err)
-	}
-}
-
-func (store *SQLStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
-	levelsBytes, err := json.Marshal(levels)
-	if err != nil {
-		store.log.Errorfln("Failed to marshal power levels of %s: %v", roomID, err)
-		return
-	}
-	_, err = store.db.Exec(`
-		INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2)
-		ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels
-	`, roomID, levelsBytes)
-	if err != nil {
-		store.log.Warnfln("Failed to store power levels of %s: %v", roomID, err)
-	}
-}
-
-func (store *SQLStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
-	var data []byte
-	err := store.db.
-		QueryRow("SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
-		Scan(&data)
-	if err != nil {
-		if !errors.Is(err, sql.ErrNoRows) {
-			store.log.Errorfln("Failed to scan power levels of %s: %v", roomID, err)
-		}
-		return
-	}
-	levels = &event.PowerLevelsEventContent{}
-	err = json.Unmarshal(data, levels)
-	if err != nil {
-		store.log.Errorfln("Failed to parse power levels of %s: %v", roomID, err)
-		return nil
-	}
-	return
-}
-
-func (store *SQLStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
-	if store.db.dialect == "postgres" {
-		var powerLevel int
-		err := store.db.
-			QueryRow(`
-				SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
-				FROM mx_room_state WHERE room_id=$1
-			`, roomID, userID).
-			Scan(&powerLevel)
-		if err != nil && !errors.Is(err, sql.ErrNoRows) {
-			store.log.Errorfln("Failed to scan power level of %s in %s: %v", userID, roomID, err)
-		}
-		return powerLevel
-	}
-	return store.GetPowerLevels(roomID).GetUserLevel(userID)
-}
-
-func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
-	if store.db.dialect == "postgres" {
-		defaultType := "events_default"
-		defaultValue := 0
-		if eventType.IsState() {
-			defaultType = "state_default"
-			defaultValue = 50
-		}
-		var powerLevel int
-		err := store.db.
-			QueryRow(`
-				SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4)
-				FROM mx_room_state WHERE room_id=$1
-			`, roomID, eventType.Type, defaultType, defaultValue).
-			Scan(&powerLevel)
-		if err != nil {
-			if !errors.Is(err, sql.ErrNoRows) {
-				store.log.Errorfln("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
-			}
-			return defaultValue
-		}
-		return powerLevel
-	}
-	return store.GetPowerLevels(roomID).GetEventLevel(eventType)
-}
-
-func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
-	if store.db.dialect == "postgres" {
-		defaultType := "events_default"
-		defaultValue := 0
-		if eventType.IsState() {
-			defaultType = "state_default"
-			defaultValue = 50
-		}
-		var hasPower bool
-		err := store.db.
-			QueryRow(`SELECT
-				COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
-				>=
-				COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5)
-				FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue).
-			Scan(&hasPower)
-		if err != nil {
-			if !errors.Is(err, sql.ErrNoRows) {
-				store.log.Errorfln("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
-			}
-			return defaultValue == 0
-		}
-		return hasPower
-	}
-	return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
-}

+ 181 - 0
database/upgrades/00-latest-revision.sql

@@ -0,0 +1,181 @@
+-- v0 -> v48: Latest revision
+
+CREATE TABLE "user" (
+    mxid     TEXT PRIMARY KEY,
+    username TEXT UNIQUE,
+	agent    SMALLINT,
+	device   SMALLINT,
+
+    management_room TEXT,
+    space_room      TEXT,
+
+    phone_last_seen   BIGINT,
+    phone_last_pinged BIGINT,
+
+    timezone TEXT
+);
+
+CREATE TABLE portal (
+    jid        TEXT,
+    receiver   TEXT,
+    mxid       TEXT UNIQUE,
+    name       TEXT NOT NULL,
+    topic      TEXT NOT NULL,
+    avatar     TEXT NOT NULL,
+    avatar_url TEXT,
+    encrypted  BOOLEAN NOT NULL DEFAULT false,
+
+    first_event_id  TEXT,
+    next_batch_id   TEXT,
+    relay_user_id   TEXT,
+	expiration_time BIGINT NOT NULL DEFAULT 0,
+
+	PRIMARY KEY (jid, receiver)
+);
+
+CREATE TABLE puppet (
+    username     TEXT PRIMARY KEY,
+    displayname  TEXT,
+	name_quality SMALLINT,
+	avatar       TEXT,
+	avatar_url   TEXT,
+
+    custom_mxid  TEXT,
+    access_token TEXT,
+    next_batch   TEXT,
+
+	enable_presence BOOLEAN NOT NULL DEFAULT true,
+	enable_receipts BOOLEAN NOT NULL DEFAULT true
+);
+
+-- only: postgres
+CREATE TYPE error_type AS ENUM ('', 'decryption_failed', 'media_not_found');
+
+CREATE TABLE message (
+    chat_jid      TEXT,
+    chat_receiver TEXT,
+    jid           TEXT,
+    mxid          TEXT UNIQUE,
+    sender        TEXT,
+    timestamp     BIGINT,
+    sent          BOOLEAN,
+	error         error_type,
+	type          TEXT,
+
+    broadcast_list_jid TEXT,
+
+	PRIMARY KEY (chat_jid, chat_receiver, jid),
+	FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
+);
+
+CREATE TABLE reaction (
+	chat_jid      TEXT,
+	chat_receiver TEXT,
+	target_jid    TEXT,
+	sender        TEXT,
+
+	mxid TEXT NOT NULL,
+	jid  TEXT NOT NULL,
+
+	PRIMARY KEY (chat_jid, chat_receiver, target_jid, sender),
+	FOREIGN KEY (chat_jid, chat_receiver, target_jid) REFERENCES message(chat_jid, chat_receiver, jid)
+	    ON DELETE CASCADE ON UPDATE CASCADE
+);
+
+CREATE TABLE disappearing_message (
+	room_id   TEXT,
+	event_id  TEXT,
+	expire_in BIGINT NOT NULL,
+	expire_at BIGINT,
+	PRIMARY KEY (room_id, event_id)
+);
+
+CREATE TABLE user_portal (
+    user_mxid       TEXT,
+    portal_jid      TEXT,
+    portal_receiver TEXT,
+    last_read_ts    BIGINT  NOT NULL DEFAULT 0,
+    in_space        BOOLEAN NOT NULL DEFAULT false,
+    PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
+    FOREIGN KEY (user_mxid)                   REFERENCES "user"(mxid)          ON UPDATE CASCADE ON DELETE CASCADE,
+    FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON UPDATE CASCADE ON DELETE CASCADE
+);
+
+CREATE TABLE backfill_queue (
+	queue_id INTEGER PRIMARY KEY
+		-- only: postgres
+		GENERATED ALWAYS AS IDENTITY
+		,
+	user_mxid        TEXT,
+	type             INTEGER NOT NULL,
+	priority         INTEGER NOT NULL,
+	portal_jid       TEXT,
+	portal_receiver  TEXT,
+	time_start       TIMESTAMP,
+	dispatch_time    TIMESTAMP,
+	completed_at     TIMESTAMP,
+	batch_delay      INTEGER,
+	max_batch_events INTEGER NOT NULL,
+	max_total_events INTEGER,
+
+	FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
+);
+
+CREATE TABLE backfill_state (
+	user_mxid         TEXT,
+	portal_jid        TEXT,
+	portal_receiver   TEXT,
+	processing_batch  BOOLEAN,
+	backfill_complete BOOLEAN,
+	first_expected_ts TIMESTAMP,
+	PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
+	FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE ON UPDATE CASCADE,
+	FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal (jid, receiver) ON DELETE CASCADE
+);
+
+CREATE TABLE media_backfill_requests (
+	user_mxid       TEXT,
+	portal_jid      TEXT,
+	portal_receiver TEXT,
+	event_id        TEXT,
+	media_key       bytea,
+	status          INTEGER,
+	error           TEXT,
+	PRIMARY KEY (user_mxid, portal_jid, portal_receiver, event_id),
+	FOREIGN KEY (user_mxid)                   REFERENCES "user"(mxid)          ON UPDATE CASCADE ON DELETE CASCADE,
+	FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON UPDATE CASCADE ON DELETE CASCADE
+);
+
+CREATE TABLE history_sync_conversation (
+	user_mxid       TEXT,
+	conversation_id TEXT,
+	portal_jid      TEXT,
+	portal_receiver TEXT,
+
+	last_message_timestamp       TIMESTAMP,
+	archived                     BOOLEAN,
+	pinned                       INTEGER,
+	mute_end_time                TIMESTAMP,
+	disappearing_mode            INTEGER,
+	end_of_history_transfer_type INTEGER,
+	ephemeral_Expiration         INTEGER,
+	marked_as_unread             BOOLEAN,
+	unread_count                 INTEGER,
+
+	PRIMARY KEY (user_mxid, conversation_id),
+	FOREIGN KEY (user_mxid)                   REFERENCES "user"(mxid)          ON UPDATE CASCADE ON DELETE CASCADE,
+	FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON UPDATE CASCADE ON DELETE CASCADE
+);
+
+CREATE TABLE history_sync_message (
+	user_mxid       TEXT,
+	conversation_id TEXT,
+	message_id      TEXT,
+	timestamp       TIMESTAMP,
+	data            bytea,
+	inserted_time   TIMESTAMP,
+
+	PRIMARY KEY (user_mxid, conversation_id, message_id),
+	FOREIGN KEY (user_mxid)                  REFERENCES "user"(mxid) ON UPDATE CASCADE ON DELETE CASCADE,
+	FOREIGN KEY (user_mxid, conversation_id) REFERENCES history_sync_conversation(user_mxid, conversation_id) ON DELETE CASCADE
+);

+ 0 - 67
database/upgrades/2018-09-01-initial-schema.go

@@ -1,67 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[0] = upgrade{"Initial schema", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`CREATE TABLE IF NOT EXISTS portal (
-			jid      VARCHAR(255),
-			receiver VARCHAR(255),
-			mxid     VARCHAR(255) UNIQUE,
-
-			name   VARCHAR(255) NOT NULL,
-			topic  VARCHAR(255) NOT NULL,
-			avatar VARCHAR(255) NOT NULL,
-
-			PRIMARY KEY (jid, receiver)
-		)`)
-		if err != nil {
-			return err
-		}
-
-		_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS puppet (
-			jid          VARCHAR(255) PRIMARY KEY,
-			avatar       VARCHAR(255),
-			displayname  VARCHAR(255),
-			name_quality SMALLINT
-		)`)
-		if err != nil {
-			return err
-		}
-
-		_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS "user" (
-			mxid VARCHAR(255) PRIMARY KEY,
-			jid  VARCHAR(255) UNIQUE,
-
-			management_room VARCHAR(255),
-
-			client_id    VARCHAR(255),
-			client_token VARCHAR(255),
-			server_token VARCHAR(255),
-			enc_key      bytea,
-			mac_key      bytea
-		)`)
-		if err != nil {
-			return err
-		}
-
-		_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS message (
-			chat_jid      VARCHAR(255),
-			chat_receiver VARCHAR(255),
-			jid           VARCHAR(255),
-			mxid          VARCHAR(255) NOT NULL UNIQUE,
-			sender        VARCHAR(255) NOT NULL,
-			content       bytea        NOT NULL,
-
-			PRIMARY KEY (chat_jid, chat_receiver, jid),
-			FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
-		)`)
-		if err != nil {
-			return err
-		}
-
-		return nil
-	}}
-}

+ 0 - 15
database/upgrades/2019-05-21-message-timestamp-column.go

@@ -1,15 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[2] = upgrade{"Add timestamp column to messages", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec("ALTER TABLE message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0")
-		if err != nil {
-			return err
-		}
-		return nil
-	}}
-}

+ 0 - 15
database/upgrades/2019-05-22-user-last-connection-column.go

@@ -1,15 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[3] = upgrade{"Add last_connection column to users", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN last_connection BIGINT NOT NULL DEFAULT 0`)
-		if err != nil {
-			return err
-		}
-		return nil
-	}}
-}

+ 0 - 23
database/upgrades/2019-05-23-puppet-custom-mxid-columns.go

@@ -1,23 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[5] = upgrade{"Add columns to store custom puppet info", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN custom_mxid VARCHAR(255)`)
-		if err != nil {
-			return err
-		}
-		_, err = tx.Exec(`ALTER TABLE puppet ADD COLUMN access_token VARCHAR(1023)`)
-		if err != nil {
-			return err
-		}
-		_, err = tx.Exec(`ALTER TABLE puppet ADD COLUMN next_batch VARCHAR(255)`)
-		if err != nil {
-			return err
-		}
-		return nil
-	}}
-}

+ 0 - 19
database/upgrades/2019-05-28-user-portal-table.go

@@ -1,19 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[6] = upgrade{"Add user-portal mapping table", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`CREATE TABLE user_portal (
-			user_jid        VARCHAR(255),
-			portal_jid      VARCHAR(255),
-			portal_receiver VARCHAR(255),
-			PRIMARY KEY (user_jid, portal_jid, portal_receiver),
-			FOREIGN KEY (user_jid)                    REFERENCES "user"(jid)           ON DELETE CASCADE,
-			FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
-		)`)
-		return err
-	}}
-}

+ 0 - 19
database/upgrades/2019-06-01-avatar-url-fields.go

@@ -1,19 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[7] = upgrade{"Add columns to store avatar MXC URIs", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN avatar_url VARCHAR(255)`)
-		if err != nil {
-			return err
-		}
-		_, err = tx.Exec(`ALTER TABLE portal ADD COLUMN avatar_url VARCHAR(255)`)
-		if err != nil {
-			return err
-		}
-		return nil
-	}}
-}

+ 0 - 12
database/upgrades/2019-08-10-portal-in-community-field.go

@@ -1,12 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[8] = upgrade{"Add columns to store portal in filtering community meta", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`ALTER TABLE user_portal ADD COLUMN in_community BOOLEAN NOT NULL DEFAULT FALSE`)
-		return err
-	}}
-}

+ 0 - 39
database/upgrades/2019-08-25-move-state-store-to-db.go

@@ -1,39 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-	"strings"
-)
-
-func init() {
-	userProfileTable := `CREATE TABLE mx_user_profile (
-		room_id     VARCHAR(255),
-		user_id     VARCHAR(255),
-		membership  VARCHAR(15) NOT NULL,
-		PRIMARY KEY (room_id, user_id)
-	)`
-
-	roomStateTable := `CREATE TABLE mx_room_state (
-		room_id      VARCHAR(255) PRIMARY KEY,
-		power_levels TEXT
-	)`
-
-	registrationsTable := `CREATE TABLE mx_registrations (
-		user_id VARCHAR(255) PRIMARY KEY
-	)`
-
-	upgrades[9] = upgrade{"Move state store to main DB", func(tx *sql.Tx, ctx context) error {
-		if ctx.dialect == Postgres {
-			roomStateTable = strings.Replace(roomStateTable, "TEXT", "JSONB", 1)
-		}
-
-		if _, err := tx.Exec(userProfileTable); err != nil {
-			return err
-		} else if _, err = tx.Exec(roomStateTable); err != nil {
-			return err
-		} else if _, err = tx.Exec(registrationsTable); err != nil {
-			return err
-		}
-		return nil
-	}}
-}

+ 0 - 16
database/upgrades/2019-11-10-full-member-state-store.go

@@ -1,16 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[10] = upgrade{"Add columns to store full member info in state store", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`ALTER TABLE mx_user_profile ADD COLUMN displayname TEXT`)
-		if err != nil {
-			return err
-		}
-		_, err = tx.Exec(`ALTER TABLE mx_user_profile ADD COLUMN avatar_url VARCHAR(255)`)
-		return err
-	}}
-}

+ 0 - 16
database/upgrades/2019-11-12-fix-room-topic-length.go

@@ -1,16 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[11] = upgrade{"Adjust the length of column topic in portal", func(tx *sql.Tx, ctx context) error {
-		if ctx.dialect == SQLite {
-			// SQLite doesn't support constraint updates, but it isn't that careful with constraints anyway.
-			return nil
-		}
-		_, err := tx.Exec(`ALTER TABLE portal ALTER COLUMN topic TYPE VARCHAR(512)`)
-		return err
-	}}
-}

+ 0 - 12
database/upgrades/2020-05-09-add-portal-encrypted-field.go

@@ -1,12 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[12] = upgrade{"Add encryption status to portal table", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`ALTER TABLE portal ADD COLUMN encrypted BOOLEAN NOT NULL DEFAULT false`)
-		return err
-	}}
-}

+ 0 - 73
database/upgrades/2020-05-09-crypto-store.go

@@ -1,73 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[13] = upgrade{"Add crypto store to database", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`CREATE TABLE crypto_account (
-			device_id  VARCHAR(255) PRIMARY KEY,
-			shared     BOOLEAN      NOT NULL,
-			sync_token TEXT         NOT NULL,
-			account    bytea        NOT NULL
-		)`)
-		if err != nil {
-			return err
-		}
-		_, err = tx.Exec(`CREATE TABLE crypto_message_index (
-			sender_key CHAR(43),
-			session_id CHAR(43),
-			"index"    INTEGER,
-			event_id   VARCHAR(255) NOT NULL,
-			timestamp  BIGINT       NOT NULL,
-
-			PRIMARY KEY (sender_key, session_id, "index")
-		)`)
-		if err != nil {
-			return err
-		}
-		_, err = tx.Exec(`CREATE TABLE crypto_tracked_user (
-			user_id VARCHAR(255) PRIMARY KEY
-		)`)
-		if err != nil {
-			return err
-		}
-		_, err = tx.Exec(`CREATE TABLE crypto_device (
-			user_id      VARCHAR(255),
-			device_id    VARCHAR(255),
-			identity_key CHAR(43)      NOT NULL,
-			signing_key  CHAR(43)      NOT NULL,
-			trust        SMALLINT      NOT NULL,
-			deleted      BOOLEAN       NOT NULL,
-			name         VARCHAR(255)  NOT NULL,
-
-			PRIMARY KEY (user_id, device_id)
-		)`)
-		if err != nil {
-			return err
-		}
-		_, err = tx.Exec(`CREATE TABLE crypto_olm_session (
-			session_id   CHAR(43)  PRIMARY KEY,
-			sender_key   CHAR(43)  NOT NULL,
-			session      bytea     NOT NULL,
-			created_at   timestamp NOT NULL,
-			last_used    timestamp NOT NULL
-		)`)
-		if err != nil {
-			return err
-		}
-		_, err = tx.Exec(`CREATE TABLE crypto_megolm_inbound_session (
-			session_id   CHAR(43)     PRIMARY KEY,
-			sender_key   CHAR(43)     NOT NULL,
-			signing_key  CHAR(43)     NOT NULL,
-			room_id      VARCHAR(255) NOT NULL,
-			session      bytea        NOT NULL,
-			forwarding_chains bytea   NOT NULL
-		)`)
-		if err != nil {
-			return err
-		}
-		return nil
-	}}
-}

+ 0 - 25
database/upgrades/2020-05-12-outbound-group-session-store.go

@@ -1,25 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[14] = upgrade{"Add outbound group sessions to database", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`CREATE TABLE crypto_megolm_outbound_session (
-			room_id       VARCHAR(255) PRIMARY KEY,
-			session_id    CHAR(43)     NOT NULL UNIQUE,
-			session       bytea        NOT NULL,
-			shared        BOOLEAN      NOT NULL,
-			max_messages  INTEGER      NOT NULL,
-			message_count INTEGER      NOT NULL,
-			max_age       BIGINT       NOT NULL,
-			created_at    timestamp    NOT NULL,
-			last_used     timestamp    NOT NULL
-		)`)
-		if err != nil {
-			return err
-		}
-		return nil
-	}}
-}

+ 0 - 12
database/upgrades/2020-07-10-custom-puppet-presence-toggle.go

@@ -1,12 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[15] = upgrade{"Add enable_presence column for puppets", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN enable_presence BOOLEAN NOT NULL DEFAULT true`)
-		return err
-	}}
-}

+ 0 - 13
database/upgrades/2020-07-10-update-crypto-store.go

@@ -1,13 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-
-	"maunium.net/go/mautrix/crypto/sql_store_upgrade"
-)
-
-func init() {
-	upgrades[16] = upgrade{"Add account_id to crypto store", func(tx *sql.Tx, c context) error {
-		return sql_store_upgrade.Upgrades[1](tx, c.dialect.String())
-	}}
-}

+ 0 - 12
database/upgrades/2020-07-10-x-custom-puppet-receipts-toggle.go

@@ -1,12 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[17] = upgrade{"Add enable_receipts column for puppets", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN enable_receipts BOOLEAN NOT NULL DEFAULT true`)
-		return err
-	}}
-}

+ 0 - 13
database/upgrades/2020-08-03-update-crypto-store.go

@@ -1,13 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-
-	"maunium.net/go/mautrix/crypto/sql_store_upgrade"
-)
-
-func init() {
-	upgrades[18] = upgrade{"Add megolm withheld data to crypto store", func(tx *sql.Tx, c context) error {
-		return sql_store_upgrade.Upgrades[2](tx, c.dialect.String())
-	}}
-}

+ 0 - 13
database/upgrades/2020-10-28-crypto-store-cross-signing.go

@@ -1,13 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-
-	"maunium.net/go/mautrix/crypto/sql_store_upgrade"
-)
-
-func init() {
-	upgrades[19] = upgrade{"Add cross-signing keys to crypto store", func(tx *sql.Tx, c context) error {
-		return sql_store_upgrade.Upgrades[3](tx, c.dialect.String())
-	}}
-}

+ 0 - 12
database/upgrades/2021-02-17-message-sent-status.go

@@ -1,12 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[20] = upgrade{"Add sent column for messages", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`ALTER TABLE message ADD COLUMN sent BOOLEAN NOT NULL DEFAULT true`)
-		return err
-	}}
-}

+ 0 - 44
database/upgrades/2021-08-19-remove-message-content.go

@@ -1,44 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[21] = upgrade{"Remove message content from local database", func(tx *sql.Tx, ctx context) error {
-		if ctx.dialect == SQLite {
-			_, err := tx.Exec("ALTER TABLE message RENAME TO old_message")
-			if err != nil {
-				return err
-			}
-			_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS message (
-				chat_jid      TEXT,
-				chat_receiver TEXT,
-				jid           TEXT,
-				mxid          TEXT    NOT NULL UNIQUE,
-				sender        TEXT    NOT NULL,
-				timestamp     BIGINT  NOT NULL,
-				sent          BOOLEAN NOT NULL,
-
-				PRIMARY KEY (chat_jid, chat_receiver, jid),
-				FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
-			)`)
-			if err != nil {
-				return err
-			}
-			_, err = tx.Exec("INSERT INTO message SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent FROM old_message")
-			return err
-		} else {
-			_, err := tx.Exec(`ALTER TABLE message DROP COLUMN content`)
-			if err != nil {
-				return err
-			}
-			_, err = tx.Exec(`ALTER TABLE message ALTER COLUMN timestamp DROP DEFAULT`)
-			if err != nil {
-				return err
-			}
-			_, err = tx.Exec(`ALTER TABLE message ALTER COLUMN sent DROP DEFAULT`)
-			return err
-		}
-	}}
-}

+ 0 - 13
database/upgrades/2021-08-19-varchar-to-text-crypto.go

@@ -1,13 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-
-	"maunium.net/go/mautrix/crypto/sql_store_upgrade"
-)
-
-func init() {
-	upgrades[23] = upgrade{"Replace VARCHAR(255) with TEXT in the crypto database", func(tx *sql.Tx, ctx context) error {
-		return sql_store_upgrade.Upgrades[4](tx, ctx.dialect.String())
-	}}
-}

+ 0 - 48
database/upgrades/2021-08-19-varchar-to-text.go

@@ -1,48 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[22] = upgrade{"Replace VARCHAR(255) with TEXT in the database", func(tx *sql.Tx, ctx context) error {
-		if ctx.dialect == SQLite {
-			// SQLite doesn't enforce varchar sizes anyway
-			return nil
-		}
-		return execMany(tx,
-			`ALTER TABLE message ALTER COLUMN chat_jid TYPE TEXT`,
-			`ALTER TABLE message ALTER COLUMN chat_receiver TYPE TEXT`,
-			`ALTER TABLE message ALTER COLUMN jid TYPE TEXT`,
-			`ALTER TABLE message ALTER COLUMN mxid TYPE TEXT`,
-			`ALTER TABLE message ALTER COLUMN sender TYPE TEXT`,
-
-			`ALTER TABLE portal ALTER COLUMN jid TYPE TEXT`,
-			`ALTER TABLE portal ALTER COLUMN receiver TYPE TEXT`,
-			`ALTER TABLE portal ALTER COLUMN mxid TYPE TEXT`,
-			`ALTER TABLE portal ALTER COLUMN name TYPE TEXT`,
-			`ALTER TABLE portal ALTER COLUMN topic TYPE TEXT`,
-			`ALTER TABLE portal ALTER COLUMN avatar TYPE TEXT`,
-			`ALTER TABLE portal ALTER COLUMN avatar_url TYPE TEXT`,
-
-			`ALTER TABLE puppet ALTER COLUMN jid TYPE TEXT`,
-			`ALTER TABLE puppet ALTER COLUMN avatar TYPE TEXT`,
-			`ALTER TABLE puppet ALTER COLUMN displayname TYPE TEXT`,
-			`ALTER TABLE puppet ALTER COLUMN custom_mxid TYPE TEXT`,
-			`ALTER TABLE puppet ALTER COLUMN access_token TYPE TEXT`,
-			`ALTER TABLE puppet ALTER COLUMN next_batch TYPE TEXT`,
-			`ALTER TABLE puppet ALTER COLUMN avatar_url TYPE TEXT`,
-
-			`ALTER TABLE "user" ALTER COLUMN mxid TYPE TEXT`,
-			`ALTER TABLE "user" ALTER COLUMN jid TYPE TEXT`,
-			`ALTER TABLE "user" ALTER COLUMN management_room TYPE TEXT`,
-			`ALTER TABLE "user" ALTER COLUMN client_id TYPE TEXT`,
-			`ALTER TABLE "user" ALTER COLUMN client_token TYPE TEXT`,
-			`ALTER TABLE "user" ALTER COLUMN server_token TYPE TEXT`,
-
-			`ALTER TABLE user_portal ALTER COLUMN user_jid TYPE TEXT`,
-			`ALTER TABLE user_portal ALTER COLUMN portal_jid TYPE TEXT`,
-			`ALTER TABLE user_portal ALTER COLUMN portal_receiver TYPE TEXT`,
-		)
-	}}
-}

+ 0 - 13
database/upgrades/2021-10-21-add-whatsmeow-store.go

@@ -1,13 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-
-	"go.mau.fi/whatsmeow/store/sqlstore"
-)
-
-func init() {
-	upgrades[24] = upgrade{"Add whatsmeow state store", func(tx *sql.Tx, ctx context) error {
-		return sqlstore.Upgrades[0](tx, sqlstore.NewWithDB(ctx.db, ctx.dialect.String(), nil))
-	}}
-}

+ 0 - 93
database/upgrades/2021-10-21-multidevice-updates.go

@@ -1,93 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[25] = upgrade{"Update things for multidevice", func(tx *sql.Tx, ctx context) error {
-		// This is probably not necessary
-		_, err := tx.Exec("DROP TABLE user_portal")
-		if err != nil {
-			return err
-		}
-
-		// Remove invalid puppet rows
-		_, err = tx.Exec("DELETE FROM puppet WHERE jid LIKE '%@g.us' OR jid LIKE '%@broadcast'")
-		if err != nil {
-			return err
-		}
-		// Remove the suffix from puppets since they'll all have the same suffix
-		_, err = tx.Exec("UPDATE puppet SET jid=REPLACE(jid, '@s.whatsapp.net', '')")
-		if err != nil {
-			return err
-		}
-		// Rename column to correctly represent the new content
-		_, err = tx.Exec("ALTER TABLE puppet RENAME COLUMN jid TO username")
-		if err != nil {
-			return err
-		}
-
-		if ctx.dialect == SQLite {
-			// Message content was removed from the main message table earlier, but the backup table still exists for SQLite
-			_, err = tx.Exec("DROP TABLE IF EXISTS old_message")
-
-			_, err = tx.Exec(`ALTER TABLE "user" RENAME TO old_user`)
-			if err != nil {
-				return err
-			}
-			_, err = tx.Exec(`CREATE TABLE "user" (
-				mxid     TEXT PRIMARY KEY,
-				username TEXT UNIQUE,
-				agent    SMALLINT,
-				device   SMALLINT,
-				management_room TEXT
-			)`)
-			if err != nil {
-				return err
-			}
-
-			// No need to copy auth data, users need to relogin anyway
-			_, err = tx.Exec(`INSERT INTO "user" (mxid, management_room) SELECT mxid, management_room FROM old_user`)
-			if err != nil {
-				return err
-			}
-
-			_, err = tx.Exec("DROP TABLE old_user")
-			if err != nil {
-				return err
-			}
-		} else {
-			// The jid column never actually contained the full JID, so let's rename it.
-			_, err = tx.Exec(`ALTER TABLE "user" RENAME COLUMN jid TO username`)
-			if err != nil {
-				return err
-			}
-
-			// The auth data is now in the whatsmeow_device table.
-			for _, column := range []string{"last_connection", "client_id", "client_token", "server_token", "enc_key", "mac_key"} {
-				_, err = tx.Exec(`ALTER TABLE "user" DROP COLUMN ` + column)
-				if err != nil {
-					return err
-				}
-			}
-
-			// The whatsmeow_device table is keyed by the full JID, so we need to store the other parts of the JID here too.
-			_, err = tx.Exec(`ALTER TABLE "user" ADD COLUMN agent SMALLINT`)
-			if err != nil {
-				return err
-			}
-			_, err = tx.Exec(`ALTER TABLE "user" ADD COLUMN device SMALLINT`)
-			if err != nil {
-				return err
-			}
-
-			// Clear all usernames, the users need to relogin anyway.
-			_, err = tx.Exec(`UPDATE "user" SET username=null`)
-			if err != nil {
-				return err
-			}
-		}
-		return nil
-	}}
-}

+ 0 - 19
database/upgrades/2021-10-26-portal-origin-event-id.go

@@ -1,19 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[26] = upgrade{"Add columns to store infinite backfill pointers for portals", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`ALTER TABLE portal ADD COLUMN first_event_id TEXT NOT NULL DEFAULT ''`)
-		if err != nil {
-			return err
-		}
-		_, err = tx.Exec(`ALTER TABLE portal ADD COLUMN next_batch_id TEXT NOT NULL DEFAULT ''`)
-		if err != nil {
-			return err
-		}
-		return nil
-	}}
-}

+ 0 - 12
database/upgrades/2021-10-27-message-decryption-errors.go

@@ -1,12 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[27] = upgrade{"Add marker for WhatsApp decryption errors in message table", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`ALTER TABLE message ADD COLUMN decryption_error BOOLEAN NOT NULL DEFAULT false`)
-		return err
-	}}
-}

+ 0 - 12
database/upgrades/2021-10-28-portal-relay-user.go

@@ -1,12 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[28] = upgrade{"Add relay user field to portal table", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`ALTER TABLE portal ADD COLUMN relay_user_id TEXT`)
-		return err
-	}}
-}

+ 0 - 22
database/upgrades/2021-10-30-varchar-to-text-state-store.go

@@ -1,22 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[29] = upgrade{"Replace VARCHAR(255) with TEXT in the Matrix state store", func(tx *sql.Tx, ctx context) error {
-		if ctx.dialect == SQLite {
-			// SQLite doesn't enforce varchar sizes anyway
-			return nil
-		}
-		return execMany(tx,
-			`ALTER TABLE mx_registrations ALTER COLUMN user_id TYPE TEXT`,
-			`ALTER TABLE mx_room_state ALTER COLUMN room_id TYPE TEXT`,
-			`ALTER TABLE mx_user_profile ALTER COLUMN room_id TYPE TEXT`,
-			`ALTER TABLE mx_user_profile ALTER COLUMN user_id TYPE TEXT`,
-			`ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE TEXT`,
-			`ALTER TABLE mx_user_profile ALTER COLUMN avatar_url TYPE TEXT`,
-		)
-	}}
-}

+ 0 - 22
database/upgrades/2021-11-30-store-last-read-state.go

@@ -1,22 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[30] = upgrade{"Store last read message timestamp in database", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`CREATE TABLE user_portal (
-			user_mxid       TEXT,
-			portal_jid      TEXT,
-			portal_receiver TEXT,
-
-			last_read_ts    BIGINT NOT NULL DEFAULT 0,
-
-			PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
-			FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
-			FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
-		)`)
-		return err
-	}}
-}

+ 0 - 13
database/upgrades/2021-12-22-crypto-store-last-used.go

@@ -1,13 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-
-	"maunium.net/go/mautrix/crypto/sql_store_upgrade"
-)
-
-func init() {
-	upgrades[31] = upgrade{"Split last_used into last_encrypted and last_decrypted in crypto store", func(tx *sql.Tx, c context) error {
-		return sql_store_upgrade.Upgrades[5](tx, c.dialect.String())
-	}}
-}

+ 0 - 12
database/upgrades/2021-12-25-broadcast-list-message-source.go

@@ -1,12 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[32] = upgrade{"Store source broadcast list in message table", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`ALTER TABLE message ADD COLUMN broadcast_list_jid TEXT`)
-		return err
-	}}
-}

+ 0 - 16
database/upgrades/2021-12-29-personal-filtering-spaces.go

@@ -1,16 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[33] = upgrade{"Add personal filtering space info to user tables", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN space_room TEXT NOT NULL DEFAULT ''`)
-		if err != nil {
-			return err
-		}
-		_, err = tx.Exec(`ALTER TABLE user_portal ADD COLUMN in_space BOOLEAN NOT NULL DEFAULT false`)
-		return err
-	}}
-}

+ 0 - 20
database/upgrades/2022-01-07-disappearing-messages.go

@@ -1,20 +0,0 @@
-package upgrades
-
-import "database/sql"
-
-func init() {
-	upgrades[34] = upgrade{"Add support for disappearing messages", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`ALTER TABLE portal ADD COLUMN expiration_time BIGINT NOT NULL DEFAULT 0 CHECK (expiration_time >= 0 AND expiration_time < 4294967296)`)
-		if err != nil {
-			return err
-		}
-		_, err = tx.Exec(`CREATE TABLE disappearing_message (
-			room_id   TEXT,
-			event_id  TEXT,
-			expire_in BIGINT NOT NULL,
-			expire_at BIGINT,
-			PRIMARY KEY (room_id, event_id)
-		)`)
-		return err
-	}}
-}

+ 0 - 10
database/upgrades/2022-01-24-phone-last-seen-ts.go

@@ -1,10 +0,0 @@
-package upgrades
-
-import "database/sql"
-
-func init() {
-	upgrades[35] = upgrade{"Store approximate last seen timestamp of the main device", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN phone_last_seen BIGINT`)
-		return err
-	}}
-}

+ 0 - 30
database/upgrades/2022-02-10-message-error-string.go

@@ -1,30 +0,0 @@
-package upgrades
-
-import "database/sql"
-
-func init() {
-	upgrades[36] = upgrade{"Store message error type as string", func(tx *sql.Tx, ctx context) error {
-		if ctx.dialect == Postgres {
-			_, err := tx.Exec("CREATE TYPE error_type AS ENUM ('', 'decryption_failed', 'media_not_found')")
-			if err != nil {
-				return err
-			}
-		}
-		_, err := tx.Exec("ALTER TABLE message ADD COLUMN error error_type NOT NULL DEFAULT ''")
-		if err != nil {
-			return err
-		}
-		_, err = tx.Exec("UPDATE message SET error='decryption_failed' WHERE decryption_error=true")
-		if err != nil {
-			return err
-		}
-		if ctx.dialect == Postgres {
-			// TODO do this on sqlite at some point
-			_, err = tx.Exec("ALTER TABLE message DROP COLUMN decryption_error")
-			if err != nil {
-				return err
-			}
-		}
-		return nil
-	}}
-}

+ 0 - 10
database/upgrades/2022-02-18-phone-ping-ts.go

@@ -1,10 +0,0 @@
-package upgrades
-
-import "database/sql"
-
-func init() {
-	upgrades[37] = upgrade{"Store timestamp for previous phone ping", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN phone_last_pinged BIGINT`)
-		return err
-	}}
-}

+ 0 - 39
database/upgrades/2022-03-05-reactions.go

@@ -1,39 +0,0 @@
-package upgrades
-
-import "database/sql"
-
-func init() {
-	upgrades[38] = upgrade{"Add support for reactions", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`ALTER TABLE message ADD COLUMN type TEXT NOT NULL DEFAULT 'message'`)
-		if err != nil {
-			return err
-		}
-		if ctx.dialect == Postgres {
-			_, err = tx.Exec("ALTER TABLE message ALTER COLUMN type DROP DEFAULT")
-			if err != nil {
-				return err
-			}
-		}
-		_, err = tx.Exec("UPDATE message SET type='' WHERE error='decryption_failed'")
-		if err != nil {
-			return err
-		}
-		_, err = tx.Exec("UPDATE message SET type='fake' WHERE jid LIKE 'FAKE::%' OR mxid LIKE 'net.maunium.whatsapp.fake::%' OR jid=mxid")
-		if err != nil {
-			return err
-		}
-		_, err = tx.Exec(`CREATE TABLE reaction (
-			chat_jid      TEXT,
-			chat_receiver TEXT,
-			target_jid    TEXT,
-			sender        TEXT,
-			mxid          TEXT NOT NULL,
-			jid           TEXT NOT NULL,
-			PRIMARY KEY (chat_jid, chat_receiver, target_jid, sender),
-			CONSTRAINT target_message_fkey FOREIGN KEY (chat_jid, chat_receiver, target_jid)
-				REFERENCES message(chat_jid, chat_receiver, jid)
-				ON DELETE CASCADE ON UPDATE CASCADE
-		)`)
-		return err
-	}}
-}

+ 0 - 45
database/upgrades/2022-03-15-prioritized-backfill.go

@@ -1,45 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-	"fmt"
-)
-
-func init() {
-	upgrades[39] = upgrade{"Add backfill queue", func(tx *sql.Tx, ctx context) error {
-		// The queue_id needs to auto-increment every insertion. For SQLite,
-		// INTEGER PRIMARY KEY is an alias for the ROWID, so it will
-		// auto-increment. See https://sqlite.org/lang_createtable.html#rowid
-		// For Postgres, we need to add GENERATED ALWAYS AS IDENTITY for the
-		// same functionality.
-		queueIDColumnTypeModifier := ""
-		if ctx.dialect == Postgres {
-			queueIDColumnTypeModifier = "GENERATED ALWAYS AS IDENTITY"
-		}
-
-		_, err := tx.Exec(fmt.Sprintf(`
-			CREATE TABLE backfill_queue (
-				queue_id            INTEGER PRIMARY KEY %s,
-				user_mxid           TEXT,
-				type                INTEGER NOT NULL,
-				priority            INTEGER NOT NULL,
-				portal_jid          TEXT,
-				portal_receiver     TEXT,
-				time_start          TIMESTAMP,
-				time_end            TIMESTAMP,
-				max_batch_events    INTEGER NOT NULL,
-				max_total_events    INTEGER,
-				batch_delay         INTEGER,
-				completed_at        TIMESTAMP,
-
-				FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
-				FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
-			)
-		`, queueIDColumnTypeModifier))
-		if err != nil {
-			return err
-		}
-
-		return err
-	}}
-}

+ 0 - 52
database/upgrades/2022-03-18-historysync-store.go

@@ -1,52 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[40] = upgrade{"Store history syncs for later backfills", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`
-			CREATE TABLE history_sync_conversation (
-				user_mxid                       TEXT,
-				conversation_id                 TEXT,
-				portal_jid                      TEXT,
-				portal_receiver                 TEXT,
-				last_message_timestamp          TIMESTAMP,
-				archived                        BOOLEAN,
-				pinned                          INTEGER,
-				mute_end_time                   TIMESTAMP,
-				disappearing_mode               INTEGER,
-				end_of_history_transfer_type    INTEGER,
-				ephemeral_expiration            INTEGER,
-				marked_as_unread                BOOLEAN,
-				unread_count                    INTEGER,
-
-				PRIMARY KEY (user_mxid, conversation_id),
-				FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
-				FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
-			)
-		`)
-		if err != nil {
-			return err
-		}
-		_, err = tx.Exec(`
-			CREATE TABLE history_sync_message (
-				user_mxid                TEXT,
-				conversation_id          TEXT,
-				message_id               TEXT,
-				timestamp                TIMESTAMP,
-				data                     BYTEA,
-
-				PRIMARY KEY (user_mxid, conversation_id, message_id),
-				FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
-				FOREIGN KEY (user_mxid, conversation_id) REFERENCES history_sync_conversation(user_mxid, conversation_id) ON DELETE CASCADE
-			)
-		`)
-		if err != nil {
-			return err
-		}
-
-		return nil
-	}}
-}

+ 0 - 20
database/upgrades/2022-04-29-backfillqueue-type-order.go

@@ -1,20 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[41] = upgrade{"Update backfill queue tables to be sortable by priority", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`
-			UPDATE backfill_queue
-				SET type=CASE
-					WHEN type=1 THEN 200
-					WHEN type=2 THEN 300
-					ELSE type
-				END
-				WHERE type=1 OR type=2
-		`)
-		return err
-	}}
-}

+ 0 - 26
database/upgrades/2022-05-09-media-backfill-requests-queue-table.go

@@ -1,26 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[42] = upgrade{"Add table of media to request from the user's phone", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`
-			CREATE TABLE media_backfill_requests (
-				user_mxid           TEXT,
-				portal_jid          TEXT,
-				portal_receiver     TEXT,
-				event_id            TEXT,
-				media_key           BYTEA,
-				status              INTEGER,
-				error               TEXT,
-
-				PRIMARY KEY (user_mxid, portal_jid, portal_receiver, event_id),
-				FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
-				FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
-			)
-		`)
-		return err
-	}}
-}

+ 0 - 12
database/upgrades/2022-05-11-add-user-timezone.go

@@ -1,12 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[43] = upgrade{"Add timezone column to user table", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN timezone TEXT`)
-		return err
-	}}
-}

+ 0 - 34
database/upgrades/2022-05-12-backfillqueue-dispatch-time.go

@@ -1,34 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[44] = upgrade{"Add dispatch time to backfill queue", func(tx *sql.Tx, ctx context) error {
-		// First, add dispatch_time TIMESTAMP column
-		_, err := tx.Exec(`
-			ALTER TABLE backfill_queue
-			ADD COLUMN dispatch_time TIMESTAMP
-		`)
-		if err != nil {
-			return err
-		}
-
-		// For all previous jobs, set dispatch time to the completed time.
-		_, err = tx.Exec(`
-			UPDATE backfill_queue
-				SET dispatch_time=completed_at
-		`)
-		if err != nil {
-			return err
-		}
-
-		// Remove time_end from the backfill queue
-		_, err = tx.Exec(`
-			ALTER TABLE backfill_queue
-			DROP COLUMN time_end
-		`)
-		return err
-	}}
-}

+ 0 - 16
database/upgrades/2022-05-12-history-sync-message-add-added-timestamp.go

@@ -1,16 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[45] = upgrade{"Add inserted time to history sync message", func(tx *sql.Tx, ctx context) error {
-		// Add the inserted time TIMESTAMP column to history_sync_message
-		_, err := tx.Exec(`
-			ALTER TABLE history_sync_message
-			ADD COLUMN inserted_time TIMESTAMP
-		`)
-		return err
-	}}
-}

+ 0 - 25
database/upgrades/2022-05-16-room-backfill-state.go

@@ -1,25 +0,0 @@
-package upgrades
-
-import (
-	"database/sql"
-)
-
-func init() {
-	upgrades[46] = upgrade{"Create the backfill state table", func(tx *sql.Tx, ctx context) error {
-		_, err := tx.Exec(`
-			CREATE TABLE backfill_state (
-				user_mxid           TEXT,
-				portal_jid          TEXT,
-				portal_receiver     TEXT,
-				processing_batch    BOOLEAN,
-				backfill_complete   BOOLEAN,
-				first_expected_ts   INTEGER,
-
-				PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
-				FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
-				FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
-			)
-		`)
-		return err
-	}}
-}

+ 5 - 0
database/upgrades/45-backfillqueue-dispatch-time.sql

@@ -0,0 +1,5 @@
+-- v45: Add dispatch time to backfill queue
+
+ALTER TABLE backfill_queue ADD COLUMN dispatch_time TIMESTAMP;
+UPDATE backfill_queue SET dispatch_time=completed_at;
+ALTER TABLE backfill_queue DROP COLUMN time_end;

+ 3 - 0
database/upgrades/46-history-sync-message-added-timestamp.sql

@@ -0,0 +1,3 @@
+-- v46: Add inserted time to history sync message
+
+ALTER TABLE history_sync_message ADD COLUMN inserted_time TIMESTAMP;

+ 13 - 0
database/upgrades/47-room-backfill-state.sql

@@ -0,0 +1,13 @@
+-- v47: Add table for keeping track of backfill state
+
+CREATE TABLE backfill_state (
+    user_mxid         TEXT,
+    portal_jid        TEXT,
+    portal_receiver   TEXT,
+    processing_batch  BOOLEAN,
+    backfill_complete BOOLEAN,
+    first_expected_ts TIMESTAMP,
+    PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
+    FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE ON UPDATE CASCADE,
+    FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal (jid, receiver) ON DELETE CASCADE
+);

+ 7 - 0
database/upgrades/48-crypto-store-handling-split.sql

@@ -0,0 +1,7 @@
+-- v48: Move crypto/state/whatsmeow store upgrade handling to separate systems
+CREATE TABLE crypto_version (version INTEGER PRIMARY KEY);
+INSERT INTO crypto_version VALUES (6);
+CREATE TABLE whatsmeow_version (version INTEGER PRIMARY KEY);
+INSERT INTO whatsmeow_version VALUES (1);
+CREATE TABLE mx_version (version INTEGER PRIMARY KEY);
+INSERT INTO mx_version VALUES (1);

+ 16 - 169
database/upgrades/upgrades.go

@@ -1,180 +1,27 @@
+// Copyright (c) 2022 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
 package upgrades
 package upgrades
 
 
 import (
 import (
 	"database/sql"
 	"database/sql"
+	"embed"
 	"errors"
 	"errors"
-	"fmt"
-	"strings"
-
-	log "maunium.net/go/maulogger/v2"
-)
-
-type Dialect int
 
 
-const (
-	Postgres Dialect = iota
-	SQLite
+	"maunium.net/go/mautrix/util/dbutil"
 )
 )
 
 
-func (dialect Dialect) String() string {
-	switch dialect {
-	case Postgres:
-		return "postgres"
-	case SQLite:
-		return "sqlite3"
-	default:
-		return ""
-	}
-}
-
-type upgradeFunc func(*sql.Tx, context) error
-
-type context struct {
-	dialect Dialect
-	db      *sql.DB
-	log     log.Logger
-}
-
-type upgrade struct {
-	message string
-	fn      upgradeFunc
-}
-
-const NumberOfUpgrades = 47
-
-var upgrades [NumberOfUpgrades]upgrade
-
-var ErrUnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version")
-var ErrForeignTables = fmt.Errorf("the database contains foreign tables")
-var ErrNotOwned = fmt.Errorf("the database is owned by")
-var IgnoreForeignTables = false
-
-const databaseOwner = "mautrix-whatsapp"
-
-func GetVersion(db *sql.DB) (int, error) {
-	_, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
-	if err != nil {
-		return -1, err
-	}
-
-	version := 0
-	err = db.QueryRow("SELECT version FROM version LIMIT 1").Scan(&version)
-	if err != nil && !errors.Is(err, sql.ErrNoRows) {
-		return -1, err
-	}
-	return version, nil
-}
-
-const tableExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)"
-const tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND table_name=$1)"
-
-func tableExists(dialect Dialect, db *sql.DB, table string) (exists bool) {
-	if dialect == SQLite {
-		_ = db.QueryRow(tableExistsSQLite, table).Scan(&exists)
-	} else if dialect == Postgres {
-		_ = db.QueryRow(tableExistsPostgres, table).Scan(&exists)
-	}
-	return
-}
-
-const createOwnerTable = `
-CREATE TABLE IF NOT EXISTS database_owner (
-	key   INTEGER PRIMARY KEY DEFAULT 0,
-	owner TEXT NOT NULL
-)
-`
-
-func CheckDatabaseOwner(dialect Dialect, db *sql.DB) error {
-	var owner string
-	if !IgnoreForeignTables {
-		if tableExists(dialect, db, "state_groups_state") {
-			return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables)
-		} else if tableExists(dialect, db, "goose_db_version") {
-			return fmt.Errorf("%w (found goose_db_version, possibly belonging to Dendrite)", ErrForeignTables)
-		}
-	}
-	if _, err := db.Exec(createOwnerTable); err != nil {
-		return fmt.Errorf("failed to ensure database owner table exists: %w", err)
-	} else if err = db.QueryRow("SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
-		_, err = db.Exec("INSERT INTO database_owner (owner) VALUES ($1)", databaseOwner)
-		if err != nil {
-			return fmt.Errorf("failed to insert database owner: %w", err)
-		}
-	} else if err != nil {
-		return fmt.Errorf("failed to check database owner: %w", err)
-	} else if owner != databaseOwner {
-		return fmt.Errorf("%w %s", ErrNotOwned, owner)
-	}
-	return nil
-}
-
-func SetVersion(tx *sql.Tx, version int) error {
-	_, err := tx.Exec("DELETE FROM version")
-	if err != nil {
-		return err
-	}
-	_, err = tx.Exec("INSERT INTO version (version) VALUES ($1)", version)
-	return err
-}
-
-func execMany(tx *sql.Tx, queries ...string) error {
-	for _, query := range queries {
-		_, err := tx.Exec(query)
-		if err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
-func Run(log log.Logger, dialectName string, db *sql.DB) error {
-	var dialect Dialect
-	switch strings.ToLower(dialectName) {
-	case "postgres":
-		dialect = Postgres
-	case "sqlite3":
-		dialect = SQLite
-	default:
-		return fmt.Errorf("unknown dialect %s", dialectName)
-	}
-
-	err := CheckDatabaseOwner(dialect, db)
-	if err != nil {
-		return err
-	}
-
-	version, err := GetVersion(db)
-	if err != nil {
-		return err
-	}
+var Table dbutil.UpgradeTable
 
 
-	if version > NumberOfUpgrades {
-		return fmt.Errorf("%w: currently on v%d, latest known: v%d", ErrUnsupportedDatabaseVersion, version, NumberOfUpgrades)
-	}
+//go:embed *.sql
+var rawUpgrades embed.FS
 
 
-	log.Infofln("Database currently on v%d, latest: v%d", version, NumberOfUpgrades)
-	for i, upgradeItem := range upgrades[version:] {
-		if upgradeItem.fn == nil {
-			continue
-		}
-		log.Infofln("Upgrading database to v%d: %s", version+i+1, upgradeItem.message)
-		var tx *sql.Tx
-		tx, err = db.Begin()
-		if err != nil {
-			return err
-		}
-		err = upgradeItem.fn(tx, context{dialect, db, log})
-		if err != nil {
-			return err
-		}
-		err = SetVersion(tx, version+i+1)
-		if err != nil {
-			return err
-		}
-		err = tx.Commit()
-		if err != nil {
-			return err
-		}
-	}
-	return nil
+func init() {
+	Table.Register(-1, 43, "Unsupported version", func(tx *sql.Tx, database *dbutil.Database) error {
+		return errors.New("please upgrade to mautrix-whatsapp v0.4.0 before upgrading to a newer version")
+	})
+	Table.RegisterFS(rawUpgrades)
 }
 }

+ 2 - 1
database/user.go

@@ -24,6 +24,7 @@ import (
 	log "maunium.net/go/maulogger/v2"
 	log "maunium.net/go/maulogger/v2"
 
 
 	"maunium.net/go/mautrix/id"
 	"maunium.net/go/mautrix/id"
+	"maunium.net/go/mautrix/util/dbutil"
 
 
 	"go.mau.fi/whatsmeow/types"
 	"go.mau.fi/whatsmeow/types"
 )
 )
@@ -89,7 +90,7 @@ type User struct {
 	inSpaceCacheLock  sync.Mutex
 	inSpaceCacheLock  sync.Mutex
 }
 }
 
 
-func (user *User) Scan(row Scannable) *User {
+func (user *User) Scan(row dbutil.Scannable) *User {
 	var username, timezone sql.NullString
 	var username, timezone sql.NullString
 	var device, agent sql.NullByte
 	var device, agent sql.NullByte
 	var phoneLastSeen, phoneLastPinged sql.NullInt64
 	var phoneLastSeen, phoneLastPinged sql.NullInt64

+ 3 - 3
disappear.go

@@ -50,9 +50,9 @@ func (portal *Portal) ScheduleDisappearing() {
 	}
 	}
 }
 }
 
 
-func (bridge *Bridge) SleepAndDeleteUpcoming() {
-	for _, msg := range bridge.DB.DisappearingMessage.GetUpcomingScheduled(1 * time.Hour) {
-		portal := bridge.GetPortalByMXID(msg.RoomID)
+func (br *WABridge) SleepAndDeleteUpcoming() {
+	for _, msg := range br.DB.DisappearingMessage.GetUpcomingScheduled(1 * time.Hour) {
+		portal := br.GetPortalByMXID(msg.RoomID)
 		if portal == nil {
 		if portal == nil {
 			msg.Delete()
 			msg.Delete()
 		} else {
 		} else {

+ 8 - 8
example-config.yaml

@@ -43,14 +43,6 @@ appservice:
         max_conn_idle_time: null
         max_conn_idle_time: null
         max_conn_lifetime: null
         max_conn_lifetime: null
 
 
-    # Settings for provisioning API
-    provisioning:
-        # Prefix for the provisioning API paths.
-        prefix: /_matrix/provision
-        # Shared secret for authentication. If set to "generate", a random secret will be generated,
-        # or if set to "disable", the provisioning API will be disabled.
-        shared_secret: generate
-
     # The unique ID of this appservice.
     # The unique ID of this appservice.
     id: whatsapp
     id: whatsapp
     # Appservice bot details.
     # Appservice bot details.
@@ -317,6 +309,14 @@ bridge:
             # Verification by the bridge is not yet implemented.
             # Verification by the bridge is not yet implemented.
             require_verification: true
             require_verification: true
 
 
+    # Settings for provisioning API
+    provisioning:
+        # Prefix for the provisioning API paths.
+        prefix: /_matrix/provision
+        # Shared secret for authentication. If set to "generate", a random secret will be generated,
+        # or if set to "disable", the provisioning API will be disabled.
+        shared_secret: generate
+
     # Permissions for using the bridge.
     # Permissions for using the bridge.
     # Permitted values:
     # Permitted values:
     #    relay - Talk through the relaybot (if enabled), no access otherwise
     #    relay - Talk through the relaybot (if enabled), no access otherwise

+ 2 - 2
formatting.go

@@ -37,7 +37,7 @@ var codeBlockRegex = regexp.MustCompile("```(?:.|\n)+?```")
 const mentionedJIDsContextKey = "net.maunium.whatsapp.mentioned_jids"
 const mentionedJIDsContextKey = "net.maunium.whatsapp.mentioned_jids"
 
 
 type Formatter struct {
 type Formatter struct {
-	bridge *Bridge
+	bridge *WABridge
 
 
 	matrixHTMLParser *format.HTMLParser
 	matrixHTMLParser *format.HTMLParser
 
 
@@ -46,7 +46,7 @@ type Formatter struct {
 	waReplFuncText map[*regexp.Regexp]func(string) string
 	waReplFuncText map[*regexp.Regexp]func(string) string
 }
 }
 
 
-func NewFormatter(bridge *Bridge) *Formatter {
+func NewFormatter(bridge *WABridge) *Formatter {
 	formatter := &Formatter{
 	formatter := &Formatter{
 		bridge: bridge,
 		bridge: bridge,
 		matrixHTMLParser: &format.HTMLParser{
 		matrixHTMLParser: &format.HTMLParser{

+ 3 - 4
go.mod

@@ -14,10 +14,8 @@ require (
 	golang.org/x/image v0.0.0-20220413100746-70e8d0d3baa9
 	golang.org/x/image v0.0.0-20220413100746-70e8d0d3baa9
 	golang.org/x/net v0.0.0-20220513224357-95641704303c
 	golang.org/x/net v0.0.0-20220513224357-95641704303c
 	google.golang.org/protobuf v1.28.0
 	google.golang.org/protobuf v1.28.0
-	gopkg.in/yaml.v3 v3.0.0-20220512140231-539c8e751b99
-	maunium.net/go/mauflag v1.0.0
 	maunium.net/go/maulogger/v2 v2.3.2
 	maunium.net/go/maulogger/v2 v2.3.2
-	maunium.net/go/mautrix v0.11.1-0.20220518174602-87d2cd49a4d1
+	maunium.net/go/mautrix v0.11.1-0.20220521215033-d578d1a610d5
 )
 )
 
 
 require (
 require (
@@ -37,7 +35,8 @@ require (
 	golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 // indirect
 	golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 // indirect
 	golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886 // indirect
 	golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886 // indirect
 	golang.org/x/text v0.3.7 // indirect
 	golang.org/x/text v0.3.7 // indirect
-	gopkg.in/yaml.v2 v2.4.0 // indirect
+	gopkg.in/yaml.v3 v3.0.0 // indirect
+	maunium.net/go/mauflag v1.0.0 // indirect
 )
 )
 
 
 // Exclude some things that cause go.sum to explode
 // Exclude some things that cause go.sum to explode

+ 4 - 5
go.sum

@@ -99,14 +99,13 @@ google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw
 gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
 gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
 gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
 gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-gopkg.in/yaml.v3 v3.0.0-20220512140231-539c8e751b99 h1:dbuHpmKjkDzSOMKAWl10QNlgaZUd3V1q99xc81tt2Kc=
-gopkg.in/yaml.v3 v3.0.0-20220512140231-539c8e751b99/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gopkg.in/yaml.v3 v3.0.0 h1:hjy8E9ON/egN1tAYqKb61G10WtihqetD4sz2H+8nIeA=
+gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
 maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
 maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
 maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
 maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
 maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0=
 maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0=
 maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A=
 maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A=
-maunium.net/go/mautrix v0.11.1-0.20220518174602-87d2cd49a4d1 h1:+KEF+nSuBfHWsfQRz92YP/DdSLbComLoXCXgcrH6WRU=
-maunium.net/go/mautrix v0.11.1-0.20220518174602-87d2cd49a4d1/go.mod h1:K29EcHwsNg6r7fMfwvi0GHQ9o5wSjqB9+Q8RjCIQEjA=
+maunium.net/go/mautrix v0.11.1-0.20220521215033-d578d1a610d5 h1:7ZORg2h+lflc1HwjTKCXZnykauXD+wzbW+VDknbv6SU=
+maunium.net/go/mautrix v0.11.1-0.20220521215033-d578d1a610d5/go.mod h1:oma8o6Y/5jcViBlDbX7tp1ajP2XP+b78h8twdI+zKI0=

+ 119 - 394
main.go

@@ -18,43 +18,26 @@ package main
 
 
 import (
 import (
 	_ "embed"
 	_ "embed"
-	"errors"
-	"fmt"
 	"net/http"
 	"net/http"
 	"os"
 	"os"
-	"os/signal"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
-	"syscall"
 	"time"
 	"time"
 
 
-	"google.golang.org/protobuf/proto"
-
 	"go.mau.fi/whatsmeow"
 	"go.mau.fi/whatsmeow"
 	waProto "go.mau.fi/whatsmeow/binary/proto"
 	waProto "go.mau.fi/whatsmeow/binary/proto"
 	"go.mau.fi/whatsmeow/store"
 	"go.mau.fi/whatsmeow/store"
 	"go.mau.fi/whatsmeow/store/sqlstore"
 	"go.mau.fi/whatsmeow/store/sqlstore"
 	"go.mau.fi/whatsmeow/types"
 	"go.mau.fi/whatsmeow/types"
+	"google.golang.org/protobuf/proto"
 
 
-	flag "maunium.net/go/mauflag"
-	log "maunium.net/go/maulogger/v2"
-
-	"maunium.net/go/mautrix"
-	"maunium.net/go/mautrix/appservice"
-	"maunium.net/go/mautrix/event"
+	"maunium.net/go/mautrix/bridge"
 	"maunium.net/go/mautrix/id"
 	"maunium.net/go/mautrix/id"
 	"maunium.net/go/mautrix/util/configupgrade"
 	"maunium.net/go/mautrix/util/configupgrade"
 
 
 	"maunium.net/go/mautrix-whatsapp/config"
 	"maunium.net/go/mautrix-whatsapp/config"
 	"maunium.net/go/mautrix-whatsapp/database"
 	"maunium.net/go/mautrix-whatsapp/database"
-	"maunium.net/go/mautrix-whatsapp/database/upgrades"
-)
-
-// The name and repo URL of the bridge.
-var (
-	Name = "mautrix-whatsapp"
-	URL  = "https://github.com/mautrix/whatsapp"
 )
 )
 
 
 // Information to find out exactly which commit the bridge was built from.
 // Information to find out exactly which commit the bridge was built from.
@@ -65,120 +48,19 @@ var (
 	BuildTime = "unknown"
 	BuildTime = "unknown"
 )
 )
 
 
-var (
-	// Version is the version number of the bridge. Changed manually when making a release.
-	Version = "0.4.0"
-	// WAVersion is the version number exposed to WhatsApp. Filled in init()
-	WAVersion = ""
-	// VersionString is the bridge version, plus commit information. Filled in init() using the build-time values.
-	VersionString = ""
-)
-
 //go:embed example-config.yaml
 //go:embed example-config.yaml
 var ExampleConfig string
 var ExampleConfig string
 
 
-func init() {
-	if len(Tag) > 0 && Tag[0] == 'v' {
-		Tag = Tag[1:]
-	}
-	if Tag != Version {
-		suffix := ""
-		if !strings.HasSuffix(Version, "+dev") {
-			suffix = "+dev"
-		}
-		if len(Commit) > 8 {
-			Version = fmt.Sprintf("%s%s.%s", Version, suffix, Commit[:8])
-		} else {
-			Version = fmt.Sprintf("%s%s.unknown", Version, suffix)
-		}
-	}
-	mautrix.DefaultUserAgent = fmt.Sprintf("mautrix-whatsapp/%s %s", Version, mautrix.DefaultUserAgent)
-	WAVersion = strings.FieldsFunc(Version, func(r rune) bool { return r == '-' || r == '+' })[0]
-	VersionString = fmt.Sprintf("%s %s (%s)", Name, Version, BuildTime)
-
-	config.ExampleConfig = ExampleConfig
-}
-
-var configPath = flag.MakeFull("c", "config", "The path to your config file.", "config.yaml").String()
-var dontSaveConfig = flag.MakeFull("n", "no-update", "Don't save updated config to disk.", "false").Bool()
-var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String()
-var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool()
-var version = flag.MakeFull("v", "version", "View bridge version and quit.", "false").Bool()
-var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if the database schema is too new").Default("false").Bool()
-var ignoreForeignTables = flag.Make().LongKey("ignore-foreign-tables").Usage("Run even if the database contains tables from other programs (like Synapse)").Default("false").Bool()
-var migrateFrom = flag.Make().LongKey("migrate-db").Usage("Source database type and URI to migrate from.").Bool()
-var wantHelp, _ = flag.MakeHelpFlag()
-
-func (bridge *Bridge) GenerateRegistration() {
-	if *dontSaveConfig {
-		// We need to save the generated as_token and hs_token in the config
-		_, _ = fmt.Fprintln(os.Stderr, "--no-update is not compatible with --generate-registration")
-		os.Exit(5)
-	}
-	reg, err := bridge.Config.NewRegistration()
-	if err != nil {
-		_, _ = fmt.Fprintln(os.Stderr, "Failed to generate registration:", err)
-		os.Exit(20)
-	}
-
-	err = reg.Save(*registrationPath)
-	if err != nil {
-		_, _ = fmt.Fprintln(os.Stderr, "Failed to save registration:", err)
-		os.Exit(21)
-	}
-
-	err = config.Mutate(*configPath, func(helper *configupgrade.Helper) {
-		helper.Set(configupgrade.Str, bridge.Config.AppService.ASToken, "appservice", "as_token")
-		helper.Set(configupgrade.Str, bridge.Config.AppService.HSToken, "appservice", "hs_token")
-	})
-	if err != nil {
-		_, _ = fmt.Fprintln(os.Stderr, "Failed to save config:", err)
-		os.Exit(22)
-	}
-	fmt.Println("Registration generated. Add the path to the registration to your Synapse config, restart it, then start the bridge.")
-	os.Exit(0)
-}
-
-func (bridge *Bridge) MigrateDatabase() {
-	oldDB, err := database.New(config.DatabaseConfig{Type: flag.Arg(0), URI: flag.Arg(1)}, log.DefaultLogger)
-	if err != nil {
-		fmt.Println("Failed to open old database:", err)
-		os.Exit(30)
-	}
-	err = oldDB.Init()
-	if err != nil {
-		fmt.Println("Failed to upgrade old database:", err)
-		os.Exit(31)
-	}
-
-	newDB, err := database.New(bridge.Config.AppService.Database, log.DefaultLogger)
-	if err != nil {
-		fmt.Println("Failed to open new database:", err)
-		os.Exit(32)
-	}
-	err = newDB.Init()
-	if err != nil {
-		fmt.Println("Failed to upgrade new database:", err)
-		os.Exit(33)
-	}
-
-	database.Migrate(oldDB, newDB)
-}
-
-type Bridge struct {
-	AS             *appservice.AppService
-	EventProcessor *appservice.EventProcessor
-	MatrixHandler  *MatrixHandler
-	Config         *config.Config
-	DB             *database.Database
-	Log            log.Logger
-	StateStore     *database.SQLStateStore
-	Provisioning   *ProvisioningAPI
-	Bot            *appservice.IntentAPI
-	Formatter      *Formatter
-	Crypto         Crypto
-	Metrics        *MetricsHandler
-	WAContainer    *sqlstore.Container
+type WABridge struct {
+	bridge.Bridge
+	MatrixHandler *MatrixHandler
+	Config        *config.Config
+	DB            *database.Database
+	Provisioning  *ProvisioningAPI
+	Formatter     *Formatter
+	Metrics       *MetricsHandler
+	WAContainer   *sqlstore.Container
+	WAVersion     string
 
 
 	usersByMXID         map[id.UserID]*User
 	usersByMXID         map[id.UserID]*User
 	usersByUsername     map[string]*User
 	usersByUsername     map[string]*User
@@ -195,111 +77,32 @@ type Bridge struct {
 	puppetsLock         sync.Mutex
 	puppetsLock         sync.Mutex
 }
 }
 
 
-type Crypto interface {
-	HandleMemberEvent(*event.Event)
-	Decrypt(*event.Event) (*event.Event, error)
-	Encrypt(id.RoomID, event.Type, event.Content) (*event.EncryptedEventContent, error)
-	WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
-	RequestSession(id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID)
-	ResetSession(id.RoomID)
-	Init() error
-	Start()
-	Stop()
-}
-
-func (bridge *Bridge) ensureConnection() {
-	for {
-		versions, err := bridge.Bot.Versions()
-		if err != nil {
-			bridge.Log.Errorfln("Failed to connect to homeserver: %v. Retrying in 10 seconds...", err)
-			time.Sleep(10 * time.Second)
-			continue
-		}
-		if !versions.ContainsGreaterOrEqual(mautrix.SpecV11) {
-			bridge.Log.Warnfln("Server isn't advertising modern spec versions")
-		}
-		resp, err := bridge.Bot.Whoami()
-		if err != nil {
-			if errors.Is(err, mautrix.MUnknownToken) {
-				bridge.Log.Fatalln("The as_token was not accepted. Is the registration file installed in your homeserver correctly?")
-				os.Exit(16)
-			} else if errors.Is(err, mautrix.MExclusive) {
-				bridge.Log.Fatalln("The as_token was accepted, but the /register request was not. Are the homeserver domain and username template in the config correct, and do they match the values in the registration?")
-				os.Exit(16)
-			}
-			bridge.Log.Errorfln("Failed to connect to homeserver: %v. Retrying in 10 seconds...", err)
-			time.Sleep(10 * time.Second)
-		} else if resp.UserID != bridge.Bot.UserID {
-			bridge.Log.Fatalln("Unexpected user ID in whoami call: got %s, expected %s", resp.UserID, bridge.Bot.UserID)
-			os.Exit(17)
-		} else {
-			break
-		}
-	}
-}
-
-func (bridge *Bridge) Init() {
-	var err error
-
-	bridge.AS, err = bridge.Config.MakeAppService()
-	if err != nil {
-		_, _ = fmt.Fprintln(os.Stderr, "Failed to initialize AppService:", err)
-		os.Exit(11)
-	}
-	_, _ = bridge.AS.Init()
-
-	bridge.Log = log.Create()
-	bridge.Config.Logging.Configure(bridge.Log)
-	log.DefaultLogger = bridge.Log.(*log.BasicLogger)
-	if len(bridge.Config.Logging.FileNameFormat) > 0 {
-		err = log.OpenFile()
-		if err != nil {
-			_, _ = fmt.Fprintln(os.Stderr, "Failed to open log file:", err)
-			os.Exit(12)
-		}
-	}
-	bridge.AS.Log = log.Sub("Matrix")
-	bridge.Bot = bridge.AS.BotIntent()
-	bridge.Log.Infoln("Initializing", VersionString)
-
-	bridge.Log.Debugln("Initializing database connection")
-	bridge.DB, err = database.New(bridge.Config.AppService.Database, bridge.Log)
-	if err != nil {
-		bridge.Log.Fatalln("Failed to initialize database connection:", err)
-		os.Exit(14)
-	}
-
-	bridge.Log.Debugln("Initializing state store")
-	bridge.StateStore = database.NewSQLStateStore(bridge.DB)
-	bridge.AS.StateStore = bridge.StateStore
-
-	Segment.log = bridge.Log.Sub("Segment")
-	Segment.key = bridge.Config.SegmentKey
+func (br *WABridge) Init() {
+	Segment.log = br.Log.Sub("Segment")
+	Segment.key = br.Config.SegmentKey
 	if Segment.IsEnabled() {
 	if Segment.IsEnabled() {
 		Segment.log.Infoln("Segment metrics are enabled")
 		Segment.log.Infoln("Segment metrics are enabled")
 	}
 	}
 
 
-	bridge.WAContainer = sqlstore.NewWithDB(bridge.DB.DB, bridge.Config.AppService.Database.Type, nil)
-	bridge.WAContainer.DatabaseErrorHandler = bridge.DB.HandleSignalStoreError
+	br.DB = database.New(br.Bridge.DB)
+	br.WAContainer = sqlstore.NewWithDB(br.DB.DB, br.DB.Dialect.String(), nil)
+	br.WAContainer.DatabaseErrorHandler = br.DB.HandleSignalStoreError
 
 
-	ss := bridge.Config.AppService.Provisioning.SharedSecret
+	ss := br.Config.Bridge.Provisioning.SharedSecret
 	if len(ss) > 0 && ss != "disable" {
 	if len(ss) > 0 && ss != "disable" {
-		bridge.Provisioning = &ProvisioningAPI{bridge: bridge}
+		br.Provisioning = &ProvisioningAPI{bridge: br}
 	}
 	}
 
 
-	bridge.Log.Debugln("Initializing Matrix event processor")
-	bridge.EventProcessor = appservice.NewEventProcessor(bridge.AS)
-	bridge.Log.Debugln("Initializing Matrix event handler")
-	bridge.MatrixHandler = NewMatrixHandler(bridge)
-	bridge.Formatter = NewFormatter(bridge)
-	bridge.Crypto = NewCryptoHelper(bridge)
-	bridge.Metrics = NewMetricsHandler(bridge.Config.Metrics.Listen, bridge.Log.Sub("Metrics"), bridge.DB)
+	br.Log.Debugln("Initializing Matrix event handler")
+	br.MatrixHandler = NewMatrixHandler(br)
+	br.Formatter = NewFormatter(br)
+	br.Metrics = NewMetricsHandler(br.Config.Metrics.Listen, br.Log.Sub("Metrics"), br.DB)
 
 
-	store.BaseClientPayload.UserAgent.OsVersion = proto.String(WAVersion)
-	store.BaseClientPayload.UserAgent.OsBuildNumber = proto.String(WAVersion)
-	store.CompanionProps.Os = proto.String(bridge.Config.WhatsApp.OSName)
-	store.CompanionProps.RequireFullSync = proto.Bool(bridge.Config.Bridge.HistorySync.RequestFullSync)
-	versionParts := strings.Split(WAVersion, ".")
+	store.BaseClientPayload.UserAgent.OsVersion = proto.String(br.WAVersion)
+	store.BaseClientPayload.UserAgent.OsBuildNumber = proto.String(br.WAVersion)
+	store.CompanionProps.Os = proto.String(br.Config.WhatsApp.OSName)
+	store.CompanionProps.RequireFullSync = proto.Bool(br.Config.Bridge.HistorySync.RequestFullSync)
+	versionParts := strings.Split(br.WAVersion, ".")
 	if len(versionParts) > 2 {
 	if len(versionParts) > 2 {
 		primary, _ := strconv.Atoi(versionParts[0])
 		primary, _ := strconv.Atoi(versionParts[0])
 		secondary, _ := strconv.Atoi(versionParts[1])
 		secondary, _ := strconv.Atoi(versionParts[1])
@@ -308,161 +111,107 @@ func (bridge *Bridge) Init() {
 		store.CompanionProps.Version.Secondary = proto.Uint32(uint32(secondary))
 		store.CompanionProps.Version.Secondary = proto.Uint32(uint32(secondary))
 		store.CompanionProps.Version.Tertiary = proto.Uint32(uint32(tertiary))
 		store.CompanionProps.Version.Tertiary = proto.Uint32(uint32(tertiary))
 	}
 	}
-	platformID, ok := waProto.CompanionProps_CompanionPropsPlatformType_value[strings.ToUpper(bridge.Config.WhatsApp.BrowserName)]
+	platformID, ok := waProto.CompanionProps_CompanionPropsPlatformType_value[strings.ToUpper(br.Config.WhatsApp.BrowserName)]
 	if ok {
 	if ok {
 		store.CompanionProps.PlatformType = waProto.CompanionProps_CompanionPropsPlatformType(platformID).Enum()
 		store.CompanionProps.PlatformType = waProto.CompanionProps_CompanionPropsPlatformType(platformID).Enum()
 	}
 	}
 }
 }
 
 
-func (bridge *Bridge) Start() {
-	bridge.Log.Debugln("Running database upgrades")
-	err := bridge.DB.Init()
-	if err != nil && (!errors.Is(err, upgrades.ErrUnsupportedDatabaseVersion) || !*ignoreUnsupportedDatabase) {
-		bridge.Log.Fatalln("Failed to initialize database:", err)
-		if errors.Is(err, upgrades.ErrForeignTables) {
-			bridge.Log.Infoln("You can use --ignore-foreign-tables to ignore this error")
-		} else if errors.Is(err, upgrades.ErrNotOwned) {
-			bridge.Log.Infoln("Sharing the same database with different programs is not supported")
-		} else if errors.Is(err, upgrades.ErrUnsupportedDatabaseVersion) {
-			bridge.Log.Infoln("Downgrading the bridge is not supported")
-		}
+func (br *WABridge) Start() {
+	err := br.WAContainer.Upgrade()
+	if err != nil {
+		br.Log.Fatalln("Failed to upgrade whatsmeow database: %v", err)
 		os.Exit(15)
 		os.Exit(15)
 	}
 	}
-	bridge.Log.Debugln("Checking connection to homeserver")
-	bridge.ensureConnection()
-	if bridge.Crypto != nil {
-		err = bridge.Crypto.Init()
-		if err != nil {
-			bridge.Log.Fatalln("Error initializing end-to-bridge encryption:", err)
-			os.Exit(19)
-		}
-	}
-	if bridge.Provisioning != nil {
-		bridge.Log.Debugln("Initializing provisioning API")
-		bridge.Provisioning.Init()
+	if br.Provisioning != nil {
+		br.Log.Debugln("Initializing provisioning API")
+		br.Provisioning.Init()
 	}
 	}
-	bridge.Log.Debugln("Starting application service HTTP server")
-	go bridge.AS.Start()
-	bridge.Log.Debugln("Starting event processor")
-	go bridge.EventProcessor.Start()
-	go bridge.CheckWhatsAppUpdate()
-	go bridge.UpdateBotProfile()
-	if bridge.Crypto != nil {
-		go bridge.Crypto.Start()
-	}
-	go bridge.StartUsers()
-	if bridge.Config.Metrics.Enabled {
-		go bridge.Metrics.Start()
+	go br.CheckWhatsAppUpdate()
+	go br.StartUsers()
+	if br.Config.Metrics.Enabled {
+		go br.Metrics.Start()
 	}
 	}
 
 
-	if bridge.Config.Bridge.ResendBridgeInfo {
-		go bridge.ResendBridgeInfo()
+	if br.Config.Bridge.ResendBridgeInfo {
+		go br.ResendBridgeInfo()
 	}
 	}
-	go bridge.Loop()
-	bridge.AS.Ready = true
+	go br.Loop()
 }
 }
 
 
-func (bridge *Bridge) CheckWhatsAppUpdate() {
-	bridge.Log.Debugfln("Checking for WhatsApp web update")
+func (br *WABridge) CheckWhatsAppUpdate() {
+	br.Log.Debugfln("Checking for WhatsApp web update")
 	resp, err := whatsmeow.CheckUpdate(http.DefaultClient)
 	resp, err := whatsmeow.CheckUpdate(http.DefaultClient)
 	if err != nil {
 	if err != nil {
-		bridge.Log.Warnfln("Failed to check for WhatsApp web update: %v", err)
+		br.Log.Warnfln("Failed to check for WhatsApp web update: %v", err)
 		return
 		return
 	}
 	}
 	if store.GetWAVersion() == resp.ParsedVersion {
 	if store.GetWAVersion() == resp.ParsedVersion {
-		bridge.Log.Debugfln("Bridge is using latest WhatsApp web protocol")
+		br.Log.Debugfln("Bridge is using latest WhatsApp web protocol")
 	} else if store.GetWAVersion().LessThan(resp.ParsedVersion) {
 	} else if store.GetWAVersion().LessThan(resp.ParsedVersion) {
 		if resp.IsBelowHard || resp.IsBroken {
 		if resp.IsBelowHard || resp.IsBroken {
-			bridge.Log.Warnfln("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
+			br.Log.Warnfln("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
 		} else if resp.IsBelowSoft {
 		} else if resp.IsBelowSoft {
-			bridge.Log.Infofln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
+			br.Log.Infofln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
 		} else {
 		} else {
-			bridge.Log.Debugfln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
+			br.Log.Debugfln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
 		}
 		}
 	} else {
 	} else {
-		bridge.Log.Debugfln("Bridge is using newer than latest WhatsApp web protocol")
+		br.Log.Debugfln("Bridge is using newer than latest WhatsApp web protocol")
 	}
 	}
 }
 }
 
 
-func (bridge *Bridge) Loop() {
+func (br *WABridge) Loop() {
 	for {
 	for {
-		bridge.SleepAndDeleteUpcoming()
+		br.SleepAndDeleteUpcoming()
 		time.Sleep(1 * time.Hour)
 		time.Sleep(1 * time.Hour)
-		bridge.WarnUsersAboutDisconnection()
+		br.WarnUsersAboutDisconnection()
 	}
 	}
 }
 }
 
 
-func (bridge *Bridge) WarnUsersAboutDisconnection() {
-	bridge.usersLock.Lock()
-	for _, user := range bridge.usersByUsername {
+func (br *WABridge) WarnUsersAboutDisconnection() {
+	br.usersLock.Lock()
+	for _, user := range br.usersByUsername {
 		if user.IsConnected() && !user.PhoneRecentlySeen(true) {
 		if user.IsConnected() && !user.PhoneRecentlySeen(true) {
 			go user.sendPhoneOfflineWarning()
 			go user.sendPhoneOfflineWarning()
 		}
 		}
 	}
 	}
-	bridge.usersLock.Unlock()
+	br.usersLock.Unlock()
 }
 }
 
 
-func (bridge *Bridge) ResendBridgeInfo() {
-	if *dontSaveConfig {
-		bridge.Log.Warnln("Not setting resend_bridge_info to false in config due to --no-update flag")
-	} else {
-		err := config.Mutate(*configPath, func(helper *configupgrade.Helper) {
-			helper.Set(configupgrade.Bool, "false", "bridge", "resend_bridge_info")
-		})
-		if err != nil {
-			bridge.Log.Errorln("Failed to save config after setting resend_bridge_info to false:", err)
-		}
-	}
-	bridge.Log.Infoln("Re-sending bridge info state event to all portals")
-	for _, portal := range bridge.GetAllPortals() {
-		portal.UpdateBridgeInfo()
-	}
-	bridge.Log.Infoln("Finished re-sending bridge info state events")
+func (br *WABridge) ResendBridgeInfo() {
+	// FIXME
+	//if *dontSaveConfig {
+	//	br.Log.Warnln("Not setting resend_bridge_info to false in config due to --no-update flag")
+	//} else {
+	//	err := config.Mutate(*configPath, func(helper *configupgrade.Helper) {
+	//		helper.Set(configupgrade.Bool, "false", "bridge", "resend_bridge_info")
+	//	})
+	//	if err != nil {
+	//		br.Log.Errorln("Failed to save config after setting resend_bridge_info to false:", err)
+	//	}
+	//}
+	//br.Log.Infoln("Re-sending bridge info state event to all portals")
+	//for _, portal := range br.GetAllPortals() {
+	//	portal.UpdateBridgeInfo()
+	//}
+	//br.Log.Infoln("Finished re-sending bridge info state events")
 }
 }
 
 
-func (bridge *Bridge) UpdateBotProfile() {
-	bridge.Log.Debugln("Updating bot profile")
-	botConfig := &bridge.Config.AppService.Bot
-
-	var err error
-	var mxc id.ContentURI
-	if botConfig.Avatar == "remove" {
-		err = bridge.Bot.SetAvatarURL(mxc)
-	} else if len(botConfig.Avatar) > 0 {
-		mxc, err = id.ParseContentURI(botConfig.Avatar)
-		if err == nil {
-			err = bridge.Bot.SetAvatarURL(mxc)
-		}
-		botConfig.ParsedAvatar = mxc
-	}
-	if err != nil {
-		bridge.Log.Warnln("Failed to update bot avatar:", err)
-	}
-
-	if botConfig.Displayname == "remove" {
-		err = bridge.Bot.SetDisplayName("")
-	} else if len(botConfig.Displayname) > 0 {
-		err = bridge.Bot.SetDisplayName(botConfig.Displayname)
-	}
-	if err != nil {
-		bridge.Log.Warnln("Failed to update bot displayname:", err)
-	}
-}
-
-func (bridge *Bridge) StartUsers() {
-	bridge.Log.Debugln("Starting users")
+func (br *WABridge) StartUsers() {
+	br.Log.Debugln("Starting users")
 	foundAnySessions := false
 	foundAnySessions := false
-	for _, user := range bridge.GetAllUsers() {
+	for _, user := range br.GetAllUsers() {
 		if !user.JID.IsEmpty() {
 		if !user.JID.IsEmpty() {
 			foundAnySessions = true
 			foundAnySessions = true
 		}
 		}
 		go user.Connect()
 		go user.Connect()
 	}
 	}
 	if !foundAnySessions {
 	if !foundAnySessions {
-		bridge.sendGlobalBridgeState(BridgeState{StateEvent: StateUnconfigured}.fill(nil))
+		br.sendGlobalBridgeState(BridgeState{StateEvent: StateUnconfigured}.fill(nil))
 	}
 	}
-	bridge.Log.Debugln("Starting custom puppets")
-	for _, loopuppet := range bridge.GetAllPuppetsWithCustomMXID() {
+	br.Log.Debugln("Starting custom puppets")
+	for _, loopuppet := range br.GetAllPuppetsWithCustomMXID() {
 		go func(puppet *Puppet) {
 		go func(puppet *Puppet) {
 			puppet.log.Debugln("Starting custom puppet", puppet.CustomMXID)
 			puppet.log.Debugln("Starting custom puppet", puppet.CustomMXID)
 			err := puppet.StartCustomMXID(true)
 			err := puppet.StartCustomMXID(true)
@@ -473,80 +222,37 @@ func (bridge *Bridge) StartUsers() {
 	}
 	}
 }
 }
 
 
-func (bridge *Bridge) Stop() {
-	if bridge.Crypto != nil {
-		bridge.Crypto.Stop()
+func (br *WABridge) Stop() {
+	if br.Crypto != nil {
+		br.Crypto.Stop()
 	}
 	}
-	bridge.AS.Stop()
-	bridge.Metrics.Stop()
-	bridge.EventProcessor.Stop()
-	for _, user := range bridge.usersByUsername {
+	br.AS.Stop()
+	br.Metrics.Stop()
+	br.EventProcessor.Stop()
+	for _, user := range br.usersByUsername {
 		if user.Client == nil {
 		if user.Client == nil {
 			continue
 			continue
 		}
 		}
-		bridge.Log.Debugln("Disconnecting", user.MXID)
+		br.Log.Debugln("Disconnecting", user.MXID)
 		user.Client.Disconnect()
 		user.Client.Disconnect()
 		close(user.historySyncs)
 		close(user.historySyncs)
 	}
 	}
 }
 }
 
 
-func (bridge *Bridge) Main() {
-	configData, upgraded, err := config.Upgrade(*configPath, !*dontSaveConfig)
-	if err != nil {
-		_, _ = fmt.Fprintln(os.Stderr, "Error updating config:", err)
-		if configData == nil {
-			os.Exit(10)
-		}
-	}
-
-	bridge.Config, err = config.Load(configData, upgraded)
-	if err != nil {
-		_, _ = fmt.Fprintln(os.Stderr, "Failed to parse config:", err)
-		os.Exit(10)
-	}
+func (br *WABridge) GetExampleConfig() string {
+	return ExampleConfig
+}
 
 
-	if *generateRegistration {
-		bridge.GenerateRegistration()
-		return
-	} else if *migrateFrom {
-		bridge.MigrateDatabase()
-		return
+func (br *WABridge) GetConfigPtr() interface{} {
+	br.Config = &config.Config{
+		BaseConfig: &br.Bridge.Config,
 	}
 	}
-
-	bridge.Init()
-	bridge.Log.Infoln("Bridge initialization complete, starting...")
-	bridge.Start()
-	bridge.Log.Infoln("Bridge started!")
-
-	c := make(chan os.Signal)
-	signal.Notify(c, os.Interrupt, syscall.SIGTERM)
-	<-c
-
-	bridge.Log.Infoln("Interrupt received, stopping...")
-	bridge.Stop()
-	bridge.Log.Infoln("Bridge stopped.")
-	os.Exit(0)
+	br.Config.BaseConfig.Bridge = &br.Config.Bridge
+	return br.Config
 }
 }
 
 
 func main() {
 func main() {
-	flag.SetHelpTitles(
-		"mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.",
-		"mautrix-whatsapp [-h] [-c <path>] [-r <path>] [-g] [--migrate-db <source type> <source uri>]")
-	err := flag.Parse()
-	if err != nil {
-		_, _ = fmt.Fprintln(os.Stderr, err)
-		flag.PrintHelp()
-		os.Exit(1)
-	} else if *wantHelp {
-		flag.PrintHelp()
-		os.Exit(0)
-	} else if *version {
-		fmt.Println(VersionString)
-		return
-	}
-	upgrades.IgnoreForeignTables = *ignoreForeignTables
-
-	(&Bridge{
+	br := &WABridge{
 		usersByMXID:         make(map[id.UserID]*User),
 		usersByMXID:         make(map[id.UserID]*User),
 		usersByUsername:     make(map[string]*User),
 		usersByUsername:     make(map[string]*User),
 		spaceRooms:          make(map[id.RoomID]*User),
 		spaceRooms:          make(map[id.RoomID]*User),
@@ -555,5 +261,24 @@ func main() {
 		portalsByJID:        make(map[database.PortalKey]*Portal),
 		portalsByJID:        make(map[database.PortalKey]*Portal),
 		puppets:             make(map[types.JID]*Puppet),
 		puppets:             make(map[types.JID]*Puppet),
 		puppetsByCustomMXID: make(map[id.UserID]*Puppet),
 		puppetsByCustomMXID: make(map[id.UserID]*Puppet),
-	}).Main()
+	}
+	br.Bridge = bridge.Bridge{
+		Name:         "mautrix-whatsapp",
+		URL:          "https://github.com/mautrix/whatsapp",
+		Description:  "A Matrix-WhatsApp puppeting bridge.",
+		Version:      "0.4.0",
+		ProtocolName: "WhatsApp",
+
+		ConfigUpgrader: &configupgrade.StructUpgrader{
+			SimpleUpgrader: configupgrade.SimpleUpgrader(config.DoUpgrade),
+			Blocks:         config.SpacedBlocks,
+			Base:           ExampleConfig,
+		},
+
+		Child: br,
+	}
+	br.InitVersion(Tag, Commit, BuildTime)
+	br.WAVersion = strings.FieldsFunc(br.Version, func(r rune) bool { return r == '-' || r == '+' })[0]
+
+	br.Main()
 }
 }

+ 4 - 3
matrix.go

@@ -28,6 +28,7 @@ import (
 
 
 	"maunium.net/go/mautrix"
 	"maunium.net/go/mautrix"
 	"maunium.net/go/mautrix/appservice"
 	"maunium.net/go/mautrix/appservice"
+	"maunium.net/go/mautrix/bridge"
 	"maunium.net/go/mautrix/event"
 	"maunium.net/go/mautrix/event"
 	"maunium.net/go/mautrix/format"
 	"maunium.net/go/mautrix/format"
 	"maunium.net/go/mautrix/id"
 	"maunium.net/go/mautrix/id"
@@ -36,13 +37,13 @@ import (
 )
 )
 
 
 type MatrixHandler struct {
 type MatrixHandler struct {
-	bridge *Bridge
+	bridge *WABridge
 	as     *appservice.AppService
 	as     *appservice.AppService
 	log    maulogger.Logger
 	log    maulogger.Logger
 	cmd    *CommandHandler
 	cmd    *CommandHandler
 }
 }
 
 
-func NewMatrixHandler(bridge *Bridge) *MatrixHandler {
+func NewMatrixHandler(bridge *WABridge) *MatrixHandler {
 	handler := &MatrixHandler{
 	handler := &MatrixHandler{
 		bridge: bridge,
 		bridge: bridge,
 		as:     bridge.AS,
 		as:     bridge.AS,
@@ -362,7 +363,7 @@ func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) {
 
 
 	decrypted, err := mx.bridge.Crypto.Decrypt(evt)
 	decrypted, err := mx.bridge.Crypto.Decrypt(evt)
 	decryptionRetryCount := 0
 	decryptionRetryCount := 0
-	if errors.Is(err, NoSessionFound) {
+	if errors.Is(err, bridge.NoSessionFound) {
 		content := evt.Content.AsEncrypted()
 		content := evt.Content.AsEncrypted()
 		mx.log.Debugfln("Couldn't find session %s trying to decrypt %s, waiting %d seconds...", content.SessionID, evt.ID, int(sessionWaitTimeout.Seconds()))
 		mx.log.Debugfln("Couldn't find session %s trying to decrypt %s, waiting %d seconds...", content.SessionID, evt.ID, int(sessionWaitTimeout.Seconds()))
 		mx.as.SendErrorMessageSendCheckpoint(evt, appservice.StepDecrypted, err, false, decryptionRetryCount)
 		mx.as.SendErrorMessageSendCheckpoint(evt, appservice.StepDecrypted, err, false, decryptionRetryCount)

+ 0 - 17
no-crypto.go

@@ -1,17 +0,0 @@
-//go:build !cgo || nocrypto
-
-package main
-
-import (
-	"errors"
-)
-
-func NewCryptoHelper(bridge *Bridge) Crypto {
-	if !bridge.Config.Bridge.Encryption.Allow {
-		bridge.Log.Warnln("Bridge built without end-to-bridge encryption, but encryption is enabled in config")
-	}
-	bridge.Log.Debugln("Bridge built without end-to-bridge encryption")
-	return nil
-}
-
-var NoSessionFound = errors.New("nil")

+ 47 - 38
portal.go

@@ -45,6 +45,7 @@ import (
 
 
 	"maunium.net/go/mautrix"
 	"maunium.net/go/mautrix"
 	"maunium.net/go/mautrix/appservice"
 	"maunium.net/go/mautrix/appservice"
+	"maunium.net/go/mautrix/bridge"
 	"maunium.net/go/mautrix/crypto/attachment"
 	"maunium.net/go/mautrix/crypto/attachment"
 	"maunium.net/go/mautrix/event"
 	"maunium.net/go/mautrix/event"
 	"maunium.net/go/mautrix/format"
 	"maunium.net/go/mautrix/format"
@@ -69,68 +70,72 @@ const PrivateChatTopic = "WhatsApp private chat"
 
 
 var ErrStatusBroadcastDisabled = errors.New("status bridging is disabled")
 var ErrStatusBroadcastDisabled = errors.New("status bridging is disabled")
 
 
-func (bridge *Bridge) GetPortalByMXID(mxid id.RoomID) *Portal {
-	bridge.portalsLock.Lock()
-	defer bridge.portalsLock.Unlock()
-	portal, ok := bridge.portalsByMXID[mxid]
+func (br *WABridge) GetPortalByMXID(mxid id.RoomID) *Portal {
+	br.portalsLock.Lock()
+	defer br.portalsLock.Unlock()
+	portal, ok := br.portalsByMXID[mxid]
 	if !ok {
 	if !ok {
-		return bridge.loadDBPortal(bridge.DB.Portal.GetByMXID(mxid), nil)
+		return br.loadDBPortal(br.DB.Portal.GetByMXID(mxid), nil)
 	}
 	}
 	return portal
 	return portal
 }
 }
 
 
-func (bridge *Bridge) GetPortalByJID(key database.PortalKey) *Portal {
-	bridge.portalsLock.Lock()
-	defer bridge.portalsLock.Unlock()
-	portal, ok := bridge.portalsByJID[key]
+func (br *WABridge) GetIPortalByMXID(mxid id.RoomID) bridge.Portal {
+	return br.GetPortalByMXID(mxid)
+}
+
+func (br *WABridge) GetPortalByJID(key database.PortalKey) *Portal {
+	br.portalsLock.Lock()
+	defer br.portalsLock.Unlock()
+	portal, ok := br.portalsByJID[key]
 	if !ok {
 	if !ok {
-		return bridge.loadDBPortal(bridge.DB.Portal.GetByJID(key), &key)
+		return br.loadDBPortal(br.DB.Portal.GetByJID(key), &key)
 	}
 	}
 	return portal
 	return portal
 }
 }
 
 
-func (bridge *Bridge) GetAllPortals() []*Portal {
-	return bridge.dbPortalsToPortals(bridge.DB.Portal.GetAll())
+func (br *WABridge) GetAllPortals() []*Portal {
+	return br.dbPortalsToPortals(br.DB.Portal.GetAll())
 }
 }
 
 
-func (bridge *Bridge) GetAllPortalsForUser(userID id.UserID) []*Portal {
-	return bridge.dbPortalsToPortals(bridge.DB.Portal.GetAllForUser(userID))
+func (br *WABridge) GetAllPortalsForUser(userID id.UserID) []*Portal {
+	return br.dbPortalsToPortals(br.DB.Portal.GetAllForUser(userID))
 }
 }
 
 
-func (bridge *Bridge) GetAllPortalsByJID(jid types.JID) []*Portal {
-	return bridge.dbPortalsToPortals(bridge.DB.Portal.GetAllByJID(jid))
+func (br *WABridge) GetAllPortalsByJID(jid types.JID) []*Portal {
+	return br.dbPortalsToPortals(br.DB.Portal.GetAllByJID(jid))
 }
 }
 
 
-func (bridge *Bridge) dbPortalsToPortals(dbPortals []*database.Portal) []*Portal {
-	bridge.portalsLock.Lock()
-	defer bridge.portalsLock.Unlock()
+func (br *WABridge) dbPortalsToPortals(dbPortals []*database.Portal) []*Portal {
+	br.portalsLock.Lock()
+	defer br.portalsLock.Unlock()
 	output := make([]*Portal, len(dbPortals))
 	output := make([]*Portal, len(dbPortals))
 	for index, dbPortal := range dbPortals {
 	for index, dbPortal := range dbPortals {
 		if dbPortal == nil {
 		if dbPortal == nil {
 			continue
 			continue
 		}
 		}
-		portal, ok := bridge.portalsByJID[dbPortal.Key]
+		portal, ok := br.portalsByJID[dbPortal.Key]
 		if !ok {
 		if !ok {
-			portal = bridge.loadDBPortal(dbPortal, nil)
+			portal = br.loadDBPortal(dbPortal, nil)
 		}
 		}
 		output[index] = portal
 		output[index] = portal
 	}
 	}
 	return output
 	return output
 }
 }
 
 
-func (bridge *Bridge) loadDBPortal(dbPortal *database.Portal, key *database.PortalKey) *Portal {
+func (br *WABridge) loadDBPortal(dbPortal *database.Portal, key *database.PortalKey) *Portal {
 	if dbPortal == nil {
 	if dbPortal == nil {
 		if key == nil {
 		if key == nil {
 			return nil
 			return nil
 		}
 		}
-		dbPortal = bridge.DB.Portal.New()
+		dbPortal = br.DB.Portal.New()
 		dbPortal.Key = *key
 		dbPortal.Key = *key
 		dbPortal.Insert()
 		dbPortal.Insert()
 	}
 	}
-	portal := bridge.NewPortal(dbPortal)
-	bridge.portalsByJID[portal.Key] = portal
+	portal := br.NewPortal(dbPortal)
+	br.portalsByJID[portal.Key] = portal
 	if len(portal.MXID) > 0 {
 	if len(portal.MXID) > 0 {
-		bridge.portalsByMXID[portal.MXID] = portal
+		br.portalsByMXID[portal.MXID] = portal
 	}
 	}
 	return portal
 	return portal
 }
 }
@@ -139,14 +144,14 @@ func (portal *Portal) GetUsers() []*User {
 	return nil
 	return nil
 }
 }
 
 
-func (bridge *Bridge) newBlankPortal(key database.PortalKey) *Portal {
+func (br *WABridge) newBlankPortal(key database.PortalKey) *Portal {
 	portal := &Portal{
 	portal := &Portal{
-		bridge: bridge,
-		log:    bridge.Log.Sub(fmt.Sprintf("Portal/%s", key)),
+		bridge: br,
+		log:    br.Log.Sub(fmt.Sprintf("Portal/%s", key)),
 
 
-		messages:       make(chan PortalMessage, bridge.Config.Bridge.PortalMessageBuffer),
-		matrixMessages: make(chan PortalMatrixMessage, bridge.Config.Bridge.PortalMessageBuffer),
-		mediaRetries:   make(chan PortalMediaRetry, bridge.Config.Bridge.PortalMessageBuffer),
+		messages:       make(chan PortalMessage, br.Config.Bridge.PortalMessageBuffer),
+		matrixMessages: make(chan PortalMatrixMessage, br.Config.Bridge.PortalMessageBuffer),
+		mediaRetries:   make(chan PortalMediaRetry, br.Config.Bridge.PortalMessageBuffer),
 
 
 		mediaErrorCache: make(map[types.MessageID]*FailedMediaMeta),
 		mediaErrorCache: make(map[types.MessageID]*FailedMediaMeta),
 	}
 	}
@@ -154,15 +159,15 @@ func (bridge *Bridge) newBlankPortal(key database.PortalKey) *Portal {
 	return portal
 	return portal
 }
 }
 
 
-func (bridge *Bridge) NewManualPortal(key database.PortalKey) *Portal {
-	portal := bridge.newBlankPortal(key)
-	portal.Portal = bridge.DB.Portal.New()
+func (br *WABridge) NewManualPortal(key database.PortalKey) *Portal {
+	portal := br.newBlankPortal(key)
+	portal.Portal = br.DB.Portal.New()
 	portal.Key = key
 	portal.Key = key
 	return portal
 	return portal
 }
 }
 
 
-func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal {
-	portal := bridge.newBlankPortal(dbPortal.Key)
+func (br *WABridge) NewPortal(dbPortal *database.Portal) *Portal {
+	portal := br.newBlankPortal(dbPortal.Key)
 	portal.Portal = dbPortal
 	portal.Portal = dbPortal
 	return portal
 	return portal
 }
 }
@@ -203,7 +208,7 @@ type recentlyHandledWrapper struct {
 type Portal struct {
 type Portal struct {
 	*database.Portal
 	*database.Portal
 
 
-	bridge *Bridge
+	bridge *WABridge
 	log    log.Logger
 	log    log.Logger
 
 
 	roomCreateLock sync.Mutex
 	roomCreateLock sync.Mutex
@@ -229,6 +234,10 @@ type Portal struct {
 	relayUser *User
 	relayUser *User
 }
 }
 
 
+func (portal *Portal) IsEncrypted() bool {
+	return portal.Encrypted
+}
+
 func (portal *Portal) handleMessageLoopItem(msg PortalMessage) {
 func (portal *Portal) handleMessageLoopItem(msg PortalMessage) {
 	if len(portal.MXID) == 0 {
 	if len(portal.MXID) == 0 {
 		if msg.fake == nil && msg.undecryptable == nil && (msg.evt == nil || !containsSupportedMessage(msg.evt.Message)) {
 		if msg.fake == nil && msg.undecryptable == nil && (msg.evt == nil || !containsSupportedMessage(msg.evt.Message)) {

+ 4 - 4
provisioning.go

@@ -43,15 +43,15 @@ import (
 )
 )
 
 
 type ProvisioningAPI struct {
 type ProvisioningAPI struct {
-	bridge *Bridge
+	bridge *WABridge
 	log    log.Logger
 	log    log.Logger
 }
 }
 
 
 func (prov *ProvisioningAPI) Init() {
 func (prov *ProvisioningAPI) Init() {
 	prov.log = prov.bridge.Log.Sub("Provisioning")
 	prov.log = prov.bridge.Log.Sub("Provisioning")
 
 
-	prov.log.Debugln("Enabling provisioning API at", prov.bridge.Config.AppService.Provisioning.Prefix)
-	r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.AppService.Provisioning.Prefix).Subrouter()
+	prov.log.Debugln("Enabling provisioning API at", prov.bridge.Config.Bridge.Provisioning.Prefix)
+	r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.Bridge.Provisioning.Prefix).Subrouter()
 	r.Use(prov.AuthMiddleware)
 	r.Use(prov.AuthMiddleware)
 	r.HandleFunc("/v1/ping", prov.Ping).Methods(http.MethodGet)
 	r.HandleFunc("/v1/ping", prov.Ping).Methods(http.MethodGet)
 	r.HandleFunc("/v1/login", prov.Login).Methods(http.MethodGet)
 	r.HandleFunc("/v1/login", prov.Login).Methods(http.MethodGet)
@@ -109,7 +109,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
 		} else if strings.HasPrefix(auth, "Bearer ") {
 		} else if strings.HasPrefix(auth, "Bearer ") {
 			auth = auth[len("Bearer "):]
 			auth = auth[len("Bearer "):]
 		}
 		}
-		if auth != prov.bridge.Config.AppService.Provisioning.SharedSecret {
+		if auth != prov.bridge.Config.Bridge.Provisioning.SharedSecret {
 			jsonResponse(w, http.StatusForbidden, map[string]interface{}{
 			jsonResponse(w, http.StatusForbidden, map[string]interface{}{
 				"error":   "Invalid auth token",
 				"error":   "Invalid auth token",
 				"errcode": "M_FORBIDDEN",
 				"errcode": "M_FORBIDDEN",

+ 42 - 42
puppet.go

@@ -39,11 +39,11 @@ import (
 
 
 var userIDRegex *regexp.Regexp
 var userIDRegex *regexp.Regexp
 
 
-func (bridge *Bridge) ParsePuppetMXID(mxid id.UserID) (jid types.JID, ok bool) {
+func (br *WABridge) ParsePuppetMXID(mxid id.UserID) (jid types.JID, ok bool) {
 	if userIDRegex == nil {
 	if userIDRegex == nil {
 		userIDRegex = regexp.MustCompile(fmt.Sprintf("^@%s:%s$",
 		userIDRegex = regexp.MustCompile(fmt.Sprintf("^@%s:%s$",
-			bridge.Config.Bridge.FormatUsername("([0-9]+)"),
-			bridge.Config.Homeserver.Domain))
+			br.Config.Bridge.FormatUsername("([0-9]+)"),
+			br.Config.Homeserver.Domain))
 	}
 	}
 	match := userIDRegex.FindStringSubmatch(string(mxid))
 	match := userIDRegex.FindStringSubmatch(string(mxid))
 	if len(match) == 2 {
 	if len(match) == 2 {
@@ -53,79 +53,79 @@ func (bridge *Bridge) ParsePuppetMXID(mxid id.UserID) (jid types.JID, ok bool) {
 	return
 	return
 }
 }
 
 
-func (bridge *Bridge) GetPuppetByMXID(mxid id.UserID) *Puppet {
-	jid, ok := bridge.ParsePuppetMXID(mxid)
+func (br *WABridge) GetPuppetByMXID(mxid id.UserID) *Puppet {
+	jid, ok := br.ParsePuppetMXID(mxid)
 	if !ok {
 	if !ok {
 		return nil
 		return nil
 	}
 	}
 
 
-	return bridge.GetPuppetByJID(jid)
+	return br.GetPuppetByJID(jid)
 }
 }
 
 
-func (bridge *Bridge) GetPuppetByJID(jid types.JID) *Puppet {
+func (br *WABridge) GetPuppetByJID(jid types.JID) *Puppet {
 	jid = jid.ToNonAD()
 	jid = jid.ToNonAD()
 	if jid.Server == types.LegacyUserServer {
 	if jid.Server == types.LegacyUserServer {
 		jid.Server = types.DefaultUserServer
 		jid.Server = types.DefaultUserServer
 	} else if jid.Server != types.DefaultUserServer {
 	} else if jid.Server != types.DefaultUserServer {
 		return nil
 		return nil
 	}
 	}
-	bridge.puppetsLock.Lock()
-	defer bridge.puppetsLock.Unlock()
-	puppet, ok := bridge.puppets[jid]
+	br.puppetsLock.Lock()
+	defer br.puppetsLock.Unlock()
+	puppet, ok := br.puppets[jid]
 	if !ok {
 	if !ok {
-		dbPuppet := bridge.DB.Puppet.Get(jid)
+		dbPuppet := br.DB.Puppet.Get(jid)
 		if dbPuppet == nil {
 		if dbPuppet == nil {
-			dbPuppet = bridge.DB.Puppet.New()
+			dbPuppet = br.DB.Puppet.New()
 			dbPuppet.JID = jid
 			dbPuppet.JID = jid
 			dbPuppet.Insert()
 			dbPuppet.Insert()
 		}
 		}
-		puppet = bridge.NewPuppet(dbPuppet)
-		bridge.puppets[puppet.JID] = puppet
+		puppet = br.NewPuppet(dbPuppet)
+		br.puppets[puppet.JID] = puppet
 		if len(puppet.CustomMXID) > 0 {
 		if len(puppet.CustomMXID) > 0 {
-			bridge.puppetsByCustomMXID[puppet.CustomMXID] = puppet
+			br.puppetsByCustomMXID[puppet.CustomMXID] = puppet
 		}
 		}
 	}
 	}
 	return puppet
 	return puppet
 }
 }
 
 
-func (bridge *Bridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet {
-	bridge.puppetsLock.Lock()
-	defer bridge.puppetsLock.Unlock()
-	puppet, ok := bridge.puppetsByCustomMXID[mxid]
+func (br *WABridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet {
+	br.puppetsLock.Lock()
+	defer br.puppetsLock.Unlock()
+	puppet, ok := br.puppetsByCustomMXID[mxid]
 	if !ok {
 	if !ok {
-		dbPuppet := bridge.DB.Puppet.GetByCustomMXID(mxid)
+		dbPuppet := br.DB.Puppet.GetByCustomMXID(mxid)
 		if dbPuppet == nil {
 		if dbPuppet == nil {
 			return nil
 			return nil
 		}
 		}
-		puppet = bridge.NewPuppet(dbPuppet)
-		bridge.puppets[puppet.JID] = puppet
-		bridge.puppetsByCustomMXID[puppet.CustomMXID] = puppet
+		puppet = br.NewPuppet(dbPuppet)
+		br.puppets[puppet.JID] = puppet
+		br.puppetsByCustomMXID[puppet.CustomMXID] = puppet
 	}
 	}
 	return puppet
 	return puppet
 }
 }
 
 
-func (bridge *Bridge) GetAllPuppetsWithCustomMXID() []*Puppet {
-	return bridge.dbPuppetsToPuppets(bridge.DB.Puppet.GetAllWithCustomMXID())
+func (br *WABridge) GetAllPuppetsWithCustomMXID() []*Puppet {
+	return br.dbPuppetsToPuppets(br.DB.Puppet.GetAllWithCustomMXID())
 }
 }
 
 
-func (bridge *Bridge) GetAllPuppets() []*Puppet {
-	return bridge.dbPuppetsToPuppets(bridge.DB.Puppet.GetAll())
+func (br *WABridge) GetAllPuppets() []*Puppet {
+	return br.dbPuppetsToPuppets(br.DB.Puppet.GetAll())
 }
 }
 
 
-func (bridge *Bridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet {
-	bridge.puppetsLock.Lock()
-	defer bridge.puppetsLock.Unlock()
+func (br *WABridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet {
+	br.puppetsLock.Lock()
+	defer br.puppetsLock.Unlock()
 	output := make([]*Puppet, len(dbPuppets))
 	output := make([]*Puppet, len(dbPuppets))
 	for index, dbPuppet := range dbPuppets {
 	for index, dbPuppet := range dbPuppets {
 		if dbPuppet == nil {
 		if dbPuppet == nil {
 			continue
 			continue
 		}
 		}
-		puppet, ok := bridge.puppets[dbPuppet.JID]
+		puppet, ok := br.puppets[dbPuppet.JID]
 		if !ok {
 		if !ok {
-			puppet = bridge.NewPuppet(dbPuppet)
-			bridge.puppets[dbPuppet.JID] = puppet
+			puppet = br.NewPuppet(dbPuppet)
+			br.puppets[dbPuppet.JID] = puppet
 			if len(dbPuppet.CustomMXID) > 0 {
 			if len(dbPuppet.CustomMXID) > 0 {
-				bridge.puppetsByCustomMXID[dbPuppet.CustomMXID] = puppet
+				br.puppetsByCustomMXID[dbPuppet.CustomMXID] = puppet
 			}
 			}
 		}
 		}
 		output[index] = puppet
 		output[index] = puppet
@@ -133,26 +133,26 @@ func (bridge *Bridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet
 	return output
 	return output
 }
 }
 
 
-func (bridge *Bridge) FormatPuppetMXID(jid types.JID) id.UserID {
+func (br *WABridge) FormatPuppetMXID(jid types.JID) id.UserID {
 	return id.NewUserID(
 	return id.NewUserID(
-		bridge.Config.Bridge.FormatUsername(jid.User),
-		bridge.Config.Homeserver.Domain)
+		br.Config.Bridge.FormatUsername(jid.User),
+		br.Config.Homeserver.Domain)
 }
 }
 
 
-func (bridge *Bridge) NewPuppet(dbPuppet *database.Puppet) *Puppet {
+func (br *WABridge) NewPuppet(dbPuppet *database.Puppet) *Puppet {
 	return &Puppet{
 	return &Puppet{
 		Puppet: dbPuppet,
 		Puppet: dbPuppet,
-		bridge: bridge,
-		log:    bridge.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)),
+		bridge: br,
+		log:    br.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)),
 
 
-		MXID: bridge.FormatPuppetMXID(dbPuppet.JID),
+		MXID: br.FormatPuppetMXID(dbPuppet.JID),
 	}
 	}
 }
 }
 
 
 type Puppet struct {
 type Puppet struct {
 	*database.Puppet
 	*database.Puppet
 
 
-	bridge *Bridge
+	bridge *WABridge
 	log    log.Logger
 	log    log.Logger
 
 
 	typingIn id.RoomID
 	typingIn id.RoomID

+ 42 - 33
user.go

@@ -35,6 +35,7 @@ import (
 
 
 	"maunium.net/go/mautrix"
 	"maunium.net/go/mautrix"
 	"maunium.net/go/mautrix/appservice"
 	"maunium.net/go/mautrix/appservice"
+	"maunium.net/go/mautrix/bridge"
 	"maunium.net/go/mautrix/event"
 	"maunium.net/go/mautrix/event"
 	"maunium.net/go/mautrix/format"
 	"maunium.net/go/mautrix/format"
 	"maunium.net/go/mautrix/id"
 	"maunium.net/go/mautrix/id"
@@ -56,7 +57,7 @@ type User struct {
 	Client  *whatsmeow.Client
 	Client  *whatsmeow.Client
 	Session *store.Device
 	Session *store.Device
 
 
-	bridge *Bridge
+	bridge *WABridge
 	log    log.Logger
 	log    log.Logger
 
 
 	Admin            bool
 	Admin            bool
@@ -84,38 +85,46 @@ type User struct {
 	BackfillQueue *BackfillQueue
 	BackfillQueue *BackfillQueue
 }
 }
 
 
-func (bridge *Bridge) getUserByMXID(userID id.UserID, onlyIfExists bool) *User {
-	_, isPuppet := bridge.ParsePuppetMXID(userID)
-	if isPuppet || userID == bridge.Bot.UserID {
+func (br *WABridge) getUserByMXID(userID id.UserID, onlyIfExists bool) *User {
+	_, isPuppet := br.ParsePuppetMXID(userID)
+	if isPuppet || userID == br.Bot.UserID {
 		return nil
 		return nil
 	}
 	}
-	bridge.usersLock.Lock()
-	defer bridge.usersLock.Unlock()
-	user, ok := bridge.usersByMXID[userID]
+	br.usersLock.Lock()
+	defer br.usersLock.Unlock()
+	user, ok := br.usersByMXID[userID]
 	if !ok {
 	if !ok {
 		userIDPtr := &userID
 		userIDPtr := &userID
 		if onlyIfExists {
 		if onlyIfExists {
 			userIDPtr = nil
 			userIDPtr = nil
 		}
 		}
-		return bridge.loadDBUser(bridge.DB.User.GetByMXID(userID), userIDPtr)
+		return br.loadDBUser(br.DB.User.GetByMXID(userID), userIDPtr)
 	}
 	}
 	return user
 	return user
 }
 }
 
 
-func (bridge *Bridge) GetUserByMXID(userID id.UserID) *User {
-	return bridge.getUserByMXID(userID, false)
+func (br *WABridge) GetUserByMXID(userID id.UserID) *User {
+	return br.getUserByMXID(userID, false)
 }
 }
 
 
-func (bridge *Bridge) GetUserByMXIDIfExists(userID id.UserID) *User {
-	return bridge.getUserByMXID(userID, true)
+func (br *WABridge) GetIUserByMXID(userID id.UserID) bridge.User {
+	return br.getUserByMXID(userID, false)
 }
 }
 
 
-func (bridge *Bridge) GetUserByJID(jid types.JID) *User {
-	bridge.usersLock.Lock()
-	defer bridge.usersLock.Unlock()
-	user, ok := bridge.usersByUsername[jid.User]
+func (user *User) IsAdmin() bool {
+	return user.Admin
+}
+
+func (br *WABridge) GetUserByMXIDIfExists(userID id.UserID) *User {
+	return br.getUserByMXID(userID, true)
+}
+
+func (br *WABridge) GetUserByJID(jid types.JID) *User {
+	br.usersLock.Lock()
+	defer br.usersLock.Unlock()
+	user, ok := br.usersByUsername[jid.User]
 	if !ok {
 	if !ok {
-		return bridge.loadDBUser(bridge.DB.User.GetByUsername(jid.User), nil)
+		return br.loadDBUser(br.DB.User.GetByUsername(jid.User), nil)
 	}
 	}
 	return user
 	return user
 }
 }
@@ -137,35 +146,35 @@ func (user *User) removeFromJIDMap(state BridgeState) {
 	user.sendBridgeState(state)
 	user.sendBridgeState(state)
 }
 }
 
 
-func (bridge *Bridge) GetAllUsers() []*User {
-	bridge.usersLock.Lock()
-	defer bridge.usersLock.Unlock()
-	dbUsers := bridge.DB.User.GetAll()
+func (br *WABridge) GetAllUsers() []*User {
+	br.usersLock.Lock()
+	defer br.usersLock.Unlock()
+	dbUsers := br.DB.User.GetAll()
 	output := make([]*User, len(dbUsers))
 	output := make([]*User, len(dbUsers))
 	for index, dbUser := range dbUsers {
 	for index, dbUser := range dbUsers {
-		user, ok := bridge.usersByMXID[dbUser.MXID]
+		user, ok := br.usersByMXID[dbUser.MXID]
 		if !ok {
 		if !ok {
-			user = bridge.loadDBUser(dbUser, nil)
+			user = br.loadDBUser(dbUser, nil)
 		}
 		}
 		output[index] = user
 		output[index] = user
 	}
 	}
 	return output
 	return output
 }
 }
 
 
-func (bridge *Bridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User {
+func (br *WABridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User {
 	if dbUser == nil {
 	if dbUser == nil {
 		if mxid == nil {
 		if mxid == nil {
 			return nil
 			return nil
 		}
 		}
-		dbUser = bridge.DB.User.New()
+		dbUser = br.DB.User.New()
 		dbUser.MXID = *mxid
 		dbUser.MXID = *mxid
 		dbUser.Insert()
 		dbUser.Insert()
 	}
 	}
-	user := bridge.NewUser(dbUser)
-	bridge.usersByMXID[user.MXID] = user
+	user := br.NewUser(dbUser)
+	br.usersByMXID[user.MXID] = user
 	if !user.JID.IsEmpty() {
 	if !user.JID.IsEmpty() {
 		var err error
 		var err error
-		user.Session, err = bridge.WAContainer.GetDevice(user.JID)
+		user.Session, err = br.WAContainer.GetDevice(user.JID)
 		if err != nil {
 		if err != nil {
 			user.log.Errorfln("Failed to load user's whatsapp session: %v", err)
 			user.log.Errorfln("Failed to load user's whatsapp session: %v", err)
 		} else if user.Session == nil {
 		} else if user.Session == nil {
@@ -174,20 +183,20 @@ func (bridge *Bridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User {
 			user.Update()
 			user.Update()
 		} else {
 		} else {
 			user.Session.Log = &waLogger{user.log.Sub("Session")}
 			user.Session.Log = &waLogger{user.log.Sub("Session")}
-			bridge.usersByUsername[user.JID.User] = user
+			br.usersByUsername[user.JID.User] = user
 		}
 		}
 	}
 	}
 	if len(user.ManagementRoom) > 0 {
 	if len(user.ManagementRoom) > 0 {
-		bridge.managementRooms[user.ManagementRoom] = user
+		br.managementRooms[user.ManagementRoom] = user
 	}
 	}
 	return user
 	return user
 }
 }
 
 
-func (bridge *Bridge) NewUser(dbUser *database.User) *User {
+func (br *WABridge) NewUser(dbUser *database.User) *User {
 	user := &User{
 	user := &User{
 		User:   dbUser,
 		User:   dbUser,
-		bridge: bridge,
-		log:    bridge.Log.Sub("User").Sub(string(dbUser.MXID)),
+		bridge: br,
+		log:    br.Log.Sub("User").Sub(string(dbUser.MXID)),
 
 
 		historySyncs: make(chan *events.HistorySync, 32),
 		historySyncs: make(chan *events.HistorySync, 32),
 		lastPresence: types.PresenceUnavailable,
 		lastPresence: types.PresenceUnavailable,