file.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. package database
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "errors"
  6. "time"
  7. log "maunium.net/go/maulogger/v2"
  8. "maunium.net/go/mautrix/crypto/attachment"
  9. "maunium.net/go/mautrix/id"
  10. "maunium.net/go/mautrix/util/dbutil"
  11. )
  12. type FileQuery struct {
  13. db *Database
  14. log log.Logger
  15. }
  16. // language=postgresql
  17. const (
  18. fileSelect = "SELECT url, encrypted, id, mxc, size, width, height, decryption_info, timestamp FROM discord_file"
  19. fileInsert = `
  20. INSERT INTO discord_file (url, encrypted, id, mxc, size, width, height, decryption_info, timestamp)
  21. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
  22. `
  23. )
  24. func (fq *FileQuery) New() *File {
  25. return &File{
  26. db: fq.db,
  27. log: fq.log,
  28. }
  29. }
  30. func (fq *FileQuery) Get(url string, encrypted bool) *File {
  31. query := fileSelect + " WHERE url=$1 AND encrypted=$2"
  32. return fq.New().Scan(fq.db.QueryRow(query, url, encrypted))
  33. }
  34. type File struct {
  35. db *Database
  36. log log.Logger
  37. URL string
  38. Encrypted bool
  39. ID string
  40. MXC id.ContentURI
  41. Size int
  42. Width int
  43. Height int
  44. DecryptionInfo *attachment.EncryptedFile
  45. Timestamp time.Time
  46. }
  47. func (f *File) Scan(row dbutil.Scannable) *File {
  48. var fileID sql.NullString
  49. var decryptionInfo []byte
  50. var width, height sql.NullInt32
  51. var timestamp int64
  52. var mxc string
  53. err := row.Scan(&f.URL, &f.Encrypted, &fileID, &mxc, &f.Size, &width, &height, &decryptionInfo, &timestamp)
  54. if err != nil {
  55. if !errors.Is(err, sql.ErrNoRows) {
  56. f.log.Errorln("Database scan failed:", err)
  57. panic(err)
  58. }
  59. return nil
  60. }
  61. f.ID = fileID.String
  62. f.Timestamp = time.UnixMilli(timestamp)
  63. f.Width = int(width.Int32)
  64. f.Height = int(height.Int32)
  65. f.MXC, err = id.ParseContentURI(mxc)
  66. if err != nil {
  67. f.log.Errorfln("Failed to parse content URI %s: %v", mxc, err)
  68. panic(err)
  69. }
  70. if decryptionInfo != nil {
  71. err = json.Unmarshal(decryptionInfo, &f.DecryptionInfo)
  72. if err != nil {
  73. f.log.Errorfln("Failed to unmarshal decryption info of %v: %v", f.MXC, err)
  74. panic(err)
  75. }
  76. }
  77. return f
  78. }
  79. func positiveIntToNullInt32(val int) (ptr sql.NullInt32) {
  80. if val > 0 {
  81. ptr.Valid = true
  82. ptr.Int32 = int32(val)
  83. }
  84. return
  85. }
  86. func (f *File) Insert(txn dbutil.Execable) {
  87. if txn == nil {
  88. txn = f.db
  89. }
  90. var err error
  91. var decryptionInfo []byte
  92. if f.DecryptionInfo != nil {
  93. decryptionInfo, err = json.Marshal(f.DecryptionInfo)
  94. if err != nil {
  95. f.log.Warnfln("Failed to marshal decryption info of %v: %v", f.MXC, err)
  96. panic(err)
  97. }
  98. }
  99. _, err = txn.Exec(fileInsert,
  100. f.URL, f.Encrypted, strPtr(f.ID), f.MXC.String(), f.Size,
  101. positiveIntToNullInt32(f.Width), positiveIntToNullInt32(f.Height),
  102. decryptionInfo, f.Timestamp.UnixMilli(),
  103. )
  104. if err != nil {
  105. f.log.Warnfln("Failed to insert copied file %v: %v", f.MXC, err)
  106. panic(err)
  107. }
  108. }
  109. func (f *File) Delete() {
  110. _, err := f.db.Exec("DELETE FROM discord_file WHERE url=$1 AND encrypted=$2", f.URL, f.Encrypted)
  111. if err != nil {
  112. f.log.Warnfln("Failed to delete copied file %v: %v", f.MXC, err)
  113. panic(err)
  114. }
  115. }