feat: add sharding support

This commit is contained in:
sam 2024-08-24 19:02:19 +02:00
parent 8f39d85486
commit c42ca3f888
Signed by: sam
GPG key ID: 5F3C3C1B3166639D
7 changed files with 167 additions and 9 deletions

View file

@ -11,7 +11,9 @@ using Remora.Commands.Attributes;
using Remora.Commands.Groups; using Remora.Commands.Groups;
using Remora.Discord.API.Abstractions.Objects; using Remora.Discord.API.Abstractions.Objects;
using Remora.Discord.API.Abstractions.Rest; using Remora.Discord.API.Abstractions.Rest;
using Remora.Discord.Commands.Extensions;
using Remora.Discord.Commands.Feedback.Services; using Remora.Discord.Commands.Feedback.Services;
using Remora.Discord.Commands.Services;
using Remora.Discord.Extensions.Embeds; using Remora.Discord.Extensions.Embeds;
using Remora.Discord.Gateway; using Remora.Discord.Gateway;
using Remora.Results; using Remora.Results;
@ -25,8 +27,9 @@ public class MetaCommands(
ILogger logger, ILogger logger,
IClock clock, IClock clock,
Config config, Config config,
DiscordGatewayClient client, ShardedGatewayClient client,
IFeedbackService feedbackService, IFeedbackService feedbackService,
ContextInjectionService contextInjection,
GuildCache guildCache, GuildCache guildCache,
ChannelCache channelCache, ChannelCache channelCache,
IDiscordRestChannelAPI channelApi) : CommandGroup IDiscordRestChannelAPI channelApi) : CommandGroup
@ -38,6 +41,13 @@ public class MetaCommands(
[Description("Ping pong! See the bot's latency")] [Description("Ping pong! See the bot's latency")]
public async Task<IResult> PingAsync() public async Task<IResult> PingAsync()
{ {
var shardId = contextInjection.Context?.TryGetGuildID(out var guildId) == true
? client.ShardIdFor(guildId.Value)
: 0;
var averageLatency = client.Shards.Values.Select(x => x.Latency.TotalMilliseconds).Sum() /
client.Shards.Count;
var t1 = clock.GetCurrentInstant(); var t1 = clock.GetCurrentInstant();
var msg = await feedbackService.SendContextualAsync("...").GetOrThrow(); var msg = await feedbackService.SendContextualAsync("...").GetOrThrow();
var elapsed = clock.GetCurrentInstant() - t1; var elapsed = clock.GetCurrentInstant() - t1;
@ -49,7 +59,9 @@ public class MetaCommands(
.WithColour(DiscordUtils.Purple) .WithColour(DiscordUtils.Purple)
.WithFooter($"{RuntimeInformation.FrameworkDescription} on {RuntimeInformation.RuntimeIdentifier}") .WithFooter($"{RuntimeInformation.FrameworkDescription} on {RuntimeInformation.RuntimeIdentifier}")
.WithCurrentTimestamp(); .WithCurrentTimestamp();
embed.AddField("Ping", $"Gateway: {client.Latency.Humanize()}\nAPI: {elapsed.ToTimeSpan().Humanize()}", embed.AddField("Ping",
$"Gateway: {client.Shards[shardId].Latency.TotalMilliseconds:N0}ms (average: {averageLatency:N0}ms)\n" +
$"API: {elapsed.TotalMilliseconds:N0}ms",
inline: true); inline: true);
embed.AddField("Memory usage", memoryUsage.Bytes().Humanize(), inline: true); embed.AddField("Memory usage", memoryUsage.Bytes().Humanize(), inline: true);
@ -60,16 +72,18 @@ public class MetaCommands(
: $"{CataloggerMetrics.MessagesReceived.Value:N0} since last restart", : $"{CataloggerMetrics.MessagesReceived.Value:N0} since last restart",
true); true);
embed.AddField("Numbers", embed.AddField("Shard", $"{shardId + 1} of {client.Shards.Count}", true);
$"{CataloggerMetrics.MessagesStored.Value:N0} messages " +
$"from {guildCache.Size:N0} servers\nCached {channelCache.Size:N0} channels",
true);
embed.AddField("Uptime", embed.AddField("Uptime",
$"{(CataloggerMetrics.Startup - clock.GetCurrentInstant()).Prettify(TimeUnit.Second)}\n" + $"{(CataloggerMetrics.Startup - clock.GetCurrentInstant()).Prettify(TimeUnit.Second)}\n" +
$"since <t:{CataloggerMetrics.Startup.ToUnixTimeSeconds()}:F>", $"since <t:{CataloggerMetrics.Startup.ToUnixTimeSeconds()}:F>",
true); true);
embed.AddField("Numbers",
$"{CataloggerMetrics.MessagesStored.Value:N0} messages " +
$"from {guildCache.Size:N0} servers\nCached {channelCache.Size:N0} channels",
false);
IEmbed[] embeds = [embed.Build().GetOrThrow()]; IEmbed[] embeds = [embed.Build().GetOrThrow()];
return (Result)await channelApi.EditMessageAsync(msg.ChannelID, msg.ID, content: "", embeds: embeds); return (Result)await channelApi.EditMessageAsync(msg.ChannelID, msg.ID, content: "", embeds: embeds);

View file

@ -0,0 +1,13 @@
using Remora.Discord.Gateway.Results;
namespace Catalogger.Backend.Bot;
public class ShardedDiscordService(ShardedGatewayClient client, IHostApplicationLifetime lifetime) : BackgroundService
{
protected override async Task ExecuteAsync(CancellationToken stoppingToken)
{
var result = await client.RunAsync(stoppingToken);
if (result.Error is GatewayError { IsCritical: true })
lifetime.StopApplication();
}
}

View file

@ -0,0 +1,120 @@
using System.Collections.Concurrent;
using System.Reflection;
using Microsoft.Extensions.Options;
using Remora.Discord.API.Abstractions.Rest;
using Remora.Discord.API.Gateway.Commands;
using Remora.Discord.Gateway;
using Remora.Rest.Core;
using Remora.Results;
namespace Catalogger.Backend.Bot;
// This class is based on VelvetToroyashi/RemoraShardHelper, licensed under the Apache 2.0 license:
// https://github.com/VelvetToroyashi/RemoraShardHelper
public class ShardedGatewayClient(
ILogger logger,
IDiscordRestGatewayAPI gatewayApi,
IServiceProvider services,
IOptions<DiscordGatewayClientOptions> gatewayClientOptions,
Config config)
: IDisposable
{
private int _shardCount = config.Discord.ShardCount ?? 0;
private readonly ILogger _logger = logger.ForContext<ShardedGatewayClient>();
private readonly ConcurrentDictionary<int, DiscordGatewayClient> _gatewayClients = new();
private static readonly FieldInfo Field =
typeof(DiscordGatewayClient).GetField("_connectionStatus", BindingFlags.Instance | BindingFlags.NonPublic)!;
private static readonly Func<DiscordGatewayClient, GatewayConnectionStatus> GetConnectionStatus =
client => (GatewayConnectionStatus)Field.GetValue(client)!;
public IReadOnlyDictionary<int, DiscordGatewayClient> Shards => _gatewayClients;
public async Task<Result> RunAsync(CancellationToken ct = default)
{
var gatewayResult = await gatewayApi.GetGatewayBotAsync(ct);
if (!gatewayResult.IsSuccess)
{
_logger.Error("Failed to retrieve gateway endpoint: {Error}", gatewayResult.Error);
return (Result)gatewayResult;
}
if (gatewayResult.Entity.Shards.IsDefined(out var discordShardCount))
{
if (_shardCount < discordShardCount && _shardCount != 0)
_logger.Warning(
"Discord recommends {DiscordShardCount} for this bot, but only {ConfigShardCount} shards are requested. This may cause issues later",
discordShardCount, _shardCount);
if (_shardCount == 0) _shardCount = discordShardCount;
}
var clients = Enumerable.Range(0, _shardCount).Select(s =>
{
var client =
ActivatorUtilities.CreateInstance<DiscordGatewayClient>(services,
CloneOptions(gatewayClientOptions.Value, s));
_gatewayClients[s] = client;
return client;
}).ToArray();
var tasks = new List<Task<Result>>();
for (var shardIndex = 0; shardIndex < clients.Length; shardIndex++)
{
_logger.Debug("Starting shard {ShardId}/{ShardCount}", shardIndex, _shardCount);
var client = clients[shardIndex];
var res = client.RunAsync(ct);
tasks.Add(res);
while (GetConnectionStatus(client) is not GatewayConnectionStatus.Connected && !res.IsCompleted)
{
await Task.Delay(100, ct);
}
if (res is { IsCompleted: true, Result.IsSuccess: false })
{
return res.Result;
}
_logger.Information("Started shard {ShardId}/{ShardCount}", shardIndex, _shardCount);
}
return await await Task.WhenAny(tasks);
}
public int ShardIdFor(ulong guildId) => (int)((guildId >> 22) % (ulong)_shardCount);
public DiscordGatewayClient ClientFor(Snowflake guildId) => ClientFor(guildId.Value);
public DiscordGatewayClient ClientFor(ulong guildId) =>
_gatewayClients.TryGetValue(ShardIdFor(guildId), out var client)
? client
: throw new CataloggerError("Shard was null, has ShardedGatewayClient.RunAsync been called?");
public void Dispose()
{
foreach (var client in _gatewayClients.Values)
client.Dispose();
}
private IOptions<DiscordGatewayClientOptions> CloneOptions(DiscordGatewayClientOptions options, int shardId)
{
var ret = new DiscordGatewayClientOptions
{
ShardIdentification = new ShardIdentification(shardId, _shardCount),
Intents = options.Intents,
Presence = options.Presence,
ConnectionProperties = options.ConnectionProperties,
HeartbeatHeadroom = options.HeartbeatHeadroom,
LargeThreshold = options.LargeThreshold,
CommandBurstRate = options.CommandBurstRate,
HeartbeatSafetyMargin = options.HeartbeatSafetyMargin,
MinimumSafetyMargin = options.MinimumSafetyMargin
};
return Options.Create(ret);
}
}

View file

@ -34,6 +34,7 @@ public class Config
public bool SyncCommands { get; init; } public bool SyncCommands { get; init; }
public ulong? CommandsGuildId { get; init; } public ulong? CommandsGuildId { get; init; }
public ulong? GuildLogId { get; init; } public ulong? GuildLogId { get; init; }
public int? ShardCount { get; init; }
} }
public class WebConfig public class WebConfig

View file

@ -1,3 +1,4 @@
using Catalogger.Backend.Bot;
using Catalogger.Backend.Bot.Commands; using Catalogger.Backend.Bot.Commands;
using Catalogger.Backend.Bot.Responders.Messages; using Catalogger.Backend.Bot.Responders.Messages;
using Catalogger.Backend.Cache; using Catalogger.Backend.Cache;
@ -12,6 +13,7 @@ using NodaTime;
using Remora.Discord.API; using Remora.Discord.API;
using Remora.Discord.API.Abstractions.Rest; using Remora.Discord.API.Abstractions.Rest;
using Remora.Discord.Commands.Services; using Remora.Discord.Commands.Services;
using Remora.Discord.Gateway.Extensions;
using Remora.Discord.Interactivity.Services; using Remora.Discord.Interactivity.Services;
using Remora.Rest.Core; using Remora.Rest.Core;
using Serilog; using Serilog;
@ -81,6 +83,13 @@ public static class StartupExtensions
.AddSingleton<GuildFetchService>() .AddSingleton<GuildFetchService>()
.AddHostedService(serviceProvider => serviceProvider.GetRequiredService<GuildFetchService>()); .AddHostedService(serviceProvider => serviceProvider.GetRequiredService<GuildFetchService>());
public static IHostBuilder AddShardedDiscordService(this IHostBuilder builder,
Func<IServiceProvider, string> tokenFactory) =>
builder.ConfigureServices((_, services) => services
.AddDiscordGateway(tokenFactory)
.AddSingleton<ShardedGatewayClient>()
.AddHostedService<ShardedDiscordService>());
public static IServiceCollection MaybeAddRedisCaches(this IServiceCollection services, Config config) public static IServiceCollection MaybeAddRedisCaches(this IServiceCollection services, Config config)
{ {
if (config.Database.Redis == null) if (config.Database.Redis == null)

View file

@ -28,7 +28,7 @@ builder.Services
}); });
builder.Host builder.Host
.AddDiscordService(_ => config.Discord.Token) .AddShardedDiscordService(_ => config.Discord.Token)
.ConfigureServices(s => .ConfigureServices(s =>
s.AddRespondersFromAssembly(typeof(Program).Assembly) s.AddRespondersFromAssembly(typeof(Program).Assembly)
.Configure<DiscordGatewayClientOptions>(g => .Configure<DiscordGatewayClientOptions>(g =>

View file

@ -1,4 +1,5 @@
using System.Collections.Concurrent; using System.Collections.Concurrent;
using Catalogger.Backend.Bot;
using Catalogger.Backend.Cache; using Catalogger.Backend.Cache;
using Humanizer; using Humanizer;
using Remora.Discord.API.Abstractions.Rest; using Remora.Discord.API.Abstractions.Rest;
@ -10,7 +11,7 @@ namespace Catalogger.Backend.Services;
public class GuildFetchService( public class GuildFetchService(
ILogger logger, ILogger logger,
DiscordGatewayClient gatewayClient, ShardedGatewayClient client,
IDiscordRestGuildAPI guildApi, IDiscordRestGuildAPI guildApi,
IInviteCache inviteCache) : BackgroundService IInviteCache inviteCache) : BackgroundService
{ {
@ -25,7 +26,7 @@ public class GuildFetchService(
if (!_guilds.TryPeek(out var guildId)) continue; if (!_guilds.TryPeek(out var guildId)) continue;
_logger.Debug("Fetching members and invites for guild {GuildId}", guildId); _logger.Debug("Fetching members and invites for guild {GuildId}", guildId);
gatewayClient.SubmitCommand(new RequestGuildMembers(guildId, "", 0)); client.ClientFor(guildId).SubmitCommand(new RequestGuildMembers(guildId, "", 0));
var res = await guildApi.GetGuildInvitesAsync(guildId, stoppingToken); var res = await guildApi.GetGuildInvitesAsync(guildId, stoppingToken);
if (res.Error != null) if (res.Error != null)
{ {