// Copyright (C) 2021-present sam (starshines.gay) // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published // by the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . using System.Data.Common; using Dapper; using NodaTime; namespace Catalogger.Backend.Database; /// /// Manages database migrations. /// public class DatabaseMigrator(ILogger logger, IClock clock, DatabaseConnection conn) : IDisposable, IAsyncDisposable { private const string RootPath = "Catalogger.Backend.Database"; private static readonly int MigrationsPathLength = $"{RootPath}.Migrations.".Length; /// /// Migrates the database to the latest version. /// public async Task MigrateUp() { var migrations = GetMigrationNames().ToArray(); logger.Debug("Getting current database migration"); var currentMigration = await GetCurrentMigration(); if (currentMigration != null) migrations = migrations .Where(s => string.CompareOrdinal(s, currentMigration.MigrationName) > 0) .ToArray(); logger.Information( "Current migration: {Migration}. Applying {Count} migrations", currentMigration?.MigrationName, migrations.Length ); if (migrations.Length == 0) { return; } // Wrap all migrations in a transaction await using var tx = await conn.BeginTransactionAsync(); var totalStartTime = clock.GetCurrentInstant(); foreach (var migration in migrations) { logger.Debug("Executing migration {Migration}", migration); var startTime = clock.GetCurrentInstant(); await ExecuteMigration(tx, migration); var took = clock.GetCurrentInstant() - startTime; logger.Debug("Executed migration {Migration} in {Took}", migration, took); } var totalTook = clock.GetCurrentInstant() - totalStartTime; logger.Information("Executed {Count} migrations in {Took}", migrations.Length, totalTook); // Finally, commit the transaction await tx.CommitAsync(); } /// /// Migrates the database to a previous version. /// /// The number of migrations to revert. If higher than the number of applied migrations, /// reverts the database to a clean slate. public async Task MigrateDown(int count = 1) { await using var tx = await conn.BeginTransactionAsync(); var migrationCount = 0; var totalStartTime = clock.GetCurrentInstant(); for (var i = count; i > 0; i--) { var migration = await GetCurrentMigration(); if (migration == null) { logger.Information( "More down migrations requested than were in the database, finishing early" ); break; } logger.Debug("Reverting migration {Migration}", migration); var startTime = clock.GetCurrentInstant(); await ExecuteMigration(tx, migration.MigrationName, up: false); var took = clock.GetCurrentInstant() - startTime; logger.Debug("Reverted migration {Migration} in {Took}", migration, took); migrationCount++; } var totalTook = clock.GetCurrentInstant() - totalStartTime; logger.Information("Reverted {Count} migrations in {Took}", migrationCount, totalTook); // Finally, commit the transaction await tx.CommitAsync(); } private async Task ExecuteMigration(DbTransaction tx, string migrationName, bool up = true) { var query = await GetResource( $"{RootPath}.Migrations.{migrationName}.{(up ? "up" : "down")}.sql" ); // Run the migration await conn.ExecuteAsync(query, transaction: tx); // Store that we ran the migration (or reverted it) if (up) await conn.ExecuteAsync( "INSERT INTO migrations (migration_name, applied_at) VALUES (@MigrationName, now())", new { MigrationName = migrationName } ); else await conn.ExecuteAsync( "DELETE FROM migrations WHERE migration_name = @MigrationName", new { MigrationName = migrationName } ); } /// Returns the current migration. If no migrations have been applied, returns null private async Task GetCurrentMigration() { // Check if the migrations table exists var hasMigrationTable = await conn.QuerySingleOrDefaultAsync( "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'" ) == 1; // If so, return the current migration if (hasMigrationTable) { return await conn.QuerySingleOrDefaultAsync( "SELECT * FROM migrations ORDER BY applied_at DESC, migration_name DESC LIMIT 1" ); } logger.Debug("Migrations table does not exist, assuming this is a new database"); // Else, create the migrations table then return null var migrationTableQuery = await GetResource($"{RootPath}.setup_migrations.sql"); await conn.ExecuteAsync(migrationTableQuery); return null; } /// Returns a resource by name as a string. private static async Task GetResource(string name) { await using var stream = typeof(DatabasePool).Assembly.GetManifestResourceStream(name) ?? throw new ArgumentException($"Invalid resource '{name}'"); using var reader = new StreamReader(stream); return await reader.ReadToEndAsync(); } private static IEnumerable GetMigrationNames() => typeof(DatabasePool) .Assembly.GetManifestResourceNames() .Where(s => s.StartsWith($"{RootPath}.Migrations")) .Where(s => s.EndsWith(".up.sql")) .Select(s => s.Substring( MigrationsPathLength, s.Length - MigrationsPathLength - ".up.sql".Length ) ) .OrderBy(s => s); private record MigrationEntry { public string MigrationName { get; init; } = null!; public Instant AppliedAt { get; init; } } public void Dispose() { conn.Dispose(); GC.SuppressFinalize(this); } public async ValueTask DisposeAsync() { await conn.DisposeAsync(); GC.SuppressFinalize(this); } }