Эх сурвалжийг харах

Add more options for guild message handling

Tulir Asokan 2 жил өмнө
parent
commit
4676ec98c4

+ 37 - 7
commands.go

@@ -34,6 +34,7 @@ import (
 	"maunium.net/go/mautrix/event"
 	"maunium.net/go/mautrix/id"
 
+	"go.mau.fi/mautrix-discord/database"
 	"go.mau.fi/mautrix-discord/remoteauth"
 )
 
@@ -305,18 +306,19 @@ var cmdGuilds = &commands.FullHandler{
 	Help: commands.HelpMeta{
 		Section:     commands.HelpSectionUnclassified,
 		Description: "Guild bridging management",
-		Args:        "<status/bridge/unbridge> [_guild ID_] [--entire]",
+		Args:        "<status/bridge/unbridge/bridging-mode> [_guild ID_] [...]",
 	},
 	RequiresLogin: true,
 }
 
-const smallGuildsHelp = "**Usage**: `$cmdprefix guilds <help/status/bridge/unbridge> [guild ID] [--entire]`"
+const smallGuildsHelp = "**Usage**: `$cmdprefix guilds <help/status/bridge/unbridge> [guild ID] [...]`"
 
 const fullGuildsHelp = smallGuildsHelp + `
 
 * **help** - View this help message.
 * **status** - View the list of guilds and their bridging status.
 * **bridge <_guild ID_> [--entire]** - Enable bridging for a guild. The --entire flag auto-creates portals for all channels.
+* **bridging-mode <_guild ID_> <_mode_>** - Set the mode for bridging messages and new channels in a guild.
 * **unbridge <_guild ID_>** - Unbridge a guild and delete all channel portal rooms.`
 
 func fnGuilds(ce *WrappedCommandEvent) {
@@ -333,6 +335,8 @@ func fnGuilds(ce *WrappedCommandEvent) {
 		fnBridgeGuild(ce)
 	case "unbridge", "delete":
 		fnUnbridgeGuild(ce)
+	case "bridging-mode", "mode":
+		fnGuildBridgingMode(ce)
 	case "help":
 		ce.Reply(fullGuildsHelp)
 	default:
@@ -347,15 +351,11 @@ func fnListGuilds(ce *WrappedCommandEvent) {
 		if guild == nil {
 			continue
 		}
-		status := "not bridged"
-		if guild.MXID != "" {
-			status = "bridged"
-		}
 		var avatarHTML string
 		if !guild.AvatarURL.IsEmpty() {
 			avatarHTML = fmt.Sprintf(`<img data-mx-emoticon height="24" src="%s" alt="" title="Guild avatar"> `, guild.AvatarURL.String())
 		}
-		items = append(items, fmt.Sprintf("<li>%s%s (<code>%s</code>) - %s</li>", avatarHTML, html.EscapeString(guild.Name), guild.ID, status))
+		items = append(items, fmt.Sprintf("<li>%s%s (<code>%s</code>) - %s</li>", avatarHTML, html.EscapeString(guild.Name), guild.ID, guild.BridgingMode.Description()))
 	}
 	if len(items) == 0 {
 		ce.Reply("No guilds found")
@@ -384,6 +384,36 @@ func fnUnbridgeGuild(ce *WrappedCommandEvent) {
 	}
 }
 
+const availableModes = "Available modes:\n" +
+	"* `nothing` to never bridge any messages (default when unbridged)\n" +
+	"* `if-portal-exists` to bridge messages in existing portals, but drop messages in unbridged channels\n" +
+	"* `create-on-message` to bridge all messages and create portals if necessary on incoming messages (default after bridging)\n" +
+	"* `everything` to bridge all messages and create portals proactively on bridge startup (default if bridged with `--entire`)\n"
+
+func fnGuildBridgingMode(ce *WrappedCommandEvent) {
+	if len(ce.Args) == 0 || len(ce.Args) > 2 {
+		ce.Reply("**Usage**: `$cmdprefix guilds bridging-mode <guild ID> [mode]`\n\n" + availableModes)
+		return
+	}
+	guild := ce.Bridge.GetGuildByID(ce.Args[0], false)
+	if guild == nil {
+		ce.Reply("Guild not found")
+		return
+	}
+	if len(ce.Args) == 1 {
+		ce.Reply("%s (%s) is currently set to %s (`%s`)\n\n%s", guild.PlainName, guild.ID, guild.BridgingMode.Description(), guild.BridgingMode.String(), availableModes)
+		return
+	}
+	mode := database.ParseGuildBridgingMode(ce.Args[1])
+	if mode == database.GuildBridgeInvalid {
+		ce.Reply("Invalid guild bridging mode `%s`", ce.Args[1])
+		return
+	}
+	guild.BridgingMode = mode
+	guild.Update()
+	ce.Reply("Set guild bridging mode to %s", mode.Description())
+}
+
 var cmdDeleteAllPortals = &commands.FullHandler{
 	Func: wrapCommand(fnDeleteAllPortals),
 	Name: "delete-all-portals",

+ 75 - 7
database/guild.go

@@ -3,6 +3,8 @@ package database
 import (
 	"database/sql"
 	"errors"
+	"fmt"
+	"strings"
 
 	log "maunium.net/go/maulogger/v2"
 	"maunium.net/go/mautrix/id"
@@ -10,13 +12,76 @@ import (
 	"maunium.net/go/mautrix/util/dbutil"
 )
 
+type GuildBridgingMode int
+
+const (
+	// GuildBridgeNothing tells the bridge to never bridge messages, not even checking if a portal exists.
+	GuildBridgeNothing GuildBridgingMode = iota
+	// GuildBridgeIfPortalExists tells the bridge to bridge messages in channels that already have portals.
+	GuildBridgeIfPortalExists
+	// GuildBridgeCreateOnMessage tells the bridge to create portals as soon as a message is received.
+	GuildBridgeCreateOnMessage
+	// GuildBridgeEverything tells the bridge to proactively create portals on startup and when receiving channel create notifications.
+	GuildBridgeEverything
+
+	GuildBridgeInvalid GuildBridgingMode = -1
+)
+
+func ParseGuildBridgingMode(str string) GuildBridgingMode {
+	str = strings.ToLower(str)
+	str = strings.ReplaceAll(str, "-", "")
+	str = strings.ReplaceAll(str, "_", "")
+	switch str {
+	case "nothing", "0":
+		return GuildBridgeNothing
+	case "ifportalexists", "1":
+		return GuildBridgeIfPortalExists
+	case "createonmessage", "2":
+		return GuildBridgeCreateOnMessage
+	case "everything", "3":
+		return GuildBridgeEverything
+	default:
+		return GuildBridgeInvalid
+	}
+}
+
+func (gbm GuildBridgingMode) String() string {
+	switch gbm {
+	case GuildBridgeNothing:
+		return "nothing"
+	case GuildBridgeIfPortalExists:
+		return "if-portal-exists"
+	case GuildBridgeCreateOnMessage:
+		return "create-on-message"
+	case GuildBridgeEverything:
+		return "everything"
+	default:
+		return ""
+	}
+}
+
+func (gbm GuildBridgingMode) Description() string {
+	switch gbm {
+	case GuildBridgeNothing:
+		return "never bridge messages"
+	case GuildBridgeIfPortalExists:
+		return "bridge messages in existing portals"
+	case GuildBridgeCreateOnMessage:
+		return "bridge all messages and create portals on first message"
+	case GuildBridgeEverything:
+		return "bridge all messages and create portals proactively"
+	default:
+		return ""
+	}
+}
+
 type GuildQuery struct {
 	db  *Database
 	log log.Logger
 }
 
 const (
-	guildSelect = "SELECT dcid, mxid, plain_name, name, name_set, avatar, avatar_url, avatar_set, auto_bridge_channels FROM guild"
+	guildSelect = "SELECT dcid, mxid, plain_name, name, name_set, avatar, avatar_url, avatar_set, bridging_mode FROM guild"
 )
 
 func (gq *GuildQuery) New() *Guild {
@@ -67,13 +132,13 @@ type Guild struct {
 	AvatarURL id.ContentURI
 	AvatarSet bool
 
-	AutoBridgeChannels bool
+	BridgingMode GuildBridgingMode
 }
 
 func (g *Guild) Scan(row dbutil.Scannable) *Guild {
 	var mxid sql.NullString
 	var avatarURL string
-	err := row.Scan(&g.ID, &mxid, &g.PlainName, &g.Name, &g.NameSet, &g.Avatar, &avatarURL, &g.AvatarSet, &g.AutoBridgeChannels)
+	err := row.Scan(&g.ID, &mxid, &g.PlainName, &g.Name, &g.NameSet, &g.Avatar, &avatarURL, &g.AvatarSet, &g.BridgingMode)
 	if err != nil {
 		if !errors.Is(err, sql.ErrNoRows) {
 			g.log.Errorln("Database scan failed:", err)
@@ -82,6 +147,9 @@ func (g *Guild) Scan(row dbutil.Scannable) *Guild {
 
 		return nil
 	}
+	if g.BridgingMode < GuildBridgeNothing || g.BridgingMode > GuildBridgeEverything {
+		panic(fmt.Errorf("invalid guild bridging mode %d in guild %s", g.BridgingMode, g.ID))
+	}
 	g.MXID = id.RoomID(mxid.String)
 	g.AvatarURL, _ = id.ParseContentURI(avatarURL)
 	return g
@@ -96,10 +164,10 @@ func (g *Guild) mxidPtr() *id.RoomID {
 
 func (g *Guild) Insert() {
 	query := `
-		INSERT INTO guild (dcid, mxid, plain_name, name, name_set, avatar, avatar_url, avatar_set, auto_bridge_channels)
+		INSERT INTO guild (dcid, mxid, plain_name, name, name_set, avatar, avatar_url, avatar_set, bridging_mode)
 		VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
 	`
-	_, err := g.db.Exec(query, g.ID, g.mxidPtr(), g.PlainName, g.Name, g.NameSet, g.Avatar, g.AvatarURL.String(), g.AvatarSet, g.AutoBridgeChannels)
+	_, err := g.db.Exec(query, g.ID, g.mxidPtr(), g.PlainName, g.Name, g.NameSet, g.Avatar, g.AvatarURL.String(), g.AvatarSet, g.BridgingMode)
 	if err != nil {
 		g.log.Warnfln("Failed to insert %s: %v", g.ID, err)
 		panic(err)
@@ -108,10 +176,10 @@ func (g *Guild) Insert() {
 
 func (g *Guild) Update() {
 	query := `
-		UPDATE guild SET mxid=$1, plain_name=$2, name=$3, name_set=$4, avatar=$5, avatar_url=$6, avatar_set=$7, auto_bridge_channels=$8
+		UPDATE guild SET mxid=$1, plain_name=$2, name=$3, name_set=$4, avatar=$5, avatar_url=$6, avatar_set=$7, bridging_mode=$8
 		WHERE dcid=$9
 	`
-	_, err := g.db.Exec(query, g.mxidPtr(), g.PlainName, g.Name, g.NameSet, g.Avatar, g.AvatarURL.String(), g.AvatarSet, g.AutoBridgeChannels, g.ID)
+	_, err := g.db.Exec(query, g.mxidPtr(), g.PlainName, g.Name, g.NameSet, g.Avatar, g.AvatarURL.String(), g.AvatarSet, g.BridgingMode, g.ID)
 	if err != nil {
 		g.log.Warnfln("Failed to update %s: %v", g.ID, err)
 		panic(err)

+ 2 - 2
database/upgrades/00-latest-revision.sql

@@ -1,4 +1,4 @@
--- v0 -> v13: Latest revision
+-- v0 -> v14: Latest revision
 
 CREATE TABLE guild (
     dcid       TEXT PRIMARY KEY,
@@ -10,7 +10,7 @@ CREATE TABLE guild (
     avatar_url TEXT NOT NULL,
     avatar_set BOOLEAN NOT NULL,
 
-    auto_bridge_channels BOOLEAN NOT NULL
+    bridging_mode INTEGER NOT NULL
 );
 
 CREATE TABLE portal (

+ 7 - 0
database/upgrades/14-guild-bridging-mode.sql

@@ -0,0 +1,7 @@
+-- v14: Add more modes of bridging guilds
+ALTER TABLE guild ADD COLUMN bridging_mode INTEGER NOT NULL DEFAULT 0;
+UPDATE guild SET bridging_mode=2 WHERE mxid<>'';
+UPDATE guild SET bridging_mode=3 WHERE auto_bridge_channels=true;
+ALTER TABLE guild DROP COLUMN auto_bridge_channels;
+-- only: postgres
+ALTER TABLE guild ALTER COLUMN bridging_mode DROP DEFAULT;

+ 1 - 1
guildportal.go

@@ -312,6 +312,6 @@ func (guild *Guild) RemoveMXID() {
 	guild.MXID = ""
 	guild.AvatarSet = false
 	guild.NameSet = false
-	guild.AutoBridgeChannels = false
+	guild.BridgingMode = database.GuildBridgeNothing
 	guild.Update()
 }

+ 14 - 11
provisioning.go

@@ -18,6 +18,7 @@ import (
 	"maunium.net/go/mautrix/bridge/bridgeconfig"
 	"maunium.net/go/mautrix/id"
 
+	"go.mau.fi/mautrix-discord/database"
 	"go.mau.fi/mautrix-discord/remoteauth"
 )
 
@@ -429,11 +430,12 @@ func (p *ProvisioningAPI) reconnect(w http.ResponseWriter, r *http.Request) {
 }
 
 type guildEntry struct {
-	ID         string        `json:"id"`
-	Name       string        `json:"name"`
-	AvatarURL  id.ContentURI `json:"avatar_url"`
-	MXID       id.RoomID     `json:"mxid"`
-	AutoBridge bool          `json:"auto_bridge_channels"`
+	ID           string        `json:"id"`
+	Name         string        `json:"name"`
+	AvatarURL    id.ContentURI `json:"avatar_url"`
+	MXID         id.RoomID     `json:"mxid"`
+	AutoBridge   bool          `json:"auto_bridge_channels"`
+	BridgingMode string        `json:"bridging_mode"`
 }
 
 type respGuildsList struct {
@@ -451,11 +453,12 @@ func (p *ProvisioningAPI) guildsList(w http.ResponseWriter, r *http.Request) {
 			continue
 		}
 		resp.Guilds = append(resp.Guilds, guildEntry{
-			ID:         guild.ID,
-			Name:       guild.PlainName,
-			AvatarURL:  guild.AvatarURL,
-			MXID:       guild.MXID,
-			AutoBridge: guild.AutoBridgeChannels,
+			ID:           guild.ID,
+			Name:         guild.PlainName,
+			AvatarURL:    guild.AvatarURL,
+			MXID:         guild.MXID,
+			AutoBridge:   guild.BridgingMode == database.GuildBridgeEverything,
+			BridgingMode: guild.BridgingMode.String(),
 		})
 	}
 
@@ -526,7 +529,7 @@ func (p *ProvisioningAPI) guildsUnbridge(w http.ResponseWriter, r *http.Request)
 			Error:   "Guild not found",
 			ErrCode: mautrix.MNotFound.ErrCode,
 		})
-	} else if !guild.AutoBridgeChannels && guild.MXID == "" {
+	} else if guild.BridgingMode == database.GuildBridgeNothing && guild.MXID == "" {
 		jsonResponse(w, http.StatusNotFound, Error{
 			Error:   "That guild is not bridged",
 			ErrCode: ErrCodeGuildNotBridged,

+ 16 - 11
user.go

@@ -567,12 +567,15 @@ func (user *User) Disconnect() error {
 	return nil
 }
 
-func (user *User) bridgeMessage(guildID string) bool {
+func (user *User) getGuildBridgingMode(guildID string) database.GuildBridgingMode {
 	if guildID == "" {
-		return true
+		return database.GuildBridgeEverything
 	}
 	guild := user.bridge.GetGuildByID(guildID, false)
-	return guild != nil && guild.MXID != ""
+	if guild == nil {
+		return database.GuildBridgeNothing
+	}
+	return guild.BridgingMode
 }
 
 func (user *User) readyHandler(_ *discordgo.Session, r *discordgo.Ready) {
@@ -769,7 +772,7 @@ func (user *User) handleGuild(meta *discordgo.Guild, timestamp time.Time, isInSp
 	if len(meta.Channels) > 0 {
 		for _, ch := range meta.Channels {
 			portal := user.GetPortalByMeta(ch)
-			if guild.AutoBridgeChannels && portal.MXID == "" && user.channelIsBridgeable(ch) {
+			if guild.BridgingMode >= database.GuildBridgeEverything && portal.MXID == "" && user.channelIsBridgeable(ch) {
 				err := portal.CreateMatrixRoom(user, ch)
 				if err != nil {
 					user.log.Errorfln("Failed to create portal for guild channel %s/%s in initial sync: %v", guild.ID, ch.ID, err)
@@ -843,7 +846,7 @@ func (user *User) guildUpdateHandler(_ *discordgo.Session, g *discordgo.GuildUpd
 }
 
 func (user *User) channelCreateHandler(_ *discordgo.Session, c *discordgo.ChannelCreate) {
-	if !user.bridgeMessage(c.GuildID) {
+	if user.getGuildBridgingMode(c.GuildID) < database.GuildBridgeEverything {
 		user.log.Debugfln("Ignoring channel create event in unbridged guild %s/%s", c.GuildID, c.ID)
 		return
 	}
@@ -893,7 +896,8 @@ func (user *User) channelUpdateHandler(_ *discordgo.Session, c *discordgo.Channe
 }
 
 func (user *User) pushPortalMessage(msg interface{}, typeName, channelID, guildID string) {
-	if !user.bridgeMessage(guildID) {
+	if user.getGuildBridgingMode(guildID) <= database.GuildBridgeNothing {
+		// If guild bridging mode is nothing, don't even check if the portal exists
 		return
 	}
 
@@ -907,8 +911,7 @@ func (user *User) pushPortalMessage(msg interface{}, typeName, channelID, guildI
 		}
 		portal = thread.Parent
 	}
-	// Double check because some messages don't have the guild ID specified.
-	if !user.bridgeMessage(portal.GuildID) {
+	if mode := user.getGuildBridgingMode(portal.GuildID); mode <= database.GuildBridgeNothing || (portal.MXID == "" && mode <= database.GuildBridgeIfPortalExists) {
 		return
 	}
 
@@ -1150,7 +1153,9 @@ func (user *User) bridgeGuild(guildID string, everything bool) error {
 			}
 		}
 	}
-	guild.AutoBridgeChannels = everything
+	if everything {
+		guild.BridgingMode = database.GuildBridgeEverything
+	}
 	guild.Update()
 
 	user.log.Debugfln("Subscribing to guild %s after bridging", guild.ID)
@@ -1177,10 +1182,10 @@ func (user *User) unbridgeGuild(guildID string) error {
 	}
 	guild.roomCreateLock.Lock()
 	defer guild.roomCreateLock.Unlock()
-	if !guild.AutoBridgeChannels && guild.MXID == "" {
+	if guild.BridgingMode == database.GuildBridgeNothing && guild.MXID == "" {
 		return errors.New("that guild is not bridged")
 	}
-	guild.AutoBridgeChannels = false
+	guild.BridgingMode = database.GuildBridgeNothing
 	guild.Update()
 	for _, portal := range user.bridge.GetAllPortalsInGuild(guild.ID) {
 		portal.cleanup(false)