// Copyright (C) 2021-present sam (starshines.gay)
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see .
using System.Data;
using Dapper;
using NodaTime;
using Npgsql;
namespace Catalogger.Backend.Database.Dapper;
public class DatabasePool
{
private readonly ILogger _rootLogger;
private readonly ILogger _logger;
private readonly NpgsqlDataSource _dataSource;
private static int _openConnections;
public static int OpenConnections => _openConnections;
public DatabasePool(Config config, ILogger logger, ILoggerFactory? loggerFactory)
{
_rootLogger = logger;
_logger = logger.ForContext();
var connString = new NpgsqlConnectionStringBuilder(config.Database.Url)
{
Timeout = config.Database.Timeout ?? 5,
MaxPoolSize = config.Database.MaxPoolSize ?? 50,
}.ConnectionString;
var dataSourceBuilder = new NpgsqlDataSourceBuilder(connString);
dataSourceBuilder.EnableDynamicJson().UseNodaTime();
if (config.Logging.LogQueries)
dataSourceBuilder.UseLoggerFactory(loggerFactory);
_dataSource = dataSourceBuilder.Build();
}
public async Task AcquireAsync(CancellationToken ct = default)
{
return new DatabaseConnection(
LogOpen(),
_rootLogger,
await _dataSource.OpenConnectionAsync(ct)
);
}
public DatabaseConnection Acquire()
{
return new DatabaseConnection(LogOpen(), _rootLogger, _dataSource.OpenConnection());
}
private Guid LogOpen()
{
var connId = Guid.NewGuid();
_logger.Debug("Opening database connection {ConnId}", connId);
IncrementConnections();
return connId;
}
public async Task ExecuteAsync(
Func func,
CancellationToken ct = default
)
{
await using var conn = await AcquireAsync(ct);
await func(conn);
}
public async Task ExecuteAsync(
Func> func,
CancellationToken ct = default
)
{
await using var conn = await AcquireAsync(ct);
return await func(conn);
}
public async Task> ExecuteAsync(
Func>> func,
CancellationToken ct = default
)
{
await using var conn = await AcquireAsync(ct);
return await func(conn);
}
internal static void IncrementConnections() => Interlocked.Increment(ref _openConnections);
internal static void DecrementConnections() => Interlocked.Decrement(ref _openConnections);
///
/// Configures Dapper's SQL mapping, as it handles several types incorrectly by default.
/// Most notably, ulongs and arrays of ulongs.
///
public static void ConfigureDapper()
{
DefaultTypeMap.MatchNamesWithUnderscores = true;
SqlMapper.RemoveTypeMap(typeof(ulong));
SqlMapper.AddTypeHandler(new UlongEncodeAsLongHandler());
SqlMapper.AddTypeHandler(new UlongArrayHandler());
SqlMapper.AddTypeHandler(new PassthroughTypeHandler());
}
// Copied from PluralKit:
// https://github.com/PluralKit/PluralKit/blob/4bf60a47d76a068fa0488bf9be96cdaf57a6fe50/PluralKit.Core/Database/Database.cs#L116
// Thanks for not working with common types by default, Dapper. Really nice of you.
private class PassthroughTypeHandler : SqlMapper.TypeHandler
{
public override void SetValue(IDbDataParameter parameter, T? value) =>
parameter.Value = value;
public override T Parse(object value) => (T)value;
}
private class UlongEncodeAsLongHandler : SqlMapper.TypeHandler
{
public override ulong Parse(object value) =>
// Cast to long to unbox, then to ulong (???)
(ulong)(long)value;
public override void SetValue(IDbDataParameter parameter, ulong value) =>
parameter.Value = (long)value;
}
private class UlongArrayHandler : SqlMapper.TypeHandler
{
public override void SetValue(IDbDataParameter parameter, ulong[]? value) =>
parameter.Value = value != null ? Array.ConvertAll(value, i => (long)i) : null;
public override ulong[] Parse(object value) =>
Array.ConvertAll((long[])value, i => (ulong)i);
}
}
public static class ServiceCollectionDatabaseExtensions
{
public static IServiceCollection AddDatabasePool(this IServiceCollection serviceCollection) =>
serviceCollection
.AddSingleton()
.AddScoped(services =>
services.GetRequiredService().Acquire()
);
}