This commit is contained in:
sam 2024-09-03 00:07:12 +02:00
commit b3bf3a7c16
Signed by: sam
GPG key ID: B4EF20DDE721CAA1
43 changed files with 2057 additions and 0 deletions

View file

@ -0,0 +1,128 @@
using System.Diagnostics;
using System.Net.Http.Headers;
using System.Net.Http.Json;
using System.Text.Json;
using Polly;
using Serilog;
namespace Foxcord.Rest;
public class BaseRestClient
{
public HttpClient Client { get; }
private readonly ILogger _logger;
private readonly JsonSerializerOptions _jsonSerializerOptions;
private readonly Func<string>? _tokenFactory;
private readonly Func<string, string> _pathCleaner;
private readonly Func<HttpResponseMessage, CancellationToken, Task> _errorHandler;
protected ResiliencePipeline<HttpResponseMessage> Pipeline { get; set; } =
new ResiliencePipelineBuilder<HttpResponseMessage>().Build();
private readonly string _apiBaseUrl;
public BaseRestClient(ILogger logger, RestClientOptions options)
{
Client = new HttpClient();
Client.DefaultRequestHeaders.TryAddWithoutValidation("User-Agent", options.UserAgent);
_logger = logger.ForContext<BaseRestClient>();
_jsonSerializerOptions = options.JsonSerializerOptions ?? JsonSerializerOptions.Default;
_apiBaseUrl = options.ApiBaseUrl;
_tokenFactory = options.TokenFactory;
_pathCleaner = options.PathLogCleaner ?? (s => s);
_errorHandler = options.ErrorHandler;
}
public async Task<T> RequestAsync<T>(HttpMethod method, string path, CancellationToken ct = default)
where T : class
{
var req = new HttpRequestMessage(method, $"{_apiBaseUrl}{path}");
if (_tokenFactory != null)
req.Headers.Add("Authorization", _tokenFactory());
var resp = await DoRequestAsync(path, req, ct);
return await resp.Content.ReadFromJsonAsync<T>(_jsonSerializerOptions, ct) ??
throw new DiscordRequestError("Content was deserialized as null");
}
public async Task<TResponse> RequestAsync<TRequest, TResponse>(HttpMethod method, string path, TRequest reqBody,
CancellationToken ct = default)
{
var req = new HttpRequestMessage(method, $"{_apiBaseUrl}{path}");
if (_tokenFactory != null)
req.Headers.Add("Authorization", _tokenFactory());
var body = JsonSerializer.Serialize(reqBody, _jsonSerializerOptions);
req.Content = new StringContent(body, new MediaTypeHeaderValue("application/json", "utf-8"));
var resp = await DoRequestAsync(path, req, ct);
return await resp.Content.ReadFromJsonAsync<TResponse>(_jsonSerializerOptions, ct) ??
throw new DiscordRequestError("Content was deserialized as null");
}
private async Task<HttpResponseMessage> DoRequestAsync(string path, HttpRequestMessage req,
CancellationToken ct = default)
{
var context = ResilienceContextPool.Shared.Get(ct);
context.Properties.Set(new ResiliencePropertyKey<string>("Path"), path);
try
{
return await Pipeline.ExecuteAsync(async ctx =>
{
HttpResponseMessage resp;
var stopwatch = new Stopwatch();
stopwatch.Start();
try
{
resp = await Client.SendAsync(req, HttpCompletionOption.ResponseHeadersRead,
ctx.CancellationToken);
stopwatch.Stop();
}
catch (Exception e)
{
_logger.Error(e, "HTTP error: {Method} {Path}", req.Method, _pathCleaner(path));
throw;
}
_logger.Debug("Response: {Method} {Path} -> {StatusCode} {ReasonPhrase} (in {ResponseMs} ms)",
req.Method, _pathCleaner(path), (int)resp.StatusCode, resp.ReasonPhrase,
stopwatch.ElapsedMilliseconds);
await _errorHandler(resp, context.CancellationToken);
return resp;
}, context);
}
finally
{
ResilienceContextPool.Shared.Return(context);
}
}
}
public class RestClientOptions
{
public JsonSerializerOptions? JsonSerializerOptions { get; init; }
public required string UserAgent { get; init; }
public required string ApiBaseUrl { get; init; }
/// <summary>
/// A function that converts non-2XX responses to errors. This should usually throw;
/// if not, the request will continue to be handled normally, which will probably cause errors.
/// </summary>
public required Func<HttpResponseMessage, CancellationToken, Task> ErrorHandler { get; init; }
/// <summary>
/// A function that returns a token. If not set, the client will not add an Authorization header.
/// </summary>
public Func<string>? TokenFactory { get; init; }
/// <summary>
/// A function that cleans up paths for logging. This should remove sensitive content (i.e. tokens).
/// If not set, paths will not be cleaned before being logged.
/// </summary>
public Func<string, string>? PathLogCleaner { get; init; }
}

