123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 |
- package database
- import (
- "database/sql"
- "encoding/json"
- "errors"
- "time"
- "go.mau.fi/util/dbutil"
- log "maunium.net/go/maulogger/v2"
- "maunium.net/go/mautrix/crypto/attachment"
- "maunium.net/go/mautrix/id"
- )
- type FileQuery struct {
- db *Database
- log log.Logger
- }
- // language=postgresql
- const (
- fileSelect = "SELECT url, encrypted, mxc, id, emoji_name, size, width, height, mime_type, decryption_info, timestamp FROM discord_file"
- fileInsert = `
- INSERT INTO discord_file (url, encrypted, mxc, id, emoji_name, size, width, height, mime_type, decryption_info, timestamp)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
- `
- )
- func (fq *FileQuery) New() *File {
- return &File{
- db: fq.db,
- log: fq.log,
- }
- }
- func (fq *FileQuery) Get(url string, encrypted bool) *File {
- query := fileSelect + " WHERE url=$1 AND encrypted=$2"
- return fq.New().Scan(fq.db.QueryRow(query, url, encrypted))
- }
- func (fq *FileQuery) GetEmojiByMXC(mxc id.ContentURI) *File {
- query := fileSelect + " WHERE mxc=$1 AND emoji_name<>'' LIMIT 1"
- return fq.New().Scan(fq.db.QueryRow(query, mxc.String()))
- }
- type File struct {
- db *Database
- log log.Logger
- URL string
- Encrypted bool
- MXC id.ContentURI
- ID string
- EmojiName string
- Size int
- Width int
- Height int
- MimeType string
- DecryptionInfo *attachment.EncryptedFile
- Timestamp time.Time
- }
- func (f *File) Scan(row dbutil.Scannable) *File {
- var fileID, emojiName, decryptionInfo sql.NullString
- var width, height sql.NullInt32
- var timestamp int64
- var mxc string
- err := row.Scan(&f.URL, &f.Encrypted, &mxc, &fileID, &emojiName, &f.Size, &width, &height, &f.MimeType, &decryptionInfo, ×tamp)
- if err != nil {
- if !errors.Is(err, sql.ErrNoRows) {
- f.log.Errorln("Database scan failed:", err)
- panic(err)
- }
- return nil
- }
- f.ID = fileID.String
- f.EmojiName = emojiName.String
- f.Timestamp = time.UnixMilli(timestamp).UTC()
- f.Width = int(width.Int32)
- f.Height = int(height.Int32)
- f.MXC, err = id.ParseContentURI(mxc)
- if err != nil {
- f.log.Errorfln("Failed to parse content URI %s: %v", mxc, err)
- panic(err)
- }
- if decryptionInfo.Valid {
- err = json.Unmarshal([]byte(decryptionInfo.String), &f.DecryptionInfo)
- if err != nil {
- f.log.Errorfln("Failed to unmarshal decryption info of %v: %v", f.MXC, err)
- panic(err)
- }
- }
- return f
- }
- func positiveIntToNullInt32(val int) (ptr sql.NullInt32) {
- if val > 0 {
- ptr.Valid = true
- ptr.Int32 = int32(val)
- }
- return
- }
- func (f *File) Insert(txn dbutil.Execable) {
- if txn == nil {
- txn = f.db
- }
- var decryptionInfoStr sql.NullString
- if f.DecryptionInfo != nil {
- decryptionInfo, err := json.Marshal(f.DecryptionInfo)
- if err != nil {
- f.log.Warnfln("Failed to marshal decryption info of %v: %v", f.MXC, err)
- panic(err)
- }
- decryptionInfoStr.Valid = true
- decryptionInfoStr.String = string(decryptionInfo)
- }
- _, err := txn.Exec(fileInsert,
- f.URL, f.Encrypted, f.MXC.String(), strPtr(f.ID), strPtr(f.EmojiName), f.Size,
- positiveIntToNullInt32(f.Width), positiveIntToNullInt32(f.Height), f.MimeType,
- decryptionInfoStr, f.Timestamp.UnixMilli(),
- )
- if err != nil {
- f.log.Warnfln("Failed to insert copied file %v: %v", f.MXC, err)
- panic(err)
- }
- }
- func (f *File) Delete() {
- _, err := f.db.Exec("DELETE FROM discord_file WHERE url=$1 AND encrypted=$2", f.URL, f.Encrypted)
- if err != nil {
- f.log.Warnfln("Failed to delete copied file %v: %v", f.MXC, err)
- panic(err)
- }
- }
|