using System.Collections.Concurrent; using System.Reflection; using Microsoft.Extensions.Options; using Remora.Discord.API.Abstractions.Rest; using Remora.Discord.API.Gateway.Commands; using Remora.Discord.Gateway; using Remora.Rest.Core; using Remora.Results; namespace Catalogger.Backend.Bot; // This class is based on VelvetToroyashi/RemoraShardHelper, licensed under the Apache 2.0 license: // https://github.com/VelvetToroyashi/RemoraShardHelper public class ShardedGatewayClient( ILogger logger, IDiscordRestGatewayAPI gatewayApi, IServiceProvider services, IOptions gatewayClientOptions, Config config ) : IDisposable { private int _shardCount = config.Discord.ShardCount ?? 0; private readonly ILogger _logger = logger.ForContext(); private readonly ConcurrentDictionary _gatewayClients = new(); private static readonly FieldInfo Field = typeof(DiscordGatewayClient).GetField( "_connectionStatus", BindingFlags.Instance | BindingFlags.NonPublic )!; private static readonly Func< DiscordGatewayClient, GatewayConnectionStatus > GetConnectionStatus = client => (GatewayConnectionStatus)Field.GetValue(client)!; public IReadOnlyDictionary Shards => _gatewayClients; public async Task RunAsync(CancellationToken ct = default) { var gatewayResult = await gatewayApi.GetGatewayBotAsync(ct); if (!gatewayResult.IsSuccess) { _logger.Error("Failed to retrieve gateway endpoint: {Error}", gatewayResult.Error); return (Result)gatewayResult; } if (gatewayResult.Entity.Shards.IsDefined(out var discordShardCount)) { if (_shardCount < discordShardCount && _shardCount != 0) _logger.Warning( "Discord recommends {DiscordShardCount} for this bot, but only {ConfigShardCount} shards are requested. This may cause issues later", discordShardCount, _shardCount ); if (_shardCount == 0) _shardCount = discordShardCount; } var clients = Enumerable .Range(0, _shardCount) .Select(s => { var client = ActivatorUtilities.CreateInstance( services, CloneOptions(gatewayClientOptions.Value, s) ); _gatewayClients[s] = client; return client; }) .ToArray(); var tasks = new List>(); for (var shardIndex = 0; shardIndex < clients.Length; shardIndex++) { _logger.Debug("Starting shard {ShardId}/{ShardCount}", shardIndex, _shardCount); var client = clients[shardIndex]; var res = client.RunAsync(ct); tasks.Add(res); while ( GetConnectionStatus(client) is not GatewayConnectionStatus.Connected && !res.IsCompleted ) { await Task.Delay(100, ct); } if (res is { IsCompleted: true, Result.IsSuccess: false }) { return res.Result; } _logger.Information("Started shard {ShardId}/{ShardCount}", shardIndex, _shardCount); } return await await Task.WhenAny(tasks); } public int ShardIdFor(ulong guildId) => (int)((guildId >> 22) % (ulong)_shardCount); public DiscordGatewayClient ClientFor(Snowflake guildId) => ClientFor(guildId.Value); public DiscordGatewayClient ClientFor(ulong guildId) => _gatewayClients.TryGetValue(ShardIdFor(guildId), out var client) ? client : throw new CataloggerError( "Shard was null, has ShardedGatewayClient.RunAsync been called?" ); public void Dispose() { foreach (var client in _gatewayClients.Values) client.Dispose(); } private IOptions CloneOptions( DiscordGatewayClientOptions options, int shardId ) { var ret = new DiscordGatewayClientOptions { ShardIdentification = new ShardIdentification(shardId, _shardCount), Intents = options.Intents, Presence = options.Presence, ConnectionProperties = options.ConnectionProperties, HeartbeatHeadroom = options.HeartbeatHeadroom, LargeThreshold = options.LargeThreshold, CommandBurstRate = options.CommandBurstRate, HeartbeatSafetyMargin = options.HeartbeatSafetyMargin, MinimumSafetyMargin = options.MinimumSafetyMargin, }; return Options.Create(ret); } }