Catalogger.NET/Catalogger.Backend/Bot/ShardedGatewayClient.cs

173 lines
6.1 KiB
C#
Raw Normal View History

// Copyright (C) 2021-present sam (starshines.gay)
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
2024-08-24 19:02:19 +02:00
using System.Collections.Concurrent;
using System.Reflection;
using Microsoft.Extensions.Options;
using Remora.Discord.API.Abstractions.Rest;
using Remora.Discord.API.Gateway.Commands;
using Remora.Discord.Gateway;
using Remora.Rest.Core;
using Remora.Results;
namespace Catalogger.Backend.Bot;
// This class is based on VelvetToroyashi/RemoraShardHelper, licensed under the Apache 2.0 license:
// https://github.com/VelvetToroyashi/RemoraShardHelper
public class ShardedGatewayClient(
ILogger logger,
IDiscordRestGatewayAPI gatewayApi,
IServiceProvider services,
IOptions<DiscordGatewayClientOptions> gatewayClientOptions,
2024-10-09 17:35:11 +02:00
Config config
) : IDisposable
2024-08-24 19:02:19 +02:00
{
public int TotalShards { get; private set; } = config.Discord.ShardCount ?? 0;
2024-08-24 19:02:19 +02:00
private readonly ILogger _logger = logger.ForContext<ShardedGatewayClient>();
private readonly ConcurrentDictionary<int, DiscordGatewayClient> _gatewayClients = new();
2024-10-09 17:35:11 +02:00
private static readonly FieldInfo Field = typeof(DiscordGatewayClient).GetField(
"_connectionStatus",
BindingFlags.Instance | BindingFlags.NonPublic
)!;
2024-08-24 19:02:19 +02:00
2024-10-09 17:35:11 +02:00
private static readonly Func<
DiscordGatewayClient,
GatewayConnectionStatus
> GetConnectionStatus = client => (GatewayConnectionStatus)Field.GetValue(client)!;
2024-08-24 19:02:19 +02:00
public static bool IsConnected(DiscordGatewayClient client) =>
GetConnectionStatus(client) == GatewayConnectionStatus.Connected;
2024-08-24 19:02:19 +02:00
public IReadOnlyDictionary<int, DiscordGatewayClient> Shards => _gatewayClients;
public async Task<Result> RunAsync(CancellationToken ct = default)
{
var gatewayResult = await gatewayApi.GetGatewayBotAsync(ct);
if (!gatewayResult.IsSuccess)
{
_logger.Error("Failed to retrieve gateway endpoint: {Error}", gatewayResult.Error);
return (Result)gatewayResult;
}
if (gatewayResult.Entity.Shards.IsDefined(out var discordShardCount))
{
if (TotalShards < discordShardCount && TotalShards != 0)
2024-08-24 19:02:19 +02:00
_logger.Warning(
"Discord recommends {DiscordShardCount} for this bot, but only {ConfigShardCount} shards are requested. This may cause issues later",
2024-10-09 17:35:11 +02:00
discordShardCount,
TotalShards
2024-10-09 17:35:11 +02:00
);
2024-08-24 19:02:19 +02:00
if (TotalShards == 0)
TotalShards = discordShardCount;
2024-08-24 19:02:19 +02:00
}
2024-10-09 17:35:11 +02:00
var clients = Enumerable
.Range(0, TotalShards)
2024-10-09 17:35:11 +02:00
.Select(s =>
{
var client = ActivatorUtilities.CreateInstance<DiscordGatewayClient>(
services,
CloneOptions(gatewayClientOptions.Value, s)
);
_gatewayClients[s] = client;
return client;
})
.ToArray();
2024-08-24 19:02:19 +02:00
var tasks = new List<Task<Result>>();
for (var shardIndex = 0; shardIndex < clients.Length; shardIndex++)
{
_logger.Debug("Starting shard {ShardId}/{ShardCount}", shardIndex, TotalShards);
2024-08-24 19:02:19 +02:00
var client = clients[shardIndex];
var res = client.RunAsync(ct);
tasks.Add(res);
2024-10-09 17:35:11 +02:00
while (
GetConnectionStatus(client) is not GatewayConnectionStatus.Connected
&& !res.IsCompleted
)
2024-08-24 19:02:19 +02:00
{
await Task.Delay(100, ct);
}
if (res is { IsCompleted: true, Result.IsSuccess: false })
{
return res.Result;
}
_logger.Information("Started shard {ShardId}/{ShardCount}", shardIndex, TotalShards);
2024-08-24 19:02:19 +02:00
}
var taskResult = await await Task.WhenAny(tasks);
Disconnect();
return taskResult;
2024-08-24 19:02:19 +02:00
}
public int ShardIdFor(ulong guildId) => (int)((guildId >> 22) % (ulong)TotalShards);
2024-08-24 19:02:19 +02:00
public DiscordGatewayClient ClientFor(Snowflake guildId) => ClientFor(guildId.Value);
public DiscordGatewayClient ClientFor(ulong guildId) =>
_gatewayClients.TryGetValue(ShardIdFor(guildId), out var client)
? client
2024-10-09 17:35:11 +02:00
: throw new CataloggerError(
"Shard was null, has ShardedGatewayClient.RunAsync been called?"
);
2024-08-24 19:02:19 +02:00
public void Dispose()
{
GC.SuppressFinalize(this);
2024-08-24 19:02:19 +02:00
foreach (var client in _gatewayClients.Values)
client.Dispose();
}
private void Disconnect()
{
_logger.Information("Disconnecting from Discord");
foreach (var shardId in _gatewayClients.Keys)
{
_logger.Debug("Disposing shard {shardId}", shardId);
if (_gatewayClients.Remove(shardId, out var client))
client.Dispose();
}
}
2024-10-09 17:35:11 +02:00
private IOptions<DiscordGatewayClientOptions> CloneOptions(
DiscordGatewayClientOptions options,
int shardId
)
2024-08-24 19:02:19 +02:00
{
var ret = new DiscordGatewayClientOptions
{
ShardIdentification = new ShardIdentification(shardId, TotalShards),
2024-08-24 19:02:19 +02:00
Intents = options.Intents,
Presence = options.Presence,
ConnectionProperties = options.ConnectionProperties,
HeartbeatHeadroom = options.HeartbeatHeadroom,
LargeThreshold = options.LargeThreshold,
CommandBurstRate = options.CommandBurstRate,
HeartbeatSafetyMargin = options.HeartbeatSafetyMargin,
2024-10-09 17:35:11 +02:00
MinimumSafetyMargin = options.MinimumSafetyMargin,
2024-08-24 19:02:19 +02:00
};
return Options.Create(ret);
}
2024-10-09 17:35:11 +02:00
}