Explorar o código

Sync group DM participants on change

Tulir Asokan %!s(int64=2) %!d(string=hai) anos
pai
achega
3e1d1740f7
Modificáronse 4 ficheiros con 52 adicións e 3 borrados
  1. 1 1
      go.mod
  2. 2 2
      go.sum
  3. 32 0
      portal.go
  4. 17 0
      user.go

+ 1 - 1
go.mod

@@ -37,4 +37,4 @@ require (
 	maunium.net/go/mauflag v1.0.0 // indirect
 	maunium.net/go/mauflag v1.0.0 // indirect
 )
 )
 
 
-replace github.com/bwmarrin/discordgo => github.com/beeper/discordgo v0.0.0-20230425174737-526618ee92f8
+replace github.com/bwmarrin/discordgo => github.com/beeper/discordgo v0.0.0-20230426180404-ce66567c447b

+ 2 - 2
go.sum

@@ -1,6 +1,6 @@
 github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
 github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
-github.com/beeper/discordgo v0.0.0-20230425174737-526618ee92f8 h1:9GiqpUOVfcgn27okKcuTLgOJ2BucQlwpX0wrFj+c6WA=
-github.com/beeper/discordgo v0.0.0-20230425174737-526618ee92f8/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
+github.com/beeper/discordgo v0.0.0-20230426180404-ce66567c447b h1:Xk0iNigYnqfx4TpbW6X5qjeO9TTq7eTW9CJIi+YfC94=
+github.com/beeper/discordgo v0.0.0-20230426180404-ce66567c447b/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
 github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
 github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
 github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
 github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=

+ 32 - 0
portal.go

@@ -898,6 +898,32 @@ func (portal *Portal) handleDiscordTyping(evt *discordgo.TypingStart) {
 	}
 	}
 }
 }
 
 
+func (portal *Portal) syncParticipant(source *User, participant *discordgo.User, remove bool) {
+	puppet := portal.bridge.GetPuppetByID(participant.ID)
+	puppet.UpdateInfo(source, participant)
+	log := portal.zlog.With().Str("participant_id", participant.ID).
+		Str("ghost_mxid", puppet.MXID.String()).
+		Str("room_id", portal.MXID.String()).
+		Logger()
+
+	user := portal.bridge.GetUserByID(participant.ID)
+	if user != nil {
+		log.Debug().Msg("Ensuring Matrix user is invited or joined to room")
+		portal.ensureUserInvited(user)
+	}
+
+	if remove {
+		_, err := puppet.DefaultIntent().LeaveRoom(portal.MXID)
+		if err != nil {
+			log.Warn().Err(err).Msg("Failed to make ghost leave room after member remove event")
+		}
+	} else if user == nil || !puppet.IntentFor(portal).IsCustomPuppet {
+		if err := puppet.IntentFor(portal).EnsureJoined(portal.MXID); err != nil {
+			log.Warn().Err(err).Msg("Failed to add ghost to room")
+		}
+	}
+}
+
 func (portal *Portal) syncParticipants(source *User, participants []*discordgo.User) {
 func (portal *Portal) syncParticipants(source *User, participants []*discordgo.User) {
 	for _, participant := range participants {
 	for _, participant := range participants {
 		puppet := portal.bridge.GetPuppetByID(participant.ID)
 		puppet := portal.bridge.GetPuppetByID(participant.ID)
@@ -2097,8 +2123,14 @@ func (portal *Portal) UpdateInfo(source *User, meta *discordgo.Channel) *discord
 				changed = portal.UpdateNameDirect(puppet.Name, false) || changed
 				changed = portal.UpdateNameDirect(puppet.Name, false) || changed
 			}
 			}
 		}
 		}
+		if portal.MXID != "" {
+			portal.syncParticipants(source, meta.Recipients)
+		}
 	case discordgo.ChannelTypeGroupDM:
 	case discordgo.ChannelTypeGroupDM:
 		changed = portal.UpdateGroupDMAvatar(meta.Icon) || changed
 		changed = portal.UpdateGroupDMAvatar(meta.Icon) || changed
+		if portal.MXID != "" {
+			portal.syncParticipants(source, meta.Recipients)
+		}
 		fallthrough
 		fallthrough
 	default:
 	default:
 		changed = portal.UpdateName(meta) || changed
 		changed = portal.UpdateName(meta) || changed

+ 17 - 0
user.go

@@ -595,6 +595,9 @@ func (user *User) Connect() error {
 	user.Session.AddHandler(user.channelPinsUpdateHandler)
 	user.Session.AddHandler(user.channelPinsUpdateHandler)
 	user.Session.AddHandler(user.channelUpdateHandler)
 	user.Session.AddHandler(user.channelUpdateHandler)
 
 
+	user.Session.AddHandler(user.channelRecipientAdd)
+	user.Session.AddHandler(user.channelRecipientRemove)
+
 	user.Session.AddHandler(user.relationshipAddHandler)
 	user.Session.AddHandler(user.relationshipAddHandler)
 	user.Session.AddHandler(user.relationshipRemoveHandler)
 	user.Session.AddHandler(user.relationshipRemoveHandler)
 	user.Session.AddHandler(user.relationshipUpdateHandler)
 	user.Session.AddHandler(user.relationshipUpdateHandler)
@@ -1071,6 +1074,20 @@ func (user *User) channelUpdateHandler(_ *discordgo.Session, c *discordgo.Channe
 	}
 	}
 }
 }
 
 
+func (user *User) channelRecipientAdd(_ *discordgo.Session, c *discordgo.ChannelRecipientAdd) {
+	portal := user.GetExistingPortalByID(c.ChannelID)
+	if portal != nil {
+		portal.syncParticipant(user, c.User, false)
+	}
+}
+
+func (user *User) channelRecipientRemove(_ *discordgo.Session, c *discordgo.ChannelRecipientRemove) {
+	portal := user.GetExistingPortalByID(c.ChannelID)
+	if portal != nil {
+		portal.syncParticipant(user, c.User, true)
+	}
+}
+
 func (user *User) findPortal(channelID string) (*Portal, *Thread) {
 func (user *User) findPortal(channelID string) (*Portal, *Thread) {
 	portal := user.GetExistingPortalByID(channelID)
 	portal := user.GetExistingPortalByID(channelID)
 	if portal != nil {
 	if portal != nil {