144 lines
5 KiB
C#
144 lines
5 KiB
C#
using System.Collections.Concurrent;
|
|
using System.Reflection;
|
|
using Microsoft.Extensions.Options;
|
|
using Remora.Discord.API.Abstractions.Rest;
|
|
using Remora.Discord.API.Gateway.Commands;
|
|
using Remora.Discord.Gateway;
|
|
using Remora.Rest.Core;
|
|
using Remora.Results;
|
|
|
|
namespace Catalogger.Backend.Bot;
|
|
|
|
// This class is based on VelvetToroyashi/RemoraShardHelper, licensed under the Apache 2.0 license:
|
|
// https://github.com/VelvetToroyashi/RemoraShardHelper
|
|
public class ShardedGatewayClient(
|
|
ILogger logger,
|
|
IDiscordRestGatewayAPI gatewayApi,
|
|
IServiceProvider services,
|
|
IOptions<DiscordGatewayClientOptions> gatewayClientOptions,
|
|
Config config
|
|
) : IDisposable
|
|
{
|
|
public int TotalShards { get; private set; } = config.Discord.ShardCount ?? 0;
|
|
|
|
private readonly ILogger _logger = logger.ForContext<ShardedGatewayClient>();
|
|
private readonly ConcurrentDictionary<int, DiscordGatewayClient> _gatewayClients = new();
|
|
|
|
private static readonly FieldInfo Field = typeof(DiscordGatewayClient).GetField(
|
|
"_connectionStatus",
|
|
BindingFlags.Instance | BindingFlags.NonPublic
|
|
)!;
|
|
|
|
private static readonly Func<
|
|
DiscordGatewayClient,
|
|
GatewayConnectionStatus
|
|
> GetConnectionStatus = client => (GatewayConnectionStatus)Field.GetValue(client)!;
|
|
|
|
public static bool IsConnected(DiscordGatewayClient client) =>
|
|
GetConnectionStatus(client) == GatewayConnectionStatus.Connected;
|
|
|
|
public IReadOnlyDictionary<int, DiscordGatewayClient> Shards => _gatewayClients;
|
|
|
|
public async Task<Result> RunAsync(CancellationToken ct = default)
|
|
{
|
|
var gatewayResult = await gatewayApi.GetGatewayBotAsync(ct);
|
|
if (!gatewayResult.IsSuccess)
|
|
{
|
|
_logger.Error("Failed to retrieve gateway endpoint: {Error}", gatewayResult.Error);
|
|
return (Result)gatewayResult;
|
|
}
|
|
|
|
if (gatewayResult.Entity.Shards.IsDefined(out var discordShardCount))
|
|
{
|
|
if (TotalShards < discordShardCount && TotalShards != 0)
|
|
_logger.Warning(
|
|
"Discord recommends {DiscordShardCount} for this bot, but only {ConfigShardCount} shards are requested. This may cause issues later",
|
|
discordShardCount,
|
|
TotalShards
|
|
);
|
|
|
|
if (TotalShards == 0)
|
|
TotalShards = discordShardCount;
|
|
}
|
|
|
|
var clients = Enumerable
|
|
.Range(0, TotalShards)
|
|
.Select(s =>
|
|
{
|
|
var client = ActivatorUtilities.CreateInstance<DiscordGatewayClient>(
|
|
services,
|
|
CloneOptions(gatewayClientOptions.Value, s)
|
|
);
|
|
_gatewayClients[s] = client;
|
|
return client;
|
|
})
|
|
.ToArray();
|
|
|
|
var tasks = new List<Task<Result>>();
|
|
|
|
for (var shardIndex = 0; shardIndex < clients.Length; shardIndex++)
|
|
{
|
|
_logger.Debug("Starting shard {ShardId}/{ShardCount}", shardIndex, TotalShards);
|
|
|
|
var client = clients[shardIndex];
|
|
var res = client.RunAsync(ct);
|
|
tasks.Add(res);
|
|
|
|
while (
|
|
GetConnectionStatus(client) is not GatewayConnectionStatus.Connected
|
|
&& !res.IsCompleted
|
|
)
|
|
{
|
|
await Task.Delay(100, ct);
|
|
}
|
|
|
|
if (res is { IsCompleted: true, Result.IsSuccess: false })
|
|
{
|
|
return res.Result;
|
|
}
|
|
|
|
_logger.Information("Started shard {ShardId}/{ShardCount}", shardIndex, TotalShards);
|
|
}
|
|
|
|
return await await Task.WhenAny(tasks);
|
|
}
|
|
|
|
public int ShardIdFor(ulong guildId) => (int)((guildId >> 22) % (ulong)TotalShards);
|
|
|
|
public DiscordGatewayClient ClientFor(Snowflake guildId) => ClientFor(guildId.Value);
|
|
|
|
public DiscordGatewayClient ClientFor(ulong guildId) =>
|
|
_gatewayClients.TryGetValue(ShardIdFor(guildId), out var client)
|
|
? client
|
|
: throw new CataloggerError(
|
|
"Shard was null, has ShardedGatewayClient.RunAsync been called?"
|
|
);
|
|
|
|
public void Dispose()
|
|
{
|
|
GC.SuppressFinalize(this);
|
|
foreach (var client in _gatewayClients.Values)
|
|
client.Dispose();
|
|
}
|
|
|
|
private IOptions<DiscordGatewayClientOptions> CloneOptions(
|
|
DiscordGatewayClientOptions options,
|
|
int shardId
|
|
)
|
|
{
|
|
var ret = new DiscordGatewayClientOptions
|
|
{
|
|
ShardIdentification = new ShardIdentification(shardId, TotalShards),
|
|
Intents = options.Intents,
|
|
Presence = options.Presence,
|
|
ConnectionProperties = options.ConnectionProperties,
|
|
HeartbeatHeadroom = options.HeartbeatHeadroom,
|
|
LargeThreshold = options.LargeThreshold,
|
|
CommandBurstRate = options.CommandBurstRate,
|
|
HeartbeatSafetyMargin = options.HeartbeatSafetyMargin,
|
|
MinimumSafetyMargin = options.MinimumSafetyMargin,
|
|
};
|
|
|
|
return Options.Create(ret);
|
|
}
|
|
}
|