feat: exorcise entity framework core from most responders

This commit is contained in:
sam 2024-10-27 23:02:42 +01:00
parent 33b78a7ac5
commit 5891f28f7c
Signed by: sam
GPG key ID: 5F3C3C1B3166639D
32 changed files with 743 additions and 145 deletions

View file

@ -0,0 +1,96 @@
// Copyright (C) 2021-present sam (starshines.gay)
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
using System.Data;
using System.Data.Common;
using System.Diagnostics.CodeAnalysis;
using Npgsql;
namespace Catalogger.Backend.Database.Dapper;
public class DatabaseConnection(Guid id, ILogger logger, NpgsqlConnection inner)
: DbConnection,
IDisposable
{
public Guid ConnectionId => id;
private readonly ILogger _logger = logger.ForContext<DatabaseConnection>();
private readonly DateTimeOffset _openTime = DateTimeOffset.UtcNow;
private bool _hasClosed;
public override async Task OpenAsync(CancellationToken cancellationToken) =>
await inner.OpenAsync(cancellationToken);
public override async Task CloseAsync()
{
if (_hasClosed)
{
await inner.CloseAsync();
return;
}
DatabasePool.DecrementConnections();
var openFor = DateTimeOffset.UtcNow - _openTime;
_logger.Debug("Closing connection {ConnId}, open for {OpenFor}", ConnectionId, openFor);
_hasClosed = true;
await inner.CloseAsync();
}
protected override async ValueTask<DbTransaction> BeginDbTransactionAsync(
IsolationLevel isolationLevel,
CancellationToken cancellationToken
)
{
_logger.Debug("Beginning transaction on connection {ConnId}", ConnectionId);
return await inner.BeginTransactionAsync(isolationLevel, cancellationToken);
}
public new void Dispose()
{
Close();
inner.Dispose();
GC.SuppressFinalize(this);
}
public override async ValueTask DisposeAsync()
{
await CloseAsync();
await inner.DisposeAsync();
GC.SuppressFinalize(this);
}
protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) =>
inner.BeginTransaction(isolationLevel);
public override void ChangeDatabase(string databaseName) => inner.ChangeDatabase(databaseName);
public override void Close() => inner.Close();
public override void Open() => inner.Open();
[AllowNull]
public override string ConnectionString
{
get => inner.ConnectionString;
set => inner.ConnectionString = value;
}
public override string Database => inner.Database;
public override ConnectionState State => inner.State;
public override string DataSource => inner.DataSource;
public override string ServerVersion => inner.ServerVersion;
protected override DbCommand CreateDbCommand() => inner.CreateCommand();
}

View file

