Quellcode durchsuchen

Add support for creating polls from Matrix

Tulir Asokan vor 2 Jahren
Ursprung
Commit
0305680317

+ 4 - 0
CHANGELOG.md

@@ -1,3 +1,7 @@
+# unreleased
+
+* Added support for sending polls from Matrix to WhatsApp.
+
 # v0.8.0 (2022-12-16)
 
 * Added support for bridging polls from WhatsApp and votes in both directions.

+ 1 - 1
ROADMAP.md

@@ -6,7 +6,7 @@
     * [x] Location messages
     * [x] Media/files
     * [x] Replies
-    * [ ] Polls
+    * [x] Polls
     * [x] Poll votes
   * [x] Message redactions
   * [x] Reactions

+ 6 - 5
database/message.go

@@ -134,11 +134,12 @@ const (
 type MessageType string
 
 const (
-	MsgUnknown  MessageType = ""
-	MsgFake     MessageType = "fake"
-	MsgNormal   MessageType = "message"
-	MsgReaction MessageType = "reaction"
-	MsgEdit     MessageType = "edit"
+	MsgUnknown    MessageType = ""
+	MsgFake       MessageType = "fake"
+	MsgNormal     MessageType = "message"
+	MsgReaction   MessageType = "reaction"
+	MsgEdit       MessageType = "edit"
+	MsgMatrixPoll MessageType = "matrix-poll"
 )
 
 type Message struct {

+ 118 - 0
database/polloption.go

@@ -0,0 +1,118 @@
+// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
+// Copyright (C) 2022 Tulir Asokan
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+package database
+
+import (
+	"fmt"
+	"strings"
+
+	"github.com/lib/pq"
+	"maunium.net/go/mautrix/util/dbutil"
+)
+
+func scanPollOptionMapping(rows dbutil.Rows) (id string, hashArr [32]byte, err error) {
+	var hash []byte
+	err = rows.Scan(&id, &hash)
+	if err != nil {
+		// return below
+	} else if len(hash) != 32 {
+		err = fmt.Errorf("unexpected hash length %d", len(hash))
+	} else {
+		hashArr = *(*[32]byte)(hash)
+	}
+	return
+}
+
+func (msg *Message) PutPollOptions(opts map[[32]byte]string) {
+	query := "INSERT INTO poll_option_id (msg_mxid, opt_id, opt_hash) VALUES ($1, $2, $3)"
+	args := make([]any, len(opts)*2+1)
+	placeholders := make([]string, len(opts))
+	args[0] = msg.MXID
+	i := 0
+	for hash, id := range opts {
+		args[i*2+1] = id
+		hashCopy := hash
+		args[i*2+2] = hashCopy[:]
+		placeholders[i] = fmt.Sprintf("($1, $%d, $%d)", i*2+2, i*2+3)
+		i++
+	}
+	query = strings.ReplaceAll(query, "($1, $2, $3)", strings.Join(placeholders, ","))
+	_, err := msg.db.Exec(query, args...)
+	if err != nil {
+		msg.log.Errorfln("Failed to save poll options for %s: %v", msg.MXID, err)
+	}
+}
+
+func (msg *Message) GetPollOptionIDs(hashes [][]byte) map[[32]byte]string {
+	query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_hash = ANY($2)"
+	var args []any
+	if msg.db.Dialect == dbutil.Postgres {
+		args = []any{msg.MXID, pq.Array(hashes)}
+	} else {
+		query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(hashes)), ",")))
+		args = make([]any, len(hashes)+1)
+		args[0] = msg.MXID
+		for i, hash := range hashes {
+			args[i+1] = hash
+		}
+	}
+	ids := make(map[[32]byte]string, len(hashes))
+	rows, err := msg.db.Query(query, args...)
+	if err != nil {
+		msg.log.Errorfln("Failed to query poll option IDs for %s: %v", msg.MXID, err)
+	} else {
+		for rows.Next() {
+			id, hash, err := scanPollOptionMapping(rows)
+			if err != nil {
+				msg.log.Errorfln("Failed to scan poll option ID for %s: %v", msg.MXID, err)
+				break
+			}
+			ids[hash] = id
+		}
+	}
+	return ids
+}
+
+func (msg *Message) GetPollOptionHashes(ids []string) map[string][32]byte {
+	query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_id = ANY($2)"
+	var args []any
+	if msg.db.Dialect == dbutil.Postgres {
+		args = []any{msg.MXID, pq.Array(ids)}
+	} else {
+		query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(ids)), ",")))
+		args = make([]any, len(ids)+1)
+		args[0] = msg.MXID
+		for i, id := range ids {
+			args[i+1] = id
+		}
+	}
+	hashes := make(map[string][32]byte, len(ids))
+	rows, err := msg.db.Query(query, args...)
+	if err != nil {
+		msg.log.Errorfln("Failed to query poll option hashes for %s: %v", msg.MXID, err)
+	} else {
+		for rows.Next() {
+			id, hash, err := scanPollOptionMapping(rows)
+			if err != nil {
+				msg.log.Errorfln("Failed to scan poll option hash for %s: %v", msg.MXID, err)
+				break
+			}
+			hashes[id] = hash
+		}
+	}
+	return hashes
+}

