Selaa lähdekoodia

Add support for running Discord bot commands. Fixes #35

Tulir Asokan 2 vuotta sitten
vanhempi
sitoutus
98f45991a4
6 muutettua tiedostoa jossa 352 lisäystä ja 6 poistoa
  1. 2 0
      commands.go
  2. 316 0
      commands_botinteraction.go
  3. 3 2
      go.mod
  4. 6 4
      go.sum
  5. 5 0
      portal.go
  6. 20 0
      user.go

+ 2 - 0
commands.go

@@ -54,6 +54,8 @@ func (br *DiscordBridge) RegisterCommands() {
 		cmdGuilds,
 		cmdRejoinSpace,
 		cmdDeleteAllPortals,
+		cmdExec,
+		cmdCommands,
 	)
 }
 

+ 316 - 0
commands_botinteraction.go

@@ -0,0 +1,316 @@
+// mautrix-discord - A Matrix-Discord puppeting bridge.
+// Copyright (C) 2023 Tulir Asokan
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+package main
+
+import (
+	"fmt"
+	"strconv"
+	"strings"
+	"time"
+
+	"github.com/bwmarrin/discordgo"
+	"github.com/google/shlex"
+
+	"maunium.net/go/mautrix/bridge/commands"
+)
+
+var cmdCommands = &commands.FullHandler{
+	Func:    wrapCommand(fnCommands),
+	Name:    "commands",
+	Aliases: []string{"cmds", "cs"},
+	Help: commands.HelpMeta{
+		Section:     commands.HelpSectionUnclassified,
+		Description: "View parameters of bot interaction commands on Discord",
+		Args:        "search <_query_> OR help <_command_>",
+	},
+	RequiresPortal: true,
+	RequiresLogin:  true,
+}
+
+var cmdExec = &commands.FullHandler{
+	Func:    wrapCommand(fnExec),
+	Name:    "exec",
+	Aliases: []string{"command", "cmd", "c", "exec", "e"},
+	Help: commands.HelpMeta{
+		Section:     commands.HelpSectionUnclassified,
+		Description: "Run bot interaction commands on Discord",
+		Args:        "<_command_> [_arg=value ..._]",
+	},
+	RequiresLogin:  true,
+	RequiresPortal: true,
+}
+
+func (portal *Portal) getCommand(user *User, command string) (*discordgo.ApplicationCommand, error) {
+	portal.commandsLock.Lock()
+	defer portal.commandsLock.Unlock()
+	cmd, ok := portal.commands[command]
+	if !ok {
+		results, err := user.Session.ApplicationCommandsSearch(portal.Key.ChannelID, command)
+		if err != nil {
+			return nil, err
+		}
+		for _, result := range results {
+			if result.Name == command {
+				portal.commands[result.Name] = result
+				cmd = result
+				break
+			}
+		}
+		if cmd == nil {
+			return nil, nil
+		}
+	}
+	return cmd, nil
+}
+
+func getCommandOptionTypeName(optType discordgo.ApplicationCommandOptionType) string {
+	switch optType {
+	case discordgo.ApplicationCommandOptionSubCommand:
+		return "subcommand"
+	case discordgo.ApplicationCommandOptionSubCommandGroup:
+		return "subcommand group (unsupported)"
+	case discordgo.ApplicationCommandOptionString:
+		return "string"
+	case discordgo.ApplicationCommandOptionInteger:
+		return "integer"
+	case discordgo.ApplicationCommandOptionBoolean:
+		return "boolean"
+	case discordgo.ApplicationCommandOptionUser:
+		return "user (unsupported)"
+	case discordgo.ApplicationCommandOptionChannel:
+		return "channel (unsupported)"
+	case discordgo.ApplicationCommandOptionRole:
+		return "role (unsupported)"
+	case discordgo.ApplicationCommandOptionMentionable:
+		return "mentionable (unsupported)"
+	case discordgo.ApplicationCommandOptionNumber:
+		return "number"
+	case discordgo.ApplicationCommandOptionAttachment:
+		return "attachment (unsupported)"
+	default:
+		return fmt.Sprintf("unknown type %d", optType)
+	}
+}
+
+func parseCommandOptionValue(optType discordgo.ApplicationCommandOptionType, value string) (any, error) {
+	switch optType {
+	case discordgo.ApplicationCommandOptionSubCommandGroup:
+		return nil, fmt.Errorf("subcommand groups aren't supported")
+	case discordgo.ApplicationCommandOptionString:
+		return value, nil
+	case discordgo.ApplicationCommandOptionInteger:
+		return strconv.ParseInt(value, 10, 64)
+	case discordgo.ApplicationCommandOptionBoolean:
+		return strconv.ParseBool(value)
+	case discordgo.ApplicationCommandOptionUser:
+		return nil, fmt.Errorf("user options aren't supported")
+	case discordgo.ApplicationCommandOptionChannel:
+		return nil, fmt.Errorf("channel options aren't supported")
+	case discordgo.ApplicationCommandOptionRole:
+		return nil, fmt.Errorf("role options aren't supported")
+	case discordgo.ApplicationCommandOptionMentionable:
+		return nil, fmt.Errorf("mentionable options aren't supported")
+	case discordgo.ApplicationCommandOptionNumber:
+		return strconv.ParseFloat(value, 64)
+	case discordgo.ApplicationCommandOptionAttachment:
+		return nil, fmt.Errorf("attachment options aren't supported")
+	default:
+		return nil, fmt.Errorf("unknown option type %d", optType)
+	}
+}
+
+func indent(text, with string) string {
+	split := strings.Split(text, "\n")
+	for i, part := range split {
+		split[i] = with + part
+	}
+	return strings.Join(split, "\n")
+}
+
+func formatOption(opt *discordgo.ApplicationCommandOption) string {
+	argText := fmt.Sprintf("* `%s`: %s", opt.Name, getCommandOptionTypeName(opt.Type))
+	if strings.ToLower(opt.Description) != opt.Name {
+		argText += fmt.Sprintf(" - %s", opt.Description)
+	}
+	if opt.Required {
+		argText += " (required)"
+	}
+	if len(opt.Options) > 0 {
+		subopts := make([]string, len(opt.Options))
+		for i, subopt := range opt.Options {
+			subopts[i] = indent(formatOption(subopt), "  ")
+		}
+		argText += "\n" + strings.Join(subopts, "\n")
+	}
+	return argText
+}
+
+func formatCommand(cmd *discordgo.ApplicationCommand) string {
+	baseText := fmt.Sprintf("$cmdprefix exec %s", cmd.Name)
+	if len(cmd.Options) > 0 {
+		args := make([]string, len(cmd.Options))
+		argPlaceholder := "[arg=value ...]"
+		for i, opt := range cmd.Options {
+			args[i] = formatOption(opt)
+			if opt.Required {
+				argPlaceholder = "<arg=value ...>"
+			}
+		}
+		baseText = fmt.Sprintf("`%s %s` - %s\n%s", baseText, argPlaceholder, cmd.Description, strings.Join(args, "\n"))
+	} else {
+		baseText = fmt.Sprintf("`%s` - %s", baseText, cmd.Description)
+	}
+	return baseText
+}
+
+func parseCommandOptions(opts []*discordgo.ApplicationCommandOption, subcommands []string, namedArgs map[string]string) (res []*discordgo.ApplicationCommandOptionInput, err error) {
+	subcommandDone := false
+	for _, opt := range opts {
+		optRes := &discordgo.ApplicationCommandOptionInput{
+			Type: opt.Type,
+			Name: opt.Name,
+		}
+		if opt.Type == discordgo.ApplicationCommandOptionSubCommand {
+			if !subcommandDone && len(subcommands) > 0 && subcommands[0] == opt.Name {
+				subcommandDone = true
+				optRes.Options, err = parseCommandOptions(opt.Options, subcommands[1:], namedArgs)
+				if err != nil {
+					err = fmt.Errorf("error parsing subcommand %s: %v", opt.Name, err)
+					break
+				}
+				subcommands = subcommands[1:]
+			} else {
+				continue
+			}
+		} else if argVal, ok := namedArgs[opt.Name]; ok {
+			optRes.Value, err = parseCommandOptionValue(opt.Type, argVal)
+			if err != nil {
+				err = fmt.Errorf("error parsing parameter %s: %v", opt.Name, err)
+				break
+			}
+		} else if opt.Required {
+			switch opt.Type {
+			case discordgo.ApplicationCommandOptionSubCommandGroup, discordgo.ApplicationCommandOptionUser,
+				discordgo.ApplicationCommandOptionChannel, discordgo.ApplicationCommandOptionRole,
+				discordgo.ApplicationCommandOptionMentionable, discordgo.ApplicationCommandOptionAttachment:
+				err = fmt.Errorf("missing required parameter %s (which is not supported by the bridge)", opt.Name)
+			default:
+				err = fmt.Errorf("missing required parameter %s", opt.Name)
+			}
+			break
+		} else {
+			continue
+		}
+		res = append(res, optRes)
+	}
+	if len(subcommands) > 0 {
+		err = fmt.Errorf("unparsed subcommands left over (did you forget quoting for parameters with spaces?)")
+	}
+	return
+}
+
+func executeCommand(cmd *discordgo.ApplicationCommand, args []string) (res []*discordgo.ApplicationCommandOptionInput, err error) {
+	namedArgs := map[string]string{}
+	n := 0
+	for _, arg := range args {
+		name, value, isNamed := strings.Cut(arg, "=")
+		if isNamed {
+			namedArgs[name] = value
+		} else {
+			args[n] = arg
+			n++
+		}
+	}
+	return parseCommandOptions(cmd.Options, args[:n], namedArgs)
+}
+
+func fnCommands(ce *WrappedCommandEvent) {
+	if len(ce.Args) < 2 {
+		ce.Reply("**Usage**: `$cmdprefix commands search <_query_>` OR `$cmdprefix commands help <_command_>`")
+		return
+	}
+	subcmd := strings.ToLower(ce.Args[0])
+	if subcmd == "search" {
+		results, err := ce.User.Session.ApplicationCommandsSearch(ce.Portal.Key.ChannelID, ce.Args[1])
+		if err != nil {
+			ce.Reply("Error searching for commands: %v", err)
+			return
+		}
+		formatted := make([]string, len(results))
+		ce.Portal.commandsLock.Lock()
+		for i, result := range results {
+			ce.Portal.commands[result.Name] = result
+			formatted[i] = indent(formatCommand(result), "  ")
+			formatted[i] = "*" + formatted[i][1:]
+		}
+		ce.Portal.commandsLock.Unlock()
+		ce.Reply("Found results:\n" + strings.Join(formatted, "\n"))
+	} else if subcmd == "help" {
+		command := strings.ToLower(ce.Args[1])
+		cmd, err := ce.Portal.getCommand(ce.User, command)
+		if err != nil {
+			ce.Reply("Error searching for commands: %v", err)
+		} else if cmd == nil {
+			ce.Reply("Command %q not found", command)
+		} else {
+			ce.Reply(formatCommand(cmd))
+		}
+	}
+}
+
+func fnExec(ce *WrappedCommandEvent) {
+	if len(ce.Args) == 0 {
+		ce.Reply("**Usage**: `$cmdprefix exec <command> [arg=value ...]`")
+		return
+	}
+	args, err := shlex.Split(ce.RawArgs)
+	if err != nil {
+		ce.Reply("Error parsing args with shlex: %v", err)
+		return
+	}
+	command := strings.ToLower(args[0])
+	cmd, err := ce.Portal.getCommand(ce.User, command)
+	if err != nil {
+		ce.Reply("Error searching for commands: %v", err)
+	} else if cmd == nil {
+		ce.Reply("Command %q not found", command)
+	} else if options, err := executeCommand(cmd, args[1:]); err != nil {
+		ce.Reply("Error parsing arguments: %v\n\n**Usage:** "+formatCommand(cmd), err)
+	} else {
+		nonce := generateNonce()
+		ce.User.pendingInteractionsLock.Lock()
+		ce.User.pendingInteractions[nonce] = ce
+		ce.User.pendingInteractionsLock.Unlock()
+		err = ce.User.Session.SendInteractions(ce.Portal.GuildID, ce.Portal.Key.ChannelID, cmd, options, nonce)
+		if err != nil {
+			ce.Reply("Error sending interaction: %v", err)
+			ce.User.pendingInteractionsLock.Lock()
+			delete(ce.User.pendingInteractions, nonce)
+			ce.User.pendingInteractionsLock.Unlock()
+		} else {
+			go func() {
+				time.Sleep(10 * time.Second)
+				ce.User.pendingInteractionsLock.Lock()
+				if _, stillWaiting := ce.User.pendingInteractions[nonce]; stillWaiting {
+					delete(ce.User.pendingInteractions, nonce)
+					ce.Reply("Timed out waiting for interaction success")
+				}
+				ce.User.pendingInteractionsLock.Unlock()
+			}()
+		}
+	}
+}

