diff --git a/Catalogger.Backend/Bot/Commands/InviteCommands.cs b/Catalogger.Backend/Bot/Commands/InviteCommands.cs index bace978..68d11f9 100644 --- a/Catalogger.Backend/Bot/Commands/InviteCommands.cs +++ b/Catalogger.Backend/Bot/Commands/InviteCommands.cs @@ -15,12 +15,10 @@ using System.ComponentModel; using System.Diagnostics.CodeAnalysis; -using System.Runtime.InteropServices; using Catalogger.Backend.Cache; using Catalogger.Backend.Cache.InMemoryCache; -using Catalogger.Backend.Database; +using Catalogger.Backend.Database.Dapper.Repositories; using Catalogger.Backend.Extensions; -using Microsoft.EntityFrameworkCore; using Remora.Commands.Attributes; using Remora.Commands.Groups; using Remora.Discord.API.Abstractions.Objects; @@ -33,7 +31,6 @@ using Remora.Discord.Commands.Feedback.Services; using Remora.Discord.Commands.Services; using Remora.Discord.Pagination.Extensions; using Remora.Rest.Core; -using Invite = Catalogger.Backend.Database.Models.Invite; using IResult = Remora.Results.IResult; namespace Catalogger.Backend.Bot.Commands; @@ -43,7 +40,7 @@ namespace Catalogger.Backend.Bot.Commands; [DiscordDefaultMemberPermissions(DiscordPermission.ManageGuild)] public class InviteCommands( ILogger logger, - DatabaseContext db, + InviteRepository inviteRepository, GuildCache guildCache, IInviteCache inviteCache, IDiscordRestChannelAPI channelApi, @@ -63,7 +60,7 @@ public class InviteCommands( if (!guildCache.TryGet(guildId, out var guild)) throw new CataloggerError("Guild not in cache"); - var dbInvites = await db.Invites.Where(i => i.GuildId == guildId.Value).ToListAsync(); + var dbInvites = await inviteRepository.GetGuildInvitesAsync(guildId); var fields = guildInvites .Select(i => new PartialNamedInvite( @@ -153,14 +150,7 @@ public class InviteCommands( + $"\nLink: https://discord.gg/{inviteResult.Entity.Code}" ); - var dbInvite = new Invite - { - GuildId = guildId.Value, - Code = inviteResult.Entity.Code, - Name = name, - }; - db.Add(dbInvite); - await db.SaveChangesAsync(); + await inviteRepository.SetInviteNameAsync(guildId, inviteResult.Entity.Code, name); return await feedbackService.ReplyAsync( $"Created a new invite in <#{channel.ID}> with the name **{name}**!" @@ -188,39 +178,18 @@ public class InviteCommands( ); } - var namedInvite = await db - .Invites.Where(i => i.GuildId == guildId.Value && i.Code == invite) - .FirstOrDefaultAsync(); - if (namedInvite == null) - { - if (name == null) - return await feedbackService.ReplyAsync($"Invite `{invite}` already has no name."); - - namedInvite = new Invite - { - GuildId = guildId.Value, - Code = invite, - Name = name, - }; - db.Add(namedInvite); - await db.SaveChangesAsync(); - - return await feedbackService.ReplyAsync( - $"New name set! The invite `{invite}` will now show up as **{name}** in logs." - ); - } + var namedInvite = await inviteRepository.GetInviteAsync(guildId, invite); + if (namedInvite == null && name == null) + return await feedbackService.ReplyAsync($"Invite `{invite}` already has no name."); if (name == null) { - db.Invites.Remove(namedInvite); - await db.SaveChangesAsync(); + await inviteRepository.DeleteInviteAsync(guildId, invite); return await feedbackService.ReplyAsync($"Removed the name for `{invite}`."); } - namedInvite.Name = name; - db.Update(namedInvite); - await db.SaveChangesAsync(); + await inviteRepository.SetInviteNameAsync(guildId, invite, name); return await feedbackService.ReplyAsync( $"New name set! The invite `{invite}` will now show up as **{name}** in logs." @@ -230,7 +199,7 @@ public class InviteCommands( public class InviteAutocompleteProvider( ILogger logger, - DatabaseContext db, + InviteRepository inviteRepository, IInviteCache inviteCache, ContextInjectionService contextInjection ) : IAutocompleteProvider @@ -262,13 +231,13 @@ public class InviteAutocompleteProvider( return []; } - var namedInvites = await db - .Invites.Where(i => - i.GuildId == guildId.Value && i.Name.ToLower().StartsWith(userInput.ToLower()) - ) + // We're filtering and ordering on the client side because a guild won't have infinite invites + // (the maximum on Discord's end is 1500-ish) + // and this way we don't need an index on (guild_id, name) for this *one* case. + var namedInvites = (await inviteRepository.GetGuildInvitesAsync(guildId)) + .Where(i => i.Name.StartsWith(userInput, StringComparison.InvariantCultureIgnoreCase)) .OrderBy(i => i.Name) - .Take(25) - .ToListAsync(ct); + .ToList(); if (namedInvites.Count != 0) { diff --git a/Catalogger.Backend/Bot/Responders/Invites/InviteDeleteResponder.cs b/Catalogger.Backend/Bot/Responders/Invites/InviteDeleteResponder.cs index f736953..c0a2896 100644 --- a/Catalogger.Backend/Bot/Responders/Invites/InviteDeleteResponder.cs +++ b/Catalogger.Backend/Bot/Responders/Invites/InviteDeleteResponder.cs @@ -32,7 +32,7 @@ namespace Catalogger.Backend.Bot.Responders.Invites; public class InviteDeleteResponder( ILogger logger, GuildRepository guildRepository, - DatabaseContext db, + InviteRepository inviteRepository, IInviteCache inviteCache, WebhookExecutorService webhookExecutor, IDiscordRestGuildAPI guildApi @@ -44,9 +44,7 @@ public class InviteDeleteResponder( { var guildId = evt.GuildID.Value; - var dbDeleteCount = await db - .Invites.Where(i => i.GuildId == guildId.Value && i.Code == evt.Code) - .ExecuteDeleteAsync(ct); + var dbDeleteCount = await inviteRepository.DeleteInviteAsync(guildId, evt.Code); if (dbDeleteCount != 0) _logger.Information( "Deleted named invite {Invite} for guild {Guild}", diff --git a/Catalogger.Backend/Bot/Responders/Members/GuildMemberAddResponder.cs b/Catalogger.Backend/Bot/Responders/Members/GuildMemberAddResponder.cs index c92c8e0..6460717 100644 --- a/Catalogger.Backend/Bot/Responders/Members/GuildMemberAddResponder.cs +++ b/Catalogger.Backend/Bot/Responders/Members/GuildMemberAddResponder.cs @@ -36,6 +36,7 @@ namespace Catalogger.Backend.Bot.Responders.Members; public class GuildMemberAddResponder( ILogger logger, DatabaseContext db, + InviteRepository inviteRepository, GuildRepository guildRepository, IMemberCache memberCache, IInviteCache inviteCache, @@ -128,11 +129,7 @@ public class GuildMemberAddResponder( goto afterInvite; } - var inviteName = - await db - .Invites.Where(i => i.Code == usedInvite.Code && i.GuildId == member.GuildID.Value) - .Select(i => i.Name) - .FirstOrDefaultAsync(ct) ?? "*(unnamed)*"; + var inviteName = inviteRepository.GetInviteNameAsync(member.GuildID, usedInvite.Code); var inviteDescription = $""" **Code:** {usedInvite.Code} diff --git a/Catalogger.Backend/Bot/Responders/Messages/MessageCreateResponder.cs b/Catalogger.Backend/Bot/Responders/Messages/MessageCreateResponder.cs index c83e9b1..8364dcf 100644 --- a/Catalogger.Backend/Bot/Responders/Messages/MessageCreateResponder.cs +++ b/Catalogger.Backend/Bot/Responders/Messages/MessageCreateResponder.cs @@ -29,7 +29,7 @@ public class MessageCreateResponder( ILogger logger, Config config, GuildRepository guildRepository, - DapperMessageRepository messageRepository, + MessageRepository messageRepository, UserCache userCache, PkMessageHandler pkMessageHandler ) : IResponder @@ -146,7 +146,7 @@ public partial class PkMessageHandler(ILogger logger, IServiceProvider services) await using var scope = services.CreateAsyncScope(); await using var messageRepository = - scope.ServiceProvider.GetRequiredService(); + scope.ServiceProvider.GetRequiredService(); await Task.WhenAll( messageRepository.SetProxiedMessageDataAsync( @@ -166,7 +166,7 @@ public partial class PkMessageHandler(ILogger logger, IServiceProvider services) await using var scope = services.CreateAsyncScope(); await using var messageRepository = - scope.ServiceProvider.GetRequiredService(); + scope.ServiceProvider.GetRequiredService(); var pluralkitApi = scope.ServiceProvider.GetRequiredService(); var (isStored, hasProxyInfo) = await messageRepository.HasProxyInfoAsync(msgId); diff --git a/Catalogger.Backend/Bot/Responders/Messages/MessageDeleteBulkResponder.cs b/Catalogger.Backend/Bot/Responders/Messages/MessageDeleteBulkResponder.cs index a986113..b9c66a4 100644 --- a/Catalogger.Backend/Bot/Responders/Messages/MessageDeleteBulkResponder.cs +++ b/Catalogger.Backend/Bot/Responders/Messages/MessageDeleteBulkResponder.cs @@ -32,7 +32,7 @@ namespace Catalogger.Backend.Bot.Responders.Messages; public class MessageDeleteBulkResponder( ILogger logger, GuildRepository guildRepository, - DapperMessageRepository messageRepository, + MessageRepository messageRepository, WebhookExecutorService webhookExecutor, ChannelCache channelCache ) : IResponder @@ -128,7 +128,7 @@ public class MessageDeleteBulkResponder( return Result.Success; } - private string RenderMessage(Snowflake messageId, DapperMessageRepository.Message? message) + private string RenderMessage(Snowflake messageId, MessageRepository.Message? message) { var timestamp = messageId.Timestamp.ToOffsetDateTime().ToString(); diff --git a/Catalogger.Backend/Bot/Responders/Messages/MessageDeleteResponder.cs b/Catalogger.Backend/Bot/Responders/Messages/MessageDeleteResponder.cs index f538995..d97b09d 100644 --- a/Catalogger.Backend/Bot/Responders/Messages/MessageDeleteResponder.cs +++ b/Catalogger.Backend/Bot/Responders/Messages/MessageDeleteResponder.cs @@ -33,7 +33,7 @@ namespace Catalogger.Backend.Bot.Responders.Messages; public class MessageDeleteResponder( ILogger logger, GuildRepository guildRepository, - DapperMessageRepository messageRepository, + MessageRepository messageRepository, WebhookExecutorService webhookExecutor, ChannelCache channelCache, UserCache userCache, diff --git a/Catalogger.Backend/Bot/Responders/Messages/MessageUpdateResponder.cs b/Catalogger.Backend/Bot/Responders/Messages/MessageUpdateResponder.cs index 8933c06..641fabd 100644 --- a/Catalogger.Backend/Bot/Responders/Messages/MessageUpdateResponder.cs +++ b/Catalogger.Backend/Bot/Responders/Messages/MessageUpdateResponder.cs @@ -35,7 +35,7 @@ public class MessageUpdateResponder( DatabaseContext db, ChannelCache channelCache, UserCache userCache, - DapperMessageRepository messageRepository, + MessageRepository messageRepository, WebhookExecutorService webhookExecutor, PluralkitApiService pluralkitApi ) : IResponder diff --git a/Catalogger.Backend/Database/Dapper/Repositories/InviteRepository.cs b/Catalogger.Backend/Database/Dapper/Repositories/InviteRepository.cs new file mode 100644 index 0000000..45183de --- /dev/null +++ b/Catalogger.Backend/Database/Dapper/Repositories/InviteRepository.cs @@ -0,0 +1,79 @@ +// Copyright (C) 2021-present sam (starshines.gay) +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +using Catalogger.Backend.Database.Models; +using Dapper; +using Remora.Rest.Core; + +namespace Catalogger.Backend.Database.Dapper.Repositories; + +public class InviteRepository(ILogger logger, DatabaseConnection conn) + : IDisposable, + IAsyncDisposable +{ + private readonly ILogger _logger = logger.ForContext(); + + public async Task> GetGuildInvitesAsync(Snowflake guildId) => + ( + await conn.QueryAsync( + "select * from invites where guild_id = @GuildId", + new { GuildId = guildId.Value } + ) + ).ToList(); + + public async Task SetInviteNameAsync(Snowflake guildId, string code, string name) => + await conn.ExecuteAsync( + """ + insert into invites (code, guild_id, name) values + (@Code, @GuildId, @Name) on conflict (code, guild_id) do update set name = @Name + """, + new + { + Code = code, + GuildId = guildId.Value, + Name = name, + } + ); + + public async Task GetInviteAsync(Snowflake guildId, string code) => + await conn.QueryFirstOrDefaultAsync( + "select * from invites where guild_id = @GuildId and code = @Code", + new { GuildId = guildId.Value, Code = code } + ); + + public async Task GetInviteNameAsync(Snowflake guildId, string code) => + await conn.ExecuteScalarAsync( + "select name from invites where guild_id = @GuildId and code = @Code", + new { GuildId = guildId.Value, Code = code } + ) ?? "(unnamed)"; + + public async Task DeleteInviteAsync(Snowflake guildId, string code) => + await conn.ExecuteAsync( + "delete from invites where guild_id = @GuildId and code = @Code", + new { GuildId = guildId.Value, Code = code } + ); + + public void Dispose() + { + conn.Dispose(); + GC.SuppressFinalize(this); + } + + public async ValueTask DisposeAsync() + { + await conn.DisposeAsync(); + GC.SuppressFinalize(this); + } +} diff --git a/Catalogger.Backend/Database/Dapper/Repositories/DapperMessageRepository.cs b/Catalogger.Backend/Database/Dapper/Repositories/MessageRepository.cs similarity index 98% rename from Catalogger.Backend/Database/Dapper/Repositories/DapperMessageRepository.cs rename to Catalogger.Backend/Database/Dapper/Repositories/MessageRepository.cs index c006f96..49ba4ab 100644 --- a/Catalogger.Backend/Database/Dapper/Repositories/DapperMessageRepository.cs +++ b/Catalogger.Backend/Database/Dapper/Repositories/MessageRepository.cs @@ -22,13 +22,13 @@ using Remora.Rest.Core; namespace Catalogger.Backend.Database.Dapper.Repositories; -public class DapperMessageRepository( +public class MessageRepository( ILogger logger, DatabaseConnection conn, IEncryptionService encryptionService ) : IDisposable, IAsyncDisposable { - private readonly ILogger _logger = logger.ForContext(); + private readonly ILogger _logger = logger.ForContext(); public async Task GetMessageAsync(ulong id, CancellationToken ct = default) { diff --git a/Catalogger.Backend/Database/Queries/MessageRepository.cs b/Catalogger.Backend/Database/Queries/MessageRepository.cs deleted file mode 100644 index 6f88bb5..0000000 --- a/Catalogger.Backend/Database/Queries/MessageRepository.cs +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright (C) 2021-present sam (starshines.gay) -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published -// by the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -using System.Text.Json; -using Catalogger.Backend.Extensions; -using Microsoft.EntityFrameworkCore; -using Remora.Discord.API; -using Remora.Discord.API.Abstractions.Gateway.Events; -using Remora.Rest.Core; -using DbMessage = Catalogger.Backend.Database.Models.Message; - -namespace Catalogger.Backend.Database.Queries; - -public class MessageRepository( - ILogger logger, - DatabaseContext db, - IEncryptionService encryptionService -) -{ - private readonly ILogger _logger = logger.ForContext(); - - public async Task SaveMessageAsync(IMessageCreate msg, CancellationToken ct = default) - { - _logger.Debug("Saving message {MessageId}", msg.ID); - - var metadata = new Metadata( - IsWebhook: msg.WebhookID.HasValue, - msg.Attachments.Select(a => new Attachment(a.Filename, a.Size, a.ContentType.Value)) - ); - - var dbMessage = new DbMessage - { - Id = msg.ID.ToUlong(), - UserId = msg.Author.ID.ToUlong(), - ChannelId = msg.ChannelID.ToUlong(), - GuildId = msg.GuildID.ToUlong(), - - Content = await Task.Run( - () => - encryptionService.Encrypt( - string.IsNullOrWhiteSpace(msg.Content) ? "None" : msg.Content - ), - ct - ), - Username = await Task.Run(() => encryptionService.Encrypt(msg.Author.Tag()), ct), - Metadata = await Task.Run( - () => encryptionService.Encrypt(JsonSerializer.Serialize(metadata)), - ct - ), - AttachmentSize = msg.Attachments.Select(a => a.Size).Sum(), - }; - - db.Add(dbMessage); - await db.SaveChangesAsync(ct); - } - - /// - /// Updates an edited message. - /// - /// true if the message was already stored and got updated, - /// false if the message wasn't stored and was newly inserted. - public async Task UpdateMessageAsync(IMessageCreate msg, CancellationToken ct = default) - { - _logger.Debug("Updating message {MessageId}", msg.ID); - - var tx = await db.Database.BeginTransactionAsync(ct); - var (isStored, _) = await HasProxyInfoAsync(msg.ID.Value); - if (!isStored) - { - _logger.Debug("Edited message {MessageId} is not stored yet, storing it", msg.ID); - await SaveMessageAsync(msg, ct); - await tx.CommitAsync(ct); - return false; - } - else - { - var metadata = new Metadata( - IsWebhook: msg.WebhookID.HasValue, - msg.Attachments.Select(a => new Attachment(a.Filename, a.Size, a.ContentType.Value)) - ); - - var dbMsg = await db.Messages.FindAsync( - new object?[] { msg.ID.Value }, - cancellationToken: ct - ); - if (dbMsg == null) - throw new CataloggerError( - "Message was null despite HasProxyInfoAsync returning true" - ); - - dbMsg.Content = await Task.Run( - () => - encryptionService.Encrypt( - string.IsNullOrWhiteSpace(msg.Content) ? "None" : msg.Content - ), - ct - ); - dbMsg.Username = await Task.Run(() => encryptionService.Encrypt(msg.Author.Tag()), ct); - dbMsg.Metadata = await Task.Run( - () => encryptionService.Encrypt(JsonSerializer.Serialize(metadata)), - ct - ); - - db.Update(dbMsg); - await db.SaveChangesAsync(ct); - await tx.CommitAsync(ct); - return true; - } - } - - public async Task GetMessageAsync(ulong id, CancellationToken ct = default) - { - _logger.Debug("Retrieving message {MessageId}", id); - - var dbMsg = await db.Messages.AsNoTracking().FirstOrDefaultAsync(m => m.Id == id, ct); - if (dbMsg == null) - return null; - - return new Message( - dbMsg.Id, - dbMsg.OriginalId, - dbMsg.UserId, - dbMsg.ChannelId, - dbMsg.GuildId, - dbMsg.Member, - dbMsg.System, - Username: await Task.Run(() => encryptionService.Decrypt(dbMsg.Username), ct), - Content: await Task.Run(() => encryptionService.Decrypt(dbMsg.Content), ct), - Metadata: dbMsg.Metadata != null - ? JsonSerializer.Deserialize( - await Task.Run(() => encryptionService.Decrypt(dbMsg.Metadata), ct) - ) - : null, - dbMsg.AttachmentSize - ); - } - - /// - /// Checks if a message has proxy information. - /// If yes, returns (true, true). If no, returns (true, false). If the message isn't saved at all, returns (false, false). - /// - public async Task<(bool, bool)> HasProxyInfoAsync(ulong id) - { - _logger.Debug("Checking if message {MessageId} has proxy information", id); - - var msg = await db - .Messages.AsNoTracking() - .Select(m => new { m.Id, m.OriginalId }) - .FirstOrDefaultAsync(m => m.Id == id); - return (msg != null, msg?.OriginalId != null); - } - - public async Task SetProxiedMessageDataAsync( - ulong id, - ulong originalId, - ulong authorId, - string? systemId, - string? memberId - ) - { - _logger.Debug("Setting proxy information for message {MessageId}", id); - - var message = await db.Messages.FirstOrDefaultAsync(m => m.Id == id); - if (message == null) - { - _logger.Debug("Message {MessageId} not found", id); - return; - } - - _logger.Debug("Updating message {MessageId}", id); - - message.OriginalId = originalId; - message.UserId = authorId; - message.System = systemId; - message.Member = memberId; - - db.Update(message); - await db.SaveChangesAsync(); - } - - public async Task IsMessageIgnoredAsync(ulong id, CancellationToken ct = default) - { - _logger.Debug("Checking if message {MessageId} is ignored", id); - return await db.IgnoredMessages.AsNoTracking().FirstOrDefaultAsync(m => m.Id == id, ct) - != null; - } - - public const int MaxMessageAgeDays = 15; - - public async Task<(int Messages, int IgnoredMessages)> DeleteExpiredMessagesAsync() - { - var cutoff = DateTimeOffset.UtcNow - TimeSpan.FromDays(MaxMessageAgeDays); - var cutoffId = Snowflake.CreateTimestampSnowflake(cutoff, Constants.DiscordEpoch); - - var msgCount = await db.Messages.Where(m => m.Id < cutoffId.Value).ExecuteDeleteAsync(); - var ignoredMsgCount = await db - .IgnoredMessages.Where(m => m.Id < cutoffId.Value) - .ExecuteDeleteAsync(); - - return (msgCount, ignoredMsgCount); - } - - public record Message( - ulong Id, - ulong? OriginalId, - ulong UserId, - ulong ChannelId, - ulong GuildId, - string? Member, - string? System, - string Username, - string Content, - Metadata? Metadata, - int AttachmentSize - ); - - public record Metadata(bool IsWebhook, IEnumerable Attachments); - - public record Attachment(string Filename, int Size, string ContentType); -} diff --git a/Catalogger.Backend/Extensions/StartupExtensions.cs b/Catalogger.Backend/Extensions/StartupExtensions.cs index e6577c3..a921874 100644 --- a/Catalogger.Backend/Extensions/StartupExtensions.cs +++ b/Catalogger.Backend/Extensions/StartupExtensions.cs @@ -106,8 +106,9 @@ public static class StartupExtensions services .AddSingleton(SystemClock.Instance) .AddDatabasePool() - .AddScoped() + .AddScoped() .AddScoped() + .AddScoped() .AddSingleton() .AddSingleton() .AddSingleton() @@ -118,7 +119,6 @@ public static class StartupExtensions .AddSingleton() .AddScoped() .AddSingleton() - // .AddScoped() .AddSingleton() .AddSingleton() .AddSingleton(InMemoryDataService.Instance) diff --git a/Catalogger.Backend/Services/BackgroundTasksService.cs b/Catalogger.Backend/Services/BackgroundTasksService.cs index 6efb6a2..936ed45 100644 --- a/Catalogger.Backend/Services/BackgroundTasksService.cs +++ b/Catalogger.Backend/Services/BackgroundTasksService.cs @@ -35,7 +35,7 @@ public class BackgroundTasksService(ILogger logger, IServiceProvider services) : await using var scope = services.CreateAsyncScope(); await using var messageRepository = - scope.ServiceProvider.GetRequiredService(); + scope.ServiceProvider.GetRequiredService(); var (msgCount, ignoredCount) = await messageRepository.DeleteExpiredMessagesAsync(); if (msgCount != 0 || ignoredCount != 0) @@ -44,7 +44,7 @@ public class BackgroundTasksService(ILogger logger, IServiceProvider services) : "Deleted {Count} messages and {IgnoredCount} ignored message IDs older than {MaxDays} days old", msgCount, ignoredCount, - DapperMessageRepository.MaxMessageAgeDays + MessageRepository.MaxMessageAgeDays ); } }