portal.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
  2. // Copyright (C) 2021 Tulir Asokan
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package database
  17. import (
  18. "database/sql"
  19. "fmt"
  20. "time"
  21. "go.mau.fi/util/dbutil"
  22. "go.mau.fi/whatsmeow/types"
  23. log "maunium.net/go/maulogger/v2"
  24. "maunium.net/go/mautrix/id"
  25. )
  26. type PortalKey struct {
  27. JID types.JID
  28. Receiver types.JID
  29. }
  30. func NewPortalKey(jid, receiver types.JID) PortalKey {
  31. if jid.Server == types.GroupServer {
  32. receiver = jid
  33. } else if jid.Server == types.LegacyUserServer {
  34. jid.Server = types.DefaultUserServer
  35. }
  36. return PortalKey{
  37. JID: jid.ToNonAD(),
  38. Receiver: receiver.ToNonAD(),
  39. }
  40. }
  41. func (key PortalKey) String() string {
  42. if key.Receiver == key.JID {
  43. return key.JID.String()
  44. }
  45. return key.JID.String() + "-" + key.Receiver.String()
  46. }
  47. type PortalQuery struct {
  48. db *Database
  49. log log.Logger
  50. }
  51. func (pq *PortalQuery) New() *Portal {
  52. return &Portal{
  53. db: pq.db,
  54. log: pq.log,
  55. }
  56. }
  57. const portalColumns = "jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, encrypted, last_sync, is_parent, parent_group, in_space, first_event_id, next_batch_id, relay_user_id, expiration_time"
  58. func (pq *PortalQuery) GetAll() []*Portal {
  59. return pq.getAll(fmt.Sprintf("SELECT %s FROM portal", portalColumns))
  60. }
  61. func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
  62. return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1 AND receiver=$2", portalColumns), key.JID, key.Receiver)
  63. }
  64. func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
  65. return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE mxid=$1", portalColumns), mxid)
  66. }
  67. func (pq *PortalQuery) GetAllByJID(jid types.JID) []*Portal {
  68. return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1", portalColumns), jid.ToNonAD())
  69. }
  70. func (pq *PortalQuery) GetAllByParentGroup(jid types.JID) []*Portal {
  71. return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE parent_group=$1", portalColumns), jid)
  72. }
  73. func (pq *PortalQuery) FindPrivateChats(receiver types.JID) []*Portal {
  74. return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE receiver=$1 AND jid LIKE '%%@s.whatsapp.net'", portalColumns), receiver.ToNonAD())
  75. }
  76. func (pq *PortalQuery) FindPrivateChatsNotInSpace(receiver types.JID) (keys []PortalKey) {
  77. receiver = receiver.ToNonAD()
  78. rows, err := pq.db.Query(`
  79. SELECT jid FROM portal
  80. LEFT JOIN user_portal ON portal.jid=user_portal.portal_jid AND portal.receiver=user_portal.portal_receiver
  81. WHERE mxid<>'' AND receiver=$1 AND (user_portal.in_space=false OR user_portal.in_space IS NULL)
  82. `, receiver)
  83. if err != nil {
  84. pq.log.Errorfln("Failed to find private chats not in space for %s: %v", receiver, err)
  85. return
  86. } else if rows == nil {
  87. return
  88. }
  89. for rows.Next() {
  90. var key PortalKey
  91. key.Receiver = receiver
  92. err = rows.Scan(&key.JID)
  93. if err == nil {
  94. keys = append(keys, key)
  95. }
  96. }
  97. return
  98. }
  99. func (pq *PortalQuery) getAll(query string, args ...interface{}) (portals []*Portal) {
  100. rows, err := pq.db.Query(query, args...)
  101. if err != nil || rows == nil {
  102. return nil
  103. }
  104. defer rows.Close()
  105. for rows.Next() {
  106. portals = append(portals, pq.New().Scan(rows))
  107. }
  108. return
  109. }
  110. func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
  111. row := pq.db.QueryRow(query, args...)
  112. if row == nil {
  113. return nil
  114. }
  115. return pq.New().Scan(row)
  116. }
  117. type Portal struct {
  118. db *Database
  119. log log.Logger
  120. Key PortalKey
  121. MXID id.RoomID
  122. Name string
  123. NameSet bool
  124. Topic string
  125. TopicSet bool
  126. Avatar string
  127. AvatarURL id.ContentURI
  128. AvatarSet bool
  129. Encrypted bool
  130. LastSync time.Time
  131. IsParent bool
  132. ParentGroup types.JID
  133. InSpace bool
  134. FirstEventID id.EventID
  135. NextBatchID id.BatchID
  136. RelayUserID id.UserID
  137. ExpirationTime uint32
  138. }
  139. func (portal *Portal) Scan(row dbutil.Scannable) *Portal {
  140. var mxid, avatarURL, firstEventID, nextBatchID, relayUserID, parentGroupJID sql.NullString
  141. var lastSyncTs int64
  142. err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.NameSet, &portal.Topic, &portal.TopicSet, &portal.Avatar, &avatarURL, &portal.AvatarSet, &portal.Encrypted, &lastSyncTs, &portal.IsParent, &parentGroupJID, &portal.InSpace, &firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime)
  143. if err != nil {
  144. if err != sql.ErrNoRows {
  145. portal.log.Errorln("Database scan failed:", err)
  146. }
  147. return nil
  148. }
  149. if lastSyncTs > 0 {
  150. portal.LastSync = time.Unix(lastSyncTs, 0)
  151. }
  152. portal.MXID = id.RoomID(mxid.String)
  153. portal.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
  154. if parentGroupJID.Valid {
  155. portal.ParentGroup, _ = types.ParseJID(parentGroupJID.String)
  156. }
  157. portal.FirstEventID = id.EventID(firstEventID.String)
  158. portal.NextBatchID = id.BatchID(nextBatchID.String)
  159. portal.RelayUserID = id.UserID(relayUserID.String)
  160. return portal
  161. }
  162. func (portal *Portal) mxidPtr() *id.RoomID {
  163. if len(portal.MXID) > 0 {
  164. return &portal.MXID
  165. }
  166. return nil
  167. }
  168. func (portal *Portal) relayUserPtr() *id.UserID {
  169. if len(portal.RelayUserID) > 0 {
  170. return &portal.RelayUserID
  171. }
  172. return nil
  173. }
  174. func (portal *Portal) parentGroupPtr() *string {
  175. if !portal.ParentGroup.IsEmpty() {
  176. val := portal.ParentGroup.String()
  177. return &val
  178. }
  179. return nil
  180. }
  181. func (portal *Portal) lastSyncTs() int64 {
  182. if portal.LastSync.IsZero() {
  183. return 0
  184. }
  185. return portal.LastSync.Unix()
  186. }
  187. func (portal *Portal) Insert() {
  188. _, err := portal.db.Exec(`
  189. INSERT INTO portal (jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
  190. encrypted, last_sync, is_parent, parent_group, in_space, first_event_id, next_batch_id,
  191. relay_user_id, expiration_time)
  192. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
  193. `,
  194. portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.NameSet, portal.Topic, portal.TopicSet,
  195. portal.Avatar, portal.AvatarURL.String(), portal.AvatarSet, portal.Encrypted, portal.lastSyncTs(),
  196. portal.IsParent, portal.parentGroupPtr(), portal.InSpace, portal.FirstEventID.String(), portal.NextBatchID.String(),
  197. portal.relayUserPtr(), portal.ExpirationTime)
  198. if err != nil {
  199. portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
  200. }
  201. }
  202. func (portal *Portal) Update(txn dbutil.Execable) {
  203. if txn == nil {
  204. txn = portal.db
  205. }
  206. _, err := txn.Exec(`
  207. UPDATE portal
  208. SET mxid=$1, name=$2, name_set=$3, topic=$4, topic_set=$5, avatar=$6, avatar_url=$7, avatar_set=$8,
  209. encrypted=$9, last_sync=$10, is_parent=$11, parent_group=$12, in_space=$13,
  210. first_event_id=$14, next_batch_id=$15, relay_user_id=$16, expiration_time=$17
  211. WHERE jid=$18 AND receiver=$19
  212. `, portal.mxidPtr(), portal.Name, portal.NameSet, portal.Topic, portal.TopicSet, portal.Avatar, portal.AvatarURL.String(),
  213. portal.AvatarSet, portal.Encrypted, portal.lastSyncTs(), portal.IsParent, portal.parentGroupPtr(), portal.InSpace,
  214. portal.FirstEventID.String(), portal.NextBatchID.String(), portal.relayUserPtr(), portal.ExpirationTime,
  215. portal.Key.JID, portal.Key.Receiver)
  216. if err != nil {
  217. portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
  218. }
  219. }
  220. func (portal *Portal) Delete() {
  221. txn, err := portal.db.Begin()
  222. if err != nil {
  223. portal.log.Errorfln("Failed to begin transaction to delete portal %v: %v", portal.Key, err)
  224. return
  225. }
  226. defer func() {
  227. if err != nil {
  228. err = txn.Rollback()
  229. if err != nil {
  230. portal.log.Warnfln("Failed to rollback failed portal delete transaction: %v", err)
  231. }
  232. } else if err = txn.Commit(); err != nil {
  233. portal.log.Warnfln("Failed to commit portal delete transaction: %v", err)
  234. }
  235. }()
  236. _, err = txn.Exec("UPDATE portal SET in_space=false WHERE parent_group=$1", portal.Key.JID)
  237. if err != nil {
  238. portal.log.Warnfln("Failed to mark child groups of %v as not in space: %v", portal.Key.JID, err)
  239. return
  240. }
  241. _, err = txn.Exec("DELETE FROM portal WHERE jid=$1 AND receiver=$2", portal.Key.JID, portal.Key.Receiver)
  242. if err != nil {
  243. portal.log.Warnfln("Failed to delete %v: %v", portal.Key, err)
  244. }
  245. }