Forráskód Böngészése

Add user-portal mapping to database

Tulir Asokan 6 éve
szülő
commit
dce08b1422
6 módosított fájl, 167 hozzáadás és 62 törlés
  1. 24 1
      database/portal.go
  2. 19 0
      database/upgrades/2019-05-28-user-portal-table.go
  3. 1 1
      database/upgrades/upgrades.go
  4. 48 0
      database/user.go
  5. 24 25
      portal.go
  6. 51 35
      user.go

+ 24 - 1
database/portal.go

@@ -42,7 +42,7 @@ func NewPortalKey(jid, receiver types.WhatsAppID) PortalKey {
 		receiver = jid
 	}
 	return PortalKey{
-		JID: jid,
+		JID:      jid,
 		Receiver: receiver,
 	}
 }
@@ -152,3 +152,26 @@ func (portal *Portal) Delete() {
 		portal.log.Warnfln("Failed to delete %s: %v", portal.Key, err)
 	}
 }
+
+func (portal *Portal) GetUserIDs() []types.MatrixUserID {
+	rows, err := portal.db.Query(`SELECT "user".mxid FROM "user", user_portal
+		WHERE "user".jid=user_portal.user_jid
+			AND user_portal.portal_jid=$1
+			AND user_portal.portal_receiver=$2`,
+		portal.Key.JID, portal.Key.Receiver)
+	if err != nil {
+		portal.log.Debugln("Failed to get portal user ids:", err)
+		return nil
+	}
+	var userIDs []types.MatrixUserID
+	for rows.Next() {
+		var userID types.MatrixUserID
+		err = rows.Scan(&userID)
+		if err != nil {
+			portal.log.Warnln("Failed to scan row:", err)
+			continue
+		}
+		userIDs = append(userIDs, userID)
+	}
+	return userIDs
+}

+ 19 - 0
database/upgrades/2019-05-28-user-portal-table.go

@@ -0,0 +1,19 @@
+package upgrades
+
+import (
+	"database/sql"
+)
+
+func init() {
+	upgrades[6] = upgrade{"Add user-portal mapping table", func(dialect Dialect, tx *sql.Tx, db *sql.DB) error {
+		_, err := tx.Exec(`CREATE TABLE user_portal (
+			user_jid        VARCHAR(255),
+			portal_jid      VARCHAR(255),
+			portal_receiver VARCHAR(255),
+			PRIMARY KEY (user_jid, portal_jid, portal_receiver),
+			FOREIGN KEY (user_jid)                    REFERENCES "user"(jid)           ON DELETE CASCADE,
+			FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
+		)`)
+		return err
+	}}
+}

+ 1 - 1
database/upgrades/upgrades.go

@@ -22,7 +22,7 @@ type upgrade struct {
 	fn upgradeFunc
 }
 
-const NumberOfUpgrades = 6
+const NumberOfUpgrades = 7
 
 var upgrades [NumberOfUpgrades]upgrade
 

+ 48 - 0
database/user.go

@@ -18,6 +18,7 @@ package database
 
 import (
 	"database/sql"
+	"fmt"
 	"strings"
 	"time"
 
@@ -165,3 +166,50 @@ func (user *User) Update() {
 		user.log.Warnfln("Failed to update %s: %v", user.MXID, err)
 	}
 }
+
+func (user *User) SetPortalKeys(newKeys []PortalKey) error {
+	tx, err := user.db.Begin()
+	if err != nil {
+		return err
+	}
+	_, err = tx.Exec("DELETE FROM user_portal WHERE user_jid=$1", user.jidPtr())
+	if err != nil {
+		_ = tx.Rollback()
+		return err
+	}
+	valueStrings := make([]string, len(newKeys))
+	values := make([]interface{}, len(newKeys)*3)
+	for i, key := range newKeys {
+		valueStrings[i] = fmt.Sprintf("($%d, $%d, $%d)", i*3+1, i*3+2, i*3+3)
+		values[i*3] = user.jidPtr()
+		values[i*3+1] = key.JID
+		values[i*3+2] = key.Receiver
+	}
+	query := fmt.Sprintf("INSERT INTO user_portal (user_jid, portal_jid, portal_receiver) VALUES %s",
+		strings.Join(valueStrings, ", "))
+	_, err = tx.Exec(query, values...)
+	if err != nil {
+		_ = tx.Rollback()
+		return err
+	}
+	return tx.Commit()
+}
+
+func (user *User) GetPortalKeys() []PortalKey {
+	rows, err := user.db.Query(`SELECT portal_jid, portal_receiver FROM user_portal WHERE user_jid=$1`, user.jidPtr())
+	if err != nil {
+		user.log.Warnln("Failed to get user portal keys:", err)
+		return nil
+	}
+	var keys []PortalKey
+	for rows.Next() {
+		var key PortalKey
+		err = rows.Scan(&key.JID, &key.Receiver)
+		if err != nil {
+			user.log.Warnln("Failed to scan row:", err)
+			continue
+		}
+		keys = append(keys, key)
+	}
+	return keys
+}

