cryptostore.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
  2. // Copyright (C) 2020 Tulir Asokan
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. // +build cgo
  17. package database
  18. import (
  19. "database/sql"
  20. "fmt"
  21. "strings"
  22. "github.com/lib/pq"
  23. "github.com/pkg/errors"
  24. log "maunium.net/go/maulogger/v2"
  25. "maunium.net/go/mautrix/crypto"
  26. "maunium.net/go/mautrix/crypto/olm"
  27. "maunium.net/go/mautrix/id"
  28. )
  29. type SQLCryptoStore struct {
  30. db *Database
  31. log log.Logger
  32. UserID id.UserID
  33. DeviceID id.DeviceID
  34. SyncToken string
  35. PickleKey []byte
  36. Account *crypto.OlmAccount
  37. GhostIDFormat string
  38. }
  39. var _ crypto.Store = (*SQLCryptoStore)(nil)
  40. func NewSQLCryptoStore(db *Database, deviceID id.DeviceID) *SQLCryptoStore {
  41. return &SQLCryptoStore{
  42. db: db,
  43. log: db.log.Sub("CryptoStore"),
  44. PickleKey: []byte("maunium.net/go/mautrix-whatsapp"),
  45. DeviceID: deviceID,
  46. }
  47. }
  48. func (db *Database) FindDeviceID() (deviceID id.DeviceID) {
  49. err := db.QueryRow("SELECT device_id FROM crypto_account LIMIT 1").Scan(&deviceID)
  50. if err != nil && err != sql.ErrNoRows {
  51. db.log.Warnln("Failed to scan device ID:", err)
  52. }
  53. return
  54. }
  55. func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) (members []id.UserID, err error) {
  56. var rows *sql.Rows
  57. rows, err = store.db.Query(`
  58. SELECT user_id FROM mx_user_profile
  59. WHERE room_id=$1
  60. AND (membership='join' OR membership='invite')
  61. AND user_id<>$2
  62. AND user_id NOT LIKE $3
  63. `, roomID, store.UserID, store.GhostIDFormat)
  64. if err != nil {
  65. return
  66. }
  67. for rows.Next() {
  68. var userID id.UserID
  69. err := rows.Scan(&userID)
  70. if err != nil {
  71. store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
  72. } else {
  73. members = append(members, userID)
  74. }
  75. }
  76. return
  77. }
  78. func (store *SQLCryptoStore) Flush() error {
  79. return nil
  80. }
  81. func (store *SQLCryptoStore) PutNextBatch(nextBatch string) {
  82. store.SyncToken = nextBatch
  83. _, err := store.db.Exec(`UPDATE crypto_account SET sync_token=$1 WHERE device_id=$2`, store.SyncToken, store.DeviceID)
  84. if err != nil {
  85. store.log.Warnln("Failed to store sync token:", err)
  86. }
  87. }
  88. func (store *SQLCryptoStore) GetNextBatch() string {
  89. if store.SyncToken == "" {
  90. err := store.db.
  91. QueryRow("SELECT sync_token FROM crypto_account WHERE device_id=$1", store.DeviceID).
  92. Scan(&store.SyncToken)
  93. if err != nil && err != sql.ErrNoRows {
  94. store.log.Warnln("Failed to scan sync token:", err)
  95. }
  96. }
  97. return store.SyncToken
  98. }
  99. func (store *SQLCryptoStore) PutAccount(account *crypto.OlmAccount) error {
  100. store.Account = account
  101. bytes := account.Internal.Pickle(store.PickleKey)
  102. var err error
  103. if store.db.dialect == "postgres" {
  104. _, err = store.db.Exec(`
  105. INSERT INTO crypto_account (device_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)
  106. ON CONFLICT (device_id) DO UPDATE SET shared=$2, sync_token=$3, account=$4`,
  107. store.DeviceID, account.Shared, store.SyncToken, bytes)
  108. } else if store.db.dialect == "sqlite3" {
  109. _, err = store.db.Exec("INSERT OR REPLACE INTO crypto_account (device_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)",
  110. store.DeviceID, account.Shared, store.SyncToken, bytes)
  111. } else {
  112. err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
  113. }
  114. if err != nil {
  115. store.log.Warnln("Failed to store account:", err)
  116. }
  117. return nil
  118. }
  119. func (store *SQLCryptoStore) GetAccount() (*crypto.OlmAccount, error) {
  120. if store.Account == nil {
  121. row := store.db.QueryRow("SELECT shared, sync_token, account FROM crypto_account WHERE device_id=$1", store.DeviceID)
  122. acc := &crypto.OlmAccount{Internal: *olm.NewBlankAccount()}
  123. var accountBytes []byte
  124. err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes)
  125. if err == sql.ErrNoRows {
  126. return nil, nil
  127. } else if err != nil {
  128. return nil, err
  129. }
  130. err = acc.Internal.Unpickle(accountBytes, store.PickleKey)
  131. if err != nil {
  132. return nil, err
  133. }
  134. store.Account = acc
  135. }
  136. return store.Account, nil
  137. }
  138. func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
  139. // TODO this may need to be changed if olm sessions start expiring
  140. var sessionID id.SessionID
  141. err := store.db.QueryRow("SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 LIMIT 1", key).Scan(&sessionID)
  142. if err == sql.ErrNoRows {
  143. return false
  144. }
  145. return len(sessionID) > 0
  146. }
  147. func (store *SQLCryptoStore) GetSessions(key id.SenderKey) (crypto.OlmSessionList, error) {
  148. rows, err := store.db.Query("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 ORDER BY session_id", key)
  149. if err != nil {
  150. return nil, err
  151. }
  152. list := crypto.OlmSessionList{}
  153. for rows.Next() {
  154. sess := crypto.OlmSession{Internal: *olm.NewBlankSession()}
  155. var sessionBytes []byte
  156. err := rows.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
  157. if err != nil {
  158. return nil, err
  159. }
  160. err = sess.Internal.Unpickle(sessionBytes, store.PickleKey)
  161. if err != nil {
  162. return nil, err
  163. }
  164. list = append(list, &sess)
  165. }
  166. return list, nil
  167. }
  168. func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*crypto.OlmSession, error) {
  169. row := store.db.QueryRow("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 ORDER BY session_id DESC LIMIT 1", key)
  170. sess := crypto.OlmSession{Internal: *olm.NewBlankSession()}
  171. var sessionBytes []byte
  172. err := row.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
  173. if err == sql.ErrNoRows {
  174. return nil, nil
  175. } else if err != nil {
  176. return nil, err
  177. }
  178. return &sess, sess.Internal.Unpickle(sessionBytes, store.PickleKey)
  179. }
  180. func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *crypto.OlmSession) error {
  181. sessionBytes := session.Internal.Pickle(store.PickleKey)
  182. _, err := store.db.Exec("INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_used) VALUES ($1, $2, $3, $4, $5)",
  183. session.ID(), key, sessionBytes, session.CreationTime, session.UseTime)
  184. return err
  185. }
  186. func (store *SQLCryptoStore) UpdateSession(key id.SenderKey, session *crypto.OlmSession) error {
  187. sessionBytes := session.Internal.Pickle(store.PickleKey)
  188. _, err := store.db.Exec("UPDATE crypto_olm_session SET session=$1, last_used=$2 WHERE session_id=$3",
  189. sessionBytes, session.UseTime, session.ID())
  190. return err
  191. }
  192. func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *crypto.InboundGroupSession) error {
  193. sessionBytes := session.Internal.Pickle(store.PickleKey)
  194. forwardingChains := strings.Join(session.ForwardingChains, ",")
  195. _, err := store.db.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, signing_key, room_id, session, forwarding_chains) VALUES ($1, $2, $3, $4, $5, $6)",
  196. sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains)
  197. return err
  198. }
  199. func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*crypto.InboundGroupSession, error) {
  200. var signingKey id.Ed25519
  201. var sessionBytes []byte
  202. var forwardingChains string
  203. err := store.db.QueryRow(`
  204. SELECT signing_key, session, forwarding_chains
  205. FROM crypto_megolm_inbound_session
  206. WHERE room_id=$1 AND sender_key=$2 AND session_id=$3`,
  207. roomID, senderKey, sessionID,
  208. ).Scan(&signingKey, &sessionBytes, &forwardingChains)
  209. if err == sql.ErrNoRows {
  210. return nil, nil
  211. } else if err != nil {
  212. return nil, err
  213. }
  214. igs := olm.NewBlankInboundGroupSession()
  215. err = igs.Unpickle(sessionBytes, store.PickleKey)
  216. if err != nil {
  217. return nil, err
  218. }
  219. return &crypto.InboundGroupSession{
  220. Internal: *igs,
  221. SigningKey: signingKey,
  222. SenderKey: senderKey,
  223. RoomID: roomID,
  224. ForwardingChains: strings.Split(forwardingChains, ","),
  225. }, nil
  226. }
  227. func (store *SQLCryptoStore) AddOutboundGroupSession(session *crypto.OutboundGroupSession) (err error) {
  228. sessionBytes := session.Internal.Pickle(store.PickleKey)
  229. if store.db.dialect == "postgres" {
  230. _, err = store.db.Exec(`
  231. INSERT INTO crypto_megolm_outbound_session (
  232. room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used
  233. ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
  234. ON CONFLICT (room_id) DO UPDATE SET session_id=$2, session=$3, shared=$4, max_messages=$5, message_count=$6, max_age=$7, created_at=$8, last_used=$9`,
  235. session.RoomID, session.ID(), sessionBytes, session.Shared, session.MaxMessages, session.MessageCount, session.MaxAge, session.CreationTime, session.UseTime)
  236. } else if store.db.dialect == "sqlite3" {
  237. _, err = store.db.Exec(`
  238. INSERT OR REPLACE INTO crypto_megolm_outbound_session (
  239. room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used
  240. ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
  241. session.RoomID, session.ID(), sessionBytes, session.Shared, session.MaxMessages, session.MessageCount, session.MaxAge, session.CreationTime, session.UseTime)
  242. } else {
  243. err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
  244. }
  245. return
  246. }
  247. func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *crypto.OutboundGroupSession) error {
  248. sessionBytes := session.Internal.Pickle(store.PickleKey)
  249. _, err := store.db.Exec("UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5",
  250. sessionBytes, session.MessageCount, session.UseTime, session.RoomID, session.ID())
  251. return err
  252. }
  253. func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*crypto.OutboundGroupSession, error) {
  254. var ogs crypto.OutboundGroupSession
  255. var sessionBytes []byte
  256. err := store.db.QueryRow(`
  257. SELECT session, shared, max_messages, message_count, max_age, created_at, last_used
  258. FROM crypto_megolm_outbound_session WHERE room_id=$1`,
  259. roomID,
  260. ).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &ogs.MaxAge, &ogs.CreationTime, &ogs.UseTime)
  261. if err == sql.ErrNoRows {
  262. return nil, nil
  263. } else if err != nil {
  264. return nil, err
  265. }
  266. intOGS := olm.NewBlankOutboundGroupSession()
  267. err = intOGS.Unpickle(sessionBytes, store.PickleKey)
  268. if err != nil {
  269. return nil, err
  270. }
  271. ogs.Internal = *intOGS
  272. ogs.RoomID = roomID
  273. return &ogs, nil
  274. }
  275. func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
  276. _, err := store.db.Exec("DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1", roomID)
  277. return err
  278. }
  279. func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) bool {
  280. var resultEventID id.EventID
  281. var resultTimestamp int64
  282. err := store.db.QueryRow(
  283. `SELECT event_id, timestamp FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND "index"=$3`,
  284. senderKey, sessionID, index,
  285. ).Scan(&resultEventID, &resultTimestamp)
  286. if err == sql.ErrNoRows {
  287. _, err := store.db.Exec(`INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp) VALUES ($1, $2, $3, $4, $5)`,
  288. senderKey, sessionID, index, eventID, timestamp)
  289. if err != nil {
  290. store.log.Warnln("Failed to store message index:", err)
  291. }
  292. return true
  293. } else if err != nil {
  294. store.log.Warnln("Failed to scan message index:", err)
  295. return true
  296. }
  297. if resultEventID != eventID || resultTimestamp != timestamp {
  298. return false
  299. }
  300. return true
  301. }
  302. func (store *SQLCryptoStore) GetDevices(userID id.UserID) (map[id.DeviceID]*crypto.DeviceIdentity, error) {
  303. var ignore id.UserID
  304. err := store.db.QueryRow("SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore)
  305. if err == sql.ErrNoRows {
  306. return nil, nil
  307. } else if err != nil {
  308. return nil, err
  309. }
  310. rows, err := store.db.Query("SELECT device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1", userID)
  311. if err != nil {
  312. return nil, err
  313. }
  314. data := make(map[id.DeviceID]*crypto.DeviceIdentity)
  315. for rows.Next() {
  316. var identity crypto.DeviceIdentity
  317. err := rows.Scan(&identity.DeviceID, &identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
  318. if err != nil {
  319. return nil, err
  320. }
  321. identity.UserID = userID
  322. data[identity.DeviceID] = &identity
  323. }
  324. return data, nil
  325. }
  326. func (store *SQLCryptoStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*crypto.DeviceIdentity, error) {
  327. var identity crypto.DeviceIdentity
  328. err := store.db.QueryRow(`
  329. SELECT identity_key, signing_key, trust, deleted, name
  330. FROM crypto_device WHERE user_id=$1 AND device_id=$2`,
  331. userID, deviceID,
  332. ).Scan(&identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
  333. if err != nil {
  334. if err == sql.ErrNoRows {
  335. return nil, nil
  336. }
  337. return nil, err
  338. }
  339. return &identity, nil
  340. }
  341. func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*crypto.DeviceIdentity) error {
  342. tx, err := store.db.Begin()
  343. if err != nil {
  344. return err
  345. }
  346. if store.db.dialect == "postgres" {
  347. _, err = tx.Exec("INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
  348. } else if store.db.dialect == "sqlite3" {
  349. _, err = tx.Exec("INSERT OR IGNORE INTO crypto_tracked_user (user_id) VALUES ($1)", userID)
  350. } else {
  351. err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
  352. }
  353. if err != nil {
  354. return errors.Wrap(err, "failed to add user to tracked users list")
  355. }
  356. _, err = tx.Exec("DELETE FROM crypto_device WHERE user_id=$1", userID)
  357. if err != nil {
  358. _ = tx.Rollback()
  359. return errors.Wrap(err, "failed to delete old devices")
  360. }
  361. if len(devices) == 0 {
  362. err = tx.Commit()
  363. if err != nil {
  364. return errors.Wrap(err, "failed to commit changes (no devices added)")
  365. }
  366. return nil
  367. }
  368. // TODO do this in batches to avoid too large db queries
  369. values := make([]interface{}, 1, len(devices)*6+1)
  370. values[0] = userID
  371. valueStrings := make([]string, 0, len(devices))
  372. i := 2
  373. for deviceID, identity := range devices {
  374. values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name)
  375. valueStrings = append(valueStrings, fmt.Sprintf("($1, $%d, $%d, $%d, $%d, $%d, $%d)", i, i+1, i+2, i+3, i+4, i+5))
  376. i += 6
  377. }
  378. valueString := strings.Join(valueStrings, ",")
  379. _, err = tx.Exec("INSERT INTO crypto_device (user_id, device_id, identity_key, signing_key, trust, deleted, name) VALUES "+valueString, values...)
  380. if err != nil {
  381. _ = tx.Rollback()
  382. return errors.Wrap(err, "failed to insert new devices")
  383. }
  384. err = tx.Commit()
  385. if err != nil {
  386. return errors.Wrap(err, "failed to commit changes")
  387. }
  388. return nil
  389. }
  390. func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID {
  391. var rows *sql.Rows
  392. var err error
  393. if store.db.dialect == "postgres" {
  394. rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", pq.Array(users))
  395. } else {
  396. queryString := make([]string, len(users))
  397. params := make([]interface{}, len(users))
  398. for i, user := range users {
  399. queryString[i] = fmt.Sprintf("$%d", i+1)
  400. params[i] = user
  401. }
  402. rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...)
  403. }
  404. if err != nil {
  405. store.log.Warnln("Failed to filter tracked users:", err)
  406. return users
  407. }
  408. var ptr int
  409. for rows.Next() {
  410. err = rows.Scan(&users[ptr])
  411. if err != nil {
  412. store.log.Warnln("Failed to tracked user ID:", err)
  413. } else {
  414. ptr++
  415. }
  416. }
  417. return users[:ptr]
  418. }