portal.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
  2. // Copyright (C) 2021 Tulir Asokan
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package database
  17. import (
  18. "database/sql"
  19. "fmt"
  20. log "maunium.net/go/maulogger/v2"
  21. "maunium.net/go/mautrix/id"
  22. "maunium.net/go/mautrix/util/dbutil"
  23. "go.mau.fi/whatsmeow/types"
  24. )
  25. type PortalKey struct {
  26. JID types.JID
  27. Receiver types.JID
  28. }
  29. func NewPortalKey(jid, receiver types.JID) PortalKey {
  30. if jid.Server == types.GroupServer {
  31. receiver = jid
  32. } else if jid.Server == types.LegacyUserServer {
  33. jid.Server = types.DefaultUserServer
  34. }
  35. return PortalKey{
  36. JID: jid.ToNonAD(),
  37. Receiver: receiver.ToNonAD(),
  38. }
  39. }
  40. func (key PortalKey) String() string {
  41. if key.Receiver == key.JID {
  42. return key.JID.String()
  43. }
  44. return key.JID.String() + "-" + key.Receiver.String()
  45. }
  46. type PortalQuery struct {
  47. db *Database
  48. log log.Logger
  49. }
  50. func (pq *PortalQuery) New() *Portal {
  51. return &Portal{
  52. db: pq.db,
  53. log: pq.log,
  54. }
  55. }
  56. const portalColumns = "jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, encrypted, first_event_id, next_batch_id, relay_user_id, expiration_time"
  57. func (pq *PortalQuery) GetAll() []*Portal {
  58. return pq.getAll(fmt.Sprintf("SELECT %s FROM portal", portalColumns))
  59. }
  60. func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
  61. return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1 AND receiver=$2", portalColumns), key.JID, key.Receiver)
  62. }
  63. func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
  64. return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE mxid=$1", portalColumns), mxid)
  65. }
  66. func (pq *PortalQuery) GetAllByJID(jid types.JID) []*Portal {
  67. return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1", portalColumns), jid.ToNonAD())
  68. }
  69. func (pq *PortalQuery) FindPrivateChats(receiver types.JID) []*Portal {
  70. return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE receiver=$1 AND jid LIKE '%@s.whatsapp.net'", portalColumns), receiver.ToNonAD())
  71. }
  72. func (pq *PortalQuery) FindPrivateChatsNotInSpace(receiver types.JID) (keys []PortalKey) {
  73. receiver = receiver.ToNonAD()
  74. rows, err := pq.db.Query(`
  75. SELECT jid FROM portal
  76. LEFT JOIN user_portal ON portal.jid=user_portal.portal_jid AND portal.receiver=user_portal.portal_receiver
  77. WHERE mxid<>'' AND receiver=$1 AND (in_space=false OR in_space IS NULL)
  78. `, receiver)
  79. if err != nil || rows == nil {
  80. return
  81. }
  82. for rows.Next() {
  83. var key PortalKey
  84. key.Receiver = receiver
  85. err = rows.Scan(&key.JID)
  86. if err == nil {
  87. keys = append(keys, key)
  88. }
  89. }
  90. return
  91. }
  92. func (pq *PortalQuery) getAll(query string, args ...interface{}) (portals []*Portal) {
  93. rows, err := pq.db.Query(query, args...)
  94. if err != nil || rows == nil {
  95. return nil
  96. }
  97. defer rows.Close()
  98. for rows.Next() {
  99. portals = append(portals, pq.New().Scan(rows))
  100. }
  101. return
  102. }
  103. func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
  104. row := pq.db.QueryRow(query, args...)
  105. if row == nil {
  106. return nil
  107. }
  108. return pq.New().Scan(row)
  109. }
  110. type Portal struct {
  111. db *Database
  112. log log.Logger
  113. Key PortalKey
  114. MXID id.RoomID
  115. Name string
  116. NameSet bool
  117. Topic string
  118. TopicSet bool
  119. Avatar string
  120. AvatarURL id.ContentURI
  121. AvatarSet bool
  122. Encrypted bool
  123. FirstEventID id.EventID
  124. NextBatchID id.BatchID
  125. RelayUserID id.UserID
  126. ExpirationTime uint32
  127. }
  128. func (portal *Portal) Scan(row dbutil.Scannable) *Portal {
  129. var mxid, avatarURL, firstEventID, nextBatchID, relayUserID sql.NullString
  130. err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.NameSet, &portal.Topic, &portal.TopicSet, &portal.Avatar, &avatarURL, &portal.AvatarSet, &portal.Encrypted, &firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime)
  131. if err != nil {
  132. if err != sql.ErrNoRows {
  133. portal.log.Errorln("Database scan failed:", err)
  134. }
  135. return nil
  136. }
  137. portal.MXID = id.RoomID(mxid.String)
  138. portal.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
  139. portal.FirstEventID = id.EventID(firstEventID.String)
  140. portal.NextBatchID = id.BatchID(nextBatchID.String)
  141. portal.RelayUserID = id.UserID(relayUserID.String)
  142. return portal
  143. }
  144. func (portal *Portal) mxidPtr() *id.RoomID {
  145. if len(portal.MXID) > 0 {
  146. return &portal.MXID
  147. }
  148. return nil
  149. }
  150. func (portal *Portal) relayUserPtr() *id.UserID {
  151. if len(portal.RelayUserID) > 0 {
  152. return &portal.RelayUserID
  153. }
  154. return nil
  155. }
  156. func (portal *Portal) Insert() {
  157. _, err := portal.db.Exec(`
  158. INSERT INTO portal (jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
  159. encrypted, first_event_id, next_batch_id, relay_user_id, expiration_time)
  160. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
  161. `,
  162. portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.NameSet, portal.Topic, portal.TopicSet,
  163. portal.Avatar, portal.AvatarURL.String(), portal.AvatarSet, portal.Encrypted, portal.FirstEventID.String(),
  164. portal.NextBatchID.String(), portal.relayUserPtr(), portal.ExpirationTime)
  165. if err != nil {
  166. portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
  167. }
  168. }
  169. func (portal *Portal) Update(txn *sql.Tx) {
  170. query := `
  171. UPDATE portal
  172. SET mxid=$1, name=$2, name_set=$3, topic=$4, topic_set=$5, avatar=$6, avatar_url=$7, avatar_set=$8,
  173. encrypted=$9, first_event_id=$10, next_batch_id=$11, relay_user_id=$12, expiration_time=$13
  174. WHERE jid=$14 AND receiver=$15
  175. `
  176. args := []interface{}{
  177. portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL.String(), portal.Encrypted,
  178. portal.FirstEventID.String(), portal.NextBatchID.String(), portal.relayUserPtr(), portal.ExpirationTime,
  179. portal.Key.JID, portal.Key.Receiver,
  180. }
  181. var err error
  182. if txn != nil {
  183. _, err = txn.Exec(query, args...)
  184. } else {
  185. _, err = portal.db.Exec(query, args...)
  186. }
  187. if err != nil {
  188. portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
  189. }
  190. }
  191. func (portal *Portal) Delete() {
  192. _, err := portal.db.Exec("DELETE FROM portal WHERE jid=$1 AND receiver=$2", portal.Key.JID, portal.Key.Receiver)
  193. if err != nil {
  194. portal.log.Warnfln("Failed to delete %s: %v", portal.Key, err)
  195. }
  196. }