foxcord/Foxcord/Gateway/DiscordGatewayClient.cs

253 lines
No EOL
9.6 KiB
C#

using System.Buffers;
using System.Diagnostics.CodeAnalysis;
using System.Net.WebSockets;
using System.Text.Json;
using Foxcord.Gateway.Events;
using Foxcord.Gateway.Events.Commands;
using Foxcord.Serialization;
using Serilog;
namespace Foxcord.Gateway;
public partial class DiscordGatewayClient
{
private const int GatewayVersion = 10;
private const int BufferSize = 64 * 1024;
private readonly ILogger _logger;
private ClientWebSocket? _ws;
private readonly string _token;
private readonly Uri _gatewayUri;
private readonly GatewayIntent _intents;
private readonly IdentifyProperties? _properties;
private readonly int[]? _shardInfo;
private readonly PresenceUpdateCommand? _initialPresence;
private readonly JsonSerializerOptions _jsonSerializerOptions = JsonSerializerExtensions.CreateSerializer();
private long? _lastSequence;
private DateTimeOffset _lastHeartbeatSend = DateTimeOffset.UnixEpoch;
private DateTimeOffset _lastHeartbeatAck = DateTimeOffset.UnixEpoch;
public TimeSpan Latency => _lastHeartbeatAck - _lastHeartbeatSend;
public DiscordGatewayClient(ILogger logger, DiscordGatewayClientOptions opts)
{
_logger = logger.ForContext<DiscordGatewayClient>();
_token = $"Bot {opts.Token}";
var uriBuilder = new UriBuilder(opts.Uri)
{
Query = $"v={GatewayVersion}&encoding=json"
};
_gatewayUri = uriBuilder.Uri;
_intents = opts.Intents;
_properties = opts.IdentifyProperties;
_shardInfo = opts.Shards;
_initialPresence = opts.InitialPresence;
}
public ConnectionStatus Status { get; private set; } = ConnectionStatus.Dead;
/// <summary>
/// Connects to the gateway. This method returns after a connection is established.
/// <c>ct</c> is stored by the client and cancelling it will close the connection.
/// The caller must pause indefinitely or the bot will shut down immediately.
/// </summary>
public async Task ConnectAsync(CancellationToken ct = default)
{
if (Status != ConnectionStatus.Dead)
throw new DiscordGatewayRequestError(
"Gateway is connecting or connected, only one concurrent connection allowed.");
try
{
_ws = new ClientWebSocket();
Status = ConnectionStatus.Connecting;
await _ws.ConnectAsync(_gatewayUri, ct);
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
cts.CancelAfter(TimeSpan.FromSeconds(30));
var (rawHelloPacketType, rawHelloPacket) = await ReadPacketAsync(cts.Token);
if (rawHelloPacketType == WebSocketMessageType.Close)
throw new DiscordGatewayRequestError("First packet received was a close message");
if (!TryDeserializeEvent(rawHelloPacket, out var rawHelloEvent))
throw new DiscordGatewayRequestError("First packet was not a valid event");
if (rawHelloEvent is not HelloEvent hello)
throw new DiscordGatewayRequestError("First event was not a HELLO event");
_logger.Debug("Received HELLO, heartbeat interval is {HeartbeatInterval}", hello.HeartbeatInterval);
var _ = HeartbeatLoopAsync(hello.HeartbeatInterval, ct);
var __ = ReceiveLoopAsync(ct);
_logger.Debug("Sending IDENTIFY");
await WritePacketAsync(new GatewayPacket
{
Opcode = GatewayOpcode.Identify,
Payload = new IdentifyEvent
{
Token = _token,
Intents = _intents,
Properties = _properties ?? new IdentifyProperties(),
Shards = _shardInfo,
Presence = _initialPresence?.ToPayload()
}
}, ct);
}
catch (Exception e)
{
_logger.Error(e, "Error connecting to gateway");
Status = ConnectionStatus.Dead;
_ws = null;
throw;
}
}
private async Task HeartbeatLoopAsync(int heartbeatInterval, CancellationToken ct = default)
{
var delay = TimeSpan.FromMilliseconds(heartbeatInterval * Random.Shared.NextDouble());
_logger.Debug("Waiting {Delay} before sending first heartbeat", delay);
using var timer = new PeriodicTimer(TimeSpan.FromMilliseconds(heartbeatInterval));
while (await timer.WaitForNextTickAsync(ct))
{
_logger.Debug("Sending heartbeat with sequence {Sequence}", _lastSequence);
_lastHeartbeatSend = DateTimeOffset.UtcNow;
await SendCommandAsync(new HeartbeatCommand(_lastSequence), ct);
}
}
private async Task ReceiveLoopAsync(CancellationToken ct = default)
{
while (!ct.IsCancellationRequested)
{
try
{
var (type, packet) = await ReadPacketAsync(ct);
if (type == WebSocketMessageType.Close || packet == null)
{
// TODO: close websocket
return;
}
_ = ReceiveAsync(packet, ct);
}
catch (Exception e)
{
_logger.Error(e, "Error while receiving data");
}
}
async Task ReceiveAsync(GatewayPacket packet, CancellationToken ct2 = default)
{
if (!TryDeserializeEvent(packet, out var gatewayEvent))
{
_logger.Debug("Event {EventType} didn't have payload", packet.Opcode);
return;
}
switch (gatewayEvent)
{
case HeartbeatEvent:
await HandleHeartbeatRequest(ct2);
break;
case HeartbeatAckEvent:
HandleHeartbeatAck();
break;
case DispatchEvent dispatch:
await HandleDispatch(dispatch.Payload, ct2);
break;
}
}
}
public async ValueTask SendCommandAsync(IGatewayCommand command, CancellationToken ct = default) =>
await WritePacketAsync(command.ToGatewayPacket(), ct);
private async ValueTask WritePacketAsync(GatewayPacket packet, CancellationToken ct = default)
{
using var buf = MemoryPool<byte>.Shared.Rent(BufferSize);
var json = JsonSerializer.SerializeToUtf8Bytes(packet, _jsonSerializerOptions);
await _ws!.SendAsync(json.AsMemory(), WebSocketMessageType.Text, true, ct);
}
private async ValueTask<(WebSocketMessageType type, GatewayPacket? packet)> ReadPacketAsync(
CancellationToken ct = default)
{
using var buf = MemoryPool<byte>.Shared.Rent(BufferSize);
var res = await _ws!.ReceiveAsync(buf.Memory, ct);
if (res.MessageType == WebSocketMessageType.Close) return (res.MessageType, null);
if (res.EndOfMessage)
return DeserializePacket(res, buf.Memory.Span[..res.Count]);
return await DeserializeMultipleBufferAsync(res, buf);
}
private async Task<(WebSocketMessageType type, GatewayPacket packet)> DeserializeMultipleBufferAsync(
ValueWebSocketReceiveResult res, IMemoryOwner<byte> buf)
{
await using var stream = new MemoryStream(BufferSize * 4);
stream.Write(buf.Memory.Span.Slice(0, res.Count));
while (!res.EndOfMessage)
{
res = await _ws!.ReceiveAsync(buf.Memory, default);
stream.Write(buf.Memory.Span.Slice(0, res.Count));
}
return DeserializePacket(res, stream.GetBuffer().AsSpan(0, (int)stream.Length));
}
private (WebSocketMessageType type, GatewayPacket packet) DeserializePacket(
ValueWebSocketReceiveResult res, Span<byte> span) => (res.MessageType,
JsonSerializer.Deserialize<GatewayPacket>(span, _jsonSerializerOptions)!);
private bool TryDeserializeEvent(GatewayPacket? packet, [NotNullWhen(true)] out IGatewayEvent? gatewayEvent)
{
gatewayEvent = null;
if (packet == null) return false;
var payload = packet.Payload is JsonElement element ? element : default;
switch (packet.Opcode)
{
case GatewayOpcode.Hello:
gatewayEvent = payload.Deserialize<HelloEvent>(_jsonSerializerOptions)!;
break;
case GatewayOpcode.Dispatch:
_lastSequence = packet.Sequence;
gatewayEvent = new DispatchEvent { Payload = ParseDispatchEvent(packet.EventType!, payload) };
break;
case GatewayOpcode.Heartbeat:
gatewayEvent = new HeartbeatEvent();
break;
case GatewayOpcode.Reconnect:
case GatewayOpcode.InvalidSession:
throw new NotImplementedException();
case GatewayOpcode.HeartbeatAck:
gatewayEvent = new HeartbeatAckEvent();
break;
default:
throw new ArgumentOutOfRangeException();
}
return true;
}
public enum ConnectionStatus
{
Dead,
Connecting,
Connected
}
}
public class DiscordGatewayClientOptions
{
public required string Token { get; init; }
public required string Uri { get; init; }
public required GatewayIntent Intents { get; init; }
public IdentifyProperties? IdentifyProperties { get; init; }
public int? ShardId { get; init; }
public int? ShardCount { get; init; }
public PresenceUpdateCommand? InitialPresence { get; init; }
internal int[]? Shards => ShardId != null && ShardCount != null ? [ShardId.Value, ShardCount.Value] : null;
}