@ -0,0 +1,157 @@
// Copyright (C) 2021-present sam (starshines.gay)
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
using System.Data;
using Dapper;
using NodaTime;
using Npgsql;
namespace Catalogger.Backend.Database.Dapper;
public class DatabasePool
{
private readonly ILogger _rootLogger;
private readonly ILogger _logger;
private readonly NpgsqlDataSource _dataSource;
private static int _openConnections;
public static int OpenConnections => _openConnections;
public DatabasePool(Config config, ILogger logger, ILoggerFactory? loggerFactory)
{
_rootLogger = logger;
_logger = logger.ForContext<DatabasePool>();
var connString = new NpgsqlConnectionStringBuilder(config.Database.Url)
{
Timeout = config.Database.Timeout ?? 5,
MaxPoolSize = config.Database.MaxPoolSize ?? 50,
}.ConnectionString;
var dataSourceBuilder = new NpgsqlDataSourceBuilder(connString);
dataSourceBuilder.EnableDynamicJson().UseNodaTime();
if (config.Logging.LogQueries)
dataSourceBuilder.UseLoggerFactory(loggerFactory);
_dataSource = dataSourceBuilder.Build();
}
public async Task<DatabaseConnection> AcquireAsync(CancellationToken ct = default)
{
return new DatabaseConnection(
LogOpen(),
_rootLogger,
await _dataSource.OpenConnectionAsync(ct)
);
}
public DatabaseConnection Acquire()
{
return new DatabaseConnection(LogOpen(), _rootLogger, _dataSource.OpenConnection());
}
private Guid LogOpen()
{
var connId = Guid.NewGuid();
_logger.Debug("Opening database connection {ConnId}", connId);
IncrementConnections();
return connId;
}
public async Task ExecuteAsync(
Func<DatabaseConnection, Task> func,
CancellationToken ct = default
)
{
await using var conn = await AcquireAsync(ct);
await func(conn);
}
public async Task<T> ExecuteAsync<T>(
Func<DatabaseConnection, Task<T>> func,
CancellationToken ct = default
)
{
await using var conn = await AcquireAsync(ct);
return await func(conn);
}
public async Task<IAsyncEnumerable<T>> ExecuteAsync<T>(
Func<DatabaseConnection, Task<IAsyncEnumerable<T>>> func,
CancellationToken ct = default
)
{
await using var conn = await AcquireAsync(ct);
return await func(conn);
}
internal static void IncrementConnections() => Interlocked.Increment(ref _openConnections);
internal static void DecrementConnections() => Interlocked.Decrement(ref _openConnections);
/// <summary>
/// Configures Dapper's SQL mapping, as it handles several types incorrectly by default.
/// Most notably, ulongs and arrays of ulongs.
/// </summary>
public static void ConfigureDapper()
{
DefaultTypeMap.MatchNamesWithUnderscores = true;
SqlMapper.RemoveTypeMap(typeof(ulong));
SqlMapper.AddTypeHandler(new UlongEncodeAsLongHandler());
SqlMapper.AddTypeHandler(new UlongArrayHandler());
SqlMapper.AddTypeHandler(new PassthroughTypeHandler<Instant>());
}
// Copied from PluralKit:
// https://github.com/PluralKit/PluralKit/blob/4bf60a47d76a068fa0488bf9be96cdaf57a6fe50/PluralKit.Core/Database/Database.cs#L116
// Thanks for not working with common types by default, Dapper. Really nice of you.
private class PassthroughTypeHandler<T> : SqlMapper.TypeHandler<T>
{
public override void SetValue(IDbDataParameter parameter, T? value) =>
parameter.Value = value;
public override T Parse(object value) => (T)value;
}
private class UlongEncodeAsLongHandler : SqlMapper.TypeHandler<ulong>
{
public override ulong Parse(object value) =>
// Cast to long to unbox, then to ulong (???)
(ulong)(long)value;
public override void SetValue(IDbDataParameter parameter, ulong value) =>
parameter.Value = (long)value;
}
private class UlongArrayHandler : SqlMapper.TypeHandler<ulong[]>
{
public override void SetValue(IDbDataParameter parameter, ulong[]? value) =>
parameter.Value = value != null ? Array.ConvertAll(value, i => (long)i) : null;
public override ulong[] Parse(object value) =>
Array.ConvertAll((long[])value, i => (ulong)i);
}
}
public static class ServiceCollectionDatabaseExtensions
{
public static IServiceCollection AddDatabasePool(this IServiceCollection serviceCollection) =>
serviceCollection
.AddSingleton<DatabasePool>()
.AddScoped<DatabaseConnection>(services =>
services.GetRequiredService<DatabasePool>().Acquire()
);
}

View file

