Browse Source

Store the message map

Gary Kramlich 3 years ago
parent
commit
312018173f
7 changed files with 213 additions and 10 deletions
  1. 1 1
      bridge/matrix.go
  2. 56 6
      bridge/portal.go
  3. 9 3
      database/database.go
  4. 66 0
      database/message.go
  5. 64 0
      database/messagequery.go
  6. 14 0
      database/migrations/01-initial.sql
  7. 3 0
      database/portal.go

+ 1 - 1
bridge/matrix.go

@@ -226,7 +226,7 @@ func (mh *matrixHandler) handleMembership(evt *event.Event) {
 		} else if puppet != nil {
 			portal.handleMatrixKick(user, puppet)
 		}
-	} else if content.Membership == event.MembershipInvite && !isSelf {
+	} else if content.Membership == event.MembershipInvite {
 		portal.handleMatrixInvite(user, evt)
 	}
 }

+ 56 - 6
bridge/portal.go

@@ -3,6 +3,7 @@ package bridge
 import (
 	"fmt"
 	"sync"
+	"time"
 
 	"github.com/bwmarrin/discordgo"
 
@@ -297,21 +298,51 @@ func (p *Portal) ensureUserInvited(user *User) bool {
 	return user.ensureInvited(p.MainIntent(), p.MXID, p.IsPrivateChat())
 }
 
+func (p *Portal) markMessageHandled(msg *database.Message, discordID string, mxid id.EventID, authorID string, timestamp time.Time) *database.Message {
+	if msg == nil {
+		msg := p.bridge.db.Message.New()
+		msg.Channel = p.Key
+		msg.DiscordID = discordID
+		msg.MatrixID = mxid
+		msg.AuthorID = authorID
+		msg.Timestamp = timestamp
+		msg.Insert()
+	} else {
+		msg.UpdateMatrixID(mxid)
+	}
+
+	return msg
+}
+
 func (p *Portal) handleDiscordMessage(msg *discordgo.Message) {
 	if p.MXID == "" {
 		p.log.Warnln("handle message called without a valid portal")
+
+		return
+	}
+
+	existing := p.bridge.db.Message.GetByDiscordID(p.Key, msg.ID)
+	if existing != nil {
+		p.log.Debugln("not handling duplicate message", msg.ID)
+
 		return
 	}
 
-	// TODO: Check if we already got the message
 	content := &event.MessageEventContent{
 		Body:    msg.Content,
 		MsgType: event.MsgText,
 	}
 
-	resp, err := p.MainIntent().SendMessageEvent(p.MXID, event.EventMessage, content)
-	p.log.Warnln("response:", resp)
-	p.log.Warnln("error:", err)
+	intent := p.bridge.GetPuppetByID(msg.Author.ID).IntentFor(p)
+
+	resp, err := intent.SendMessageEvent(p.MXID, event.EventMessage, content)
+	if err != nil {
+		p.log.Warnfln("failed to send message %q to matrix: %v", msg.ID, err)
+		return
+	}
+
+	ts, _ := msg.Timestamp.Parse()
+	p.markMessageHandled(nil, msg.ID, resp.EventID, msg.Author.ID, ts)
 }
 
 func (p *Portal) syncParticipants(source *User, participants []*discordgo.User) {
@@ -344,6 +375,13 @@ func (p *Portal) handleMatrixMessage(sender *User, evt *event.Event) {
 		return
 	}
 
+	existing := p.bridge.db.Message.GetByMatrixID(p.Key, evt.ID)
+	if existing != nil {
+		p.log.Debugln("not handling duplicate message", evt.ID)
+
+		return
+	}
+
 	content, ok := evt.Content.Parsed.(*event.MessageEventContent)
 	if !ok {
 		p.log.Debugfln("Failed to handle event %s: unexpected parsed content type %T", evt.ID, evt.Content.Parsed)
@@ -351,8 +389,20 @@ func (p *Portal) handleMatrixMessage(sender *User, evt *event.Event) {
 		return
 	}
 
-	sender.Session.ChannelMessageSend(p.Key.ChannelID, content.Body)
-	p.log.Debugln("sent message:", content.Body)
+	msg, err := sender.Session.ChannelMessageSend(p.Key.ChannelID, content.Body)
+	if err != nil {
+		p.log.Errorfln("Failed to send message: %v", err)
+
+		return
+	}
+
+	dbMsg := p.bridge.db.Message.New()
+	dbMsg.Channel = p.Key
+	dbMsg.DiscordID = msg.ID
+	dbMsg.MatrixID = evt.ID
+	dbMsg.AuthorID = sender.ID
+	dbMsg.Timestamp = time.Now()
+	dbMsg.Insert()
 }
 
 func (p *Portal) handleMatrixLeave(sender *User) {

+ 9 - 3
database/database.go

@@ -16,9 +16,10 @@ type Database struct {
 	log     log.Logger
 	dialect string
 
-	User   *UserQuery
-	Portal *PortalQuery
-	Puppet *PuppetQuery
+	User    *UserQuery
+	Portal  *PortalQuery
+	Puppet  *PuppetQuery
+	Message *MessageQuery
 }
 
 func New(dbType, uri string, maxOpenConns, maxIdleConns int, baseLog log.Logger) (*Database, error) {
@@ -61,5 +62,10 @@ func New(dbType, uri string, maxOpenConns, maxIdleConns int, baseLog log.Logger)
 		log: db.log.Sub("Puppet"),
 	}
 
+	db.Message = &MessageQuery{
+		db:  db,
+		log: db.log.Sub("Message"),
+	}
+
 	return db, nil
 }

+ 66 - 0
database/message.go

@@ -0,0 +1,66 @@
+package database
+
+import (
+	"database/sql"
+	"errors"
+	"time"
+
+	log "maunium.net/go/maulogger/v2"
+	"maunium.net/go/mautrix/id"
+)
+
+type Message struct {
+	db  *Database
+	log log.Logger
+
+	Channel PortalKey
+
+	DiscordID string
+	MatrixID  id.EventID
+
+	AuthorID  string
+	Timestamp time.Time
+}
+
+func (m *Message) Scan(row Scannable) *Message {
+	var ts int64
+
+	err := row.Scan(&m.Channel.ChannelID, &m.Channel.Receiver, &m.DiscordID, &m.MatrixID, &m.AuthorID, &ts)
+	if err != nil {
+		if !errors.Is(err, sql.ErrNoRows) {
+			m.log.Errorln("Database scan failed:", err)
+		}
+
+		return nil
+	}
+
+	if ts != 0 {
+		m.Timestamp = time.Unix(ts, 0)
+	}
+
+	return m
+}
+
+func (m *Message) Insert() {
+	query := "INSERT INTO message" +
+		" (channel_id, receiver, discord_message_id, matrix_message_id," +
+		" author_id, timestamp) VALUES ($1, $2, $3, $4, $5, $6)"
+
+	_, err := m.db.Exec(query, m.Channel.ChannelID, m.Channel.Receiver,
+		m.DiscordID, m.MatrixID, m.AuthorID, m.Timestamp.Unix())
+
+	if err != nil {
+		m.log.Warnfln("Failed to insert %s@%s: %v", m.Channel, m.DiscordID, err)
+	}
+}
+
+func (m *Message) UpdateMatrixID(mxid id.EventID) {
+	query := "UPDATE message SET matrix_message_id=$1 WHERE channel_id=$2" +
+		"AND receiver=$3 AND discord_message_id=$4"
+	m.MatrixID = mxid
+
+	_, err := m.db.Exec(query, m.MatrixID, m.Channel.ChannelID, m.Channel.Receiver, m.DiscordID)
+	if err != nil {
+		m.log.Warnfln("Failed to update %s@%s: %v", m.Channel, m.DiscordID, err)
+	}
+}

+ 64 - 0
database/messagequery.go

@@ -0,0 +1,64 @@
+package database
+
+import (
+	log "maunium.net/go/maulogger/v2"
+	"maunium.net/go/mautrix/id"
+)
+
+type MessageQuery struct {
+	db  *Database
+	log log.Logger
+}
+
+const (
+	messageSelect = "SELECT channel_id, receiver, discord_message_id," +
+		" matrix_message_id, author_id, timestamp FROM message"
+)
+
+func (mq *MessageQuery) New() *Message {
+	return &Message{
+		db:  mq.db,
+		log: mq.log,
+	}
+}
+
+func (mq *MessageQuery) GetAll(key PortalKey) []*Message {
+	query := messageSelect + " WHERE channeld_id=$1 AND receiver=$2"
+
+	rows, err := mq.db.Query(query, key.ChannelID, key.Receiver)
+	if err != nil || rows == nil {
+		return nil
+	}
+
+	messages := []*Message{}
+	for rows.Next() {
+		messages = append(messages, mq.New().Scan(rows))
+	}
+
+	return messages
+}
+
+func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) *Message {
+	query := messageSelect + " WHERE channel_id=$1 AND receiver=$2 AND" +
+		" discord_message_id=$3"
+
+	row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID)
+	if row == nil {
+		mq.log.Debugfln("failed to find existing message for discord_id %s", discordID)
+		return nil
+	}
+
+	return mq.New().Scan(row)
+}
+
+func (mq *MessageQuery) GetByMatrixID(key PortalKey, matrixID id.EventID) *Message {
+	query := messageSelect + " WHERE channel_id=$1 AND receiver=$2 AND" +
+		" matrix_message_id=$3"
+
+	row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, matrixID)
+	if row == nil {
+		return nil
+	}
+
+	return mq.New().Scan(row)
+}

+ 14 - 0
database/migrations/01-initial.sql

@@ -33,6 +33,20 @@ CREATE TABLE user (
 	token TEXT
 );
 
+CREATE TABLE message (
+	channel_id TEXT NOT NULL,
+	receiver TEXT NOT NULL,
+
+	discord_message_id TEXT NOT NULL,
+	matrix_message_id TEXT NOT NULL UNIQUE,
+
+	author_id TEXT NOT NULL,
+	timestamp BIGINT NOT NULL,
+
+	PRIMARY KEY(discord_message_id, channel_id, receiver),
+	FOREIGN KEY(channel_id, receiver) REFERENCES portal(channel_id, receiver) ON DELETE CASCADE
+);
+
 CREATE TABLE mx_user_profile (
 	room_id     TEXT,
 	user_id     TEXT,

+ 3 - 0
database/portal.go

@@ -20,6 +20,9 @@ type Portal struct {
 	Avatar    string
 	AvatarURL id.ContentURI
 
+	Type   int
+	DMUser string
+
 	FirstEventID id.EventID
 }