Parcourir la source

Load users from the database during startup

Gary Kramlich il y a 3 ans
Parent
commit
de1f524e25
5 fichiers modifiés avec 60 ajouts et 8 suppressions
  1. 2 0
      bridge/bridge.go
  2. 1 1
      bridge/commands.go
  3. 36 2
      bridge/user.go
  4. 5 5
      database/user.go
  5. 16 0
      database/userquery.go

+ 2 - 0
bridge/bridge.go

@@ -141,6 +141,8 @@ func (b *Bridge) Start() error {
 
 
 	go b.updateBotProfile()
 	go b.updateBotProfile()
 
 
+	go b.startUsers()
+
 	// Finally tell the appservice we're ready
 	// Finally tell the appservice we're ready
 	b.as.Ready = true
 	b.as.Ready = true
 
 

+ 1 - 1
bridge/commands.go

@@ -123,7 +123,7 @@ func (l *loginCmd) Run(g *globals) error {
 		return err
 		return err
 	}
 	}
 
 
-	if err := g.user.login(user.Token); err != nil {
+	if err := g.user.Login(user.Token); err != nil {
 		fmt.Println(g.context.Stdout, "failed to login", err)
 		fmt.Println(g.context.Stdout, "failed to login", err)
 
 
 		return err
 		return err

+ 36 - 2
bridge/user.go

@@ -74,6 +74,36 @@ func (b *Bridge) NewUser(dbUser *database.User) *User {
 	return user
 	return user
 }
 }
 
 
+func (b *Bridge) getAllUsers() []*User {
+	b.usersLock.Lock()
+	defer b.usersLock.Unlock()
+
+	dbUsers := b.db.User.GetAll()
+	users := make([]*User, len(dbUsers))
+
+	for idx, dbUser := range dbUsers {
+		user, ok := b.usersByMXID[dbUser.MXID]
+		if !ok {
+			user = b.loadUser(dbUser, nil)
+		}
+		users[idx] = user
+	}
+
+	return users
+}
+
+func (b *Bridge) startUsers() {
+	b.log.Debugln("Starting users")
+
+	for _, user := range b.getAllUsers() {
+		// if user.ID != "" {
+		// 	haveSessions = true
+		// }
+
+		go user.Connect()
+	}
+}
+
 func (u *User) SetManagementRoom(roomID id.RoomID) {
 func (u *User) SetManagementRoom(roomID id.RoomID) {
 	u.bridge.managementRoomsLock.Lock()
 	u.bridge.managementRoomsLock.Lock()
 	defer u.bridge.managementRoomsLock.Unlock()
 	defer u.bridge.managementRoomsLock.Unlock()
@@ -137,12 +167,16 @@ func (u *User) uploadQRCode(code string) (id.ContentURI, error) {
 	return resp.ContentURI, nil
 	return resp.ContentURI, nil
 }
 }
 
 
-func (u *User) login(token string) error {
-	err := u.User.Login(token)
+func (u *User) Login(token string) error {
+	err := u.User.NewSession(token)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
+	return u.Connect()
+}
+
+func (u *User) Connect() error {
 	u.User.Session.AddHandler(u.messageHandler)
 	u.User.Session.AddHandler(u.messageHandler)
 
 
 	u.log.Warnln("logged in, opening websocket")
 	u.log.Warnln("logged in, opening websocket")

+ 5 - 5
database/user.go

@@ -21,10 +21,10 @@ type User struct {
 	Session *discordgo.Session
 	Session *discordgo.Session
 }
 }
 
 
-// Login is just used to create the session and update the database and should
-// only be called by bridge.User.Login which will continue setting up event
-// handlers.
-func (u *User) Login(token string) error {
+// NewSession is just used to create the session and update the database. It
+// should only be called by bridge.User.Connect which will continue setting up
+// event handlers and everything else.
+func (u *User) NewSession(token string) error {
 	session, err := discordgo.New(token)
 	session, err := discordgo.New(token)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -50,7 +50,7 @@ func (u *User) Scan(row Scannable) *User {
 	}
 	}
 
 
 	if token.Valid {
 	if token.Valid {
-		if err := u.Login(token.String); err != nil {
+		if err := u.NewSession(token.String); err != nil {
 			u.log.Errorln("Failed to login: ", err)
 			u.log.Errorln("Failed to login: ", err)
 		}
 		}
 	}
 	}

+ 16 - 0
database/userquery.go

@@ -25,3 +25,19 @@ func (uq *UserQuery) GetByMXID(userID id.UserID) *User {
 
 
 	return uq.New().Scan(row)
 	return uq.New().Scan(row)
 }
 }
+
+func (uq *UserQuery) GetAll() []*User {
+	rows, err := uq.db.Query("SELECT mxid, id, management_room, token FROM user")
+	if err != nil || rows == nil {
+		return nil
+	}
+
+	defer rows.Close()
+
+	users := []*User{}
+	for rows.Next() {
+		users = append(users, uq.New().Scan(rows))
+	}
+
+	return users
+}