@ -0,0 +1,225 @@
// Copyright (C) 2021-present sam (starshines.gay)
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
using System.Text.Json;
using Catalogger.Backend.Extensions;
using Dapper;
using Remora.Discord.API;
using Remora.Discord.API.Abstractions.Gateway.Events;
using Remora.Rest.Core;
namespace Catalogger.Backend.Database.Dapper.Repositories;
public class DapperMessageRepository(
ILogger logger,
DatabaseConnection conn,
IEncryptionService encryptionService
) : IDisposable, IAsyncDisposable
{
private readonly ILogger _logger = logger.ForContext<DapperMessageRepository>();
public async Task<Message?> GetMessageAsync(ulong id, CancellationToken ct = default)
{
_logger.Debug("Retrieving message {MessageId}", id);
var dbMsg = await conn.QueryFirstOrDefaultAsync<Models.Message>(
"select * from messages where id = @Id",
new { Id = id }
);
if (dbMsg == null)
return null;
return new Message(
dbMsg.Id,
dbMsg.OriginalId,
dbMsg.UserId,
dbMsg.ChannelId,
dbMsg.GuildId,
dbMsg.Member,
dbMsg.System,
Username: await Task.Run(() => encryptionService.Decrypt(dbMsg.Username), ct),
Content: await Task.Run(() => encryptionService.Decrypt(dbMsg.Content), ct),
Metadata: dbMsg.Metadata != null
? JsonSerializer.Deserialize<Metadata>(
await Task.Run(() => encryptionService.Decrypt(dbMsg.Metadata), ct)
)
: null,
dbMsg.AttachmentSize
);
}
/// <summary>
/// Adds a new message. If the message is already in the database, updates the existing message instead.
/// </summary>
public async Task<bool> SaveMessageAsync(IMessageCreate msg, CancellationToken ct = default)
{
var content = await Task.Run(
() =>
encryptionService.Encrypt(
string.IsNullOrWhiteSpace(msg.Content) ? "None" : msg.Content
),
ct
);
var username = await Task.Run(() => encryptionService.Encrypt(msg.Author.Tag()), ct);
var metadata = await Task.Run(
() =>
encryptionService.Encrypt(
JsonSerializer.Serialize(
new Metadata(
IsWebhook: msg.WebhookID.HasValue,
msg.Attachments.Select(a => new Attachment(
a.Filename,
a.Size,
a.ContentType.Value
))
)
)
),
ct
);
// MessageUpdateResponder wants to know whether the message already existed, so query this *before* inserting.
var exists = await conn.ExecuteScalarAsync<bool>(
"select exists(select id from messages where id = @Id)",
new { Id = msg.ID.Value }
);
await conn.ExecuteAsync(
"""
insert into messages (id, user_id, channel_id, guild_id, username, content, metadata, attachment_size)
values (@Id, @UserId, @ChannelId, @GuildId, @Username, @Content, @Metadata, @AttachmentSize)
on conflict (id) do update set username = @Username, content = @Content, metadata = @Metadata
""",
new
{
Id = msg.ID.Value,
UserId = msg.Author.ID.Value,
ChannelId = msg.ChannelID.Value,
GuildId = msg.GuildID.Map(s => s.Value).OrDefault(),
Content = content,
Username = username,
Metadata = metadata,
AttachmentSize = msg.Attachments.Select(a => a.Size).Sum(),
}
);
return exists;
}
public async Task<(bool IsStored, bool HasProxyInfo)> HasProxyInfoAsync(ulong id)
{
_logger.Debug("Checking if message {MessageId} has proxy information", id);
var msg = await conn.QueryFirstOrDefaultAsync<(ulong Id, ulong OriginalId)>(
"select id, original_id from messages where id = @Id",
new { Id = id }
);
return (msg.Id != 0, msg.OriginalId != 0);
}
/// <summary>
/// Updates a stored message with PluralKit information.
/// </summary>
/// <returns>True if the message exists and was updated, false if it doesn't exist.</returns>
public async Task<bool> SetProxiedMessageDataAsync(
ulong id,
ulong originalId,
ulong authorId,
string? systemId,
string? memberId
)
{
_logger.Debug("Setting proxy information for message {MessageId}", id);
var updatedCount = await conn.ExecuteAsync(
"update messages set original_id = @OriginalId, user_id = @AuthorId, system = @SystemId, member = @MemberId where id = @Id",
new
{
Id = id,
OriginalId = originalId,
AuthorId = authorId,
SystemId = systemId,
MemberId = memberId,
}
);
if (updatedCount == 0)
{
_logger.Debug("Message {MessageId} not found, can't set proxy data for it", id);
return false;
}
return true;
}
public async Task<bool> IsMessageIgnoredAsync(ulong id) =>
await conn.ExecuteScalarAsync<bool>(
"select exists(select id from messages where id = @Id)",
new { Id = id }
);
public const int MaxMessageAgeDays = 15;
public async Task<(int Messages, int IgnoredMessages)> DeleteExpiredMessagesAsync()
{
var cutoff = DateTimeOffset.UtcNow - TimeSpan.FromDays(MaxMessageAgeDays);
var cutoffId = Snowflake.CreateTimestampSnowflake(cutoff, Constants.DiscordEpoch).Value;
var msgCount = await conn.ExecuteAsync(
"delete from messages where id < @Cutoff",
new { Cutoff = cutoffId }
);
var ignoredMsgCount = await conn.ExecuteAsync(
"delete from ignored_messages where id < @Cutoff",
new { Cutoff = cutoffId }
);
return (msgCount, ignoredMsgCount);
}
public async Task IgnoreMessageAsync(ulong id) =>
await conn.ExecuteAsync(
"insert into ignored_messages (id) values (@Id) on conflict do nothing",
new { Id = id }
);
public record Message(
ulong Id,
ulong? OriginalId,
ulong UserId,
ulong ChannelId,
ulong GuildId,
string? Member,
string? System,
string Username,
string Content,
Metadata? Metadata,
int AttachmentSize
);
public record Metadata(bool IsWebhook, IEnumerable<Attachment> Attachments);
public record Attachment(string Filename, int Size, string ContentType);
public void Dispose()
{
conn.Dispose();
GC.SuppressFinalize(this);
}
public async ValueTask DisposeAsync()
{
await conn.DisposeAsync();
GC.SuppressFinalize(this);
}
}

View file