+ 12 - 1
database/upgrades/00-latest-revision.sql

@@ -1,4 +1,4 @@
--- v0 -> v52: Latest revision
+-- v0 -> v54: Latest revision
 
 CREATE TABLE "user" (
     mxid     TEXT PRIMARY KEY,
@@ -40,6 +40,7 @@ CREATE TABLE portal (
 
     PRIMARY KEY (jid, receiver)
 );
+CREATE INDEX portal_parent_group_idx ON portal(parent_group);
 
 CREATE TABLE puppet (
     username     TEXT PRIMARY KEY,
@@ -79,6 +80,16 @@ CREATE TABLE message (
     FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
 );
 
+CREATE TABLE poll_option_id (
+    msg_mxid TEXT,
+    opt_id   TEXT,
+    opt_hash bytea CHECK ( length(opt_hash) = 32 ),
+
+    PRIMARY KEY (msg_mxid, opt_id),
+    CONSTRAINT poll_option_unique_hash UNIQUE (msg_mxid, opt_hash),
+    CONSTRAINT message_mxid_fkey FOREIGN KEY (msg_mxid) REFERENCES message(mxid) ON DELETE CASCADE ON UPDATE CASCADE
+);
+
 CREATE TABLE reaction (
     chat_jid      TEXT,
     chat_receiver TEXT,

+ 11 - 0
database/upgrades/54-poll-option-id-map.sql

@@ -0,0 +1,11 @@
+-- v54: Store mapping for poll option IDs from Matrix
+
+CREATE TABLE poll_option_id (
+    msg_mxid TEXT,
+    opt_id   TEXT,
+    opt_hash bytea CHECK ( length(opt_hash) = 32 ),
+
+    PRIMARY KEY (msg_mxid, opt_id),
+    CONSTRAINT poll_option_unique_hash UNIQUE (msg_mxid, opt_hash),
+    CONSTRAINT message_mxid_fkey FOREIGN KEY (msg_mxid) REFERENCES message(mxid) ON DELETE CASCADE ON UPDATE CASCADE
+);

+ 11 - 3
formatting.go

@@ -35,7 +35,8 @@ var strikethroughRegex = regexp.MustCompile("([\\s>_*]|^)~(.+?)~([^a-zA-Z\\d]|$)
 var codeBlockRegex = regexp.MustCompile("```(?:.|\n)+?```")
 var inlineURLRegex = regexp.MustCompile(`\[(.+?)]\((.+?)\)`)
 
-const mentionedJIDsContextKey = "net.maunium.whatsapp.mentioned_jids"
+const mentionedJIDsContextKey = "fi.mau.whatsapp.mentioned_jids"
+const disableMentionsContextKey = "fi.mau.whatsapp.no_mentions"
 
 type Formatter struct {
 	bridge *WABridge
@@ -55,7 +56,8 @@ func NewFormatter(bridge *WABridge) *Formatter {
 			Newline:      "\n",
 
 			PillConverter: func(displayname, mxid, eventID string, ctx format.Context) string {
-				if mxid[0] == '@' {
+				_, disableMentions := ctx[disableMentionsContextKey]
+				if mxid[0] == '@' && !disableMentions {
 					puppet := bridge.GetPuppetByMXID(id.UserID(mxid))
 					if puppet != nil {
 						jids, ok := ctx[mentionedJIDsContextKey].([]string)
@@ -67,7 +69,7 @@ func NewFormatter(bridge *WABridge) *Formatter {
 						return "@" + puppet.JID.User
 					}
 				}
-				return mxid
+				return displayname
 			},
 			BoldConverter:           func(text string, _ format.Context) string { return fmt.Sprintf("*%s*", text) },
 			ItalicConverter:         func(text string, _ format.Context) string { return fmt.Sprintf("_%s_", text) },
@@ -151,3 +153,9 @@ func (formatter *Formatter) ParseMatrix(html string) (string, []string) {
 	mentionedJIDs, _ := ctx[mentionedJIDsContextKey].([]string)
 	return result, mentionedJIDs
 }
+
+func (formatter *Formatter) ParseMatrixWithoutMentions(html string) string {
+	ctx := make(format.Context)
+	ctx[disableMentionsContextKey] = true
+	return formatter.matrixHTMLParser.Parse(html, ctx)
+}

+ 3 - 2
main.go

@@ -87,8 +87,9 @@ func (br *WABridge) Init() {
 
 	// TODO this is a weird place for this
 	br.EventProcessor.On(event.EphemeralEventPresence, br.HandlePresence)
-	br.EventProcessor.On(TypeMSC3881PollResponse, br.MatrixHandler.HandleMessage)
-	br.EventProcessor.On(TypeMSC3881V2PollResponse, br.MatrixHandler.HandleMessage)
+	br.EventProcessor.On(TypeMSC3381PollStart, br.MatrixHandler.HandleMessage)
+	br.EventProcessor.On(TypeMSC3381PollResponse, br.MatrixHandler.HandleMessage)
+	br.EventProcessor.On(TypeMSC3381V2PollResponse, br.MatrixHandler.HandleMessage)
 
 	Segment.log = br.Log.Sub("Segment")
 	Segment.key = br.Config.SegmentKey

+ 6 - 2
messagetracking.go

@@ -52,6 +52,8 @@ var (
 	errTargetIsFake                = errors.New("target is a fake event")
 	errReactionSentBySomeoneElse   = errors.New("target reaction was sent by someone else")
 	errDMSentByOtherUser           = errors.New("target message was sent by the other user in a DM")
+	errPollMissingQuestion         = errors.New("poll message is missing question")
+	errPollDuplicateOption         = errors.New("poll options must be unique")
 
 	errBroadcastReactionNotSupported = errors.New("reacting to status messages is not currently supported")
 	errBroadcastSendDisabled         = errors.New("sending status messages is disabled")
@@ -76,7 +78,7 @@ func errorToStatusReason(err error) (reason event.MessageStatusReason, status ev
 		return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, ""
 	case errors.Is(err, errMNoticeDisabled):
 		return event.MessageStatusUnsupported, event.MessageStatusFail, true, false, ""
-	case errors.Is(err, errMediaUnsupportedType):
+	case errors.Is(err, errMediaUnsupportedType), errors.Is(err, errPollMissingQuestion), errors.Is(err, errPollDuplicateOption):
 		return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, err.Error()
 	case errors.Is(err, errTimeoutBeforeHandling):
 		return event.MessageStatusTooOld, event.MessageStatusRetriable, true, true, "the message was too old when it reached the bridge, so it was not handled"
@@ -185,8 +187,10 @@ func (portal *Portal) sendMessageMetrics(evt *event.Event, err error, part strin
 		msgType = "reaction"
 	case event.EventRedaction:
 		msgType = "redaction"
-	case TypeMSC3881PollResponse, TypeMSC3881V2PollResponse:
+	case TypeMSC3381PollResponse, TypeMSC3381V2PollResponse:
 		msgType = "poll response"
+	case TypeMSC3381PollStart:
+		msgType = "poll start"
 	default:
 		msgType = "unknown event"
 	}

+ 205 - 62
portal.go

@@ -19,6 +19,7 @@ package main
 import (
 	"bytes"
 	"context"
+	"crypto/rand"
 	"crypto/sha256"
 	"encoding/hex"
 	"encoding/json"
@@ -322,7 +323,7 @@ func (portal *Portal) handleMatrixMessageLoopItem(msg PortalMatrixMessage) {
 	portal.handleMatrixReadReceipt(msg.user, "", evtTS, false)
 	timings.implicitRR = time.Since(implicitRRStart)
 	switch msg.evt.Type {
-	case event.EventMessage, event.EventSticker, TypeMSC3881V2PollResponse, TypeMSC3881PollResponse:
+	case event.EventMessage, event.EventSticker, TypeMSC3381V2PollResponse, TypeMSC3381PollResponse, TypeMSC3381PollStart:
 		portal.HandleMatrixMessage(msg.user, msg.evt, timings)
 	case event.EventRedaction:
 		portal.HandleMatrixRedaction(msg.user, msg.evt)
@@ -2267,9 +2268,6 @@ func (portal *Portal) convertListResponseMessage(intent *appservice.IntentAPI, m
 }
 
 func (portal *Portal) convertPollUpdateMessage(intent *appservice.IntentAPI, source *User, info *types.MessageInfo, msg *waProto.PollUpdateMessage) *ConvertedMessage {
-	if !portal.bridge.Config.Bridge.ExtEvPolls {
-		return nil
-	}
 	pollMessage := portal.bridge.DB.Message.GetByJID(portal.Key, msg.GetPollCreationMessageKey().GetId())
 	if pollMessage == nil {
 		portal.log.Warnfln("Failed to convert vote message %s: poll message %s not found", info.ID, msg.GetPollCreationMessageKey().GetId())
@@ -2284,13 +2282,28 @@ func (portal *Portal) convertPollUpdateMessage(intent *appservice.IntentAPI, sou
 		return nil
 	}
 	selectedHashes := make([]string, len(vote.GetSelectedOptions()))
-	for i, opt := range vote.GetSelectedOptions() {
-		selectedHashes[i] = hex.EncodeToString(opt)
+	if pollMessage.Type == database.MsgMatrixPoll {
+		mappedAnswers := pollMessage.GetPollOptionIDs(vote.GetSelectedOptions())
+		for i, opt := range vote.GetSelectedOptions() {
+			if len(opt) != 32 {
+				portal.log.Warnfln("Unexpected option hash length %d in %s's vote to %s", len(opt), info.Sender, pollMessage.MXID)
+				continue
+			}
+			var ok bool
+			selectedHashes[i], ok = mappedAnswers[*(*[32]byte)(opt)]
+			if !ok {
+				portal.log.Warnfln("Didn't find ID for option %X in %s's vote to %s", opt, info.Sender, pollMessage.MXID)
+			}
+		}
+	} else {
+		for i, opt := range vote.GetSelectedOptions() {
+			selectedHashes[i] = hex.EncodeToString(opt)
+		}
 	}
 
-	evtType := TypeMSC3881PollResponse
+	evtType := TypeMSC3381PollResponse
 	//if portal.bridge.Config.Bridge.ExtEvPolls == 2 {
-	//	evtType = TypeMSC3881V2PollResponse
+	//	evtType = TypeMSC3381V2PollResponse
 	//}
 	return &ConvertedMessage{
 		Intent: intent,
@@ -2341,7 +2354,7 @@ func (portal *Portal) convertPollCreationMessage(intent *appservice.IntentAPI, m
 	}
 	evtType := event.EventMessage
 	if portal.bridge.Config.Bridge.ExtEvPolls {
-		evtType.Type = "org.matrix.msc3381.poll.start"
+		evtType = TypeMSC3381PollStart
 	}
 	//else if portal.bridge.Config.Bridge.ExtEvPolls == 2 {
 	//	evtType.Type = "org.matrix.msc3381.v2.poll.start"
@@ -3505,8 +3518,9 @@ func getUnstableWaveform(content map[string]interface{}) []byte {
 }
 
 var (
-	TypeMSC3881PollResponse   = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.response"}
-	TypeMSC3881V2PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3881.v2.poll.response"}
+	TypeMSC3381PollStart      = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.start"}
+	TypeMSC3381PollResponse   = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.response"}
+	TypeMSC3381V2PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.v2.poll.response"}
 )
 
 type PollResponseContent struct {
@@ -3532,15 +3546,71 @@ func (content *PollResponseContent) SetRelatesTo(rel *event.RelatesTo) {
 	content.RelatesTo = *rel
 }
 
+type MSC1767Message struct {
+	Text    string `json:"org.matrix.msc1767.text,omitempty"`
+	HTML    string `json:"org.matrix.msc1767.html,omitempty"`
+	Message []struct {
+		MimeType string `json:"mimetype"`
+		Body     string `json:"body"`
+	} `json:"org.matrix.msc1767.message,omitempty"`
+}
+
+func (portal *Portal) msc1767ToWhatsApp(msg MSC1767Message, mentions bool) (string, []string) {
+	for _, part := range msg.Message {
+		if part.MimeType == "text/html" && msg.HTML == "" {
+			msg.HTML = part.Body
+		} else if part.MimeType == "text/plain" && msg.Text == "" {
+			msg.Text = part.Body
+		}
+	}
+	if msg.HTML != "" {
+		if mentions {
+			return portal.bridge.Formatter.ParseMatrix(msg.HTML)
+		} else {
+			return portal.bridge.Formatter.ParseMatrixWithoutMentions(msg.HTML), nil
+		}
+	}
+	return msg.Text, nil
+}
+
+type PollStartContent struct {
+	RelatesTo *event.RelatesTo `json:"m.relates_to"`
+	PollStart struct {
+		Kind          string         `json:"kind"`
+		MaxSelections int            `json:"max_selections"`
+		Question      MSC1767Message `json:"question"`
+		Answers       []struct {
+			ID string `json:"id"`
+			MSC1767Message
+		} `json:"answers"`
+	} `json:"org.matrix.msc3381.poll.start"`
+}
+
+func (content *PollStartContent) GetRelatesTo() *event.RelatesTo {
+	if content.RelatesTo == nil {
+		content.RelatesTo = &event.RelatesTo{}
+	}
+	return content.RelatesTo
+}
+
+func (content *PollStartContent) OptionalGetRelatesTo() *event.RelatesTo {
+	return content.RelatesTo
+}
+
+func (content *PollStartContent) SetRelatesTo(rel *event.RelatesTo) {
+	content.RelatesTo = rel
+}
+
 func init() {
-	event.TypeMap[TypeMSC3881PollResponse] = reflect.TypeOf(PollResponseContent{})
-	event.TypeMap[TypeMSC3881V2PollResponse] = reflect.TypeOf(PollResponseContent{})
+	event.TypeMap[TypeMSC3381PollResponse] = reflect.TypeOf(PollResponseContent{})
+	event.TypeMap[TypeMSC3381V2PollResponse] = reflect.TypeOf(PollResponseContent{})
+	event.TypeMap[TypeMSC3381PollStart] = reflect.TypeOf(PollStartContent{})
 }
 
-func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, error) {
+func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) {
 	content, ok := evt.Content.Parsed.(*PollResponseContent)
 	if !ok {
-		return nil, sender, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
+		return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
 	}
 	var answers []string
 	if content.V1Response.Answers != nil {
@@ -3550,7 +3620,7 @@ func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt
 	}
 	pollMsg := portal.bridge.DB.Message.GetByMXID(content.RelatesTo.EventID)
 	if pollMsg == nil {
-		return nil, sender, errTargetNotFound
+		return nil, sender, nil, errTargetNotFound
 	}
 	pollMsgInfo := &types.MessageInfo{
 		MessageSource: types.MessageSource{
@@ -3563,43 +3633,81 @@ func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt
 		Type: "poll",
 	}
 	optionHashes := make([][]byte, 0, len(answers))
-	for _, selection := range answers {
-		hash, _ := hex.DecodeString(selection)
-		if hash != nil && len(hash) == 32 {
-			optionHashes = append(optionHashes, hash)
+	if pollMsg.Type == database.MsgMatrixPoll {
+		mappedAnswers := pollMsg.GetPollOptionHashes(answers)
+		for _, selection := range answers {
+			hash, ok := mappedAnswers[selection]
+			if ok {
+				optionHashes = append(optionHashes, hash[:])
+			} else {
+				portal.log.Warnfln("Didn't find hash for option %s in %s's vote to %s", selection, evt.Sender, pollMsg.MXID)
+			}
+		}
+	} else {
+		for _, selection := range answers {
+			hash, _ := hex.DecodeString(selection)
+			if hash != nil && len(hash) == 32 {
+				optionHashes = append(optionHashes, hash)
+			}
 		}
 	}
 	pollUpdate, err := sender.Client.EncryptPollVote(pollMsgInfo, &waProto.PollVoteMessage{
 		SelectedOptions: optionHashes,
 	})
-	return &waProto.Message{PollUpdateMessage: pollUpdate}, sender, err
+	return &waProto.Message{PollUpdateMessage: pollUpdate}, sender, nil, err
 }
 
-func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, error) {
-	if evt.Type == TypeMSC3881PollResponse || evt.Type == TypeMSC3881V2PollResponse {
-		return portal.convertMatrixPollVote(ctx, sender, evt)
-	}
-	content, ok := evt.Content.Parsed.(*event.MessageEventContent)
+func (portal *Portal) convertMatrixPollStart(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) {
+	content, ok := evt.Content.Parsed.(*PollStartContent)
 	if !ok {
-		return nil, sender, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
-	}
-	var editRootMsg *database.Message
-	if editEventID := content.RelatesTo.GetReplaceID(); editEventID != "" && portal.bridge.Config.Bridge.SendWhatsAppEdits {
-		editRootMsg = portal.bridge.DB.Message.GetByMXID(editEventID)
-		if editRootMsg == nil || editRootMsg.Type != database.MsgNormal || editRootMsg.IsFakeJID() || editRootMsg.Sender.User != sender.JID.User {
-			return nil, sender, fmt.Errorf("edit rejected") // TODO more specific error message
-		}
-		if content.NewContent != nil {
-			content = content.NewContent
-		}
-	}
+		return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
+	}
+	maxAnswers := content.PollStart.MaxSelections
+	if maxAnswers >= len(content.PollStart.Answers) || maxAnswers < 0 {
+		maxAnswers = 0
+	}
+	fmt.Printf("%+v\n", content.PollStart)
+	ctxInfo := portal.generateContextInfo(content.RelatesTo)
+	var question string
+	question, ctxInfo.MentionedJid = portal.msc1767ToWhatsApp(content.PollStart.Question, true)
+	if len(question) == 0 {
+		return nil, sender, nil, errPollMissingQuestion
+	}
+	options := make([]*waProto.PollCreationMessage_Option, len(content.PollStart.Answers))
+	optionMap := make(map[[32]byte]string, len(options))
+	for i, opt := range content.PollStart.Answers {
+		body, _ := portal.msc1767ToWhatsApp(opt.MSC1767Message, false)
+		hash := sha256.Sum256([]byte(body))
+		if _, alreadyExists := optionMap[hash]; alreadyExists {
+			portal.log.Warnfln("Poll %s by %s has option %q more than once, rejecting", evt.ID, evt.Sender, body)
+			return nil, sender, nil, errPollDuplicateOption
+		}
+		optionMap[hash] = opt.ID
+		options[i] = &waProto.PollCreationMessage_Option{
+			OptionName: proto.String(body),
+		}
+	}
+	secret := make([]byte, 32)
+	_, err := rand.Read(secret)
+	return &waProto.Message{
+		PollCreationMessage: &waProto.PollCreationMessage{
+			Name:                   proto.String(question),
+			Options:                options,
+			SelectableOptionsCount: proto.Uint32(uint32(maxAnswers)),
+			ContextInfo:            ctxInfo,
+		},
+		MessageContextInfo: &waProto.MessageContextInfo{
+			MessageSecret: secret,
+		},
+	}, sender, &extraConvertMeta{PollOptions: optionMap}, err
+}
 
-	msg := &waProto.Message{}
+func (portal *Portal) generateContextInfo(relatesTo *event.RelatesTo) *waProto.ContextInfo {
 	var ctxInfo waProto.ContextInfo
-	replyToID := content.RelatesTo.GetReplyTo()
+	replyToID := relatesTo.GetReplyTo()
 	if len(replyToID) > 0 {
 		replyToMsg := portal.bridge.DB.Message.GetByMXID(replyToID)
-		if replyToMsg != nil && !replyToMsg.IsFakeJID() && replyToMsg.Type == database.MsgNormal {
+		if replyToMsg != nil && !replyToMsg.IsFakeJID() && (replyToMsg.Type == database.MsgNormal || replyToMsg.Type == database.MsgMatrixPoll) {
 			ctxInfo.StanzaId = &replyToMsg.JID
 			ctxInfo.Participant = proto.String(replyToMsg.Sender.ToNonAD().String())
 			// Using blank content here seems to work fine on all official WhatsApp apps.
@@ -3613,10 +3721,40 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 	if portal.ExpirationTime != 0 {
 		ctxInfo.Expiration = proto.Uint32(portal.ExpirationTime)
 	}
+	return &ctxInfo
+}
+
+type extraConvertMeta struct {
+	PollOptions map[[32]byte]string
+}
+
+func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) {
+	if evt.Type == TypeMSC3381PollResponse || evt.Type == TypeMSC3381V2PollResponse {
+		return portal.convertMatrixPollVote(ctx, sender, evt)
+	} else if evt.Type == TypeMSC3381PollStart {
+		return portal.convertMatrixPollStart(ctx, sender, evt)
+	}
+	content, ok := evt.Content.Parsed.(*event.MessageEventContent)
+	if !ok {
+		return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
+	}
+	var editRootMsg *database.Message
+	if editEventID := content.RelatesTo.GetReplaceID(); editEventID != "" && portal.bridge.Config.Bridge.SendWhatsAppEdits {
+		editRootMsg = portal.bridge.DB.Message.GetByMXID(editEventID)
+		if editRootMsg == nil || editRootMsg.Type != database.MsgNormal || editRootMsg.IsFakeJID() || editRootMsg.Sender.User != sender.JID.User {
+			return nil, sender, nil, fmt.Errorf("edit rejected") // TODO more specific error message
+		}
+		if content.NewContent != nil {
+			content = content.NewContent
+		}
+	}
+
+	msg := &waProto.Message{}
+	ctxInfo := portal.generateContextInfo(content.RelatesTo)
 	relaybotFormatted := false
 	if !sender.IsLoggedIn() || (portal.IsPrivateChat() && sender.JID.User != portal.Key.Receiver.User) {
 		if !portal.HasRelaybot() {
-			return nil, sender, errUserNotLoggedIn
+			return nil, sender, nil, errUserNotLoggedIn
 		}
 		relaybotFormatted = portal.addRelaybotFormat(sender, content)
 		sender = portal.GetRelayUser()
@@ -3637,7 +3775,7 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 	case event.MsgText, event.MsgEmote, event.MsgNotice:
 		text := content.Body
 		if content.MsgType == event.MsgNotice && !portal.bridge.Config.Bridge.BridgeNotices {
-			return nil, sender, errMNoticeDisabled
+			return nil, sender, nil, errMNoticeDisabled
 		}
 		if content.Format == event.FormatHTML {
 			text, ctxInfo.MentionedJid = portal.bridge.Formatter.ParseMatrix(content.FormattedBody)
@@ -3647,11 +3785,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 		}
 		msg.ExtendedTextMessage = &waProto.ExtendedTextMessage{
 			Text:        &text,
-			ContextInfo: &ctxInfo,
+			ContextInfo: ctxInfo,
 		}
 		hasPreview := portal.convertURLPreviewToWhatsApp(ctx, sender, evt, msg.ExtendedTextMessage)
 		if ctx.Err() != nil {
-			return nil, nil, ctx.Err()
+			return nil, nil, nil, ctx.Err()
 		}
 		if ctxInfo.StanzaId == nil && ctxInfo.MentionedJid == nil && ctxInfo.Expiration == nil && !hasPreview {
 			// No need for extended message
@@ -3661,11 +3799,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 	case event.MsgImage:
 		media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaImage)
 		if media == nil {
-			return nil, sender, err
+			return nil, sender, nil, err
 		}
 		ctxInfo.MentionedJid = media.MentionedJIDs
 		msg.ImageMessage = &waProto.ImageMessage{
-			ContextInfo:   &ctxInfo,
+			ContextInfo:   ctxInfo,
 			Caption:       &media.Caption,
 			JpegThumbnail: media.Thumbnail,
 			Url:           &media.URL,
@@ -3678,11 +3816,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 	case event.MessageType(event.EventSticker.Type):
 		media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaImage)
 		if media == nil {
-			return nil, sender, err
+			return nil, sender, nil, err
 		}
 		ctxInfo.MentionedJid = media.MentionedJIDs
 		msg.StickerMessage = &waProto.StickerMessage{
-			ContextInfo:   &ctxInfo,
+			ContextInfo:   ctxInfo,
 			PngThumbnail:  media.Thumbnail,
 			Url:           &media.URL,
 			MediaKey:      media.MediaKey,
@@ -3695,12 +3833,12 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 		gifPlayback := content.GetInfo().MimeType == "image/gif"
 		media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaVideo)
 		if media == nil {
-			return nil, sender, err
+			return nil, sender, nil, err
 		}
 		duration := uint32(content.GetInfo().Duration / 1000)
 		ctxInfo.MentionedJid = media.MentionedJIDs
 		msg.VideoMessage = &waProto.VideoMessage{
-			ContextInfo:   &ctxInfo,
+			ContextInfo:   ctxInfo,
 			Caption:       &media.Caption,
 			JpegThumbnail: media.Thumbnail,
 			Url:           &media.URL,
@@ -3715,11 +3853,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 	case event.MsgAudio:
 		media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaAudio)
 		if media == nil {
-			return nil, sender, err
+			return nil, sender, nil, err
 		}
 		duration := uint32(content.GetInfo().Duration / 1000)
 		msg.AudioMessage = &waProto.AudioMessage{
-			ContextInfo:   &ctxInfo,
+			ContextInfo:   ctxInfo,
 			Url:           &media.URL,
 			MediaKey:      media.MediaKey,
 			Mimetype:      &content.GetInfo().MimeType,
@@ -3738,10 +3876,10 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 	case event.MsgFile:
 		media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaDocument)
 		if media == nil {
-			return nil, sender, err
+			return nil, sender, nil, err
 		}
 		msg.DocumentMessage = &waProto.DocumentMessage{
-			ContextInfo:   &ctxInfo,
+			ContextInfo:   ctxInfo,
 			Caption:       &media.Caption,
 			JpegThumbnail: media.Thumbnail,
 			Url:           &media.URL,
@@ -3764,16 +3902,16 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 	case event.MsgLocation:
 		lat, long, err := parseGeoURI(content.GeoURI)
 		if err != nil {
-			return nil, sender, fmt.Errorf("%w: %v", errInvalidGeoURI, err)
+			return nil, sender, nil, fmt.Errorf("%w: %v", errInvalidGeoURI, err)
 		}
 		msg.LocationMessage = &waProto.LocationMessage{
 			DegreesLatitude:  &lat,
 			DegreesLongitude: &long,
 			Comment:          &content.Body,
-			ContextInfo:      &ctxInfo,
+			ContextInfo:      ctxInfo,
 		}
 	default:
-		return nil, sender, fmt.Errorf("%w %q", errUnknownMsgType, content.MsgType)
+		return nil, sender, nil, fmt.Errorf("%w %q", errUnknownMsgType, content.MsgType)
 	}
 
 	if editRootMsg != nil {
@@ -3795,7 +3933,7 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 		}
 	}
 
-	return msg, sender, nil
+	return msg, sender, nil, nil
 }
 
 func (portal *Portal) generateMessageInfo(sender *User) *types.MessageInfo {
@@ -3815,7 +3953,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing
 	start := time.Now()
 	ms := metricSender{portal: portal, timings: &timings}
 
-	allowRelay := evt.Type != TypeMSC3881PollResponse && evt.Type != TypeMSC3881V2PollResponse
+	allowRelay := evt.Type != TypeMSC3381PollResponse && evt.Type != TypeMSC3381V2PollResponse && evt.Type != TypeMSC3381PollStart
 	if err := portal.canBridgeFrom(sender, allowRelay); err != nil {
 		go ms.sendMessageMetrics(evt, err, "Ignoring", true)
 		return
@@ -3875,14 +4013,16 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing
 
 	timings.preproc = time.Since(start)
 	start = time.Now()
-	msg, sender, err := portal.convertMatrixMessage(ctx, sender, evt)
+	msg, sender, extraMeta, err := portal.convertMatrixMessage(ctx, sender, evt)
 	timings.convert = time.Since(start)
 	if msg == nil {
 		go ms.sendMessageMetrics(evt, err, "Error converting", true)
 		return
 	}
 	dbMsgType := database.MsgNormal
-	if msg.EditedMessage == nil {
+	if msg.PollCreationMessage != nil || msg.PollCreationMessageV2 != nil {
+		dbMsgType = database.MsgMatrixPoll
+	} else if msg.EditedMessage == nil {
 		portal.MarkDisappearing(nil, origEvtID, portal.ExpirationTime, true)
 	} else {
 		dbMsgType = database.MsgEdit
@@ -3893,6 +4033,9 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing
 	} else {
 		info.ID = dbMsg.JID
 	}
+	if dbMsgType == database.MsgMatrixPoll && extraMeta != nil && extraMeta.PollOptions != nil {
+		dbMsg.PutPollOptions(extraMeta.PollOptions)
+	}
 	portal.log.Debugln("Sending event", evt.ID, "to WhatsApp", info.ID)
 	start = time.Now()
 	resp, err := sender.Client.SendMessage(ctx, portal.Key.JID, info.ID, msg)