Foxnouns.NET/rate/rate_limiter.go
2024-09-10 16:53:43 +02:00

145 lines
3.1 KiB
Go

package main
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/go-chi/httprate"
)
func userOrIp(r *http.Request, userId *string) (string, error) {
if userId != nil {
return *userId, nil
}
return httprate.KeyByRealIP(r)
}
type Limiter struct {
bucketRateLimiters map[bucketKey]*httprate.RateLimiter
globalRateLimiters map[string]*httprate.RateLimiter
bucketMu *sync.Mutex
globalMu *sync.Mutex
}
type LimitData struct {
Global bool
Bucket string
Remaining int
Reset int
}
type bucketKey struct {
user, bucket string
}
func NewLimiter() *Limiter {
return &Limiter{
bucketRateLimiters: make(map[bucketKey]*httprate.RateLimiter),
globalRateLimiters: make(map[string]*httprate.RateLimiter),
bucketMu: new(sync.Mutex),
globalMu: new(sync.Mutex),
}
}
// TryLimit handles rate limiting for a request.
// If this function returns true, a rate limit was hit, and the request should be aborted with a 429 status.
func (l *Limiter) TryLimit(w http.ResponseWriter, r *http.Request, userID *string, template string) (bool, error) {
user, err := userOrIp(r, userID)
if err != nil {
return false, fmt.Errorf("getting user or IP: %w", err)
}
// Check the globalLimiter rate limit for this user
globalLimiter := l.globalLimiter(user)
if globalLimiter.OnLimit(w, r, "") {
w.Header().Add("X-RateLimit-Global", "true")
return true, nil
}
bucket := requestBucket(r.Method, template)
w.Header().Add("X-RateLimit-Bucket", bucket)
bucketLimiter := l.bucketLimiter(user, r.Method, bucket)
if bucketLimiter.OnLimit(w, r, "") {
return true, nil
}
return false, nil
}
func getReset(w http.ResponseWriter) int64 {
header := w.Header().Get("X-RateLimit-Reset")
if header == "" {
return 0
}
i, err := strconv.ParseInt(header, 10, 64)
if err != nil {
return 0
}
return i
}
func requestBucket(method, template string) string {
hasher := sha256.New()
_, err := hasher.Write([]byte(method + "-" + template))
if err != nil {
panic(err)
}
return hex.EncodeToString(hasher.Sum(nil))
}
func (l *Limiter) globalLimiter(user string) *httprate.RateLimiter {
l.globalMu.Lock()
defer l.globalMu.Unlock()
limiter, ok := l.globalRateLimiters[user]
if ok {
return limiter
}
limiter = httprate.NewRateLimiter(20, time.Second)
l.globalRateLimiters[user] = limiter
return limiter
}
func (l *Limiter) bucketLimiter(user, method, bucket string) *httprate.RateLimiter {
l.bucketMu.Lock()
defer l.bucketMu.Unlock()
limiter, ok := l.bucketRateLimiters[bucketKey{user, bucket}]
if ok {
return limiter
}
requestLimit, windowLength := requestLimitFor(method)
limiter = httprate.NewRateLimiter(requestLimit, windowLength)
l.bucketRateLimiters[bucketKey{user, bucket}] = limiter
return limiter
}
// returns the request limit and window length for the given method.
// methods that change state have lower rate limits
func requestLimitFor(method string) (int, time.Duration) {
switch strings.ToUpper(method) {
case "PATCH", "POST":
return 1, time.Second
case "DELETE":
return 1, 5 * time.Second
case "GET":
return 3, time.Second
default:
return 2, time.Second
}
}