package web import ( "bytes" "crypto/sha256" "database/sql" "encoding/hex" "fmt" "io" "net/http" "codeberg.org/u1f320/filer/db/queries" "emperror.dev/errors" "github.com/google/uuid" "github.com/rs/zerolog/log" ) // 500 megabytes const maxMemory = 500_000_000 func (app *App) uploadFile(w http.ResponseWriter, r *http.Request) { ctx := r.Context() token := r.Header.Get("Authorization") if token == "" { http.Error(w, "No token provided", http.StatusUnauthorized) return } u, err := app.db.GetUserByToken(ctx, token) if err != nil { if errors.Is(err, sql.ErrNoRows) { http.Error(w, "Invalid token provided", http.StatusForbidden) return } } err = r.ParseMultipartForm(maxMemory) if err != nil { log.Err(err).Msg("parsing multipart form data") http.Error(w, "Bad request", http.StatusBadRequest) return } file, fileHeader, err := r.FormFile("file") if err != nil { log.Err(err).Msg("getting file from multipart form") http.Error(w, "Bad request", http.StatusBadRequest) } buffer := new(bytes.Buffer) _, err = io.Copy(buffer, file) if err != nil { log.Err(err).Msg("copying file to temporary buffer") http.Error(w, "Internal server error", http.StatusInternalServerError) return } contentType := http.DetectContentType(buffer.Bytes()) hash, err := getSha256Hash(buffer.Bytes()) if err != nil { log.Err(err).Msg("generating file hash") http.Error(w, "Internal server error", http.StatusInternalServerError) return } dbfile, err := app.db.CreateFile(ctx, queries.CreateFileParams{ ID: uuid.New(), UserID: u.ID, Filename: fileHeader.Filename, ContentType: contentType, Hash: hash, Size: int64(buffer.Len()), }) if err != nil { log.Err(err).Msg("creating file in database") http.Error(w, "Internal server error", http.StatusInternalServerError) return } err = app.db.Store.WriteFile(dbfile.ID, buffer, contentType) if err != nil { log.Err(err).Msg("creating file in store") err = app.db.DeleteFile(ctx, dbfile.ID) if err != nil { log.Err(err).Msg("deleting file from database") } http.Error(w, "Internal server error", http.StatusInternalServerError) return } w.WriteHeader(http.StatusOK) w.Write([]byte(fmt.Sprintf("/%s/%s", hash, fileHeader.Filename))) } func getSha256Hash(b []byte) (string, error) { hasher := sha256.New() _, err := hasher.Write(b) if err != nil { return "", err } return hex.EncodeToString(hasher.Sum(nil)), nil }