View file

@ -0,0 +1,66 @@
using System.Net;
using System.Net.Http.Json;
using System.Text.Json;
using Foxcord.Models;
using Foxcord.Rest.Rate;
using Foxcord.Rest.Types;
using Foxcord.Serialization;
using Polly;
using Serilog;
namespace Foxcord.Rest;
public class DiscordRestClient : BaseRestClient
{
public DiscordRestClient(ILogger logger, DiscordRestClientOptions options) : base(logger, new RestClientOptions
{
JsonSerializerOptions = JsonSerializerOptions,
ApiBaseUrl = options.ApiBaseUrl ?? DefaultApiBaseUrl,
UserAgent = options.UserAgent ?? DefaultUserAgent,
TokenFactory = () => $"Bot {options.Token}",
ErrorHandler = HandleError,
})
{
Pipeline = new ResiliencePipelineBuilder<HttpResponseMessage>()
.AddDiscordStrategy(new RateLimiter(logger.ForContext<DiscordRestClient>()))
.Build();
}
public const string DefaultUserAgent = "DiscordBot (https://code.vulpine.solutions/sam/Foxcord, v1)";
public const string DefaultApiBaseUrl = "https://discord.com/api/v10";
public async Task<GetGatewayBotResponse> GatewayBotAsync(CancellationToken ct = default) =>
await RequestAsync<GetGatewayBotResponse>(HttpMethod.Get, "/gateway/bot", ct);
public async Task<Message> CreateMessageAsync(Snowflake channelId, CreateMessageParams message,
CancellationToken ct = default) =>
await RequestAsync<CreateMessageParams, Message>(HttpMethod.Post, $"/channels/{channelId}/messages", message,
ct);
#region BaseRestClient parameters
private static readonly JsonSerializerOptions JsonSerializerOptions = JsonSerializerExtensions.CreateSerializer();
private static async Task HandleError(HttpResponseMessage resp, CancellationToken ct)
{
if (resp.IsSuccessStatusCode) return;
if (resp.StatusCode == HttpStatusCode.TooManyRequests)
{
throw new RateLimitError(resp.Headers);
}
var error = await resp.Content.ReadFromJsonAsync<DiscordRestError>(JsonSerializerOptions, ct);
if (error == null) throw new ArgumentNullException(nameof(resp), "Response was deserialized as null");
throw error;
}
#endregion
}
public class DiscordRestClientOptions
{
public required string Token { get; init; }
public string? UserAgent { get; init; }
public string? ApiBaseUrl { get; init; }
}

View file

@ -0,0 +1,22 @@
using System.Net.Http.Headers;
namespace Foxcord.Rest;
public class DiscordRestError : Exception
{
public required DiscordErrorCode Code { get; init; }
public new required string Message { get; init; }
}
public enum DiscordErrorCode
{
GeneralError = 0,
UnknownAccount = 10001,
}
public class DiscordRequestError(string message) : Exception(message);
public class RateLimitError(HttpHeaders headers) : Exception("Rate limit error")
{
public HttpHeaders Headers { get; } = headers;
}

View file

