Browse Source

Add support for creating polls from Matrix

Tulir Asokan 2 năm trước cách đây
mục cha
commit
0305680317

+ 4 - 0
CHANGELOG.md

@@ -1,3 +1,7 @@
+# unreleased
+
+* Added support for sending polls from Matrix to WhatsApp.
+
 # v0.8.0 (2022-12-16)
 # v0.8.0 (2022-12-16)
 
 
 * Added support for bridging polls from WhatsApp and votes in both directions.
 * Added support for bridging polls from WhatsApp and votes in both directions.

+ 1 - 1
ROADMAP.md

@@ -6,7 +6,7 @@
     * [x] Location messages
     * [x] Location messages
     * [x] Media/files
     * [x] Media/files
     * [x] Replies
     * [x] Replies
-    * [ ] Polls
+    * [x] Polls
     * [x] Poll votes
     * [x] Poll votes
   * [x] Message redactions
   * [x] Message redactions
   * [x] Reactions
   * [x] Reactions

+ 6 - 5
database/message.go

@@ -134,11 +134,12 @@ const (
 type MessageType string
 type MessageType string
 
 
 const (
 const (
-	MsgUnknown  MessageType = ""
-	MsgFake     MessageType = "fake"
-	MsgNormal   MessageType = "message"
-	MsgReaction MessageType = "reaction"
-	MsgEdit     MessageType = "edit"
+	MsgUnknown    MessageType = ""
+	MsgFake       MessageType = "fake"
+	MsgNormal     MessageType = "message"
+	MsgReaction   MessageType = "reaction"
+	MsgEdit       MessageType = "edit"
+	MsgMatrixPoll MessageType = "matrix-poll"
 )
 )
 
 
 type Message struct {
 type Message struct {

+ 118 - 0
database/polloption.go

@@ -0,0 +1,118 @@
+// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
+// Copyright (C) 2022 Tulir Asokan
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+package database
+
+import (
+	"fmt"
+	"strings"
+
+	"github.com/lib/pq"
+	"maunium.net/go/mautrix/util/dbutil"
+)
+
+func scanPollOptionMapping(rows dbutil.Rows) (id string, hashArr [32]byte, err error) {
+	var hash []byte
+	err = rows.Scan(&id, &hash)
+	if err != nil {
+		// return below
+	} else if len(hash) != 32 {
+		err = fmt.Errorf("unexpected hash length %d", len(hash))
+	} else {
+		hashArr = *(*[32]byte)(hash)
+	}
+	return
+}
+
+func (msg *Message) PutPollOptions(opts map[[32]byte]string) {
+	query := "INSERT INTO poll_option_id (msg_mxid, opt_id, opt_hash) VALUES ($1, $2, $3)"
+	args := make([]any, len(opts)*2+1)
+	placeholders := make([]string, len(opts))
+	args[0] = msg.MXID
+	i := 0
+	for hash, id := range opts {
+		args[i*2+1] = id
+		hashCopy := hash
+		args[i*2+2] = hashCopy[:]
+		placeholders[i] = fmt.Sprintf("($1, $%d, $%d)", i*2+2, i*2+3)
+		i++
+	}
+	query = strings.ReplaceAll(query, "($1, $2, $3)", strings.Join(placeholders, ","))
+	_, err := msg.db.Exec(query, args...)
+	if err != nil {
+		msg.log.Errorfln("Failed to save poll options for %s: %v", msg.MXID, err)
+	}
+}
+
+func (msg *Message) GetPollOptionIDs(hashes [][]byte) map[[32]byte]string {
+	query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_hash = ANY($2)"
+	var args []any
+	if msg.db.Dialect == dbutil.Postgres {
+		args = []any{msg.MXID, pq.Array(hashes)}
+	} else {
+		query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(hashes)), ",")))
+		args = make([]any, len(hashes)+1)
+		args[0] = msg.MXID
+		for i, hash := range hashes {
+			args[i+1] = hash
+		}
+	}
+	ids := make(map[[32]byte]string, len(hashes))
+	rows, err := msg.db.Query(query, args...)
+	if err != nil {
+		msg.log.Errorfln("Failed to query poll option IDs for %s: %v", msg.MXID, err)
+	} else {
+		for rows.Next() {
+			id, hash, err := scanPollOptionMapping(rows)
+			if err != nil {
+				msg.log.Errorfln("Failed to scan poll option ID for %s: %v", msg.MXID, err)
+				break
+			}
+			ids[hash] = id
+		}
+	}
+	return ids
+}
+
+func (msg *Message) GetPollOptionHashes(ids []string) map[string][32]byte {
+	query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_id = ANY($2)"
+	var args []any
+	if msg.db.Dialect == dbutil.Postgres {
+		args = []any{msg.MXID, pq.Array(ids)}
+	} else {
+		query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(ids)), ",")))
+		args = make([]any, len(ids)+1)
+		args[0] = msg.MXID
+		for i, id := range ids {
+			args[i+1] = id
+		}
+	}
+	hashes := make(map[string][32]byte, len(ids))
+	rows, err := msg.db.Query(query, args...)
+	if err != nil {
+		msg.log.Errorfln("Failed to query poll option hashes for %s: %v", msg.MXID, err)
+	} else {
+		for rows.Next() {
+			id, hash, err := scanPollOptionMapping(rows)
+			if err != nil {
+				msg.log.Errorfln("Failed to scan poll option hash for %s: %v", msg.MXID, err)
+				break
+			}
+			hashes[id] = hash
+		}
+	}
+	return hashes
+}

+ 12 - 1
database/upgrades/00-latest-revision.sql

