From 3d2238568964bac103c0c5bac61c67cc39ce256a Mon Sep 17 00:00:00 2001 From: sam Date: Tue, 10 Sep 2024 16:53:43 +0200 Subject: [PATCH] feat: add rate limiter proxy --- .gitignore | 1 + .../Controllers/InternalController.cs | 83 ++++++++ Foxnouns.Backend/Database/DatabaseContext.cs | 3 + go.mod | 9 + go.sum | 7 + rate/README.md | 7 + rate/handler.go | 192 ++++++++++++++++++ rate/main.go | 46 +++++ rate/proxy-config.example.json | 5 + rate/rate_limiter.go | 145 +++++++++++++ 10 files changed, 498 insertions(+) create mode 100644 Foxnouns.Backend/Controllers/InternalController.cs create mode 100644 go.mod create mode 100644 go.sum create mode 100644 rate/README.md create mode 100644 rate/handler.go create mode 100644 rate/main.go create mode 100644 rate/proxy-config.example.json create mode 100644 rate/rate_limiter.go diff --git a/.gitignore b/.gitignore index 8e84a9d..2d0d096 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ obj/ .version config.ini *.DotSettings.user +proxy-config.json diff --git a/Foxnouns.Backend/Controllers/InternalController.cs b/Foxnouns.Backend/Controllers/InternalController.cs new file mode 100644 index 0000000..859c2fe --- /dev/null +++ b/Foxnouns.Backend/Controllers/InternalController.cs @@ -0,0 +1,83 @@ +using System.Security.Cryptography; +using System.Text.RegularExpressions; +using Foxnouns.Backend.Database; +using Foxnouns.Backend.Utils; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Controllers; +using Microsoft.AspNetCore.Mvc.Routing; +using Microsoft.AspNetCore.Routing.Template; +using Microsoft.EntityFrameworkCore; +using NodaTime; + +namespace Foxnouns.Backend.Controllers; + +[ApiController] +[Route("/api/internal")] +public partial class InternalController(DatabaseContext db, IClock clock) : ControllerBase +{ + [GeneratedRegex(@"(\{\w+\})")] + private static partial Regex PathVarRegex(); + + private static string GetCleanedTemplate(string template) + { + if (template.StartsWith("api/v2")) template = template.Substring("api/v2".Length); + template = PathVarRegex() + .Replace(template, "{id}") // Replace all path variables (almost always IDs) with `{id}` + .Replace("@me", "{id}"); // Also replace hardcoded `@me` with `{id}` + if (template.Contains("{id}")) return template.Split("{id}")[0] + "{id}"; + return template; + } + + [HttpPost("request-data")] + public async Task GetRequestDataAsync([FromBody] RequestDataRequest req) + { + var endpoint = GetEndpoint(HttpContext, req.Path, req.Method); + if (endpoint == null) throw new ApiError.BadRequest("Path/method combination is invalid"); + + var actionDescriptor = endpoint.Metadata.GetMetadata(); + var template = actionDescriptor?.AttributeRouteInfo?.Template; + if (template == null) throw new FoxnounsError("Template value was null on valid endpoint"); + template = GetCleanedTemplate(template); + + // If no token was supplied, or it isn't valid base 64, return a null user ID (limiting by IP) + if (req.Token == null) return Ok(new RequestDataResponse(null, template)); + if (!AuthUtils.TryFromBase64String(req.Token, out var rawToken)) + return Ok(new RequestDataResponse(null, template)); + + var hash = SHA512.HashData(rawToken); + var oauthToken = await db.Tokens + .Include(t => t.Application) + .Include(t => t.User) + .FirstOrDefaultAsync(t => t.Hash == hash && t.ExpiresAt > clock.GetCurrentInstant() && !t.ManuallyExpired); + + return Ok(new RequestDataResponse(oauthToken?.UserId, template)); + } + + public record RequestDataRequest(string? Token, string Method, string Path); + + public record RequestDataResponse( + Snowflake? UserId, + string Template); + + private static Endpoint? GetEndpoint(HttpContext httpContext, string url, string requestMethod) + { + var endpointDataSource = httpContext.RequestServices.GetService(); + if (endpointDataSource == null) return null; + var endpoints = endpointDataSource.Endpoints.OfType(); + + foreach (var endpoint in endpoints) + { + if (endpoint.RoutePattern.RawText == null) continue; + + var templateMatcher = new TemplateMatcher(TemplateParser.Parse(endpoint.RoutePattern.RawText), new()); + if (!templateMatcher.TryMatch(url, new())) continue; + var httpMethodAttribute = endpoint.Metadata.GetMetadata(); + if (httpMethodAttribute != null && + !httpMethodAttribute.HttpMethods.Any(x => x.Equals(requestMethod, StringComparison.OrdinalIgnoreCase))) + continue; + return endpoint; + } + + return null; + } +} \ No newline at end of file diff --git a/Foxnouns.Backend/Database/DatabaseContext.cs b/Foxnouns.Backend/Database/DatabaseContext.cs index 26087b8..8ce0d6f 100644 --- a/Foxnouns.Backend/Database/DatabaseContext.cs +++ b/Foxnouns.Backend/Database/DatabaseContext.cs @@ -28,6 +28,9 @@ public class DatabaseContext : DbContext { Timeout = config.Database.Timeout ?? 5, MaxPoolSize = config.Database.MaxPoolSize ?? 50, + MinPoolSize = 0, + ConnectionPruningInterval = 10, + ConnectionIdleLifetime = 10, }.ConnectionString; var dataSourceBuilder = new NpgsqlDataSourceBuilder(connString); diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..4636156 --- /dev/null +++ b/go.mod @@ -0,0 +1,9 @@ +module code.vulpine.solutions/sam/Foxnouns.NET + +go 1.20 + +require ( + github.com/go-chi/httprate v0.14.1 +) + +require github.com/cespare/xxhash/v2 v2.3.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..cfbff79 --- /dev/null +++ b/go.sum @@ -0,0 +1,7 @@ +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/go-chi/httprate v0.14.1 h1:EKZHYEZ58Cg6hWcYzoZILsv7ppb46Wt4uQ738IRtpZs= +github.com/go-chi/httprate v0.14.1/go.mod h1:TUepLXaz/pCjmCtf/obgOQJ2Sz6rC8fSf5cAt5cnTt0= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= +golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= diff --git a/rate/README.md b/rate/README.md new file mode 100644 index 0000000..d65f0f1 --- /dev/null +++ b/rate/README.md @@ -0,0 +1,7 @@ +# Rate limiting proxy + +This is a service that's meant to sit between nginx (or another reverse proxy) and Foxnouns.Backend. +To configure, copy `proxy-config.example.json` to your working directory, rename it to `proxy-config.json`, +and change any keys you need to. + +Build with `go build -v .` and run with `./rate` diff --git a/rate/handler.go b/rate/handler.go new file mode 100644 index 0000000..5cbfff3 --- /dev/null +++ b/rate/handler.go @@ -0,0 +1,192 @@ +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/http/httputil" + "strconv" + "strings" + "time" +) + +const notFoundError = `{"status":404,"code":"NOT_FOUND","message":"Not found"}` +const internalServerErrorTemplate = `{"status":500,"code":"INTERNAL_SERVER_ERROR","error_id":"%v","message":"Internal server error"}` +const rateLimitedErrorTemplate = `{"status":429,"code":"RATE_LIMITED","reset":%v,"message":"You are being rate limited"}` + +// error ID chosen by fair duckduckgo search, guaranteed to be random +// (we just need to return *any* error ID, this will do) +const errorID = "951c03eadb6b474db35c23d47a10f6c0" + +type Handler struct { + Port int `json:"port"` + ProxyTarget string `json:"proxy_target"` + Debug bool `json:"debug"` + + limiter *Limiter + proxy *httputil.ReverseProxy + client *http.Client +} + +func (hn *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // all public api endpoints are prefixed with this + if !strings.HasPrefix(r.URL.Path, "/api/v2") { + w.WriteHeader(http.StatusNotFound) + return + } + + data, err := hn.requestData(r) + if err != nil { + if rde, ok := err.(requestDataError); ok { + switch rde.Type { + case "badRequest": + if hn.Debug { + log.Printf("Bad request error for path %v\n", r.URL.Path) + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Content-Length", strconv.Itoa(len(notFoundError))) + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(notFoundError)) + + return + case "internalServerError": + respBody := fmt.Sprintf(internalServerErrorTemplate, rde.ErrorID) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Content-Length", strconv.Itoa(len(respBody))) + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(respBody)) + + return + } + } + + log.Printf("internal error while parsing request data response: %v\n", err) + + respBody := fmt.Sprintf(internalServerErrorTemplate, errorID) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Content-Length", strconv.Itoa(len(respBody))) + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(respBody)) + + return + } + + if hn.Debug { + log.Printf("proxying request to %v %v", r.Method, data.Template) + } + + isLimited, err := hn.limiter.TryLimit(w, r, data.UserID, data.Template) + if err != nil { + log.Printf("error checking rate limit: %v\n", err) + + respBody := fmt.Sprintf(internalServerErrorTemplate, errorID) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Content-Length", strconv.Itoa(len(respBody))) + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(respBody)) + return + } + + if isLimited { + var resetTime time.Time + if reset := getReset(w); reset == 0 { + resetTime = time.Now().UTC().Add(10 * time.Second) + } else { + resetTime = time.Unix(reset, 0) + } + + respBody := fmt.Sprintf(rateLimitedErrorTemplate, resetTime.UnixMilli()) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Content-Length", strconv.Itoa(len(respBody))) + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(respBody)) + return + } + + hn.proxy.ServeHTTP(w, r) +} + +func (hn *Handler) requestData(r *http.Request) (data requestDataResponse, err error) { + var token *string + if header := r.Header.Get("Authorization"); header != "" { + token = &header + } + + url := hn.ProxyTarget + "/api/internal/request-data" + + reqData, err := json.Marshal(requestDataRequest{Token: token, Method: r.Method, Path: r.URL.Path}) + if err != nil { + return data, fmt.Errorf("marshaling request json: %w", err) + } + + req, err := http.NewRequestWithContext(r.Context(), "POST", url, bytes.NewReader(reqData)) + if err != nil { + return data, fmt.Errorf("creating request: %w", err) + } + req.Header.Set("User-Agent", "Foxnouns.NET/rate") + req.Header.Set("Content-Type", "application/json; charset=utf-8") + + resp, err := hn.client.Do(req) + if err != nil { + return data, fmt.Errorf("making request: %w", err) + } + defer resp.Body.Close() + + // If we got a bad request error, that means the endpoint wasn't found, and we should tell the client that + if resp.StatusCode == http.StatusBadRequest { + return data, requestDataError{Type: "badRequest"} + } + + // If we got an internal server error we have to forward the error ID + if resp.StatusCode == http.StatusInternalServerError { + b, err := io.ReadAll(resp.Body) + if err != nil { + return data, fmt.Errorf("reading internal server error body: %w", err) + } + + var fne foxnounsError + err = json.Unmarshal(b, &fne) + if err != nil { + return data, fmt.Errorf("unmarshaling internal server error: %w", err) + } + + return data, requestDataError{Type: "internalServerError", ErrorID: fne.ErrorId} + } + + b, err := io.ReadAll(resp.Body) + if err != nil { + return data, fmt.Errorf("reading body: %w", err) + } + + err = json.Unmarshal(b, &data) + if err != nil { + return data, fmt.Errorf("unmarshaling data: %w", err) + } + return data, nil +} + +type requestDataRequest struct { + Token *string `json:"token"` + Method string `json:"method"` + Path string `json:"path"` +} + +type requestDataResponse struct { + UserID *string `json:"user_id"` + Template string `json:"template"` +} + +type foxnounsError struct { + ErrorId string `json:"error_id"` +} + +type requestDataError struct { + Type string + ErrorID string +} + +func (e requestDataError) Error() string { return "request data error: " + e.Type } diff --git a/rate/main.go b/rate/main.go new file mode 100644 index 0000000..33cccfd --- /dev/null +++ b/rate/main.go @@ -0,0 +1,46 @@ +package main + +import ( + "encoding/json" + "log" + "net/http" + "net/http/httputil" + "net/url" + "os" + "strconv" +) + +func main() { + // read config file and parse it + confB, err := os.ReadFile("proxy-config.json") + if err != nil { + log.Fatalf("reading config.json: %v", err) + } + + hn := &Handler{} + err = json.Unmarshal(confB, hn) + if err != nil { + log.Fatalf("unmarshaling config.json: %v", err) + } + + proxyURL, err := url.Parse(hn.ProxyTarget) + if err != nil { + log.Fatalf("parsing proxy_target as URL: %v", err) + } + hn.proxy = &httputil.ReverseProxy{ + Rewrite: func(pr *httputil.ProxyRequest) { + pr.SetURL(proxyURL) + pr.Out.Host = pr.In.Host + }, + } + hn.limiter = NewLimiter() + hn.client = &http.Client{} + + log.Printf("serving on port %v", hn.Port) + + err = http.ListenAndServe(":"+strconv.Itoa(hn.Port), hn) + if err != nil { + log.Fatalf("listening on port %v: %v", hn.Port, err) + } + +} diff --git a/rate/proxy-config.example.json b/rate/proxy-config.example.json new file mode 100644 index 0000000..1d2d6af --- /dev/null +++ b/rate/proxy-config.example.json @@ -0,0 +1,5 @@ +{ + "port": 5003, + "proxy_target": "http://localhost:5000", + "debug": true +} diff --git a/rate/rate_limiter.go b/rate/rate_limiter.go new file mode 100644 index 0000000..28368e2 --- /dev/null +++ b/rate/rate_limiter.go @@ -0,0 +1,145 @@ +package main + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/go-chi/httprate" +) + +func userOrIp(r *http.Request, userId *string) (string, error) { + if userId != nil { + return *userId, nil + } + + return httprate.KeyByRealIP(r) +} + +type Limiter struct { + bucketRateLimiters map[bucketKey]*httprate.RateLimiter + globalRateLimiters map[string]*httprate.RateLimiter + + bucketMu *sync.Mutex + globalMu *sync.Mutex +} + +type LimitData struct { + Global bool + Bucket string + Remaining int + Reset int +} + +type bucketKey struct { + user, bucket string +} + +func NewLimiter() *Limiter { + return &Limiter{ + bucketRateLimiters: make(map[bucketKey]*httprate.RateLimiter), + globalRateLimiters: make(map[string]*httprate.RateLimiter), + + bucketMu: new(sync.Mutex), + globalMu: new(sync.Mutex), + } +} + +// TryLimit handles rate limiting for a request. +// If this function returns true, a rate limit was hit, and the request should be aborted with a 429 status. +func (l *Limiter) TryLimit(w http.ResponseWriter, r *http.Request, userID *string, template string) (bool, error) { + user, err := userOrIp(r, userID) + if err != nil { + return false, fmt.Errorf("getting user or IP: %w", err) + } + + // Check the globalLimiter rate limit for this user + globalLimiter := l.globalLimiter(user) + if globalLimiter.OnLimit(w, r, "") { + w.Header().Add("X-RateLimit-Global", "true") + return true, nil + } + + bucket := requestBucket(r.Method, template) + w.Header().Add("X-RateLimit-Bucket", bucket) + + bucketLimiter := l.bucketLimiter(user, r.Method, bucket) + if bucketLimiter.OnLimit(w, r, "") { + return true, nil + } + + return false, nil +} + +func getReset(w http.ResponseWriter) int64 { + header := w.Header().Get("X-RateLimit-Reset") + if header == "" { + return 0 + } + + i, err := strconv.ParseInt(header, 10, 64) + if err != nil { + return 0 + } + + return i +} + +func requestBucket(method, template string) string { + hasher := sha256.New() + _, err := hasher.Write([]byte(method + "-" + template)) + if err != nil { + panic(err) + } + return hex.EncodeToString(hasher.Sum(nil)) +} + +func (l *Limiter) globalLimiter(user string) *httprate.RateLimiter { + l.globalMu.Lock() + defer l.globalMu.Unlock() + + limiter, ok := l.globalRateLimiters[user] + if ok { + return limiter + } + + limiter = httprate.NewRateLimiter(20, time.Second) + l.globalRateLimiters[user] = limiter + return limiter +} + +func (l *Limiter) bucketLimiter(user, method, bucket string) *httprate.RateLimiter { + l.bucketMu.Lock() + defer l.bucketMu.Unlock() + + limiter, ok := l.bucketRateLimiters[bucketKey{user, bucket}] + if ok { + return limiter + } + + requestLimit, windowLength := requestLimitFor(method) + + limiter = httprate.NewRateLimiter(requestLimit, windowLength) + l.bucketRateLimiters[bucketKey{user, bucket}] = limiter + return limiter +} + +// returns the request limit and window length for the given method. +// methods that change state have lower rate limits +func requestLimitFor(method string) (int, time.Duration) { + switch strings.ToUpper(method) { + case "PATCH", "POST": + return 1, time.Second + case "DELETE": + return 1, 5 * time.Second + case "GET": + return 3, time.Second + default: + return 2, time.Second + } +}