// Copyright (C) 2021-present sam (starshines.gay) // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . using System.Data; using System.Text.Json; using System.Text.Json.Serialization; using Catalogger.Backend.Database.Models; using Dapper; using NodaTime; using Npgsql; namespace Catalogger.Backend.Database.Dapper; public class DatabasePool { private readonly ILogger _rootLogger; private readonly ILogger _logger; private readonly NpgsqlDataSource _dataSource; private static int _openConnections; public static int OpenConnections => _openConnections; public DatabasePool(Config config, ILogger logger, ILoggerFactory? loggerFactory) { _rootLogger = logger; _logger = logger.ForContext(); var connString = new NpgsqlConnectionStringBuilder(config.Database.Url) { Timeout = config.Database.Timeout ?? 5, MaxPoolSize = config.Database.MaxPoolSize ?? 50, }.ConnectionString; var dataSourceBuilder = new NpgsqlDataSourceBuilder(connString); dataSourceBuilder.EnableDynamicJson().UseNodaTime(); if (config.Logging.LogQueries) dataSourceBuilder.UseLoggerFactory(loggerFactory); _dataSource = dataSourceBuilder.Build(); } public async Task AcquireAsync(CancellationToken ct = default) { return new DatabaseConnection( LogOpen(), _rootLogger, await _dataSource.OpenConnectionAsync(ct) ); } public DatabaseConnection Acquire() { return new DatabaseConnection(LogOpen(), _rootLogger, _dataSource.OpenConnection()); } private Guid LogOpen() { var connId = Guid.NewGuid(); _logger.Debug("Opening database connection {ConnId}", connId); IncrementConnections(); return connId; } public async Task ExecuteAsync( Func func, CancellationToken ct = default ) { await using var conn = await AcquireAsync(ct); await func(conn); } public async Task ExecuteAsync( Func> func, CancellationToken ct = default ) { await using var conn = await AcquireAsync(ct); return await func(conn); } public async Task> ExecuteAsync( Func>> func, CancellationToken ct = default ) { await using var conn = await AcquireAsync(ct); return await func(conn); } internal static void IncrementConnections() => Interlocked.Increment(ref _openConnections); internal static void DecrementConnections() => Interlocked.Decrement(ref _openConnections); /// /// Configures Dapper's SQL mapping, as it handles several types incorrectly by default. /// Most notably, ulongs and arrays of ulongs. /// public static void ConfigureDapper() { DefaultTypeMap.MatchNamesWithUnderscores = true; SqlMapper.RemoveTypeMap(typeof(ulong)); SqlMapper.AddTypeHandler(new UlongEncodeAsLongHandler()); SqlMapper.AddTypeHandler(new UlongArrayHandler()); SqlMapper.AddTypeHandler(new PassthroughTypeHandler()); SqlMapper.AddTypeHandler(new JsonTypeHandler()); } // Copied from PluralKit: // https://github.com/PluralKit/PluralKit/blob/4bf60a47d76a068fa0488bf9be96cdaf57a6fe50/PluralKit.Core/Database/Database.cs#L116 // Thanks for not working with common types by default, Dapper. Really nice of you. private class PassthroughTypeHandler : SqlMapper.TypeHandler { public override void SetValue(IDbDataParameter parameter, T? value) => parameter.Value = value; public override T Parse(object value) => (T)value; } private class UlongEncodeAsLongHandler : SqlMapper.TypeHandler { public override ulong Parse(object value) => // Cast to long to unbox, then to ulong (???) (ulong)(long)value; public override void SetValue(IDbDataParameter parameter, ulong value) => parameter.Value = (long)value; } private class UlongArrayHandler : SqlMapper.TypeHandler { public override void SetValue(IDbDataParameter parameter, ulong[]? value) => parameter.Value = value != null ? Array.ConvertAll(value, i => (long)i) : null; public override ulong[] Parse(object value) => Array.ConvertAll((long[])value, i => (ulong)i); } public class JsonTypeHandler : SqlMapper.TypeHandler { public override T Parse(object value) { string json = (string)value; return JsonSerializer.Deserialize(json) ?? throw new CataloggerError("JsonTypeHandler returned null"); } public override void SetValue(IDbDataParameter parameter, T? value) { parameter.Value = JsonSerializer.Serialize(value); } } } public static class ServiceCollectionDatabaseExtensions { public static IServiceCollection AddDatabasePool(this IServiceCollection serviceCollection) => serviceCollection .AddSingleton() .AddScoped(services => services.GetRequiredService().Acquire() ); }