89 lines
2.2 KiB
Go
89 lines
2.2 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"codeberg.org/pronounscc/pronouns.cc/backend/log"
|
|
"codeberg.org/pronounscc/pronouns.cc/backend/server/auth"
|
|
"github.com/go-chi/render"
|
|
)
|
|
|
|
// maybeAuth is a globally-used middleware.
|
|
func (s *Server) maybeAuth(next http.Handler) http.Handler {
|
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
|
token := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
|
|
if token == "" {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
claims, err := s.Auth.Claims(token)
|
|
if err != nil {
|
|
render.Status(r, errCodeStatuses[ErrInvalidToken])
|
|
render.JSON(w, r, APIError{
|
|
Code: ErrInvalidToken,
|
|
Message: errCodeMessages[ErrInvalidToken],
|
|
})
|
|
return
|
|
}
|
|
|
|
// "valid" here refers to existence and expiry date, not whether the token is known
|
|
valid, err := s.DB.TokenValid(r.Context(), claims.UserID, claims.TokenID)
|
|
if err != nil {
|
|
log.Errorf("validating token for user %v: %v", claims.UserID, err)
|
|
render.Status(r, errCodeStatuses[ErrInternalServerError])
|
|
render.JSON(w, r, APIError{
|
|
Code: ErrInternalServerError,
|
|
Message: errCodeMessages[ErrInternalServerError],
|
|
})
|
|
return
|
|
}
|
|
|
|
if !valid {
|
|
render.Status(r, errCodeStatuses[ErrInvalidToken])
|
|
render.JSON(w, r, APIError{
|
|
Code: ErrInvalidToken,
|
|
Message: errCodeMessages[ErrInvalidToken],
|
|
})
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), ctxKeyClaims, claims)
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
|
|
return http.HandlerFunc(fn)
|
|
}
|
|
|
|
// MustAuth makes a valid token required
|
|
func MustAuth(next http.Handler) http.Handler {
|
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
|
_, ok := ClaimsFromContext(r.Context())
|
|
if !ok {
|
|
render.Status(r, errCodeStatuses[ErrForbidden])
|
|
render.JSON(w, r, APIError{
|
|
Code: ErrForbidden,
|
|
Message: errCodeMessages[ErrForbidden],
|
|
})
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
}
|
|
|
|
return http.HandlerFunc(fn)
|
|
}
|
|
|
|
// ClaimsFromContext returns the auth.Claims in the context, if any.
|
|
func ClaimsFromContext(ctx context.Context) (auth.Claims, bool) {
|
|
v := ctx.Value(ctxKeyClaims)
|
|
if v == nil {
|
|
return auth.Claims{}, false
|
|
}
|
|
|
|
claims, ok := v.(auth.Claims)
|
|
return claims, ok
|
|
}
|