pronounscc/backend/server/rate/rate.go
2023-03-13 01:31:12 +01:00

96 lines
2.2 KiB
Go

package rate
import (
"net/http"
"os"
"sort"
"strings"
"time"
"github.com/go-chi/httprate"
"github.com/gobwas/glob"
)
type Limiter struct {
scopes []*scopedLimiter
defaultLimiter func(http.Handler) http.Handler
windowLength time.Duration
options []httprate.Option
wildcardScopes []*scopedLimiter
frontendIP string
}
type scopedLimiter struct {
Method, Pattern string
glob glob.Glob
handler func(http.Handler) http.Handler
}
func NewLimiter(defaultLimit int, windowLength time.Duration, options ...httprate.Option) *Limiter {
return &Limiter{
windowLength: windowLength,
options: options,
defaultLimiter: httprate.Limit(defaultLimit, windowLength, options...),
frontendIP: os.Getenv("FRONTEND_IP"),
}
}
func (l *Limiter) Scope(method, pattern string, requestLimit int) error {
handler := httprate.Limit(requestLimit, l.windowLength, l.options...)
g, err := glob.Compile("/v*"+pattern, '/')
if err != nil {
return err
}
if method == "*" {
l.wildcardScopes = append(l.wildcardScopes, &scopedLimiter{method, pattern, g, handler})
} else {
l.scopes = append(l.scopes, &scopedLimiter{method, pattern, g, handler})
}
return nil
}
func (l *Limiter) Handler() func(http.Handler) http.Handler {
sort.Slice(l.scopes, func(i, j int) bool {
len1 := len(strings.Split(l.scopes[i].Pattern, "/"))
len2 := len(strings.Split(l.scopes[j].Pattern, "/"))
return len1 > len2
})
l.scopes = append(l.scopes, l.wildcardScopes...)
return l.handle
}
func (l *Limiter) handle(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if l.frontendIP != "" {
ip, err := httprate.KeyByIP(r)
if err == nil && ip == l.frontendIP {
// frontend gets to bypass ratelimit
next.ServeHTTP(w, r)
return
}
}
for _, s := range l.scopes {
if (r.Method == s.Method || s.Method == "*") && s.glob.Match(r.URL.Path) {
bucket := s.Pattern
if s.Method != "*" {
bucket = s.Method + " " + s.Pattern
}
w.Header().Set("X-RateLimit-Bucket", bucket)
s.handler(next).ServeHTTP(w, r)
return
}
}
w.Header().Set("X-RateLimit-Bucket", "/")
l.defaultLimiter(next).ServeHTTP(w, r)
})
}