Эх сурвалжийг харах

Reject unsupported message types and add some more conversions

Closes #524
Closes #539
Fixes #510
Tulir Asokan 2 жил өмнө
parent
commit
76c9660849
5 өөрчлөгдсөн 109 нэмэгдсэн , 27 устгасан
  1. 1 0
      go.mod
  2. 2 0
      go.sum
  3. 3 0
      messagetracking.go
  4. 102 26
      portal.go
  5. 1 1
      urlpreview.go

+ 1 - 0
go.mod

@@ -22,6 +22,7 @@ require (
 	filippo.io/edwards25519 v1.0.0 // indirect
 	filippo.io/edwards25519 v1.0.0 // indirect
 	github.com/beorn7/perks v1.0.1 // indirect
 	github.com/beorn7/perks v1.0.1 // indirect
 	github.com/cespare/xxhash/v2 v2.1.2 // indirect
 	github.com/cespare/xxhash/v2 v2.1.2 // indirect
+	github.com/chai2010/webp v1.1.1 // indirect
 	github.com/golang/protobuf v1.5.2 // indirect
 	github.com/golang/protobuf v1.5.2 // indirect
 	github.com/mattn/go-colorable v0.1.12 // indirect
 	github.com/mattn/go-colorable v0.1.12 // indirect
 	github.com/mattn/go-isatty v0.0.14 // indirect
 	github.com/mattn/go-isatty v0.0.14 // indirect

+ 2 - 0
go.sum

@@ -5,6 +5,8 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
 github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
 github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
 github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
 github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
 github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
 github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
+github.com/chai2010/webp v1.1.1 h1:jTRmEccAJ4MGrhFOrPMpNGIJ/eybIgwKpcACsrTEapk=
+github.com/chai2010/webp v1.1.1/go.mod h1:0XVwvZWdjjdxpUEIf7b9g9VkHFnInUSYujwqTLEuldU=
 github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
 github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
 github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
 github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
 github.com/go-kit/log v0.2.0/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0=
 github.com/go-kit/log v0.2.0/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0=

+ 3 - 0
messagetracking.go

@@ -45,6 +45,7 @@ var (
 	errMediaDecryptFailed          = errors.New("failed to decrypt media")
 	errMediaDecryptFailed          = errors.New("failed to decrypt media")
 	errMediaConvertFailed          = errors.New("failed to convert media")
 	errMediaConvertFailed          = errors.New("failed to convert media")
 	errMediaWhatsAppUploadFailed   = errors.New("failed to upload media to WhatsApp")
 	errMediaWhatsAppUploadFailed   = errors.New("failed to upload media to WhatsApp")
+	errMediaUnsupportedType        = errors.New("unsupported media type")
 	errTargetNotFound              = errors.New("target event not found")
 	errTargetNotFound              = errors.New("target event not found")
 	errReactionDatabaseNotFound    = errors.New("reaction database entry not found")
 	errReactionDatabaseNotFound    = errors.New("reaction database entry not found")
 	errReactionTargetNotFound      = errors.New("reaction target message not found")
 	errReactionTargetNotFound      = errors.New("reaction target message not found")
@@ -73,6 +74,8 @@ func errorToStatusReason(err error) (reason event.MessageStatusReason, status ev
 		errors.Is(err, errBroadcastReactionNotSupported),
 		errors.Is(err, errBroadcastReactionNotSupported),
 		errors.Is(err, errBroadcastSendDisabled):
 		errors.Is(err, errBroadcastSendDisabled):
 		return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, ""
 		return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, ""
+	case errors.Is(err, errMediaUnsupportedType):
+		return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, err.Error()
 	case errors.Is(err, errTimeoutBeforeHandling):
 	case errors.Is(err, errTimeoutBeforeHandling):
 		return event.MessageStatusTooOld, event.MessageStatusRetriable, true, true, "the message was too old when it reached the bridge, so it was not handled"
 		return event.MessageStatusTooOld, event.MessageStatusRetriable, true, true, "the message was too old when it reached the bridge, so it was not handled"
 	case errors.Is(err, context.DeadlineExceeded):
 	case errors.Is(err, context.DeadlineExceeded):

+ 102 - 26
portal.go

@@ -35,9 +35,9 @@ import (
 	"sync"
 	"sync"
 	"time"
 	"time"
 
 
+	"github.com/chai2010/webp"
 	"github.com/tidwall/gjson"
 	"github.com/tidwall/gjson"
 	"golang.org/x/image/draw"
 	"golang.org/x/image/draw"
-	"golang.org/x/image/webp"
 	"google.golang.org/protobuf/proto"
 	"google.golang.org/protobuf/proto"
 
 
 	log "maunium.net/go/maulogger/v2"
 	log "maunium.net/go/maulogger/v2"
@@ -2526,7 +2526,7 @@ func (portal *Portal) convertMediaMessage(intent *appservice.IntentAPI, source *
 		if portal.bridge.Config.Bridge.HistorySync.MediaRequests.AutoRequestMedia && isBackfill {
 		if portal.bridge.Config.Bridge.HistorySync.MediaRequests.AutoRequestMedia && isBackfill {
 			errorText += " Media will be automatically requested from your phone later."
 			errorText += " Media will be automatically requested from your phone later."
 		} else {
 		} else {
-			errorText += ` React with the \u267b (recycle) emoji to request this media from your phone.`
+			errorText += " React with the \u267b (recycle) emoji to request this media from your phone."
 		}
 		}
 
 
 		return portal.makeMediaBridgeFailureMessage(info, err, converted, &FailedMediaKeys{
 		return portal.makeMediaBridgeFailureMessage(info, err, converted, &FailedMediaKeys{
@@ -2737,7 +2737,7 @@ func (portal *Portal) requestMediaRetry(user *User, eventID id.EventID, mediaKey
 const thumbnailMaxSize = 72
 const thumbnailMaxSize = 72
 const thumbnailMinSize = 24
 const thumbnailMinSize = 24
 
 
-func createJPEGThumbnailAndGetSize(source []byte) ([]byte, int, int, error) {
+func createThumbnailAndGetSize(source []byte, pngThumbnail bool) ([]byte, int, int, error) {
 	src, _, err := image.Decode(bytes.NewReader(source))
 	src, _, err := image.Decode(bytes.NewReader(source))
 	if err != nil {
 	if err != nil {
 		return nil, 0, 0, fmt.Errorf("failed to decode thumbnail: %w", err)
 		return nil, 0, 0, fmt.Errorf("failed to decode thumbnail: %w", err)
@@ -2771,19 +2771,23 @@ func createJPEGThumbnailAndGetSize(source []byte) ([]byte, int, int, error) {
 	}
 	}
 
 
 	var buf bytes.Buffer
 	var buf bytes.Buffer
-	err = jpeg.Encode(&buf, img, &jpeg.Options{Quality: jpeg.DefaultQuality})
+	if pngThumbnail {
+		err = png.Encode(&buf, img)
+	} else {
+		err = jpeg.Encode(&buf, img, &jpeg.Options{Quality: jpeg.DefaultQuality})
+	}
 	if err != nil {
 	if err != nil {
 		return nil, width, height, fmt.Errorf("failed to re-encode thumbnail: %w", err)
 		return nil, width, height, fmt.Errorf("failed to re-encode thumbnail: %w", err)
 	}
 	}
 	return buf.Bytes(), width, height, nil
 	return buf.Bytes(), width, height, nil
 }
 }
 
 
-func createJPEGThumbnail(source []byte) ([]byte, error) {
-	data, _, _, err := createJPEGThumbnailAndGetSize(source)
+func createThumbnail(source []byte, png bool) ([]byte, error) {
+	data, _, _, err := createThumbnailAndGetSize(source, png)
 	return data, err
 	return data, err
 }
 }
 
 
-func (portal *Portal) downloadThumbnail(ctx context.Context, original []byte, thumbnailURL id.ContentURIString, eventID id.EventID) ([]byte, error) {
+func (portal *Portal) downloadThumbnail(ctx context.Context, original []byte, thumbnailURL id.ContentURIString, eventID id.EventID, png bool) ([]byte, error) {
 	if len(thumbnailURL) == 0 {
 	if len(thumbnailURL) == 0 {
 		// just fall back to making thumbnail of original
 		// just fall back to making thumbnail of original
 	} else if mxc, err := thumbnailURL.Parse(); err != nil {
 	} else if mxc, err := thumbnailURL.Parse(); err != nil {
@@ -2791,9 +2795,9 @@ func (portal *Portal) downloadThumbnail(ctx context.Context, original []byte, th
 	} else if thumbnail, err := portal.MainIntent().DownloadBytesContext(ctx, mxc); err != nil {
 	} else if thumbnail, err := portal.MainIntent().DownloadBytesContext(ctx, mxc); err != nil {
 		portal.log.Warnfln("Failed to download thumbnail in %s: %v (falling back to generating thumbnail from source)", eventID, err)
 		portal.log.Warnfln("Failed to download thumbnail in %s: %v (falling back to generating thumbnail from source)", eventID, err)
 	} else {
 	} else {
-		return createJPEGThumbnail(thumbnail)
+		return createThumbnail(thumbnail, png)
 	}
 	}
-	return createJPEGThumbnail(original)
+	return createThumbnail(original, png)
 }
 }
 
 
 func (portal *Portal) convertWebPtoPNG(webpImage []byte) ([]byte, error) {
 func (portal *Portal) convertWebPtoPNG(webpImage []byte) ([]byte, error) {
@@ -2803,7 +2807,21 @@ func (portal *Portal) convertWebPtoPNG(webpImage []byte) ([]byte, error) {
 	}
 	}
 
 
 	var pngBuffer bytes.Buffer
 	var pngBuffer bytes.Buffer
-	if err := png.Encode(&pngBuffer, webpDecoded); err != nil {
+	if err = png.Encode(&pngBuffer, webpDecoded); err != nil {
+		return nil, fmt.Errorf("failed to encode png image: %w", err)
+	}
+
+	return pngBuffer.Bytes(), nil
+}
+
+func (portal *Portal) convertToWebP(img []byte) ([]byte, error) {
+	webpDecoded, _, err := image.Decode(bytes.NewReader(img))
+	if err != nil {
+		return nil, fmt.Errorf("failed to decode image: %w", err)
+	}
+
+	var pngBuffer bytes.Buffer
+	if err = webp.Encode(&pngBuffer, webpDecoded, nil); err != nil {
 		return nil, fmt.Errorf("failed to encode png image: %w", err)
 		return nil, fmt.Errorf("failed to encode png image: %w", err)
 	}
 	}
 
 
@@ -2815,6 +2833,7 @@ func (portal *Portal) preprocessMatrixMedia(ctx context.Context, sender *User, r
 	var caption string
 	var caption string
 	var mentionedJIDs []string
 	var mentionedJIDs []string
 	var hasHTMLCaption bool
 	var hasHTMLCaption bool
+	isSticker := string(content.MsgType) == event.EventSticker.Type
 	if content.FileName != "" && content.Body != content.FileName {
 	if content.FileName != "" && content.Body != content.FileName {
 		fileName = content.FileName
 		fileName = content.FileName
 		caption = content.Body
 		caption = content.Body
@@ -2844,22 +2863,58 @@ func (portal *Portal) preprocessMatrixMedia(ctx context.Context, sender *User, r
 			return nil, util.NewDualError(errMediaDecryptFailed, err)
 			return nil, util.NewDualError(errMediaDecryptFailed, err)
 		}
 		}
 	}
 	}
-	if mediaType == whatsmeow.MediaVideo && content.GetInfo().MimeType == "image/gif" {
-		data, err = ffmpeg.ConvertBytes(ctx, data, ".mp4", []string{"-f", "gif"}, []string{
-			"-pix_fmt", "yuv420p", "-c:v", "libx264", "-movflags", "+faststart",
-			"-filter:v", "crop='floor(in_w/2)*2:floor(in_h/2)*2'",
-		}, content.GetInfo().MimeType)
-		if err != nil {
-			return nil, util.NewDualError(fmt.Errorf("%w (gif to mp4)", errMediaConvertFailed), err)
+	mimeType := content.GetInfo().MimeType
+	var convertErr error
+	// Allowed mime types from https://developers.facebook.com/docs/whatsapp/on-premises/reference/media
+	switch {
+	case isSticker:
+		if mimeType != "image/webp" {
+			data, convertErr = portal.convertToWebP(data)
+			content.Info.MimeType = "image/webp"
+		}
+	case mediaType == whatsmeow.MediaVideo:
+		switch mimeType {
+		case "video/mp4", "video/3gpp":
+			// Allowed
+		case "image/gif":
+			data, convertErr = ffmpeg.ConvertBytes(ctx, data, ".mp4", []string{"-f", "gif"}, []string{
+				"-pix_fmt", "yuv420p", "-c:v", "libx264", "-movflags", "+faststart",
+				"-filter:v", "crop='floor(in_w/2)*2:floor(in_h/2)*2'",
+			}, mimeType)
+			content.Info.MimeType = "video/mp4"
+		case "video/webm":
+			data, convertErr = ffmpeg.ConvertBytes(ctx, data, ".mp4", []string{"-f", "webm"}, []string{
+				"-pix_fmt", "yuv420p", "-c:v", "libx264",
+			}, mimeType)
+			content.Info.MimeType = "video/mp4"
+		default:
+			return nil, fmt.Errorf("%w %q in video message", errMediaUnsupportedType, mimeType)
+		}
+	case mediaType == whatsmeow.MediaImage:
+		switch mimeType {
+		case "image/jpeg", "image/png":
+			// Allowed
+		case "image/webp":
+			data, convertErr = portal.convertWebPtoPNG(data)
+			content.Info.MimeType = "image/png"
+		default:
+			return nil, fmt.Errorf("%w %q in image message", errMediaUnsupportedType, mimeType)
+		}
+	case mediaType == whatsmeow.MediaAudio:
+		switch mimeType {
+		case "audio/aac", "audio/mp4", "audio/amr", "audio/mpeg", "audio/ogg; codecs=opus":
+			// Allowed
+		case "audio/ogg":
+			// Hopefully it's opus already
+			content.Info.MimeType = "audio/ogg; codecs=opus"
+		default:
+			return nil, fmt.Errorf("%w %q in audio message", errMediaUnsupportedType, mimeType)
 		}
 		}
-		content.Info.MimeType = "video/mp4"
+	case mediaType == whatsmeow.MediaDocument:
+		// Everything is allowed
 	}
 	}
-	if mediaType == whatsmeow.MediaImage && content.GetInfo().MimeType == "image/webp" {
-		data, err = portal.convertWebPtoPNG(data)
-		if err != nil {
-			return nil, util.NewDualError(fmt.Errorf("%w (webp to png)", errMediaConvertFailed), err)
-		}
-		content.Info.MimeType = "image/png"
+	if convertErr != nil {
+		return nil, util.NewDualError(fmt.Errorf("%w (%s to %s)", errMediaConvertFailed, mimeType, content.Info.MimeType), err)
 	}
 	}
 	uploadResp, err := sender.Client.Upload(ctx, data, mediaType)
 	uploadResp, err := sender.Client.Upload(ctx, data, mediaType)
 	if err != nil {
 	if err != nil {
@@ -2869,7 +2924,7 @@ func (portal *Portal) preprocessMatrixMedia(ctx context.Context, sender *User, r
 	// Audio doesn't have thumbnails
 	// Audio doesn't have thumbnails
 	var thumbnail []byte
 	var thumbnail []byte
 	if mediaType != whatsmeow.MediaAudio {
 	if mediaType != whatsmeow.MediaAudio {
-		thumbnail, err = portal.downloadThumbnail(ctx, data, content.GetInfo().ThumbnailURL, eventID)
+		thumbnail, err = portal.downloadThumbnail(ctx, data, content.GetInfo().ThumbnailURL, eventID, isSticker)
 		// Ignore format errors for non-image files, we don't care about those thumbnails
 		// Ignore format errors for non-image files, we don't care about those thumbnails
 		if err != nil && (!errors.Is(err, image.ErrFormat) || mediaType == whatsmeow.MediaImage) {
 		if err != nil && (!errors.Is(err, image.ErrFormat) || mediaType == whatsmeow.MediaImage) {
 			portal.log.Warnfln("Failed to generate thumbnail for %s: %v", eventID, err)
 			portal.log.Warnfln("Failed to generate thumbnail for %s: %v", eventID, err)
@@ -2996,7 +3051,12 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 		sender = portal.GetRelayUser()
 		sender = portal.GetRelayUser()
 	}
 	}
 	if evt.Type == event.EventSticker {
 	if evt.Type == event.EventSticker {
-		content.MsgType = event.MsgImage
+		if relaybotFormatted {
+			// Stickers can't have captions, so force relaybot stickers to be images
+			content.MsgType = event.MsgImage
+		} else {
+			content.MsgType = event.MessageType(event.EventSticker.Type)
+		}
 	}
 	}
 	if content.MsgType == event.MsgImage && content.GetInfo().MimeType == "image/gif" {
 	if content.MsgType == event.MsgImage && content.GetInfo().MimeType == "image/gif" {
 		content.MsgType = event.MsgVideo
 		content.MsgType = event.MsgVideo
@@ -3044,6 +3104,22 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 			FileSha256:    media.FileSHA256,
 			FileSha256:    media.FileSHA256,
 			FileLength:    proto.Uint64(uint64(media.FileLength)),
 			FileLength:    proto.Uint64(uint64(media.FileLength)),
 		}
 		}
+	case event.MessageType(event.EventSticker.Type):
+		media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaImage)
+		if media == nil {
+			return nil, sender, err
+		}
+		ctxInfo.MentionedJid = media.MentionedJIDs
+		msg.StickerMessage = &waProto.StickerMessage{
+			ContextInfo:   &ctxInfo,
+			PngThumbnail:  media.Thumbnail,
+			Url:           &media.URL,
+			MediaKey:      media.MediaKey,
+			Mimetype:      &content.GetInfo().MimeType,
+			FileEncSha256: media.FileEncSHA256,
+			FileSha256:    media.FileSHA256,
+			FileLength:    proto.Uint64(uint64(media.FileLength)),
+		}
 	case event.MsgVideo:
 	case event.MsgVideo:
 		gifPlayback := content.GetInfo().MimeType == "image/gif"
 		gifPlayback := content.GetInfo().MimeType == "image/gif"
 		media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaVideo)
 		media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaVideo)

+ 1 - 1
urlpreview.go

@@ -186,7 +186,7 @@ func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *U
 		dest.ThumbnailDirectPath = &uploadResp.DirectPath
 		dest.ThumbnailDirectPath = &uploadResp.DirectPath
 		dest.MediaKey = uploadResp.MediaKey
 		dest.MediaKey = uploadResp.MediaKey
 		var width, height int
 		var width, height int
-		dest.JpegThumbnail, width, height, err = createJPEGThumbnailAndGetSize(data)
+		dest.JpegThumbnail, width, height, err = createThumbnailAndGetSize(data, false)
 		if err != nil {
 		if err != nil {
 			portal.log.Warnfln("Failed to create JPEG thumbnail for URL preview in %s: %v", evt.ID, err)
 			portal.log.Warnfln("Failed to create JPEG thumbnail for URL preview in %s: %v", evt.ID, err)
 		}
 		}