file.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. package database
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "errors"
  6. "time"
  7. log "maunium.net/go/maulogger/v2"
  8. "maunium.net/go/mautrix/crypto/attachment"
  9. "maunium.net/go/mautrix/id"
  10. "maunium.net/go/mautrix/util/dbutil"
  11. )
  12. type FileQuery struct {
  13. db *Database
  14. log log.Logger
  15. }
  16. // language=postgresql
  17. const (
  18. fileSelect = "SELECT url, encrypted, id, mxc, size, width, height, mime_type, decryption_info, timestamp FROM discord_file"
  19. fileInsert = `
  20. INSERT INTO discord_file (url, encrypted, id, mxc, size, width, height, mime_type, decryption_info, timestamp)
  21. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
  22. `
  23. )
  24. func (fq *FileQuery) New() *File {
  25. return &File{
  26. db: fq.db,
  27. log: fq.log,
  28. }
  29. }
  30. func (fq *FileQuery) Get(url string, encrypted bool) *File {
  31. query := fileSelect + " WHERE url=$1 AND encrypted=$2"
  32. return fq.New().Scan(fq.db.QueryRow(query, url, encrypted))
  33. }
  34. type File struct {
  35. db *Database
  36. log log.Logger
  37. URL string
  38. Encrypted bool
  39. ID string
  40. MXC id.ContentURI
  41. Size int
  42. Width int
  43. Height int
  44. MimeType string
  45. DecryptionInfo *attachment.EncryptedFile
  46. Timestamp time.Time
  47. }
  48. func (f *File) Scan(row dbutil.Scannable) *File {
  49. var fileID, decryptionInfo sql.NullString
  50. var width, height sql.NullInt32
  51. var timestamp int64
  52. var mxc string
  53. err := row.Scan(&f.URL, &f.Encrypted, &fileID, &mxc, &f.Size, &width, &height, &f.MimeType, &decryptionInfo, &timestamp)
  54. if err != nil {
  55. if !errors.Is(err, sql.ErrNoRows) {
  56. f.log.Errorln("Database scan failed:", err)
  57. panic(err)
  58. }
  59. return nil
  60. }
  61. f.ID = fileID.String
  62. f.Timestamp = time.UnixMilli(timestamp)
  63. f.Width = int(width.Int32)
  64. f.Height = int(height.Int32)
  65. f.MXC, err = id.ParseContentURI(mxc)
  66. if err != nil {
  67. f.log.Errorfln("Failed to parse content URI %s: %v", mxc, err)
  68. panic(err)
  69. }
  70. if decryptionInfo.Valid {
  71. err = json.Unmarshal([]byte(decryptionInfo.String), &f.DecryptionInfo)
  72. if err != nil {
  73. f.log.Errorfln("Failed to unmarshal decryption info of %v: %v", f.MXC, err)
  74. panic(err)
  75. }
  76. }
  77. return f
  78. }
  79. func positiveIntToNullInt32(val int) (ptr sql.NullInt32) {
  80. if val > 0 {
  81. ptr.Valid = true
  82. ptr.Int32 = int32(val)
  83. }
  84. return
  85. }
  86. func (f *File) Insert(txn dbutil.Execable) {
  87. if txn == nil {
  88. txn = f.db
  89. }
  90. var decryptionInfoStr sql.NullString
  91. if f.DecryptionInfo != nil {
  92. decryptionInfo, err := json.Marshal(f.DecryptionInfo)
  93. if err != nil {
  94. f.log.Warnfln("Failed to marshal decryption info of %v: %v", f.MXC, err)
  95. panic(err)
  96. }
  97. decryptionInfoStr.Valid = true
  98. decryptionInfoStr.String = string(decryptionInfo)
  99. }
  100. _, err := txn.Exec(fileInsert,
  101. f.URL, f.Encrypted, strPtr(f.ID), f.MXC.String(), f.Size,
  102. positiveIntToNullInt32(f.Width), positiveIntToNullInt32(f.Height), f.MimeType,
  103. decryptionInfoStr, f.Timestamp.UnixMilli(),
  104. )
  105. if err != nil {
  106. f.log.Warnfln("Failed to insert copied file %v: %v", f.MXC, err)
  107. panic(err)
  108. }
  109. }
  110. func (f *File) Delete() {
  111. _, err := f.db.Exec("DELETE FROM discord_file WHERE url=$1 AND encrypted=$2", f.URL, f.Encrypted)
  112. if err != nil {
  113. f.log.Warnfln("Failed to delete copied file %v: %v", f.MXC, err)
  114. panic(err)
  115. }
  116. }