Catalogger.NET/Catalogger.Backend/Database/DatabaseMigrator.cs

147 lines
5.3 KiB
C#

// Copyright (C) 2021-present sam (starshines.gay)
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
using System.Data.Common;
using Dapper;
using NodaTime;
namespace Catalogger.Backend.Database;
public class DatabaseMigrator(ILogger logger, IClock clock, DatabaseConnection conn)
: IDisposable,
IAsyncDisposable
{
private const string RootPath = "Catalogger.Backend.Database";
private static readonly int MigrationsPathLength = $"{RootPath}.Migrations.".Length;
public async Task Migrate()
{
var migrations = GetMigrationNames().ToArray();
logger.Debug("Getting current database migration");
var currentMigration = await GetCurrentMigration();
if (currentMigration != null)
migrations = migrations
.Where(s => string.CompareOrdinal(s, currentMigration.MigrationName) > 0)
.ToArray();
logger.Information(
"Current migration: {Migration}. Applying {Count} migrations",
currentMigration?.MigrationName,
migrations.Length
);
if (migrations.Length == 0)
{
return;
}
// Wrap all migrations in a transaction
await using var tx = await conn.BeginTransactionAsync();
var totalStartTime = clock.GetCurrentInstant();
foreach (var migration in migrations)
{
logger.Debug("Executing migration {Migration}", migration);
var startTime = clock.GetCurrentInstant();
await ExecuteMigration(tx, migration);
var took = clock.GetCurrentInstant() - startTime;
logger.Debug("Executed migration {Migration} in {Took}", migration, took);
}
var totalTook = clock.GetCurrentInstant() - totalStartTime;
logger.Information("Executed {Count} migrations in {Took}", migrations.Length, totalTook);
// Finally, commit the transaction
await tx.CommitAsync();
}
private async Task ExecuteMigration(DbTransaction tx, string migrationName, bool up = true)
{
var query = await GetResource(
$"{RootPath}.Migrations.{migrationName}.{(up ? "up" : "down")}.sql"
);
// Run the migration
await conn.ExecuteAsync(query, transaction: tx);
// Store that we ran the migration
await conn.ExecuteAsync(
"INSERT INTO migrations (migration_name, applied_at) VALUES (@MigrationName, @AppliedAt)",
new { MigrationName = migrationName, AppliedAt = clock.GetCurrentInstant() }
);
}
/// Returns the current migration. If no migrations have been applied, returns null
private async Task<MigrationEntry?> GetCurrentMigration()
{
// Check if the migrations table exists
var hasMigrationTable =
await conn.QuerySingleOrDefaultAsync<int>(
"SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'"
) == 1;
// If so, return the current migration
if (hasMigrationTable)
{
return await conn.QuerySingleOrDefaultAsync<MigrationEntry>(
"SELECT * FROM migrations ORDER BY applied_at DESC LIMIT 1"
);
}
logger.Debug("Migrations table does not exist, assuming this is a new database");
// Else, create the migrations table then return null
var migrationTableQuery = await GetResource($"{RootPath}.setup_migrations.sql");
await conn.ExecuteAsync(migrationTableQuery);
return null;
}
/// Returns a resource by name as a string.
private static async Task<string> GetResource(string name)
{
await using var stream =
typeof(DatabasePool).Assembly.GetManifestResourceStream(name)
?? throw new ArgumentException($"Invalid resource '{name}'");
using var reader = new StreamReader(stream);
return await reader.ReadToEndAsync();
}
public static IEnumerable<string> GetMigrationNames() =>
typeof(DatabasePool)
.Assembly.GetManifestResourceNames()
.Where(s => s.StartsWith($"{RootPath}.Migrations"))
.Where(s => s.EndsWith(".up.sql"))
.Select(s =>
s.Substring(
MigrationsPathLength,
s.Length - MigrationsPathLength - ".up.sql".Length
)
)
.OrderBy(s => s);
private record MigrationEntry
{
public string MigrationName { get; init; } = null!;
public Instant AppliedAt { get; init; }
}
public void Dispose()
{
conn.Dispose();
GC.SuppressFinalize(this);
}
public async ValueTask DisposeAsync()
{
await conn.DisposeAsync();
GC.SuppressFinalize(this);
}
}