add invite repository to replace ef core

This commit is contained in:
sam 2024-10-27 23:30:02 +01:00
parent 5891f28f7c
commit 64b4c26d93
Signed by: sam
GPG key ID: 5F3C3C1B3166639D
12 changed files with 112 additions and 301 deletions

View file

@ -15,12 +15,10 @@
using System.ComponentModel;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.InteropServices;
using Catalogger.Backend.Cache;
using Catalogger.Backend.Cache.InMemoryCache;
using Catalogger.Backend.Database;
using Catalogger.Backend.Database.Dapper.Repositories;
using Catalogger.Backend.Extensions;
using Microsoft.EntityFrameworkCore;
using Remora.Commands.Attributes;
using Remora.Commands.Groups;
using Remora.Discord.API.Abstractions.Objects;
@ -33,7 +31,6 @@ using Remora.Discord.Commands.Feedback.Services;
using Remora.Discord.Commands.Services;
using Remora.Discord.Pagination.Extensions;
using Remora.Rest.Core;
using Invite = Catalogger.Backend.Database.Models.Invite;
using IResult = Remora.Results.IResult;
namespace Catalogger.Backend.Bot.Commands;
@ -43,7 +40,7 @@ namespace Catalogger.Backend.Bot.Commands;
[DiscordDefaultMemberPermissions(DiscordPermission.ManageGuild)]
public class InviteCommands(
ILogger logger,
DatabaseContext db,
InviteRepository inviteRepository,
GuildCache guildCache,
IInviteCache inviteCache,
IDiscordRestChannelAPI channelApi,
@ -63,7 +60,7 @@ public class InviteCommands(
if (!guildCache.TryGet(guildId, out var guild))
throw new CataloggerError("Guild not in cache");
var dbInvites = await db.Invites.Where(i => i.GuildId == guildId.Value).ToListAsync();
var dbInvites = await inviteRepository.GetGuildInvitesAsync(guildId);
var fields = guildInvites
.Select(i => new PartialNamedInvite(
@ -153,14 +150,7 @@ public class InviteCommands(
+ $"\nLink: https://discord.gg/{inviteResult.Entity.Code}"
);
var dbInvite = new Invite
{
GuildId = guildId.Value,
Code = inviteResult.Entity.Code,
Name = name,
};
db.Add(dbInvite);
await db.SaveChangesAsync();
await inviteRepository.SetInviteNameAsync(guildId, inviteResult.Entity.Code, name);
return await feedbackService.ReplyAsync(
$"Created a new invite in <#{channel.ID}> with the name **{name}**!"
@ -188,39 +178,18 @@ public class InviteCommands(
);
}
var namedInvite = await db
.Invites.Where(i => i.GuildId == guildId.Value && i.Code == invite)
.FirstOrDefaultAsync();
if (namedInvite == null)
{
if (name == null)
return await feedbackService.ReplyAsync($"Invite `{invite}` already has no name.");
namedInvite = new Invite
{
GuildId = guildId.Value,
Code = invite,
Name = name,
};
db.Add(namedInvite);
await db.SaveChangesAsync();
return await feedbackService.ReplyAsync(
$"New name set! The invite `{invite}` will now show up as **{name}** in logs."
);
}
var namedInvite = await inviteRepository.GetInviteAsync(guildId, invite);
if (namedInvite == null && name == null)
return await feedbackService.ReplyAsync($"Invite `{invite}` already has no name.");
if (name == null)
{
db.Invites.Remove(namedInvite);
await db.SaveChangesAsync();
await inviteRepository.DeleteInviteAsync(guildId, invite);
return await feedbackService.ReplyAsync($"Removed the name for `{invite}`.");
}
namedInvite.Name = name;
db.Update(namedInvite);
await db.SaveChangesAsync();
await inviteRepository.SetInviteNameAsync(guildId, invite, name);
return await feedbackService.ReplyAsync(
$"New name set! The invite `{invite}` will now show up as **{name}** in logs."
@ -230,7 +199,7 @@ public class InviteCommands(
public class InviteAutocompleteProvider(
ILogger logger,
DatabaseContext db,
InviteRepository inviteRepository,
IInviteCache inviteCache,
ContextInjectionService contextInjection
) : IAutocompleteProvider
@ -262,13 +231,13 @@ public class InviteAutocompleteProvider(
return [];
}
var namedInvites = await db
.Invites.Where(i =>
i.GuildId == guildId.Value && i.Name.ToLower().StartsWith(userInput.ToLower())
)
// We're filtering and ordering on the client side because a guild won't have infinite invites
// (the maximum on Discord's end is 1500-ish)
// and this way we don't need an index on (guild_id, name) for this *one* case.
var namedInvites = (await inviteRepository.GetGuildInvitesAsync(guildId))
.Where(i => i.Name.StartsWith(userInput, StringComparison.InvariantCultureIgnoreCase))
.OrderBy(i => i.Name)
.Take(25)
.ToListAsync(ct);
.ToList();
if (namedInvites.Count != 0)
{

View file

@ -32,7 +32,7 @@ namespace Catalogger.Backend.Bot.Responders.Invites;
public class InviteDeleteResponder(
ILogger logger,
GuildRepository guildRepository,
DatabaseContext db,
InviteRepository inviteRepository,
IInviteCache inviteCache,
WebhookExecutorService webhookExecutor,
IDiscordRestGuildAPI guildApi
@ -44,9 +44,7 @@ public class InviteDeleteResponder(
{
var guildId = evt.GuildID.Value;
var dbDeleteCount = await db
.Invites.Where(i => i.GuildId == guildId.Value && i.Code == evt.Code)
.ExecuteDeleteAsync(ct);
var dbDeleteCount = await inviteRepository.DeleteInviteAsync(guildId, evt.Code);
if (dbDeleteCount != 0)
_logger.Information(
"Deleted named invite {Invite} for guild {Guild}",

View file

@ -36,6 +36,7 @@ namespace Catalogger.Backend.Bot.Responders.Members;
public class GuildMemberAddResponder(
ILogger logger,
DatabaseContext db,
InviteRepository inviteRepository,
GuildRepository guildRepository,
IMemberCache memberCache,
IInviteCache inviteCache,
@ -128,11 +129,7 @@ public class GuildMemberAddResponder(
goto afterInvite;
}
var inviteName =
await db
.Invites.Where(i => i.Code == usedInvite.Code && i.GuildId == member.GuildID.Value)
.Select(i => i.Name)
.FirstOrDefaultAsync(ct) ?? "*(unnamed)*";
var inviteName = inviteRepository.GetInviteNameAsync(member.GuildID, usedInvite.Code);
var inviteDescription = $"""
**Code:** {usedInvite.Code}

View file

@ -29,7 +29,7 @@ public class MessageCreateResponder(
ILogger logger,
Config config,
GuildRepository guildRepository,
DapperMessageRepository messageRepository,
MessageRepository messageRepository,
UserCache userCache,
PkMessageHandler pkMessageHandler
) : IResponder<IMessageCreate>
@ -146,7 +146,7 @@ public partial class PkMessageHandler(ILogger logger, IServiceProvider services)
await using var scope = services.CreateAsyncScope();
await using var messageRepository =
scope.ServiceProvider.GetRequiredService<DapperMessageRepository>();
scope.ServiceProvider.GetRequiredService<MessageRepository>();
await Task.WhenAll(
messageRepository.SetProxiedMessageDataAsync(
@ -166,7 +166,7 @@ public partial class PkMessageHandler(ILogger logger, IServiceProvider services)
await using var scope = services.CreateAsyncScope();
await using var messageRepository =
scope.ServiceProvider.GetRequiredService<DapperMessageRepository>();
scope.ServiceProvider.GetRequiredService<MessageRepository>();
var pluralkitApi = scope.ServiceProvider.GetRequiredService<PluralkitApiService>();
var (isStored, hasProxyInfo) = await messageRepository.HasProxyInfoAsync(msgId);

View file

@ -32,7 +32,7 @@ namespace Catalogger.Backend.Bot.Responders.Messages;
public class MessageDeleteBulkResponder(
ILogger logger,
GuildRepository guildRepository,
DapperMessageRepository messageRepository,
MessageRepository messageRepository,
WebhookExecutorService webhookExecutor,
ChannelCache channelCache
) : IResponder<IMessageDeleteBulk>
@ -128,7 +128,7 @@ public class MessageDeleteBulkResponder(
return Result.Success;
}
private string RenderMessage(Snowflake messageId, DapperMessageRepository.Message? message)
private string RenderMessage(Snowflake messageId, MessageRepository.Message? message)
{
var timestamp = messageId.Timestamp.ToOffsetDateTime().ToString();

View file

@ -33,7 +33,7 @@ namespace Catalogger.Backend.Bot.Responders.Messages;
public class MessageDeleteResponder(
ILogger logger,
GuildRepository guildRepository,
DapperMessageRepository messageRepository,
MessageRepository messageRepository,
WebhookExecutorService webhookExecutor,
ChannelCache channelCache,
UserCache userCache,

View file

@ -35,7 +35,7 @@ public class MessageUpdateResponder(
DatabaseContext db,
ChannelCache channelCache,
UserCache userCache,
DapperMessageRepository messageRepository,
MessageRepository messageRepository,
WebhookExecutorService webhookExecutor,
PluralkitApiService pluralkitApi
) : IResponder<IMessageUpdate>

View file

@ -0,0 +1,79 @@
// Copyright (C) 2021-present sam (starshines.gay)
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
using Catalogger.Backend.Database.Models;
using Dapper;
using Remora.Rest.Core;
namespace Catalogger.Backend.Database.Dapper.Repositories;
public class InviteRepository(ILogger logger, DatabaseConnection conn)
: IDisposable,
IAsyncDisposable
{
private readonly ILogger _logger = logger.ForContext<InviteRepository>();
public async Task<List<Invite>> GetGuildInvitesAsync(Snowflake guildId) =>
(
await conn.QueryAsync<Invite>(
"select * from invites where guild_id = @GuildId",
new { GuildId = guildId.Value }
)
).ToList();
public async Task SetInviteNameAsync(Snowflake guildId, string code, string name) =>
await conn.ExecuteAsync(
"""
insert into invites (code, guild_id, name) values
(@Code, @GuildId, @Name) on conflict (code, guild_id) do update set name = @Name
""",
new
{
Code = code,
GuildId = guildId.Value,
Name = name,
}
);
public async Task<Invite?> GetInviteAsync(Snowflake guildId, string code) =>
await conn.QueryFirstOrDefaultAsync<Invite>(
"select * from invites where guild_id = @GuildId and code = @Code",
new { GuildId = guildId.Value, Code = code }
);
public async Task<string> GetInviteNameAsync(Snowflake guildId, string code) =>
await conn.ExecuteScalarAsync<string>(
"select name from invites where guild_id = @GuildId and code = @Code",
new { GuildId = guildId.Value, Code = code }
) ?? "(unnamed)";
public async Task<int> DeleteInviteAsync(Snowflake guildId, string code) =>
await conn.ExecuteAsync(
"delete from invites where guild_id = @GuildId and code = @Code",
new { GuildId = guildId.Value, Code = code }
);
public void Dispose()
{
conn.Dispose();
GC.SuppressFinalize(this);
}
public async ValueTask DisposeAsync()
{
await conn.DisposeAsync();
GC.SuppressFinalize(this);
}
}

View file

@ -22,13 +22,13 @@ using Remora.Rest.Core;
namespace Catalogger.Backend.Database.Dapper.Repositories;
public class DapperMessageRepository(
public class MessageRepository(
ILogger logger,
DatabaseConnection conn,
IEncryptionService encryptionService
) : IDisposable, IAsyncDisposable
{
private readonly ILogger _logger = logger.ForContext<DapperMessageRepository>();
private readonly ILogger _logger = logger.ForContext<MessageRepository>();
public async Task<Message?> GetMessageAsync(ulong id, CancellationToken ct = default)
{

View file

@ -1,232 +0,0 @@
// Copyright (C) 2021-present sam (starshines.gay)
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
using System.Text.Json;
using Catalogger.Backend.Extensions;
using Microsoft.EntityFrameworkCore;
using Remora.Discord.API;
using Remora.Discord.API.Abstractions.Gateway.Events;
using Remora.Rest.Core;
using DbMessage = Catalogger.Backend.Database.Models.Message;
namespace Catalogger.Backend.Database.Queries;
public class MessageRepository(
ILogger logger,
DatabaseContext db,
IEncryptionService encryptionService
)
{
private readonly ILogger _logger = logger.ForContext<MessageRepository>();
public async Task SaveMessageAsync(IMessageCreate msg, CancellationToken ct = default)
{
_logger.Debug("Saving message {MessageId}", msg.ID);
var metadata = new Metadata(
IsWebhook: msg.WebhookID.HasValue,
msg.Attachments.Select(a => new Attachment(a.Filename, a.Size, a.ContentType.Value))
);
var dbMessage = new DbMessage
{
Id = msg.ID.ToUlong(),
UserId = msg.Author.ID.ToUlong(),
ChannelId = msg.ChannelID.ToUlong(),
GuildId = msg.GuildID.ToUlong(),
Content = await Task.Run(
() =>
encryptionService.Encrypt(
string.IsNullOrWhiteSpace(msg.Content) ? "None" : msg.Content
),
ct
),
Username = await Task.Run(() => encryptionService.Encrypt(msg.Author.Tag()), ct),
Metadata = await Task.Run(
() => encryptionService.Encrypt(JsonSerializer.Serialize(metadata)),
ct
),
AttachmentSize = msg.Attachments.Select(a => a.Size).Sum(),
};
db.Add(dbMessage);
await db.SaveChangesAsync(ct);
}
/// <summary>
/// Updates an edited message.
/// </summary>
/// <returns>true if the message was already stored and got updated,
/// false if the message wasn't stored and was newly inserted.</returns>
public async Task<bool> UpdateMessageAsync(IMessageCreate msg, CancellationToken ct = default)
{
_logger.Debug("Updating message {MessageId}", msg.ID);
var tx = await db.Database.BeginTransactionAsync(ct);
var (isStored, _) = await HasProxyInfoAsync(msg.ID.Value);
if (!isStored)
{
_logger.Debug("Edited message {MessageId} is not stored yet, storing it", msg.ID);
await SaveMessageAsync(msg, ct);
await tx.CommitAsync(ct);
return false;
}
else
{
var metadata = new Metadata(
IsWebhook: msg.WebhookID.HasValue,
msg.Attachments.Select(a => new Attachment(a.Filename, a.Size, a.ContentType.Value))
);
var dbMsg = await db.Messages.FindAsync(
new object?[] { msg.ID.Value },
cancellationToken: ct
);
if (dbMsg == null)
throw new CataloggerError(
"Message was null despite HasProxyInfoAsync returning true"
);
dbMsg.Content = await Task.Run(
() =>
encryptionService.Encrypt(
string.IsNullOrWhiteSpace(msg.Content) ? "None" : msg.Content
),
ct
);
dbMsg.Username = await Task.Run(() => encryptionService.Encrypt(msg.Author.Tag()), ct);
dbMsg.Metadata = await Task.Run(
() => encryptionService.Encrypt(JsonSerializer.Serialize(metadata)),
ct
);
db.Update(dbMsg);
await db.SaveChangesAsync(ct);
await tx.CommitAsync(ct);
return true;
}
}
public async Task<Message?> GetMessageAsync(ulong id, CancellationToken ct = default)
{
_logger.Debug("Retrieving message {MessageId}", id);
var dbMsg = await db.Messages.AsNoTracking().FirstOrDefaultAsync(m => m.Id == id, ct);
if (dbMsg == null)
return null;
return new Message(
dbMsg.Id,
dbMsg.OriginalId,
dbMsg.UserId,
dbMsg.ChannelId,
dbMsg.GuildId,
dbMsg.Member,
dbMsg.System,
Username: await Task.Run(() => encryptionService.Decrypt(dbMsg.Username), ct),
Content: await Task.Run(() => encryptionService.Decrypt(dbMsg.Content), ct),
Metadata: dbMsg.Metadata != null
? JsonSerializer.Deserialize<Metadata>(
await Task.Run(() => encryptionService.Decrypt(dbMsg.Metadata), ct)
)
: null,
dbMsg.AttachmentSize
);
}
/// <summary>
/// Checks if a message has proxy information.
/// If yes, returns (true, true). If no, returns (true, false). If the message isn't saved at all, returns (false, false).
/// </summary>
public async Task<(bool, bool)> HasProxyInfoAsync(ulong id)
{
_logger.Debug("Checking if message {MessageId} has proxy information", id);
var msg = await db
.Messages.AsNoTracking()
.Select(m => new { m.Id, m.OriginalId })
.FirstOrDefaultAsync(m => m.Id == id);
return (msg != null, msg?.OriginalId != null);
}
public async Task SetProxiedMessageDataAsync(
ulong id,
ulong originalId,
ulong authorId,
string? systemId,
string? memberId
)
{
_logger.Debug("Setting proxy information for message {MessageId}", id);
var message = await db.Messages.FirstOrDefaultAsync(m => m.Id == id);
if (message == null)
{
_logger.Debug("Message {MessageId} not found", id);
return;
}
_logger.Debug("Updating message {MessageId}", id);
message.OriginalId = originalId;
message.UserId = authorId;
message.System = systemId;
message.Member = memberId;
db.Update(message);
await db.SaveChangesAsync();
}
public async Task<bool> IsMessageIgnoredAsync(ulong id, CancellationToken ct = default)
{
_logger.Debug("Checking if message {MessageId} is ignored", id);
return await db.IgnoredMessages.AsNoTracking().FirstOrDefaultAsync(m => m.Id == id, ct)
!= null;
}
public const int MaxMessageAgeDays = 15;
public async Task<(int Messages, int IgnoredMessages)> DeleteExpiredMessagesAsync()
{
var cutoff = DateTimeOffset.UtcNow - TimeSpan.FromDays(MaxMessageAgeDays);
var cutoffId = Snowflake.CreateTimestampSnowflake(cutoff, Constants.DiscordEpoch);
var msgCount = await db.Messages.Where(m => m.Id < cutoffId.Value).ExecuteDeleteAsync();
var ignoredMsgCount = await db
.IgnoredMessages.Where(m => m.Id < cutoffId.Value)
.ExecuteDeleteAsync();
return (msgCount, ignoredMsgCount);
}
public record Message(
ulong Id,
ulong? OriginalId,
ulong UserId,
ulong ChannelId,
ulong GuildId,
string? Member,
string? System,
string Username,
string Content,
Metadata? Metadata,
int AttachmentSize
);
public record Metadata(bool IsWebhook, IEnumerable<Attachment> Attachments);
public record Attachment(string Filename, int Size, string ContentType);
}

View file

@ -106,8 +106,9 @@ public static class StartupExtensions
services
.AddSingleton<IClock>(SystemClock.Instance)
.AddDatabasePool()
.AddScoped<DapperMessageRepository>()
.AddScoped<MessageRepository>()
.AddScoped<GuildRepository>()
.AddScoped<InviteRepository>()
.AddSingleton<GuildCache>()
.AddSingleton<RoleCache>()
.AddSingleton<ChannelCache>()
@ -118,7 +119,6 @@ public static class StartupExtensions
.AddSingleton<NewsService>()
.AddScoped<IEncryptionService, EncryptionService>()
.AddSingleton<MetricsCollectionService>()
// .AddScoped<MessageRepository>()
.AddSingleton<WebhookExecutorService>()
.AddSingleton<PkMessageHandler>()
.AddSingleton(InMemoryDataService<Snowflake, ChannelCommandData>.Instance)

View file

@ -35,7 +35,7 @@ public class BackgroundTasksService(ILogger logger, IServiceProvider services) :
await using var scope = services.CreateAsyncScope();
await using var messageRepository =
scope.ServiceProvider.GetRequiredService<DapperMessageRepository>();
scope.ServiceProvider.GetRequiredService<MessageRepository>();
var (msgCount, ignoredCount) = await messageRepository.DeleteExpiredMessagesAsync();
if (msgCount != 0 || ignoredCount != 0)
@ -44,7 +44,7 @@ public class BackgroundTasksService(ILogger logger, IServiceProvider services) :
"Deleted {Count} messages and {IgnoredCount} ignored message IDs older than {MaxDays} days old",
msgCount,
ignoredCount,
DapperMessageRepository.MaxMessageAgeDays
MessageRepository.MaxMessageAgeDays
);
}
}