feat: add sharding support

This commit is contained in:
sam 2024-08-24 19:02:19 +02:00
parent 8f39d85486
commit c42ca3f888
Signed by: sam
GPG key ID: 5F3C3C1B3166639D
7 changed files with 167 additions and 9 deletions

View file

@ -11,7 +11,9 @@ using Remora.Commands.Attributes;
using Remora.Commands.Groups;
using Remora.Discord.API.Abstractions.Objects;
using Remora.Discord.API.Abstractions.Rest;
using Remora.Discord.Commands.Extensions;
using Remora.Discord.Commands.Feedback.Services;
using Remora.Discord.Commands.Services;
using Remora.Discord.Extensions.Embeds;
using Remora.Discord.Gateway;
using Remora.Results;
@ -25,8 +27,9 @@ public class MetaCommands(
ILogger logger,
IClock clock,
Config config,
DiscordGatewayClient client,
ShardedGatewayClient client,
IFeedbackService feedbackService,
ContextInjectionService contextInjection,
GuildCache guildCache,
ChannelCache channelCache,
IDiscordRestChannelAPI channelApi) : CommandGroup
@ -38,6 +41,13 @@ public class MetaCommands(
[Description("Ping pong! See the bot's latency")]
public async Task<IResult> PingAsync()
{
var shardId = contextInjection.Context?.TryGetGuildID(out var guildId) == true
? client.ShardIdFor(guildId.Value)
: 0;
var averageLatency = client.Shards.Values.Select(x => x.Latency.TotalMilliseconds).Sum() /
client.Shards.Count;
var t1 = clock.GetCurrentInstant();
var msg = await feedbackService.SendContextualAsync("...").GetOrThrow();
var elapsed = clock.GetCurrentInstant() - t1;
@ -49,7 +59,9 @@ public class MetaCommands(
.WithColour(DiscordUtils.Purple)
.WithFooter($"{RuntimeInformation.FrameworkDescription} on {RuntimeInformation.RuntimeIdentifier}")
.WithCurrentTimestamp();
embed.AddField("Ping", $"Gateway: {client.Latency.Humanize()}\nAPI: {elapsed.ToTimeSpan().Humanize()}",
embed.AddField("Ping",
$"Gateway: {client.Shards[shardId].Latency.TotalMilliseconds:N0}ms (average: {averageLatency:N0}ms)\n" +
$"API: {elapsed.TotalMilliseconds:N0}ms",
inline: true);
embed.AddField("Memory usage", memoryUsage.Bytes().Humanize(), inline: true);
@ -60,16 +72,18 @@ public class MetaCommands(
: $"{CataloggerMetrics.MessagesReceived.Value:N0} since last restart",
true);
embed.AddField("Numbers",
$"{CataloggerMetrics.MessagesStored.Value:N0} messages " +
$"from {guildCache.Size:N0} servers\nCached {channelCache.Size:N0} channels",
true);
embed.AddField("Shard", $"{shardId + 1} of {client.Shards.Count}", true);
embed.AddField("Uptime",
$"{(CataloggerMetrics.Startup - clock.GetCurrentInstant()).Prettify(TimeUnit.Second)}\n" +
$"since <t:{CataloggerMetrics.Startup.ToUnixTimeSeconds()}:F>",
true);
embed.AddField("Numbers",
$"{CataloggerMetrics.MessagesStored.Value:N0} messages " +
$"from {guildCache.Size:N0} servers\nCached {channelCache.Size:N0} channels",
false);
IEmbed[] embeds = [embed.Build().GetOrThrow()];
return (Result)await channelApi.EditMessageAsync(msg.ChannelID, msg.ID, content: "", embeds: embeds);

View file

@ -0,0 +1,13 @@
using Remora.Discord.Gateway.Results;
namespace Catalogger.Backend.Bot;
public class ShardedDiscordService(ShardedGatewayClient client, IHostApplicationLifetime lifetime) : BackgroundService
{
protected override async Task ExecuteAsync(CancellationToken stoppingToken)
{
var result = await client.RunAsync(stoppingToken);
if (result.Error is GatewayError { IsCritical: true })
lifetime.StopApplication();
}
}

View file

@ -0,0 +1,120 @@
using System.Collections.Concurrent;
using System.Reflection;
using Microsoft.Extensions.Options;
using Remora.Discord.API.Abstractions.Rest;
using Remora.Discord.API.Gateway.Commands;
using Remora.Discord.Gateway;
using Remora.Rest.Core;
using Remora.Results;
namespace Catalogger.Backend.Bot;
// This class is based on VelvetToroyashi/RemoraShardHelper, licensed under the Apache 2.0 license:
// https://github.com/VelvetToroyashi/RemoraShardHelper
public class ShardedGatewayClient(
ILogger logger,
IDiscordRestGatewayAPI gatewayApi,
IServiceProvider services,
IOptions<DiscordGatewayClientOptions> gatewayClientOptions,
Config config)
: IDisposable
{
private int _shardCount = config.Discord.ShardCount ?? 0;
private readonly ILogger _logger = logger.ForContext<ShardedGatewayClient>();
private readonly ConcurrentDictionary<int, DiscordGatewayClient> _gatewayClients = new();
private static readonly FieldInfo Field =
typeof(DiscordGatewayClient).GetField("_connectionStatus", BindingFlags.Instance | BindingFlags.NonPublic)!;
private static readonly Func<DiscordGatewayClient, GatewayConnectionStatus> GetConnectionStatus =
client => (GatewayConnectionStatus)Field.GetValue(client)!;
public IReadOnlyDictionary<int, DiscordGatewayClient> Shards => _gatewayClients;
public async Task<Result> RunAsync(CancellationToken ct = default)
{
var gatewayResult = await gatewayApi.GetGatewayBotAsync(ct);
if (!gatewayResult.IsSuccess)
{
_logger.Error("Failed to retrieve gateway endpoint: {Error}", gatewayResult.Error);
return (Result)gatewayResult;
}
if (gatewayResult.Entity.Shards.IsDefined(out var discordShardCount))
{
if (_shardCount < discordShardCount && _shardCount != 0)
_logger.Warning(
"Discord recommends {DiscordShardCount} for this bot, but only {ConfigShardCount} shards are requested. This may cause issues later",
discordShardCount, _shardCount);
if (_shardCount == 0) _shardCount = discordShardCount;
}
var clients = Enumerable.Range(0, _shardCount).Select(s =>
{
var client =
ActivatorUtilities.CreateInstance<DiscordGatewayClient>(services,
CloneOptions(gatewayClientOptions.Value, s));
_gatewayClients[s] = client;
return client;
}).ToArray();
var tasks = new List<Task<Result>>();
for (var shardIndex = 0; shardIndex < clients.Length; shardIndex++)
{
_logger.Debug("Starting shard {ShardId}/{ShardCount}", shardIndex, _shardCount);
var client = clients[shardIndex];
var res = client.RunAsync(ct);
tasks.Add(res);
while (GetConnectionStatus(client) is not GatewayConnectionStatus.Connected && !res.IsCompleted)
{
await Task.Delay(100, ct);
}
if (res is { IsCompleted: true, Result.IsSuccess: false })
{
return res.Result;
}
_logger.Information("Started shard {ShardId}/{ShardCount}", shardIndex, _shardCount);
}
return await await Task.WhenAny(tasks);
}
public int ShardIdFor(ulong guildId) => (int)((guildId >> 22) % (ulong)_shardCount);
public DiscordGatewayClient ClientFor(Snowflake guildId) => ClientFor(guildId.Value);
public DiscordGatewayClient ClientFor(ulong guildId) =>
_gatewayClients.TryGetValue(ShardIdFor(guildId), out var client)
? client
: throw new CataloggerError("Shard was null, has ShardedGatewayClient.RunAsync been called?");
public void Dispose()
{
foreach (var client in _gatewayClients.Values)
client.Dispose();
}
private IOptions<DiscordGatewayClientOptions> CloneOptions(DiscordGatewayClientOptions options, int shardId)
{
var ret = new DiscordGatewayClientOptions
{
ShardIdentification = new ShardIdentification(shardId, _shardCount),
Intents = options.Intents,
Presence = options.Presence,
ConnectionProperties = options.ConnectionProperties,
HeartbeatHeadroom = options.HeartbeatHeadroom,
LargeThreshold = options.LargeThreshold,
CommandBurstRate = options.CommandBurstRate,
HeartbeatSafetyMargin = options.HeartbeatSafetyMargin,
MinimumSafetyMargin = options.MinimumSafetyMargin
};
return Options.Create(ret);
}
}