portal.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
  2. // Copyright (C) 2018 Tulir Asokan
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package database
  17. import (
  18. log "maunium.net/go/maulogger"
  19. "maunium.net/go/mautrix-whatsapp/types"
  20. "database/sql"
  21. )
  22. type PortalQuery struct {
  23. db *Database
  24. log log.Logger
  25. }
  26. func (pq *PortalQuery) CreateTable() error {
  27. _, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS portal (
  28. jid VARCHAR(255),
  29. owner VARCHAR(255),
  30. mxid VARCHAR(255) UNIQUE,
  31. name VARCHAR(255) NOT NULL,
  32. topic VARCHAR(255) NOT NULL,
  33. avatar VARCHAR(255) NOT NULL,
  34. PRIMARY KEY (jid, owner),
  35. FOREIGN KEY (owner) REFERENCES user(mxid)
  36. )`)
  37. return err
  38. }
  39. func (pq *PortalQuery) New() *Portal {
  40. return &Portal{
  41. db: pq.db,
  42. log: pq.log,
  43. }
  44. }
  45. func (pq *PortalQuery) GetAll(owner types.MatrixUserID) (portals []*Portal) {
  46. rows, err := pq.db.Query("SELECT * FROM portal WHERE owner=?", owner)
  47. if err != nil || rows == nil {
  48. return nil
  49. }
  50. defer rows.Close()
  51. for rows.Next() {
  52. portals = append(portals, pq.New().Scan(rows))
  53. }
  54. return
  55. }
  56. func (pq *PortalQuery) GetByJID(owner types.MatrixUserID, jid types.WhatsAppID) *Portal {
  57. return pq.get("SELECT * FROM portal WHERE jid=? AND owner=?", jid, owner)
  58. }
  59. func (pq *PortalQuery) GetByMXID(mxid types.MatrixRoomID) *Portal {
  60. return pq.get("SELECT * FROM portal WHERE mxid=?", mxid)
  61. }
  62. func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
  63. row := pq.db.QueryRow(query, args...)
  64. if row == nil {
  65. return nil
  66. }
  67. return pq.New().Scan(row)
  68. }
  69. type Portal struct {
  70. db *Database
  71. log log.Logger
  72. JID types.WhatsAppID
  73. MXID types.MatrixRoomID
  74. Owner types.MatrixUserID
  75. Name string
  76. Topic string
  77. Avatar string
  78. }
  79. func (portal *Portal) Scan(row Scannable) *Portal {
  80. err := row.Scan(&portal.JID, &portal.Owner, &portal.MXID, &portal.Name, &portal.Topic, &portal.Avatar)
  81. if err != nil {
  82. if err != sql.ErrNoRows {
  83. portal.log.Errorln("Database scan failed:", err)
  84. }
  85. return nil
  86. }
  87. return portal
  88. }
  89. func (portal *Portal) Insert() error {
  90. var mxid *string
  91. if len(portal.MXID) > 0 {
  92. mxid = &portal.MXID
  93. }
  94. _, err := portal.db.Exec("INSERT INTO portal VALUES (?, ?, ?, ?, ?, ?)",
  95. portal.JID, portal.Owner, mxid, portal.Name, portal.Topic, portal.Avatar)
  96. if err != nil {
  97. portal.log.Warnfln("Failed to insert %s->%s: %v", portal.JID, portal.Owner, err)
  98. }
  99. return err
  100. }
  101. func (portal *Portal) Update() error {
  102. var mxid *string
  103. if len(portal.MXID) > 0 {
  104. mxid = &portal.MXID
  105. }
  106. _, err := portal.db.Exec("UPDATE portal SET mxid=?, name=?, topic=?, avatar=? WHERE jid=? AND owner=?",
  107. mxid, portal.Name, portal.Topic, portal.Avatar, portal.JID, portal.Owner)
  108. if err != nil {
  109. portal.log.Warnfln("Failed to update %s->%s: %v", portal.JID, portal.Owner, err)
  110. }
  111. return err
  112. }