// Copyright (C) 2021-present sam (starshines.gay) // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . using System.Data; using System.Text.Json; using Catalogger.Backend.Database.Models; using Dapper; using NodaTime; using Npgsql; namespace Catalogger.Backend.Database; public class DatabasePool { private readonly NpgsqlDataSource _dataSource; private static int _openConnections; public static int OpenConnections => _openConnections; public DatabasePool(Config config, ILoggerFactory? loggerFactory) { var connString = new NpgsqlConnectionStringBuilder(config.Database.Url) { Timeout = config.Database.Timeout ?? 5, MaxPoolSize = config.Database.MaxPoolSize ?? 50, }.ConnectionString; var dataSourceBuilder = new NpgsqlDataSourceBuilder(connString); dataSourceBuilder.EnableDynamicJson().UseNodaTime(); if (config.Logging.LogQueries) dataSourceBuilder.UseLoggerFactory(loggerFactory); _dataSource = dataSourceBuilder.Build(); } public async Task AcquireAsync(CancellationToken ct = default) { IncrementConnections(); return new DatabaseConnection(await _dataSource.OpenConnectionAsync(ct)); } public DatabaseConnection Acquire() { IncrementConnections(); return new DatabaseConnection(_dataSource.OpenConnection()); } public async Task ExecuteAsync( Func func, CancellationToken ct = default ) { await using var conn = await AcquireAsync(ct); await func(conn); } public async Task ExecuteAsync( Func> func, CancellationToken ct = default ) { await using var conn = await AcquireAsync(ct); return await func(conn); } public async Task> ExecuteAsync( Func>> func, CancellationToken ct = default ) { await using var conn = await AcquireAsync(ct); return await func(conn); } internal static void IncrementConnections() => Interlocked.Increment(ref _openConnections); internal static void DecrementConnections() => Interlocked.Decrement(ref _openConnections); /// /// Configures Dapper's SQL mapping, as it handles several types incorrectly by default. /// Most notably, ulongs and arrays of ulongs. /// public static void ConfigureDapper() { DefaultTypeMap.MatchNamesWithUnderscores = true; SqlMapper.RemoveTypeMap(typeof(ulong)); SqlMapper.AddTypeHandler(new UlongEncodeAsLongHandler()); SqlMapper.AddTypeHandler(new PassthroughTypeHandler()); SqlMapper.AddTypeHandler(new JsonTypeHandler()); SqlMapper.AddTypeHandler(new JsonTypeHandler()); SqlMapper.AddTypeHandler(new UlongListHandler()); } // Copied from PluralKit: // https://github.com/PluralKit/PluralKit/blob/4bf60a47d76a068fa0488bf9be96cdaf57a6fe50/PluralKit.Core/Database/Database.cs#L116 // Thanks for not working with common types by default, Dapper. Really nice of you. private class PassthroughTypeHandler : SqlMapper.TypeHandler { public override void SetValue(IDbDataParameter parameter, T? value) => parameter.Value = value; public override T Parse(object value) => (T)value; } private class UlongEncodeAsLongHandler : SqlMapper.TypeHandler { public override void SetValue(IDbDataParameter parameter, ulong value) => parameter.Value = (long)value; public override ulong Parse(object value) => // Cast to long to unbox, then to ulong (???) (ulong)(long)value; } private class UlongListHandler : SqlMapper.TypeHandler> { public override void SetValue(IDbDataParameter parameter, List? value) => parameter.Value = value?.Select(i => (long)i).ToArray(); public override List? Parse(object value) => ((long[])value).Select(i => (ulong)i).ToList(); } private class JsonTypeHandler : SqlMapper.TypeHandler { public override void SetValue(IDbDataParameter parameter, T? value) => parameter.Value = JsonSerializer.Serialize(value); public override T Parse(object value) { var json = (string)value; return JsonSerializer.Deserialize(json) ?? throw new CataloggerError("JsonTypeHandler returned null"); } } } public static class ServiceCollectionDatabaseExtensions { public static IServiceCollection AddDatabasePool(this IServiceCollection serviceCollection) => serviceCollection .AddSingleton() .AddScoped(services => services.GetRequiredService().Acquire() ); }