Catalogger.NET/Catalogger.Backend/Services/WebhookExecutorService.cs

487 lines
17 KiB
C#
Raw Normal View History

// Copyright (C) 2021-present sam (starshines.gay)
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
2024-08-13 13:08:50 +02:00
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
2024-08-13 13:08:50 +02:00
using Catalogger.Backend.Cache;
using Catalogger.Backend.Cache.InMemoryCache;
2024-08-13 13:08:50 +02:00
using Catalogger.Backend.Extensions;
using OneOf;
2024-08-13 13:08:50 +02:00
using Remora.Discord.API;
using Remora.Discord.API.Abstractions.Objects;
using Remora.Discord.API.Abstractions.Rest;
using Remora.Rest.Core;
2024-08-13 16:48:54 +02:00
using Guild = Catalogger.Backend.Database.Models.Guild;
2024-08-13 13:08:50 +02:00
namespace Catalogger.Backend.Services;
[SuppressMessage(
"ReSharper",
"InconsistentlySynchronizedField",
Justification = "ILogger doesn't need to be synchronized"
)]
2024-08-13 13:08:50 +02:00
public class WebhookExecutorService(
Config config,
ILogger logger,
IWebhookCache webhookCache,
ChannelCache channelCache,
2024-10-09 17:35:11 +02:00
IDiscordRestWebhookAPI webhookApi
)
2024-08-13 13:08:50 +02:00
{
private readonly ILogger _logger = logger.ForContext<WebhookExecutorService>();
private readonly Snowflake _applicationId = DiscordSnowflake.New(config.Discord.ApplicationId);
private readonly ConcurrentDictionary<ulong, ConcurrentQueue<IEmbed>> _cache = new();
2024-08-13 16:48:54 +02:00
private readonly ConcurrentDictionary<ulong, object> _locks = new();
2024-08-13 13:08:50 +02:00
private readonly ConcurrentDictionary<ulong, Timer> _timers = new();
private IUser? _selfUser;
/// <summary>
/// Sets the current user for this webhook executor service. This must be called as soon as possible,
/// before any logs are sent, such as in a READY event.
/// </summary>
2024-08-13 13:08:50 +02:00
public void SetSelfUser(IUser user) => _selfUser = user;
/// <summary>
/// Queues a log embed for the given log channel type.
/// If the log channel is already known, use the ulong overload of this method instead.
/// If the log channel depends on the source channel and source user, also use the ulong overload.
/// </summary>
public void QueueLog(Guild guildConfig, LogChannelType logChannelType, IEmbed embed)
2024-08-13 13:08:50 +02:00
{
var logChannel = GetLogChannel(
guildConfig,
logChannelType,
channelId: null,
userId: null,
roleId: null,
roleIds: null
);
2024-10-09 17:35:11 +02:00
if (logChannel == null)
return;
2024-08-13 13:08:50 +02:00
QueueLog(logChannel.Value, embed);
2024-08-13 13:08:50 +02:00
}
/// <summary>
/// Queues a log embed for the given channel ID.
/// </summary>
public void QueueLog(ulong? channelId, IEmbed embed)
2024-08-13 13:08:50 +02:00
{
if (channelId is null or 0)
2024-10-09 17:35:11 +02:00
return;
var queue = _cache.GetOrAdd(channelId.Value, []);
queue.Enqueue(embed);
_cache[channelId.Value] = queue;
2024-08-13 16:48:54 +02:00
SetTimer(channelId.Value, queue);
2024-08-13 16:48:54 +02:00
}
2024-08-13 13:08:50 +02:00
/// <summary>
/// Sends multiple embeds and/or files to a channel, bypassing the embed queue.
/// </summary>
/// <param name="channelId">The channel ID to send the content to.</param>
/// <param name="embeds">The embeds to send. Must be under 6000 characters in length total.</param>
/// <param name="files">The files to send.</param>
2024-10-09 17:35:11 +02:00
public async Task SendLogAsync(
ulong channelId,
List<IEmbed> embeds,
IEnumerable<FileData> files
)
2024-08-13 16:48:54 +02:00
{
2024-10-09 17:35:11 +02:00
if (channelId == 0)
return;
if (config.Discord.TestMode)
{
_logger.Information(
"Should have logged to {ChannelId}, but test mode is enabled, ignoring",
channelId
);
return;
}
var attachments = files
.Select<FileData, OneOf<FileData, IPartialAttachment>>(f => f)
.ToList();
if (embeds.Count == 0 && attachments.Count == 0)
{
_logger.Error(
"SendLogAsync was called with zero embeds and zero attachments, bailing to prevent a bad request error"
);
return;
}
if (embeds.Select(e => e.TextLength()).Sum() > MaxContentLength)
{
_logger.Error(
"SendLogAsync was called with embeds totaling more than 6000 characters, bailing to prevent a bad request error"
);
return;
}
2024-10-09 17:35:11 +02:00
_logger.Debug(
"Sending {EmbedCount} embeds/{FileCount} files to channel {ChannelId}",
embeds.Count,
attachments.Count,
channelId
);
2024-08-13 13:08:50 +02:00
2024-10-09 17:35:11 +02:00
var webhook = await webhookCache.GetOrFetchWebhookAsync(
channelId,
id => FetchWebhookAsync(id)
);
await webhookApi.ExecuteWebhookAsync(
DiscordSnowflake.New(webhook.Id),
webhook.Token,
shouldWait: false,
embeds: embeds,
attachments: attachments,
username: _selfUser!.Username,
avatarUrl: _selfUser.AvatarUrl()
);
2024-08-13 16:48:54 +02:00
}
2024-08-13 13:08:50 +02:00
/// <summary>
2024-10-09 17:35:11 +02:00
/// Sets a 3 second timer for the given channel.
/// </summary>
private void SetTimer(ulong channelId, ConcurrentQueue<IEmbed> queue)
2024-08-13 16:48:54 +02:00
{
2024-10-09 17:35:11 +02:00
if (_timers.TryGetValue(channelId, out var existingTimer))
existingTimer.Dispose();
_timers[channelId] = new Timer(
_ =>
2024-08-13 16:48:54 +02:00
{
var __ = SendLogAsync(channelId, TakeFromQueue(channelId), []);
2024-10-09 17:35:11 +02:00
if (!queue.IsEmpty)
{
if (_timers.TryGetValue(channelId, out var timer))
timer.Dispose();
SetTimer(channelId, queue);
}
},
null,
3000,
Timeout.Infinite
);
2024-08-13 13:08:50 +02:00
}
private const int MaxContentLength = 6000;
/// <summary>
/// Takes as many embeds as possible from the queue for the given channel.
/// Up to ten embeds are returned, or less if their combined length is longer than 6000 characters.
/// Note that this locks the queue to prevent duplicate embeds from being sent.
/// </summary>
private List<IEmbed> TakeFromQueue(ulong channelId)
2024-08-13 16:48:54 +02:00
{
var queue = _cache.GetOrAdd(channelId, []);
var channelLock = _locks.GetOrAdd(channelId, channelId);
lock (channelLock)
{
var totalContentLength = 0;
var embeds = new List<IEmbed>();
while (embeds.Count < 10 && totalContentLength < MaxContentLength)
{
if (!queue.TryPeek(out var embed))
break;
var length = embed.TextLength();
if (length > MaxContentLength)
{
_logger.Warning(
"Queued embed for {ChannelId} exceeds maximum length, discarding it",
channelId
);
queue.TryDequeue(out _);
break;
}
if (totalContentLength + length > MaxContentLength)
2024-10-09 17:35:11 +02:00
break;
totalContentLength += length;
queue.TryDequeue(out _);
embeds.Add(embed);
}
2024-08-13 16:48:54 +02:00
if (embeds.Count == 0)
return embeds;
_logger.Debug(
"Took {EmbedCount} embeds from queue for {ChannelId}, total length is {TotalLength}",
embeds.Count,
channelId,
totalContentLength
);
return embeds;
}
2024-08-13 16:48:54 +02:00
}
// TODO: make it so this method can only have one request per channel in flight simultaneously
2024-10-09 17:35:11 +02:00
private async Task<IWebhook> FetchWebhookAsync(
Snowflake channelId,
CancellationToken ct = default
)
2024-08-13 13:08:50 +02:00
{
2024-10-09 17:35:11 +02:00
var channelWebhooks = await webhookApi.GetChannelWebhooksAsync(channelId, ct).GetOrThrow();
var webhook = channelWebhooks.FirstOrDefault(w =>
w.ApplicationID == _applicationId && w.Token.IsDefined()
);
if (webhook != null)
return webhook;
2024-08-13 13:08:50 +02:00
2024-10-09 17:35:11 +02:00
return await webhookApi
.CreateWebhookAsync(
channelId,
"Catalogger",
default,
reason: "Creating logging webhook",
ct: ct
)
.GetOrThrow();
2024-08-13 13:08:50 +02:00
}
2024-10-09 17:35:11 +02:00
public ulong? GetLogChannel(
Guild guild,
LogChannelType logChannelType,
Snowflake? channelId = null,
ulong? userId = null,
Snowflake? roleId = null,
IReadOnlyList<Snowflake>? roleIds = null
)
{
var isMessageLog =
logChannelType
is LogChannelType.MessageUpdate
or LogChannelType.MessageDelete
or LogChannelType.MessageDeleteBulk;
// Check if we're getting the channel for a channel log
var isChannelLog =
channelId != null
&& logChannelType
is LogChannelType.ChannelCreate
or LogChannelType.ChannelDelete
or LogChannelType.ChannelUpdate;
// Check if we're getting the channel for a role log
var isRoleLog =
roleId != null
&& logChannelType
is LogChannelType.GuildRoleCreate
or LogChannelType.GuildRoleUpdate
or LogChannelType.GuildRoleDelete;
// Check if we're getting the channel for a member update log
var isMemberRoleUpdateLog =
roleIds != null && logChannelType is LogChannelType.GuildMemberUpdate;
if (isMessageLog)
return GetMessageLogChannel(guild, logChannelType, channelId, userId);
if (isChannelLog)
return GetChannelLogChannel(guild, logChannelType, channelId!.Value);
if (isRoleLog && guild.IgnoredRoles.Contains(roleId!.Value.Value))
return null;
// Member update logs are only ignored if *all* updated roles are ignored
if (isMemberRoleUpdateLog && roleIds!.All(r => guild.IgnoredRoles.Contains(r.Value)))
return null;
// If nothing is ignored, return the correct log channel!
return GetDefaultLogChannel(guild, logChannelType);
}
private ulong? GetChannelLogChannel(
Guild guild,
LogChannelType logChannelType,
Snowflake channelId
)
{
if (!channelCache.TryGet(channelId, out var channel))
return GetDefaultLogChannel(guild, logChannelType);
Snowflake? categoryId;
if (
channel.Type
is ChannelType.AnnouncementThread
or ChannelType.PrivateThread
or ChannelType.PublicThread
)
{
// parent_id should always have a value for threads
channelId = channel.ParentID.Value!.Value;
if (!channelCache.TryGet(channelId, out var parentChannel))
return GetDefaultLogChannel(guild, logChannelType);
categoryId = parentChannel.ParentID.Value;
}
else
{
channelId = channel.ID;
categoryId = channel.ParentID.Value;
}
// Check if the channel or its category is ignored
if (
guild.IgnoredChannels.Contains(channelId.Value)
|| (categoryId != null && guild.IgnoredChannels.Contains(categoryId.Value.Value))
)
return null;
return GetDefaultLogChannel(guild, logChannelType);
}
private ulong? GetMessageLogChannel(
2024-10-09 17:35:11 +02:00
Guild guild,
LogChannelType logChannelType,
Snowflake? channelId = null,
ulong? userId = null
)
2024-08-13 13:08:50 +02:00
{
// Check if the user is ignored globally
if (userId != null && guild.Messages.IgnoredUsers.Contains(userId.Value))
return null;
// If the user isn't ignored and we didn't get a channel ID, return the default log channel
2024-10-09 17:35:11 +02:00
if (channelId == null)
return GetDefaultLogChannel(guild, logChannelType);
2024-10-09 17:35:11 +02:00
if (!channelCache.TryGet(channelId.Value, out var channel))
return null;
2024-08-13 13:08:50 +02:00
Snowflake? categoryId;
2024-10-09 17:35:11 +02:00
if (
channel.Type
is ChannelType.AnnouncementThread
or ChannelType.PrivateThread
or ChannelType.PublicThread
)
2024-08-13 13:08:50 +02:00
{
// parent_id should always have a value for threads
channelId = channel.ParentID.Value!.Value;
2024-08-13 16:48:54 +02:00
if (!channelCache.TryGet(channelId.Value, out var parentChannel))
2024-08-13 13:08:50 +02:00
return GetDefaultLogChannel(guild, logChannelType);
categoryId = parentChannel.ParentID.Value;
}
else
{
channelId = channel.ID;
categoryId = channel.ParentID.Value;
}
// Check if the channel or its category is ignored
2024-10-09 17:35:11 +02:00
if (
guild.Messages.IgnoredChannels.Contains(channelId.Value.Value)
|| categoryId != null && guild.Messages.IgnoredChannels.Contains(categoryId.Value.Value)
2024-10-09 17:35:11 +02:00
)
return null;
2024-08-13 13:08:50 +02:00
if (userId != null)
{
// Check the channel-local and category-local ignored users
var channelIgnoredUsers =
guild.Messages.IgnoredUsersPerChannel.GetValueOrDefault(channelId.Value.Value)
2024-10-09 17:35:11 +02:00
?? [];
var categoryIgnoredUsers =
(
categoryId != null
? guild.Messages.IgnoredUsersPerChannel.GetValueOrDefault(
2024-10-09 17:35:11 +02:00
categoryId.Value.Value
)
: []
) ?? [];
if (channelIgnoredUsers.Concat(categoryIgnoredUsers).Contains(userId.Value))
return null;
2024-08-13 13:08:50 +02:00
}
// These three events can be redirected to other channels. Redirects can be on a channel or category level.
// The events are only redirected if they're supposed to be logged in the first place.
if (GetDefaultLogChannel(guild, logChannelType) == 0)
return null;
2024-08-13 13:08:50 +02:00
var categoryRedirect =
categoryId != null
? guild.Channels.Redirects.GetValueOrDefault(categoryId.Value.Value)
: 0;
2024-08-13 13:08:50 +02:00
if (guild.Channels.Redirects.TryGetValue(channelId.Value.Value, out var channelRedirect))
return channelRedirect;
return categoryRedirect != 0
? categoryRedirect
: GetDefaultLogChannel(guild, logChannelType);
2024-08-13 13:08:50 +02:00
}
public static ulong GetDefaultLogChannel(Guild guild, LogChannelType logChannelType) =>
logChannelType switch
2024-10-09 17:35:11 +02:00
{
LogChannelType.GuildUpdate => guild.Channels.GuildUpdate,
LogChannelType.GuildEmojisUpdate => guild.Channels.GuildEmojisUpdate,
LogChannelType.GuildRoleCreate => guild.Channels.GuildRoleCreate,
LogChannelType.GuildRoleUpdate => guild.Channels.GuildRoleUpdate,
LogChannelType.GuildRoleDelete => guild.Channels.GuildRoleDelete,
LogChannelType.ChannelCreate => guild.Channels.ChannelCreate,
LogChannelType.ChannelUpdate => guild.Channels.ChannelUpdate,
LogChannelType.ChannelDelete => guild.Channels.ChannelDelete,
LogChannelType.GuildMemberAdd => guild.Channels.GuildMemberAdd,
LogChannelType.GuildMemberUpdate => guild.Channels.GuildMemberUpdate,
LogChannelType.GuildKeyRoleUpdate => guild.Channels.GuildKeyRoleUpdate,
LogChannelType.GuildMemberNickUpdate => guild.Channels.GuildMemberNickUpdate,
LogChannelType.GuildMemberAvatarUpdate => guild.Channels.GuildMemberAvatarUpdate,
LogChannelType.GuildMemberTimeout => guild.Channels.GuildMemberTimeout,
2024-10-09 17:35:11 +02:00
LogChannelType.GuildMemberRemove => guild.Channels.GuildMemberRemove,
LogChannelType.GuildMemberKick => guild.Channels.GuildMemberKick,
LogChannelType.GuildBanAdd => guild.Channels.GuildBanAdd,
LogChannelType.GuildBanRemove => guild.Channels.GuildBanRemove,
LogChannelType.InviteCreate => guild.Channels.InviteCreate,
LogChannelType.InviteDelete => guild.Channels.InviteDelete,
LogChannelType.MessageUpdate => guild.Channels.MessageUpdate,
LogChannelType.MessageDelete => guild.Channels.MessageDelete,
LogChannelType.MessageDeleteBulk => guild.Channels.MessageDeleteBulk,
_ => throw new ArgumentOutOfRangeException(nameof(logChannelType)),
2024-10-09 17:35:11 +02:00
};
2024-08-13 13:08:50 +02:00
}
public enum LogChannelType
{
GuildUpdate,
GuildEmojisUpdate,
GuildRoleCreate,
GuildRoleUpdate,
GuildRoleDelete,
ChannelCreate,
ChannelUpdate,
ChannelDelete,
GuildMemberAdd,
GuildMemberUpdate,
GuildKeyRoleUpdate,
GuildMemberNickUpdate,
GuildMemberAvatarUpdate,
GuildMemberTimeout,
2024-08-13 13:08:50 +02:00
GuildMemberRemove,
GuildMemberKick,
GuildBanAdd,
GuildBanRemove,
InviteCreate,
InviteDelete,
MessageUpdate,
MessageDelete,
2024-10-09 17:35:11 +02:00
MessageDeleteBulk,
}