Pārlūkot izejas kodu

Add option for autojoining threads when opened

Tulir Asokan 2 gadi atpakaļ
vecāks
revīzija
f268ddd132

+ 1 - 0
config/bridge.go

@@ -41,6 +41,7 @@ type BridgeConfig struct {
 	MessageStatusEvents         bool `yaml:"message_status_events"`
 	MessageErrorNotices         bool `yaml:"message_error_notices"`
 	RestrictedRooms             bool `yaml:"restricted_rooms"`
+	AutojoinThreadOnOpen        bool `yaml:"autojoin_thread_on_open"`
 	SyncDirectChatList          bool `yaml:"sync_direct_chat_list"`
 	ResendBridgeInfo            bool `yaml:"resend_bridge_info"`
 	DeletePortalOnChannelDelete bool `yaml:"delete_portal_on_channel_delete"`

+ 1 - 0
config/upgrade.go

@@ -36,6 +36,7 @@ func DoUpgrade(helper *up.Helper) {
 	helper.Copy(up.Bool, "bridge", "message_status_events")
 	helper.Copy(up.Bool, "bridge", "message_error_notices")
 	helper.Copy(up.Bool, "bridge", "restricted_rooms")
+	helper.Copy(up.Bool, "bridge", "autojoin_thread_on_open")
 	helper.Copy(up.Bool, "bridge", "sync_direct_chat_list")
 	helper.Copy(up.Bool, "bridge", "resend_bridge_info")
 	helper.Copy(up.Bool, "bridge", "delete_portal_on_channel_delete")

+ 5 - 7
database/message.go

@@ -60,9 +60,9 @@ func (mq *MessageQuery) GetLastByDiscordID(key PortalKey, discordID string) *Mes
 	return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
 }
 
-func (mq *MessageQuery) GetClosestBefore(key PortalKey, ts time.Time) *Message {
-	query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND timestamp<=$3 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
-	return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, ts.UnixMilli()))
+func (mq *MessageQuery) GetClosestBefore(key PortalKey, threadID string, ts time.Time) *Message {
+	query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 AND timestamp<=$4 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
+	return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID, ts.UnixMilli()))
 }
 
 func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message {
@@ -115,9 +115,8 @@ func (m *Message) DiscordProtoChannelID() string {
 
 func (m *Message) Scan(row dbutil.Scannable) *Message {
 	var ts int64
-	var threadID sql.NullString
 
-	err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.EditIndex, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &threadID, &m.MXID)
+	err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.EditIndex, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &m.ThreadID, &m.MXID)
 	if err != nil {
 		if !errors.Is(err, sql.ErrNoRows) {
 			m.log.Errorln("Database scan failed:", err)
@@ -130,7 +129,6 @@ func (m *Message) Scan(row dbutil.Scannable) *Message {
 	if ts != 0 {
 		m.Timestamp = time.UnixMilli(ts)
 	}
-	m.ThreadID = threadID.String
 
 	return m
 }
@@ -181,7 +179,7 @@ func (m *Message) MassInsert(msgs []MessagePart) {
 func (m *Message) Insert() {
 	_, err := m.db.Exec(messageInsertQuery,
 		m.DiscordID, m.AttachmentID, m.EditIndex, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
-		m.Timestamp.UnixMilli(), strPtr(m.ThreadID), m.MXID)
+		m.Timestamp.UnixMilli(), m.ThreadID, m.MXID)
 
 	if err != nil {
 		m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err)

+ 26 - 19
database/thread.go

@@ -16,7 +16,7 @@ type ThreadQuery struct {
 }
 
 const (
-	threadSelect = "SELECT dcid, parent_chan_id, root_msg_dcid, root_msg_mxid FROM thread"
+	threadSelect = "SELECT dcid, parent_chan_id, root_msg_dcid, root_msg_mxid, creation_notice_mxid FROM thread"
 )
 
 func (tq *ThreadQuery) New() *Thread {
@@ -37,17 +37,6 @@ func (tq *ThreadQuery) GetByDiscordID(discordID string) *Thread {
 	return tq.New().Scan(row)
 }
 
-//func (tq *ThreadQuery) GetByDiscordRootMsg(channelID, messageID string) *Thread {
-//	query := messageSelect + " WHERE parent_chan_id=$1 AND root_msg_dcid=$2"
-//
-//	row := tq.db.QueryRow(query, channelID, messageID)
-//	if row == nil {
-//		return nil
-//	}
-//
-//	return tq.New().Scan(row)
-//}
-
 func (tq *ThreadQuery) GetByMatrixRootMsg(mxid id.EventID) *Thread {
 	query := threadSelect + " WHERE root_msg_mxid=$1"
 
@@ -59,6 +48,17 @@ func (tq *ThreadQuery) GetByMatrixRootMsg(mxid id.EventID) *Thread {
 	return tq.New().Scan(row)
 }
 
+func (tq *ThreadQuery) GetByMatrixRootOrCreationNoticeMsg(mxid id.EventID) *Thread {
+	query := threadSelect + " WHERE root_msg_mxid=$1 OR creation_notice_mxid=$1"
+
+	row := tq.db.QueryRow(query, mxid)
+	if row == nil {
+		return nil
+	}
+
+	return tq.New().Scan(row)
+}
+
 type Thread struct {
 	db  *Database
 	log log.Logger
@@ -68,10 +68,12 @@ type Thread struct {
 
 	RootDiscordID string
 	RootMXID      id.EventID
+
+	CreationNoticeMXID id.EventID
 }
 
 func (t *Thread) Scan(row dbutil.Scannable) *Thread {
-	err := row.Scan(&t.ID, &t.ParentID, &t.RootDiscordID, &t.RootMXID)
+	err := row.Scan(&t.ID, &t.ParentID, &t.RootDiscordID, &t.RootMXID, &t.CreationNoticeMXID)
 	if err != nil {
 		if !errors.Is(err, sql.ErrNoRows) {
 			t.log.Errorln("Database scan failed:", err)
@@ -83,21 +85,26 @@ func (t *Thread) Scan(row dbutil.Scannable) *Thread {
 }
 
 func (t *Thread) Insert() {
-	query := "INSERT INTO thread (dcid, parent_chan_id, root_msg_dcid, root_msg_mxid) VALUES ($1, $2, $3, $4)"
-
-	_, err := t.db.Exec(query, t.ID, t.ParentID, t.RootDiscordID, t.RootMXID)
-
+	query := "INSERT INTO thread (dcid, parent_chan_id, root_msg_dcid, root_msg_mxid, creation_notice_mxid) VALUES ($1, $2, $3, $4, $5)"
+	_, err := t.db.Exec(query, t.ID, t.ParentID, t.RootDiscordID, t.RootMXID, t.CreationNoticeMXID)
 	if err != nil {
 		t.log.Warnfln("Failed to insert %s@%s: %v", t.ID, t.ParentID, err)
 		panic(err)
 	}
 }
 
+func (t *Thread) Update() {
+	query := "UPDATE thread SET creation_notice_mxid=$2 WHERE dcid=$1"
+	_, err := t.db.Exec(query, t.ID, t.CreationNoticeMXID)
+	if err != nil {
+		t.log.Warnfln("Failed to update %s@%s: %v", t.ID, t.ParentID, err)
+		panic(err)
+	}
+}
+
 func (t *Thread) Delete() {
 	query := "DELETE FROM thread WHERE dcid=$1 AND parent_chan_id=$2"
-
 	_, err := t.db.Exec(query, t.ID, t.ParentID)
-
 	if err != nil {
 		t.log.Warnfln("Failed to delete %s@%s: %v", t.ID, t.ParentID, err)
 		panic(err)

+ 6 - 5
database/upgrades/00-latest-revision.sql

@@ -1,4 +1,4 @@
--- v0 -> v8: Latest revision
+-- v0 -> v9: Latest revision
 
 CREATE TABLE guild (
     dcid       TEXT PRIMARY KEY,
@@ -49,6 +49,7 @@ CREATE TABLE thread (
     parent_chan_id TEXT NOT NULL,
     root_msg_dcid  TEXT NOT NULL,
     root_msg_mxid  TEXT NOT NULL,
+    creation_notice_mxid TEXT NOT NULL,
     -- This is also not accessed by the bridge.
     receiver   TEXT NOT NULL DEFAULT '',
 
@@ -98,9 +99,9 @@ CREATE TABLE message (
     dc_edit_index    INTEGER,
     dc_chan_id       TEXT,
     dc_chan_receiver TEXT,
-    dc_sender        TEXT NOT NULL,
+    dc_sender        TEXT   NOT NULL,
     timestamp        BIGINT NOT NULL,
-    dc_thread_id     TEXT,
+    dc_thread_id     TEXT   NOT NULL,
 
     mxid TEXT NOT NULL UNIQUE,
 
@@ -114,9 +115,9 @@ CREATE TABLE reaction (
     dc_msg_id        TEXT,
     dc_sender        TEXT,
     dc_emoji_name    TEXT,
-    dc_thread_id     TEXT,
+    dc_thread_id     TEXT NOT NULL,
 
-    dc_first_attachment_id TEXT NOT NULL,
+    dc_first_attachment_id TEXT    NOT NULL,
     _dc_first_edit_index   INTEGER NOT NULL DEFAULT 0,
 
     mxid TEXT NOT NULL UNIQUE,

+ 9 - 0
database/upgrades/09-more-thread-data.sql

@@ -0,0 +1,9 @@
+-- v9: Store more info for proper thread support
+ALTER TABLE thread ADD COLUMN creation_notice_mxid TEXT NOT NULL DEFAULT '';
+UPDATE message SET dc_thread_id='' WHERE dc_thread_id IS NULL;
+UPDATE reaction SET dc_thread_id='' WHERE dc_thread_id IS NULL;
+
+-- only: postgres for next 3 lines
+ALTER TABLE thread ALTER COLUMN creation_notice_mxid DROP DEFAULT;
+ALTER TABLE message ALTER COLUMN dc_thread_id SET NOT NULL;
+ALTER TABLE reaction ALTER COLUMN dc_thread_id SET NOT NULL;

+ 15 - 4
database/userportal.go

@@ -10,8 +10,9 @@ import (
 )
 
 const (
-	UserPortalTypeDM    = "dm"
-	UserPortalTypeGuild = "guild"
+	UserPortalTypeDM     = "dm"
+	UserPortalTypeGuild  = "guild"
+	UserPortalTypeThread = "thread"
 )
 
 type UserPortal struct {
@@ -62,6 +63,16 @@ func (u *User) IsInSpace(discordID string) (isIn bool) {
 	return
 }
 
+func (u *User) IsInPortal(discordID string) (isIn bool) {
+	query := `SELECT EXISTS(SELECT 1 FROM user_portal WHERE user_mxid=$1 AND discord_id=$2)`
+	err := u.db.QueryRow(query, u.MXID, discordID).Scan(&isIn)
+	if err != nil && !errors.Is(err, sql.ErrNoRows) {
+		u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, discordID, err)
+		panic(err)
+	}
+	return
+}
+
 func (u *User) MarkInPortal(portal UserPortal) {
 	query := `
 		INSERT INTO user_portal (discord_id, type, user_mxid, timestamp, in_space)
@@ -88,8 +99,8 @@ func (u *User) MarkNotInPortal(discordID string) {
 func (u *User) PrunePortalList(beforeTS time.Time) []UserPortal {
 	query := `
 		DELETE FROM user_portal
-			WHERE user_mxid=$1 AND timestamp<$2
-			RETURNING discord_id, type, timestamp, in_space
+		WHERE user_mxid=$1 AND timestamp<$2 AND type IN ('dm', 'guild')
+		RETURNING discord_id, type, timestamp, in_space
 	`
 	rows, err := u.db.Query(query, u.MXID, beforeTS.UnixMilli())
 	if err != nil {

+ 3 - 0
example-config.yaml

@@ -106,6 +106,9 @@ bridge:
     # Should the bridge use space-restricted join rules instead of invite-only for guild rooms?
     # This can avoid unnecessary invite events in guild rooms when members are synced in.
     restricted_rooms: true
+    # Should the bridge automatically join the user to threads on Discord when the thread is opened on Matrix?
+    # This only works with clients that support thread read receipts (MSC3771 added in Matrix v1.4).
+    autojoin_thread_on_open: true
     # Should the bridge update the m.direct account data event when double puppeting is enabled.
     # Note that updating the m.direct event is not atomic (except with mautrix-asmux)
     # and is therefore prone to race conditions.

+ 6 - 6
go.mod

@@ -6,13 +6,13 @@ require (
 	github.com/bwmarrin/discordgo v0.26.1
 	github.com/gorilla/mux v1.8.0
 	github.com/gorilla/websocket v1.5.0
-	github.com/lib/pq v1.10.6
+	github.com/lib/pq v1.10.7
 	github.com/mattn/go-sqlite3 v1.14.15
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
 	github.com/stretchr/testify v1.8.0
-	github.com/yuin/goldmark v1.4.13
+	github.com/yuin/goldmark v1.5.2
 	maunium.net/go/maulogger/v2 v2.3.2
-	maunium.net/go/mautrix v0.12.2-0.20220919211529-34e3589a5d5e
+	maunium.net/go/mautrix v0.12.2-0.20221016082058-f7d28cc18df8
 )
 
 require (
@@ -20,13 +20,13 @@ require (
 	github.com/mattn/go-colorable v0.1.12 // indirect
 	github.com/mattn/go-isatty v0.0.14 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
-	github.com/rs/zerolog v1.27.0 // indirect
+	github.com/rs/zerolog v1.28.0 // indirect
 	github.com/tidwall/gjson v1.14.3 // indirect
 	github.com/tidwall/match v1.1.1 // indirect
 	github.com/tidwall/pretty v1.2.0 // indirect
 	github.com/tidwall/sjson v1.2.5 // indirect
-	golang.org/x/crypto v0.0.0-20220817201139-bc19a97f63c8 // indirect
-	golang.org/x/net v0.0.0-20220812174116-3211cb980234 // indirect
+	golang.org/x/crypto v0.0.0-20221012134737-56aed061732a // indirect
+	golang.org/x/net v0.0.0-20221014081412-f15817d10f9b // indirect
 	golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 	maunium.net/go/mauflag v1.0.0 // indirect

+ 13 - 13
go.sum

@@ -11,8 +11,8 @@ github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB7
 github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
 github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
 github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
-github.com/lib/pq v1.10.6 h1:jbk+ZieJ0D7EVGJYpL9QTz7/YW6UHbmdnZWYyK5cdBs=
-github.com/lib/pq v1.10.6/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
+github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw=
+github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
 github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40=
 github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
 github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
@@ -22,9 +22,9 @@ github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S
 github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
-github.com/rs/zerolog v1.27.0 h1:1T7qCieN22GVc8S4Q2yuexzBb1EqjbgjSH9RohbMjKs=
-github.com/rs/zerolog v1.27.0/go.mod h1:7frBqO0oezxmnO7GF86FY++uy8I0Tk/If5ni1G9Qc0U=
+github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
+github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY=
+github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@@ -41,14 +41,14 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
 github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
 github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
 github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
-github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE=
-github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
+github.com/yuin/goldmark v1.5.2 h1:ALmeCk/px5FSm1MAcFBAsVKZjDuMVj8Tm7FFIlMJnqU=
+github.com/yuin/goldmark v1.5.2/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
 golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
-golang.org/x/crypto v0.0.0-20220817201139-bc19a97f63c8 h1:GIAS/yBem/gq2MUqgNIzUHW7cJMmx3TGZOrnyYaNQ6c=
-golang.org/x/crypto v0.0.0-20220817201139-bc19a97f63c8/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
+golang.org/x/crypto v0.0.0-20221012134737-56aed061732a h1:NmSIgad6KjE6VvHciPZuNRTKxGhlPfD6OA87W/PLkqg=
+golang.org/x/crypto v0.0.0-20221012134737-56aed061732a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.0.0-20220812174116-3211cb980234 h1:RDqmgfe7SvlMWoqC3xwQ2blLO3fcWcxMa3eBLRdRW7E=
-golang.org/x/net v0.0.0-20220812174116-3211cb980234/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
+golang.org/x/net v0.0.0-20221014081412-f15817d10f9b h1:tvrvnPFcdzp294diPnrdZZZ8XUt2Tyj7svb7X52iDuU=
+golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -66,5 +66,5 @@ maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
 maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
 maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0=
 maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A=
-maunium.net/go/mautrix v0.12.2-0.20220919211529-34e3589a5d5e h1:NkB/p2VTBiF+kis+JhOYKKyAgE0KH7kN9744WETU9aA=
-maunium.net/go/mautrix v0.12.2-0.20220919211529-34e3589a5d5e/go.mod h1:/jxQFIipObSsjZPH6o3xyUi8uoULz3Hfr/8p9loqpYE=
+maunium.net/go/mautrix v0.12.2-0.20221016082058-f7d28cc18df8 h1:OsKa24dXPnBgT5lr5ResRbTO9f+bZpZe/K5ioO2+1a8=
+maunium.net/go/mautrix v0.12.2-0.20221016082058-f7d28cc18df8/go.mod h1:bCw45Qx/m9qsz7eazmbe7Rzq5ZbTPzwRE1UgX2S9DXs=

+ 7 - 5
main.go

@@ -59,9 +59,10 @@ type DiscordBridge struct {
 	portalsByID   map[database.PortalKey]*Portal
 	portalsLock   sync.Mutex
 
-	threadsByID       map[string]*Thread
-	threadsByRootMXID map[id.EventID]*Thread
-	threadsLock       sync.Mutex
+	threadsByID                 map[string]*Thread
+	threadsByRootMXID           map[id.EventID]*Thread
+	threadsByCreationNoticeMXID map[id.EventID]*Thread
+	threadsLock                 sync.Mutex
 
 	guildsByMXID map[id.RoomID]*Guild
 	guildsByID   map[string]*Guild
@@ -153,8 +154,9 @@ func main() {
 		portalsByMXID: make(map[id.RoomID]*Portal),
 		portalsByID:   make(map[database.PortalKey]*Portal),
 
-		threadsByID:       make(map[string]*Thread),
-		threadsByRootMXID: make(map[id.EventID]*Thread),
+		threadsByID:                 make(map[string]*Thread),
+		threadsByRootMXID:           make(map[id.EventID]*Thread),
+		threadsByCreationNoticeMXID: make(map[id.EventID]*Thread),
 
 		guildsByID:   make(map[string]*Guild),
 		guildsByMXID: make(map[id.RoomID]*Guild),

+ 88 - 11
portal.go

@@ -59,6 +59,14 @@ type Portal struct {
 	currentlyTypingLock sync.Mutex
 }
 
+var _ bridge.Portal = (*Portal)(nil)
+var _ bridge.ReadReceiptHandlingPortal = (*Portal)(nil)
+var _ bridge.MembershipHandlingPortal = (*Portal)(nil)
+var _ bridge.TypingPortal = (*Portal)(nil)
+
+//var _ bridge.MetaHandlingPortal = (*Portal)(nil)
+//var _ bridge.DisappearingPortal = (*Portal)(nil)
+
 func (portal *Portal) IsEncrypted() bool {
 	return portal.Encrypted
 }
@@ -74,8 +82,6 @@ func (portal *Portal) ReceiveMatrixEvent(user bridge.User, evt *event.Event) {
 	}
 }
 
-var _ bridge.Portal = (*Portal)(nil)
-
 var (
 	portalCreationDummyEvent = event.Type{Type: "fi.mau.dummy.portal_created", Class: event.MessageEventType}
 )
@@ -714,6 +720,48 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
 	}
 }
 
+const JoinThreadReaction = "join thread"
+
+func (portal *Portal) sendThreadCreationNotice(thread *Thread) {
+	thread.creationNoticeLock.Lock()
+	defer thread.creationNoticeLock.Unlock()
+	if thread.CreationNoticeMXID != "" {
+		return
+	}
+	creationNotice := "Thread created. React to this message with \"join thread\" to join the thread on Discord."
+	if portal.bridge.Config.Bridge.AutojoinThreadOnOpen {
+		creationNotice = "Thread created. Opening this thread will auto-join you to it on Discord."
+	}
+	resp, err := portal.sendMatrixMessage(portal.MainIntent(), event.EventMessage, &event.MessageEventContent{
+		Body:      creationNotice,
+		MsgType:   event.MsgNotice,
+		RelatesTo: (&event.RelatesTo{}).SetThread(thread.RootMXID, thread.RootMXID),
+	}, nil, time.Now().UnixMilli())
+	if err != nil {
+		portal.log.Errorfln("Failed to send thread creation notice: %v", err)
+		return
+	}
+	portal.bridge.threadsLock.Lock()
+	thread.CreationNoticeMXID = resp.EventID
+	portal.bridge.threadsByCreationNoticeMXID[resp.EventID] = thread
+	portal.bridge.threadsLock.Unlock()
+	thread.Update()
+	portal.log.Debugfln("Sent notice %s about thread for %s being created", thread.CreationNoticeMXID, thread.ID)
+
+	resp, err = portal.MainIntent().SendMessageEvent(portal.MXID, event.EventReaction, &event.ReactionEventContent{
+		RelatesTo: event.RelatesTo{
+			Type:    event.RelAnnotation,
+			EventID: thread.CreationNoticeMXID,
+			Key:     JoinThreadReaction,
+		},
+	})
+	if err != nil {
+		portal.log.Errorfln("Failed to send prefilled reaction to thread creation notice: %v", err)
+	} else {
+		portal.log.Debugfln("Sent prefilled reaction %s to thread creation notice %s", resp.EventID, thread.CreationNoticeMXID)
+	}
+}
+
 func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Message) {
 	if portal.MXID == "" {
 		portal.log.Warnln("handle message called without a valid portal")
@@ -728,13 +776,11 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess
 	}
 
 	if msg.Flags == discordgo.MessageFlagsHasThread {
-		portal.bridge.GetThreadByID(msg.ID, existing[0])
+		thread := portal.bridge.GetThreadByID(msg.ID, existing[0])
 		portal.log.Debugfln("Marked %s as a thread root", msg.ID)
-		// TODO make autojoining configurable
-		//err := user.Session.ThreadJoinWithLocation(msg.ID, discordgo.ThreadJoinLocationContextMenu)
-		//if err != nil {
-		//	user.log.Warnfln("Error autojoining thread %s@%s: %v", msg.ChannelID, portal.Key.ChannelID, err)
-		//}
+		if thread.CreationNoticeMXID == "" {
+			portal.sendThreadCreationNotice(thread)
+		}
 	}
 
 	// There's a few scenarios where the author is nil but I haven't figured
@@ -1322,6 +1368,16 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) {
 		return
 	}
 
+	if reaction.RelatesTo.Key == JoinThreadReaction {
+		thread := portal.bridge.GetThreadByRootOrCreationNoticeMXID(reaction.RelatesTo.EventID)
+		if thread == nil {
+			go portal.sendMessageMetrics(evt, errTargetNotFound, "Ignoring thread join")
+			return
+		}
+		thread.Join(sender)
+		return
+	}
+
 	msg := portal.bridge.DB.Message.GetByMXID(portal.Key, reaction.RelatesTo.EventID)
 	if msg == nil {
 		go portal.sendMessageMetrics(evt, errTargetNotFound, "Ignoring")
@@ -1479,14 +1535,31 @@ func (portal *Portal) handleMatrixRedaction(sender *User, evt *event.Event) {
 	go portal.sendMessageMetrics(evt, errTargetNotFound, "Ignoring")
 }
 
-func (portal *Portal) HandleMatrixReadReceipt(brUser bridge.User, eventID id.EventID, receiptTimestamp time.Time) {
+func (portal *Portal) HandleMatrixReadReceipt(brUser bridge.User, eventID id.EventID, receipt event.ReadReceipt) {
 	sender := brUser.(*User)
 	if sender.Session == nil {
 		return
 	}
+	var thread *Thread
+	discordThreadID := ""
+	if receipt.ThreadID != "" && receipt.ThreadID != event.ReadReceiptThreadMain {
+		thread = portal.bridge.GetThreadByRootMXID(receipt.ThreadID)
+		if thread != nil {
+			discordThreadID = thread.ID
+		}
+	}
+	if thread != nil {
+		if portal.bridge.Config.Bridge.AutojoinThreadOnOpen {
+			thread.Join(sender)
+		}
+		if eventID == thread.CreationNoticeMXID {
+			portal.log.Debugfln("Dropping Matrix read receipt from %s for thread creation notice %s of %s", sender.MXID, thread.CreationNoticeMXID, thread.ID)
+			return
+		}
+	}
 	msg := portal.bridge.DB.Message.GetByMXID(portal.Key, eventID)
 	if msg == nil {
-		msg = portal.bridge.DB.Message.GetClosestBefore(portal.Key, receiptTimestamp)
+		msg = portal.bridge.DB.Message.GetClosestBefore(portal.Key, discordThreadID, receipt.Timestamp)
 		if msg == nil {
 			portal.log.Debugfln("Dropping Matrix read receipt from %s for %s: no messages found", sender.MXID, eventID)
 			return
@@ -1494,13 +1567,17 @@ func (portal *Portal) HandleMatrixReadReceipt(brUser bridge.User, eventID id.Eve
 			portal.log.Debugfln("Matrix read receipt target %s from %s not found, using closest message %s", eventID, sender.MXID, msg.MXID)
 		}
 	}
+	if receipt.ThreadID != "" && msg.ThreadID != discordThreadID {
+		portal.log.Debugfln("Dropping Matrix read receipt from %s for %s in unexpected thread (receipt: %s, message: %s)", receipt.ThreadID, msg.ThreadID)
+		return
+	}
 	resp, err := sender.Session.ChannelMessageAckNoToken(msg.DiscordProtoChannelID(), msg.DiscordID)
 	if err != nil {
 		portal.log.Warnfln("Failed to handle read receipt for %s/%s from %s: %v", msg.MXID, msg.DiscordID, sender.MXID, err)
 	} else if resp.Token != nil {
 		portal.log.Debugfln("Marked %s/%s as read by %s (and got unexpected non-nil token %s)", msg.MXID, msg.DiscordID, sender.MXID, *resp.Token)
 	} else {
-		portal.log.Debugfln("Marked %s/%s as read by %s", msg.MXID, msg.DiscordID, sender.MXID)
+		portal.log.Debugfln("Marked %s/%s in %s as read by %s", msg.MXID, msg.DiscordID, msg.DiscordProtoChannelID(), sender.MXID)
 	}
 }
 

+ 39 - 0
thread.go

@@ -1,6 +1,10 @@
 package main
 
 import (
+	"sync"
+	"time"
+
+	"github.com/bwmarrin/discordgo"
 	"maunium.net/go/mautrix/id"
 
 	"go.mau.fi/mautrix-discord/database"
@@ -9,6 +13,8 @@ import (
 type Thread struct {
 	*database.Thread
 	Parent *Portal
+
+	creationNoticeLock sync.Mutex
 }
 
 func (br *DiscordBridge) GetThreadByID(id string, root *database.Message) *Thread {
@@ -31,6 +37,19 @@ func (br *DiscordBridge) GetThreadByRootMXID(mxid id.EventID) *Thread {
 	return thread
 }
 
+func (br *DiscordBridge) GetThreadByRootOrCreationNoticeMXID(mxid id.EventID) *Thread {
+	br.threadsLock.Lock()
+	defer br.threadsLock.Unlock()
+	thread, ok := br.threadsByRootMXID[mxid]
+	if !ok {
+		thread, ok = br.threadsByCreationNoticeMXID[mxid]
+		if !ok {
+			return br.loadThread(br.DB.Thread.GetByMatrixRootOrCreationNoticeMsg(mxid), "", nil)
+		}
+	}
+	return thread
+}
+
 func (br *DiscordBridge) loadThread(dbThread *database.Thread, id string, root *database.Message) *Thread {
 	if dbThread == nil {
 		if root == nil {
@@ -49,5 +68,25 @@ func (br *DiscordBridge) loadThread(dbThread *database.Thread, id string, root *
 	thread.Parent = br.GetExistingPortalByID(database.NewPortalKey(thread.ParentID, ""))
 	br.threadsByID[thread.ID] = thread
 	br.threadsByRootMXID[thread.RootMXID] = thread
+	if thread.CreationNoticeMXID != "" {
+		br.threadsByCreationNoticeMXID[thread.CreationNoticeMXID] = thread
+	}
 	return thread
 }
+
+func (thread *Thread) Join(user *User) {
+	if user.IsInPortal(thread.ID) {
+		return
+	}
+	user.log.Debugfln("Joining thread %s@%s", thread.ID, thread.ParentID)
+	err := user.Session.ThreadJoinWithLocation(thread.ID, discordgo.ThreadJoinLocationContextMenu)
+	if err != nil {
+		user.log.Errorfln("Error joining thread %s@%s: %v", thread.ID, thread.ParentID, err)
+	} else {
+		user.MarkInPortal(database.UserPortal{
+			DiscordID: thread.ID,
+			Type:      database.UserPortalTypeThread,
+			Timestamp: time.Now(),
+		})
+	}
+}