Преглед изворни кода

Add the db for reactions and removals from discord

Gary Kramlich пре 3 година
родитељ
комит
9d13905a77
6 измењених фајлова са 306 додато и 11 уклоњено
  1. 58 5
      bridge/portal.go
  2. 21 2
      bridge/user.go
  3. 10 4
      database/database.go
  4. 20 0
      database/migrations/01-initial.sql
  5. 115 0
      database/reaction.go
  6. 82 0
      database/reactionquery.go

+ 58 - 5
bridge/portal.go

@@ -291,6 +291,8 @@ func (p *Portal) handleDiscordMessages(msg portalDiscordMessage) {
 		p.handleDiscordMessage(msg.user, msg.msg.(*discordgo.MessageCreate).Message)
 	case *discordgo.MessageReactionAdd:
 		p.handleDiscordReaction(msg.user, msg.msg.(*discordgo.MessageReactionAdd).MessageReaction, true)
+	case *discordgo.MessageReactionRemove:
+		p.handleDiscordReaction(msg.user, msg.msg.(*discordgo.MessageReactionRemove).MessageReaction, false)
 	default:
 		p.log.Warnln("unknown message type")
 	}
@@ -549,12 +551,18 @@ func (p *Portal) handleDiscordReaction(user *User, reaction *discordgo.MessageRe
 		return
 	}
 
