Catalogger.NET/Catalogger.Backend/Database/DatabasePool.cs

175 lines
5.9 KiB
C#

// Copyright (C) 2021-present sam (starshines.gay)
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
using System.Data;
using System.Text.Json;
using Catalogger.Backend.Database.Models;
using Dapper;
using NodaTime;
using Npgsql;
namespace Catalogger.Backend.Database;
public class DatabasePool
{
private readonly ILogger _rootLogger;
private readonly ILogger _logger;
private readonly NpgsqlDataSource _dataSource;
private static int _openConnections;
public static int OpenConnections => _openConnections;
public DatabasePool(Config config, ILogger logger, ILoggerFactory? loggerFactory)
{
_rootLogger = logger;
_logger = logger.ForContext<DatabasePool>();
var connString = new NpgsqlConnectionStringBuilder(config.Database.Url)
{
Timeout = config.Database.Timeout ?? 5,
MaxPoolSize = config.Database.MaxPoolSize ?? 50,
}.ConnectionString;
var dataSourceBuilder = new NpgsqlDataSourceBuilder(connString);
dataSourceBuilder.EnableDynamicJson().UseNodaTime();
if (config.Logging.LogQueries)
dataSourceBuilder.UseLoggerFactory(loggerFactory);
_dataSource = dataSourceBuilder.Build();
}
public async Task<DatabaseConnection> AcquireAsync(CancellationToken ct = default)
{
return new DatabaseConnection(
LogOpen(),
_rootLogger,
await _dataSource.OpenConnectionAsync(ct)
);
}
public DatabaseConnection Acquire()
{
return new DatabaseConnection(LogOpen(), _rootLogger, _dataSource.OpenConnection());
}
private Guid LogOpen()
{
var connId = Guid.NewGuid();
_logger.Verbose("Opening database connection {ConnId}", connId);
IncrementConnections();
return connId;
}
public async Task ExecuteAsync(
Func<DatabaseConnection, Task> func,
CancellationToken ct = default
)
{
await using var conn = await AcquireAsync(ct);
await func(conn);
}
public async Task<T> ExecuteAsync<T>(
Func<DatabaseConnection, Task<T>> func,
CancellationToken ct = default
)
{
await using var conn = await AcquireAsync(ct);
return await func(conn);
}
public async Task<IAsyncEnumerable<T>> ExecuteAsync<T>(
Func<DatabaseConnection, Task<IAsyncEnumerable<T>>> func,
CancellationToken ct = default
)
{
await using var conn = await AcquireAsync(ct);
return await func(conn);
}
internal static void IncrementConnections() => Interlocked.Increment(ref _openConnections);
internal static void DecrementConnections() => Interlocked.Decrement(ref _openConnections);
/// <summary>
/// Configures Dapper's SQL mapping, as it handles several types incorrectly by default.
/// Most notably, ulongs and arrays of ulongs.
/// </summary>
public static void ConfigureDapper()
{
DefaultTypeMap.MatchNamesWithUnderscores = true;
SqlMapper.RemoveTypeMap(typeof(ulong));
SqlMapper.AddTypeHandler(new UlongEncodeAsLongHandler());
SqlMapper.AddTypeHandler(new UlongArrayHandler());
SqlMapper.AddTypeHandler(new PassthroughTypeHandler<Instant>());
SqlMapper.AddTypeHandler(new JsonTypeHandler<Guild.ChannelConfig>());
}
// Copied from PluralKit:
// https://github.com/PluralKit/PluralKit/blob/4bf60a47d76a068fa0488bf9be96cdaf57a6fe50/PluralKit.Core/Database/Database.cs#L116
// Thanks for not working with common types by default, Dapper. Really nice of you.
private class PassthroughTypeHandler<T> : SqlMapper.TypeHandler<T>
{
public override void SetValue(IDbDataParameter parameter, T? value) =>
parameter.Value = value;
public override T Parse(object value) => (T)value;
}
private class UlongEncodeAsLongHandler : SqlMapper.TypeHandler<ulong>
{
public override ulong Parse(object value) =>
// Cast to long to unbox, then to ulong (???)
(ulong)(long)value;
public override void SetValue(IDbDataParameter parameter, ulong value) =>
parameter.Value = (long)value;
}
private class UlongArrayHandler : SqlMapper.TypeHandler<ulong[]>
{
public override void SetValue(IDbDataParameter parameter, ulong[]? value) =>
parameter.Value = value != null ? Array.ConvertAll(value, i => (long)i) : null;
public override ulong[] Parse(object value) =>
Array.ConvertAll((long[])value, i => (ulong)i);
}
public class JsonTypeHandler<T> : SqlMapper.TypeHandler<T>
{
public override T Parse(object value)
{
string json = (string)value;
return JsonSerializer.Deserialize<T>(json)
?? throw new CataloggerError("JsonTypeHandler<T> returned null");
}
public override void SetValue(IDbDataParameter parameter, T? value)
{
parameter.Value = JsonSerializer.Serialize(value);
}
}
}
public static class ServiceCollectionDatabaseExtensions
{
public static IServiceCollection AddDatabasePool(this IServiceCollection serviceCollection) =>
serviceCollection
.AddSingleton<DatabasePool>()
.AddScoped<DatabaseConnection>(services =>
services.GetRequiredService<DatabasePool>().Acquire()
);
}