瀏覽代碼

Cache files copied to Matrix

Tulir Asokan 2 年之前
父節點
當前提交
9ca27a8df6
共有 6 個文件被更改,包括 247 次插入39 次删除
  1. 49 29
      attachments.go
  2. 5 0
      database/database.go
  3. 132 0
      database/file.go
  4. 19 1
      database/upgrades/00-latest-revision.sql
  5. 18 0
      database/upgrades/11-cache-reuploaded-files.sql
  6. 24 9
      portal.go

+ 49 - 29
attachments.go

@@ -7,18 +7,19 @@ import (
 	"io"
 	"net/http"
 	"strings"
-
-	"maunium.net/go/mautrix/crypto/attachment"
+	"time"
 
 	"github.com/bwmarrin/discordgo"
 
 	"maunium.net/go/mautrix"
 	"maunium.net/go/mautrix/appservice"
+	"maunium.net/go/mautrix/crypto/attachment"
 	"maunium.net/go/mautrix/event"
-	"maunium.net/go/mautrix/id"
+
+	"go.mau.fi/mautrix-discord/database"
 )
 
-func (portal *Portal) downloadDiscordAttachment(url string) ([]byte, error) {
+func downloadDiscordAttachment(url string) ([]byte, error) {
 	req, err := http.NewRequest(http.MethodGet, url, nil)
 	if err != nil {
 		return nil, err
@@ -68,48 +69,67 @@ func (portal *Portal) downloadMatrixAttachment(content *event.MessageEventConten
 	return data, nil
 }
 
-func (portal *Portal) uploadMatrixAttachment(intent *appservice.IntentAPI, data []byte, content *event.MessageEventContent) error {
-	content.Info.Size = len(data)
-	if content.Info.Width == 0 && content.Info.Height == 0 && strings.HasPrefix(content.Info.MimeType, "image/") {
+func (br *DiscordBridge) uploadMatrixAttachment(intent *appservice.IntentAPI, data []byte, url string, encrypt bool, attachmentID, mime string) (*database.File, error) {
+	dbFile := br.DB.File.New()
+	dbFile.Timestamp = time.Now()
+	dbFile.URL = url
+	dbFile.ID = attachmentID
+	dbFile.Size = len(data)
+	if strings.HasPrefix(mime, "image/") {
 		cfg, _, _ := image.DecodeConfig(bytes.NewReader(data))
-		content.Info.Width = cfg.Width
-		content.Info.Height = cfg.Height
+		dbFile.Width = cfg.Width
+		dbFile.Height = cfg.Height
 	}
 
-	uploadMime := content.Info.MimeType
-	var file *attachment.EncryptedFile
-	if portal.Encrypted {
-		file = attachment.NewEncryptedFile()
-		file.EncryptInPlace(data)
+	uploadMime := mime
+	if encrypt {
+		dbFile.Encrypted = true
+		dbFile.DecryptionInfo = attachment.NewEncryptedFile()
+		dbFile.DecryptionInfo.EncryptInPlace(data)
 		uploadMime = "application/octet-stream"
 	}
 	req := mautrix.ReqUploadMedia{
 		ContentBytes: data,
 		ContentType:  uploadMime,
 	}
-	var mxc id.ContentURI
-	if portal.bridge.Config.Homeserver.AsyncMedia {
-		uploaded, err := intent.UnstableUploadAsync(req)
+	if br.Config.Homeserver.AsyncMedia {
+		resp, err := intent.UnstableCreateMXC()
 		if err != nil {
-			return err
+			return nil, err
 		}
-		mxc = uploaded.ContentURI
+		dbFile.MXC = resp.ContentURI
+		req.UnstableMXC = resp.ContentURI
+		req.UploadURL = resp.UploadURL
+		go func() {
+			_, err = intent.UploadMedia(req)
+			if err != nil {
+				br.Log.Errorfln("Failed to upload %s: %v", req.UnstableMXC, err)
+				dbFile.Delete()
+			}
+		}()
 	} else {
 		uploaded, err := intent.UploadMedia(req)
 		if err != nil {
-			return err
+			return nil, err
 		}
-		mxc = uploaded.ContentURI
+		dbFile.MXC = uploaded.ContentURI
 	}
+	dbFile.Insert(nil)
+	return dbFile, nil
+}
 
-	if file != nil {
-		content.File = &event.EncryptedFileInfo{
-			EncryptedFile: *file,
-			URL:           mxc.CUString(),
+func (br *DiscordBridge) copyAttachmentToMatrix(intent *appservice.IntentAPI, url string, encrypt bool, attachmentID, mime string) (*database.File, error) {
+	dbFile := br.DB.File.Get(url, encrypt)
+	if dbFile == nil {
+		data, err := downloadDiscordAttachment(url)
+		if err != nil {
+			return nil, err
 		}
-	} else {
-		content.URL = mxc.CUString()
-	}
 
-	return nil
+		dbFile, err = br.uploadMatrixAttachment(intent, data, url, encrypt, attachmentID, mime)
+		if err != nil {
+			return nil, err
+		}
+	}
+	return dbFile, nil
 }

+ 5 - 0
database/database.go

@@ -24,6 +24,7 @@ type Database struct {
 	Emoji    *EmojiQuery
 	Guild    *GuildQuery
 	Role     *RoleQuery
+	File     *FileQuery
 }
 
 func New(baseDB *dbutil.Database, log maulogger.Logger) *Database {
@@ -65,6 +66,10 @@ func New(baseDB *dbutil.Database, log maulogger.Logger) *Database {
 		db:  db,
 		log: log.Sub("Role"),
 	}
+	db.File = &FileQuery{
+		db:  db,
+		log: log.Sub("File"),
+	}
 	return db
 }
 

+ 132 - 0
database/file.go

@@ -0,0 +1,132 @@
+package database
+
+import (
+	"database/sql"
+	"encoding/json"
+	"errors"
+	"time"
+
+	log "maunium.net/go/maulogger/v2"
+
+	"maunium.net/go/mautrix/crypto/attachment"
+	"maunium.net/go/mautrix/id"
+	"maunium.net/go/mautrix/util/dbutil"
+)
+
+type FileQuery struct {
+	db  *Database
+	log log.Logger
+}
+
+// language=postgresql
+const (
+	fileSelect = "SELECT url, encrypted, id, mxc, size, width, height, decryption_info, timestamp FROM discord_file"
+	fileInsert = `
+		INSERT INTO discord_file (url, encrypted, id, mxc, size, width, height, decryption_info, timestamp)
+		VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
+	`
+)
+
+func (fq *FileQuery) New() *File {
+	return &File{
+		db:  fq.db,
+		log: fq.log,
+	}
+}
+
+func (fq *FileQuery) Get(url string, encrypted bool) *File {
+	query := fileSelect + " WHERE url=$1 AND encrypted=$2"
+	return fq.New().Scan(fq.db.QueryRow(query, url, encrypted))
+}
+
+type File struct {
+	db  *Database
+	log log.Logger
+
+	URL       string
+	Encrypted bool
+
+	ID  string
+	MXC id.ContentURI
+
+	Size   int
+	Width  int
+	Height int
+
+	DecryptionInfo *attachment.EncryptedFile
+
+	Timestamp time.Time
+}
+
+func (f *File) Scan(row dbutil.Scannable) *File {
+	var fileID sql.NullString
+	var decryptionInfo []byte
+	var width, height sql.NullInt32
+	var timestamp int64
+	var mxc string
+	err := row.Scan(&f.URL, &f.Encrypted, &fileID, &mxc, &f.Size, &width, &height, &decryptionInfo, &timestamp)
+	if err != nil {
+		if !errors.Is(err, sql.ErrNoRows) {
+			f.log.Errorln("Database scan failed:", err)
+			panic(err)
+		}
+		return nil
+	}
+	f.ID = fileID.String
+	f.Timestamp = time.UnixMilli(timestamp)
+	f.Width = int(width.Int32)
+	f.Height = int(height.Int32)
+	f.MXC, err = id.ParseContentURI(mxc)
+	if err != nil {
+		f.log.Errorfln("Failed to parse content URI %s: %v", mxc, err)
+		panic(err)
+	}
+	if decryptionInfo != nil {
+		err = json.Unmarshal(decryptionInfo, &f.DecryptionInfo)
+		if err != nil {
+			f.log.Errorfln("Failed to unmarshal decryption info of %v: %v", f.MXC, err)
+			panic(err)
+		}
+	}
+	return f
+}
+
+func positiveIntToNullInt32(val int) (ptr sql.NullInt32) {
+	if val > 0 {
+		ptr.Valid = true
+		ptr.Int32 = int32(val)
+	}
+	return
+}
+
+func (f *File) Insert(txn dbutil.Execable) {
+	if txn == nil {
+		txn = f.db
+	}
+	var err error
+	var decryptionInfo []byte
+	if f.DecryptionInfo != nil {
+		decryptionInfo, err = json.Marshal(f.DecryptionInfo)
+		if err != nil {
+			f.log.Warnfln("Failed to marshal decryption info of %v: %v", f.MXC, err)
+			panic(err)
+		}
+	}
+	_, err = txn.Exec(fileInsert,
+		f.URL, f.Encrypted, strPtr(f.ID), f.MXC.String(), f.Size,
+		positiveIntToNullInt32(f.Width), positiveIntToNullInt32(f.Height),
+		decryptionInfo, f.Timestamp.UnixMilli(),
+	)
+	if err != nil {
+		f.log.Warnfln("Failed to insert copied file %v: %v", f.MXC, err)
+		panic(err)
+	}
+}
+
+func (f *File) Delete() {
+	_, err := f.db.Exec("DELETE FROM discord_file WHERE url=$1 AND encrypted=$2", f.URL, f.Encrypted)
+	if err != nil {
+		f.log.Warnfln("Failed to delete copied file %v: %v", f.MXC, err)
+		panic(err)
+	}
+}

+ 19 - 1
database/upgrades/00-latest-revision.sql

@@ -1,4 +1,4 @@
--- v0 -> v10: Latest revision
+-- v0 -> v11: Latest revision
 
 CREATE TABLE guild (
     dcid       TEXT PRIMARY KEY,
@@ -150,3 +150,21 @@ CREATE TABLE role (
     PRIMARY KEY (dc_guild_id, dcid),
     CONSTRAINT role_guild_fkey FOREIGN KEY (dc_guild_id) REFERENCES guild (dcid) ON DELETE CASCADE
 );
+
+CREATE TABLE discord_file (
+    url       TEXT,
+    encrypted BOOLEAN,
+
+    id  TEXT,
+    mxc TEXT NOT NULL,
+
+    size   BIGINT NOT NULL,
+    width  INTEGER,
+    height INTEGER,
+
+    decryption_info jsonb,
+
+    timestamp BIGINT NOT NULL,
+
+    PRIMARY KEY (url, encrypted)
+);

+ 18 - 0
database/upgrades/11-cache-reuploaded-files.sql

@@ -0,0 +1,18 @@
+-- v11: Cache files copied from Discord to Matrix
+CREATE TABLE discord_file (
+    url       TEXT,
+    encrypted BOOLEAN,
+
+    id  TEXT,
+    mxc TEXT NOT NULL,
+
+    size   BIGINT NOT NULL,
+    width  INTEGER,
+    height INTEGER,
+
+    decryption_info jsonb,
+
+    timestamp BIGINT NOT NULL,
+
+    PRIMARY KEY (url, encrypted)
+);

+ 24 - 9
portal.go

@@ -523,31 +523,46 @@ func (portal *Portal) markMessageHandled(discordID string, editIndex int, author
 	msg.MassInsert(parts)
 }
 
-func (portal *Portal) sendMediaFailedMessage(intent *appservice.IntentAPI, bridgeErr error) {
+func (portal *Portal) sendMediaFailedMessage(intent *appservice.IntentAPI, bridgeErr error) id.EventID {
 	content := &event.MessageEventContent{
 		Body:    fmt.Sprintf("Failed to bridge media: %v", bridgeErr),
 		MsgType: event.MsgNotice,
 	}
 
-	_, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, 0)
+	resp, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, 0)
 	if err != nil {
 		portal.log.Warnfln("Failed to send media error message to matrix: %v", err)
+		return ""
 	}
+	return resp.EventID
 }
 
 const DiscordStickerSize = 160
 
 func (portal *Portal) handleDiscordFile(typeName string, intent *appservice.IntentAPI, id, url string, content *event.MessageEventContent, ts time.Time, threadRelation *event.RelatesTo) *database.MessagePart {
-	data, err := portal.downloadDiscordAttachment(url)
+	dbFile, err := portal.bridge.copyAttachmentToMatrix(intent, url, portal.Encrypted, id, content.Info.MimeType)
 	if err != nil {
-		portal.sendMediaFailedMessage(intent, err)
+		errorEventID := portal.sendMediaFailedMessage(intent, err)
+		if errorEventID != "" {
+			return &database.MessagePart{
+				AttachmentID: id,
+				MXID:         errorEventID,
+			}
+		}
 		return nil
 	}
-
-	err = portal.uploadMatrixAttachment(intent, data, content)
-	if err != nil {
-		portal.sendMediaFailedMessage(intent, err)
-		return nil
+	content.Info.Size = dbFile.Size
+	if content.Info.Width == 0 && content.Info.Height == 0 {
+		content.Info.Width = dbFile.Width
+		content.Info.Height = dbFile.Height
+	}
+	if dbFile.DecryptionInfo != nil {
+		content.File = &event.EncryptedFileInfo{
+			EncryptedFile: *dbFile.DecryptionInfo,
+			URL:           dbFile.MXC.CUString(),
+		}
+	} else {
+		content.URL = dbFile.MXC.CUString()
 	}
 
 	evtType := event.EventMessage