-	if reaction.Emoji.ID != "" {
+	// Emoji.ID is only set if it's a custom emote, otherwise Emoji.Name is
+	// used.
+	customEmote := (reaction.Emoji.ID != "")
+
+	// This is temporary until we add support for custom emoji.
+	if customEmote {
 		p.log.Debugln("ignoring non-unicode reaction")
 
 		return
 	}
 
+	// Find the message that we're working with.
 	message := p.bridge.db.Message.GetByDiscordID(p.Key, reaction.MessageID)
 	if message == nil {
 		p.log.Debugfln("failed to add reaction to message %s: message not found", reaction.MessageID)
@@ -562,8 +570,34 @@ func (p *Portal) handleDiscordReaction(user *User, reaction *discordgo.MessageRe
 		return
 	}
 
+	// Lookup an existing reaction
+	var existing *database.Reaction
+
+	if customEmote {
+		existing = p.bridge.db.Reaction.GetByDiscordID(p.Key, message.DiscordID, reaction.Emoji.ID)
+	} else {
+		existing = p.bridge.db.Reaction.GetByDiscordName(p.Key, message.DiscordID, reaction.Emoji.Name)
+	}
+
+	if !add && existing == nil {
+		p.log.Debugln("Failed to remove emote for unknown message", reaction.MessageID)
+
+		return
+	}
+
 	intent := p.bridge.GetPuppetByID(reaction.UserID).IntentFor(p)
 
+	if !add {
+		_, err := intent.RedactEvent(p.MXID, existing.MatrixEventID)
+		if err != nil {
+			p.log.Warnfln("Failed to remove reaction from %s: %v", p.MXID, err)
+		}
+
+		existing.Delete()
+
+		return
+	}
+
 	content := event.Content{Parsed: &event.ReactionEventContent{
 		RelatesTo: event.RelatesTo{
 			EventID: message.MatrixID,
@@ -572,10 +606,29 @@ func (p *Portal) handleDiscordReaction(user *User, reaction *discordgo.MessageRe
 		},
 	}}
 
-	_, err := intent.Client.SendMessageEvent(p.MXID, event.EventReaction, &content)
-	if err != nil {
-		p.log.Errorfln("failed to send reaction from %s: %v", reaction.MessageID, err)
+	if add {
+		resp, err := intent.Client.SendMessageEvent(p.MXID, event.EventReaction, &content)
+		if err != nil {
+			p.log.Errorfln("failed to send reaction from %s: %v", reaction.MessageID, err)
 
-		return
+			return
+		}
+
+		if existing == nil {
+			dbReaction := p.bridge.db.Reaction.New()
+			dbReaction.Channel = p.Key
+			dbReaction.DiscordMessageID = message.DiscordID
+			dbReaction.MatrixEventID = resp.EventID
+			dbReaction.AuthorID = reaction.UserID
+
+			if customEmote {
+				// TODO:
+			} else {
+				dbReaction.MatrixName = reaction.Emoji.Name
+				dbReaction.DiscordName = reaction.Emoji.Name
+			}
+
+			dbReaction.Insert()
+		}
 	}
 }

+ 21 - 2
bridge/user.go

@@ -213,7 +213,8 @@ func (u *User) Connect() error {
 	u.User.Session.AddHandler(u.channelUpdateHandler)
 
 	u.User.Session.AddHandler(u.messageHandler)
-	u.User.Session.AddHandler(u.reactionHandler)
+	u.User.Session.AddHandler(u.reactionAddHandler)
+	u.User.Session.AddHandler(u.reactionRemoveHandler)
 
 	// u.User.Session.Identify.Capabilities = 125
 	// // Setup our properties
@@ -296,7 +297,25 @@ func (u *User) messageHandler(s *discordgo.Session, m *discordgo.MessageCreate)
 	portal.discordMessages <- msg
 }
 
-func (u *User) reactionHandler(s *discordgo.Session, m *discordgo.MessageReactionAdd) {
+func (u *User) reactionAddHandler(s *discordgo.Session, m *discordgo.MessageReactionAdd) {
+	if m.GuildID != "" {
+		u.log.Debugln("ignoring reaction for guild message")
+
+		return
+	}
+
+	key := database.NewPortalKey(m.ChannelID, u.User.ID)
+	portal := u.bridge.GetPortalByID(key)
+
+	msg := portalDiscordMessage{
+		msg:  m,
+		user: u,
+	}
+
+	portal.discordMessages <- msg
+}
+
+func (u *User) reactionRemoveHandler(s *discordgo.Session, m *discordgo.MessageReactionRemove) {
 	if m.GuildID != "" {
 		u.log.Debugln("ignoring reaction for guild message")
 

+ 10 - 4
database/database.go

@@ -16,10 +16,11 @@ type Database struct {
 	log     log.Logger
 	dialect string
 
-	User    *UserQuery
-	Portal  *PortalQuery
-	Puppet  *PuppetQuery
-	Message *MessageQuery
+	User     *UserQuery
+	Portal   *PortalQuery
+	Puppet   *PuppetQuery
+	Message  *MessageQuery
+	Reaction *ReactionQuery
 }
 
 func New(dbType, uri string, maxOpenConns, maxIdleConns int, baseLog log.Logger) (*Database, error) {
@@ -67,5 +68,10 @@ func New(dbType, uri string, maxOpenConns, maxIdleConns int, baseLog log.Logger)
 		log: db.log.Sub("Message"),
 	}
 
+	db.Reaction = &ReactionQuery{
+		db:  db,
+		log: db.log.Sub("Reaction"),
+	}
+
 	return db, nil
 }

+ 20 - 0
database/migrations/01-initial.sql

@@ -47,6 +47,26 @@ CREATE TABLE message (
 	FOREIGN KEY(channel_id, receiver) REFERENCES portal(channel_id, receiver) ON DELETE CASCADE
 );
 
+CREATE TABLE reaction (
+	channel_id TEXT NOT NULL,
+	receiver TEXT NOT NULL,
+
+	discord_message_id TEXT NOT NULL,
+	matrix_event_id TEXT NOT NULL UNIQUE,
+
+	author_id TEXT NOT NULL,
+
+	matrix_name TEXT,
+	matrix_url TEXT,
+
+	discord_name TEXT,
+	discord_id TEXT,
+
+	CHECK ((discord_name IS NULL AND discord_id IS NOT NULL) OR (discord_name IS NOT NULL AND discord_id IS NULL)),
+	UNIQUE (discord_name, discord_id, author_id, discord_message_id, channel_id, receiver),
+	FOREIGN KEY(channel_id, receiver) REFERENCES portal(channel_id, receiver) ON DELETE CASCADE
+);
+
 CREATE TABLE mx_user_profile (
 	room_id     TEXT,
 	user_id     TEXT,

+ 115 - 0
database/reaction.go

@@ -0,0 +1,115 @@
+package database
+
+import (
+	"database/sql"
+	"errors"
+
+	log "maunium.net/go/maulogger/v2"
+	"maunium.net/go/mautrix/id"
+)
+
+type Reaction struct {
+	db  *Database
+	log log.Logger
+
+	Channel PortalKey
+
+	DiscordMessageID string
+	MatrixEventID    id.EventID
+
+	// The discord ID of who create this reaction
+	AuthorID string
+
+	MatrixName string
+	MatrixURL  string // Used for custom emoji
+
+	DiscordName string // Used for unicode emoji
+	DiscordID   string // Used for custom emoji
+}
+
+func (r *Reaction) Scan(row Scannable) *Reaction {
+	var discordName, discordID sql.NullString
+
+	err := row.Scan(
+		&r.Channel.ChannelID, &r.Channel.Receiver,
+		&r.DiscordMessageID, &r.MatrixEventID,
+		&r.AuthorID,
+		&r.MatrixName, &r.MatrixURL,
+		&discordName, &discordID)
+
+	if err != nil {
+		if !errors.Is(err, sql.ErrNoRows) {
+			r.log.Errorln("Database scan failed:", err)
+		}
+
+		return nil
+	}
+
+	r.DiscordName = discordName.String
+	r.DiscordID = discordID.String
+
+	return r
+}
+
+func (r *Reaction) Insert() {
+	query := "INSERT INTO reaction" +
+		" (channel_id, receiver, discord_message_id, matrix_event_id," +
+		"  author_id, matrix_name, matrix_url, discord_name, discord_id)" +
+		" VALUES($1, $2, $3, $4, $5, $6, $7, $8, $9);"
+
+	var discordName, discordID sql.NullString
+
+	if r.DiscordName != "" {
+		discordName = sql.NullString{r.DiscordName, true}
+	}
+
+	if r.DiscordID != "" {
+		discordID = sql.NullString{r.DiscordID, true}
+	}
+
+	_, err := r.db.Exec(
+		query,
+		r.Channel.ChannelID, r.Channel.Receiver,
+		r.DiscordMessageID, r.MatrixEventID,
+		r.AuthorID,
+		r.MatrixName, r.MatrixURL,
+		discordName, discordID,
+	)
+
+	if err != nil {
+		r.log.Warnfln("Failed to insert reaction for %s@%s: %v", r.Channel, r.DiscordMessageID, err)
+	}
+}
+
+func (r *Reaction) Update() {
+	// TODO: determine if we need this. The only scenario I can think of that
+	// would require this is if we insert a custom emoji before uploading to
+	// the homeserver?
+}
+
+func (r *Reaction) Delete() {
+	query := "DELETE FROM reaction WHERE" +
+		" channel_id=$1 AND receiver=$2 AND discord_message_id=$3 AND" +
+		" author_id=$4 AND discord_name=$5 AND discord_id=$6"
+
+	var discordName, discordID sql.NullString
+
+	if r.DiscordName != "" {
+		discordName = sql.NullString{r.DiscordName, true}
+	}
+
+	if r.DiscordID != "" {
+		discordID = sql.NullString{r.DiscordID, true}
+	}
+
+	_, err := r.db.Exec(
+		query,
+		r.Channel.ChannelID, r.Channel.Receiver,
+		r.DiscordMessageID, r.AuthorID,
+		discordName, discordID,
+	)
+
+	if err != nil {
+		r.log.Warnfln("Failed to delete reaction for %s@%s: %v", r.Channel, r.DiscordMessageID, err)
+	}
+}

+ 82 - 0
database/reactionquery.go

@@ -0,0 +1,82 @@
+package database
+
+import (
+	log "maunium.net/go/maulogger/v2"
+	"maunium.net/go/mautrix/id"
+)
+
+type ReactionQuery struct {
+	db  *Database
+	log log.Logger
+}
+
+const (
+	reactionSelect = "SELECT channel_id, receiver, discord_message_id," +
+		" matrix_event_id, author_id, matrix_name, matrix_url, " +
+		" discord_name, discord_id FROM reaction"
+)
+
+func (rq *ReactionQuery) New() *Reaction {
+	return &Reaction{
+		db:  rq.db,
+		log: rq.log,
+	}
+}
+
+func (rq *ReactionQuery) GetAllByDiscordID(key PortalKey, discordMessageID string) []*Reaction {
+	query := reactionSelect + " WHERE channel_id=$1 AND receiver=$2 AND" +
+		" discord_message_id=$3"
+
+	return rq.getAll(query, key.ChannelID, key.Receiver, discordMessageID)
+}
+
+func (rq *ReactionQuery) GetAllByMatrixID(key PortalKey, matrixEventID id.EventID) []*Reaction {
+	query := reactionSelect + " WHERE channel_id=$1 AND receiver=$2 AND" +
+		" matrix_event_id=$3"
+
+	return rq.getAll(query, key.ChannelID, key.Receiver, matrixEventID)
+}
+
+func (rq *ReactionQuery) getAll(query string, args ...interface{}) []*Reaction {
+	rows, err := rq.db.Query(query)
+	if err != nil || rows == nil {
+		return nil
+	}
+
+	reactions := []*Reaction{}
+	for rows.Next() {
+		reactions = append(reactions, rq.New().Scan(rows))
+	}
+
+	return reactions
+}
+
+func (rq *ReactionQuery) GetByDiscordName(key PortalKey, discordMessageID, discordName string) *Reaction {
+	query := reactionSelect + " WHERE channel_id=$1 AND receiver=$2" +
+		" AND discord_message_id=$3 AND discord_name=$4"
+
+	return rq.get(query, key.ChannelID, key.Receiver, discordMessageID, discordName)
+}
+
+func (rq *ReactionQuery) GetByDiscordID(key PortalKey, discordMessageID, discordID string) *Reaction {
+	query := reactionSelect + " WHERE channel_id=$1 AND receiver=$2" +
+		" AND discord_message_id=$3 AND discord_id=$4"
+
+	return rq.get(query, key.ChannelID, key.Receiver, discordMessageID, discordID)
+}
+
+func (rq *ReactionQuery) GetByMatrixName(key PortalKey, matrixEventID id.EventID, matrixName string) *Reaction {
+	query := reactionSelect + " WHERE channel_id=$1 AND receiver=$2" +
+		" AND matrix_event_id=$3 AND matrix_name=$4"
+
+	return rq.get(query, key.ChannelID, key.Receiver, matrixEventID, matrixName)
+}
+
+func (rq *ReactionQuery) get(query string, args ...interface{}) *Reaction {
+	row := rq.db.QueryRow(query, args...)
+	if row == nil {
+		return nil
+	}
+
+	return rq.New().Scan(row)
+}