portal.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
  2. // Copyright (C) 2018 Tulir Asokan
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package database
  17. import (
  18. "database/sql"
  19. "strings"
  20. log "maunium.net/go/maulogger/v2"
  21. "maunium.net/go/mautrix-whatsapp/types"
  22. )
  23. type PortalKey struct {
  24. JID types.WhatsAppID
  25. Receiver types.WhatsAppID
  26. }
  27. func GroupPortalKey(jid types.WhatsAppID) PortalKey {
  28. return PortalKey{
  29. JID: jid,
  30. Receiver: jid,
  31. }
  32. }
  33. func NewPortalKey(jid, receiver types.WhatsAppID) PortalKey {
  34. if strings.HasSuffix(jid, "@g.us") {
  35. receiver = jid
  36. }
  37. return PortalKey{
  38. JID: jid,
  39. Receiver: receiver,
  40. }
  41. }
  42. func (key PortalKey) String() string {
  43. if key.Receiver == key.JID {
  44. return key.JID
  45. }
  46. return key.JID + "-" + key.Receiver
  47. }
  48. type PortalQuery struct {
  49. db *Database
  50. log log.Logger
  51. }
  52. func (pq *PortalQuery) CreateTable() error {
  53. _, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS portal (
  54. jid VARCHAR(25),
  55. receiver VARCHAR(25),
  56. mxid VARCHAR(255) UNIQUE,
  57. name VARCHAR(255) NOT NULL,
  58. topic VARCHAR(255) NOT NULL,
  59. avatar VARCHAR(255) NOT NULL,
  60. PRIMARY KEY (jid, receiver),
  61. FOREIGN KEY (receiver) REFERENCES user(mxid)
  62. )`)
  63. return err
  64. }
  65. func (pq *PortalQuery) New() *Portal {
  66. return &Portal{
  67. db: pq.db,
  68. log: pq.log,
  69. }
  70. }
  71. func (pq *PortalQuery) GetAll() (portals []*Portal) {
  72. rows, err := pq.db.Query("SELECT * FROM portal")
  73. if err != nil || rows == nil {
  74. return nil
  75. }
  76. defer rows.Close()
  77. for rows.Next() {
  78. portals = append(portals, pq.New().Scan(rows))
  79. }
  80. return
  81. }
  82. func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
  83. return pq.get("SELECT * FROM portal WHERE jid=? AND receiver=?", key.JID, key.Receiver)
  84. }
  85. func (pq *PortalQuery) GetByMXID(mxid types.MatrixRoomID) *Portal {
  86. return pq.get("SELECT * FROM portal WHERE mxid=?", mxid)
  87. }
  88. func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
  89. row := pq.db.QueryRow(query, args...)
  90. if row == nil {
  91. return nil
  92. }
  93. return pq.New().Scan(row)
  94. }
  95. type Portal struct {
  96. db *Database
  97. log log.Logger
  98. Key PortalKey
  99. MXID types.MatrixRoomID
  100. Name string
  101. Topic string
  102. Avatar string
  103. }
  104. func (portal *Portal) Scan(row Scannable) *Portal {
  105. var mxid sql.NullString
  106. err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.Topic, &portal.Avatar)
  107. if err != nil {
  108. if err != sql.ErrNoRows {
  109. portal.log.Errorln("Database scan failed:", err)
  110. }
  111. return nil
  112. }
  113. portal.MXID = mxid.String
  114. return portal
  115. }
  116. func (portal *Portal) mxidPtr() *string {
  117. if len(portal.MXID) > 0 {
  118. return &portal.MXID
  119. }
  120. return nil
  121. }
  122. func (portal *Portal) Insert() error {
  123. _, err := portal.db.Exec("INSERT INTO portal VALUES (?, ?, ?, ?, ?, ?)",
  124. portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar)
  125. if err != nil {
  126. portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
  127. }
  128. return err
  129. }
  130. func (portal *Portal) Update() error {
  131. var mxid *string
  132. if len(portal.MXID) > 0 {
  133. mxid = &portal.MXID
  134. }
  135. _, err := portal.db.Exec("UPDATE portal SET mxid=?, name=?, topic=?, avatar=? WHERE jid=? AND receiver=?",
  136. mxid, portal.Name, portal.Topic, portal.Avatar, portal.Key.JID, portal.Key.Receiver)
  137. if err != nil {
  138. portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
  139. }
  140. return err
  141. }