cryptostore.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
  2. // Copyright (C) 2020 Tulir Asokan
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. // +build cgo
  17. package database
  18. import (
  19. "database/sql"
  20. "fmt"
  21. "strings"
  22. "sync"
  23. "github.com/lib/pq"
  24. "github.com/pkg/errors"
  25. log "maunium.net/go/maulogger/v2"
  26. "maunium.net/go/mautrix/crypto"
  27. "maunium.net/go/mautrix/crypto/olm"
  28. "maunium.net/go/mautrix/id"
  29. )
  30. type SQLCryptoStore struct {
  31. db *Database
  32. log log.Logger
  33. UserID id.UserID
  34. DeviceID id.DeviceID
  35. SyncToken string
  36. PickleKey []byte
  37. Account *crypto.OlmAccount
  38. GhostIDFormat string
  39. OGSLock sync.RWMutex
  40. OutGroupSessions map[id.RoomID]*crypto.OutboundGroupSession
  41. }
  42. var _ crypto.Store = (*SQLCryptoStore)(nil)
  43. func NewSQLCryptoStore(db *Database, deviceID id.DeviceID) *SQLCryptoStore {
  44. return &SQLCryptoStore{
  45. db: db,
  46. log: db.log.Sub("CryptoStore"),
  47. PickleKey: []byte("maunium.net/go/mautrix-whatsapp"),
  48. DeviceID: deviceID,
  49. OutGroupSessions: make(map[id.RoomID]*crypto.OutboundGroupSession),
  50. }
  51. }
  52. func (db *Database) FindDeviceID() (deviceID id.DeviceID) {
  53. err := db.QueryRow("SELECT device_id FROM crypto_account LIMIT 1").Scan(&deviceID)
  54. if err != nil && err != sql.ErrNoRows {
  55. db.log.Warnln("Failed to scan device ID:", err)
  56. }
  57. return
  58. }
  59. func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) (members []id.UserID, err error) {
  60. var rows *sql.Rows
  61. rows, err = store.db.Query(`
  62. SELECT user_id FROM mx_user_profile
  63. WHERE room_id=$1
  64. AND (membership='join' OR membership='invite')
  65. AND user_id<>$2
  66. AND user_id NOT LIKE $3
  67. `, roomID, store.UserID, store.GhostIDFormat)
  68. if err != nil {
  69. return
  70. }
  71. for rows.Next() {
  72. var userID id.UserID
  73. err := rows.Scan(&userID)
  74. if err != nil {
  75. store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
  76. } else {
  77. members = append(members, userID)
  78. }
  79. }
  80. return
  81. }
  82. func (store *SQLCryptoStore) Flush() error {
  83. return nil
  84. }
  85. func (store *SQLCryptoStore) PutNextBatch(nextBatch string) {
  86. store.SyncToken = nextBatch
  87. _, err := store.db.Exec(`UPDATE crypto_account SET sync_token=$1 WHERE device_id=$2`, store.SyncToken, store.DeviceID)
  88. if err != nil {
  89. store.log.Warnln("Failed to store sync token:", err)
  90. }
  91. }
  92. func (store *SQLCryptoStore) GetNextBatch() string {
  93. if store.SyncToken == "" {
  94. err := store.db.
  95. QueryRow("SELECT sync_token FROM crypto_account WHERE device_id=$1", store.DeviceID).
  96. Scan(&store.SyncToken)
  97. if err != nil && err != sql.ErrNoRows {
  98. store.log.Warnln("Failed to scan sync token:", err)
  99. }
  100. }
  101. return store.SyncToken
  102. }
  103. func (store *SQLCryptoStore) PutAccount(account *crypto.OlmAccount) error {
  104. store.Account = account
  105. bytes := account.Internal.Pickle(store.PickleKey)
  106. var err error
  107. if store.db.dialect == "postgres" {
  108. _, err = store.db.Exec(`
  109. INSERT INTO crypto_account (device_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)
  110. ON CONFLICT (device_id) DO UPDATE SET shared=$2, sync_token=$3, account=$4`,
  111. store.DeviceID, account.Shared, store.SyncToken, bytes)
  112. } else if store.db.dialect == "sqlite3" {
  113. _, err = store.db.Exec("INSERT OR REPLACE INTO crypto_account (device_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)",
  114. store.DeviceID, account.Shared, store.SyncToken, bytes)
  115. } else {
  116. err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
  117. }
  118. if err != nil {
  119. store.log.Warnln("Failed to store account:", err)
  120. }
  121. return nil
  122. }
  123. func (store *SQLCryptoStore) GetAccount() (*crypto.OlmAccount, error) {
  124. if store.Account == nil {
  125. row := store.db.QueryRow("SELECT shared, sync_token, account FROM crypto_account WHERE device_id=$1", store.DeviceID)
  126. acc := &crypto.OlmAccount{Internal: *olm.NewBlankAccount()}
  127. var accountBytes []byte
  128. err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes)
  129. if err == sql.ErrNoRows {
  130. return nil, nil
  131. } else if err != nil {
  132. return nil, err
  133. }
  134. err = acc.Internal.Unpickle(accountBytes, store.PickleKey)
  135. if err != nil {
  136. return nil, err
  137. }
  138. store.Account = acc
  139. }
  140. return store.Account, nil
  141. }
  142. func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
  143. // TODO this may need to be changed if olm sessions start expiring
  144. var sessionID id.SessionID
  145. err := store.db.QueryRow("SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 LIMIT 1", key).Scan(&sessionID)
  146. if err == sql.ErrNoRows {
  147. return false
  148. }
  149. return len(sessionID) > 0
  150. }
  151. func (store *SQLCryptoStore) GetSessions(key id.SenderKey) (crypto.OlmSessionList, error) {
  152. rows, err := store.db.Query("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 ORDER BY session_id", key)
  153. if err != nil {
  154. return nil, err
  155. }
  156. list := crypto.OlmSessionList{}
  157. for rows.Next() {
  158. sess := crypto.OlmSession{Internal: *olm.NewBlankSession()}
  159. var sessionBytes []byte
  160. err := rows.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
  161. if err != nil {
  162. return nil, err
  163. }
  164. err = sess.Internal.Unpickle(sessionBytes, store.PickleKey)
  165. if err != nil {
  166. return nil, err
  167. }
  168. list = append(list, &sess)
  169. }
  170. return list, nil
  171. }
  172. func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*crypto.OlmSession, error) {
  173. row := store.db.QueryRow("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 ORDER BY session_id DESC LIMIT 1", key)
  174. sess := crypto.OlmSession{Internal: *olm.NewBlankSession()}
  175. var sessionBytes []byte
  176. err := row.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
  177. if err == sql.ErrNoRows {
  178. return nil, nil
  179. } else if err != nil {
  180. return nil, err
  181. }
  182. return &sess, sess.Internal.Unpickle(sessionBytes, store.PickleKey)
  183. }
  184. func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *crypto.OlmSession) error {
  185. sessionBytes := session.Internal.Pickle(store.PickleKey)
  186. _, err := store.db.Exec("INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_used) VALUES ($1, $2, $3, $4, $5)",
  187. session.ID(), key, sessionBytes, session.CreationTime, session.UseTime)
  188. return err
  189. }
  190. func (store *SQLCryptoStore) UpdateSession(key id.SenderKey, session *crypto.OlmSession) error {
  191. sessionBytes := session.Internal.Pickle(store.PickleKey)
  192. _, err := store.db.Exec("UPDATE crypto_olm_session SET session=$1, last_used=$2 WHERE session_id=$3",
  193. sessionBytes, session.UseTime, session.ID())
  194. return err
  195. }
  196. func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *crypto.InboundGroupSession) error {
  197. sessionBytes := session.Internal.Pickle(store.PickleKey)
  198. forwardingChains := strings.Join(session.ForwardingChains, ",")
  199. _, err := store.db.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, signing_key, room_id, session, forwarding_chains) VALUES ($1, $2, $3, $4, $5, $6)",
  200. sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains)
  201. return err
  202. }
  203. func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*crypto.InboundGroupSession, error) {
  204. var signingKey id.Ed25519
  205. var sessionBytes []byte
  206. var forwardingChains string
  207. err := store.db.QueryRow(`
  208. SELECT signing_key, session, forwarding_chains
  209. FROM crypto_megolm_inbound_session
  210. WHERE room_id=$1 AND sender_key=$2 AND session_id=$3`,
  211. roomID, senderKey, sessionID,
  212. ).Scan(&signingKey, &sessionBytes, &forwardingChains)
  213. if err == sql.ErrNoRows {
  214. return nil, nil
  215. } else if err != nil {
  216. return nil, err
  217. }
  218. igs := olm.NewBlankInboundGroupSession()
  219. err = igs.Unpickle(sessionBytes, store.PickleKey)
  220. if err != nil {
  221. return nil, err
  222. }
  223. return &crypto.InboundGroupSession{
  224. Internal: *igs,
  225. SigningKey: signingKey,
  226. SenderKey: senderKey,
  227. RoomID: roomID,
  228. ForwardingChains: strings.Split(forwardingChains, ","),
  229. }, nil
  230. }
  231. func (store *SQLCryptoStore) PutOutboundGroupSession(roomID id.RoomID, session *crypto.OutboundGroupSession) error {
  232. store.OGSLock.Lock()
  233. store.OutGroupSessions[roomID] = session
  234. store.OGSLock.Unlock()
  235. return nil
  236. }
  237. func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*crypto.OutboundGroupSession, error) {
  238. store.OGSLock.RLock()
  239. defer store.OGSLock.RUnlock()
  240. return store.OutGroupSessions[roomID], nil
  241. }
  242. func (store *SQLCryptoStore) PopOutboundGroupSession(roomID id.RoomID) error {
  243. store.OGSLock.Lock()
  244. delete(store.OutGroupSessions, roomID)
  245. store.OGSLock.Unlock()
  246. return nil
  247. }
  248. func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) bool {
  249. var resultEventID id.EventID
  250. var resultTimestamp int64
  251. err := store.db.QueryRow(
  252. `SELECT event_id, timestamp FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND "index"=$3`,
  253. senderKey, sessionID, index,
  254. ).Scan(&resultEventID, &resultTimestamp)
  255. if err == sql.ErrNoRows {
  256. _, err := store.db.Exec(`INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp) VALUES ($1, $2, $3, $4, $5)`,
  257. senderKey, sessionID, index, eventID, timestamp)
  258. if err != nil {
  259. store.log.Warnln("Failed to store message index:", err)
  260. }
  261. return true
  262. } else if err != nil {
  263. store.log.Warnln("Failed to scan message index:", err)
  264. return true
  265. }
  266. if resultEventID != eventID || resultTimestamp != timestamp {
  267. return false
  268. }
  269. return true
  270. }
  271. func (store *SQLCryptoStore) GetDevices(userID id.UserID) (map[id.DeviceID]*crypto.DeviceIdentity, error) {
  272. var ignore id.UserID
  273. err := store.db.QueryRow("SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore)
  274. if err == sql.ErrNoRows {
  275. return nil, nil
  276. } else if err != nil {
  277. return nil, err
  278. }
  279. rows, err := store.db.Query("SELECT device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1", userID)
  280. if err != nil {
  281. return nil, err
  282. }
  283. data := make(map[id.DeviceID]*crypto.DeviceIdentity)
  284. for rows.Next() {
  285. var identity crypto.DeviceIdentity
  286. err := rows.Scan(&identity.DeviceID, &identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
  287. if err != nil {
  288. return nil, err
  289. }
  290. identity.UserID = userID
  291. data[identity.DeviceID] = &identity
  292. }
  293. return data, nil
  294. }
  295. func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*crypto.DeviceIdentity) error {
  296. tx, err := store.db.Begin()
  297. if err != nil {
  298. return err
  299. }
  300. if store.db.dialect == "postgres" {
  301. _, err = tx.Exec("INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
  302. } else if store.db.dialect == "sqlite3" {
  303. _, err = tx.Exec("INSERT OR IGNORE INTO crypto_tracked_user (user_id) VALUES ($1)", userID)
  304. } else {
  305. err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
  306. }
  307. if err != nil {
  308. return errors.Wrap(err, "failed to add user to tracked users list")
  309. }
  310. _, err = tx.Exec("DELETE FROM crypto_device WHERE user_id=$1", userID)
  311. if err != nil {
  312. _ = tx.Rollback()
  313. return errors.Wrap(err, "failed to delete old devices")
  314. }
  315. if len(devices) == 0 {
  316. err = tx.Commit()
  317. if err != nil {
  318. return errors.Wrap(err, "failed to commit changes (no devices added)")
  319. }
  320. return nil
  321. }
  322. // TODO do this in batches to avoid too large db queries
  323. values := make([]interface{}, 1, len(devices)*6+1)
  324. values[0] = userID
  325. valueStrings := make([]string, 0, len(devices))
  326. i := 2
  327. for deviceID, identity := range devices {
  328. values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name)
  329. valueStrings = append(valueStrings, fmt.Sprintf("($1, $%d, $%d, $%d, $%d, $%d, $%d)", i, i+1, i+2, i+3, i+4, i+5))
  330. i += 6
  331. }
  332. valueString := strings.Join(valueStrings, ",")
  333. _, err = tx.Exec("INSERT INTO crypto_device (user_id, device_id, identity_key, signing_key, trust, deleted, name) VALUES "+valueString, values...)
  334. if err != nil {
  335. _ = tx.Rollback()
  336. return errors.Wrap(err, "failed to insert new devices")
  337. }
  338. err = tx.Commit()
  339. if err != nil {
  340. return errors.Wrap(err, "failed to commit changes")
  341. }
  342. return nil
  343. }
  344. func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID {
  345. var rows *sql.Rows
  346. var err error
  347. if store.db.dialect == "postgres" {
  348. rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", pq.Array(users))
  349. } else {
  350. queryString := make([]string, len(users))
  351. params := make([]interface{}, len(users))
  352. for i, user := range users {
  353. queryString[i] = fmt.Sprintf("$%d", i+1)
  354. params[i] = user
  355. }
  356. rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN (" + strings.Join(queryString, ",") + ")", params...)
  357. }
  358. if err != nil {
  359. store.log.Warnln("Failed to filter tracked users:", err)
  360. return users
  361. }
  362. var ptr int
  363. for rows.Next() {
  364. err = rows.Scan(&users[ptr])
  365. if err != nil {
  366. store.log.Warnln("Failed to tracked user ID:", err)
  367. } else {
  368. ptr++
  369. }
  370. }
  371. return users[:ptr]
  372. }