@ -0,0 +1,65 @@
using System.Text;
namespace Foxcord.Rest.Rate;
// All of this code is taken from Arikawa:
// https://github.com/diamondburned/arikawa/blob/v3/api/rate/rate.go
internal static class BucketKeyUtils
{
private static readonly string[] MajorRootPaths = ["channels", "guilds"];
internal static string Parse(string path)
{
path = path.Split("?", 2)[0];
var parts = path.Split("/");
if (parts.Length == 0) return path;
parts = parts.Skip(1).ToArray();
var skip = 0;
if (MajorRootPaths.Contains(parts[0])) skip = 2;
skip++;
for (; skip < parts.Length; skip += 2)
{
if (long.TryParse(parts[skip], out _) || StringIsEmojiOnly(parts[skip]) || StringIsCustomEmoji(parts[skip]))
parts[skip] = "";
}
path = string.Join("/", parts);
return $"/{path}";
}
private static bool StringIsCustomEmoji(string emoji)
{
var parts = emoji.Split(":");
if (parts.Length != 2) return false;
if (!long.TryParse(parts[1], out _)) return false;
if (parts[0].Contains(' ')) return false;
return true;
}
private static bool StringIsEmojiOnly(string emoji)
{
var runes = emoji.EnumerateRunes().ToArray();
switch (runes.Length)
{
case 0:
return false;
case 1:
case 2:
return EmojiRune(runes[0]);
}
return false;
}
private static bool EmojiRune(Rune r)
{
if (r == new Rune('\u00a9') || r == new Rune('\u00ae') ||
(r >= new Rune('\u2000') && r <= new Rune('\u3300'))) return true;
return false;
}
}

View file

@ -0,0 +1,34 @@
using Polly;
namespace Foxcord.Rest.Rate;
public class DiscordResilienceStrategy(RateLimiter rateLimiter)
: ResilienceStrategy<HttpResponseMessage>
{
protected override async ValueTask<Outcome<HttpResponseMessage>> ExecuteCore<TState>(
Func<ResilienceContext, TState, ValueTask<Outcome<HttpResponseMessage>>> callback, ResilienceContext context,
TState state)
{
var path = context.Properties.GetValue(new ResiliencePropertyKey<string>("Path"), string.Empty);
if (path == string.Empty) throw new DiscordRequestError("Path was not set in Polly context");
var b = await rateLimiter.LockBucket(BucketKeyUtils.Parse(path), context.CancellationToken);
var response = await callback(context, state).ConfigureAwait(context.ContinueOnCapturedContext);
if (response.Exception is RateLimitError rateLimitError)
b.Release(rateLimitError.Headers);
else if (response.Result != null)
b.Release(response.Result.Headers);
return response;
}
}
public class DiscordResilienceStrategyOptions : ResilienceStrategyOptions;
public static class DiscordResilienceStrategyExtensions
{
public static ResiliencePipelineBuilder<HttpResponseMessage> AddDiscordStrategy(
this ResiliencePipelineBuilder<HttpResponseMessage> builder, RateLimiter rateLimiter) =>
builder.AddStrategy(_ => new DiscordResilienceStrategy(rateLimiter), new DiscordResilienceStrategyOptions());
}

View file

