diff --git a/Catalogger.Backend/Bot/Commands/MetaCommands.cs b/Catalogger.Backend/Bot/Commands/MetaCommands.cs index bfd32a4..466d10c 100644 --- a/Catalogger.Backend/Bot/Commands/MetaCommands.cs +++ b/Catalogger.Backend/Bot/Commands/MetaCommands.cs @@ -11,7 +11,9 @@ using Remora.Commands.Attributes; using Remora.Commands.Groups; using Remora.Discord.API.Abstractions.Objects; using Remora.Discord.API.Abstractions.Rest; +using Remora.Discord.Commands.Extensions; using Remora.Discord.Commands.Feedback.Services; +using Remora.Discord.Commands.Services; using Remora.Discord.Extensions.Embeds; using Remora.Discord.Gateway; using Remora.Results; @@ -25,8 +27,9 @@ public class MetaCommands( ILogger logger, IClock clock, Config config, - DiscordGatewayClient client, + ShardedGatewayClient client, IFeedbackService feedbackService, + ContextInjectionService contextInjection, GuildCache guildCache, ChannelCache channelCache, IDiscordRestChannelAPI channelApi) : CommandGroup @@ -38,6 +41,13 @@ public class MetaCommands( [Description("Ping pong! See the bot's latency")] public async Task PingAsync() { + var shardId = contextInjection.Context?.TryGetGuildID(out var guildId) == true + ? client.ShardIdFor(guildId.Value) + : 0; + + var averageLatency = client.Shards.Values.Select(x => x.Latency.TotalMilliseconds).Sum() / + client.Shards.Count; + var t1 = clock.GetCurrentInstant(); var msg = await feedbackService.SendContextualAsync("...").GetOrThrow(); var elapsed = clock.GetCurrentInstant() - t1; @@ -49,7 +59,9 @@ public class MetaCommands( .WithColour(DiscordUtils.Purple) .WithFooter($"{RuntimeInformation.FrameworkDescription} on {RuntimeInformation.RuntimeIdentifier}") .WithCurrentTimestamp(); - embed.AddField("Ping", $"Gateway: {client.Latency.Humanize()}\nAPI: {elapsed.ToTimeSpan().Humanize()}", + embed.AddField("Ping", + $"Gateway: {client.Shards[shardId].Latency.TotalMilliseconds:N0}ms (average: {averageLatency:N0}ms)\n" + + $"API: {elapsed.TotalMilliseconds:N0}ms", inline: true); embed.AddField("Memory usage", memoryUsage.Bytes().Humanize(), inline: true); @@ -60,16 +72,18 @@ public class MetaCommands( : $"{CataloggerMetrics.MessagesReceived.Value:N0} since last restart", true); - embed.AddField("Numbers", - $"{CataloggerMetrics.MessagesStored.Value:N0} messages " + - $"from {guildCache.Size:N0} servers\nCached {channelCache.Size:N0} channels", - true); + embed.AddField("Shard", $"{shardId + 1} of {client.Shards.Count}", true); embed.AddField("Uptime", $"{(CataloggerMetrics.Startup - clock.GetCurrentInstant()).Prettify(TimeUnit.Second)}\n" + $"since ", true); + embed.AddField("Numbers", + $"{CataloggerMetrics.MessagesStored.Value:N0} messages " + + $"from {guildCache.Size:N0} servers\nCached {channelCache.Size:N0} channels", + false); + IEmbed[] embeds = [embed.Build().GetOrThrow()]; return (Result)await channelApi.EditMessageAsync(msg.ChannelID, msg.ID, content: "", embeds: embeds); diff --git a/Catalogger.Backend/Bot/ShardedDiscordService.cs b/Catalogger.Backend/Bot/ShardedDiscordService.cs new file mode 100644 index 0000000..c961e81 --- /dev/null +++ b/Catalogger.Backend/Bot/ShardedDiscordService.cs @@ -0,0 +1,13 @@ +using Remora.Discord.Gateway.Results; + +namespace Catalogger.Backend.Bot; + +public class ShardedDiscordService(ShardedGatewayClient client, IHostApplicationLifetime lifetime) : BackgroundService +{ + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + var result = await client.RunAsync(stoppingToken); + if (result.Error is GatewayError { IsCritical: true }) + lifetime.StopApplication(); + } +} \ No newline at end of file diff --git a/Catalogger.Backend/Bot/ShardedGatewayClient.cs b/Catalogger.Backend/Bot/ShardedGatewayClient.cs new file mode 100644 index 0000000..56e9207 --- /dev/null +++ b/Catalogger.Backend/Bot/ShardedGatewayClient.cs @@ -0,0 +1,120 @@ +using System.Collections.Concurrent; +using System.Reflection; +using Microsoft.Extensions.Options; +using Remora.Discord.API.Abstractions.Rest; +using Remora.Discord.API.Gateway.Commands; +using Remora.Discord.Gateway; +using Remora.Rest.Core; +using Remora.Results; + +namespace Catalogger.Backend.Bot; + +// This class is based on VelvetToroyashi/RemoraShardHelper, licensed under the Apache 2.0 license: +// https://github.com/VelvetToroyashi/RemoraShardHelper +public class ShardedGatewayClient( + ILogger logger, + IDiscordRestGatewayAPI gatewayApi, + IServiceProvider services, + IOptions gatewayClientOptions, + Config config) + : IDisposable +{ + private int _shardCount = config.Discord.ShardCount ?? 0; + private readonly ILogger _logger = logger.ForContext(); + private readonly ConcurrentDictionary _gatewayClients = new(); + + private static readonly FieldInfo Field = + typeof(DiscordGatewayClient).GetField("_connectionStatus", BindingFlags.Instance | BindingFlags.NonPublic)!; + + private static readonly Func GetConnectionStatus = + client => (GatewayConnectionStatus)Field.GetValue(client)!; + + public IReadOnlyDictionary Shards => _gatewayClients; + + public async Task RunAsync(CancellationToken ct = default) + { + var gatewayResult = await gatewayApi.GetGatewayBotAsync(ct); + if (!gatewayResult.IsSuccess) + { + _logger.Error("Failed to retrieve gateway endpoint: {Error}", gatewayResult.Error); + return (Result)gatewayResult; + } + + if (gatewayResult.Entity.Shards.IsDefined(out var discordShardCount)) + { + if (_shardCount < discordShardCount && _shardCount != 0) + _logger.Warning( + "Discord recommends {DiscordShardCount} for this bot, but only {ConfigShardCount} shards are requested. This may cause issues later", + discordShardCount, _shardCount); + + if (_shardCount == 0) _shardCount = discordShardCount; + } + + var clients = Enumerable.Range(0, _shardCount).Select(s => + { + var client = + ActivatorUtilities.CreateInstance(services, + CloneOptions(gatewayClientOptions.Value, s)); + _gatewayClients[s] = client; + return client; + }).ToArray(); + + var tasks = new List>(); + + for (var shardIndex = 0; shardIndex < clients.Length; shardIndex++) + { + _logger.Debug("Starting shard {ShardId}/{ShardCount}", shardIndex, _shardCount); + + var client = clients[shardIndex]; + var res = client.RunAsync(ct); + tasks.Add(res); + + while (GetConnectionStatus(client) is not GatewayConnectionStatus.Connected && !res.IsCompleted) + { + await Task.Delay(100, ct); + } + + if (res is { IsCompleted: true, Result.IsSuccess: false }) + { + return res.Result; + } + + _logger.Information("Started shard {ShardId}/{ShardCount}", shardIndex, _shardCount); + } + + return await await Task.WhenAny(tasks); + } + + public int ShardIdFor(ulong guildId) => (int)((guildId >> 22) % (ulong)_shardCount); + + public DiscordGatewayClient ClientFor(Snowflake guildId) => ClientFor(guildId.Value); + + public DiscordGatewayClient ClientFor(ulong guildId) => + _gatewayClients.TryGetValue(ShardIdFor(guildId), out var client) + ? client + : throw new CataloggerError("Shard was null, has ShardedGatewayClient.RunAsync been called?"); + + public void Dispose() + { + foreach (var client in _gatewayClients.Values) + client.Dispose(); + } + + private IOptions CloneOptions(DiscordGatewayClientOptions options, int shardId) + { + var ret = new DiscordGatewayClientOptions + { + ShardIdentification = new ShardIdentification(shardId, _shardCount), + Intents = options.Intents, + Presence = options.Presence, + ConnectionProperties = options.ConnectionProperties, + HeartbeatHeadroom = options.HeartbeatHeadroom, + LargeThreshold = options.LargeThreshold, + CommandBurstRate = options.CommandBurstRate, + HeartbeatSafetyMargin = options.HeartbeatSafetyMargin, + MinimumSafetyMargin = options.MinimumSafetyMargin + }; + + return Options.Create(ret); + } +} \ No newline at end of file diff --git a/Catalogger.Backend/Config.cs b/Catalogger.Backend/Config.cs index 6ebd9aa..cc65037 100644 --- a/Catalogger.Backend/Config.cs +++ b/Catalogger.Backend/Config.cs @@ -34,6 +34,7 @@ public class Config public bool SyncCommands { get; init; } public ulong? CommandsGuildId { get; init; } public ulong? GuildLogId { get; init; } + public int? ShardCount { get; init; } } public class WebConfig diff --git a/Catalogger.Backend/Extensions/StartupExtensions.cs b/Catalogger.Backend/Extensions/StartupExtensions.cs index 483bee9..e996000 100644 --- a/Catalogger.Backend/Extensions/StartupExtensions.cs +++ b/Catalogger.Backend/Extensions/StartupExtensions.cs @@ -1,3 +1,4 @@ +using Catalogger.Backend.Bot; using Catalogger.Backend.Bot.Commands; using Catalogger.Backend.Bot.Responders.Messages; using Catalogger.Backend.Cache; @@ -12,6 +13,7 @@ using NodaTime; using Remora.Discord.API; using Remora.Discord.API.Abstractions.Rest; using Remora.Discord.Commands.Services; +using Remora.Discord.Gateway.Extensions; using Remora.Discord.Interactivity.Services; using Remora.Rest.Core; using Serilog; @@ -81,6 +83,13 @@ public static class StartupExtensions .AddSingleton() .AddHostedService(serviceProvider => serviceProvider.GetRequiredService()); + public static IHostBuilder AddShardedDiscordService(this IHostBuilder builder, + Func tokenFactory) => + builder.ConfigureServices((_, services) => services + .AddDiscordGateway(tokenFactory) + .AddSingleton() + .AddHostedService()); + public static IServiceCollection MaybeAddRedisCaches(this IServiceCollection services, Config config) { if (config.Database.Redis == null) diff --git a/Catalogger.Backend/Program.cs b/Catalogger.Backend/Program.cs index 1f2528a..0ff4dba 100644 --- a/Catalogger.Backend/Program.cs +++ b/Catalogger.Backend/Program.cs @@ -28,7 +28,7 @@ builder.Services }); builder.Host - .AddDiscordService(_ => config.Discord.Token) + .AddShardedDiscordService(_ => config.Discord.Token) .ConfigureServices(s => s.AddRespondersFromAssembly(typeof(Program).Assembly) .Configure(g => diff --git a/Catalogger.Backend/Services/GuildFetchService.cs b/Catalogger.Backend/Services/GuildFetchService.cs index afb43dc..ee3c084 100644 --- a/Catalogger.Backend/Services/GuildFetchService.cs +++ b/Catalogger.Backend/Services/GuildFetchService.cs @@ -1,4 +1,5 @@ using System.Collections.Concurrent; +using Catalogger.Backend.Bot; using Catalogger.Backend.Cache; using Humanizer; using Remora.Discord.API.Abstractions.Rest; @@ -10,7 +11,7 @@ namespace Catalogger.Backend.Services; public class GuildFetchService( ILogger logger, - DiscordGatewayClient gatewayClient, + ShardedGatewayClient client, IDiscordRestGuildAPI guildApi, IInviteCache inviteCache) : BackgroundService { @@ -25,7 +26,7 @@ public class GuildFetchService( if (!_guilds.TryPeek(out var guildId)) continue; _logger.Debug("Fetching members and invites for guild {GuildId}", guildId); - gatewayClient.SubmitCommand(new RequestGuildMembers(guildId, "", 0)); + client.ClientFor(guildId).SubmitCommand(new RequestGuildMembers(guildId, "", 0)); var res = await guildApi.GetGuildInvitesAsync(guildId, stoppingToken); if (res.Error != null) {