user.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. package database
  2. import (
  3. "database/sql"
  4. log "maunium.net/go/maulogger/v2"
  5. "maunium.net/go/mautrix/id"
  6. "maunium.net/go/mautrix/util/dbutil"
  7. )
  8. type UserQuery struct {
  9. db *Database
  10. log log.Logger
  11. }
  12. func (uq *UserQuery) New() *User {
  13. return &User{
  14. db: uq.db,
  15. log: uq.log,
  16. }
  17. }
  18. func (uq *UserQuery) GetByMXID(userID id.UserID) *User {
  19. query := `SELECT mxid, dcid, management_room, token FROM "user" WHERE mxid=$1`
  20. row := uq.db.QueryRow(query, userID)
  21. if row == nil {
  22. return nil
  23. }
  24. return uq.New().Scan(row)
  25. }
  26. func (uq *UserQuery) GetByID(id string) *User {
  27. query := `SELECT mxid, dcid, management_room, token FROM "user" WHERE dcid=$1`
  28. row := uq.db.QueryRow(query, id)
  29. if row == nil {
  30. return nil
  31. }
  32. return uq.New().Scan(row)
  33. }
  34. func (uq *UserQuery) GetAll() []*User {
  35. rows, err := uq.db.Query(`SELECT mxid, dcid, management_room, token FROM "user" WHERE token IS NOT NULL`)
  36. if err != nil || rows == nil {
  37. return nil
  38. }
  39. defer rows.Close()
  40. users := []*User{}
  41. for rows.Next() {
  42. users = append(users, uq.New().Scan(rows))
  43. }
  44. return users
  45. }
  46. type User struct {
  47. db *Database
  48. log log.Logger
  49. MXID id.UserID
  50. ID string
  51. ManagementRoom id.RoomID
  52. Token string
  53. }
  54. func (u *User) Scan(row dbutil.Scannable) *User {
  55. var token sql.NullString
  56. var discordID sql.NullString
  57. err := row.Scan(&u.MXID, &discordID, &u.ManagementRoom, &token)
  58. if err != nil {
  59. if err != sql.ErrNoRows {
  60. u.log.Errorln("Database scan failed:", err)
  61. }
  62. return nil
  63. }
  64. if token.Valid {
  65. u.Token = token.String
  66. }
  67. if discordID.Valid {
  68. u.ID = discordID.String
  69. }
  70. return u
  71. }
  72. func (u *User) Insert() {
  73. query := "INSERT INTO \"user\" (mxid, dcid, management_room, token) VALUES ($1, $2, $3, $4)"
  74. var token sql.NullString
  75. var discordID sql.NullString
  76. if u.Token != "" {
  77. token.String = u.Token
  78. token.Valid = true
  79. }
  80. if u.ID != "" {
  81. discordID.String = u.ID
  82. discordID.Valid = true
  83. }
  84. _, err := u.db.Exec(query, u.MXID, discordID, u.ManagementRoom, token)
  85. if err != nil {
  86. u.log.Warnfln("Failed to insert %s: %v", u.MXID, err)
  87. }
  88. }
  89. func (u *User) Update() {
  90. query := "UPDATE \"user\" SET dcid=$1, management_room=$2, token=$3 WHERE mxid=$4"
  91. var token sql.NullString
  92. var discordID sql.NullString
  93. if u.Token != "" {
  94. token.String = u.Token
  95. token.Valid = true
  96. }
  97. if u.ID != "" {
  98. discordID.String = u.ID
  99. discordID.Valid = true
  100. }
  101. _, err := u.db.Exec(query, discordID, u.ManagementRoom, token, u.MXID)
  102. if err != nil {
  103. u.log.Warnfln("Failed to update %q: %v", u.MXID, err)
  104. }
  105. }