portal.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
  2. // Copyright (C) 2021 Tulir Asokan
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package database
  17. import (
  18. "database/sql"
  19. "fmt"
  20. "time"
  21. log "maunium.net/go/maulogger/v2"
  22. "maunium.net/go/mautrix/id"
  23. "maunium.net/go/mautrix/util/dbutil"
  24. "go.mau.fi/whatsmeow/types"
  25. )
  26. type PortalKey struct {
  27. JID types.JID
  28. Receiver types.JID
  29. }
  30. func NewPortalKey(jid, receiver types.JID) PortalKey {
  31. if jid.Server == types.GroupServer {
  32. receiver = jid
  33. } else if jid.Server == types.LegacyUserServer {
  34. jid.Server = types.DefaultUserServer
  35. }
  36. return PortalKey{
  37. JID: jid.ToNonAD(),
  38. Receiver: receiver.ToNonAD(),
  39. }
  40. }
  41. func (key PortalKey) String() string {
  42. if key.Receiver == key.JID {
  43. return key.JID.String()
  44. }
  45. return key.JID.String() + "-" + key.Receiver.String()
  46. }
  47. type PortalQuery struct {
  48. db *Database
  49. log log.Logger
  50. }
  51. func (pq *PortalQuery) New() *Portal {
  52. return &Portal{
  53. db: pq.db,
  54. log: pq.log,
  55. }
  56. }
  57. const portalColumns = "jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, encrypted, last_sync, is_parent, parent_group, in_space, first_event_id, next_batch_id, relay_user_id, expiration_time"
  58. func (pq *PortalQuery) GetAll() []*Portal {
  59. return pq.getAll(fmt.Sprintf("SELECT %s FROM portal", portalColumns))
  60. }
  61. func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
  62. return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1 AND receiver=$2", portalColumns), key.JID, key.Receiver)
  63. }
  64. func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
  65. return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE mxid=$1", portalColumns), mxid)
  66. }
  67. func (pq *PortalQuery) GetAllByJID(jid types.JID) []*Portal {
  68. return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1", portalColumns), jid.ToNonAD())
  69. }
  70. func (pq *PortalQuery) FindPrivateChats(receiver types.JID) []*Portal {
  71. return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE receiver=$1 AND jid LIKE '%%@s.whatsapp.net'", portalColumns), receiver.ToNonAD())
  72. }
  73. func (pq *PortalQuery) FindPrivateChatsNotInSpace(receiver types.JID) (keys []PortalKey) {
  74. receiver = receiver.ToNonAD()
  75. rows, err := pq.db.Query(`
  76. SELECT jid FROM portal
  77. LEFT JOIN user_portal ON portal.jid=user_portal.portal_jid AND portal.receiver=user_portal.portal_receiver
  78. WHERE mxid<>'' AND receiver=$1 AND (in_space=false OR in_space IS NULL)
  79. `, receiver)
  80. if err != nil || rows == nil {
  81. return
  82. }
  83. for rows.Next() {
  84. var key PortalKey
  85. key.Receiver = receiver
  86. err = rows.Scan(&key.JID)
  87. if err == nil {
  88. keys = append(keys, key)
  89. }
  90. }
  91. return
  92. }
  93. func (pq *PortalQuery) getAll(query string, args ...interface{}) (portals []*Portal) {
  94. rows, err := pq.db.Query(query, args...)
  95. if err != nil || rows == nil {
  96. return nil
  97. }
  98. defer rows.Close()
  99. for rows.Next() {
  100. portals = append(portals, pq.New().Scan(rows))
  101. }
  102. return
  103. }
  104. func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
  105. row := pq.db.QueryRow(query, args...)
  106. if row == nil {
  107. return nil
  108. }
  109. return pq.New().Scan(row)
  110. }
  111. type Portal struct {
  112. db *Database
  113. log log.Logger
  114. Key PortalKey
  115. MXID id.RoomID
  116. Name string
  117. NameSet bool
  118. Topic string
  119. TopicSet bool
  120. Avatar string
  121. AvatarURL id.ContentURI
  122. AvatarSet bool
  123. Encrypted bool
  124. LastSync time.Time
  125. IsParent bool
  126. ParentGroup types.JID
  127. InSpace bool
  128. FirstEventID id.EventID
  129. NextBatchID id.BatchID
  130. RelayUserID id.UserID
  131. ExpirationTime uint32
  132. }
  133. func (portal *Portal) Scan(row dbutil.Scannable) *Portal {
  134. var mxid, avatarURL, firstEventID, nextBatchID, relayUserID, parentGroupJID sql.NullString
  135. var lastSyncTs int64
  136. err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.NameSet, &portal.Topic, &portal.TopicSet, &portal.Avatar, &avatarURL, &portal.AvatarSet, &portal.Encrypted, &lastSyncTs, &portal.IsParent, &parentGroupJID, &portal.InSpace, &firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime)
  137. if err != nil {
  138. if err != sql.ErrNoRows {
  139. portal.log.Errorln("Database scan failed:", err)
  140. }
  141. return nil
  142. }
  143. if lastSyncTs > 0 {
  144. portal.LastSync = time.Unix(lastSyncTs, 0)
  145. }
  146. portal.MXID = id.RoomID(mxid.String)
  147. portal.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
  148. if parentGroupJID.Valid {
  149. portal.ParentGroup, _ = types.ParseJID(parentGroupJID.String)
  150. }
  151. portal.FirstEventID = id.EventID(firstEventID.String)
  152. portal.NextBatchID = id.BatchID(nextBatchID.String)
  153. portal.RelayUserID = id.UserID(relayUserID.String)
  154. return portal
  155. }
  156. func (portal *Portal) mxidPtr() *id.RoomID {
  157. if len(portal.MXID) > 0 {
  158. return &portal.MXID
  159. }
  160. return nil
  161. }
  162. func (portal *Portal) relayUserPtr() *id.UserID {
  163. if len(portal.RelayUserID) > 0 {
  164. return &portal.RelayUserID
  165. }
  166. return nil
  167. }
  168. func (portal *Portal) parentGroupPtr() *string {
  169. if !portal.ParentGroup.IsEmpty() {
  170. val := portal.ParentGroup.String()
  171. return &val
  172. }
  173. return nil
  174. }
  175. func (portal *Portal) lastSyncTs() int64 {
  176. if portal.LastSync.IsZero() {
  177. return 0
  178. }
  179. return portal.LastSync.Unix()
  180. }
  181. func (portal *Portal) Insert() {
  182. _, err := portal.db.Exec(`
  183. INSERT INTO portal (jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
  184. encrypted, last_sync, is_parent, parent_group, in_space, first_event_id, next_batch_id,
  185. relay_user_id, expiration_time)
  186. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
  187. `,
  188. portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.NameSet, portal.Topic, portal.TopicSet,
  189. portal.Avatar, portal.AvatarURL.String(), portal.AvatarSet, portal.Encrypted, portal.lastSyncTs(),
  190. portal.IsParent, portal.parentGroupPtr(), portal.InSpace, portal.FirstEventID.String(), portal.NextBatchID.String(),
  191. portal.relayUserPtr(), portal.ExpirationTime)
  192. if err != nil {
  193. portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
  194. }
  195. }
  196. func (portal *Portal) Update(txn dbutil.Execable) {
  197. if txn == nil {
  198. txn = portal.db
  199. }
  200. _, err := txn.Exec(`
  201. UPDATE portal
  202. SET mxid=$1, name=$2, name_set=$3, topic=$4, topic_set=$5, avatar=$6, avatar_url=$7, avatar_set=$8,
  203. encrypted=$9, last_sync=$10, is_parent=$11, parent_group=$12, in_space=$13,
  204. first_event_id=$14, next_batch_id=$15, relay_user_id=$16, expiration_time=$17
  205. WHERE jid=$18 AND receiver=$19
  206. `, portal.mxidPtr(), portal.Name, portal.NameSet, portal.Topic, portal.TopicSet, portal.Avatar, portal.AvatarURL.String(),
  207. portal.AvatarSet, portal.Encrypted, portal.lastSyncTs(), portal.IsParent, portal.parentGroupPtr(), portal.InSpace,
  208. portal.FirstEventID.String(), portal.NextBatchID.String(), portal.relayUserPtr(), portal.ExpirationTime,
  209. portal.Key.JID, portal.Key.Receiver)
  210. if err != nil {
  211. portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
  212. }
  213. }
  214. func (portal *Portal) Delete() {
  215. _, err := portal.db.Exec("DELETE FROM portal WHERE jid=$1 AND receiver=$2", portal.Key.JID, portal.Key.Receiver)
  216. if err != nil {
  217. portal.log.Warnfln("Failed to delete %s: %v", portal.Key, err)
  218. }
  219. }