ソースを参照

Add provisioning api for bridging guilds

This isn't exactly the same as the puppet bridge, basically it removes the
puppetID from the url and just works against the logged in user.

There is a known errata right now where some times all channels aren't left
when unbridging a guild. I figured it's more important to get some more testing
on this while I try to figure that out in the mean time.

Also when you call /guilds/:guildID/joinentire, it joins each channel serially.
I tried to make this concurrent but I maxed out the rate limit. We can do this
in the future, but again, rather get this into everyone's hands sooner rather
than later. I also did the same thing with unbridge and had to revert for the
same reasons.

Refs #8
Gary Kramlich 3 年 前
コミット
b66556ad99
4 ファイル変更176 行追加2 行削除
  1. 16 0
      bridge/discord.go
  2. 9 0
      bridge/portal.go
  3. 75 0
      bridge/provisioning.go
  4. 76 2
      bridge/user.go

+ 16 - 0
bridge/discord.go

@@ -0,0 +1,16 @@
+package bridge
+
+import (
+	"github.com/bwmarrin/discordgo"
+)
+
+func channelIsBridgeable(channel *discordgo.Channel) bool {
+	switch channel.Type {
+	case discordgo.ChannelTypeGuildText:
+		fallthrough
+	case discordgo.ChannelTypeGuildNews:
+		return true
+	}
+
+	return false
+}

+ 9 - 0
bridge/portal.go

@@ -659,6 +659,15 @@ func (p *Portal) handleMatrixLeave(sender *User) {
 	p.cleanupIfEmpty()
 	p.cleanupIfEmpty()
 }
 }
 
 
+func (p *Portal) leave(sender *User) {
+	if p.MXID == "" {
+		return
+	}
+
+	intent := p.bridge.GetPuppetByID(sender.ID).IntentFor(p)
+	intent.LeaveRoom(p.MXID)
+}
+
 func (p *Portal) delete() {
 func (p *Portal) delete() {
 	p.Portal.Delete()
 	p.Portal.Delete()
 	p.bridge.portalsLock.Lock()
 	p.bridge.portalsLock.Lock()

+ 75 - 0
bridge/provisioning.go

@@ -11,6 +11,7 @@ import (
 	"strings"
 	"strings"
 	"time"
 	"time"
 
 
+	"github.com/gorilla/mux"
 	"github.com/gorilla/websocket"
 	"github.com/gorilla/websocket"
 	log "maunium.net/go/maulogger/v2"
 	log "maunium.net/go/maulogger/v2"
 
 
@@ -48,6 +49,12 @@ func newProvisioningAPI(bridge *Bridge) *ProvisioningAPI {
 	r.HandleFunc("/logout", p.logout).Methods(http.MethodPost)
 	r.HandleFunc("/logout", p.logout).Methods(http.MethodPost)
 	r.HandleFunc("/reconnect", p.reconnect).Methods(http.MethodPost)
 	r.HandleFunc("/reconnect", p.reconnect).Methods(http.MethodPost)
 
 
+	// Setup the guild endpoints
+	r.HandleFunc("/guilds", p.guildsList).Methods(http.MethodGet)
+	r.HandleFunc("/guilds/{guildID}/bridge", p.guildsBridge).Methods(http.MethodPost)
+	r.HandleFunc("/guilds/{guildID}/unbridge", p.guildsUnbridge).Methods(http.MethodPost)
+	r.HandleFunc("/guilds/{guildID}/joinentire", p.guildsJoinEntire).Methods(http.MethodPost)
+
 	return p
 	return p
 }
 }
 
 
@@ -381,3 +388,71 @@ func (p *ProvisioningAPI) reconnect(w http.ResponseWriter, r *http.Request) {
 		})
 		})
 	}
 	}
 }
 }