@@ -1,4 +1,4 @@
--- v0 -> v52: Latest revision
+-- v0 -> v54: Latest revision
 
 
 CREATE TABLE "user" (
 CREATE TABLE "user" (
     mxid     TEXT PRIMARY KEY,
     mxid     TEXT PRIMARY KEY,
@@ -40,6 +40,7 @@ CREATE TABLE portal (
 
 
     PRIMARY KEY (jid, receiver)
     PRIMARY KEY (jid, receiver)
 );
 );
+CREATE INDEX portal_parent_group_idx ON portal(parent_group);
 
 
 CREATE TABLE puppet (
 CREATE TABLE puppet (
     username     TEXT PRIMARY KEY,
     username     TEXT PRIMARY KEY,
@@ -79,6 +80,16 @@ CREATE TABLE message (
     FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
     FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
 );
 );
 
 
+CREATE TABLE poll_option_id (
+    msg_mxid TEXT,
+    opt_id   TEXT,
+    opt_hash bytea CHECK ( length(opt_hash) = 32 ),
+
+    PRIMARY KEY (msg_mxid, opt_id),
+    CONSTRAINT poll_option_unique_hash UNIQUE (msg_mxid, opt_hash),
+    CONSTRAINT message_mxid_fkey FOREIGN KEY (msg_mxid) REFERENCES message(mxid) ON DELETE CASCADE ON UPDATE CASCADE
+);
+
 CREATE TABLE reaction (
 CREATE TABLE reaction (
     chat_jid      TEXT,
     chat_jid      TEXT,
     chat_receiver TEXT,
     chat_receiver TEXT,

+ 11 - 0
database/upgrades/54-poll-option-id-map.sql

@@ -0,0 +1,11 @@
+-- v54: Store mapping for poll option IDs from Matrix
+
+CREATE TABLE poll_option_id (
+    msg_mxid TEXT,
+    opt_id   TEXT,
+    opt_hash bytea CHECK ( length(opt_hash) = 32 ),
+
+    PRIMARY KEY (msg_mxid, opt_id),
+    CONSTRAINT poll_option_unique_hash UNIQUE (msg_mxid, opt_hash),
+    CONSTRAINT message_mxid_fkey FOREIGN KEY (msg_mxid) REFERENCES message(mxid) ON DELETE CASCADE ON UPDATE CASCADE
+);

+ 11 - 3
formatting.go

@@ -35,7 +35,8 @@ var strikethroughRegex = regexp.MustCompile("([\\s>_*]|^)~(.+?)~([^a-zA-Z\\d]|$)
 var codeBlockRegex = regexp.MustCompile("```(?:.|\n)+?```")
 var codeBlockRegex = regexp.MustCompile("```(?:.|\n)+?```")
 var inlineURLRegex = regexp.MustCompile(`\[(.+?)]\((.+?)\)`)
 var inlineURLRegex = regexp.MustCompile(`\[(.+?)]\((.+?)\)`)
 
 
-const mentionedJIDsContextKey = "net.maunium.whatsapp.mentioned_jids"
+const mentionedJIDsContextKey = "fi.mau.whatsapp.mentioned_jids"
+const disableMentionsContextKey = "fi.mau.whatsapp.no_mentions"
 
 
 type Formatter struct {
 type Formatter struct {
 	bridge *WABridge
 	bridge *WABridge
@@ -55,7 +56,8 @@ func NewFormatter(bridge *WABridge) *Formatter {
 			Newline:      "\n",
 			Newline:      "\n",
 
 
 			PillConverter: func(displayname, mxid, eventID string, ctx format.Context) string {
 			PillConverter: func(displayname, mxid, eventID string, ctx format.Context) string {
-				if mxid[0] == '@' {
+				_, disableMentions := ctx[disableMentionsContextKey]
+				if mxid[0] == '@' && !disableMentions {
 					puppet := bridge.GetPuppetByMXID(id.UserID(mxid))
 					puppet := bridge.GetPuppetByMXID(id.UserID(mxid))
 					if puppet != nil {
 					if puppet != nil {
 						jids, ok := ctx[mentionedJIDsContextKey].([]string)
 						jids, ok := ctx[mentionedJIDsContextKey].([]string)
@@ -67,7 +69,7 @@ func NewFormatter(bridge *WABridge) *Formatter {
 						return "@" + puppet.JID.User
 						return "@" + puppet.JID.User
 					}
 					}
 				}
 				}
-				return mxid
+				return displayname
 			},
 			},
 			BoldConverter:           func(text string, _ format.Context) string { return fmt.Sprintf("*%s*", text) },
 			BoldConverter:           func(text string, _ format.Context) string { return fmt.Sprintf("*%s*", text) },
 			ItalicConverter:         func(text string, _ format.Context) string { return fmt.Sprintf("_%s_", text) },
 			ItalicConverter:         func(text string, _ format.Context) string { return fmt.Sprintf("_%s_", text) },
@@ -151,3 +153,9 @@ func (formatter *Formatter) ParseMatrix(html string) (string, []string) {
 	mentionedJIDs, _ := ctx[mentionedJIDsContextKey].([]string)
 	mentionedJIDs, _ := ctx[mentionedJIDsContextKey].([]string)
 	return result, mentionedJIDs
 	return result, mentionedJIDs
 }
 }
+
+func (formatter *Formatter) ParseMatrixWithoutMentions(html string) string {
+	ctx := make(format.Context)
+	ctx[disableMentionsContextKey] = true
+	return formatter.matrixHTMLParser.Parse(html, ctx)
+}

+ 3 - 2
main.go

@@ -87,8 +87,9 @@ func (br *WABridge) Init() {
 
 
 	// TODO this is a weird place for this
 	// TODO this is a weird place for this
 	br.EventProcessor.On(event.EphemeralEventPresence, br.HandlePresence)
 	br.EventProcessor.On(event.EphemeralEventPresence, br.HandlePresence)
-	br.EventProcessor.On(TypeMSC3881PollResponse, br.MatrixHandler.HandleMessage)
-	br.EventProcessor.On(TypeMSC3881V2PollResponse, br.MatrixHandler.HandleMessage)
+	br.EventProcessor.On(TypeMSC3381PollStart, br.MatrixHandler.HandleMessage)
+	br.EventProcessor.On(TypeMSC3381PollResponse, br.MatrixHandler.HandleMessage)
+	br.EventProcessor.On(TypeMSC3381V2PollResponse, br.MatrixHandler.HandleMessage)
 
 
 	Segment.log = br.Log.Sub("Segment")
 	Segment.log = br.Log.Sub("Segment")
 	Segment.key = br.Config.SegmentKey
 	Segment.key = br.Config.SegmentKey

+ 6 - 2
messagetracking.go

@@ -52,6 +52,8 @@ var (
 	errTargetIsFake                = errors.New("target is a fake event")
 	errTargetIsFake                = errors.New("target is a fake event")
 	errReactionSentBySomeoneElse   = errors.New("target reaction was sent by someone else")
 	errReactionSentBySomeoneElse   = errors.New("target reaction was sent by someone else")
 	errDMSentByOtherUser           = errors.New("target message was sent by the other user in a DM")
 	errDMSentByOtherUser           = errors.New("target message was sent by the other user in a DM")
+	errPollMissingQuestion         = errors.New("poll message is missing question")
+	errPollDuplicateOption         = errors.New("poll options must be unique")
 
 
 	errBroadcastReactionNotSupported = errors.New("reacting to status messages is not currently supported")
 	errBroadcastReactionNotSupported = errors.New("reacting to status messages is not currently supported")
 	errBroadcastSendDisabled         = errors.New("sending status messages is disabled")
 	errBroadcastSendDisabled         = errors.New("sending status messages is disabled")
@@ -76,7 +78,7 @@ func errorToStatusReason(err error) (reason event.MessageStatusReason, status ev
 		return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, ""
 		return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, ""
 	case errors.Is(err, errMNoticeDisabled):
 	case errors.Is(err, errMNoticeDisabled):
 		return event.MessageStatusUnsupported, event.MessageStatusFail, true, false, ""
 		return event.MessageStatusUnsupported, event.MessageStatusFail, true, false, ""
-	case errors.Is(err, errMediaUnsupportedType):
+	case errors.Is(err, errMediaUnsupportedType), errors.Is(err, errPollMissingQuestion), errors.Is(err, errPollDuplicateOption):
 		return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, err.Error()
 		return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, err.Error()
 	case errors.Is(err, errTimeoutBeforeHandling):
 	case errors.Is(err, errTimeoutBeforeHandling):
 		return event.MessageStatusTooOld, event.MessageStatusRetriable, true, true, "the message was too old when it reached the bridge, so it was not handled"
 		return event.MessageStatusTooOld, event.MessageStatusRetriable, true, true, "the message was too old when it reached the bridge, so it was not handled"
@@ -185,8 +187,10 @@ func (portal *Portal) sendMessageMetrics(evt *event.Event, err error, part strin
 		msgType = "reaction"
 		msgType = "reaction"
 	case event.EventRedaction:
 	case event.EventRedaction:
 		msgType = "redaction"
 		msgType = "redaction"
-	case TypeMSC3881PollResponse, TypeMSC3881V2PollResponse:
+	case TypeMSC3381PollResponse, TypeMSC3381V2PollResponse:
 		msgType = "poll response"
 		msgType = "poll response"
+	case TypeMSC3381PollStart:
+		msgType = "poll start"
 	default:
 	default:
 		msgType = "unknown event"
 		msgType = "unknown event"
 	}
 	}

+ 205 - 62
portal.go

@@ -19,6 +19,7 @@ package main
 import (
 import (
 	"bytes"
 	"bytes"
 	"context"
 	"context"
+	"crypto/rand"
 	"crypto/sha256"
 	"crypto/sha256"
 	"encoding/hex"
 	"encoding/hex"
 	"encoding/json"
 	"encoding/json"
@@ -322,7 +323,7 @@ func (portal *Portal) handleMatrixMessageLoopItem(msg PortalMatrixMessage) {
 	portal.handleMatrixReadReceipt(msg.user, "", evtTS, false)
 	portal.handleMatrixReadReceipt(msg.user, "", evtTS, false)
 	timings.implicitRR = time.Since(implicitRRStart)
 	timings.implicitRR = time.Since(implicitRRStart)
 	switch msg.evt.Type {
 	switch msg.evt.Type {
-	case event.EventMessage, event.EventSticker, TypeMSC3881V2PollResponse, TypeMSC3881PollResponse:
+	case event.EventMessage, event.EventSticker, TypeMSC3381V2PollResponse, TypeMSC3381PollResponse, TypeMSC3381PollStart:
 		portal.HandleMatrixMessage(msg.user, msg.evt, timings)
 		portal.HandleMatrixMessage(msg.user, msg.evt, timings)
 	case event.EventRedaction:
 	case event.EventRedaction:
 		portal.HandleMatrixRedaction(msg.user, msg.evt)
 		portal.HandleMatrixRedaction(msg.user, msg.evt)
@@ -2267,9 +2268,6 @@ func (portal *Portal) convertListResponseMessage(intent *appservice.IntentAPI, m
 }
 }
 
 
 func (portal *Portal) convertPollUpdateMessage(intent *appservice.IntentAPI, source *User, info *types.MessageInfo, msg *waProto.PollUpdateMessage) *ConvertedMessage {
 func (portal *Portal) convertPollUpdateMessage(intent *appservice.IntentAPI, source *User, info *types.MessageInfo, msg *waProto.PollUpdateMessage) *ConvertedMessage {
-	if !portal.bridge.Config.Bridge.ExtEvPolls {
-		return nil
-	}
 	pollMessage := portal.bridge.DB.Message.GetByJID(portal.Key, msg.GetPollCreationMessageKey().GetId())
 	pollMessage := portal.bridge.DB.Message.GetByJID(portal.Key, msg.GetPollCreationMessageKey().GetId())
 	if pollMessage == nil {
 	if pollMessage == nil {
 		portal.log.Warnfln("Failed to convert vote message %s: poll message %s not found", info.ID, msg.GetPollCreationMessageKey().GetId())
 		portal.log.Warnfln("Failed to convert vote message %s: poll message %s not found", info.ID, msg.GetPollCreationMessageKey().GetId())
@@ -2284,13 +2282,28 @@ func (portal *Portal) convertPollUpdateMessage(intent *appservice.IntentAPI, sou
 		return nil
 		return nil
 	}
 	}
 	selectedHashes := make([]string, len(vote.GetSelectedOptions()))
 	selectedHashes := make([]string, len(vote.GetSelectedOptions()))
-	for i, opt := range vote.GetSelectedOptions() {
-		selectedHashes[i] = hex.EncodeToString(opt)
+	if pollMessage.Type == database.MsgMatrixPoll {
+		mappedAnswers := pollMessage.GetPollOptionIDs(vote.GetSelectedOptions())
+		for i, opt := range vote.GetSelectedOptions() {
+			if len(opt) != 32 {
+				portal.log.Warnfln("Unexpected option hash length %d in %s's vote to %s", len(opt), info.Sender, pollMessage.MXID)
+				continue
+			}
+			var ok bool
+			selectedHashes[i], ok = mappedAnswers[*(*[32]byte)(opt)]
+			if !ok {
+				portal.log.Warnfln("Didn't find ID for option %X in %s's vote to %s", opt, info.Sender, pollMessage.MXID)
+			}
+		}
+	} else {
+		for i, opt := range vote.GetSelectedOptions() {
+			selectedHashes[i] = hex.EncodeToString(opt)
+		}
 	}
 	}
 
 
-	evtType := TypeMSC3881PollResponse
+	evtType := TypeMSC3381PollResponse
 	//if portal.bridge.Config.Bridge.ExtEvPolls == 2 {
 	//if portal.bridge.Config.Bridge.ExtEvPolls == 2 {
-	//	evtType = TypeMSC3881V2PollResponse
+	//	evtType = TypeMSC3381V2PollResponse
 	//}
 	//}
 	return &ConvertedMessage{
 	return &ConvertedMessage{
 		Intent: intent,
 		Intent: intent,
@@ -2341,7 +2354,7 @@ func (portal *Portal) convertPollCreationMessage(intent *appservice.IntentAPI, m
 	}
 	}
 	evtType := event.EventMessage
 	evtType := event.EventMessage
 	if portal.bridge.Config.Bridge.ExtEvPolls {
 	if portal.bridge.Config.Bridge.ExtEvPolls {
-		evtType.Type = "org.matrix.msc3381.poll.start"
+		evtType = TypeMSC3381PollStart
 	}
 	}
 	//else if portal.bridge.Config.Bridge.ExtEvPolls == 2 {
 	//else if portal.bridge.Config.Bridge.ExtEvPolls == 2 {
 	//	evtType.Type = "org.matrix.msc3381.v2.poll.start"
 	//	evtType.Type = "org.matrix.msc3381.v2.poll.start"
@@ -3505,8 +3518,9 @@ func getUnstableWaveform(content map[string]interface{}) []byte {
 }
 }
 
 
 var (
 var (
-	TypeMSC3881PollResponse   = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.response"}
-	TypeMSC3881V2PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3881.v2.poll.response"}
+	TypeMSC3381PollStart      = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.start"}
+	TypeMSC3381PollResponse   = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.response"}
+	TypeMSC3381V2PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.v2.poll.response"}
 )
 )
 
 
 type PollResponseContent struct {
 type PollResponseContent struct {
@@ -3532,15 +3546,71 @@ func (content *PollResponseContent) SetRelatesTo(rel *event.RelatesTo) {
 	content.RelatesTo = *rel
 	content.RelatesTo = *rel
 }
 }
 
 
+type MSC1767Message struct {
+	Text    string `json:"org.matrix.msc1767.text,omitempty"`
+	HTML    string `json:"org.matrix.msc1767.html,omitempty"`
+	Message []struct {
+		MimeType string `json:"mimetype"`
+		Body     string `json:"body"`
+	} `json:"org.matrix.msc1767.message,omitempty"`
+}
+
+func (portal *Portal) msc1767ToWhatsApp(msg MSC1767Message, mentions bool) (string, []string) {
+	for _, part := range msg.Message {
+		if part.MimeType == "text/html" && msg.HTML == "" {
+			msg.HTML = part.Body
+		} else if part.MimeType == "text/plain" && msg.Text == "" {
+			msg.Text = part.Body
+		}
+	}
+	if msg.HTML != "" {
+		if mentions {
+			return portal.bridge.Formatter.ParseMatrix(msg.HTML)
+		} else {
+			return portal.bridge.Formatter.ParseMatrixWithoutMentions(msg.HTML), nil
+		}
+	}
+	return msg.Text, nil
+}
+
+type PollStartContent struct {
+	RelatesTo *event.RelatesTo `json:"m.relates_to"`
+	PollStart struct {
+		Kind          string         `json:"kind"`
+		MaxSelections int            `json:"max_selections"`
+		Question      MSC1767Message `json:"question"`
+		Answers       []struct {
+			ID string `json:"id"`
+			MSC1767Message
+		} `json:"answers"`
+	} `json:"org.matrix.msc3381.poll.start"`
+}
+
+func (content *PollStartContent) GetRelatesTo() *event.RelatesTo {
+	if content.RelatesTo == nil {
+		content.RelatesTo = &event.RelatesTo{}
+	}
+	return content.RelatesTo
+}
+
+func (content *PollStartContent) OptionalGetRelatesTo() *event.RelatesTo {
+	return content.RelatesTo
+}
+
+func (content *PollStartContent) SetRelatesTo(rel *event.RelatesTo) {
+	content.RelatesTo = rel
+}
+
 func init() {
 func init() {
-	event.TypeMap[TypeMSC3881PollResponse] = reflect.TypeOf(PollResponseContent{})
-	event.TypeMap[TypeMSC3881V2PollResponse] = reflect.TypeOf(PollResponseContent{})
+	event.TypeMap[TypeMSC3381PollResponse] = reflect.TypeOf(PollResponseContent{})
+	event.TypeMap[TypeMSC3381V2PollResponse] = reflect.TypeOf(PollResponseContent{})
+	event.TypeMap[TypeMSC3381PollStart] = reflect.TypeOf(PollStartContent{})
 }
 }
 
 
-func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, error) {
+func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) {
 	content, ok := evt.Content.Parsed.(*PollResponseContent)
 	content, ok := evt.Content.Parsed.(*PollResponseContent)
 	if !ok {
 	if !ok {
-		return nil, sender, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
+		return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
 	}
 	}
 	var answers []string
 	var answers []string
 	if content.V1Response.Answers != nil {
 	if content.V1Response.Answers != nil {
@@ -3550,7 +3620,7 @@ func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt
 	}
 	}
 	pollMsg := portal.bridge.DB.Message.GetByMXID(content.RelatesTo.EventID)
 	pollMsg := portal.bridge.DB.Message.GetByMXID(content.RelatesTo.EventID)
 	if pollMsg == nil {
 	if pollMsg == nil {
-		return nil, sender, errTargetNotFound
+		return nil, sender, nil, errTargetNotFound
 	}
 	}
 	pollMsgInfo := &types.MessageInfo{
 	pollMsgInfo := &types.MessageInfo{
 		MessageSource: types.MessageSource{
 		MessageSource: types.MessageSource{
@@ -3563,43 +3633,81 @@ func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt
 		Type: "poll",
 		Type: "poll",
 	}
 	}
 	optionHashes := make([][]byte, 0, len(answers))
 	optionHashes := make([][]byte, 0, len(answers))
-	for _, selection := range answers {
-		hash, _ := hex.DecodeString(selection)
-		if hash != nil && len(hash) == 32 {
-			optionHashes = append(optionHashes, hash)
+	if pollMsg.Type == database.MsgMatrixPoll {
+		mappedAnswers := pollMsg.GetPollOptionHashes(answers)
+		for _, selection := range answers {
+			hash, ok := mappedAnswers[selection]
+			if ok {
+				optionHashes = append(optionHashes, hash[:])
+			} else {
+				portal.log.Warnfln("Didn't find hash for option %s in %s's vote to %s", selection, evt.Sender, pollMsg.MXID)
+			}
+		}
+	} else {
+		for _, selection := range answers {
+			hash, _ := hex.DecodeString(selection)
+			if hash != nil && len(hash) == 32 {
+				optionHashes = append(optionHashes, hash)
+			}
 		}
 		}
 	}
 	}
 	pollUpdate, err := sender.Client.EncryptPollVote(pollMsgInfo, &waProto.PollVoteMessage{
 	pollUpdate, err := sender.Client.EncryptPollVote(pollMsgInfo, &waProto.PollVoteMessage{
 		SelectedOptions: optionHashes,
 		SelectedOptions: optionHashes,
 	})
 	})
-	return &waProto.Message{PollUpdateMessage: pollUpdate}, sender, err
+	return &waProto.Message{PollUpdateMessage: pollUpdate}, sender, nil, err
 }
 }
 
 
-func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, error) {
-	if evt.Type == TypeMSC3881PollResponse || evt.Type == TypeMSC3881V2PollResponse {
-		return portal.convertMatrixPollVote(ctx, sender, evt)
-	}
-	content, ok := evt.Content.Parsed.(*event.MessageEventContent)
+func (portal *Portal) convertMatrixPollStart(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) {
+	content, ok := evt.Content.Parsed.(*PollStartContent)
 	if !ok {
 	if !ok {
-		return nil, sender, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
-	}
-	var editRootMsg *database.Message
-	if editEventID := content.RelatesTo.GetReplaceID(); editEventID != "" && portal.bridge.Config.Bridge.SendWhatsAppEdits {
-		editRootMsg = portal.bridge.DB.Message.GetByMXID(editEventID)
-		if editRootMsg == nil || editRootMsg.Type != database.MsgNormal || editRootMsg.IsFakeJID() || editRootMsg.Sender.User != sender.JID.User {
-			return nil, sender, fmt.Errorf("edit rejected") // TODO more specific error message
-		}
-		if content.NewContent != nil {
-			content = content.NewContent
-		}
-	}
+		return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
+	}
+	maxAnswers := content.PollStart.MaxSelections
+	if maxAnswers >= len(content.PollStart.Answers) || maxAnswers < 0 {
+		maxAnswers = 0
+	}
+	fmt.Printf("%+v\n", content.PollStart)
+	ctxInfo := portal.generateContextInfo(content.RelatesTo)
+	var question string
+	question, ctxInfo.MentionedJid = portal.msc1767ToWhatsApp(content.PollStart.Question, true)
+	if len(question) == 0 {
+		return nil, sender, nil, errPollMissingQuestion
+	}
+	options := make([]*waProto.PollCreationMessage_Option, len(content.PollStart.Answers))
+	optionMap := make(map[[32]byte]string, len(options))
+	for i, opt := range content.PollStart.Answers {
+		body, _ := portal.msc1767ToWhatsApp(opt.MSC1767Message, false)
+		hash := sha256.Sum256([]byte(body))
+		if _, alreadyExists := optionMap[hash]; alreadyExists {
+			portal.log.Warnfln("Poll %s by %s has option %q more than once, rejecting", evt.ID, evt.Sender, body)
+			return nil, sender, nil, errPollDuplicateOption
+		}
+		optionMap[hash] = opt.ID
+		options[i] = &waProto.PollCreationMessage_Option{
+			OptionName: proto.String(body),
+		}
+	}
+	secret := make([]byte, 32)
+	_, err := rand.Read(secret)
+	return &waProto.Message{
+		PollCreationMessage: &waProto.PollCreationMessage{
+			Name:                   proto.String(question),
+			Options:                options,
+			SelectableOptionsCount: proto.Uint32(uint32(maxAnswers)),
+			ContextInfo:            ctxInfo,
+		},
+		MessageContextInfo: &waProto.MessageContextInfo{
+			MessageSecret: secret,
+		},
+	}, sender, &extraConvertMeta{PollOptions: optionMap}, err
+}
 
 
-	msg := &waProto.Message{}
+func (portal *Portal) generateContextInfo(relatesTo *event.RelatesTo) *waProto.ContextInfo {
 	var ctxInfo waProto.ContextInfo
 	var ctxInfo waProto.ContextInfo
-	replyToID := content.RelatesTo.GetReplyTo()
+	replyToID := relatesTo.GetReplyTo()
 	if len(replyToID) > 0 {
 	if len(replyToID) > 0 {
 		replyToMsg := portal.bridge.DB.Message.GetByMXID(replyToID)
 		replyToMsg := portal.bridge.DB.Message.GetByMXID(replyToID)
-		if replyToMsg != nil && !replyToMsg.IsFakeJID() && replyToMsg.Type == database.MsgNormal {
+		if replyToMsg != nil && !replyToMsg.IsFakeJID() && (replyToMsg.Type == database.MsgNormal || replyToMsg.Type == database.MsgMatrixPoll) {
 			ctxInfo.StanzaId = &replyToMsg.JID
 			ctxInfo.StanzaId = &replyToMsg.JID
 			ctxInfo.Participant = proto.String(replyToMsg.Sender.ToNonAD().String())
 			ctxInfo.Participant = proto.String(replyToMsg.Sender.ToNonAD().String())
 			// Using blank content here seems to work fine on all official WhatsApp apps.
 			// Using blank content here seems to work fine on all official WhatsApp apps.
@@ -3613,10 +3721,40 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 	if portal.ExpirationTime != 0 {
 	if portal.ExpirationTime != 0 {
 		ctxInfo.Expiration = proto.Uint32(portal.ExpirationTime)
 		ctxInfo.Expiration = proto.Uint32(portal.ExpirationTime)
 	}
 	}
+	return &ctxInfo
+}
+
+type extraConvertMeta struct {
+	PollOptions map[[32]byte]string
+}
+
+func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) {
+	if evt.Type == TypeMSC3381PollResponse || evt.Type == TypeMSC3381V2PollResponse {
+		return portal.convertMatrixPollVote(ctx, sender, evt)
+	} else if evt.Type == TypeMSC3381PollStart {
+		return portal.convertMatrixPollStart(ctx, sender, evt)
+	}
+	content, ok := evt.Content.Parsed.(*event.MessageEventContent)
+	if !ok {
+		return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
+	}
+	var editRootMsg *database.Message
+	if editEventID := content.RelatesTo.GetReplaceID(); editEventID != "" && portal.bridge.Config.Bridge.SendWhatsAppEdits {
+		editRootMsg = portal.bridge.DB.Message.GetByMXID(editEventID)
+		if editRootMsg == nil || editRootMsg.Type != database.MsgNormal || editRootMsg.IsFakeJID() || editRootMsg.Sender.User != sender.JID.User {
+			return nil, sender, nil, fmt.Errorf("edit rejected") // TODO more specific error message
+		}
+		if content.NewContent != nil {
+			content = content.NewContent
+		}
+	}
+
+	msg := &waProto.Message{}
+	ctxInfo := portal.generateContextInfo(content.RelatesTo)
 	relaybotFormatted := false
 	relaybotFormatted := false
 	if !sender.IsLoggedIn() || (portal.IsPrivateChat() && sender.JID.User != portal.Key.Receiver.User) {
 	if !sender.IsLoggedIn() || (portal.IsPrivateChat() && sender.JID.User != portal.Key.Receiver.User) {
 		if !portal.HasRelaybot() {
 		if !portal.HasRelaybot() {
-			return nil, sender, errUserNotLoggedIn
+			return nil, sender, nil, errUserNotLoggedIn
 		}
 		}
 		relaybotFormatted = portal.addRelaybotFormat(sender, content)
 		relaybotFormatted = portal.addRelaybotFormat(sender, content)
 		sender = portal.GetRelayUser()
 		sender = portal.GetRelayUser()
@@ -3637,7 +3775,7 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 	case event.MsgText, event.MsgEmote, event.MsgNotice:
 	case event.MsgText, event.MsgEmote, event.MsgNotice:
 		text := content.Body
 		text := content.Body
 		if content.MsgType == event.MsgNotice && !portal.bridge.Config.Bridge.BridgeNotices {
 		if content.MsgType == event.MsgNotice && !portal.bridge.Config.Bridge.BridgeNotices {
-			return nil, sender, errMNoticeDisabled
+			return nil, sender, nil, errMNoticeDisabled
 		}
 		}
 		if content.Format == event.FormatHTML {
 		if content.Format == event.FormatHTML {
 			text, ctxInfo.MentionedJid = portal.bridge.Formatter.ParseMatrix(content.FormattedBody)
 			text, ctxInfo.MentionedJid = portal.bridge.Formatter.ParseMatrix(content.FormattedBody)
@@ -3647,11 +3785,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 		}
 		}
 		msg.ExtendedTextMessage = &waProto.ExtendedTextMessage{
 		msg.ExtendedTextMessage = &waProto.ExtendedTextMessage{
 			Text:        &text,
 			Text:        &text,
-			ContextInfo: &ctxInfo,
+			ContextInfo: ctxInfo,
 		}
 		}
 		hasPreview := portal.convertURLPreviewToWhatsApp(ctx, sender, evt, msg.ExtendedTextMessage)
 		hasPreview := portal.convertURLPreviewToWhatsApp(ctx, sender, evt, msg.ExtendedTextMessage)
 		if ctx.Err() != nil {
 		if ctx.Err() != nil {
-			return nil, nil, ctx.Err()
+			return nil, nil, nil, ctx.Err()
 		}
 		}
 		if ctxInfo.StanzaId == nil && ctxInfo.MentionedJid == nil && ctxInfo.Expiration == nil && !hasPreview {
 		if ctxInfo.StanzaId == nil && ctxInfo.MentionedJid == nil && ctxInfo.Expiration == nil && !hasPreview {
 			// No need for extended message
 			// No need for extended message
@@ -3661,11 +3799,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 	case event.MsgImage:
 	case event.MsgImage:
 		media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaImage)
 		media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaImage)
 		if media == nil {
 		if media == nil {
-			return nil, sender, err
+			return nil, sender, nil, err
 		}
 		}
 		ctxInfo.MentionedJid = media.MentionedJIDs
 		ctxInfo.MentionedJid = media.MentionedJIDs
 		msg.ImageMessage = &waProto.ImageMessage{
 		msg.ImageMessage = &waProto.ImageMessage{
-			ContextInfo:   &ctxInfo,
+			ContextInfo:   ctxInfo,
 			Caption:       &media.Caption,
 			Caption:       &media.Caption,
 			JpegThumbnail: media.Thumbnail,
 			JpegThumbnail: media.Thumbnail,
 			Url:           &media.URL,
 			Url:           &media.URL,
@@ -3678,11 +3816,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 	case event.MessageType(event.EventSticker.Type):
 	case event.MessageType(event.EventSticker.Type):
 		media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaImage)
 		media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaImage)
 		if media == nil {
 		if media == nil {
-			return nil, sender, err
+			return nil, sender, nil, err
 		}
 		}
 		ctxInfo.MentionedJid = media.MentionedJIDs
 		ctxInfo.MentionedJid = media.MentionedJIDs
 		msg.StickerMessage = &waProto.StickerMessage{
 		msg.StickerMessage = &waProto.StickerMessage{
-			ContextInfo:   &ctxInfo,
+			ContextInfo:   ctxInfo,
 			PngThumbnail:  media.Thumbnail,
 			PngThumbnail:  media.Thumbnail,
 			Url:           &media.URL,
 			Url:           &media.URL,
 			MediaKey:      media.MediaKey,
 			MediaKey:      media.MediaKey,
@@ -3695,12 +3833,12 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 		gifPlayback := content.GetInfo().MimeType == "image/gif"
 		gifPlayback := content.GetInfo().MimeType == "image/gif"
 		media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaVideo)
 		media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaVideo)
 		if media == nil {
 		if media == nil {
-			return nil, sender, err
+			return nil, sender, nil, err
 		}
 		}
 		duration := uint32(content.GetInfo().Duration / 1000)
 		duration := uint32(content.GetInfo().Duration / 1000)
 		ctxInfo.MentionedJid = media.MentionedJIDs
 		ctxInfo.MentionedJid = media.MentionedJIDs
 		msg.VideoMessage = &waProto.VideoMessage{
 		msg.VideoMessage = &waProto.VideoMessage{
-			ContextInfo:   &ctxInfo,
+			ContextInfo:   ctxInfo,
 			Caption:       &media.Caption,
 			Caption:       &media.Caption,
 			JpegThumbnail: media.Thumbnail,
 			JpegThumbnail: media.Thumbnail,
 			Url:           &media.URL,
 			Url:           &media.URL,
@@ -3715,11 +3853,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 	case event.MsgAudio:
 	case event.MsgAudio:
 		media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaAudio)
 		media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaAudio)
 		if media == nil {
 		if media == nil {
-			return nil, sender, err
+			return nil, sender, nil, err
 		}
 		}
 		duration := uint32(content.GetInfo().Duration / 1000)
 		duration := uint32(content.GetInfo().Duration / 1000)
 		msg.AudioMessage = &waProto.AudioMessage{
 		msg.AudioMessage = &waProto.AudioMessage{
-			ContextInfo:   &ctxInfo,
+			ContextInfo:   ctxInfo,
 			Url:           &media.URL,
 			Url:           &media.URL,
 			MediaKey:      media.MediaKey,
 			MediaKey:      media.MediaKey,
 			Mimetype:      &content.GetInfo().MimeType,
 			Mimetype:      &content.GetInfo().MimeType,
@@ -3738,10 +3876,10 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 	case event.MsgFile:
 	case event.MsgFile:
 		media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaDocument)
 		media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaDocument)
 		if media == nil {
 		if media == nil {
-			return nil, sender, err
+			return nil, sender, nil, err
 		}
 		}
 		msg.DocumentMessage = &waProto.DocumentMessage{
 		msg.DocumentMessage = &waProto.DocumentMessage{
-			ContextInfo:   &ctxInfo,
+			ContextInfo:   ctxInfo,
 			Caption:       &media.Caption,
 			Caption:       &media.Caption,
 			JpegThumbnail: media.Thumbnail,
 			JpegThumbnail: media.Thumbnail,
 			Url:           &media.URL,
 			Url:           &media.URL,
@@ -3764,16 +3902,16 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 	case event.MsgLocation:
 	case event.MsgLocation:
 		lat, long, err := parseGeoURI(content.GeoURI)
 		lat, long, err := parseGeoURI(content.GeoURI)
 		if err != nil {
 		if err != nil {
-			return nil, sender, fmt.Errorf("%w: %v", errInvalidGeoURI, err)
+			return nil, sender, nil, fmt.Errorf("%w: %v", errInvalidGeoURI, err)
 		}
 		}
 		msg.LocationMessage = &waProto.LocationMessage{
 		msg.LocationMessage = &waProto.LocationMessage{
 			DegreesLatitude:  &lat,
 			DegreesLatitude:  &lat,
 			DegreesLongitude: &long,
 			DegreesLongitude: &long,
 			Comment:          &content.Body,
 			Comment:          &content.Body,
-			ContextInfo:      &ctxInfo,
+			ContextInfo:      ctxInfo,
 		}
 		}
 	default:
 	default:
-		return nil, sender, fmt.Errorf("%w %q", errUnknownMsgType, content.MsgType)
+		return nil, sender, nil, fmt.Errorf("%w %q", errUnknownMsgType, content.MsgType)
 	}
 	}
 
 
 	if editRootMsg != nil {
 	if editRootMsg != nil {
@@ -3795,7 +3933,7 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
 		}
 		}
 	}
 	}
 
 
-	return msg, sender, nil
+	return msg, sender, nil, nil
 }
 }
 
 
 func (portal *Portal) generateMessageInfo(sender *User) *types.MessageInfo {
 func (portal *Portal) generateMessageInfo(sender *User) *types.MessageInfo {
@@ -3815,7 +3953,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing
 	start := time.Now()
 	start := time.Now()
 	ms := metricSender{portal: portal, timings: &timings}
 	ms := metricSender{portal: portal, timings: &timings}
 
 
-	allowRelay := evt.Type != TypeMSC3881PollResponse && evt.Type != TypeMSC3881V2PollResponse
+	allowRelay := evt.Type != TypeMSC3381PollResponse && evt.Type != TypeMSC3381V2PollResponse && evt.Type != TypeMSC3381PollStart
 	if err := portal.canBridgeFrom(sender, allowRelay); err != nil {
 	if err := portal.canBridgeFrom(sender, allowRelay); err != nil {
 		go ms.sendMessageMetrics(evt, err, "Ignoring", true)
 		go ms.sendMessageMetrics(evt, err, "Ignoring", true)
 		return
 		return
@@ -3875,14 +4013,16 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing
 
 
 	timings.preproc = time.Since(start)
 	timings.preproc = time.Since(start)
 	start = time.Now()
 	start = time.Now()
-	msg, sender, err := portal.convertMatrixMessage(ctx, sender, evt)
+	msg, sender, extraMeta, err := portal.convertMatrixMessage(ctx, sender, evt)
 	timings.convert = time.Since(start)
 	timings.convert = time.Since(start)
 	if msg == nil {
 	if msg == nil {
 		go ms.sendMessageMetrics(evt, err, "Error converting", true)
 		go ms.sendMessageMetrics(evt, err, "Error converting", true)
 		return
 		return
 	}
 	}
 	dbMsgType := database.MsgNormal
 	dbMsgType := database.MsgNormal
-	if msg.EditedMessage == nil {
+	if msg.PollCreationMessage != nil || msg.PollCreationMessageV2 != nil {
+		dbMsgType = database.MsgMatrixPoll
+	} else if msg.EditedMessage == nil {
 		portal.MarkDisappearing(nil, origEvtID, portal.ExpirationTime, true)
 		portal.MarkDisappearing(nil, origEvtID, portal.ExpirationTime, true)
 	} else {
 	} else {
 		dbMsgType = database.MsgEdit
 		dbMsgType = database.MsgEdit
@@ -3893,6 +4033,9 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing
 	} else {
 	} else {
 		info.ID = dbMsg.JID
 		info.ID = dbMsg.JID
 	}
 	}
+	if dbMsgType == database.MsgMatrixPoll && extraMeta != nil && extraMeta.PollOptions != nil {
+		dbMsg.PutPollOptions(extraMeta.PollOptions)
+	}
 	portal.log.Debugln("Sending event", evt.ID, "to WhatsApp", info.ID)
 	portal.log.Debugln("Sending event", evt.ID, "to WhatsApp", info.ID)
 	start = time.Now()
 	start = time.Now()
 	resp, err := sender.Client.SendMessage(ctx, portal.Key.JID, info.ID, msg)
 	resp, err := sender.Client.SendMessage(ctx, portal.Key.JID, info.ID, msg)