puppet.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
  2. // Copyright (C) 2021 Tulir Asokan
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package database
  17. import (
  18. "database/sql"
  19. log "maunium.net/go/maulogger/v2"
  20. "maunium.net/go/mautrix/id"
  21. "maunium.net/go/mautrix/util/dbutil"
  22. "go.mau.fi/whatsmeow/types"
  23. )
  24. type PuppetQuery struct {
  25. db *Database
  26. log log.Logger
  27. }
  28. func (pq *PuppetQuery) New() *Puppet {
  29. return &Puppet{
  30. db: pq.db,
  31. log: pq.log,
  32. EnablePresence: true,
  33. EnableReceipts: true,
  34. }
  35. }
  36. func (pq *PuppetQuery) GetAll() (puppets []*Puppet) {
  37. rows, err := pq.db.Query("SELECT username, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet")
  38. if err != nil || rows == nil {
  39. return nil
  40. }
  41. defer rows.Close()
  42. for rows.Next() {
  43. puppets = append(puppets, pq.New().Scan(rows))
  44. }
  45. return
  46. }
  47. func (pq *PuppetQuery) Get(jid types.JID) *Puppet {
  48. row := pq.db.QueryRow("SELECT username, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE username=$1", jid.User)
  49. if row == nil {
  50. return nil
  51. }
  52. return pq.New().Scan(row)
  53. }
  54. func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet {
  55. row := pq.db.QueryRow("SELECT username, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE custom_mxid=$1", mxid)
  56. if row == nil {
  57. return nil
  58. }
  59. return pq.New().Scan(row)
  60. }
  61. func (pq *PuppetQuery) GetAllWithCustomMXID() (puppets []*Puppet) {
  62. rows, err := pq.db.Query("SELECT username, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE custom_mxid<>''")
  63. if err != nil || rows == nil {
  64. return nil
  65. }
  66. defer rows.Close()
  67. for rows.Next() {
  68. puppets = append(puppets, pq.New().Scan(rows))
  69. }
  70. return
  71. }
  72. type Puppet struct {
  73. db *Database
  74. log log.Logger
  75. JID types.JID
  76. Avatar string
  77. AvatarURL id.ContentURI
  78. Displayname string
  79. NameQuality int8
  80. CustomMXID id.UserID
  81. AccessToken string
  82. NextBatch string
  83. EnablePresence bool
  84. EnableReceipts bool
  85. }
  86. func (puppet *Puppet) Scan(row dbutil.Scannable) *Puppet {
  87. var displayname, avatar, avatarURL, customMXID, accessToken, nextBatch sql.NullString
  88. var quality sql.NullInt64
  89. var enablePresence, enableReceipts sql.NullBool
  90. var username string
  91. err := row.Scan(&username, &avatar, &avatarURL, &displayname, &quality, &customMXID, &accessToken, &nextBatch, &enablePresence, &enableReceipts)
  92. if err != nil {
  93. if err != sql.ErrNoRows {
  94. puppet.log.Errorln("Database scan failed:", err)
  95. }
  96. return nil
  97. }
  98. puppet.JID = types.NewJID(username, types.DefaultUserServer)
  99. puppet.Displayname = displayname.String
  100. puppet.Avatar = avatar.String
  101. puppet.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
  102. puppet.NameQuality = int8(quality.Int64)
  103. puppet.CustomMXID = id.UserID(customMXID.String)
  104. puppet.AccessToken = accessToken.String
  105. puppet.NextBatch = nextBatch.String
  106. puppet.EnablePresence = enablePresence.Bool
  107. puppet.EnableReceipts = enableReceipts.Bool
  108. return puppet
  109. }
  110. func (puppet *Puppet) Insert() {
  111. if puppet.JID.Server != types.DefaultUserServer {
  112. puppet.log.Warnfln("Not inserting %s: not a user", puppet.JID)
  113. return
  114. }
  115. _, err := puppet.db.Exec("INSERT INTO puppet (username, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch, enable_presence, enable_receipts) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)",
  116. puppet.JID.User, puppet.Avatar, puppet.AvatarURL.String(), puppet.Displayname, puppet.NameQuality, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch, puppet.EnablePresence, puppet.EnableReceipts)
  117. if err != nil {
  118. puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err)
  119. }
  120. }
  121. func (puppet *Puppet) Update() {
  122. _, err := puppet.db.Exec("UPDATE puppet SET displayname=$1, name_quality=$2, avatar=$3, avatar_url=$4, custom_mxid=$5, access_token=$6, next_batch=$7, enable_presence=$8, enable_receipts=$9 WHERE username=$10",
  123. puppet.Displayname, puppet.NameQuality, puppet.Avatar, puppet.AvatarURL.String(), puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch, puppet.EnablePresence, puppet.EnableReceipts, puppet.JID.User)
  124. if err != nil {
  125. puppet.log.Warnfln("Failed to update %s: %v", puppet.JID, err)
  126. }
  127. }