+
+func (p *ProvisioningAPI) guildsList(w http.ResponseWriter, r *http.Request) {
+	user := r.Context().Value("user").(*User)
+
+	user.guildsLock.Lock()
+	defer user.guildsLock.Unlock()
+
+	data := make([]map[string]interface{}, len(user.guilds))
+	idx := 0
+	for _, guild := range user.guilds {
+		data[idx] = map[string]interface{}{
+			"name":    guild.GuildName,
+			"id":      guild.GuildID,
+			"bridged": guild.Bridge,
+		}
+
+		idx++
+	}
+
+	jsonResponse(w, http.StatusOK, data)
+}
+
+func (p *ProvisioningAPI) guildsBridge(w http.ResponseWriter, r *http.Request) {
+	user := r.Context().Value("user").(*User)
+
+	guildID, _ := mux.Vars(r)["guildID"]
+
+	if err := user.bridgeGuild(guildID, false); err != nil {
+		jsonResponse(w, http.StatusNotFound, Error{
+			Error:   err.Error(),
+			ErrCode: "M_NOT_FOUND",
+		})
+	} else {
+		w.WriteHeader(http.StatusCreated)
+	}
+}
+
+func (p *ProvisioningAPI) guildsUnbridge(w http.ResponseWriter, r *http.Request) {
+	user := r.Context().Value("user").(*User)
+
+	guildID, _ := mux.Vars(r)["guildID"]
+
+	if err := user.unbridgeGuild(guildID); err != nil {
+		jsonResponse(w, http.StatusNotFound, Error{
+			Error:   err.Error(),
+			ErrCode: "M_NOT_FOUND",
+		})
+
+		return
+	}
+
+	w.WriteHeader(http.StatusNoContent)
+}
+
+func (p *ProvisioningAPI) guildsJoinEntire(w http.ResponseWriter, r *http.Request) {
+	user := r.Context().Value("user").(*User)
+
+	guildID, _ := mux.Vars(r)["guildID"]
+
+	if err := user.bridgeGuild(guildID, true); err != nil {
+		jsonResponse(w, http.StatusNotFound, Error{
+			Error:   err.Error(),
+			ErrCode: "M_NOT_FOUND",
+		})
+	} else {
+		w.WriteHeader(http.StatusCreated)
+	}
+}

+ 76 - 2
bridge/user.go

@@ -503,7 +503,7 @@ func (u *User) guildUpdateHandler(s *discordgo.Session, g *discordgo.GuildUpdate
 	}
 	}
 }
 }
 
 
-func (u *User) channelCreateHandler(s *discordgo.Session, c *discordgo.ChannelCreate) {
+func (u *User) createChannel(c *discordgo.Channel) {
 	key := database.NewPortalKey(c.ID, u.User.ID)
 	key := database.NewPortalKey(c.ID, u.User.ID)
 	portal := u.bridge.GetPortalByID(key)
 	portal := u.bridge.GetPortalByID(key)
 
 
@@ -525,7 +525,11 @@ func (u *User) channelCreateHandler(s *discordgo.Session, c *discordgo.ChannelCr
 
 
 	portal.Update()
 	portal.Update()
 
 
-	portal.createMatrixRoom(u, c.Channel)
+	portal.createMatrixRoom(u, c)
+}
+
+func (u *User) channelCreateHandler(s *discordgo.Session, c *discordgo.ChannelCreate) {
+	u.createChannel(c.Channel)
 }
 }
 
 
 func (u *User) channelDeleteHandler(s *discordgo.Session, c *discordgo.ChannelDelete) {
 func (u *User) channelDeleteHandler(s *discordgo.Session, c *discordgo.ChannelDelete) {
@@ -738,3 +742,73 @@ func (u *User) updateDirectChats(chats map[id.UserID][]id.RoomID) {
 		u.log.Warnln("Failed to update m.direct list:", err)
 		u.log.Warnln("Failed to update m.direct list:", err)
 	}
 	}
 }
 }
+
+func (u *User) bridgeGuild(guildID string, everything bool) error {
+	u.guildsLock.Lock()
+	defer u.guildsLock.Unlock()
+
+	guild, found := u.guilds[guildID]
+	if !found {
+		return fmt.Errorf("guildID not found")
+	}
+
+	// Update the guild
+	guild.Bridge = true
+	guild.Upsert()
+
+	// If this is a full bridge, create portals for all the channels
+	if everything {
+		channels, err := u.Session.GuildChannels(guildID)
+		if err != nil {
+			return err
+		}
+
+		for _, channel := range channels {
+			if channelIsBridgeable(channel) {
+				u.createChannel(channel)
+			}
+		}
+	}
+
+	return nil
+}
+
+func (u *User) unbridgeGuild(guildID string) error {
+	u.guildsLock.Lock()
+	defer u.guildsLock.Unlock()
+
+	guild, exists := u.guilds[guildID]
+	if !exists {
+		return fmt.Errorf("guildID not found")
+	}
+
+	if !guild.Bridge {
+		return fmt.Errorf("guild not bridged")
+	}
+
+	// First update the guild so we don't have any other go routines recreating
+	// channels we're about to destroy.
+	guild.Bridge = false
+	guild.Upsert()
+
+	// Now run through the channels in the guild and remove any portals we
+	// have for them.
+	channels, err := u.Session.GuildChannels(guildID)
+	if err != nil {
+		return err
+	}
+
+	for _, channel := range channels {
+		if channelIsBridgeable(channel) {
+			key := database.PortalKey{
+				ChannelID: channel.ID,
+				Receiver:  u.ID,
+			}
+
+			portal := u.bridge.GetPortalByID(key)
+			portal.leave(u)
+		}
+	}
+
+	return nil
+}