feat: add rate limiter proxy
This commit is contained in:
parent
13a0cac663
commit
3d22385689
10 changed files with 498 additions and 0 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -3,3 +3,4 @@ obj/
|
||||||
.version
|
.version
|
||||||
config.ini
|
config.ini
|
||||||
*.DotSettings.user
|
*.DotSettings.user
|
||||||
|
proxy-config.json
|
||||||
|
|
83
Foxnouns.Backend/Controllers/InternalController.cs
Normal file
83
Foxnouns.Backend/Controllers/InternalController.cs
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
using System.Security.Cryptography;
|
||||||
|
using System.Text.RegularExpressions;
|
||||||
|
using Foxnouns.Backend.Database;
|
||||||
|
using Foxnouns.Backend.Utils;
|
||||||
|
using Microsoft.AspNetCore.Mvc;
|
||||||
|
using Microsoft.AspNetCore.Mvc.Controllers;
|
||||||
|
using Microsoft.AspNetCore.Mvc.Routing;
|
||||||
|
using Microsoft.AspNetCore.Routing.Template;
|
||||||
|
using Microsoft.EntityFrameworkCore;
|
||||||
|
using NodaTime;
|
||||||
|
|
||||||
|
namespace Foxnouns.Backend.Controllers;
|
||||||
|
|
||||||
|
[ApiController]
|
||||||
|
[Route("/api/internal")]
|
||||||
|
public partial class InternalController(DatabaseContext db, IClock clock) : ControllerBase
|
||||||
|
{
|
||||||
|
[GeneratedRegex(@"(\{\w+\})")]
|
||||||
|
private static partial Regex PathVarRegex();
|
||||||
|
|
||||||
|
private static string GetCleanedTemplate(string template)
|
||||||
|
{
|
||||||
|
if (template.StartsWith("api/v2")) template = template.Substring("api/v2".Length);
|
||||||
|
template = PathVarRegex()
|
||||||
|
.Replace(template, "{id}") // Replace all path variables (almost always IDs) with `{id}`
|
||||||
|
.Replace("@me", "{id}"); // Also replace hardcoded `@me` with `{id}`
|
||||||
|
if (template.Contains("{id}")) return template.Split("{id}")[0] + "{id}";
|
||||||
|
return template;
|
||||||
|
}
|
||||||
|
|
||||||
|
[HttpPost("request-data")]
|
||||||
|
public async Task<IActionResult> GetRequestDataAsync([FromBody] RequestDataRequest req)
|
||||||
|
{
|
||||||
|
var endpoint = GetEndpoint(HttpContext, req.Path, req.Method);
|
||||||
|
if (endpoint == null) throw new ApiError.BadRequest("Path/method combination is invalid");
|
||||||
|
|
||||||
|
var actionDescriptor = endpoint.Metadata.GetMetadata<ControllerActionDescriptor>();
|
||||||
|
var template = actionDescriptor?.AttributeRouteInfo?.Template;
|
||||||
|
if (template == null) throw new FoxnounsError("Template value was null on valid endpoint");
|
||||||
|
template = GetCleanedTemplate(template);
|
||||||
|
|
||||||
|
// If no token was supplied, or it isn't valid base 64, return a null user ID (limiting by IP)
|
||||||
|
if (req.Token == null) return Ok(new RequestDataResponse(null, template));
|
||||||
|
if (!AuthUtils.TryFromBase64String(req.Token, out var rawToken))
|
||||||
|
return Ok(new RequestDataResponse(null, template));
|
||||||
|
|
||||||
|
var hash = SHA512.HashData(rawToken);
|
||||||
|
var oauthToken = await db.Tokens
|
||||||
|
.Include(t => t.Application)
|
||||||
|
.Include(t => t.User)
|
||||||
|
.FirstOrDefaultAsync(t => t.Hash == hash && t.ExpiresAt > clock.GetCurrentInstant() && !t.ManuallyExpired);
|
||||||
|
|
||||||
|
return Ok(new RequestDataResponse(oauthToken?.UserId, template));
|
||||||
|
}
|
||||||
|
|
||||||
|
public record RequestDataRequest(string? Token, string Method, string Path);
|
||||||
|
|
||||||
|
public record RequestDataResponse(
|
||||||
|
Snowflake? UserId,
|
||||||
|
string Template);
|
||||||
|
|
||||||
|
private static Endpoint? GetEndpoint(HttpContext httpContext, string url, string requestMethod)
|
||||||
|
{
|
||||||
|
var endpointDataSource = httpContext.RequestServices.GetService<EndpointDataSource>();
|
||||||
|
if (endpointDataSource == null) return null;
|
||||||
|
var endpoints = endpointDataSource.Endpoints.OfType<RouteEndpoint>();
|
||||||
|
|
||||||
|
foreach (var endpoint in endpoints)
|
||||||
|
{
|
||||||
|
if (endpoint.RoutePattern.RawText == null) continue;
|
||||||
|
|
||||||
|
var templateMatcher = new TemplateMatcher(TemplateParser.Parse(endpoint.RoutePattern.RawText), new());
|
||||||
|
if (!templateMatcher.TryMatch(url, new())) continue;
|
||||||
|
var httpMethodAttribute = endpoint.Metadata.GetMetadata<HttpMethodAttribute>();
|
||||||
|
if (httpMethodAttribute != null &&
|
||||||
|
!httpMethodAttribute.HttpMethods.Any(x => x.Equals(requestMethod, StringComparison.OrdinalIgnoreCase)))
|
||||||
|
continue;
|
||||||
|
return endpoint;
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
|
@ -28,6 +28,9 @@ public class DatabaseContext : DbContext
|
||||||
{
|
{
|
||||||
Timeout = config.Database.Timeout ?? 5,
|
Timeout = config.Database.Timeout ?? 5,
|
||||||
MaxPoolSize = config.Database.MaxPoolSize ?? 50,
|
MaxPoolSize = config.Database.MaxPoolSize ?? 50,
|
||||||
|
MinPoolSize = 0,
|
||||||
|
ConnectionPruningInterval = 10,
|
||||||
|
ConnectionIdleLifetime = 10,
|
||||||
}.ConnectionString;
|
}.ConnectionString;
|
||||||
|
|
||||||
var dataSourceBuilder = new NpgsqlDataSourceBuilder(connString);
|
var dataSourceBuilder = new NpgsqlDataSourceBuilder(connString);
|
||||||
|
|
9
go.mod
Normal file
9
go.mod
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
module code.vulpine.solutions/sam/Foxnouns.NET
|
||||||
|
|
||||||
|
go 1.20
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/go-chi/httprate v0.14.1
|
||||||
|
)
|
||||||
|
|
||||||
|
require github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
7
go.sum
Normal file
7
go.sum
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
|
github.com/go-chi/httprate v0.14.1 h1:EKZHYEZ58Cg6hWcYzoZILsv7ppb46Wt4uQ738IRtpZs=
|
||||||
|
github.com/go-chi/httprate v0.14.1/go.mod h1:TUepLXaz/pCjmCtf/obgOQJ2Sz6rC8fSf5cAt5cnTt0=
|
||||||
|
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||||
|
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
|
||||||
|
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
7
rate/README.md
Normal file
7
rate/README.md
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Rate limiting proxy
|
||||||
|
|
||||||
|
This is a service that's meant to sit between nginx (or another reverse proxy) and Foxnouns.Backend.
|
||||||
|
To configure, copy `proxy-config.example.json` to your working directory, rename it to `proxy-config.json`,
|
||||||
|
and change any keys you need to.
|
||||||
|
|
||||||
|
Build with `go build -v .` and run with `./rate`
|
192
rate/handler.go
Normal file
192
rate/handler.go
Normal file
|
@ -0,0 +1,192 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const notFoundError = `{"status":404,"code":"NOT_FOUND","message":"Not found"}`
|
||||||
|
const internalServerErrorTemplate = `{"status":500,"code":"INTERNAL_SERVER_ERROR","error_id":"%v","message":"Internal server error"}`
|
||||||
|
const rateLimitedErrorTemplate = `{"status":429,"code":"RATE_LIMITED","reset":%v,"message":"You are being rate limited"}`
|
||||||
|
|
||||||
|
// error ID chosen by fair duckduckgo search, guaranteed to be random
|
||||||
|
// (we just need to return *any* error ID, this will do)
|
||||||
|
const errorID = "951c03eadb6b474db35c23d47a10f6c0"
|
||||||
|
|
||||||
|
type Handler struct {
|
||||||
|
Port int `json:"port"`
|
||||||
|
ProxyTarget string `json:"proxy_target"`
|
||||||
|
Debug bool `json:"debug"`
|
||||||
|
|
||||||
|
limiter *Limiter
|
||||||
|
proxy *httputil.ReverseProxy
|
||||||
|
client *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hn *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// all public api endpoints are prefixed with this
|
||||||
|
if !strings.HasPrefix(r.URL.Path, "/api/v2") {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := hn.requestData(r)
|
||||||
|
if err != nil {
|
||||||
|
if rde, ok := err.(requestDataError); ok {
|
||||||
|
switch rde.Type {
|
||||||
|
case "badRequest":
|
||||||
|
if hn.Debug {
|
||||||
|
log.Printf("Bad request error for path %v\n", r.URL.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
|
w.Header().Set("Content-Length", strconv.Itoa(len(notFoundError)))
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
_, _ = w.Write([]byte(notFoundError))
|
||||||
|
|
||||||
|
return
|
||||||
|
case "internalServerError":
|
||||||
|
respBody := fmt.Sprintf(internalServerErrorTemplate, rde.ErrorID)
|
||||||
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
|
w.Header().Set("Content-Length", strconv.Itoa(len(respBody)))
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
_, _ = w.Write([]byte(respBody))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("internal error while parsing request data response: %v\n", err)
|
||||||
|
|
||||||
|
respBody := fmt.Sprintf(internalServerErrorTemplate, errorID)
|
||||||
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
|
w.Header().Set("Content-Length", strconv.Itoa(len(respBody)))
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
_, _ = w.Write([]byte(respBody))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if hn.Debug {
|
||||||
|
log.Printf("proxying request to %v %v", r.Method, data.Template)
|
||||||
|
}
|
||||||
|
|
||||||
|
isLimited, err := hn.limiter.TryLimit(w, r, data.UserID, data.Template)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("error checking rate limit: %v\n", err)
|
||||||
|
|
||||||
|
respBody := fmt.Sprintf(internalServerErrorTemplate, errorID)
|
||||||
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
|
w.Header().Set("Content-Length", strconv.Itoa(len(respBody)))
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
_, _ = w.Write([]byte(respBody))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if isLimited {
|
||||||
|
var resetTime time.Time
|
||||||
|
if reset := getReset(w); reset == 0 {
|
||||||
|
resetTime = time.Now().UTC().Add(10 * time.Second)
|
||||||
|
} else {
|
||||||
|
resetTime = time.Unix(reset, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody := fmt.Sprintf(rateLimitedErrorTemplate, resetTime.UnixMilli())
|
||||||
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
|
w.Header().Set("Content-Length", strconv.Itoa(len(respBody)))
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
_, _ = w.Write([]byte(respBody))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hn.proxy.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hn *Handler) requestData(r *http.Request) (data requestDataResponse, err error) {
|
||||||
|
var token *string
|
||||||
|
if header := r.Header.Get("Authorization"); header != "" {
|
||||||
|
token = &header
|
||||||
|
}
|
||||||
|
|
||||||
|
url := hn.ProxyTarget + "/api/internal/request-data"
|
||||||
|
|
||||||
|
reqData, err := json.Marshal(requestDataRequest{Token: token, Method: r.Method, Path: r.URL.Path})
|
||||||
|
if err != nil {
|
||||||
|
return data, fmt.Errorf("marshaling request json: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(r.Context(), "POST", url, bytes.NewReader(reqData))
|
||||||
|
if err != nil {
|
||||||
|
return data, fmt.Errorf("creating request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("User-Agent", "Foxnouns.NET/rate")
|
||||||
|
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||||
|
|
||||||
|
resp, err := hn.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return data, fmt.Errorf("making request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// If we got a bad request error, that means the endpoint wasn't found, and we should tell the client that
|
||||||
|
if resp.StatusCode == http.StatusBadRequest {
|
||||||
|
return data, requestDataError{Type: "badRequest"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we got an internal server error we have to forward the error ID
|
||||||
|
if resp.StatusCode == http.StatusInternalServerError {
|
||||||
|
b, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return data, fmt.Errorf("reading internal server error body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var fne foxnounsError
|
||||||
|
err = json.Unmarshal(b, &fne)
|
||||||
|
if err != nil {
|
||||||
|
return data, fmt.Errorf("unmarshaling internal server error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, requestDataError{Type: "internalServerError", ErrorID: fne.ErrorId}
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return data, fmt.Errorf("reading body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(b, &data)
|
||||||
|
if err != nil {
|
||||||
|
return data, fmt.Errorf("unmarshaling data: %w", err)
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestDataRequest struct {
|
||||||
|
Token *string `json:"token"`
|
||||||
|
Method string `json:"method"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestDataResponse struct {
|
||||||
|
UserID *string `json:"user_id"`
|
||||||
|
Template string `json:"template"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type foxnounsError struct {
|
||||||
|
ErrorId string `json:"error_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestDataError struct {
|
||||||
|
Type string
|
||||||
|
ErrorID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e requestDataError) Error() string { return "request data error: " + e.Type }
|
46
rate/main.go
Normal file
46
rate/main.go
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// read config file and parse it
|
||||||
|
confB, err := os.ReadFile("proxy-config.json")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("reading config.json: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hn := &Handler{}
|
||||||
|
err = json.Unmarshal(confB, hn)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("unmarshaling config.json: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL, err := url.Parse(hn.ProxyTarget)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("parsing proxy_target as URL: %v", err)
|
||||||
|
}
|
||||||
|
hn.proxy = &httputil.ReverseProxy{
|
||||||
|
Rewrite: func(pr *httputil.ProxyRequest) {
|
||||||
|
pr.SetURL(proxyURL)
|
||||||
|
pr.Out.Host = pr.In.Host
|
||||||
|
},
|
||||||
|
}
|
||||||
|
hn.limiter = NewLimiter()
|
||||||
|
hn.client = &http.Client{}
|
||||||
|
|
||||||
|
log.Printf("serving on port %v", hn.Port)
|
||||||
|
|
||||||
|
err = http.ListenAndServe(":"+strconv.Itoa(hn.Port), hn)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("listening on port %v: %v", hn.Port, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
5
rate/proxy-config.example.json
Normal file
5
rate/proxy-config.example.json
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
{
|
||||||
|
"port": 5003,
|
||||||
|
"proxy_target": "http://localhost:5000",
|
||||||
|
"debug": true
|
||||||
|
}
|
145
rate/rate_limiter.go
Normal file
145
rate/rate_limiter.go
Normal file
|
@ -0,0 +1,145 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/httprate"
|
||||||
|
)
|
||||||
|
|
||||||
|
func userOrIp(r *http.Request, userId *string) (string, error) {
|
||||||
|
if userId != nil {
|
||||||
|
return *userId, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return httprate.KeyByRealIP(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Limiter struct {
|
||||||
|
bucketRateLimiters map[bucketKey]*httprate.RateLimiter
|
||||||
|
globalRateLimiters map[string]*httprate.RateLimiter
|
||||||
|
|
||||||
|
bucketMu *sync.Mutex
|
||||||
|
globalMu *sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type LimitData struct {
|
||||||
|
Global bool
|
||||||
|
Bucket string
|
||||||
|
Remaining int
|
||||||
|
Reset int
|
||||||
|
}
|
||||||
|
|
||||||
|
type bucketKey struct {
|
||||||
|
user, bucket string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLimiter() *Limiter {
|
||||||
|
return &Limiter{
|
||||||
|
bucketRateLimiters: make(map[bucketKey]*httprate.RateLimiter),
|
||||||
|
globalRateLimiters: make(map[string]*httprate.RateLimiter),
|
||||||
|
|
||||||
|
bucketMu: new(sync.Mutex),
|
||||||
|
globalMu: new(sync.Mutex),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TryLimit handles rate limiting for a request.
|
||||||
|
// If this function returns true, a rate limit was hit, and the request should be aborted with a 429 status.
|
||||||
|
func (l *Limiter) TryLimit(w http.ResponseWriter, r *http.Request, userID *string, template string) (bool, error) {
|
||||||
|
user, err := userOrIp(r, userID)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("getting user or IP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the globalLimiter rate limit for this user
|
||||||
|
globalLimiter := l.globalLimiter(user)
|
||||||
|
if globalLimiter.OnLimit(w, r, "") {
|
||||||
|
w.Header().Add("X-RateLimit-Global", "true")
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
bucket := requestBucket(r.Method, template)
|
||||||
|
w.Header().Add("X-RateLimit-Bucket", bucket)
|
||||||
|
|
||||||
|
bucketLimiter := l.bucketLimiter(user, r.Method, bucket)
|
||||||
|
if bucketLimiter.OnLimit(w, r, "") {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getReset(w http.ResponseWriter) int64 {
|
||||||
|
header := w.Header().Get("X-RateLimit-Reset")
|
||||||
|
if header == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
i, err := strconv.ParseInt(header, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestBucket(method, template string) string {
|
||||||
|
hasher := sha256.New()
|
||||||
|
_, err := hasher.Write([]byte(method + "-" + template))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(hasher.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Limiter) globalLimiter(user string) *httprate.RateLimiter {
|
||||||
|
l.globalMu.Lock()
|
||||||
|
defer l.globalMu.Unlock()
|
||||||
|
|
||||||
|
limiter, ok := l.globalRateLimiters[user]
|
||||||
|
if ok {
|
||||||
|
return limiter
|
||||||
|
}
|
||||||
|
|
||||||
|
limiter = httprate.NewRateLimiter(20, time.Second)
|
||||||
|
l.globalRateLimiters[user] = limiter
|
||||||
|
return limiter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Limiter) bucketLimiter(user, method, bucket string) *httprate.RateLimiter {
|
||||||
|
l.bucketMu.Lock()
|
||||||
|
defer l.bucketMu.Unlock()
|
||||||
|
|
||||||
|
limiter, ok := l.bucketRateLimiters[bucketKey{user, bucket}]
|
||||||
|
if ok {
|
||||||
|
return limiter
|
||||||
|
}
|
||||||
|
|
||||||
|
requestLimit, windowLength := requestLimitFor(method)
|
||||||
|
|
||||||
|
limiter = httprate.NewRateLimiter(requestLimit, windowLength)
|
||||||
|
l.bucketRateLimiters[bucketKey{user, bucket}] = limiter
|
||||||
|
return limiter
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns the request limit and window length for the given method.
|
||||||
|
// methods that change state have lower rate limits
|
||||||
|
func requestLimitFor(method string) (int, time.Duration) {
|
||||||
|
switch strings.ToUpper(method) {
|
||||||
|
case "PATCH", "POST":
|
||||||
|
return 1, time.Second
|
||||||
|
case "DELETE":
|
||||||
|
return 1, 5 * time.Second
|
||||||
|
case "GET":
|
||||||
|
return 3, time.Second
|
||||||
|
default:
|
||||||
|
return 2, time.Second
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue