sam
b47ed7b699
the /users/{id} prefix contains most API routes so it's not a good idea to put a single rate limit on *all* of them combined. the rate limiter will now ignore the /users/{id} prefix *if* there's a second {id} parameter in the URL. also, X-RateLimit-Bucket is no longer hashed, so it can be directly decoded by clients to get the actual bucket name. i'm not sure if this will actually be useful, but it's nice to have the option.
139 lines
3 KiB
Go
139 lines
3 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/hex"
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/go-chi/httprate"
|
|
)
|
|
|
|
func userOrIp(r *http.Request, userId *string) (string, error) {
|
|
if userId != nil {
|
|
return *userId, nil
|
|
}
|
|
|
|
return httprate.KeyByRealIP(r)
|
|
}
|
|
|
|
type Limiter struct {
|
|
bucketRateLimiters map[bucketKey]*httprate.RateLimiter
|
|
globalRateLimiters map[string]*httprate.RateLimiter
|
|
|
|
bucketMu *sync.Mutex
|
|
globalMu *sync.Mutex
|
|
}
|
|
|
|
type LimitData struct {
|
|
Global bool
|
|
Bucket string
|
|
Remaining int
|
|
Reset int
|
|
}
|
|
|
|
type bucketKey struct {
|
|
user, bucket string
|
|
}
|
|
|
|
func NewLimiter() *Limiter {
|
|
return &Limiter{
|
|
bucketRateLimiters: make(map[bucketKey]*httprate.RateLimiter),
|
|
globalRateLimiters: make(map[string]*httprate.RateLimiter),
|
|
|
|
bucketMu: new(sync.Mutex),
|
|
globalMu: new(sync.Mutex),
|
|
}
|
|
}
|
|
|
|
// TryLimit handles rate limiting for a request.
|
|
// If this function returns true, a rate limit was hit, and the request should be aborted with a 429 status.
|
|
func (l *Limiter) TryLimit(w http.ResponseWriter, r *http.Request, userID *string, template string) (bool, error) {
|
|
user, err := userOrIp(r, userID)
|
|
if err != nil {
|
|
return false, fmt.Errorf("getting user or IP: %w", err)
|
|
}
|
|
|
|
// Check the globalLimiter rate limit for this user
|
|
globalLimiter := l.globalLimiter(user)
|
|
if globalLimiter.OnLimit(w, r, "") {
|
|
w.Header().Add("X-RateLimit-Global", "true")
|
|
return true, nil
|
|
}
|
|
|
|
bucket := requestBucket(r.Method, template)
|
|
w.Header().Add("X-RateLimit-Bucket", bucket)
|
|
|
|
bucketLimiter := l.bucketLimiter(user, r.Method, bucket)
|
|
if bucketLimiter.OnLimit(w, r, "") {
|
|
return true, nil
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
func getReset(w http.ResponseWriter) int64 {
|
|
header := w.Header().Get("X-RateLimit-Reset")
|
|
if header == "" {
|
|
return 0
|
|
}
|
|
|
|
i, err := strconv.ParseInt(header, 10, 64)
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
|
|
return i
|
|
}
|
|
|
|
func requestBucket(method, template string) string {
|
|
return hex.EncodeToString([]byte(method + "-" + template))
|
|
}
|
|
|
|
func (l *Limiter) globalLimiter(user string) *httprate.RateLimiter {
|
|
l.globalMu.Lock()
|
|
defer l.globalMu.Unlock()
|
|
|
|
limiter, ok := l.globalRateLimiters[user]
|
|
if ok {
|
|
return limiter
|
|
}
|
|
|
|
limiter = httprate.NewRateLimiter(20, time.Second)
|
|
l.globalRateLimiters[user] = limiter
|
|
return limiter
|
|
}
|
|
|
|
func (l *Limiter) bucketLimiter(user, method, bucket string) *httprate.RateLimiter {
|
|
l.bucketMu.Lock()
|
|
defer l.bucketMu.Unlock()
|
|
|
|
limiter, ok := l.bucketRateLimiters[bucketKey{user, bucket}]
|
|
if ok {
|
|
return limiter
|
|
}
|
|
|
|
requestLimit, windowLength := requestLimitFor(method)
|
|
|
|
limiter = httprate.NewRateLimiter(requestLimit, windowLength)
|
|
l.bucketRateLimiters[bucketKey{user, bucket}] = limiter
|
|
return limiter
|
|
}
|
|
|
|
// returns the request limit and window length for the given method.
|
|
// methods that change state have lower rate limits
|
|
func requestLimitFor(method string) (int, time.Duration) {
|
|
switch strings.ToUpper(method) {
|
|
case "PATCH", "POST":
|
|
return 3, time.Second
|
|
case "DELETE":
|
|
return 2, 5 * time.Second
|
|
case "GET":
|
|
return 10, time.Second
|
|
default:
|
|
return 5, time.Second
|
|
}
|
|
}
|