Browse Source

Prevent handling too many attachments in parallel

Tulir Asokan 2 years ago
parent
commit
b4249488db
4 changed files with 46 additions and 7 deletions
  1. 38 5
      attachments.go
  2. 1 0
      go.mod
  3. 2 0
      go.sum
  4. 5 2
      main.go

+ 38 - 5
attachments.go

@@ -3,6 +3,7 @@ package main
 import (
 import (
 	"bytes"
 	"bytes"
 	"context"
 	"context"
+	"errors"
 	"fmt"
 	"fmt"
 	"image"
 	"image"
 	"io"
 	"io"
@@ -12,6 +13,7 @@ import (
 	"path/filepath"
 	"path/filepath"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
+	"sync"
 	"time"
 	"time"
 
 
 	"github.com/bwmarrin/discordgo"
 	"github.com/bwmarrin/discordgo"
@@ -28,7 +30,7 @@ import (
 	"go.mau.fi/mautrix-discord/database"
 	"go.mau.fi/mautrix-discord/database"
 )
 )
 
 
-func downloadDiscordAttachment(url string) ([]byte, error) {
+func downloadDiscordAttachment(url string, maxSize int64) ([]byte, error) {
 	req, err := http.NewRequest(http.MethodGet, url, nil)
 	req, err := http.NewRequest(http.MethodGet, url, nil)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -46,7 +48,22 @@ func downloadDiscordAttachment(url string) ([]byte, error) {
 		data, _ := io.ReadAll(resp.Body)
 		data, _ := io.ReadAll(resp.Body)
 		return nil, fmt.Errorf("unexpected status %d downloading %s: %s", resp.StatusCode, url, data)
 		return nil, fmt.Errorf("unexpected status %d downloading %s: %s", resp.StatusCode, url, data)
 	}
 	}
-	return io.ReadAll(resp.Body)
+	if resp.Header.Get("Content-Length") != "" {
+		length, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
+		if err != nil {
+			return nil, fmt.Errorf("failed to parse content length: %w", err)
+		} else if length > maxSize {
+			return nil, fmt.Errorf("attachment too large (%d > %d)", length, maxSize)
+		}
+		return io.ReadAll(resp.Body)
+	} else {
+		var mbe *http.MaxBytesError
+		data, err := io.ReadAll(http.MaxBytesReader(nil, resp.Body, maxSize))
+		if err != nil && errors.As(err, &mbe) {
+			return nil, fmt.Errorf("attachment too large (over %d)", maxSize)
+		}
+		return data, err
+	}
 }
 }
 
 
 func uploadDiscordAttachment(url string, data []byte) error {
 func uploadDiscordAttachment(url string, data []byte) error {
@@ -99,7 +116,7 @@ func downloadMatrixAttachment(intent *appservice.IntentAPI, content *event.Messa
 	return data, nil
 	return data, nil
 }
 }
 
 
-func (br *DiscordBridge) uploadMatrixAttachment(intent *appservice.IntentAPI, data []byte, url string, encrypt bool, meta AttachmentMeta) (*database.File, error) {
+func (br *DiscordBridge) uploadMatrixAttachment(intent *appservice.IntentAPI, data []byte, url string, encrypt bool, meta AttachmentMeta, semaWg *sync.WaitGroup) (*database.File, error) {
 	dbFile := br.DB.File.New()
 	dbFile := br.DB.File.New()
 	dbFile.Timestamp = time.Now()
 	dbFile.Timestamp = time.Now()
 	dbFile.URL = url
 	dbFile.URL = url
@@ -135,7 +152,9 @@ func (br *DiscordBridge) uploadMatrixAttachment(intent *appservice.IntentAPI, da
 		dbFile.MXC = resp.ContentURI
 		dbFile.MXC = resp.ContentURI
 		req.MXC = resp.ContentURI
 		req.MXC = resp.ContentURI
 		req.UnstableUploadURL = resp.UnstableUploadURL
 		req.UnstableUploadURL = resp.UnstableUploadURL
+		semaWg.Add(1)
 		go func() {
 		go func() {
+			defer semaWg.Done()
 			_, err = intent.UploadMedia(req)
 			_, err = intent.UploadMedia(req)
 			if err != nil {
 			if err != nil {
 				br.Log.Errorfln("Failed to upload %s: %v", req.MXC, err)
 				br.Log.Errorfln("Failed to upload %s: %v", req.MXC, err)
@@ -259,8 +278,21 @@ func (br *DiscordBridge) copyAttachmentToMatrix(intent *appservice.IntentAPI, ur
 				}
 				}
 			}
 			}
 
 
+			const attachmentSizeVal = 1
+			onceErr = br.parallelAttachmentSemaphore.Acquire(context.Background(), attachmentSizeVal)
+			if onceErr != nil {
+				onceErr = fmt.Errorf("failed to acquire semaphore: %w", onceErr)
+				return
+			}
+			var semaWg sync.WaitGroup
+			semaWg.Add(1)
+			go func() {
+				semaWg.Wait()
+				br.parallelAttachmentSemaphore.Release(attachmentSizeVal)
+			}()
+
 			var data []byte
 			var data []byte
-			data, onceErr = downloadDiscordAttachment(url)
+			data, onceErr = downloadDiscordAttachment(url, br.MediaConfig.UploadSize)
 			if onceErr != nil {
 			if onceErr != nil {
 				return
 				return
 			}
 			}
@@ -273,7 +305,7 @@ func (br *DiscordBridge) copyAttachmentToMatrix(intent *appservice.IntentAPI, ur
 				}
 				}
 			}
 			}
 
 
-			onceDBFile, onceErr = br.uploadMatrixAttachment(intent, data, url, encrypt, meta)
+			onceDBFile, onceErr = br.uploadMatrixAttachment(intent, data, url, encrypt, meta, &semaWg)
 			if onceErr != nil {
 			if onceErr != nil {
 				return
 				return
 			}
 			}
@@ -281,6 +313,7 @@ func (br *DiscordBridge) copyAttachmentToMatrix(intent *appservice.IntentAPI, ur
 				onceDBFile.Insert(nil)
 				onceDBFile.Insert(nil)
 			}
 			}
 			br.attachmentTransfers.Delete(transferKey)
 			br.attachmentTransfers.Delete(transferKey)
+			semaWg.Done()
 			return
 			return
 		})
 		})
 	}
 	}

