using System.Buffers; using System.Diagnostics.CodeAnalysis; using System.Net.WebSockets; using System.Text.Json; using Foxcord.Gateway.Events; using Foxcord.Gateway.Events.Commands; using Foxcord.Serialization; using Serilog; namespace Foxcord.Gateway; public partial class DiscordGatewayClient { private const int GatewayVersion = 10; private const int BufferSize = 64 * 1024; private readonly ILogger _logger; private ClientWebSocket? _ws; private readonly string _token; private readonly Uri _gatewayUri; private readonly GatewayIntent _intents; private readonly IdentifyProperties? _properties; private readonly int[]? _shardInfo; private readonly PresenceUpdateCommand? _initialPresence; private readonly JsonSerializerOptions _jsonSerializerOptions = JsonSerializerExtensions.CreateSerializer(); private long? _lastSequence; private DateTimeOffset _lastHeartbeatSend = DateTimeOffset.UnixEpoch; private DateTimeOffset _lastHeartbeatAck = DateTimeOffset.UnixEpoch; public TimeSpan Latency => _lastHeartbeatAck - _lastHeartbeatSend; public DiscordGatewayClient(ILogger logger, DiscordGatewayClientOptions opts) { _logger = logger.ForContext(); _token = $"Bot {opts.Token}"; var uriBuilder = new UriBuilder(opts.Uri) { Query = $"v={GatewayVersion}&encoding=json" }; _gatewayUri = uriBuilder.Uri; _intents = opts.Intents; _properties = opts.IdentifyProperties; _shardInfo = opts.Shards; _initialPresence = opts.InitialPresence; } public ConnectionStatus Status { get; private set; } = ConnectionStatus.Dead; /// /// Connects to the gateway. This method returns after a connection is established. /// ct is stored by the client and cancelling it will close the connection. /// The caller must pause indefinitely or the bot will shut down immediately. /// public async Task ConnectAsync(CancellationToken ct = default) { if (Status != ConnectionStatus.Dead) throw new DiscordGatewayRequestError( "Gateway is connecting or connected, only one concurrent connection allowed."); try { _ws = new ClientWebSocket(); Status = ConnectionStatus.Connecting; await _ws.ConnectAsync(_gatewayUri, ct); using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); cts.CancelAfter(TimeSpan.FromSeconds(30)); var (rawHelloPacketType, rawHelloPacket) = await ReadPacketAsync(cts.Token); if (rawHelloPacketType == WebSocketMessageType.Close) throw new DiscordGatewayRequestError("First packet received was a close message"); if (!TryDeserializeEvent(rawHelloPacket, out var rawHelloEvent)) throw new DiscordGatewayRequestError("First packet was not a valid event"); if (rawHelloEvent is not HelloEvent hello) throw new DiscordGatewayRequestError("First event was not a HELLO event"); _logger.Debug("Received HELLO, heartbeat interval is {HeartbeatInterval}", hello.HeartbeatInterval); var _ = HeartbeatLoopAsync(hello.HeartbeatInterval, ct); var __ = ReceiveLoopAsync(ct); _logger.Debug("Sending IDENTIFY"); await WritePacketAsync(new GatewayPacket { Opcode = GatewayOpcode.Identify, Payload = new IdentifyEvent { Token = _token, Intents = _intents, Properties = _properties ?? new IdentifyProperties(), Shards = _shardInfo, Presence = _initialPresence?.ToPayload() } }, ct); } catch (Exception e) { _logger.Error(e, "Error connecting to gateway"); Status = ConnectionStatus.Dead; _ws = null; throw; } } private async Task HeartbeatLoopAsync(int heartbeatInterval, CancellationToken ct = default) { var delay = TimeSpan.FromMilliseconds(heartbeatInterval * Random.Shared.NextDouble()); _logger.Debug("Waiting {Delay} before sending first heartbeat", delay); using var timer = new PeriodicTimer(TimeSpan.FromMilliseconds(heartbeatInterval)); while (await timer.WaitForNextTickAsync(ct)) { _logger.Debug("Sending heartbeat with sequence {Sequence}", _lastSequence); _lastHeartbeatSend = DateTimeOffset.UtcNow; await SendCommandAsync(new HeartbeatCommand(_lastSequence), ct); } } private async Task ReceiveLoopAsync(CancellationToken ct = default) { while (!ct.IsCancellationRequested) { try { var (type, packet) = await ReadPacketAsync(ct); if (type == WebSocketMessageType.Close || packet == null) { // TODO: close websocket return; } _ = ReceiveAsync(packet, ct); } catch (Exception e) { _logger.Error(e, "Error while receiving data"); } } async Task ReceiveAsync(GatewayPacket packet, CancellationToken ct2 = default) { if (!TryDeserializeEvent(packet, out var gatewayEvent)) { _logger.Debug("Event {EventType} didn't have payload", packet.Opcode); return; } switch (gatewayEvent) { case HeartbeatEvent: await HandleHeartbeatRequest(ct2); break; case HeartbeatAckEvent: HandleHeartbeatAck(); break; case DispatchEvent dispatch: await HandleDispatch(dispatch.Payload, ct2); break; } } } public async ValueTask SendCommandAsync(IGatewayCommand command, CancellationToken ct = default) => await WritePacketAsync(command.ToGatewayPacket(), ct); private async ValueTask WritePacketAsync(GatewayPacket packet, CancellationToken ct = default) { using var buf = MemoryPool.Shared.Rent(BufferSize); var json = JsonSerializer.SerializeToUtf8Bytes(packet, _jsonSerializerOptions); await _ws!.SendAsync(json.AsMemory(), WebSocketMessageType.Text, true, ct); } private async ValueTask<(WebSocketMessageType type, GatewayPacket? packet)> ReadPacketAsync( CancellationToken ct = default) { using var buf = MemoryPool.Shared.Rent(BufferSize); var res = await _ws!.ReceiveAsync(buf.Memory, ct); if (res.MessageType == WebSocketMessageType.Close) return (res.MessageType, null); if (res.EndOfMessage) return DeserializePacket(res, buf.Memory.Span[..res.Count]); return await DeserializeMultipleBufferAsync(res, buf); } private async Task<(WebSocketMessageType type, GatewayPacket packet)> DeserializeMultipleBufferAsync( ValueWebSocketReceiveResult res, IMemoryOwner buf) { await using var stream = new MemoryStream(BufferSize * 4); stream.Write(buf.Memory.Span.Slice(0, res.Count)); while (!res.EndOfMessage) { res = await _ws!.ReceiveAsync(buf.Memory, default); stream.Write(buf.Memory.Span.Slice(0, res.Count)); } return DeserializePacket(res, stream.GetBuffer().AsSpan(0, (int)stream.Length)); } private (WebSocketMessageType type, GatewayPacket packet) DeserializePacket( ValueWebSocketReceiveResult res, Span span) => (res.MessageType, JsonSerializer.Deserialize(span, _jsonSerializerOptions)!); private bool TryDeserializeEvent(GatewayPacket? packet, [NotNullWhen(true)] out IGatewayEvent? gatewayEvent) { gatewayEvent = null; if (packet == null) return false; var payload = packet.Payload is JsonElement element ? element : default; switch (packet.Opcode) { case GatewayOpcode.Hello: gatewayEvent = payload.Deserialize(_jsonSerializerOptions)!; break; case GatewayOpcode.Dispatch: _lastSequence = packet.Sequence; gatewayEvent = new DispatchEvent { Payload = ParseDispatchEvent(packet.EventType!, payload) }; break; case GatewayOpcode.Heartbeat: gatewayEvent = new HeartbeatEvent(); break; case GatewayOpcode.Reconnect: case GatewayOpcode.InvalidSession: throw new NotImplementedException(); case GatewayOpcode.HeartbeatAck: gatewayEvent = new HeartbeatAckEvent(); break; default: throw new ArgumentOutOfRangeException(); } return true; } public enum ConnectionStatus { Dead, Connecting, Connected } } public class DiscordGatewayClientOptions { public required string Token { get; init; } public required string Uri { get; init; } public required GatewayIntent Intents { get; init; } public IdentifyProperties? IdentifyProperties { get; init; } public int? ShardId { get; init; } public int? ShardCount { get; init; } public PresenceUpdateCommand? InitialPresence { get; init; } internal int[]? Shards => ShardId != null && ShardCount != null ? [ShardId.Value, ShardCount.Value] : null; }