@ -0,0 +1,161 @@
using System.Collections.Concurrent;
using System.Globalization;
using System.Net.Http.Headers;
using Serilog;
namespace Foxcord.Rest.Rate;
// Most of this code is taken from discordgo:
// https://github.com/bwmarrin/discordgo/blob/master/ratelimit.go
public class RateLimiter(ILogger logger)
{
private readonly ILogger _logger = logger.ForContext<RateLimiter>();
private readonly ConcurrentDictionary<string, Bucket> _buckets = new();
private readonly ConcurrentDictionary<string, CustomRateLimit> _customRateLimits = new([
new KeyValuePair<string, CustomRateLimit>("//reactions//", new CustomRateLimit
{
Requests = 1,
Reset = TimeSpan.FromMilliseconds(200)
})
]);
internal long Global;
internal Bucket GetBucket(string key)
{
key = BucketKeyUtils.Parse(key);
var bucket = _buckets.GetOrAdd(key, _ => new Bucket
{
Key = key,
Remaining = 1,
RateLimiter = this,
Logger = _logger
});
if (_customRateLimits.Any(r => key.EndsWith(r.Key)))
bucket.CustomRateLimit = _customRateLimits.First(r => key.EndsWith(r.Key)).Value;
return bucket;
}
internal TimeSpan GetWaitTime(Bucket b, int minRemaining)
{
if (b.Remaining < minRemaining && b.Reset > DateTimeOffset.UtcNow)
return b.Reset - DateTimeOffset.UtcNow;
var sleepTo = DateTimeOffset.FromUnixTimeMilliseconds(Global);
if (sleepTo > DateTimeOffset.UtcNow)
return sleepTo - DateTimeOffset.UtcNow;
return TimeSpan.Zero;
}
internal async Task<Bucket> LockBucket(string bucketId, CancellationToken ct = default) =>
await LockBucket(GetBucket(bucketId), ct);
internal async Task<Bucket> LockBucket(Bucket b, CancellationToken ct = default)
{
_logger.Verbose("Locking bucket {Bucket}", b.Key);
await b.Semaphore.WaitAsync(ct);
var waitTime = GetWaitTime(b, 1);
if (waitTime > TimeSpan.Zero) await Task.Delay(waitTime, ct);
b.Remaining--;
_logger.Verbose("Letting request for bucket {Bucket} through", b.Key);
return b;
}
}
internal class CustomRateLimit
{
internal int Requests;
internal TimeSpan Reset;
}
internal class Bucket
{
internal readonly SemaphoreSlim Semaphore = new(1);
internal required string Key;
internal required ILogger Logger { private get; init; }
internal int Remaining;
internal DateTimeOffset Reset;
private DateTimeOffset _lastReset;
internal CustomRateLimit? CustomRateLimit;
internal required RateLimiter RateLimiter;
// discordgo mentions that this is required to prevent 429s, I trust that
private static readonly TimeSpan ExtraResetTime = TimeSpan.FromMilliseconds(250);
internal void Release(HttpHeaders headers)
{
try
{
if (CustomRateLimit != null)
{
if (DateTimeOffset.UtcNow - _lastReset >= CustomRateLimit.Reset)
{
Remaining = CustomRateLimit.Requests - 1;
_lastReset = DateTimeOffset.UtcNow;
}
if (Remaining < 1)
{
Reset = DateTimeOffset.UtcNow + CustomRateLimit.Reset;
}
return;
}
var remaining = TryGetHeaderValue(headers, "X-RateLimit-Remaining");
var reset = TryGetHeaderValue(headers, "X-RateLimit-Reset");
var global = TryGetHeaderValue(headers, "X-RateLimit-Global");
var resetAfter = TryGetHeaderValue(headers, "X-RateLimit-Reset-After");
if (resetAfter != null)
{
if (!double.TryParse(resetAfter, out var parsedResetAfter))
throw new InvalidRateLimitHeaderException("X-RateLimit-Reset-After was not a valid double");
var resetAt = DateTimeOffset.UtcNow + TimeSpan.FromSeconds(parsedResetAfter);
if (global != null) RateLimiter.Global = resetAt.ToUnixTimeMilliseconds();
else Reset = resetAt;
}
else if (reset != null)
{
var dateHeader = TryGetHeaderValue(headers, "Date");
if (dateHeader == null) throw new InvalidRateLimitHeaderException("Date header was not set");
if (!DateTimeOffset.TryParseExact(dateHeader, "r", CultureInfo.InvariantCulture,
DateTimeStyles.AssumeUniversal, out var parsedDate))
throw new InvalidRateLimitHeaderException("Date was not a valid date");
if (!long.TryParse(reset, out var parsedReset))
throw new InvalidRateLimitHeaderException("X-RateLimit-Reset was not a valid long");
var delta = DateTimeOffset.FromUnixTimeMilliseconds(parsedReset) - parsedDate + ExtraResetTime;
Reset = DateTimeOffset.UtcNow + delta;
}
if (remaining == null) return;
if (!int.TryParse(remaining, out var parsedRemaining))
throw new InvalidRateLimitHeaderException("X-RateLimit-Remaining was not a valid integer");
Remaining = parsedRemaining;
Logger.Verbose("New remaining for bucket {Bucket} is {Remaining}", Key, Remaining);
}
finally
{
Logger.Verbose("Releasing bucket {Bucket}", Key);
Semaphore.Release();
}
}
private static string? TryGetHeaderValue(HttpHeaders headers, string key) =>
headers.TryGetValues(key, out var values) ? values.FirstOrDefault() : null;
}
public class InvalidRateLimitHeaderException(string message) : Exception(message);

View file

@ -0,0 +1,6 @@
namespace Foxcord.Rest.Types;
public record CreateMessageParams(
string? Content = null,
object? Nonce = null,
bool Tts = false);

View file

@ -0,0 +1,5 @@
namespace Foxcord.Rest.Types;
public record GetGatewayBotResponse(string Url, int Shards, SessionStartLimit SessionStartLimit);
public record SessionStartLimit(int Total, int Remaining, int ResetAfter, int MaxConcurrency);