file.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. package database
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "errors"
  6. "time"
  7. "go.mau.fi/util/dbutil"
  8. log "maunium.net/go/maulogger/v2"
  9. "maunium.net/go/mautrix/crypto/attachment"
  10. "maunium.net/go/mautrix/id"
  11. )
  12. type FileQuery struct {
  13. db *Database
  14. log log.Logger
  15. }
  16. // language=postgresql
  17. const (
  18. fileSelect = "SELECT url, encrypted, mxc, id, emoji_name, size, width, height, mime_type, decryption_info, timestamp FROM discord_file"
  19. fileInsert = `
  20. INSERT INTO discord_file (url, encrypted, mxc, id, emoji_name, size, width, height, mime_type, decryption_info, timestamp)
  21. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
  22. `
  23. )
  24. func (fq *FileQuery) New() *File {
  25. return &File{
  26. db: fq.db,
  27. log: fq.log,
  28. }
  29. }
  30. func (fq *FileQuery) Get(url string, encrypted bool) *File {
  31. query := fileSelect + " WHERE url=$1 AND encrypted=$2"
  32. return fq.New().Scan(fq.db.QueryRow(query, url, encrypted))
  33. }
  34. func (fq *FileQuery) GetEmojiByMXC(mxc id.ContentURI) *File {
  35. query := fileSelect + " WHERE mxc=$1 AND emoji_name<>'' LIMIT 1"
  36. return fq.New().Scan(fq.db.QueryRow(query, mxc.String()))
  37. }
  38. type File struct {
  39. db *Database
  40. log log.Logger
  41. URL string
  42. Encrypted bool
  43. MXC id.ContentURI
  44. ID string
  45. EmojiName string
  46. Size int
  47. Width int
  48. Height int
  49. MimeType string
  50. DecryptionInfo *attachment.EncryptedFile
  51. Timestamp time.Time
  52. }
  53. func (f *File) Scan(row dbutil.Scannable) *File {
  54. var fileID, emojiName, decryptionInfo sql.NullString
  55. var width, height sql.NullInt32
  56. var timestamp int64
  57. var mxc string
  58. err := row.Scan(&f.URL, &f.Encrypted, &mxc, &fileID, &emojiName, &f.Size, &width, &height, &f.MimeType, &decryptionInfo, &timestamp)
  59. if err != nil {
  60. if !errors.Is(err, sql.ErrNoRows) {
  61. f.log.Errorln("Database scan failed:", err)
  62. panic(err)
  63. }
  64. return nil
  65. }
  66. f.ID = fileID.String
  67. f.EmojiName = emojiName.String
  68. f.Timestamp = time.UnixMilli(timestamp).UTC()
  69. f.Width = int(width.Int32)
  70. f.Height = int(height.Int32)
  71. f.MXC, err = id.ParseContentURI(mxc)
  72. if err != nil {
  73. f.log.Errorfln("Failed to parse content URI %s: %v", mxc, err)
  74. panic(err)
  75. }
  76. if decryptionInfo.Valid {
  77. err = json.Unmarshal([]byte(decryptionInfo.String), &f.DecryptionInfo)
  78. if err != nil {
  79. f.log.Errorfln("Failed to unmarshal decryption info of %v: %v", f.MXC, err)
  80. panic(err)
  81. }
  82. }
  83. return f
  84. }
  85. func positiveIntToNullInt32(val int) (ptr sql.NullInt32) {
  86. if val > 0 {
  87. ptr.Valid = true
  88. ptr.Int32 = int32(val)
  89. }
  90. return
  91. }
  92. func (f *File) Insert(txn dbutil.Execable) {
  93. if txn == nil {
  94. txn = f.db
  95. }
  96. var decryptionInfoStr sql.NullString
  97. if f.DecryptionInfo != nil {
  98. decryptionInfo, err := json.Marshal(f.DecryptionInfo)
  99. if err != nil {
  100. f.log.Warnfln("Failed to marshal decryption info of %v: %v", f.MXC, err)
  101. panic(err)
  102. }
  103. decryptionInfoStr.Valid = true
  104. decryptionInfoStr.String = string(decryptionInfo)
  105. }
  106. _, err := txn.Exec(fileInsert,
  107. f.URL, f.Encrypted, f.MXC.String(), strPtr(f.ID), strPtr(f.EmojiName), f.Size,
  108. positiveIntToNullInt32(f.Width), positiveIntToNullInt32(f.Height), f.MimeType,
  109. decryptionInfoStr, f.Timestamp.UnixMilli(),
  110. )
  111. if err != nil {
  112. f.log.Warnfln("Failed to insert copied file %v: %v", f.MXC, err)
  113. panic(err)
  114. }
  115. }
  116. func (f *File) Delete() {
  117. _, err := f.db.Exec("DELETE FROM discord_file WHERE url=$1 AND encrypted=$2", f.URL, f.Encrypted)
  118. if err != nil {
  119. f.log.Warnfln("Failed to delete copied file %v: %v", f.MXC, err)
  120. panic(err)
  121. }
  122. }