116 lines
3 KiB
Go
116 lines
3 KiB
Go
package app
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.sleepycat.moe/sam/mercury/internal/database"
|
|
"git.sleepycat.moe/sam/mercury/web/api"
|
|
"github.com/go-chi/render"
|
|
)
|
|
|
|
type ctxKey int
|
|
|
|
const (
|
|
ctxKeyClaims ctxKey = 1
|
|
)
|
|
|
|
func (app *App) FrontendAuth(next http.Handler) http.Handler {
|
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
|
cookie, err := r.Cookie(database.TokenCookieName)
|
|
if err != nil || cookie.Value == "" {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
token, err := app.ParseToken(r.Context(), cookie.Value)
|
|
if err != nil {
|
|
app.ErrorTemplate(w, r, "Invalid token", "The provided token was not valid. Try clearing your cookies and reloading the page.")
|
|
return
|
|
}
|
|
|
|
if token.Expires.Before(time.Now()) {
|
|
http.SetCookie(w, &http.Cookie{Name: database.TokenCookieName, Value: "", Expires: time.Now()})
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), ctxKeyClaims, token)
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
|
|
return http.HandlerFunc(fn)
|
|
}
|
|
|
|
func (app *App) APIAuth(scope database.TokenScope, anonAccess bool) func(next http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
header := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
|
|
if header == "" {
|
|
cookie, err := r.Cookie(database.TokenCookieName)
|
|
if err != nil || cookie.Value == "" {
|
|
if anonAccess { // no token supplied, but the endpoint allows anonymous access
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
render.Status(r, api.ErrCodeStatus(api.ErrInvalidToken))
|
|
render.JSON(w, r, api.Error{
|
|
Code: api.ErrInvalidToken,
|
|
Message: api.ErrCodeMessage(api.ErrInvalidToken),
|
|
Details: "No token supplied",
|
|
})
|
|
return
|
|
}
|
|
header = cookie.Value
|
|
}
|
|
|
|
token, err := app.ParseToken(r.Context(), header)
|
|
if err != nil {
|
|
render.Status(r, api.ErrCodeStatus(api.ErrInvalidToken))
|
|
render.JSON(w, r, api.Error{
|
|
Code: api.ErrInvalidToken,
|
|
Message: api.ErrCodeMessage(api.ErrInvalidToken),
|
|
Details: "Could not parse token",
|
|
})
|
|
return
|
|
}
|
|
|
|
if token.Expires.Before(time.Now()) {
|
|
render.Status(r, api.ErrCodeStatus(api.ErrInvalidToken))
|
|
render.JSON(w, r, api.Error{
|
|
Code: api.ErrInvalidToken,
|
|
Message: api.ErrCodeMessage(api.ErrInvalidToken),
|
|
Details: "Token is expired",
|
|
})
|
|
return
|
|
}
|
|
|
|
if !token.Scopes.Has(scope) {
|
|
render.Status(r, api.ErrCodeStatus(api.ErrMissingScope))
|
|
render.JSON(w, r, api.Error{
|
|
Code: api.ErrMissingScope,
|
|
Message: api.ErrCodeMessage(api.ErrMissingScope),
|
|
})
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), ctxKeyClaims, token)
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
func (app *App) TokenFromContext(ctx context.Context) (database.Token, bool) {
|
|
v := ctx.Value(ctxKeyClaims)
|
|
if v == nil {
|
|
return database.Token{}, false
|
|
}
|
|
|
|
claims, ok := v.(database.Token)
|
|
return claims, ok
|
|
}
|