feat: add rate limiter proxy

This commit is contained in:
sam 2024-09-10 16:53:43 +02:00
parent 13a0cac663
commit 3d22385689
Signed by: sam
GPG key ID: B4EF20DDE721CAA1
10 changed files with 498 additions and 0 deletions

7
rate/README.md Normal file
View file

@ -0,0 +1,7 @@
# Rate limiting proxy
This is a service that's meant to sit between nginx (or another reverse proxy) and Foxnouns.Backend.
To configure, copy `proxy-config.example.json` to your working directory, rename it to `proxy-config.json`,
and change any keys you need to.
Build with `go build -v .` and run with `./rate`

192
rate/handler.go Normal file
View file

@ -0,0 +1,192 @@
package main
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/http/httputil"
"strconv"
"strings"
"time"
)
const notFoundError = `{"status":404,"code":"NOT_FOUND","message":"Not found"}`
const internalServerErrorTemplate = `{"status":500,"code":"INTERNAL_SERVER_ERROR","error_id":"%v","message":"Internal server error"}`
const rateLimitedErrorTemplate = `{"status":429,"code":"RATE_LIMITED","reset":%v,"message":"You are being rate limited"}`
// error ID chosen by fair duckduckgo search, guaranteed to be random
// (we just need to return *any* error ID, this will do)
const errorID = "951c03eadb6b474db35c23d47a10f6c0"
type Handler struct {
Port int `json:"port"`
ProxyTarget string `json:"proxy_target"`
Debug bool `json:"debug"`
limiter *Limiter
proxy *httputil.ReverseProxy
client *http.Client
}
func (hn *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// all public api endpoints are prefixed with this
if !strings.HasPrefix(r.URL.Path, "/api/v2") {
w.WriteHeader(http.StatusNotFound)
return
}
data, err := hn.requestData(r)
if err != nil {
if rde, ok := err.(requestDataError); ok {
switch rde.Type {
case "badRequest":
if hn.Debug {
log.Printf("Bad request error for path %v\n", r.URL.Path)
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Header().Set("Content-Length", strconv.Itoa(len(notFoundError)))
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte(notFoundError))
return
case "internalServerError":
respBody := fmt.Sprintf(internalServerErrorTemplate, rde.ErrorID)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Header().Set("Content-Length", strconv.Itoa(len(respBody)))
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(respBody))
return
}
}
log.Printf("internal error while parsing request data response: %v\n", err)
respBody := fmt.Sprintf(internalServerErrorTemplate, errorID)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Header().Set("Content-Length", strconv.Itoa(len(respBody)))
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(respBody))
return
}
if hn.Debug {
log.Printf("proxying request to %v %v", r.Method, data.Template)
}
isLimited, err := hn.limiter.TryLimit(w, r, data.UserID, data.Template)
if err != nil {
log.Printf("error checking rate limit: %v\n", err)
respBody := fmt.Sprintf(internalServerErrorTemplate, errorID)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Header().Set("Content-Length", strconv.Itoa(len(respBody)))
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(respBody))
return
}
if isLimited {
var resetTime time.Time
if reset := getReset(w); reset == 0 {
resetTime = time.Now().UTC().Add(10 * time.Second)
} else {
resetTime = time.Unix(reset, 0)
}
respBody := fmt.Sprintf(rateLimitedErrorTemplate, resetTime.UnixMilli())
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Header().Set("Content-Length", strconv.Itoa(len(respBody)))
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(respBody))
return
}
hn.proxy.ServeHTTP(w, r)
}
func (hn *Handler) requestData(r *http.Request) (data requestDataResponse, err error) {
var token *string
if header := r.Header.Get("Authorization"); header != "" {
token = &header
}
url := hn.ProxyTarget + "/api/internal/request-data"
reqData, err := json.Marshal(requestDataRequest{Token: token, Method: r.Method, Path: r.URL.Path})
if err != nil {
return data, fmt.Errorf("marshaling request json: %w", err)
}
req, err := http.NewRequestWithContext(r.Context(), "POST", url, bytes.NewReader(reqData))
if err != nil {
return data, fmt.Errorf("creating request: %w", err)
}
req.Header.Set("User-Agent", "Foxnouns.NET/rate")
req.Header.Set("Content-Type", "application/json; charset=utf-8")
resp, err := hn.client.Do(req)
if err != nil {
return data, fmt.Errorf("making request: %w", err)
}
defer resp.Body.Close()
// If we got a bad request error, that means the endpoint wasn't found, and we should tell the client that
if resp.StatusCode == http.StatusBadRequest {
return data, requestDataError{Type: "badRequest"}
}
// If we got an internal server error we have to forward the error ID
if resp.StatusCode == http.StatusInternalServerError {
b, err := io.ReadAll(resp.Body)
if err != nil {
return data, fmt.Errorf("reading internal server error body: %w", err)
}
var fne foxnounsError
err = json.Unmarshal(b, &fne)
if err != nil {
return data, fmt.Errorf("unmarshaling internal server error: %w", err)
}
return data, requestDataError{Type: "internalServerError", ErrorID: fne.ErrorId}
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return data, fmt.Errorf("reading body: %w", err)
}
err = json.Unmarshal(b, &data)
if err != nil {
return data, fmt.Errorf("unmarshaling data: %w", err)
}
return data, nil
}
type requestDataRequest struct {
Token *string `json:"token"`
Method string `json:"method"`
Path string `json:"path"`
}
type requestDataResponse struct {
UserID *string `json:"user_id"`
Template string `json:"template"`
}
type foxnounsError struct {
ErrorId string `json:"error_id"`
}
type requestDataError struct {
Type string
ErrorID string
}
func (e requestDataError) Error() string { return "request data error: " + e.Type }