@ -0,0 +1,89 @@
// Copyright (C) 2021-present sam (starshines.gay)
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
using Catalogger.Backend.Database.Models;
using Dapper;
using Remora.Rest.Core;
namespace Catalogger.Backend.Database.Dapper.Repositories;
public class GuildRepository(ILogger logger, DatabaseConnection conn)
: IDisposable,
IAsyncDisposable
{
private readonly ILogger _logger = logger.ForContext<GuildRepository>();
public async Task<Guild> GetAsync(Optional<Snowflake> id) => await GetAsync(id.Value.Value);
public async Task<Guild> GetAsync(Snowflake id) => await GetAsync(id.Value);
public async Task<Guild> GetAsync(ulong id)
{
_logger.Debug("Getting guild config for {GuildId}", id);
var guild = await conn.QueryFirstOrDefaultAsync<Guild>(
"select * from guilds where id = @Id",
new { Id = id }
);
if (guild == null)
throw new CataloggerError("Guild not found, was not initialized during guild create");
return guild;
}
public async Task<bool> IsGuildKnown(ulong id) =>
await conn.ExecuteScalarAsync<bool>(
"select exists(select id from guilds where id = @Id)",
new { Id = id }
);
public async Task AddGuildAsync(ulong id) =>
await conn.ExecuteAsync(
"""
insert into guilds (id, key_roles, banned_systems, key_roles, channels)
values (@Id, array[]::bigint[], array[]::text[], array[]::bigint[], @Channels)
on conflict do nothing
""",
new { Id = id, Channels = new Guild.ChannelConfig() }
);
public async Task BanSystemAsync(Snowflake guildId, string hid, Guid uuid) =>
await conn.ExecuteAsync(
"update guilds set banned_systems = array_cat(banned_systems, @SystemIds) where id = @GuildId",
new { GuildId = guildId.Value, SystemIds = (string[])[hid, uuid.ToString()] }
);
public async Task UnbanSystemAsync(Snowflake guildId, string hid, Guid uuid) =>
await conn.ExecuteAsync(
"update guilds set banned_systems = array_remove(array_remove(banned_systems, @Hid), @Uuid) where id = @Id",
new
{
GuildId = guildId.Value,
Hid = hid,
Uuid = uuid.ToString(),
}
);
public void Dispose()
{
conn.Dispose();
GC.SuppressFinalize(this);
}
public async ValueTask DisposeAsync()
{
await conn.DisposeAsync();
GC.SuppressFinalize(this);
}
}

View file

@ -30,14 +30,11 @@ public class Message
public string? Member { get; set; }
public string? System { get; set; }
[Column("username")]
public byte[] EncryptedUsername { get; set; } = [];
public byte[] Username { get; set; } = [];
[Column("content")]
public byte[] EncryptedContent { get; set; } = [];
public byte[] Content { get; set; } = [];
[Column("metadata")]
public byte[]? EncryptedMetadata { get; set; }
public byte[]? Metadata { get; set; }
public int AttachmentSize { get; set; } = 0;
}

View file

@ -47,18 +47,15 @@ public class MessageRepository(
ChannelId = msg.ChannelID.ToUlong(),
GuildId = msg.GuildID.ToUlong(),
EncryptedContent = await Task.Run(
Content = await Task.Run(
() =>
encryptionService.Encrypt(
string.IsNullOrWhiteSpace(msg.Content) ? "None" : msg.Content
),
ct
),
EncryptedUsername = await Task.Run(
() => encryptionService.Encrypt(msg.Author.Tag()),
ct
),
EncryptedMetadata = await Task.Run(
Username = await Task.Run(() => encryptionService.Encrypt(msg.Author.Tag()), ct),
Metadata = await Task.Run(
() => encryptionService.Encrypt(JsonSerializer.Serialize(metadata)),
ct
),
@ -103,18 +100,15 @@ public class MessageRepository(
"Message was null despite HasProxyInfoAsync returning true"
);
dbMsg.EncryptedContent = await Task.Run(
dbMsg.Content = await Task.Run(
() =>
encryptionService.Encrypt(
string.IsNullOrWhiteSpace(msg.Content) ? "None" : msg.Content
),
ct
);
dbMsg.EncryptedUsername = await Task.Run(
() => encryptionService.Encrypt(msg.Author.Tag()),
ct
);
dbMsg.EncryptedMetadata = await Task.Run(
dbMsg.Username = await Task.Run(() => encryptionService.Encrypt(msg.Author.Tag()), ct);
dbMsg.Metadata = await Task.Run(
() => encryptionService.Encrypt(JsonSerializer.Serialize(metadata)),
ct
);
@ -142,11 +136,11 @@ public class MessageRepository(
dbMsg.GuildId,
dbMsg.Member,
dbMsg.System,
Username: await Task.Run(() => encryptionService.Decrypt(dbMsg.EncryptedUsername), ct),
Content: await Task.Run(() => encryptionService.Decrypt(dbMsg.EncryptedContent), ct),
Metadata: dbMsg.EncryptedMetadata != null
Username: await Task.Run(() => encryptionService.Decrypt(dbMsg.Username), ct),
Content: await Task.Run(() => encryptionService.Decrypt(dbMsg.Content), ct),
Metadata: dbMsg.Metadata != null
? JsonSerializer.Deserialize<Metadata>(
await Task.Run(() => encryptionService.Decrypt(dbMsg.EncryptedMetadata), ct)
await Task.Run(() => encryptionService.Decrypt(dbMsg.Metadata), ct)
)
: null,
dbMsg.AttachmentSize