Parcourir la source

Bridge incoming votes in MSC3381 polls

Tulir Asokan il y a 2 ans
Parent
commit
b8dc3c0e56
3 fichiers modifiés avec 85 ajouts et 5 suppressions
  1. 2 0
      main.go
  2. 2 0
      messagetracking.go
  3. 81 5
      portal.go

+ 2 - 0
main.go

@@ -87,6 +87,8 @@ func (br *WABridge) Init() {
 
 	// TODO this is a weird place for this
 	br.EventProcessor.On(event.EphemeralEventPresence, br.HandlePresence)
+	br.EventProcessor.On(TypeMSC3881PollResponse, br.MatrixHandler.HandleMessage)
+	br.EventProcessor.On(TypeMSC3881V2PollResponse, br.MatrixHandler.HandleMessage)
 
 	Segment.log = br.Log.Sub("Segment")
 	Segment.key = br.Config.SegmentKey

+ 2 - 0
messagetracking.go

@@ -185,6 +185,8 @@ func (portal *Portal) sendMessageMetrics(evt *event.Event, err error, part strin
 		msgType = "reaction"
 	case event.EventRedaction:
 		msgType = "redaction"
+	case TypeMSC3881PollResponse, TypeMSC3881V2PollResponse:
+		msgType = "poll response"
 	default:
 		msgType = "unknown event"
 	}

+ 81 - 5
portal.go

@@ -33,6 +33,7 @@ import (
 	"math"
 	"mime"
 	"net/http"
+	"reflect"
 	"runtime/debug"
 	"strconv"
 	"strings"
@@ -315,7 +316,7 @@ func (portal *Portal) handleMatrixMessageLoopItem(msg PortalMatrixMessage) {
 	portal.handleMatrixReadReceipt(msg.user, "", evtTS, false)
 	timings.implicitRR = time.Since(implicitRRStart)
 	switch msg.evt.Type {
-	case event.EventMessage, event.EventSticker:
+	case event.EventMessage, event.EventSticker, TypeMSC3881V2PollResponse, TypeMSC3881PollResponse:
 		portal.HandleMatrixMessage(msg.user, msg.evt, timings)
 	case event.EventRedaction:
 		portal.HandleMatrixRedaction(msg.user, msg.evt)
@@ -2121,9 +2122,9 @@ func (portal *Portal) convertPollUpdateMessage(intent *appservice.IntentAPI, sou
 		selectedHashes[i] = hex.EncodeToString(opt)
 	}
 
-	evtType := event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.response"}
+	evtType := TypeMSC3881PollResponse
 	if portal.bridge.Config.Bridge.ExtEvPolls == 2 {
-		evtType.Type = "org.matrix.msc3381.v2.poll.response"
+		evtType = TypeMSC3881V2PollResponse
 	}
 	return &ConvertedMessage{
 		Intent: intent,
@@ -3333,7 +3334,81 @@ func getUnstableWaveform(content map[string]interface{}) []byte {
 	return output
 }
 
+var (
+	TypeMSC3881PollResponse   = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.response"}
+	TypeMSC3881V2PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3881.v2.poll.response"}
+)
+
+type PollResponseContent struct {
+	RelatesTo  event.RelatesTo `json:"m.relates_to"`
+	V1Response struct {
+		Answers []string `json:"answers"`
+	} `json:"org.matrix.msc3381.poll.response"`
+	V2Selections []string `json:"org.matrix.msc3381.v2.selections"`
+}
+
+func (content *PollResponseContent) GetRelatesTo() *event.RelatesTo {
+	return &content.RelatesTo
+}
+
+func (content *PollResponseContent) OptionalGetRelatesTo() *event.RelatesTo {
+	if content.RelatesTo.Type == "" {
+		return nil
+	}
+	return &content.RelatesTo
+}
+
+func (content *PollResponseContent) SetRelatesTo(rel *event.RelatesTo) {
+	content.RelatesTo = *rel
+}
+
+func init() {
+	event.TypeMap[TypeMSC3881PollResponse] = reflect.TypeOf(PollResponseContent{})
+	event.TypeMap[TypeMSC3881V2PollResponse] = reflect.TypeOf(PollResponseContent{})
+}
+
+func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, error) {
+	content, ok := evt.Content.Parsed.(*PollResponseContent)
+	if !ok {
+		return nil, sender, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
+	}
+	var answers []string
+	if content.V1Response.Answers != nil {
+		answers = content.V1Response.Answers
+	} else if content.V2Selections != nil {
+		answers = content.V2Selections
+	}
+	pollMsg := portal.bridge.DB.Message.GetByMXID(content.RelatesTo.EventID)
+	if pollMsg == nil {
+		return nil, sender, errTargetNotFound
+	}
+	pollMsgInfo := &types.MessageInfo{
+		MessageSource: types.MessageSource{
+			Chat:     portal.Key.JID,
+			Sender:   pollMsg.Sender,
+			IsFromMe: pollMsg.Sender.User == sender.JID.User,
+			IsGroup:  portal.IsGroupChat(),
+		},
+		ID:   pollMsg.JID,
+		Type: "poll",
+	}
+	optionHashes := make([][]byte, 0, len(answers))
+	for _, selection := range answers {
+		hash, _ := hex.DecodeString(selection)
+		if hash != nil && len(hash) == 32 {
+			optionHashes = append(optionHashes, hash)
+		}
+	}
+	pollUpdate, err := sender.Client.EncryptPollVote(pollMsgInfo, &waProto.PollVoteMessage{
+		SelectedOptions: optionHashes,
+	})
+	return &waProto.Message{PollUpdateMessage: pollUpdate}, sender, err
+}
+
 func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, error) {
+	if evt.Type == TypeMSC3881PollResponse || evt.Type == TypeMSC3881V2PollResponse {
+		return portal.convertMatrixPollVote(ctx, sender, evt)
+	}
 	content, ok := evt.Content.Parsed.(*event.MessageEventContent)
 	if !ok {
 		return nil, sender, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
@@ -3351,7 +3426,7 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 
 	msg := &waProto.Message{}
 	var ctxInfo waProto.ContextInfo
-	replyToID := content.GetReplyTo()
+	replyToID := content.RelatesTo.GetReplyTo()
 	if len(replyToID) > 0 {
 		replyToMsg := portal.bridge.DB.Message.GetByMXID(replyToID)
 		if replyToMsg != nil && !replyToMsg.IsFakeJID() && replyToMsg.Type == database.MsgNormal {
@@ -3570,7 +3645,8 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing
 	start := time.Now()
 	ms := metricSender{portal: portal, timings: &timings}
 
-	if err := portal.canBridgeFrom(sender, true); err != nil {
+	allowRelay := evt.Type != TypeMSC3881PollResponse && evt.Type != TypeMSC3881V2PollResponse
+	if err := portal.canBridgeFrom(sender, allowRelay); err != nil {
 		go ms.sendMessageMetrics(evt, err, "Ignoring", true)
 		return
 	} else if portal.Key.JID == types.StatusBroadcastJID && portal.bridge.Config.Bridge.DisableStatusBroadcastSend {