+ 3 - 2
go.mod

@@ -5,6 +5,7 @@ go 1.18
 require (
 	github.com/bwmarrin/discordgo v0.26.1
 	github.com/gabriel-vasile/mimetype v1.4.1
+	github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
 	github.com/gorilla/mux v1.8.0
 	github.com/gorilla/websocket v1.5.0
 	github.com/lib/pq v1.10.7
@@ -13,7 +14,7 @@ require (
 	github.com/stretchr/testify v1.8.1
 	github.com/yuin/goldmark v1.5.3
 	maunium.net/go/maulogger/v2 v2.3.2
-	maunium.net/go/mautrix v0.13.1-0.20230128124647-7d98a9f8e3a6
+	maunium.net/go/mautrix v0.13.1-0.20230129104640-4a2a7653e437
 )
 
 require (
@@ -33,4 +34,4 @@ require (
 	maunium.net/go/mauflag v1.0.0 // indirect
 )
 
-replace github.com/bwmarrin/discordgo => github.com/beeper/discordgo v0.0.0-20230128134018-766d08cb045e
+replace github.com/bwmarrin/discordgo => github.com/beeper/discordgo v0.0.0-20230129113304-139f45f429a0

+ 6 - 4
go.sum

@@ -1,6 +1,6 @@
 github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
-github.com/beeper/discordgo v0.0.0-20230128134018-766d08cb045e h1:R0Db6p3gANvV2Hk8lbSSlPDNG3zzeOM8nyZHmLl3tkI=
-github.com/beeper/discordgo v0.0.0-20230128134018-766d08cb045e/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
+github.com/beeper/discordgo v0.0.0-20230129113304-139f45f429a0 h1:RrF9ffkMyEsUtZqWR/m/KXSrYbpyT7bkuL+KY8pexSE=
+github.com/beeper/discordgo v0.0.0-20230129113304-139f45f429a0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
 github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
@@ -8,6 +8,8 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
 github.com/gabriel-vasile/mimetype v1.4.1 h1:TRWk7se+TOjCYgRth7+1/OYLNiRNIotknkFtf/dnN7Q=
 github.com/gabriel-vasile/mimetype v1.4.1/go.mod h1:05Vi0w3Y9c/lNvJOdmIwvrrAhX3rYhfQQCaf9VJcv7M=
 github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
+github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
+github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
 github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
 github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
 github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
@@ -75,5 +77,5 @@ maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
 maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
 maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0=
 maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A=
-maunium.net/go/mautrix v0.13.1-0.20230128124647-7d98a9f8e3a6 h1:c2HKxT3wYxWS213BXaWDY3UkHGfXGmhzOv4h1OKZm20=
-maunium.net/go/mautrix v0.13.1-0.20230128124647-7d98a9f8e3a6/go.mod h1:gYMQPsZ9lQpyKlVp+DGwOuc9LIcE/c8GZW2CvKHISgM=
+maunium.net/go/mautrix v0.13.1-0.20230129104640-4a2a7653e437 h1:BMfeE1yJNs97rIXCRzIY284g7dXa7E6OaM0HCWpddwU=
+maunium.net/go/mautrix v0.13.1-0.20230129104640-4a2a7653e437/go.mod h1:gYMQPsZ9lQpyKlVp+DGwOuc9LIcE/c8GZW2CvKHISgM=

+ 5 - 0
portal.go

@@ -60,6 +60,9 @@ type Portal struct {
 
 	recentMessages *util.RingBuffer[string, *discordgo.Message]
 
+	commands     map[string]*discordgo.ApplicationCommand
+	commandsLock sync.RWMutex
+
 	currentlyTyping     []id.UserID
 	currentlyTypingLock sync.Mutex
 }
@@ -232,6 +235,8 @@ func (br *DiscordBridge) NewPortal(dbPortal *database.Portal) *Portal {
 		matrixMessages:  make(chan portalMatrixMessage, br.Config.Bridge.PortalMessageBuffer),
 
 		recentMessages: util.NewRingBuffer[string, *discordgo.Message](recentMessageBufferSize),
+
+		commands: make(map[string]*discordgo.ApplicationCommand),
 	}
 
 	go portal.messageLoop()

+ 20 - 0
user.go

@@ -59,6 +59,9 @@ type User struct {
 	markedOpened     map[string]time.Time
 	markedOpenedLock sync.Mutex
 
+	pendingInteractions     map[string]*WrappedCommandEvent
+	pendingInteractionsLock sync.Mutex
+
 	nextDiscordUploadID atomic.Int32
 }
 
@@ -197,6 +200,8 @@ func (br *DiscordBridge) NewUser(dbUser *database.User) *User {
 
 		markedOpened:    make(map[string]time.Time),
 		PermissionLevel: br.Config.Bridge.Permissions.Get(dbUser.MXID),
+
+		pendingInteractions: make(map[string]*WrappedCommandEvent),
 	}
 	user.nextDiscordUploadID.Store(rand.Int31n(100))
 	user.BridgeState = br.NewBridgeStateQueue(user, user.log)
@@ -540,6 +545,8 @@ func (user *User) Connect() error {
 	user.Session.AddHandler(user.messageAckHandler)
 	user.Session.AddHandler(user.typingStartHandler)
 
+	user.Session.AddHandler(user.interactionSuccessHandler)
+
 	user.Session.Identify.Presence.Status = "online"
 
 	return user.Session.Open()
@@ -963,6 +970,19 @@ func (user *User) typingStartHandler(_ *discordgo.Session, t *discordgo.TypingSt
 	}
 }
 
+func (user *User) interactionSuccessHandler(_ *discordgo.Session, s *discordgo.InteractionSuccess) {
+	user.pendingInteractionsLock.Lock()
+	defer user.pendingInteractionsLock.Unlock()
+	ce, ok := user.pendingInteractions[s.Nonce]
+	if !ok {
+		user.log.Debugfln("Got interaction success for unknown interaction %s/%s", s.Nonce, s.ID)
+	} else {
+		user.log.Infofln("Got interaction success for pending interaction %s/%s", s.Nonce, s.ID)
+		ce.React("✅")
+		delete(user.pendingInteractions, s.Nonce)
+	}
+}
+
 func (user *User) ensureInvited(intent *appservice.IntentAPI, roomID id.RoomID, isDirect bool) bool {
 	if intent == nil {
 		intent = user.bridge.Bot