147 lines
5.3 KiB
C#
147 lines
5.3 KiB
C#
// Copyright (C) 2021-present sam (starshines.gay)
|
|
//
|
|
// This program is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU Affero General Public License as published
|
|
// by the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// This program is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU Affero General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU Affero General Public License
|
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
using System.Data.Common;
|
|
using Dapper;
|
|
using NodaTime;
|
|
|
|
namespace Catalogger.Backend.Database;
|
|
|
|
public class DatabaseMigrator(ILogger logger, IClock clock, DatabaseConnection conn)
|
|
: IDisposable,
|
|
IAsyncDisposable
|
|
{
|
|
private const string RootPath = "Catalogger.Backend.Database";
|
|
private static readonly int MigrationsPathLength = $"{RootPath}.Migrations.".Length;
|
|
|
|
public async Task Migrate()
|
|
{
|
|
var migrations = GetMigrationNames().ToArray();
|
|
logger.Debug("Getting current database migration");
|
|
var currentMigration = await GetCurrentMigration();
|
|
if (currentMigration != null)
|
|
migrations = migrations
|
|
.Where(s => string.CompareOrdinal(s, currentMigration.MigrationName) > 0)
|
|
.ToArray();
|
|
|
|
logger.Information(
|
|
"Current migration: {Migration}. Applying {Count} migrations",
|
|
currentMigration?.MigrationName,
|
|
migrations.Length
|
|
);
|
|
if (migrations.Length == 0)
|
|
{
|
|
return;
|
|
}
|
|
|
|
// Wrap all migrations in a transaction
|
|
await using var tx = await conn.BeginTransactionAsync();
|
|
var totalStartTime = clock.GetCurrentInstant();
|
|
foreach (var migration in migrations)
|
|
{
|
|
logger.Debug("Executing migration {Migration}", migration);
|
|
var startTime = clock.GetCurrentInstant();
|
|
await ExecuteMigration(tx, migration);
|
|
var took = clock.GetCurrentInstant() - startTime;
|
|
logger.Debug("Executed migration {Migration} in {Took}", migration, took);
|
|
}
|
|
|
|
var totalTook = clock.GetCurrentInstant() - totalStartTime;
|
|
logger.Information("Executed {Count} migrations in {Took}", migrations.Length, totalTook);
|
|
|
|
// Finally, commit the transaction
|
|
await tx.CommitAsync();
|
|
}
|
|
|
|
private async Task ExecuteMigration(DbTransaction tx, string migrationName, bool up = true)
|
|
{
|
|
var query = await GetResource(
|
|
$"{RootPath}.Migrations.{migrationName}.{(up ? "up" : "down")}.sql"
|
|
);
|
|
|
|
// Run the migration
|
|
await conn.ExecuteAsync(query, transaction: tx);
|
|
// Store that we ran the migration
|
|
await conn.ExecuteAsync(
|
|
"INSERT INTO migrations (migration_name, applied_at) VALUES (@MigrationName, @AppliedAt)",
|
|
new { MigrationName = migrationName, AppliedAt = clock.GetCurrentInstant() }
|
|
);
|
|
}
|
|
|
|
/// Returns the current migration. If no migrations have been applied, returns null
|
|
private async Task<MigrationEntry?> GetCurrentMigration()
|
|
{
|
|
// Check if the migrations table exists
|
|
var hasMigrationTable =
|
|
await conn.QuerySingleOrDefaultAsync<int>(
|
|
"SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'migrations'"
|
|
) == 1;
|
|
// If so, return the current migration
|
|
if (hasMigrationTable)
|
|
{
|
|
return await conn.QuerySingleOrDefaultAsync<MigrationEntry>(
|
|
"SELECT * FROM migrations ORDER BY applied_at DESC LIMIT 1"
|
|
);
|
|
}
|
|
|
|
logger.Debug("Migrations table does not exist, assuming this is a new database");
|
|
|
|
// Else, create the migrations table then return null
|
|
var migrationTableQuery = await GetResource($"{RootPath}.setup_migrations.sql");
|
|
await conn.ExecuteAsync(migrationTableQuery);
|
|
return null;
|
|
}
|
|
|
|
/// Returns a resource by name as a string.
|
|
private static async Task<string> GetResource(string name)
|
|
{
|
|
await using var stream =
|
|
typeof(DatabasePool).Assembly.GetManifestResourceStream(name)
|
|
?? throw new ArgumentException($"Invalid resource '{name}'");
|
|
using var reader = new StreamReader(stream);
|
|
return await reader.ReadToEndAsync();
|
|
}
|
|
|
|
public static IEnumerable<string> GetMigrationNames() =>
|
|
typeof(DatabasePool)
|
|
.Assembly.GetManifestResourceNames()
|
|
.Where(s => s.StartsWith($"{RootPath}.Migrations"))
|
|
.Where(s => s.EndsWith(".up.sql"))
|
|
.Select(s =>
|
|
s.Substring(
|
|
MigrationsPathLength,
|
|
s.Length - MigrationsPathLength - ".up.sql".Length
|
|
)
|
|
)
|
|
.OrderBy(s => s);
|
|
|
|
private record MigrationEntry
|
|
{
|
|
public string MigrationName { get; init; } = null!;
|
|
public Instant AppliedAt { get; init; }
|
|
}
|
|
|
|
public void Dispose()
|
|
{
|
|
conn.Dispose();
|
|
GC.SuppressFinalize(this);
|
|
}
|
|
|
|
public async ValueTask DisposeAsync()
|
|
{
|
|
await conn.DisposeAsync();
|
|
GC.SuppressFinalize(this);
|
|
}
|
|
}
|