cryptostore.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
  2. // Copyright (C) 2020 Tulir Asokan
  3. //
  4. // This program is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Affero General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // This program is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Affero General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Affero General Public License
  15. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. package database
  17. import (
  18. "database/sql"
  19. "fmt"
  20. "strings"
  21. "sync"
  22. "github.com/lib/pq"
  23. "github.com/pkg/errors"
  24. log "maunium.net/go/maulogger/v2"
  25. "maunium.net/go/mautrix/crypto"
  26. "maunium.net/go/mautrix/crypto/olm"
  27. "maunium.net/go/mautrix/id"
  28. )
  29. type SQLCryptoStore struct {
  30. db *Database
  31. log log.Logger
  32. UserID id.UserID
  33. DeviceID id.DeviceID
  34. SyncToken string
  35. PickleKey []byte
  36. Account *crypto.OlmAccount
  37. GhostIDFormat string
  38. OGSLock sync.RWMutex
  39. OutGroupSessions map[id.RoomID]*crypto.OutboundGroupSession
  40. }
  41. var _ crypto.Store = (*SQLCryptoStore)(nil)
  42. func NewSQLCryptoStore(db *Database, deviceID id.DeviceID) *SQLCryptoStore {
  43. return &SQLCryptoStore{
  44. db: db,
  45. log: db.log.Sub("CryptoStore"),
  46. PickleKey: []byte("maunium.net/go/mautrix-whatsapp"),
  47. DeviceID: deviceID,
  48. OutGroupSessions: make(map[id.RoomID]*crypto.OutboundGroupSession),
  49. }
  50. }
  51. func (db *Database) FindDeviceID() (deviceID id.DeviceID) {
  52. err := db.QueryRow("SELECT device_id FROM crypto_account LIMIT 1").Scan(&deviceID)
  53. if err != nil && err != sql.ErrNoRows {
  54. db.log.Warnln("Failed to scan device ID:", err)
  55. }
  56. return
  57. }
  58. func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) (members []id.UserID, err error) {
  59. var rows *sql.Rows
  60. rows, err = store.db.Query(`
  61. SELECT user_id FROM mx_user_profile
  62. WHERE room_id=$1
  63. AND (membership='join' OR membership='invite')
  64. AND user_id<>$2
  65. AND user_id NOT LIKE $3
  66. `, roomID, store.UserID, store.GhostIDFormat)
  67. if err != nil {
  68. return
  69. }
  70. for rows.Next() {
  71. var userID id.UserID
  72. err := rows.Scan(&userID)
  73. if err != nil {
  74. store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
  75. } else {
  76. members = append(members, userID)
  77. }
  78. }
  79. return
  80. }
  81. func (store *SQLCryptoStore) Flush() error {
  82. return nil
  83. }
  84. func (store *SQLCryptoStore) PutNextBatch(nextBatch string) {
  85. store.SyncToken = nextBatch
  86. _, err := store.db.Exec(`UPDATE crypto_account SET sync_token=$1 WHERE device_id=$2`, store.SyncToken, store.DeviceID)
  87. if err != nil {
  88. store.log.Warnln("Failed to store sync token:", err)
  89. }
  90. }
  91. func (store *SQLCryptoStore) GetNextBatch() string {
  92. if store.SyncToken == "" {
  93. err := store.db.
  94. QueryRow("SELECT sync_token FROM crypto_account WHERE device_id=$1", store.DeviceID).
  95. Scan(&store.SyncToken)
  96. if err != nil && err != sql.ErrNoRows {
  97. store.log.Warnln("Failed to scan sync token:", err)
  98. }
  99. }
  100. return store.SyncToken
  101. }
  102. func (store *SQLCryptoStore) PutAccount(account *crypto.OlmAccount) error {
  103. store.Account = account
  104. bytes := account.Internal.Pickle(store.PickleKey)
  105. var err error
  106. if store.db.dialect == "postgres" {
  107. _, err = store.db.Exec(`
  108. INSERT INTO crypto_account (device_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)
  109. ON CONFLICT (device_id) DO UPDATE SET shared=$2, sync_token=$3, account=$4`,
  110. store.DeviceID, account.Shared, store.SyncToken, bytes)
  111. } else if store.db.dialect == "sqlite3" {
  112. _, err = store.db.Exec("INSERT OR REPLACE INTO crypto_account (deivce_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)",
  113. store.DeviceID, account.Shared, store.SyncToken, bytes)
  114. } else {
  115. err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
  116. }
  117. if err != nil {
  118. store.log.Warnln("Failed to store account:", err)
  119. }
  120. return nil
  121. }
  122. func (store *SQLCryptoStore) GetAccount() (*crypto.OlmAccount, error) {
  123. if store.Account == nil {
  124. row := store.db.QueryRow("SELECT shared, sync_token, account FROM crypto_account WHERE device_id=$1", store.DeviceID)
  125. acc := &crypto.OlmAccount{Internal: *olm.NewBlankAccount()}
  126. var accountBytes []byte
  127. err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes)
  128. if err == sql.ErrNoRows {
  129. return nil, nil
  130. } else if err != nil {
  131. return nil, err
  132. }
  133. err = acc.Internal.Unpickle(accountBytes, store.PickleKey)
  134. if err != nil {
  135. return nil, err
  136. }
  137. store.Account = acc
  138. }
  139. return store.Account, nil
  140. }
  141. func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
  142. // TODO this may need to be changed if olm sessions start expiring
  143. var sessionID id.SessionID
  144. err := store.db.QueryRow("SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 LIMIT 1", key).Scan(&sessionID)
  145. if err == sql.ErrNoRows {
  146. return false
  147. }
  148. return len(sessionID) > 0
  149. }
  150. func (store *SQLCryptoStore) GetSessions(key id.SenderKey) (crypto.OlmSessionList, error) {
  151. rows, err := store.db.Query("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 ORDER BY session_id", key)
  152. if err != nil {
  153. return nil, err
  154. }
  155. list := crypto.OlmSessionList{}
  156. for rows.Next() {
  157. sess := crypto.OlmSession{Internal: *olm.NewBlankSession()}
  158. var sessionBytes []byte
  159. err := rows.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
  160. if err != nil {
  161. return nil, err
  162. }
  163. err = sess.Internal.Unpickle(sessionBytes, store.PickleKey)
  164. if err != nil {
  165. return nil, err
  166. }
  167. list = append(list, &sess)
  168. }
  169. return list, nil
  170. }
  171. func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*crypto.OlmSession, error) {
  172. row := store.db.QueryRow("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 ORDER BY session_id DESC LIMIT 1", key)
  173. sess := crypto.OlmSession{Internal: *olm.NewBlankSession()}
  174. var sessionBytes []byte
  175. err := row.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
  176. if err == sql.ErrNoRows {
  177. return nil, nil
  178. } else if err != nil {
  179. return nil, err
  180. }
  181. return &sess, sess.Internal.Unpickle(sessionBytes, store.PickleKey)
  182. }
  183. func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *crypto.OlmSession) error {
  184. sessionBytes := session.Internal.Pickle(store.PickleKey)
  185. _, err := store.db.Exec("INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_used) VALUES ($1, $2, $3, $4, $5)",
  186. session.ID(), key, sessionBytes, session.CreationTime, session.UseTime)
  187. return err
  188. }
  189. func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *crypto.InboundGroupSession) error {
  190. sessionBytes := session.Internal.Pickle(store.PickleKey)
  191. forwardingChains := strings.Join(session.ForwardingChains, ",")
  192. _, err := store.db.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, signing_key, room_id, session, forwarding_chains) VALUES ($1, $2, $3, $4, $5, $6)",
  193. sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains)
  194. return err
  195. }
  196. func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*crypto.InboundGroupSession, error) {
  197. var signingKey id.Ed25519
  198. var sessionBytes []byte
  199. var forwardingChains string
  200. err := store.db.QueryRow(`
  201. SELECT signing_key, session, forwarding_chains
  202. FROM crypto_megolm_inbound_session
  203. WHERE room_id=$1 AND sender_key=$2 AND session_id=$3`,
  204. roomID, senderKey, sessionID,
  205. ).Scan(&signingKey, &sessionBytes, &forwardingChains)
  206. if err == sql.ErrNoRows {
  207. return nil, nil
  208. } else if err != nil {
  209. return nil, err
  210. }
  211. igs := olm.NewBlankInboundGroupSession()
  212. err = igs.Unpickle(sessionBytes, store.PickleKey)
  213. if err != nil {
  214. return nil, err
  215. }
  216. return &crypto.InboundGroupSession{
  217. Internal: *igs,
  218. SigningKey: signingKey,
  219. SenderKey: senderKey,
  220. RoomID: roomID,
  221. ForwardingChains: strings.Split(forwardingChains, ","),
  222. }, nil
  223. }
  224. func (store *SQLCryptoStore) PutOutboundGroupSession(roomID id.RoomID, session *crypto.OutboundGroupSession) error {
  225. store.OGSLock.Lock()
  226. store.OutGroupSessions[roomID] = session
  227. store.OGSLock.Unlock()
  228. return nil
  229. }
  230. func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*crypto.OutboundGroupSession, error) {
  231. store.OGSLock.RLock()
  232. defer store.OGSLock.RUnlock()
  233. return store.OutGroupSessions[roomID], nil
  234. }
  235. func (store *SQLCryptoStore) PopOutboundGroupSession(roomID id.RoomID) error {
  236. store.OGSLock.Lock()
  237. delete(store.OutGroupSessions, roomID)
  238. store.OGSLock.Unlock()
  239. return nil
  240. }
  241. func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) bool {
  242. var resultEventID id.EventID
  243. var resultTimestamp int64
  244. err := store.db.QueryRow(
  245. "SELECT event_id, timestamp FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND index=$3",
  246. senderKey, sessionID, index,
  247. ).Scan(&resultEventID, &resultTimestamp)
  248. if err == sql.ErrNoRows {
  249. _, err := store.db.Exec("INSERT INTO crypto_message_index (sender_key, session_id, index, event_id, timestamp) VALUES ($1, $2, $3, $4, $5)",
  250. senderKey, sessionID, index, eventID, timestamp)
  251. if err != nil {
  252. store.log.Warnln("Failed to store message index:", err)
  253. }
  254. return true
  255. } else if err != nil {
  256. store.log.Warnln("Failed to scan message index:", err)
  257. return true
  258. }
  259. if resultEventID != eventID || resultTimestamp != timestamp {
  260. return false
  261. }
  262. return true
  263. }
  264. func (store *SQLCryptoStore) GetDevices(userID id.UserID) (map[id.DeviceID]*crypto.DeviceIdentity, error) {
  265. var ignore id.UserID
  266. err := store.db.QueryRow("SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore)
  267. if err == sql.ErrNoRows {
  268. return nil, nil
  269. } else if err != nil {
  270. return nil, err
  271. }
  272. rows, err := store.db.Query("SELECT device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1", userID)
  273. if err != nil {
  274. return nil, err
  275. }
  276. data := make(map[id.DeviceID]*crypto.DeviceIdentity)
  277. for rows.Next() {
  278. var identity crypto.DeviceIdentity
  279. err := rows.Scan(&identity.DeviceID, &identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
  280. if err != nil {
  281. return nil, err
  282. }
  283. identity.UserID = userID
  284. data[identity.DeviceID] = &identity
  285. }
  286. return data, nil
  287. }
  288. func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*crypto.DeviceIdentity) error {
  289. tx, err := store.db.Begin()
  290. if err != nil {
  291. return err
  292. }
  293. if store.db.dialect == "postgres" {
  294. _, err = tx.Exec("INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
  295. } else if store.db.dialect == "sqlite3" {
  296. _, err = tx.Exec("INSERT OR IGNORE INTO crypto_tracked_users (user_id) VALUES ($1)", userID)
  297. } else {
  298. err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
  299. }
  300. if err != nil {
  301. return errors.Wrap(err, "failed to add user to tracked users list")
  302. }
  303. _, err = tx.Exec("DELETE FROM crypto_device WHERE user_id=$1", userID)
  304. if err != nil {
  305. _ = tx.Rollback()
  306. return errors.Wrap(err, "failed to delete old devices")
  307. }
  308. if len(devices) == 0 {
  309. err = tx.Commit()
  310. if err != nil {
  311. return errors.Wrap(err, "failed to commit changes (no devices added)")
  312. }
  313. return nil
  314. }
  315. // TODO do this in batches to avoid too large db queries
  316. values := make([]interface{}, 1, len(devices)*6+1)
  317. values[0] = userID
  318. valueStrings := make([]string, 0, len(devices))
  319. i := 2
  320. for deviceID, identity := range devices {
  321. values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name)
  322. valueStrings = append(valueStrings, fmt.Sprintf("($1, $%d, $%d, $%d, $%d, $%d, $%d)", i, i+1, i+2, i+3, i+4, i+5))
  323. i += 6
  324. }
  325. valueString := strings.Join(valueStrings, ",")
  326. _, err = tx.Exec("INSERT INTO crypto_device (user_id, device_id, identity_key, signing_key, trust, deleted, name) VALUES "+valueString, values...)
  327. if err != nil {
  328. _ = tx.Rollback()
  329. return errors.Wrap(err, "failed to insert new devices")
  330. }
  331. err = tx.Commit()
  332. if err != nil {
  333. return errors.Wrap(err, "failed to commit changes")
  334. }
  335. return nil
  336. }
  337. func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID {
  338. var rows *sql.Rows
  339. var err error
  340. if store.db.dialect == "postgres" {
  341. rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", pq.Array(users))
  342. } else {
  343. rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ($1)", users)
  344. }
  345. if err != nil {
  346. store.log.Warnln("Failed to filter tracked users:", err)
  347. return users
  348. }
  349. var ptr int
  350. for rows.Next() {
  351. err = rows.Scan(&users[ptr])
  352. if err != nil {
  353. store.log.Warnln("Failed to tracked user ID:", err)
  354. } else {
  355. ptr++
  356. }
  357. }
  358. return users[:ptr]
  359. }