portal.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
  2. // Copyright (C) 2021 Tulir Asokan
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package database
  17. import (
  18. "database/sql"
  19. log "maunium.net/go/maulogger/v2"
  20. "maunium.net/go/mautrix/id"
  21. "go.mau.fi/whatsmeow/types"
  22. )
  23. type PortalKey struct {
  24. JID types.JID
  25. Receiver types.JID
  26. }
  27. func GroupPortalKey(jid types.JID) PortalKey {
  28. return NewPortalKey(jid, jid)
  29. }
  30. func NewPortalKey(jid, receiver types.JID) PortalKey {
  31. if jid.Server == types.GroupServer {
  32. receiver = jid
  33. } else if jid.Server == types.LegacyUserServer {
  34. jid.Server = types.DefaultUserServer
  35. }
  36. return PortalKey{
  37. JID: jid.ToNonAD(),
  38. Receiver: receiver.ToNonAD(),
  39. }
  40. }
  41. func (key PortalKey) String() string {
  42. if key.Receiver == key.JID {
  43. return key.JID.String()
  44. }
  45. return key.JID.String() + "-" + key.Receiver.String()
  46. }
  47. type PortalQuery struct {
  48. db *Database
  49. log log.Logger
  50. }
  51. func (pq *PortalQuery) New() *Portal {
  52. return &Portal{
  53. db: pq.db,
  54. log: pq.log,
  55. }
  56. }
  57. func (pq *PortalQuery) GetAll() []*Portal {
  58. return pq.getAll("SELECT * FROM portal")
  59. }
  60. func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
  61. return pq.get("SELECT * FROM portal WHERE jid=$1 AND receiver=$2", key.JID, key.Receiver)
  62. }
  63. func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
  64. return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid)
  65. }
  66. func (pq *PortalQuery) GetAllByJID(jid types.JID) []*Portal {
  67. return pq.getAll("SELECT * FROM portal WHERE jid=$1", jid.ToNonAD())
  68. }
  69. func (pq *PortalQuery) FindPrivateChats(receiver types.JID) []*Portal {
  70. return pq.getAll("SELECT * FROM portal WHERE receiver=$1 AND jid LIKE '%@s.whatsapp.net'", receiver.ToNonAD())
  71. }
  72. func (pq *PortalQuery) getAll(query string, args ...interface{}) (portals []*Portal) {
  73. rows, err := pq.db.Query(query, args...)
  74. if err != nil || rows == nil {
  75. return nil
  76. }
  77. defer rows.Close()
  78. for rows.Next() {
  79. portals = append(portals, pq.New().Scan(rows))
  80. }
  81. return
  82. }
  83. func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
  84. row := pq.db.QueryRow(query, args...)
  85. if row == nil {
  86. return nil
  87. }
  88. return pq.New().Scan(row)
  89. }
  90. type Portal struct {
  91. db *Database
  92. log log.Logger
  93. Key PortalKey
  94. MXID id.RoomID
  95. Name string
  96. Topic string
  97. Avatar string
  98. AvatarURL id.ContentURI
  99. Encrypted bool
  100. FirstEventID id.EventID
  101. NextBatchID id.BatchID
  102. RelayUserID id.UserID
  103. }
  104. func (portal *Portal) Scan(row Scannable) *Portal {
  105. var mxid, avatarURL, firstEventID, nextBatchID, relayUserID sql.NullString
  106. err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.Topic, &portal.Avatar, &avatarURL, &portal.Encrypted, &firstEventID, &nextBatchID, &relayUserID)
  107. if err != nil {
  108. if err != sql.ErrNoRows {
  109. portal.log.Errorln("Database scan failed:", err)
  110. }
  111. return nil
  112. }
  113. portal.MXID = id.RoomID(mxid.String)
  114. portal.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
  115. portal.FirstEventID = id.EventID(firstEventID.String)
  116. portal.NextBatchID = id.BatchID(nextBatchID.String)
  117. portal.RelayUserID = id.UserID(relayUserID.String)
  118. return portal
  119. }
  120. func (portal *Portal) mxidPtr() *id.RoomID {
  121. if len(portal.MXID) > 0 {
  122. return &portal.MXID
  123. }
  124. return nil
  125. }
  126. func (portal *Portal) relayUserPtr() *id.UserID {
  127. if len(portal.RelayUserID) > 0 {
  128. return &portal.RelayUserID
  129. }
  130. return nil
  131. }
  132. func (portal *Portal) Insert() {
  133. _, err := portal.db.Exec("INSERT INTO portal (jid, receiver, mxid, name, topic, avatar, avatar_url, encrypted, first_event_id, next_batch_id, relay_user_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)",
  134. portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL.String(), portal.Encrypted, portal.FirstEventID.String(), portal.NextBatchID.String(), portal.relayUserPtr())
  135. if err != nil {
  136. portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
  137. }
  138. }
  139. func (portal *Portal) Update() {
  140. _, err := portal.db.Exec("UPDATE portal SET mxid=$3, name=$4, topic=$5, avatar=$6, avatar_url=$7, encrypted=$8, first_event_id=$9, next_batch_id=$10, relay_user_id=$11 WHERE jid=$1 AND receiver=$2",
  141. portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL.String(), portal.Encrypted, portal.FirstEventID.String(), portal.NextBatchID.String(), portal.relayUserPtr())
  142. if err != nil {
  143. portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
  144. }
  145. }
  146. func (portal *Portal) Delete() {
  147. _, err := portal.db.Exec("DELETE FROM portal WHERE jid=$1 AND receiver=$2", portal.Key.JID, portal.Key.Receiver)
  148. if err != nil {
  149. portal.log.Warnfln("Failed to delete %s: %v", portal.Key, err)
  150. }
  151. }
  152. //func (portal *Portal) GetUserIDs() []id.UserID {
  153. // rows, err := portal.db.Query(`SELECT "user".mxid FROM "user", user_portal
  154. // WHERE "user".jid=user_portal.user_jid
  155. // AND user_portal.portal_jid=$1
  156. // AND user_portal.portal_receiver=$2`,
  157. // portal.Key.JID, portal.Key.Receiver)
  158. // if err != nil {
  159. // portal.log.Debugln("Failed to get portal user ids:", err)
  160. // return nil
  161. // }
  162. // var userIDs []id.UserID
  163. // for rows.Next() {
  164. // var userID id.UserID
  165. // err = rows.Scan(&userID)
  166. // if err != nil {
  167. // portal.log.Warnln("Failed to scan row:", err)
  168. // continue
  169. // }
  170. // userIDs = append(userIDs, userID)
  171. // }
  172. // return userIDs
  173. //}