+ 1 - 0
go.mod

@@ -15,6 +15,7 @@ require (
 	github.com/stretchr/testify v1.8.4
 	github.com/stretchr/testify v1.8.4
 	github.com/yuin/goldmark v1.5.4
 	github.com/yuin/goldmark v1.5.4
 	golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1
 	golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1
+	golang.org/x/sync v0.3.0
 	maunium.net/go/maulogger/v2 v2.4.1
 	maunium.net/go/maulogger/v2 v2.4.1
 	maunium.net/go/mautrix v0.15.4-0.20230623121006-d8b15c18dc3f
 	maunium.net/go/mautrix v0.15.4-0.20230623121006-d8b15c18dc3f
 )
 )

+ 2 - 0
go.sum

@@ -51,6 +51,8 @@ golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERs
 golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
 golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
 golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU=
 golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU=
 golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ=
 golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ=
+golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
+golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
 golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
 golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=

+ 5 - 2
main.go

@@ -20,6 +20,7 @@ import (
 	_ "embed"
 	_ "embed"
 	"sync"
 	"sync"
 
 
+	"golang.org/x/sync/semaphore"
 	"maunium.net/go/mautrix/bridge"
 	"maunium.net/go/mautrix/bridge"
 	"maunium.net/go/mautrix/bridge/commands"
 	"maunium.net/go/mautrix/bridge/commands"
 	"maunium.net/go/mautrix/id"
 	"maunium.net/go/mautrix/id"
@@ -73,7 +74,8 @@ type DiscordBridge struct {
 	puppetsByCustomMXID map[id.UserID]*Puppet
 	puppetsByCustomMXID map[id.UserID]*Puppet
 	puppetsLock         sync.Mutex
 	puppetsLock         sync.Mutex
 
 
-	attachmentTransfers *util.SyncMap[attachmentKey, *util.ReturnableOnce[*database.File]]
+	attachmentTransfers         *util.SyncMap[attachmentKey, *util.ReturnableOnce[*database.File]]
+	parallelAttachmentSemaphore *semaphore.Weighted
 }
 }
 
 
 func (br *DiscordBridge) GetExampleConfig() string {
 func (br *DiscordBridge) GetExampleConfig() string {
@@ -170,7 +172,8 @@ func main() {
 		puppets:             make(map[string]*Puppet),
 		puppets:             make(map[string]*Puppet),
 		puppetsByCustomMXID: make(map[id.UserID]*Puppet),
 		puppetsByCustomMXID: make(map[id.UserID]*Puppet),
 
 
-		attachmentTransfers: util.NewSyncMap[attachmentKey, *util.ReturnableOnce[*database.File]](),
+		attachmentTransfers:         util.NewSyncMap[attachmentKey, *util.ReturnableOnce[*database.File]](),
+		parallelAttachmentSemaphore: semaphore.NewWeighted(3),
 	}
 	}
 	br.Bridge = bridge.Bridge{
 	br.Bridge = bridge.Bridge{
 		Name:              "mautrix-discord",
 		Name:              "mautrix-discord",