+ 24 - 25
portal.go

@@ -50,15 +50,7 @@ func (bridge *Bridge) GetPortalByMXID(mxid types.MatrixRoomID) *Portal {
 	defer bridge.portalsLock.Unlock()
 	portal, ok := bridge.portalsByMXID[mxid]
 	if !ok {
-		dbPortal := bridge.DB.Portal.GetByMXID(mxid)
-		if dbPortal == nil {
-			return nil
-		}
-		portal = bridge.NewPortal(dbPortal)
-		bridge.portalsByJID[portal.Key] = portal
-		if len(portal.MXID) > 0 {
-			bridge.portalsByMXID[portal.MXID] = portal
-		}
+		return bridge.loadDBPortal(bridge.DB.Portal.GetByMXID(mxid), nil)
 	}
 	return portal
 }
@@ -68,17 +60,7 @@ func (bridge *Bridge) GetPortalByJID(key database.PortalKey) *Portal {
 	defer bridge.portalsLock.Unlock()
 	portal, ok := bridge.portalsByJID[key]
 	if !ok {
-		dbPortal := bridge.DB.Portal.GetByJID(key)
-		if dbPortal == nil {
-			dbPortal = bridge.DB.Portal.New()
-			dbPortal.Key = key
-			dbPortal.Insert()
-		}
-		portal = bridge.NewPortal(dbPortal)
-		bridge.portalsByJID[portal.Key] = portal
-		if len(portal.MXID) > 0 {
-			bridge.portalsByMXID[portal.MXID] = portal
-		}
+		return bridge.loadDBPortal(bridge.DB.Portal.GetByJID(key), &key)
 	}
 	return portal
 }
@@ -91,17 +73,34 @@ func (bridge *Bridge) GetAllPortals() []*Portal {
 	for index, dbPortal := range dbPortals {
 		portal, ok := bridge.portalsByJID[dbPortal.Key]
 		if !ok {
-			portal = bridge.NewPortal(dbPortal)
-			bridge.portalsByJID[portal.Key] = portal
-			if len(dbPortal.MXID) > 0 {
-				bridge.portalsByMXID[dbPortal.MXID] = portal
-			}
+			portal = bridge.loadDBPortal(dbPortal, nil)
 		}
 		output[index] = portal
 	}
 	return output
 }
 