46
rate/main.go Normal file
View file

@ -0,0 +1,46 @@
package main
import (
"encoding/json"
"log"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strconv"
)
func main() {
// read config file and parse it
confB, err := os.ReadFile("proxy-config.json")
if err != nil {
log.Fatalf("reading config.json: %v", err)
}
hn := &Handler{}
err = json.Unmarshal(confB, hn)
if err != nil {
log.Fatalf("unmarshaling config.json: %v", err)
}
proxyURL, err := url.Parse(hn.ProxyTarget)
if err != nil {
log.Fatalf("parsing proxy_target as URL: %v", err)
}
hn.proxy = &httputil.ReverseProxy{
Rewrite: func(pr *httputil.ProxyRequest) {
pr.SetURL(proxyURL)
pr.Out.Host = pr.In.Host
},
}
hn.limiter = NewLimiter()
hn.client = &http.Client{}
log.Printf("serving on port %v", hn.Port)
err = http.ListenAndServe(":"+strconv.Itoa(hn.Port), hn)
if err != nil {
log.Fatalf("listening on port %v: %v", hn.Port, err)
}
}

View file

@ -0,0 +1,5 @@
{
"port": 5003,
"proxy_target": "http://localhost:5000",
"debug": true
}

145
rate/rate_limiter.go Normal file
View file

@ -0,0 +1,145 @@
package main
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/go-chi/httprate"
)
func userOrIp(r *http.Request, userId *string) (string, error) {
if userId != nil {
return *userId, nil
}
return httprate.KeyByRealIP(r)
}
type Limiter struct {
bucketRateLimiters map[bucketKey]*httprate.RateLimiter
globalRateLimiters map[string]*httprate.RateLimiter
bucketMu *sync.Mutex
globalMu *sync.Mutex
}
type LimitData struct {
Global bool
Bucket string
Remaining int
Reset int
}
type bucketKey struct {
user, bucket string
}
func NewLimiter() *Limiter {
return &Limiter{
bucketRateLimiters: make(map[bucketKey]*httprate.RateLimiter),
globalRateLimiters: make(map[string]*httprate.RateLimiter),
bucketMu: new(sync.Mutex),
globalMu: new(sync.Mutex),
}
}
// TryLimit handles rate limiting for a request.
// If this function returns true, a rate limit was hit, and the request should be aborted with a 429 status.
func (l *Limiter) TryLimit(w http.ResponseWriter, r *http.Request, userID *string, template string) (bool, error) {
user, err := userOrIp(r, userID)
if err != nil {
return false, fmt.Errorf("getting user or IP: %w", err)
}
// Check the globalLimiter rate limit for this user
globalLimiter := l.globalLimiter(user)
if globalLimiter.OnLimit(w, r, "") {
w.Header().Add("X-RateLimit-Global", "true")
return true, nil
}
bucket := requestBucket(r.Method, template)
w.Header().Add("X-RateLimit-Bucket", bucket)
bucketLimiter := l.bucketLimiter(user, r.Method, bucket)
if bucketLimiter.OnLimit(w, r, "") {
return true, nil
}
return false, nil
}
func getReset(w http.ResponseWriter) int64 {
header := w.Header().Get("X-RateLimit-Reset")
if header == "" {
return 0
}
i, err := strconv.ParseInt(header, 10, 64)
if err != nil {
return 0
}
return i
}
func requestBucket(method, template string) string {
hasher := sha256.New()
_, err := hasher.Write([]byte(method + "-" + template))
if err != nil {
panic(err)
}
return hex.EncodeToString(hasher.Sum(nil))
}
func (l *Limiter) globalLimiter(user string) *httprate.RateLimiter {
l.globalMu.Lock()
defer l.globalMu.Unlock()
limiter, ok := l.globalRateLimiters[user]
if ok {
return limiter
}
limiter = httprate.NewRateLimiter(20, time.Second)
l.globalRateLimiters[user] = limiter
return limiter
}
func (l *Limiter) bucketLimiter(user, method, bucket string) *httprate.RateLimiter {
l.bucketMu.Lock()
defer l.bucketMu.Unlock()
limiter, ok := l.bucketRateLimiters[bucketKey{user, bucket}]
if ok {
return limiter
}
requestLimit, windowLength := requestLimitFor(method)
limiter = httprate.NewRateLimiter(requestLimit, windowLength)
l.bucketRateLimiters[bucketKey{user, bucket}] = limiter
return limiter
}
// returns the request limit and window length for the given method.
// methods that change state have lower rate limits
func requestLimitFor(method string) (int, time.Duration) {
switch strings.ToUpper(method) {
case "PATCH", "POST":
return 1, time.Second
case "DELETE":
return 1, 5 * time.Second
case "GET":
return 3, time.Second
default:
return 2, time.Second
}
}