portal.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
  2. // Copyright (C) 2021 Tulir Asokan
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package database
  17. import (
  18. "database/sql"
  19. log "maunium.net/go/maulogger/v2"
  20. "maunium.net/go/mautrix/id"
  21. "maunium.net/go/mautrix/util/dbutil"
  22. "go.mau.fi/whatsmeow/types"
  23. )
  24. type PortalKey struct {
  25. JID types.JID
  26. Receiver types.JID
  27. }
  28. func NewPortalKey(jid, receiver types.JID) PortalKey {
  29. if jid.Server == types.GroupServer {
  30. receiver = jid
  31. } else if jid.Server == types.LegacyUserServer {
  32. jid.Server = types.DefaultUserServer
  33. }
  34. return PortalKey{
  35. JID: jid.ToNonAD(),
  36. Receiver: receiver.ToNonAD(),
  37. }
  38. }
  39. func (key PortalKey) String() string {
  40. if key.Receiver == key.JID {
  41. return key.JID.String()
  42. }
  43. return key.JID.String() + "-" + key.Receiver.String()
  44. }
  45. type PortalQuery struct {
  46. db *Database
  47. log log.Logger
  48. }
  49. func (pq *PortalQuery) New() *Portal {
  50. return &Portal{
  51. db: pq.db,
  52. log: pq.log,
  53. }
  54. }
  55. func (pq *PortalQuery) GetAll() []*Portal {
  56. return pq.getAll("SELECT * FROM portal")
  57. }
  58. func (pq *PortalQuery) GetAllForUser(userID id.UserID) []*Portal {
  59. return pq.getAll(`
  60. SELECT p.* FROM portal p
  61. LEFT JOIN user_portal up ON p.jid=up.portal_jid AND p.receiver=up.portal_receiver
  62. WHERE mxid<>'' AND up.user_mxid=$1
  63. `, userID)
  64. }
  65. func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
  66. return pq.get("SELECT * FROM portal WHERE jid=$1 AND receiver=$2", key.JID, key.Receiver)
  67. }
  68. func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
  69. return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid)
  70. }
  71. func (pq *PortalQuery) GetAllByJID(jid types.JID) []*Portal {
  72. return pq.getAll("SELECT * FROM portal WHERE jid=$1", jid.ToNonAD())
  73. }
  74. func (pq *PortalQuery) FindPrivateChats(receiver types.JID) []*Portal {
  75. return pq.getAll("SELECT * FROM portal WHERE receiver=$1 AND jid LIKE '%@s.whatsapp.net'", receiver.ToNonAD())
  76. }
  77. func (pq *PortalQuery) FindPrivateChatsNotInSpace(receiver types.JID) (keys []PortalKey) {
  78. receiver = receiver.ToNonAD()
  79. rows, err := pq.db.Query(`
  80. SELECT jid FROM portal
  81. LEFT JOIN user_portal ON portal.jid=user_portal.portal_jid AND portal.receiver=user_portal.portal_receiver
  82. WHERE mxid<>'' AND receiver=$1 AND (in_space=false OR in_space IS NULL)
  83. `, receiver)
  84. if err != nil || rows == nil {
  85. return
  86. }
  87. for rows.Next() {
  88. var key PortalKey
  89. key.Receiver = receiver
  90. err = rows.Scan(&key.JID)
  91. if err == nil {
  92. keys = append(keys, key)
  93. }
  94. }
  95. return
  96. }
  97. func (pq *PortalQuery) getAll(query string, args ...interface{}) (portals []*Portal) {
  98. rows, err := pq.db.Query(query, args...)
  99. if err != nil || rows == nil {
  100. return nil
  101. }
  102. defer rows.Close()
  103. for rows.Next() {
  104. portals = append(portals, pq.New().Scan(rows))
  105. }
  106. return
  107. }
  108. func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
  109. row := pq.db.QueryRow(query, args...)
  110. if row == nil {
  111. return nil
  112. }
  113. return pq.New().Scan(row)
  114. }
  115. type Portal struct {
  116. db *Database
  117. log log.Logger
  118. Key PortalKey
  119. MXID id.RoomID
  120. Name string
  121. Topic string
  122. Avatar string
  123. AvatarURL id.ContentURI
  124. Encrypted bool
  125. FirstEventID id.EventID
  126. NextBatchID id.BatchID
  127. RelayUserID id.UserID
  128. ExpirationTime uint32
  129. }
  130. func (portal *Portal) Scan(row dbutil.Scannable) *Portal {
  131. var mxid, avatarURL, firstEventID, nextBatchID, relayUserID sql.NullString
  132. err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.Topic, &portal.Avatar, &avatarURL, &portal.Encrypted, &firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime)
  133. if err != nil {
  134. if err != sql.ErrNoRows {
  135. portal.log.Errorln("Database scan failed:", err)
  136. }
  137. return nil
  138. }
  139. portal.MXID = id.RoomID(mxid.String)
  140. portal.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
  141. portal.FirstEventID = id.EventID(firstEventID.String)
  142. portal.NextBatchID = id.BatchID(nextBatchID.String)
  143. portal.RelayUserID = id.UserID(relayUserID.String)
  144. return portal
  145. }
  146. func (portal *Portal) mxidPtr() *id.RoomID {
  147. if len(portal.MXID) > 0 {
  148. return &portal.MXID
  149. }
  150. return nil
  151. }
  152. func (portal *Portal) relayUserPtr() *id.UserID {
  153. if len(portal.RelayUserID) > 0 {
  154. return &portal.RelayUserID
  155. }
  156. return nil
  157. }
  158. func (portal *Portal) Insert() {
  159. _, err := portal.db.Exec("INSERT INTO portal (jid, receiver, mxid, name, topic, avatar, avatar_url, encrypted, first_event_id, next_batch_id, relay_user_id, expiration_time) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)",
  160. portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL.String(), portal.Encrypted, portal.FirstEventID.String(), portal.NextBatchID.String(), portal.relayUserPtr(), portal.ExpirationTime)
  161. if err != nil {
  162. portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
  163. }
  164. }
  165. func (portal *Portal) Update(txn *sql.Tx) {
  166. query := `
  167. UPDATE portal
  168. SET mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, encrypted=$6, first_event_id=$7, next_batch_id=$8, relay_user_id=$9, expiration_time=$10
  169. WHERE jid=$11 AND receiver=$12
  170. `
  171. args := []interface{}{
  172. portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL.String(), portal.Encrypted, portal.FirstEventID.String(), portal.NextBatchID.String(), portal.relayUserPtr(), portal.ExpirationTime, portal.Key.JID, portal.Key.Receiver,
  173. }
  174. var err error
  175. if txn != nil {
  176. _, err = txn.Exec(query, args...)
  177. } else {
  178. _, err = portal.db.Exec(query, args...)
  179. }
  180. if err != nil {
  181. portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
  182. }
  183. }
  184. func (portal *Portal) Delete() {
  185. _, err := portal.db.Exec("DELETE FROM portal WHERE jid=$1 AND receiver=$2", portal.Key.JID, portal.Key.Receiver)
  186. if err != nil {
  187. portal.log.Warnfln("Failed to delete %s: %v", portal.Key, err)
  188. }
  189. }