+func (bridge *Bridge) loadDBPortal(dbPortal *database.Portal, key *database.PortalKey) *Portal {
+	if dbPortal == nil {
+		if key == nil {
+			return nil
+		}
+		dbPortal = bridge.DB.Portal.New()
+		dbPortal.Key = *key
+		dbPortal.Insert()
+	}
+	portal := bridge.NewPortal(dbPortal)
+	bridge.portalsByJID[portal.Key] = portal
+	if len(portal.MXID) > 0 {
+		bridge.portalsByMXID[portal.MXID] = portal
+	}
+	return portal
+}
+
+func (portal *Portal) GetUsers() []*User {
+	return nil
+}
+
 func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal {
 	portal := &Portal{
 		Portal: dbPortal,

+ 51 - 35
user.go

@@ -66,20 +66,7 @@ func (bridge *Bridge) GetUserByMXID(userID types.MatrixUserID) *User {
 	defer bridge.usersLock.Unlock()
 	user, ok := bridge.usersByMXID[userID]
 	if !ok {
-		dbUser := bridge.DB.User.GetByMXID(userID)
-		if dbUser == nil {
-			dbUser = bridge.DB.User.New()
-			dbUser.MXID = userID
-			dbUser.Insert()
-		}
-		user = bridge.NewUser(dbUser)
-		bridge.usersByMXID[user.MXID] = user
-		if len(user.JID) > 0 {
-			bridge.usersByJID[user.JID] = user
-		}
-		if len(user.ManagementRoom) > 0 {
-			bridge.managementRooms[user.ManagementRoom] = user
-		}
+		return bridge.loadDBUser(bridge.DB.User.GetByMXID(userID), &userID)
 	}
 	return user
 }
@@ -89,16 +76,7 @@ func (bridge *Bridge) GetUserByJID(userID types.WhatsAppID) *User {
 	defer bridge.usersLock.Unlock()
 	user, ok := bridge.usersByJID[userID]
 	if !ok {
-		dbUser := bridge.DB.User.GetByJID(userID)
-		if dbUser == nil {
-			return nil
-		}
-		user = bridge.NewUser(dbUser)
-		bridge.usersByMXID[user.MXID] = user
-		bridge.usersByJID[user.JID] = user
-		if len(user.ManagementRoom) > 0 {
-			bridge.managementRooms[user.ManagementRoom] = user
-		}
+		return bridge.loadDBUser(bridge.DB.User.GetByJID(userID), nil)
 	}
 	return user
 }
@@ -111,20 +89,50 @@ func (bridge *Bridge) GetAllUsers() []*User {
 	for index, dbUser := range dbUsers {
 		user, ok := bridge.usersByMXID[dbUser.MXID]
 		if !ok {
-			user = bridge.NewUser(dbUser)
-			bridge.usersByMXID[user.MXID] = user
-			if len(user.JID) > 0 {
-				bridge.usersByJID[user.JID] = user
-			}
-			if len(user.ManagementRoom) > 0 {
-				bridge.managementRooms[user.ManagementRoom] = user
-			}
+			user = bridge.loadDBUser(dbUser, nil)
 		}
 		output[index] = user
 	}
 	return output
 }
 
+func (bridge *Bridge) loadDBUser(dbUser *database.User, mxid *types.MatrixUserID) *User {
+	if dbUser == nil {
+		if mxid == nil {
+			return nil
+		}
+		dbUser = bridge.DB.User.New()
+		dbUser.MXID = *mxid
+		dbUser.Insert()
+	}
+	user := bridge.NewUser(dbUser)
+	bridge.usersByMXID[user.MXID] = user
+	if len(user.JID) > 0 {
+		bridge.usersByJID[user.JID] = user
+	}
+	if len(user.ManagementRoom) > 0 {
+		bridge.managementRooms[user.ManagementRoom] = user
+	}
+	return user
+}
+
+func (user *User) GetPortals() []*Portal {
+	keys := user.User.GetPortalKeys()
+	portals := make([]*Portal, len(keys))
+
+	user.bridge.portalsLock.Lock()
+	defer user.bridge.portalsLock.Unlock()
+
+	for i, key := range keys {
+		portal, ok := user.bridge.portalsByJID[key]
+		if !ok {
+			portal = user.bridge.loadDBPortal(user.bridge.DB.Portal.GetByJID(key), &key)
+		}
+		portals[i] = portal
+	}
+	return portals
+}
+
 func (bridge *Bridge) NewUser(dbUser *database.User) *User {
 	user := &User{
 		User:   dbUser,
@@ -295,18 +303,26 @@ func (user *User) PostLogin() {
 }
 
 func (user *User) syncPortals(createAll bool) {
-	var chats ChatList
+	chats := make(ChatList, 0, len(user.Conn.Store.Chats))
+	portalKeys := make([]database.PortalKey, 0, len(user.Conn.Store.Chats))
 	for _, chat := range user.Conn.Store.Chats {
 		ts, err := strconv.ParseUint(chat.LastMessageTime, 10, 64)
 		if err != nil {
 			user.log.Warnfln("Non-integer last message time in %s: %s", chat.Jid, chat.LastMessageTime)
 			continue
 		}
+		portal := user.GetPortalByJID(chat.Jid)
+
 		chats = append(chats, Chat{
-			Portal:          user.GetPortalByJID(chat.Jid),
+			Portal:          portal,
 			Contact:         user.Conn.Store.Contacts[chat.Jid],
 			LastMessageTime: ts,
 		})
+		portalKeys = append(portalKeys, portal.Key)
+	}
+	err := user.SetPortalKeys(portalKeys)
+	if err != nil {
+		user.log.Warnln("Failed to update user-portal mapping:", err)
 	}
 	sort.Sort(chats)
 	limit := user.bridge.Config.Bridge.InitialChatSync
@@ -315,7 +331,7 @@ func (user *User) syncPortals(createAll bool) {
 	}
 	now := uint64(time.Now().Unix())
 	for i, chat := range chats {
-		if chat.LastMessageTime + user.bridge.Config.Bridge.SyncChatMaxAge < now {
+		if chat.LastMessageTime+user.bridge.Config.Bridge.SyncChatMaxAge < now {
 			break
 		}
 		create := (chat.LastMessageTime >= user.LastConnection && user.LastConnection > 0) || i < limit