mercury/web/app/middleware.go
2023-09-15 16:33:08 +02:00

116 lines
3 KiB
Go

package app
import (
"context"
"net/http"
"strings"
"time"
"git.sleepycat.moe/sam/mercury/internal/database"
"git.sleepycat.moe/sam/mercury/web/api"
"github.com/go-chi/render"
)
type ctxKey int
const (
ctxKeyClaims ctxKey = 1
)
func (app *App) FrontendAuth(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(database.TokenCookieName)
if err != nil || cookie.Value == "" {
next.ServeHTTP(w, r)
return
}
token, err := app.ParseToken(r.Context(), cookie.Value)
if err != nil {
app.ErrorTemplate(w, r, "Invalid token", "The provided token was not valid. Try clearing your cookies and reloading the page.")
return
}
if token.Expires.Before(time.Now()) {
http.SetCookie(w, &http.Cookie{Name: database.TokenCookieName, Value: "", Expires: time.Now()})
next.ServeHTTP(w, r)
return
}
ctx := context.WithValue(r.Context(), ctxKeyClaims, token)
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
func (app *App) APIAuth(scope database.TokenScope, anonAccess bool) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
header := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
if header == "" {
cookie, err := r.Cookie(database.TokenCookieName)
if err != nil || cookie.Value == "" {
if anonAccess { // no token supplied, but the endpoint allows anonymous access
next.ServeHTTP(w, r)
return
}
render.Status(r, api.ErrCodeStatus(api.ErrInvalidToken))
render.JSON(w, r, api.Error{
Code: api.ErrInvalidToken,
Message: api.ErrCodeMessage(api.ErrInvalidToken),
Details: "No token supplied",
})
return
}
header = cookie.Value
}
token, err := app.ParseToken(r.Context(), header)
if err != nil {
render.Status(r, api.ErrCodeStatus(api.ErrInvalidToken))
render.JSON(w, r, api.Error{
Code: api.ErrInvalidToken,
Message: api.ErrCodeMessage(api.ErrInvalidToken),
Details: "Could not parse token",
})
return
}
if token.Expires.Before(time.Now()) {
render.Status(r, api.ErrCodeStatus(api.ErrInvalidToken))
render.JSON(w, r, api.Error{
Code: api.ErrInvalidToken,
Message: api.ErrCodeMessage(api.ErrInvalidToken),
Details: "Token is expired",
})
return
}
if !token.Scopes.Has(scope) {
render.Status(r, api.ErrCodeStatus(api.ErrMissingScope))
render.JSON(w, r, api.Error{
Code: api.ErrMissingScope,
Message: api.ErrCodeMessage(api.ErrMissingScope),
})
return
}
ctx := context.WithValue(r.Context(), ctxKeyClaims, token)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
func (app *App) TokenFromContext(ctx context.Context) (database.Token, bool) {
v := ctx.Value(ctxKeyClaims)
if v == nil {
return database.Token{}, false
}
claims, ok := v.(database.Token)
return claims, ok
}