Browse Source

Add (dis|re)connect commands and provision api

Also fixed a number of data races.
Gary Kramlich 3 years ago
parent
commit
4b87ea1cc7
8 changed files with 292 additions and 115 deletions
  1. 6 1
      bridge/commandhandler.go
  2. 53 5
      bridge/commands.go
  3. 1 1
      bridge/matrix.go
  4. 92 16
      bridge/provisioning.go
  5. 91 40
      bridge/user.go
  6. 18 52
      database/user.go
  7. 15 0
      remoteauth/client.go
  8. 16 0
      remoteauth/serverpackets.go

+ 6 - 1
bridge/commandhandler.go

@@ -101,7 +101,12 @@ func (h *commandHandler) handle(roomID id.RoomID, user *User, message string, re
 	if err != nil {
 		h.log.Warnf("Command %q failed: %v", message, err)
 
-		cmd.globals.reply("unexpected failure")
+		output := buf.String()
+		if output != "" {
+			cmd.globals.reply(output)
+		} else {
+			cmd.globals.reply("unexpected failure")
+		}
 
 		return
 	}

+ 53 - 5
bridge/commands.go

@@ -46,10 +46,12 @@ func (g *globals) reply(msg string) {
 type commands struct {
 	globals
 
-	Help    helpCmd    `kong:"cmd,help='Displays this message.'"`
-	Login   loginCmd   `kong:"cmd,help='Log in to Discord.'"`
-	Logout  logoutCmd  `kong:"cmd,help='Log out of Discord.'"`
-	Version versionCmd `kong:"cmd,help='Displays the version of the bridge.'"`
+	Disconnect disconnectCmd `kong:"cmd,help='Disconnect from Discord'"`
+	Help       helpCmd       `kong:"cmd,help='Displays this message.'"`
+	Login      loginCmd      `kong:"cmd,help='Log in to Discord.'"`
+	Logout     logoutCmd     `kong:"cmd,help='Log out of Discord.'"`
+	Reconnect  reconnectCmd  `kong:"cmd,help='Reconnect to Discord'"`
+	Version    versionCmd    `kong:"cmd,help='Displays the version of the bridge.'"`
 }
 
 type helpCmd struct {
@@ -87,6 +89,12 @@ func (c *versionCmd) Run(g *globals) error {
 type loginCmd struct{}
 
 func (l *loginCmd) Run(g *globals) error {
+	if g.user.LoggedIn() {
+		fmt.Fprintf(g.context.Stdout, "You are already logged in")
+
+		return fmt.Errorf("user already logged in")
+	}
+
 	client, err := remoteauth.New()
 	if err != nil {
 		return err
@@ -145,7 +153,7 @@ func (l *logoutCmd) Run(g *globals) error {
 		return fmt.Errorf("user is not logged in")
 	}
 
-	err := g.user.DeleteSession()
+	err := g.user.Logout()
 	if err != nil {
 		fmt.Fprintln(g.context.Stdout, "Failed to log out")
 
@@ -156,3 +164,43 @@ func (l *logoutCmd) Run(g *globals) error {
 
 	return nil
 }
+
+type disconnectCmd struct{}
+
+func (d *disconnectCmd) Run(g *globals) error {
+	if !g.user.Connected() {
+		fmt.Fprintln(g.context.Stdout, "You are not connected")
+
+		return fmt.Errorf("user is not connected")
+	}
+
+	if err := g.user.Disconnect(); err != nil {
+		fmt.Fprintln(g.context.Stdout, "Failed to disconnect")
+
+		return err
+	}
+
+	fmt.Fprintln(g.context.Stdout, "Successfully disconnected")
+
+	return nil
+}
+
+type reconnectCmd struct{}
+
+func (r *reconnectCmd) Run(g *globals) error {
+	if g.user.Connected() {
+		fmt.Fprintln(g.context.Stdout, "You are already connected")
+
+		return fmt.Errorf("user is already connected")
+	}
+
+	if err := g.user.Connect(); err != nil {
+		fmt.Fprintln(g.context.Stdout, "Failed to connect")
+
+		return err
+	}
+
+	fmt.Fprintln(g.context.Stdout, "Successfully connected")
+
+	return nil
+}

+ 1 - 1
bridge/matrix.go

@@ -162,7 +162,7 @@ func (mh *matrixHandler) handleBotInvite(evt *event.Event) {
 	mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.Welcome)
 
 	if evt.RoomID == user.ManagementRoom {
-		if user.HasSession() {
+		if user.Connected() {
 			mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.Connected)
 		} else {
 			mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.NotConnected)

+ 92 - 16
bridge/provisioning.go

@@ -42,9 +42,11 @@ func newProvisioningAPI(bridge *Bridge) *ProvisioningAPI {
 
 	r.Use(p.authMiddleware)
 
-	r.HandleFunc("/ping", p.Ping).Methods(http.MethodGet)
-	r.HandleFunc("/login", p.Login).Methods(http.MethodGet)
-	r.HandleFunc("/logout", p.Logout).Methods(http.MethodPost)
+	r.HandleFunc("/disconnect", p.disconnect).Methods(http.MethodPost)
+	r.HandleFunc("/ping", p.ping).Methods(http.MethodGet)
+	r.HandleFunc("/login", p.login).Methods(http.MethodGet)
+	r.HandleFunc("/logout", p.logout).Methods(http.MethodPost)
+	r.HandleFunc("/reconnect", p.reconnect).Methods(http.MethodPost)
 
 	return p
 }
@@ -138,38 +140,78 @@ var upgrader = websocket.Upgrader{
 }
 
 // Handlers
-func (p *ProvisioningAPI) Ping(w http.ResponseWriter, r *http.Request) {
+func (p *ProvisioningAPI) disconnect(w http.ResponseWriter, r *http.Request) {
+	user := r.Context().Value("user").(*User)
+
+	if !user.Connected() {
+		jsonResponse(w, http.StatusConflict, Error{
+			Error:   "You're not connected to discord",
+			ErrCode: "not connected",
+		})
+
+		return
+	}
+
+	if err := user.Disconnect(); err != nil {
+		jsonResponse(w, http.StatusInternalServerError, Error{
+			Error:   "Failed to disconnect from discord",
+			ErrCode: "failed to disconnect",
+		})
+	} else {
+		jsonResponse(w, http.StatusOK, Response{
+			Success: true,
+			Status:  "Disconnected from Discord",
+		})
+	}
+}
+
+func (p *ProvisioningAPI) ping(w http.ResponseWriter, r *http.Request) {
 	user := r.Context().Value("user").(*User)
 
 	discord := map[string]interface{}{
-		"has_session":     user.Session != nil,
-		"management_room": user.ManagementRoom,
-		"conn":            nil,
+		"logged_in": user.LoggedIn(),
+		"connected": user.Connected(),
+		"conn":      nil,
 	}
 
+	user.Lock()
 	if user.ID != "" {
 		discord["id"] = user.ID
 	}
 
 	if user.Session != nil {
+		user.Session.Lock()
 		discord["conn"] = map[string]interface{}{
 			"last_heartbeat_ack":  user.Session.LastHeartbeatAck,
 			"last_heartbeat_sent": user.Session.LastHeartbeatSent,
 		}
+		user.Session.Unlock()
 	}
 
 	resp := map[string]interface{}{
-		"mxid":    user.MXID,
-		"discord": discord,
+		"discord":         discord,
+		"management_room": user.ManagementRoom,
+		"mxid":            user.MXID,
 	}
 
+	user.Unlock()
+
 	jsonResponse(w, http.StatusOK, resp)
 }
 
-func (p *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
+func (p *ProvisioningAPI) logout(w http.ResponseWriter, r *http.Request) {
 	user := r.Context().Value("user").(*User)
 	force := strings.ToLower(r.URL.Query().Get("force")) != "false"
 
+	if !user.LoggedIn() {
+		jsonResponse(w, http.StatusNotFound, Error{
+			Error:   "You're not logged in",
+			ErrCode: "not logged in",
+		})
+
+		return
+	}
+
 	if user.Session == nil {
 		if force {
 			jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."})
@@ -183,7 +225,7 @@ func (p *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	err := user.DeleteSession()
+	err := user.Logout()
 	if err != nil {
 		user.log.Warnln("Error while logging out:", err)
 
@@ -200,7 +242,7 @@ func (p *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
 	jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."})
 }
 
-func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
+func (p *ProvisioningAPI) login(w http.ResponseWriter, r *http.Request) {
 	userID := r.URL.Query().Get("user_id")
 	user := p.bridge.GetUserByMXID(id.UserID(userID))
 
@@ -220,7 +262,7 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
 	go func() {
 		// Read everything so SetCloseHandler() works
 		for {
-			_, _, err = c.ReadMessage()
+			_, _, err := c.ReadMessage()
 			if err != nil {
 				break
 			}
@@ -236,6 +278,15 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
 		return nil
 	})
 
+	if user.LoggedIn() {
+		c.WriteJSON(Error{
+			Error:   "You're already logged into Discord",
+			ErrCode: "already logged in",
+		})
+
+		return
+	}
+
 	client, err := remoteauth.New()
 	if err != nil {
 		user.log.Errorf("Failed to log in from provisioning API:", err)
@@ -280,6 +331,9 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
 				return
 			}
 
+			user.ID = discordUser.UserID
+			user.Update()
+
 			if err := user.Login(discordUser.Token); err != nil {
 				c.WriteJSON(Error{
 					Error:   "Failed to connect to Discord",
@@ -291,9 +345,6 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
 				return
 			}
 
-			user.ID = discordUser.UserID
-			user.Update()
-
 			c.WriteJSON(map[string]interface{}{
 				"success": true,
 				"id":      user.ID,
@@ -305,3 +356,28 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
 		}
 	}
 }
+
+func (p *ProvisioningAPI) reconnect(w http.ResponseWriter, r *http.Request) {
+	user := r.Context().Value("user").(*User)
+
+	if user.Connected() {
+		jsonResponse(w, http.StatusConflict, Error{
+			Error:   "You're already connected to discord",
+			ErrCode: "already connected",
+		})
+
+		return
+	}
+
+	if err := user.Connect(); err != nil {
+		jsonResponse(w, http.StatusInternalServerError, Error{
+			Error:   "Failed to connect to discord",
+			ErrCode: "failed to connect",
+		})
+	} else {
+		jsonResponse(w, http.StatusOK, Response{
+			Success: true,
+			Status:  "Connected to Discord",
+		})
+	}
+}

+ 91 - 40
bridge/user.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 	"fmt"
 	"strings"
+	"sync"
 
 	"github.com/bwmarrin/discordgo"
 	"github.com/skip2/go-qrcode"
@@ -17,11 +18,20 @@ import (
 	"gitlab.com/beeper/discord/database"
 )
 
+var (
+	ErrNotConnected = errors.New("not connected")
+	ErrNotLoggedIn  = errors.New("not logged in")
+)
+
 type User struct {
 	*database.User
 
+	sync.Mutex
+
 	bridge *Bridge
 	log    log.Logger
+
+	Session *discordgo.Session
 }
 
 func (b *Bridge) loadUser(dbUser *database.User, mxid *id.UserID) *User {
@@ -140,10 +150,6 @@ func (u *User) SetManagementRoom(roomID id.RoomID) {
 	u.Update()
 }
 
-func (u *User) HasSession() bool {
-	return u.User.Session != nil
-}
-
 func (u *User) sendQRCode(bot *appservice.IntentAPI, roomID id.RoomID, code string) (id.EventID, error) {
 	url, err := u.uploadQRCode(code)
 	if err != nil {
@@ -189,23 +195,65 @@ func (u *User) Login(token string) error {
 		return fmt.Errorf("No token specified")
 	}
 
-	err := u.User.NewSession(token)
-	if err != nil {
-		return err
-	}
+	u.Token = token
+	u.Update()
 
 	return u.Connect()
 }
 
 func (u *User) LoggedIn() bool {
+	u.Lock()
+	defer u.Unlock()
+
+	return u.Token != ""
+}
+
+func (u *User) Logout() error {
+	u.Lock()
+	defer u.Unlock()
+
+	if u.Session == nil {
+		return ErrNotLoggedIn
+	}
+
+	if err := u.Session.Close(); err != nil {
+		return err
+	}
+
+	u.Session = nil
+
+	u.Token = ""
+	u.Update()
+
+	return nil
+}
+
+func (u *User) Connected() bool {
+	u.Lock()
+	defer u.Unlock()
+
 	return u.Session != nil
 }
 
 func (u *User) Connect() error {
+	u.Lock()
+	defer u.Unlock()
+
+	if u.Token == "" {
+		return ErrNotLoggedIn
+	}
+
 	u.log.Debugln("connecting to discord")
 
+	session, err := discordgo.New(u.Token)
+	if err != nil {
+		return err
+	}
+
+	u.Session = session
+
 	// get our user info
-	user, err := u.User.Session.User("@me")
+	user, err := u.Session.User("@me")
 	if err != nil {
 		return err
 	}
@@ -213,37 +261,40 @@ func (u *User) Connect() error {
 	u.User.ID = user.ID
 
 	// Add our event handlers
-	u.User.Session.AddHandler(u.connectedHandler)
-	u.User.Session.AddHandler(u.disconnectedHandler)
-
-	u.User.Session.AddHandler(u.channelCreateHandler)
-	u.User.Session.AddHandler(u.channelDeleteHandler)
-	u.User.Session.AddHandler(u.channelPinsUpdateHandler)
-	u.User.Session.AddHandler(u.channelUpdateHandler)
-
-	u.User.Session.AddHandler(u.messageCreateHandler)
-	u.User.Session.AddHandler(u.messageDeleteHandler)
-	u.User.Session.AddHandler(u.messageUpdateHandler)
-	u.User.Session.AddHandler(u.reactionAddHandler)
-	u.User.Session.AddHandler(u.reactionRemoveHandler)
-
-	// u.User.Session.Identify.Capabilities = 125
-	// // Setup our properties
-	// u.User.Session.Identify.Properties = discordgo.IdentifyProperties{
-	// 	OS:                "Windows",
-	// 	OSVersion:         "10",
-	// 	Browser:           "Chrome",
-	// 	BrowserUserAgent:  "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.159 Safari/537.36",
-	// 	BrowserVersion:    "92.0.4515.159",
-	// 	Referrer:          "https://discord.com/channels/@me",
-	// 	ReferringDomain:   "discord.com",
-	// 	ClientBuildNumber: "83364",
-	// 	ReleaseChannel:    "stable",
-	// }
-
-	u.User.Session.Identify.Presence.Status = "online"
-
-	return u.User.Session.Open()
+	u.Session.AddHandler(u.connectedHandler)
+	u.Session.AddHandler(u.disconnectedHandler)
+
+	u.Session.AddHandler(u.channelCreateHandler)
+	u.Session.AddHandler(u.channelDeleteHandler)
+	u.Session.AddHandler(u.channelPinsUpdateHandler)
+	u.Session.AddHandler(u.channelUpdateHandler)
+
+	u.Session.AddHandler(u.messageCreateHandler)
+	u.Session.AddHandler(u.messageDeleteHandler)
+	u.Session.AddHandler(u.messageUpdateHandler)
+	u.Session.AddHandler(u.reactionAddHandler)
+	u.Session.AddHandler(u.reactionRemoveHandler)
+
+	u.Session.Identify.Presence.Status = "online"
+
+	return u.Session.Open()
+}
+
+func (u *User) Disconnect() error {
+	u.Lock()
+	defer u.Unlock()
+
+	if u.Session == nil {
+		return ErrNotConnected
+	}
+
+	if err := u.Session.Close(); err != nil {
+		return err
+	}
+
+	u.Session = nil
+
+	return nil
 }
 
 func (u *User) connectedHandler(s *discordgo.Session, c *discordgo.Connect) {

+ 18 - 52
database/user.go

@@ -3,8 +3,6 @@ package database
 import (
 	"database/sql"
 
-	"github.com/bwmarrin/discordgo"
-
 	log "maunium.net/go/maulogger/v2"
 	"maunium.net/go/mautrix/id"
 )
@@ -18,38 +16,7 @@ type User struct {
 
 	ManagementRoom id.RoomID
 
-	Session *discordgo.Session
-}
-
-// NewSession is just used to create the session and update the database. It
-// should only be called by bridge.User.Connect which will continue setting up
-// event handlers and everything else.
-func (u *User) NewSession(token string) error {
-	session, err := discordgo.New(token)
-	if err != nil {
-		return err
-	}
-
-	u.Session = session
-
-	u.Update()
-
-	return nil
-}
-
-// DeleteSession tries to logout and delete the session from the database.
-func (u *User) DeleteSession() error {
-	err := u.Session.Close()
-
-	if err != nil {
-		u.log.Warnfln("failed to close the session for %s: %v", u.ID, err)
-	}
-
-	u.Session = nil
-
-	u.Update()
-
-	return nil
+	Token string
 }
 
 func (u *User) Scan(row Scannable) *User {
@@ -65,31 +32,25 @@ func (u *User) Scan(row Scannable) *User {
 	}
 
 	if token.Valid {
-		if err := u.NewSession(token.String); err != nil {
-			u.log.Errorln("Failed to login: ", err)
-		}
+		u.Token = token.String
 	}
 
 	return u
 }
 
-func (u *User) sessionNonptr() discordgo.Session {
-	if u.Session != nil {
-		return *u.Session
-	}
-
-	return discordgo.Session{}
-}
-
 func (u *User) Insert() {
-	session := u.sessionNonptr()
-
 	query := "INSERT INTO \"user\"" +
 		" (mxid, id, management_room, token)" +
 		" VALUES ($1, $2, $3, $4);"
 
-	_, err := u.db.Exec(query, u.MXID, u.ID, u.ManagementRoom,
-		session.Identify.Token)
+	var token sql.NullString
+
+	if u.Token != "" {
+		token.String = u.Token
+		token.Valid = true
+	}
+
+	_, err := u.db.Exec(query, u.MXID, u.ID, u.ManagementRoom, token)
 
 	if err != nil {
 		u.log.Warnfln("Failed to insert %s: %v", u.MXID, err)
@@ -97,13 +58,18 @@ func (u *User) Insert() {
 }
 
 func (u *User) Update() {
-	session := u.sessionNonptr()
-
 	query := "UPDATE \"user\" SET" +
 		" id=$1, management_room=$2, token=$3" +
 		" WHERE mxid=$4;"
 
-	_, err := u.db.Exec(query, u.ID, u.ManagementRoom, session.Identify.Token, u.MXID)
+	var token sql.NullString
+
+	if u.Token != "" {
+		token.String = u.Token
+		token.Valid = true
+	}
+
+	_, err := u.db.Exec(query, u.ID, u.ManagementRoom, token, u.MXID)
 
 	if err != nil {
 		u.log.Warnfln("Failed to update %q: %v", u.MXID, err)

+ 15 - 0
remoteauth/client.go

@@ -8,11 +8,14 @@ import (
 	"encoding/base64"
 	"encoding/json"
 	"net/http"
+	"sync"
 
 	"github.com/gorilla/websocket"
 )
 
 type Client struct {
+	sync.Mutex
+
 	URL    string
 	Origin string
 
@@ -48,6 +51,9 @@ func New() (*Client, error) {
 // Dial will start the QRCode login process. ctx may be used to abandon the
 // process.
 func (c *Client) Dial(ctx context.Context, qrChan chan string, doneChan chan struct{}) error {
+	c.Lock()
+	defer c.Unlock()
+
 	header := http.Header{
 		"Origin": []string{c.Origin},
 	}
@@ -68,10 +74,16 @@ func (c *Client) Dial(ctx context.Context, qrChan chan string, doneChan chan str
 }
 
 func (c *Client) Result() (User, error) {
+	c.Lock()
+	defer c.Unlock()
+
 	return c.user, c.err
 }
 
 func (c *Client) close() error {
+	c.Lock()
+	defer c.Unlock()
+
 	if c.closed {
 		return nil
 	}
@@ -89,6 +101,9 @@ func (c *Client) close() error {
 }
 
 func (c *Client) write(p clientPacket) error {
+	c.Lock()
+	defer c.Unlock()
+
 	payload, err := json.Marshal(p)
 	if err != nil {
 		return err

+ 16 - 0
remoteauth/serverpackets.go

@@ -22,10 +22,15 @@ func (c *Client) processMessages() {
 	defer c.close()
 
 	for {
+		c.Lock()
 		_, packet, err := c.conn.ReadMessage()
+		c.Unlock()
+
 		if err != nil {
 			if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
+				c.Lock()
 				c.err = err
+				c.Unlock()
 			}
 
 			return
@@ -33,7 +38,9 @@ func (c *Client) processMessages() {
 
 		raw := rawPacket{}
 		if err := json.Unmarshal(packet, &raw); err != nil {
+			c.Lock()
 			c.err = err
+			c.Unlock()
 
 			return
 		}
@@ -57,7 +64,9 @@ func (c *Client) processMessages() {
 		}
 
 		if err := json.Unmarshal(packet, dest); err != nil {
+			c.Lock()
 			c.err = err
+			c.Unlock()
 
 			return
 		}
@@ -65,7 +74,9 @@ func (c *Client) processMessages() {
 		op := dest.(serverPacket)
 		err = op.process(c)
 		if err != nil {
+			c.Lock()
 			c.err = err
+			c.Unlock()
 
 			return
 		}
@@ -92,7 +103,10 @@ func (h *serverHello) process(client *Client) error {
 			case <-ticker.C:
 				h := clientHeartbeat{}
 				if err := h.send(client); err != nil {
+					client.Lock()
 					client.err = err
+					client.Unlock()
+
 					return
 				}
 			}
@@ -104,8 +118,10 @@ func (h *serverHello) process(client *Client) error {
 
 		<-time.After(duration)
 
+		client.Lock()
 		client.err = fmt.Errorf("Timed out after %s", duration)
 		client.close()
+		client.Unlock()
 	}()
 
 	i := clientInit{}