feat(backend): separate rate limits into buckets
This commit is contained in:
parent
d11f296026
commit
c95285e26b
4 changed files with 118 additions and 11 deletions
83
backend/server/rate/rate.go
Normal file
83
backend/server/rate/rate.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
package rate
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/httprate"
|
||||
"github.com/gobwas/glob"
|
||||
)
|
||||
|
||||
type Limiter struct {
|
||||
scopes []*scopedLimiter
|
||||
defaultLimiter func(http.Handler) http.Handler
|
||||
|
||||
windowLength time.Duration
|
||||
options []httprate.Option
|
||||
|
||||
wildcardScopes []*scopedLimiter
|
||||
}
|
||||
|
||||
type scopedLimiter struct {
|
||||
Method, Pattern string
|
||||
|
||||
glob glob.Glob
|
||||
handler func(http.Handler) http.Handler
|
||||
}
|
||||
|
||||
func NewLimiter(defaultLimit int, windowLength time.Duration, options ...httprate.Option) *Limiter {
|
||||
return &Limiter{
|
||||
windowLength: windowLength,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Limiter) Scope(method, pattern string, requestLimit int) error {
|
||||
handler := httprate.Limit(requestLimit, l.windowLength, l.options...)
|
||||
|
||||
g, err := glob.Compile("/v*"+pattern, '/')
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if method == "*" {
|
||||
l.wildcardScopes = append(l.wildcardScopes, &scopedLimiter{method, pattern, g, handler})
|
||||
} else {
|
||||
l.scopes = append(l.scopes, &scopedLimiter{method, pattern, g, handler})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Limiter) Handler() func(http.Handler) http.Handler {
|
||||
sort.Slice(l.scopes, func(i, j int) bool {
|
||||
len1 := len(strings.Split(l.scopes[i].Pattern, "/"))
|
||||
len2 := len(strings.Split(l.scopes[j].Pattern, "/"))
|
||||
|
||||
return len1 > len2
|
||||
})
|
||||
l.scopes = append(l.scopes, l.wildcardScopes...)
|
||||
|
||||
return l.handle
|
||||
}
|
||||
|
||||
func (l *Limiter) handle(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
for _, s := range l.scopes {
|
||||
if (r.Method == s.Method || s.Method == "*") && s.glob.Match(r.URL.Path) {
|
||||
bucket := s.Pattern
|
||||
if s.Method != "*" {
|
||||
bucket = s.Method + " " + s.Pattern
|
||||
}
|
||||
w.Header().Set("X-RateLimit-Bucket", bucket)
|
||||
|
||||
s.handler(next).ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("X-RateLimit-Bucket", "/")
|
||||
l.defaultLimiter(next).ServeHTTP(w, r)
|
||||
})
|
||||
}
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"codeberg.org/u1f320/pronouns.cc/backend/db"
|
||||
"codeberg.org/u1f320/pronouns.cc/backend/server/auth"
|
||||
"codeberg.org/u1f320/pronouns.cc/backend/server/rate"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/go-chi/httprate"
|
||||
|
@ -45,11 +46,10 @@ func New() (*Server, error) {
|
|||
s.Router.Use(s.maybeAuth)
|
||||
|
||||
// rate limit handling
|
||||
// - 120 req/minute (2/s)
|
||||
// - base is 120 req/minute (2/s)
|
||||
// - keyed by Authorization header if valid token is provided, otherwise by IP
|
||||
// - returns rate limit reset info in error
|
||||
s.Router.Use(httprate.Limit(
|
||||
120, time.Minute,
|
||||
rateLimiter := rate.NewLimiter(120, time.Minute,
|
||||
httprate.WithKeyFuncs(func(r *http.Request) (string, error) {
|
||||
_, ok := ClaimsFromContext(r.Context())
|
||||
if token := r.Header.Get("Authorization"); ok && token != "" {
|
||||
|
@ -69,7 +69,33 @@ func New() (*Server, error) {
|
|||
RatelimitReset: &reset,
|
||||
})
|
||||
}),
|
||||
))
|
||||
)
|
||||
|
||||
// set scopes
|
||||
// users
|
||||
rateLimiter.Scope("GET", "/users/*", 60)
|
||||
rateLimiter.Scope("PATCH", "/users/@me", 5)
|
||||
|
||||
// members
|
||||
rateLimiter.Scope("GET", "/users/*/members", 60)
|
||||
rateLimiter.Scope("GET", "/users/*/members/*", 60)
|
||||
|
||||
rateLimiter.Scope("POST", "/members", 5)
|
||||
rateLimiter.Scope("GET", "/members/*", 60)
|
||||
rateLimiter.Scope("PATCH", "/members/*", 5)
|
||||
rateLimiter.Scope("DELETE", "/members/*", 5)
|
||||
|
||||
// auth
|
||||
rateLimiter.Scope("*", "/auth/*", 20)
|
||||
rateLimiter.Scope("*", "/auth/tokens", 10)
|
||||
rateLimiter.Scope("*", "/auth/invites", 10)
|
||||
rateLimiter.Scope("POST", "/auth/discord/*", 10)
|
||||
|
||||
// rate limit handling
|
||||
// - 120 req/minute (2/s)
|
||||
// - keyed by Authorization header if valid token is provided, otherwise by IP
|
||||
// - returns rate limit reset info in error
|
||||
s.Router.Use(rateLimiter.Handler())
|
||||
|
||||
// return an API error for not found + method not allowed
|
||||
s.Router.NotFound(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue