chore: add csharpier to husky, format backend with csharpier

This commit is contained in:
sam 2024-10-02 00:28:07 +02:00
parent 5fab66444f
commit 7f971e8549
Signed by: sam
GPG key ID: B4EF20DDE721CAA1
73 changed files with 2098 additions and 1048 deletions

View file

@ -4,7 +4,16 @@
"tools": { "tools": {
"husky": { "husky": {
"version": "0.7.1", "version": "0.7.1",
"commands": ["husky"], "commands": [
"husky"
],
"rollForward": false
},
"csharpier": {
"version": "0.29.2",
"commands": [
"dotnet-csharpier"
],
"rollForward": false "rollForward": false
} }
} }

View file

@ -8,9 +8,10 @@
"pathMode": "absolute" "pathMode": "absolute"
}, },
{ {
"name": "dotnet-format", "name": "run-csharpier",
"command": "dotnet", "command": "dotnet",
"args": ["format"] "args": [ "csharpier", "${staged}" ],
"include": [ "**/*.cs" ]
} }
] ]
} }

View file

@ -8,19 +8,24 @@ public static class BuildInfo
public static async Task ReadBuildInfo() public static async Task ReadBuildInfo()
{ {
await using var stream = typeof(BuildInfo).Assembly.GetManifestResourceStream("version"); await using var stream = typeof(BuildInfo).Assembly.GetManifestResourceStream("version");
if (stream == null) return; if (stream == null)
return;
using var reader = new StreamReader(stream); using var reader = new StreamReader(stream);
var data = (await reader.ReadToEndAsync()).Trim().Split("\n"); var data = (await reader.ReadToEndAsync()).Trim().Split("\n");
if (data.Length < 3) return; if (data.Length < 3)
return;
Hash = data[0]; Hash = data[0];
var dirty = data[2] == "dirty"; var dirty = data[2] == "dirty";
var versionData = data[1].Split("-"); var versionData = data[1].Split("-");
if (versionData.Length < 3) return; if (versionData.Length < 3)
return;
Version = versionData[0]; Version = versionData[0];
if (versionData[1] != "0" || dirty) Version += $"+{versionData[2]}"; if (versionData[1] != "0" || dirty)
if (dirty) Version += ".dirty"; Version += $"+{versionData[2]}";
if (dirty)
Version += ".dirty";
} }
} }

View file

@ -11,8 +11,12 @@ using NodaTime;
namespace Foxnouns.Backend.Controllers.Authentication; namespace Foxnouns.Backend.Controllers.Authentication;
[Route("/api/internal/auth")] [Route("/api/internal/auth")]
public class AuthController(Config config, DatabaseContext db, KeyCacheService keyCache, ILogger logger) public class AuthController(
: ApiControllerBase Config config,
DatabaseContext db,
KeyCacheService keyCache,
ILogger logger
) : ApiControllerBase
{ {
private readonly ILogger _logger = logger.ForContext<AuthController>(); private readonly ILogger _logger = logger.ForContext<AuthController>();
@ -20,27 +24,25 @@ public class AuthController(Config config, DatabaseContext db, KeyCacheService k
[ProducesResponseType<UrlsResponse>(StatusCodes.Status200OK)] [ProducesResponseType<UrlsResponse>(StatusCodes.Status200OK)]
public async Task<IActionResult> UrlsAsync(CancellationToken ct = default) public async Task<IActionResult> UrlsAsync(CancellationToken ct = default)
{ {
_logger.Debug("Generating auth URLs for Discord: {Discord}, Google: {Google}, Tumblr: {Tumblr}", _logger.Debug(
"Generating auth URLs for Discord: {Discord}, Google: {Google}, Tumblr: {Tumblr}",
config.DiscordAuth.Enabled, config.DiscordAuth.Enabled,
config.GoogleAuth.Enabled, config.GoogleAuth.Enabled,
config.TumblrAuth.Enabled); config.TumblrAuth.Enabled
);
var state = HttpUtility.UrlEncode(await keyCache.GenerateAuthStateAsync(ct)); var state = HttpUtility.UrlEncode(await keyCache.GenerateAuthStateAsync(ct));
string? discord = null; string? discord = null;
if (config.DiscordAuth is { ClientId: not null, ClientSecret: not null }) if (config.DiscordAuth is { ClientId: not null, ClientSecret: not null })
discord = discord =
$"https://discord.com/oauth2/authorize?response_type=code" + $"https://discord.com/oauth2/authorize?response_type=code"
$"&client_id={config.DiscordAuth.ClientId}&scope=identify" + + $"&client_id={config.DiscordAuth.ClientId}&scope=identify"
$"&prompt=none&state={state}" + + $"&prompt=none&state={state}"
$"&redirect_uri={HttpUtility.UrlEncode($"{config.BaseUrl}/auth/callback/discord")}"; + $"&redirect_uri={HttpUtility.UrlEncode($"{config.BaseUrl}/auth/callback/discord")}";
return Ok(new UrlsResponse(discord, null, null)); return Ok(new UrlsResponse(discord, null, null));
} }
private record UrlsResponse( private record UrlsResponse(string? Discord, string? Google, string? Tumblr);
string? Discord,
string? Google,
string? Tumblr
);
public record AuthResponse( public record AuthResponse(
UserRendererService.UserResponse User, UserRendererService.UserResponse User,
@ -50,16 +52,13 @@ public class AuthController(Config config, DatabaseContext db, KeyCacheService k
public record CallbackResponse( public record CallbackResponse(
bool HasAccount, bool HasAccount,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] [property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] string? Ticket,
string? Ticket,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] [property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
string? RemoteUsername, string? RemoteUsername,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] [property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
UserRendererService.UserResponse? User, UserRendererService.UserResponse? User,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] [property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] string? Token,
string? Token, [property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] Instant? ExpiresAt
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
Instant? ExpiresAt
); );
public record OauthRegisterRequest(string Ticket, string Username); public record OauthRegisterRequest(string Ticket, string Username);
@ -71,7 +70,8 @@ public class AuthController(Config config, DatabaseContext db, KeyCacheService k
public async Task<IActionResult> ForceLogoutAsync() public async Task<IActionResult> ForceLogoutAsync()
{ {
_logger.Information("Invalidating all tokens for user {UserId}", CurrentUser!.Id); _logger.Information("Invalidating all tokens for user {UserId}", CurrentUser!.Id);
await db.Tokens.Where(t => t.UserId == CurrentUser.Id) await db
.Tokens.Where(t => t.UserId == CurrentUser.Id)
.ExecuteUpdateAsync(s => s.SetProperty(t => t.ManuallyExpired, true)); .ExecuteUpdateAsync(s => s.SetProperty(t => t.ManuallyExpired, true));
return NoContent(); return NoContent();

View file

@ -19,7 +19,8 @@ public class DiscordAuthController(
KeyCacheService keyCacheService, KeyCacheService keyCacheService,
AuthService authService, AuthService authService,
RemoteAuthService remoteAuthService, RemoteAuthService remoteAuthService,
UserRendererService userRenderer) : ApiControllerBase UserRendererService userRenderer
) : ApiControllerBase
{ {
private readonly ILogger _logger = logger.ForContext<DiscordAuthController>(); private readonly ILogger _logger = logger.ForContext<DiscordAuthController>();
@ -27,59 +28,93 @@ public class DiscordAuthController(
// TODO: duplicating attribute doesn't work, find another way to mark both as possible response // TODO: duplicating attribute doesn't work, find another way to mark both as possible response
// leaving it here for documentation purposes // leaving it here for documentation purposes
[ProducesResponseType<AuthController.CallbackResponse>(StatusCodes.Status200OK)] [ProducesResponseType<AuthController.CallbackResponse>(StatusCodes.Status200OK)]
public async Task<IActionResult> CallbackAsync([FromBody] AuthController.CallbackRequest req, public async Task<IActionResult> CallbackAsync(
CancellationToken ct = default) [FromBody] AuthController.CallbackRequest req,
CancellationToken ct = default
)
{ {
CheckRequirements(); CheckRequirements();
await keyCacheService.ValidateAuthStateAsync(req.State, ct); await keyCacheService.ValidateAuthStateAsync(req.State, ct);
var remoteUser = await remoteAuthService.RequestDiscordTokenAsync(req.Code, req.State, ct); var remoteUser = await remoteAuthService.RequestDiscordTokenAsync(req.Code, req.State, ct);
var user = await authService.AuthenticateUserAsync(AuthType.Discord, remoteUser.Id, ct: ct); var user = await authService.AuthenticateUserAsync(AuthType.Discord, remoteUser.Id, ct: ct);
if (user != null) return Ok(await GenerateUserTokenAsync(user, ct)); if (user != null)
return Ok(await GenerateUserTokenAsync(user, ct));
_logger.Debug("Discord user {Username} ({Id}) authenticated with no local account", remoteUser.Username, _logger.Debug(
remoteUser.Id); "Discord user {Username} ({Id}) authenticated with no local account",
remoteUser.Username,
remoteUser.Id
);
var ticket = AuthUtils.RandomToken(); var ticket = AuthUtils.RandomToken();
await keyCacheService.SetKeyAsync($"discord:{ticket}", remoteUser, Duration.FromMinutes(20), ct); await keyCacheService.SetKeyAsync(
$"discord:{ticket}",
remoteUser,
Duration.FromMinutes(20),
ct
);
return Ok(new AuthController.CallbackResponse( return Ok(
new AuthController.CallbackResponse(
HasAccount: false, HasAccount: false,
Ticket: ticket, Ticket: ticket,
RemoteUsername: remoteUser.Username, RemoteUsername: remoteUser.Username,
User: null, User: null,
Token: null, Token: null,
ExpiresAt: null ExpiresAt: null
)); )
);
} }
[HttpPost("register")] [HttpPost("register")]
[ProducesResponseType<AuthController.AuthResponse>(StatusCodes.Status200OK)] [ProducesResponseType<AuthController.AuthResponse>(StatusCodes.Status200OK)]
public async Task<IActionResult> RegisterAsync([FromBody] AuthController.OauthRegisterRequest req) public async Task<IActionResult> RegisterAsync(
[FromBody] AuthController.OauthRegisterRequest req
)
{ {
var remoteUser = await keyCacheService.GetKeyAsync<RemoteAuthService.RemoteUser>($"discord:{req.Ticket}"); var remoteUser = await keyCacheService.GetKeyAsync<RemoteAuthService.RemoteUser>(
if (remoteUser == null) throw new ApiError.BadRequest("Invalid ticket", "ticket", req.Ticket); $"discord:{req.Ticket}"
if (await db.AuthMethods.AnyAsync(a => a.AuthType == AuthType.Discord && a.RemoteId == remoteUser.Id)) );
if (remoteUser == null)
throw new ApiError.BadRequest("Invalid ticket", "ticket", req.Ticket);
if (
await db.AuthMethods.AnyAsync(a =>
a.AuthType == AuthType.Discord && a.RemoteId == remoteUser.Id
)
)
{ {
_logger.Error("Discord user {Id} has valid ticket but is already linked to an existing account", _logger.Error(
remoteUser.Id); "Discord user {Id} has valid ticket but is already linked to an existing account",
remoteUser.Id
);
throw new ApiError.BadRequest("Invalid ticket", "ticket", req.Ticket); throw new ApiError.BadRequest("Invalid ticket", "ticket", req.Ticket);
} }
var user = await authService.CreateUserWithRemoteAuthAsync(req.Username, AuthType.Discord, remoteUser.Id, var user = await authService.CreateUserWithRemoteAuthAsync(
remoteUser.Username); req.Username,
AuthType.Discord,
remoteUser.Id,
remoteUser.Username
);
return Ok(await GenerateUserTokenAsync(user)); return Ok(await GenerateUserTokenAsync(user));
} }
private async Task<AuthController.CallbackResponse> GenerateUserTokenAsync(User user, private async Task<AuthController.CallbackResponse> GenerateUserTokenAsync(
CancellationToken ct = default) User user,
CancellationToken ct = default
)
{ {
var frontendApp = await db.GetFrontendApplicationAsync(ct); var frontendApp = await db.GetFrontendApplicationAsync(ct);
_logger.Debug("Logging user {Id} in with Discord", user.Id); _logger.Debug("Logging user {Id} in with Discord", user.Id);
var (tokenStr, token) = var (tokenStr, token) = authService.GenerateToken(
authService.GenerateToken(user, frontendApp, ["*"], clock.GetCurrentInstant() + Duration.FromDays(365)); user,
frontendApp,
["*"],
clock.GetCurrentInstant() + Duration.FromDays(365)
);
db.Add(token); db.Add(token);
_logger.Debug("Generated token {TokenId} for {UserId}", user.Id, token.Id); _logger.Debug("Generated token {TokenId} for {UserId}", user.Id, token.Id);
@ -90,7 +125,12 @@ public class DiscordAuthController(
HasAccount: true, HasAccount: true,
Ticket: null, Ticket: null,
RemoteUsername: null, RemoteUsername: null,
User: await userRenderer.RenderUserAsync(user, selfUser: user, renderMembers: false, ct: ct), User: await userRenderer.RenderUserAsync(
user,
selfUser: user,
renderMembers: false,
ct: ct
),
Token: tokenStr, Token: tokenStr,
ExpiresAt: token.ExpiresAt ExpiresAt: token.ExpiresAt
); );
@ -99,6 +139,8 @@ public class DiscordAuthController(
private void CheckRequirements() private void CheckRequirements()
{ {
if (!config.DiscordAuth.Enabled) if (!config.DiscordAuth.Enabled)
throw new ApiError.BadRequest("Discord authentication is not enabled on this instance."); throw new ApiError.BadRequest(
"Discord authentication is not enabled on this instance."
);
} }
} }

View file

@ -20,21 +20,35 @@ public class EmailAuthController(
KeyCacheService keyCacheService, KeyCacheService keyCacheService,
UserRendererService userRenderer, UserRendererService userRenderer,
IClock clock, IClock clock,
ILogger logger) : ApiControllerBase ILogger logger
) : ApiControllerBase
{ {
private readonly ILogger _logger = logger.ForContext<EmailAuthController>(); private readonly ILogger _logger = logger.ForContext<EmailAuthController>();
[HttpPost("register")] [HttpPost("register")]
public async Task<IActionResult> RegisterAsync([FromBody] RegisterRequest req, CancellationToken ct = default) public async Task<IActionResult> RegisterAsync(
[FromBody] RegisterRequest req,
CancellationToken ct = default
)
{ {
CheckRequirements(); CheckRequirements();
if (!req.Email.Contains('@')) throw new ApiError.BadRequest("Email is invalid", "email", req.Email); if (!req.Email.Contains('@'))
throw new ApiError.BadRequest("Email is invalid", "email", req.Email);
var state = await keyCacheService.GenerateRegisterEmailStateAsync(req.Email, userId: null, ct); var state = await keyCacheService.GenerateRegisterEmailStateAsync(
req.Email,
userId: null,
ct
);
// If there's already a user with that email address, pretend we sent an email but actually ignore it // If there's already a user with that email address, pretend we sent an email but actually ignore it
if (await db.AuthMethods.AnyAsync(a => a.AuthType == AuthType.Email && a.RemoteId == req.Email, ct)) if (
await db.AuthMethods.AnyAsync(
a => a.AuthType == AuthType.Email && a.RemoteId == req.Email,
ct
)
)
return NoContent(); return NoContent();
mailService.QueueAccountCreationEmail(req.Email, state); mailService.QueueAccountCreationEmail(req.Email, state);
@ -47,29 +61,48 @@ public class EmailAuthController(
CheckRequirements(); CheckRequirements();
var state = await keyCacheService.GetRegisterEmailStateAsync(req.State); var state = await keyCacheService.GetRegisterEmailStateAsync(req.State);
if (state == null) throw new ApiError.BadRequest("Invalid state", "state", req.State); if (state == null)
throw new ApiError.BadRequest("Invalid state", "state", req.State);
// If this callback is for an existing user, add the email address to their auth methods // If this callback is for an existing user, add the email address to their auth methods
if (state.ExistingUserId != null) if (state.ExistingUserId != null)
{ {
var authMethod = var authMethod = await authService.AddAuthMethodAsync(
await authService.AddAuthMethodAsync(state.ExistingUserId.Value, AuthType.Email, state.Email); state.ExistingUserId.Value,
_logger.Debug("Added email auth {AuthId} for user {UserId}", authMethod.Id, state.ExistingUserId); AuthType.Email,
state.Email
);
_logger.Debug(
"Added email auth {AuthId} for user {UserId}",
authMethod.Id,
state.ExistingUserId
);
return NoContent(); return NoContent();
} }
var ticket = AuthUtils.RandomToken(); var ticket = AuthUtils.RandomToken();
await keyCacheService.SetKeyAsync($"email:{ticket}", state.Email, Duration.FromMinutes(20)); await keyCacheService.SetKeyAsync($"email:{ticket}", state.Email, Duration.FromMinutes(20));
return Ok(new AuthController.CallbackResponse(HasAccount: false, Ticket: ticket, RemoteUsername: state.Email, return Ok(
User: null, Token: null, ExpiresAt: null)); new AuthController.CallbackResponse(
HasAccount: false,
Ticket: ticket,
RemoteUsername: state.Email,
User: null,
Token: null,
ExpiresAt: null
)
);
} }
[HttpPost("complete-registration")] [HttpPost("complete-registration")]
public async Task<IActionResult> CompleteRegistrationAsync([FromBody] CompleteRegistrationRequest req) public async Task<IActionResult> CompleteRegistrationAsync(
[FromBody] CompleteRegistrationRequest req
)
{ {
var email = await keyCacheService.GetKeyAsync($"email:{req.Ticket}"); var email = await keyCacheService.GetKeyAsync($"email:{req.Ticket}");
if (email == null) throw new ApiError.BadRequest("Unknown ticket", "ticket", req.Ticket); if (email == null)
throw new ApiError.BadRequest("Unknown ticket", "ticket", req.Ticket);
// Check if username is valid at all // Check if username is valid at all
ValidationUtils.Validate([("username", ValidationUtils.ValidateUsername(req.Username))]); ValidationUtils.Validate([("username", ValidationUtils.ValidateUsername(req.Username))]);
@ -80,28 +113,41 @@ public class EmailAuthController(
var user = await authService.CreateUserWithPasswordAsync(req.Username, email, req.Password); var user = await authService.CreateUserWithPasswordAsync(req.Username, email, req.Password);
var frontendApp = await db.GetFrontendApplicationAsync(); var frontendApp = await db.GetFrontendApplicationAsync();
var (tokenStr, token) = var (tokenStr, token) = authService.GenerateToken(
authService.GenerateToken(user, frontendApp, ["*"], clock.GetCurrentInstant() + Duration.FromDays(365)); user,
frontendApp,
["*"],
clock.GetCurrentInstant() + Duration.FromDays(365)
);
db.Add(token); db.Add(token);
await db.SaveChangesAsync(); await db.SaveChangesAsync();
await keyCacheService.DeleteKeyAsync($"email:{req.Ticket}"); await keyCacheService.DeleteKeyAsync($"email:{req.Ticket}");
return Ok(new AuthController.AuthResponse( return Ok(
new AuthController.AuthResponse(
await userRenderer.RenderUserAsync(user, selfUser: user, renderMembers: false), await userRenderer.RenderUserAsync(user, selfUser: user, renderMembers: false),
tokenStr, tokenStr,
token.ExpiresAt token.ExpiresAt
)); )
);
} }
[HttpPost("login")] [HttpPost("login")]
[ProducesResponseType<AuthController.AuthResponse>(StatusCodes.Status200OK)] [ProducesResponseType<AuthController.AuthResponse>(StatusCodes.Status200OK)]
public async Task<IActionResult> LoginAsync([FromBody] LoginRequest req, CancellationToken ct = default) public async Task<IActionResult> LoginAsync(
[FromBody] LoginRequest req,
CancellationToken ct = default
)
{ {
CheckRequirements(); CheckRequirements();
var (user, authenticationResult) = await authService.AuthenticateUserAsync(req.Email, req.Password, ct); var (user, authenticationResult) = await authService.AuthenticateUserAsync(
req.Email,
req.Password,
ct
);
if (authenticationResult == AuthService.EmailAuthenticationResult.MfaRequired) if (authenticationResult == AuthService.EmailAuthenticationResult.MfaRequired)
throw new NotImplementedException("MFA is not implemented yet"); throw new NotImplementedException("MFA is not implemented yet");
@ -109,19 +155,30 @@ public class EmailAuthController(
_logger.Debug("Logging user {Id} in with email and password", user.Id); _logger.Debug("Logging user {Id} in with email and password", user.Id);
var (tokenStr, token) = var (tokenStr, token) = authService.GenerateToken(
authService.GenerateToken(user, frontendApp, ["*"], clock.GetCurrentInstant() + Duration.FromDays(365)); user,
frontendApp,
["*"],
clock.GetCurrentInstant() + Duration.FromDays(365)
);
db.Add(token); db.Add(token);
_logger.Debug("Generated token {TokenId} for {UserId}", token.Id, user.Id); _logger.Debug("Generated token {TokenId} for {UserId}", token.Id, user.Id);
await db.SaveChangesAsync(ct); await db.SaveChangesAsync(ct);
return Ok(new AuthController.AuthResponse( return Ok(
await userRenderer.RenderUserAsync(user, selfUser: user, renderMembers: false, ct: ct), new AuthController.AuthResponse(
await userRenderer.RenderUserAsync(
user,
selfUser: user,
renderMembers: false,
ct: ct
),
tokenStr, tokenStr,
token.ExpiresAt token.ExpiresAt
)); )
);
} }
[HttpPost("add")] [HttpPost("add")]

View file

@ -18,13 +18,16 @@ public class FlagsController(
UserRendererService userRenderer, UserRendererService userRenderer,
ObjectStorageService objectStorageService, ObjectStorageService objectStorageService,
ISnowflakeGenerator snowflakeGenerator, ISnowflakeGenerator snowflakeGenerator,
IQueue queue) : ApiControllerBase IQueue queue
) : ApiControllerBase
{ {
private readonly ILogger _logger = logger.ForContext<FlagsController>(); private readonly ILogger _logger = logger.ForContext<FlagsController>();
[HttpGet] [HttpGet]
[Authorize("identify")] [Authorize("identify")]
[ProducesResponseType<IEnumerable<UserRendererService.PrideFlagResponse>>(statusCode: StatusCodes.Status200OK)] [ProducesResponseType<IEnumerable<UserRendererService.PrideFlagResponse>>(
statusCode: StatusCodes.Status200OK
)]
public async Task<IActionResult> GetFlagsAsync(CancellationToken ct = default) public async Task<IActionResult> GetFlagsAsync(CancellationToken ct = default)
{ {
var flags = await db.PrideFlags.Where(f => f.UserId == CurrentUser!.Id).ToListAsync(ct); var flags = await db.PrideFlags.Where(f => f.UserId == CurrentUser!.Id).ToListAsync(ct);
@ -34,7 +37,9 @@ public class FlagsController(
[HttpPost] [HttpPost]
[Authorize("user.update")] [Authorize("user.update")]
[ProducesResponseType<UserRendererService.PrideFlagResponse>(statusCode: StatusCodes.Status202Accepted)] [ProducesResponseType<UserRendererService.PrideFlagResponse>(
statusCode: StatusCodes.Status202Accepted
)]
public IActionResult CreateFlag([FromBody] CreateFlagRequest req) public IActionResult CreateFlag([FromBody] CreateFlagRequest req)
{ {
ValidationUtils.Validate(ValidateFlag(req.Name, req.Description, req.Image)); ValidationUtils.Validate(ValidateFlag(req.Name, req.Description, req.Image));
@ -42,7 +47,8 @@ public class FlagsController(
var id = snowflakeGenerator.GenerateSnowflake(); var id = snowflakeGenerator.GenerateSnowflake();
queue.QueueInvocableWithPayload<CreateFlagInvocable, CreateFlagPayload>( queue.QueueInvocableWithPayload<CreateFlagInvocable, CreateFlagPayload>(
new CreateFlagPayload(id, CurrentUser!.Id, req.Name, req.Image, req.Description)); new CreateFlagPayload(id, CurrentUser!.Id, req.Name, req.Image, req.Description)
);
return Accepted(new CreateFlagResponse(id, req.Name, req.Description)); return Accepted(new CreateFlagResponse(id, req.Name, req.Description));
} }
@ -57,10 +63,14 @@ public class FlagsController(
{ {
ValidationUtils.Validate(ValidateFlag(req.Name, req.Description, null)); ValidationUtils.Validate(ValidateFlag(req.Name, req.Description, null));
var flag = await db.PrideFlags.FirstOrDefaultAsync(f => f.Id == id && f.UserId == CurrentUser!.Id); var flag = await db.PrideFlags.FirstOrDefaultAsync(f =>
if (flag == null) throw new ApiError.NotFound("Unknown flag ID, or it's not your flag."); f.Id == id && f.UserId == CurrentUser!.Id
);
if (flag == null)
throw new ApiError.NotFound("Unknown flag ID, or it's not your flag.");
if (req.Name != null) flag.Name = req.Name; if (req.Name != null)
flag.Name = req.Name;
if (req.HasProperty(nameof(req.Description))) if (req.HasProperty(nameof(req.Description)))
flag.Description = req.Description; flag.Description = req.Description;
@ -83,8 +93,11 @@ public class FlagsController(
{ {
await using var tx = await db.Database.BeginTransactionAsync(); await using var tx = await db.Database.BeginTransactionAsync();
var flag = await db.PrideFlags.FirstOrDefaultAsync(f => f.Id == id && f.UserId == CurrentUser!.Id); var flag = await db.PrideFlags.FirstOrDefaultAsync(f =>
if (flag == null) throw new ApiError.NotFound("Unknown flag ID, or it's not your flag."); f.Id == id && f.UserId == CurrentUser!.Id
);
if (flag == null)
throw new ApiError.NotFound("Unknown flag ID, or it's not your flag.");
var hash = flag.Hash; var hash = flag.Hash;
@ -96,7 +109,10 @@ public class FlagsController(
{ {
try try
{ {
_logger.Information("Deleting flag file {Hash} as it is no longer used by any flags", hash); _logger.Information(
"Deleting flag file {Hash} as it is no longer used by any flags",
hash
);
await objectStorageService.DeleteFlagAsync(hash); await objectStorageService.DeleteFlagAsync(hash);
} }
catch (Exception e) catch (Exception e)
@ -104,14 +120,19 @@ public class FlagsController(
_logger.Error(e, "Error deleting flag file {Hash}", hash); _logger.Error(e, "Error deleting flag file {Hash}", hash);
} }
} }
else _logger.Debug("Flag file {Hash} is used by other flags, not deleting", hash); else
_logger.Debug("Flag file {Hash} is used by other flags, not deleting", hash);
await tx.CommitAsync(); await tx.CommitAsync();
return NoContent(); return NoContent();
} }
private static List<(string, ValidationError?)> ValidateFlag(string? name, string? description, string? imageData) private static List<(string, ValidationError?)> ValidateFlag(
string? name,
string? description,
string? imageData
)
{ {
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
@ -120,10 +141,20 @@ public class FlagsController(
switch (name.Length) switch (name.Length)
{ {
case < 1: case < 1:
errors.Add(("name", ValidationError.LengthError("Name is too short", 1, 100, name.Length))); errors.Add(
(
"name",
ValidationError.LengthError("Name is too short", 1, 100, name.Length)
)
);
break; break;
case > 100: case > 100:
errors.Add(("name", ValidationError.LengthError("Name is too long", 1, 100, name.Length))); errors.Add(
(
"name",
ValidationError.LengthError("Name is too long", 1, 100, name.Length)
)
);
break; break;
} }
} }
@ -133,12 +164,30 @@ public class FlagsController(
switch (description.Length) switch (description.Length)
{ {
case < 1: case < 1:
errors.Add(("description", errors.Add(
ValidationError.LengthError("Description is too short", 1, 100, description.Length))); (
"description",
ValidationError.LengthError(
"Description is too short",
1,
100,
description.Length
)
)
);
break; break;
case > 500: case > 500:
errors.Add(("description", errors.Add(
ValidationError.LengthError("Description is too long", 1, 100, description.Length))); (
"description",
ValidationError.LengthError(
"Description is too long",
1,
100,
description.Length
)
)
);
break; break;
} }
} }
@ -148,10 +197,20 @@ public class FlagsController(
switch (imageData.Length) switch (imageData.Length)
{ {
case 0: case 0:
errors.Add(("image", ValidationError.GenericValidationError("Image cannot be empty", null))); errors.Add(
(
"image",
ValidationError.GenericValidationError("Image cannot be empty", null)
)
);
break; break;
case > 1_500_000: case > 1_500_000:
errors.Add(("image", ValidationError.GenericValidationError("Image is too large", null))); errors.Add(
(
"image",
ValidationError.GenericValidationError("Image is too large", null)
)
);
break; break;
} }
} }

View file

@ -17,11 +17,13 @@ public partial class InternalController(DatabaseContext db) : ControllerBase
private static string GetCleanedTemplate(string template) private static string GetCleanedTemplate(string template)
{ {
if (template.StartsWith("api/v2")) template = template["api/v2".Length..]; if (template.StartsWith("api/v2"))
template = template["api/v2".Length..];
template = PathVarRegex() template = PathVarRegex()
.Replace(template, "{id}") // Replace all path variables (almost always IDs) with `{id}` .Replace(template, "{id}") // Replace all path variables (almost always IDs) with `{id}`
.Replace("@me", "{id}"); // Also replace hardcoded `@me` with `{id}` .Replace("@me", "{id}"); // Also replace hardcoded `@me` with `{id}`
if (template.Contains("{id}")) return template.Split("{id}")[0] + "{id}"; if (template.Contains("{id}"))
return template.Split("{id}")[0] + "{id}";
return template; return template;
} }
@ -29,11 +31,13 @@ public partial class InternalController(DatabaseContext db) : ControllerBase
public async Task<IActionResult> GetRequestDataAsync([FromBody] RequestDataRequest req) public async Task<IActionResult> GetRequestDataAsync([FromBody] RequestDataRequest req)
{ {
var endpoint = GetEndpoint(HttpContext, req.Path, req.Method); var endpoint = GetEndpoint(HttpContext, req.Path, req.Method);
if (endpoint == null) throw new ApiError.BadRequest("Path/method combination is invalid"); if (endpoint == null)
throw new ApiError.BadRequest("Path/method combination is invalid");
var actionDescriptor = endpoint.Metadata.GetMetadata<ControllerActionDescriptor>(); var actionDescriptor = endpoint.Metadata.GetMetadata<ControllerActionDescriptor>();
var template = actionDescriptor?.AttributeRouteInfo?.Template; var template = actionDescriptor?.AttributeRouteInfo?.Template;
if (template == null) throw new FoxnounsError("Template value was null on valid endpoint"); if (template == null)
throw new FoxnounsError("Template value was null on valid endpoint");
template = GetCleanedTemplate(template); template = GetCleanedTemplate(template);
// If no token was supplied, or it isn't valid base 64, return a null user ID (limiting by IP) // If no token was supplied, or it isn't valid base 64, return a null user ID (limiting by IP)
@ -46,26 +50,37 @@ public partial class InternalController(DatabaseContext db) : ControllerBase
public record RequestDataRequest(string? Token, string Method, string Path); public record RequestDataRequest(string? Token, string Method, string Path);
public record RequestDataResponse( public record RequestDataResponse(Snowflake? UserId, string Template);
Snowflake? UserId,
string Template);
private static RouteEndpoint? GetEndpoint(HttpContext httpContext, string url, string requestMethod) private static RouteEndpoint? GetEndpoint(
HttpContext httpContext,
string url,
string requestMethod
)
{ {
var endpointDataSource = httpContext.RequestServices.GetService<EndpointDataSource>(); var endpointDataSource = httpContext.RequestServices.GetService<EndpointDataSource>();
if (endpointDataSource == null) return null; if (endpointDataSource == null)
return null;
var endpoints = endpointDataSource.Endpoints.OfType<RouteEndpoint>(); var endpoints = endpointDataSource.Endpoints.OfType<RouteEndpoint>();
foreach (var endpoint in endpoints) foreach (var endpoint in endpoints)
{ {
if (endpoint.RoutePattern.RawText == null) continue; if (endpoint.RoutePattern.RawText == null)
continue;
var templateMatcher = new TemplateMatcher(TemplateParser.Parse(endpoint.RoutePattern.RawText), var templateMatcher = new TemplateMatcher(
new RouteValueDictionary()); TemplateParser.Parse(endpoint.RoutePattern.RawText),
if (!templateMatcher.TryMatch(url, new())) continue; new RouteValueDictionary()
);
if (!templateMatcher.TryMatch(url, new()))
continue;
var httpMethodAttribute = endpoint.Metadata.GetMetadata<HttpMethodAttribute>(); var httpMethodAttribute = endpoint.Metadata.GetMetadata<HttpMethodAttribute>();
if (httpMethodAttribute != null && if (
!httpMethodAttribute.HttpMethods.Any(x => x.Equals(requestMethod, StringComparison.OrdinalIgnoreCase))) httpMethodAttribute != null
&& !httpMethodAttribute.HttpMethods.Any(x =>
x.Equals(requestMethod, StringComparison.OrdinalIgnoreCase)
)
)
continue; continue;
return endpoint; return endpoint;
} }

View file

@ -21,12 +21,15 @@ public class MembersController(
ISnowflakeGenerator snowflakeGenerator, ISnowflakeGenerator snowflakeGenerator,
ObjectStorageService objectStorageService, ObjectStorageService objectStorageService,
IQueue queue, IQueue queue,
IClock clock) : ApiControllerBase IClock clock
) : ApiControllerBase
{ {
private readonly ILogger _logger = logger.ForContext<MembersController>(); private readonly ILogger _logger = logger.ForContext<MembersController>();
[HttpGet] [HttpGet]
[ProducesResponseType<IEnumerable<MemberRendererService.PartialMember>>(StatusCodes.Status200OK)] [ProducesResponseType<IEnumerable<MemberRendererService.PartialMember>>(
StatusCodes.Status200OK
)]
public async Task<IActionResult> GetMembersAsync(string userRef, CancellationToken ct = default) public async Task<IActionResult> GetMembersAsync(string userRef, CancellationToken ct = default)
{ {
var user = await db.ResolveUserAsync(userRef, CurrentToken, ct); var user = await db.ResolveUserAsync(userRef, CurrentToken, ct);
@ -35,7 +38,11 @@ public class MembersController(
[HttpGet("{memberRef}")] [HttpGet("{memberRef}")]
[ProducesResponseType<MemberRendererService.MemberResponse>(StatusCodes.Status200OK)] [ProducesResponseType<MemberRendererService.MemberResponse>(StatusCodes.Status200OK)]
public async Task<IActionResult> GetMemberAsync(string userRef, string memberRef, CancellationToken ct = default) public async Task<IActionResult> GetMemberAsync(
string userRef,
string memberRef,
CancellationToken ct = default
)
{ {
var member = await db.ResolveMemberAsync(userRef, memberRef, CurrentToken, ct); var member = await db.ResolveMemberAsync(userRef, memberRef, CurrentToken, ct);
return Ok(memberRenderer.RenderMember(member, CurrentToken)); return Ok(memberRenderer.RenderMember(member, CurrentToken));
@ -46,19 +53,30 @@ public class MembersController(
[HttpPost("/api/v2/users/@me/members")] [HttpPost("/api/v2/users/@me/members")]
[ProducesResponseType<MemberRendererService.MemberResponse>(StatusCodes.Status200OK)] [ProducesResponseType<MemberRendererService.MemberResponse>(StatusCodes.Status200OK)]
[Authorize("member.create")] [Authorize("member.create")]
public async Task<IActionResult> CreateMemberAsync([FromBody] CreateMemberRequest req, public async Task<IActionResult> CreateMemberAsync(
CancellationToken ct = default) [FromBody] CreateMemberRequest req,
CancellationToken ct = default
)
{ {
ValidationUtils.Validate([ ValidationUtils.Validate(
[
("name", ValidationUtils.ValidateMemberName(req.Name)), ("name", ValidationUtils.ValidateMemberName(req.Name)),
("display_name", ValidationUtils.ValidateDisplayName(req.DisplayName)), ("display_name", ValidationUtils.ValidateDisplayName(req.DisplayName)),
("bio", ValidationUtils.ValidateBio(req.Bio)), ("bio", ValidationUtils.ValidateBio(req.Bio)),
("avatar", ValidationUtils.ValidateAvatar(req.Avatar)), ("avatar", ValidationUtils.ValidateAvatar(req.Avatar)),
.. ValidationUtils.ValidateFields(req.Fields, CurrentUser!.CustomPreferences), .. ValidationUtils.ValidateFields(req.Fields, CurrentUser!.CustomPreferences),
.. ValidationUtils.ValidateFieldEntries(req.Names?.ToArray(), CurrentUser!.CustomPreferences, "names"), .. ValidationUtils.ValidateFieldEntries(
.. ValidationUtils.ValidatePronouns(req.Pronouns?.ToArray(), CurrentUser!.CustomPreferences), req.Names?.ToArray(),
.. ValidationUtils.ValidateLinks(req.Links) CurrentUser!.CustomPreferences,
]); "names"
),
.. ValidationUtils.ValidatePronouns(
req.Pronouns?.ToArray(),
CurrentUser!.CustomPreferences
),
.. ValidationUtils.ValidateLinks(req.Links),
]
);
var memberCount = await db.Members.CountAsync(m => m.UserId == CurrentUser.Id, ct); var memberCount = await db.Members.CountAsync(m => m.UserId == CurrentUser.Id, ct);
if (memberCount >= MaxMemberCount) if (memberCount >= MaxMemberCount)
@ -75,11 +93,16 @@ public class MembersController(
Fields = req.Fields ?? [], Fields = req.Fields ?? [],
Names = req.Names ?? [], Names = req.Names ?? [],
Pronouns = req.Pronouns ?? [], Pronouns = req.Pronouns ?? [],
Unlisted = req.Unlisted ?? false Unlisted = req.Unlisted ?? false,
}; };
db.Add(member); db.Add(member);
_logger.Debug("Creating member {MemberName} ({Id}) for {UserId}", member.Name, member.Id, CurrentUser!.Id); _logger.Debug(
"Creating member {MemberName} ({Id}) for {UserId}",
member.Name,
member.Id,
CurrentUser!.Id
);
try try
{ {
@ -88,19 +111,27 @@ public class MembersController(
catch (UniqueConstraintException) catch (UniqueConstraintException)
{ {
_logger.Debug("Could not create member {Id} due to name conflict", member.Id); _logger.Debug("Could not create member {Id} due to name conflict", member.Id);
throw new ApiError.BadRequest("A member with that name already exists", "name", req.Name); throw new ApiError.BadRequest(
"A member with that name already exists",
"name",
req.Name
);
} }
if (req.Avatar != null) if (req.Avatar != null)
queue.QueueInvocableWithPayload<MemberAvatarUpdateInvocable, AvatarUpdatePayload>( queue.QueueInvocableWithPayload<MemberAvatarUpdateInvocable, AvatarUpdatePayload>(
new AvatarUpdatePayload(member.Id, req.Avatar)); new AvatarUpdatePayload(member.Id, req.Avatar)
);
return Ok(memberRenderer.RenderMember(member, CurrentToken)); return Ok(memberRenderer.RenderMember(member, CurrentToken));
} }
[HttpPatch("/api/v2/users/@me/members/{memberRef}")] [HttpPatch("/api/v2/users/@me/members/{memberRef}")]
[Authorize("member.update")] [Authorize("member.update")]
public async Task<IActionResult> UpdateMemberAsync(string memberRef, [FromBody] UpdateMemberRequest req) public async Task<IActionResult> UpdateMemberAsync(
string memberRef,
[FromBody] UpdateMemberRequest req
)
{ {
await using var tx = await db.Database.BeginTransactionAsync(); await using var tx = await db.Database.BeginTransactionAsync();
var member = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef); var member = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef);
@ -134,26 +165,37 @@ public class MembersController(
if (req.Names != null) if (req.Names != null)
{ {
errors.AddRange(ValidationUtils.ValidateFieldEntries(req.Names, CurrentUser!.CustomPreferences, "names")); errors.AddRange(
ValidationUtils.ValidateFieldEntries(
req.Names,
CurrentUser!.CustomPreferences,
"names"
)
);
member.Names = req.Names.ToList(); member.Names = req.Names.ToList();
} }
if (req.Pronouns != null) if (req.Pronouns != null)
{ {
errors.AddRange(ValidationUtils.ValidatePronouns(req.Pronouns, CurrentUser!.CustomPreferences)); errors.AddRange(
ValidationUtils.ValidatePronouns(req.Pronouns, CurrentUser!.CustomPreferences)
);
member.Pronouns = req.Pronouns.ToList(); member.Pronouns = req.Pronouns.ToList();
} }
if (req.Fields != null) if (req.Fields != null)
{ {
errors.AddRange(ValidationUtils.ValidateFields(req.Fields.ToList(), CurrentUser!.CustomPreferences)); errors.AddRange(
ValidationUtils.ValidateFields(req.Fields.ToList(), CurrentUser!.CustomPreferences)
);
member.Fields = req.Fields.ToList(); member.Fields = req.Fields.ToList();
} }
if (req.Flags != null) if (req.Flags != null)
{ {
var flagError = await db.SetMemberFlagsAsync(CurrentUser!.Id, member.Id, req.Flags); var flagError = await db.SetMemberFlagsAsync(CurrentUser!.Id, member.Id, req.Flags);
if (flagError != null) errors.Add(("flags", flagError)); if (flagError != null)
errors.Add(("flags", flagError));
} }
if (req.HasProperty(nameof(req.Avatar))) if (req.HasProperty(nameof(req.Avatar)))
@ -165,16 +207,25 @@ public class MembersController(
// so it's in a separate block to the validation above. // so it's in a separate block to the validation above.
if (req.HasProperty(nameof(req.Avatar))) if (req.HasProperty(nameof(req.Avatar)))
queue.QueueInvocableWithPayload<MemberAvatarUpdateInvocable, AvatarUpdatePayload>( queue.QueueInvocableWithPayload<MemberAvatarUpdateInvocable, AvatarUpdatePayload>(
new AvatarUpdatePayload(member.Id, req.Avatar)); new AvatarUpdatePayload(member.Id, req.Avatar)
);
try try
{ {
await db.SaveChangesAsync(); await db.SaveChangesAsync();
} }
catch (UniqueConstraintException) catch (UniqueConstraintException)
{ {
_logger.Debug("Could not update member {Id} due to name conflict ({CurrentName} / {NewName})", member.Id, _logger.Debug(
member.Name, req.Name); "Could not update member {Id} due to name conflict ({CurrentName} / {NewName})",
throw new ApiError.BadRequest("A member with that name already exists", "name", req.Name!); member.Id,
member.Name,
req.Name
);
throw new ApiError.BadRequest(
"A member with that name already exists",
"name",
req.Name!
);
} }
await tx.CommitAsync(); await tx.CommitAsync();
@ -199,15 +250,20 @@ public class MembersController(
public async Task<IActionResult> DeleteMemberAsync(string memberRef) public async Task<IActionResult> DeleteMemberAsync(string memberRef)
{ {
var member = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef); var member = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef);
var deleteCount = await db.Members.Where(m => m.UserId == CurrentUser!.Id && m.Id == member.Id) var deleteCount = await db
.Members.Where(m => m.UserId == CurrentUser!.Id && m.Id == member.Id)
.ExecuteDeleteAsync(); .ExecuteDeleteAsync();
if (deleteCount == 0) if (deleteCount == 0)
{ {
_logger.Warning("Successfully resolved member {Id} but could not delete them", member.Id); _logger.Warning(
"Successfully resolved member {Id} but could not delete them",
member.Id
);
return NoContent(); return NoContent();
} }
if (member.Avatar != null) await objectStorageService.DeleteMemberAvatarAsync(member.Id, member.Avatar); if (member.Avatar != null)
await objectStorageService.DeleteMemberAvatarAsync(member.Id, member.Avatar);
return NoContent(); return NoContent();
} }
@ -220,7 +276,8 @@ public class MembersController(
string[]? Links, string[]? Links,
List<FieldEntry>? Names, List<FieldEntry>? Names,
List<Pronoun>? Pronouns, List<Pronoun>? Pronouns,
List<Field>? Fields); List<Field>? Fields
);
[HttpPost("/api/v2/users/@me/members/{memberRef}/reroll-sid")] [HttpPost("/api/v2/users/@me/members/{memberRef}/reroll-sid")]
[Authorize("member.update")] [Authorize("member.update")]
@ -234,14 +291,16 @@ public class MembersController(
throw new ApiError.BadRequest("Cannot reroll short ID yet"); throw new ApiError.BadRequest("Cannot reroll short ID yet");
// Using ExecuteUpdateAsync here as the new short ID is generated by the database // Using ExecuteUpdateAsync here as the new short ID is generated by the database
await db.Members.Where(m => m.Id == member.Id) await db
.ExecuteUpdateAsync(s => s .Members.Where(m => m.Id == member.Id)
.SetProperty(m => m.Sid, _ => db.FindFreeMemberSid())); .ExecuteUpdateAsync(s => s.SetProperty(m => m.Sid, _ => db.FindFreeMemberSid()));
await db.Users.Where(u => u.Id == CurrentUser.Id) await db
.ExecuteUpdateAsync(s => s .Users.Where(u => u.Id == CurrentUser.Id)
.SetProperty(u => u.LastSidReroll, clock.GetCurrentInstant()) .ExecuteUpdateAsync(s =>
.SetProperty(u => u.LastActive, clock.GetCurrentInstant())); s.SetProperty(u => u.LastSidReroll, clock.GetCurrentInstant())
.SetProperty(u => u.LastActive, clock.GetCurrentInstant())
);
// Re-fetch member to fetch the new sid // Re-fetch member to fetch the new sid
var updatedMember = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef); var updatedMember = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef);

View file

@ -12,8 +12,12 @@ public class MetaController : ApiControllerBase
[ProducesResponseType<MetaResponse>(StatusCodes.Status200OK)] [ProducesResponseType<MetaResponse>(StatusCodes.Status200OK)]
public IActionResult GetMeta() public IActionResult GetMeta()
{ {
return Ok(new MetaResponse( return Ok(
Repository, BuildInfo.Version, BuildInfo.Hash, (int)FoxnounsMetrics.MemberCount.Value, new MetaResponse(
Repository,
BuildInfo.Version,
BuildInfo.Hash,
(int)FoxnounsMetrics.MemberCount.Value,
new UserInfo( new UserInfo(
(int)FoxnounsMetrics.UsersCount.Value, (int)FoxnounsMetrics.UsersCount.Value,
(int)FoxnounsMetrics.UsersActiveMonthCount.Value, (int)FoxnounsMetrics.UsersActiveMonthCount.Value,
@ -23,12 +27,15 @@ public class MetaController : ApiControllerBase
new Limits( new Limits(
MemberCount: MembersController.MaxMemberCount, MemberCount: MembersController.MaxMemberCount,
BioLength: ValidationUtils.MaxBioLength, BioLength: ValidationUtils.MaxBioLength,
CustomPreferences: UsersController.MaxCustomPreferences)) CustomPreferences: UsersController.MaxCustomPreferences
)
)
); );
} }
[HttpGet("/api/v2/coffee")] [HttpGet("/api/v2/coffee")]
public IActionResult BrewCoffee() => Problem("Sorry, I'm a teapot!", statusCode: StatusCodes.Status418ImATeapot); public IActionResult BrewCoffee() =>
Problem("Sorry, I'm a teapot!", statusCode: StatusCodes.Status418ImATeapot);
private record MetaResponse( private record MetaResponse(
string Repository, string Repository,
@ -36,13 +43,11 @@ public class MetaController : ApiControllerBase
string Hash, string Hash,
int Members, int Members,
UserInfo Users, UserInfo Users,
Limits Limits); Limits Limits
);
private record UserInfo(int Total, int ActiveMonth, int ActiveWeek, int ActiveDay); private record UserInfo(int Total, int ActiveMonth, int ActiveWeek, int ActiveDay);
// All limits that the frontend should know about (for UI purposes) // All limits that the frontend should know about (for UI purposes)
private record Limits( private record Limits(int MemberCount, int BioLength, int CustomPreferences);
int MemberCount,
int BioLength,
int CustomPreferences);
} }

View file

@ -20,7 +20,8 @@ public class UsersController(
UserRendererService userRenderer, UserRendererService userRenderer,
ISnowflakeGenerator snowflakeGenerator, ISnowflakeGenerator snowflakeGenerator,
IQueue queue, IQueue queue,
IClock clock) : ApiControllerBase IClock clock
) : ApiControllerBase
{ {
private readonly ILogger _logger = logger.ForContext<UsersController>(); private readonly ILogger _logger = logger.ForContext<UsersController>();
@ -29,20 +30,25 @@ public class UsersController(
public async Task<IActionResult> GetUserAsync(string userRef, CancellationToken ct = default) public async Task<IActionResult> GetUserAsync(string userRef, CancellationToken ct = default)
{ {
var user = await db.ResolveUserAsync(userRef, CurrentToken, ct); var user = await db.ResolveUserAsync(userRef, CurrentToken, ct);
return Ok(await userRenderer.RenderUserAsync( return Ok(
await userRenderer.RenderUserAsync(
user, user,
selfUser: CurrentUser, selfUser: CurrentUser,
token: CurrentToken, token: CurrentToken,
renderMembers: true, renderMembers: true,
renderAuthMethods: true, renderAuthMethods: true,
ct: ct ct: ct
)); )
);
} }
[HttpPatch("@me")] [HttpPatch("@me")]
[Authorize("user.update")] [Authorize("user.update")]
[ProducesResponseType<UserRendererService.UserResponse>(statusCode: StatusCodes.Status200OK)] [ProducesResponseType<UserRendererService.UserResponse>(statusCode: StatusCodes.Status200OK)]
public async Task<IActionResult> UpdateUserAsync([FromBody] UpdateUserRequest req, CancellationToken ct = default) public async Task<IActionResult> UpdateUserAsync(
[FromBody] UpdateUserRequest req,
CancellationToken ct = default
)
{ {
await using var tx = await db.Database.BeginTransactionAsync(ct); await using var tx = await db.Database.BeginTransactionAsync(ct);
var user = await db.Users.FirstAsync(u => u.Id == CurrentUser!.Id, ct); var user = await db.Users.FirstAsync(u => u.Id == CurrentUser!.Id, ct);
@ -74,26 +80,37 @@ public class UsersController(
if (req.Names != null) if (req.Names != null)
{ {
errors.AddRange(ValidationUtils.ValidateFieldEntries(req.Names, CurrentUser!.CustomPreferences, "names")); errors.AddRange(
ValidationUtils.ValidateFieldEntries(
req.Names,
CurrentUser!.CustomPreferences,
"names"
)
);
user.Names = req.Names.ToList(); user.Names = req.Names.ToList();
} }
if (req.Pronouns != null) if (req.Pronouns != null)
{ {
errors.AddRange(ValidationUtils.ValidatePronouns(req.Pronouns, CurrentUser!.CustomPreferences)); errors.AddRange(
ValidationUtils.ValidatePronouns(req.Pronouns, CurrentUser!.CustomPreferences)
);
user.Pronouns = req.Pronouns.ToList(); user.Pronouns = req.Pronouns.ToList();
} }
if (req.Fields != null) if (req.Fields != null)
{ {
errors.AddRange(ValidationUtils.ValidateFields(req.Fields.ToList(), CurrentUser!.CustomPreferences)); errors.AddRange(
ValidationUtils.ValidateFields(req.Fields.ToList(), CurrentUser!.CustomPreferences)
);
user.Fields = req.Fields.ToList(); user.Fields = req.Fields.ToList();
} }
if (req.Flags != null) if (req.Flags != null)
{ {
var flagError = await db.SetUserFlagsAsync(CurrentUser!.Id, req.Flags); var flagError = await db.SetUserFlagsAsync(CurrentUser!.Id, req.Flags);
if (flagError != null) errors.Add(("flags", flagError)); if (flagError != null)
errors.Add(("flags", flagError));
} }
if (req.HasProperty(nameof(req.Avatar))) if (req.HasProperty(nameof(req.Avatar)))
@ -105,7 +122,8 @@ public class UsersController(
// so it's in a separate block to the validation above. // so it's in a separate block to the validation above.
if (req.HasProperty(nameof(req.Avatar))) if (req.HasProperty(nameof(req.Avatar)))
queue.QueueInvocableWithPayload<UserAvatarUpdateInvocable, AvatarUpdatePayload>( queue.QueueInvocableWithPayload<UserAvatarUpdateInvocable, AvatarUpdatePayload>(
new AvatarUpdatePayload(CurrentUser!.Id, req.Avatar)); new AvatarUpdatePayload(CurrentUser!.Id, req.Avatar)
);
try try
{ {
@ -113,26 +131,45 @@ public class UsersController(
} }
catch (UniqueConstraintException) catch (UniqueConstraintException)
{ {
_logger.Debug("Could not update user {Id} due to name conflict ({CurrentName} / {NewName})", user.Id, _logger.Debug(
user.Username, req.Username); "Could not update user {Id} due to name conflict ({CurrentName} / {NewName})",
throw new ApiError.BadRequest("That username is already taken.", "username", req.Username!); user.Id,
user.Username,
req.Username
);
throw new ApiError.BadRequest(
"That username is already taken.",
"username",
req.Username!
);
} }
await tx.CommitAsync(ct); await tx.CommitAsync(ct);
return Ok(await userRenderer.RenderUserAsync(user, CurrentUser, renderMembers: false, return Ok(
renderAuthMethods: false, ct: ct)); await userRenderer.RenderUserAsync(
user,
CurrentUser,
renderMembers: false,
renderAuthMethods: false,
ct: ct
)
);
} }
[HttpPatch("@me/custom-preferences")] [HttpPatch("@me/custom-preferences")]
[Authorize("user.update")] [Authorize("user.update")]
[ProducesResponseType<Dictionary<Snowflake, User.CustomPreference>>(StatusCodes.Status200OK)] [ProducesResponseType<Dictionary<Snowflake, User.CustomPreference>>(StatusCodes.Status200OK)]
public async Task<IActionResult> UpdateCustomPreferencesAsync([FromBody] List<CustomPreferencesUpdateRequest> req, public async Task<IActionResult> UpdateCustomPreferencesAsync(
CancellationToken ct = default) [FromBody] List<CustomPreferencesUpdateRequest> req,
CancellationToken ct = default
)
{ {
ValidationUtils.Validate(ValidateCustomPreferences(req)); ValidationUtils.Validate(ValidateCustomPreferences(req));
var user = await db.ResolveUserAsync(CurrentUser!.Id, ct); var user = await db.ResolveUserAsync(CurrentUser!.Id, ct);
var preferences = user.CustomPreferences.Where(x => req.Any(r => r.Id == x.Key)).ToDictionary(); var preferences = user
.CustomPreferences.Where(x => req.Any(r => r.Id == x.Key))
.ToDictionary();
foreach (var r in req) foreach (var r in req)
{ {
@ -144,7 +181,7 @@ public class UsersController(
Icon = r.Icon, Icon = r.Icon,
Muted = r.Muted, Muted = r.Muted,
Size = r.Size, Size = r.Size,
Tooltip = r.Tooltip Tooltip = r.Tooltip,
}; };
} }
else else
@ -155,7 +192,7 @@ public class UsersController(
Icon = r.Icon, Icon = r.Icon,
Muted = r.Muted, Muted = r.Muted,
Size = r.Size, Size = r.Size,
Tooltip = r.Tooltip Tooltip = r.Tooltip,
}; };
} }
} }
@ -180,15 +217,25 @@ public class UsersController(
public const int MaxCustomPreferences = 25; public const int MaxCustomPreferences = 25;
private static List<(string, ValidationError?)> ValidateCustomPreferences( private static List<(string, ValidationError?)> ValidateCustomPreferences(
List<CustomPreferencesUpdateRequest> preferences) List<CustomPreferencesUpdateRequest> preferences
)
{ {
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
if (preferences.Count > MaxCustomPreferences) if (preferences.Count > MaxCustomPreferences)
errors.Add(("custom_preferences", errors.Add(
ValidationError.LengthError("Too many custom preferences", 0, MaxCustomPreferences, (
preferences.Count))); "custom_preferences",
if (preferences.Count > 50) return errors; ValidationError.LengthError(
"Too many custom preferences",
0,
MaxCustomPreferences,
preferences.Count
)
)
);
if (preferences.Count > 50)
return errors;
// TODO: validate individual preferences // TODO: validate individual preferences
@ -208,7 +255,6 @@ public class UsersController(
public Snowflake[]? Flags { get; init; } public Snowflake[]? Flags { get; init; }
} }
[HttpGet("@me/settings")] [HttpGet("@me/settings")]
[Authorize("user.read_hidden")] [Authorize("user.read_hidden")]
[ProducesResponseType<UserSettings>(statusCode: StatusCodes.Status200OK)] [ProducesResponseType<UserSettings>(statusCode: StatusCodes.Status200OK)]
@ -221,8 +267,10 @@ public class UsersController(
[HttpPatch("@me/settings")] [HttpPatch("@me/settings")]
[Authorize("user.read_hidden", "user.update")] [Authorize("user.read_hidden", "user.update")]
[ProducesResponseType<UserSettings>(statusCode: StatusCodes.Status200OK)] [ProducesResponseType<UserSettings>(statusCode: StatusCodes.Status200OK)]
public async Task<IActionResult> UpdateUserSettingsAsync([FromBody] UpdateUserSettingsRequest req, public async Task<IActionResult> UpdateUserSettingsAsync(
CancellationToken ct = default) [FromBody] UpdateUserSettingsRequest req,
CancellationToken ct = default
)
{ {
var user = await db.Users.FirstAsync(u => u.Id == CurrentUser!.Id, ct); var user = await db.Users.FirstAsync(u => u.Id == CurrentUser!.Id, ct);
@ -250,13 +298,22 @@ public class UsersController(
throw new ApiError.BadRequest("Cannot reroll short ID yet"); throw new ApiError.BadRequest("Cannot reroll short ID yet");
// Using ExecuteUpdateAsync here as the new short ID is generated by the database // Using ExecuteUpdateAsync here as the new short ID is generated by the database
await db.Users.Where(u => u.Id == CurrentUser.Id) await db
.ExecuteUpdateAsync(s => s .Users.Where(u => u.Id == CurrentUser.Id)
.SetProperty(u => u.Sid, _ => db.FindFreeUserSid()) .ExecuteUpdateAsync(s =>
s.SetProperty(u => u.Sid, _ => db.FindFreeUserSid())
.SetProperty(u => u.LastSidReroll, clock.GetCurrentInstant()) .SetProperty(u => u.LastSidReroll, clock.GetCurrentInstant())
.SetProperty(u => u.LastActive, clock.GetCurrentInstant())); .SetProperty(u => u.LastActive, clock.GetCurrentInstant())
);
var user = await db.ResolveUserAsync(CurrentUser.Id); var user = await db.ResolveUserAsync(CurrentUser.Id);
return Ok(await userRenderer.RenderUserAsync(user, CurrentUser, CurrentToken, renderMembers: false)); return Ok(
await userRenderer.RenderUserAsync(
user,
CurrentUser,
CurrentToken,
renderMembers: false
)
);
} }
} }

View file

@ -45,11 +45,12 @@ public class DatabaseContext : DbContext
_loggerFactory = loggerFactory; _loggerFactory = loggerFactory;
} }
protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) =>
=> optionsBuilder optionsBuilder
.ConfigureWarnings(c => .ConfigureWarnings(c =>
c.Ignore(CoreEventId.ManyServiceProvidersCreatedWarning) c.Ignore(CoreEventId.ManyServiceProvidersCreatedWarning)
.Ignore(CoreEventId.SaveChangesFailed)) .Ignore(CoreEventId.SaveChangesFailed)
)
.UseNpgsql(_dataSource, o => o.UseNodaTime()) .UseNpgsql(_dataSource, o => o.UseNodaTime())
.UseSnakeCaseNamingConvention() .UseSnakeCaseNamingConvention()
.UseLoggerFactory(_loggerFactory) .UseLoggerFactory(_loggerFactory)
@ -76,7 +77,10 @@ public class DatabaseContext : DbContext
modelBuilder.Entity<User>().Property(u => u.CustomPreferences).HasColumnType("jsonb"); modelBuilder.Entity<User>().Property(u => u.CustomPreferences).HasColumnType("jsonb");
modelBuilder.Entity<User>().Property(u => u.Settings).HasColumnType("jsonb"); modelBuilder.Entity<User>().Property(u => u.Settings).HasColumnType("jsonb");
modelBuilder.Entity<Member>().Property(m => m.Sid).HasDefaultValueSql("find_free_member_sid()"); modelBuilder
.Entity<Member>()
.Property(m => m.Sid)
.HasDefaultValueSql("find_free_member_sid()");
modelBuilder.Entity<Member>().Property(m => m.Fields).HasColumnType("jsonb"); modelBuilder.Entity<Member>().Property(m => m.Fields).HasColumnType("jsonb");
modelBuilder.Entity<Member>().Property(m => m.Names).HasColumnType("jsonb"); modelBuilder.Entity<Member>().Property(m => m.Names).HasColumnType("jsonb");
modelBuilder.Entity<Member>().Property(m => m.Pronouns).HasColumnType("jsonb"); modelBuilder.Entity<Member>().Property(m => m.Pronouns).HasColumnType("jsonb");
@ -84,10 +88,12 @@ public class DatabaseContext : DbContext
modelBuilder.Entity<UserFlag>().Navigation(f => f.PrideFlag).AutoInclude(); modelBuilder.Entity<UserFlag>().Navigation(f => f.PrideFlag).AutoInclude();
modelBuilder.Entity<MemberFlag>().Navigation(f => f.PrideFlag).AutoInclude(); modelBuilder.Entity<MemberFlag>().Navigation(f => f.PrideFlag).AutoInclude();
modelBuilder.HasDbFunction(typeof(DatabaseContext).GetMethod(nameof(FindFreeUserSid))!) modelBuilder
.HasDbFunction(typeof(DatabaseContext).GetMethod(nameof(FindFreeUserSid))!)
.HasName("find_free_user_sid"); .HasName("find_free_user_sid");
modelBuilder.HasDbFunction(typeof(DatabaseContext).GetMethod(nameof(FindFreeMemberSid))!) modelBuilder
.HasDbFunction(typeof(DatabaseContext).GetMethod(nameof(FindFreeMemberSid))!)
.HasName("find_free_member_sid"); .HasName("find_free_member_sid");
} }
@ -102,13 +108,18 @@ public class DatabaseContext : DbContext
public string FindFreeMemberSid() => throw new NotSupportedException(); public string FindFreeMemberSid() => throw new NotSupportedException();
} }
[SuppressMessage("ReSharper", "UnusedType.Global", Justification = "Used by EF Core's migration generator")] [SuppressMessage(
"ReSharper",
"UnusedType.Global",
Justification = "Used by EF Core's migration generator"
)]
public class DesignTimeDatabaseContextFactory : IDesignTimeDbContextFactory<DatabaseContext> public class DesignTimeDatabaseContextFactory : IDesignTimeDbContextFactory<DatabaseContext>
{ {
public DatabaseContext CreateDbContext(string[] args) public DatabaseContext CreateDbContext(string[] args)
{ {
// Read the configuration file // Read the configuration file
var config = new ConfigurationBuilder() var config =
new ConfigurationBuilder()
.AddConfiguration() .AddConfiguration()
.Build() .Build()
// Get the configuration as our config class // Get the configuration as our config class

View file

@ -8,89 +8,128 @@ namespace Foxnouns.Backend.Database;
public static class DatabaseQueryExtensions public static class DatabaseQueryExtensions
{ {
public static async Task<User> ResolveUserAsync(this DatabaseContext context, string userRef, Token? token, public static async Task<User> ResolveUserAsync(
CancellationToken ct = default) this DatabaseContext context,
string userRef,
Token? token,
CancellationToken ct = default
)
{ {
if (userRef == "@me") if (userRef == "@me")
{ {
return token != null return token != null
? await context.Users.FirstAsync(u => u.Id == token.UserId, ct) ? await context.Users.FirstAsync(u => u.Id == token.UserId, ct)
: throw new ApiError.Unauthorized("This endpoint requires an authenticated user.", : throw new ApiError.Unauthorized(
ErrorCode.AuthenticationRequired); "This endpoint requires an authenticated user.",
ErrorCode.AuthenticationRequired
);
} }
User? user; User? user;
if (Snowflake.TryParse(userRef, out var snowflake)) if (Snowflake.TryParse(userRef, out var snowflake))
{ {
user = await context.Users user = await context
.Where(u => !u.Deleted) .Users.Where(u => !u.Deleted)
.FirstOrDefaultAsync(u => u.Id == snowflake, ct); .FirstOrDefaultAsync(u => u.Id == snowflake, ct);
if (user != null) return user; if (user != null)
return user;
} }
user = await context.Users user = await context
.Where(u => !u.Deleted) .Users.Where(u => !u.Deleted)
.FirstOrDefaultAsync(u => u.Username == userRef, ct); .FirstOrDefaultAsync(u => u.Username == userRef, ct);
if (user != null) return user; if (user != null)
throw new ApiError.NotFound("No user with that ID or username found.", code: ErrorCode.UserNotFound); return user;
throw new ApiError.NotFound(
"No user with that ID or username found.",
code: ErrorCode.UserNotFound
);
} }
public static async Task<User> ResolveUserAsync(this DatabaseContext context, Snowflake id, public static async Task<User> ResolveUserAsync(
CancellationToken ct = default) this DatabaseContext context,
Snowflake id,
CancellationToken ct = default
)
{ {
var user = await context.Users var user = await context
.Where(u => !u.Deleted) .Users.Where(u => !u.Deleted)
.FirstOrDefaultAsync(u => u.Id == id, ct); .FirstOrDefaultAsync(u => u.Id == id, ct);
if (user != null) return user; if (user != null)
return user;
throw new ApiError.NotFound("No user with that ID found.", code: ErrorCode.UserNotFound); throw new ApiError.NotFound("No user with that ID found.", code: ErrorCode.UserNotFound);
} }
public static async Task<Member> ResolveMemberAsync(this DatabaseContext context, Snowflake id, public static async Task<Member> ResolveMemberAsync(
CancellationToken ct = default) this DatabaseContext context,
Snowflake id,
CancellationToken ct = default
)
{ {
var member = await context.Members var member = await context
.Include(m => m.User) .Members.Include(m => m.User)
.Where(m => !m.User.Deleted) .Where(m => !m.User.Deleted)
.FirstOrDefaultAsync(m => m.Id == id, ct); .FirstOrDefaultAsync(m => m.Id == id, ct);
if (member != null) return member; if (member != null)
throw new ApiError.NotFound("No member with that ID found.", code: ErrorCode.MemberNotFound); return member;
throw new ApiError.NotFound(
"No member with that ID found.",
code: ErrorCode.MemberNotFound
);
} }
public static async Task<Member> ResolveMemberAsync(this DatabaseContext context, string userRef, string memberRef, public static async Task<Member> ResolveMemberAsync(
Token? token, CancellationToken ct = default) this DatabaseContext context,
string userRef,
string memberRef,
Token? token,
CancellationToken ct = default
)
{ {
var user = await context.ResolveUserAsync(userRef, token, ct); var user = await context.ResolveUserAsync(userRef, token, ct);
return await context.ResolveMemberAsync(user.Id, memberRef, ct); return await context.ResolveMemberAsync(user.Id, memberRef, ct);
} }
public static async Task<Member> ResolveMemberAsync(this DatabaseContext context, Snowflake userId, public static async Task<Member> ResolveMemberAsync(
string memberRef, CancellationToken ct = default) this DatabaseContext context,
Snowflake userId,
string memberRef,
CancellationToken ct = default
)
{ {
Member? member; Member? member;
if (Snowflake.TryParse(memberRef, out var snowflake)) if (Snowflake.TryParse(memberRef, out var snowflake))
{ {
member = await context.Members member = await context
.Include(m => m.User) .Members.Include(m => m.User)
.Include(m => m.ProfileFlags) .Include(m => m.ProfileFlags)
.Where(m => !m.User.Deleted) .Where(m => !m.User.Deleted)
.FirstOrDefaultAsync(m => m.Id == snowflake && m.UserId == userId, ct); .FirstOrDefaultAsync(m => m.Id == snowflake && m.UserId == userId, ct);
if (member != null) return member; if (member != null)
return member;
} }
member = await context.Members member = await context
.Include(m => m.User) .Members.Include(m => m.User)
.Include(m => m.ProfileFlags) .Include(m => m.ProfileFlags)
.Where(m => !m.User.Deleted) .Where(m => !m.User.Deleted)
.FirstOrDefaultAsync(m => m.Name == memberRef && m.UserId == userId, ct); .FirstOrDefaultAsync(m => m.Name == memberRef && m.UserId == userId, ct);
if (member != null) return member; if (member != null)
throw new ApiError.NotFound("No member with that ID or name found.", code: ErrorCode.MemberNotFound); return member;
throw new ApiError.NotFound(
"No member with that ID or name found.",
code: ErrorCode.MemberNotFound
);
} }
public static async Task<Application> GetFrontendApplicationAsync(this DatabaseContext context, public static async Task<Application> GetFrontendApplicationAsync(
CancellationToken ct = default) this DatabaseContext context,
CancellationToken ct = default
)
{ {
var app = await context.Applications.FirstOrDefaultAsync(a => a.Id == new Snowflake(0), ct); var app = await context.Applications.FirstOrDefaultAsync(a => a.Id == new Snowflake(0), ct);
if (app != null) return app; if (app != null)
return app;
app = new Application app = new Application
{ {
@ -107,27 +146,42 @@ public static class DatabaseQueryExtensions
return app; return app;
} }
public static async Task<Token?> GetToken(this DatabaseContext context, byte[] rawToken, public static async Task<Token?> GetToken(
CancellationToken ct = default) this DatabaseContext context,
byte[] rawToken,
CancellationToken ct = default
)
{ {
var hash = SHA512.HashData(rawToken); var hash = SHA512.HashData(rawToken);
var oauthToken = await context.Tokens var oauthToken = await context
.Include(t => t.Application) .Tokens.Include(t => t.Application)
.Include(t => t.User) .Include(t => t.User)
.FirstOrDefaultAsync( .FirstOrDefaultAsync(
t => t.Hash == hash && t.ExpiresAt > SystemClock.Instance.GetCurrentInstant() && !t.ManuallyExpired, t =>
ct); t.Hash == hash
&& t.ExpiresAt > SystemClock.Instance.GetCurrentInstant()
&& !t.ManuallyExpired,
ct
);
return oauthToken; return oauthToken;
} }
public static async Task<Snowflake?> GetTokenUserId(this DatabaseContext context, byte[] rawToken, public static async Task<Snowflake?> GetTokenUserId(
CancellationToken ct = default) this DatabaseContext context,
byte[] rawToken,
CancellationToken ct = default
)
{ {
var hash = SHA512.HashData(rawToken); var hash = SHA512.HashData(rawToken);
return await context.Tokens return await context
.Where(t => t.Hash == hash && t.ExpiresAt > SystemClock.Instance.GetCurrentInstant() && !t.ManuallyExpired) .Tokens.Where(t =>
.Select(t => t.UserId).FirstOrDefaultAsync(ct); t.Hash == hash
&& t.ExpiresAt > SystemClock.Instance.GetCurrentInstant()
&& !t.ManuallyExpired
)
.Select(t => t.UserId)
.FirstOrDefaultAsync(ct);
} }
} }

View file

@ -5,23 +5,30 @@ namespace Foxnouns.Backend.Database;
public static class FlagQueryExtensions public static class FlagQueryExtensions
{ {
private static async Task<List<PrideFlag>> GetFlagsAsync(this DatabaseContext db, Snowflake userId) => private static async Task<List<PrideFlag>> GetFlagsAsync(
await db.PrideFlags.Where(f => f.UserId == userId).OrderBy(f => f.Id).ToListAsync(); this DatabaseContext db,
Snowflake userId
) => await db.PrideFlags.Where(f => f.UserId == userId).OrderBy(f => f.Id).ToListAsync();
/// <summary> /// <summary>
/// Sets the user's profile flags to the given IDs. Returns a validation error if any of the flag IDs are unknown /// Sets the user's profile flags to the given IDs. Returns a validation error if any of the flag IDs are unknown
/// or if too many IDs are given. Duplicates are allowed. /// or if too many IDs are given. Duplicates are allowed.
/// </summary> /// </summary>
public static async Task<ValidationError?> SetUserFlagsAsync(this DatabaseContext db, Snowflake userId, public static async Task<ValidationError?> SetUserFlagsAsync(
Snowflake[] flagIds) this DatabaseContext db,
Snowflake userId,
Snowflake[] flagIds
)
{ {
var currentFlags = await db.UserFlags.Where(f => f.UserId == userId).ToListAsync(); var currentFlags = await db.UserFlags.Where(f => f.UserId == userId).ToListAsync();
foreach (var flag in currentFlags) foreach (var flag in currentFlags)
db.UserFlags.Remove(flag); db.UserFlags.Remove(flag);
// If there's no new flags to set, we're done // If there's no new flags to set, we're done
if (flagIds.Length == 0) return null; if (flagIds.Length == 0)
if (flagIds.Length > 100) return ValidationError.LengthError("Too many profile flags", 0, 100, flagIds.Length); return null;
if (flagIds.Length > 100)
return ValidationError.LengthError("Too many profile flags", 0, 100, flagIds.Length);
var flags = await db.GetFlagsAsync(userId); var flags = await db.GetFlagsAsync(userId);
var unknownFlagIds = flagIds.Where(id => flags.All(f => f.Id != id)).ToArray(); var unknownFlagIds = flagIds.Where(id => flags.All(f => f.Id != id)).ToArray();
@ -34,22 +41,32 @@ public static class FlagQueryExtensions
return null; return null;
} }
public static async Task<ValidationError?> SetMemberFlagsAsync(this DatabaseContext db, Snowflake userId, public static async Task<ValidationError?> SetMemberFlagsAsync(
Snowflake memberId, Snowflake[] flagIds) this DatabaseContext db,
Snowflake userId,
Snowflake memberId,
Snowflake[] flagIds
)
{ {
var currentFlags = await db.MemberFlags.Where(f => f.MemberId == memberId).ToListAsync(); var currentFlags = await db.MemberFlags.Where(f => f.MemberId == memberId).ToListAsync();
foreach (var flag in currentFlags) foreach (var flag in currentFlags)
db.MemberFlags.Remove(flag); db.MemberFlags.Remove(flag);
if (flagIds.Length == 0) return null; if (flagIds.Length == 0)
if (flagIds.Length > 100) return ValidationError.LengthError("Too many profile flags", 0, 100, flagIds.Length); return null;
if (flagIds.Length > 100)
return ValidationError.LengthError("Too many profile flags", 0, 100, flagIds.Length);
var flags = await db.GetFlagsAsync(userId); var flags = await db.GetFlagsAsync(userId);
var unknownFlagIds = flagIds.Where(id => flags.All(f => f.Id != id)).ToArray(); var unknownFlagIds = flagIds.Where(id => flags.All(f => f.Id != id)).ToArray();
if (unknownFlagIds.Length != 0) if (unknownFlagIds.Length != 0)
return ValidationError.GenericValidationError("Unknown flag IDs", unknownFlagIds); return ValidationError.GenericValidationError("Unknown flag IDs", unknownFlagIds);
var memberFlags = flagIds.Select(id => new MemberFlag { PrideFlagId = id, MemberId = memberId }); var memberFlags = flagIds.Select(id => new MemberFlag
{
PrideFlagId = id,
MemberId = memberId,
});
db.MemberFlags.AddRange(memberFlags); db.MemberFlags.AddRange(memberFlags);
return null; return null;

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations; using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Migrations;
using NodaTime; using NodaTime;
#nullable disable #nullable disable
@ -22,12 +22,13 @@ namespace Foxnouns.Backend.Database.Migrations
domain = table.Column<string>(type: "text", nullable: false), domain = table.Column<string>(type: "text", nullable: false),
client_id = table.Column<string>(type: "text", nullable: false), client_id = table.Column<string>(type: "text", nullable: false),
client_secret = table.Column<string>(type: "text", nullable: false), client_secret = table.Column<string>(type: "text", nullable: false),
instance_type = table.Column<int>(type: "integer", nullable: false) instance_type = table.Column<int>(type: "integer", nullable: false),
}, },
constraints: table => constraints: table =>
{ {
table.PrimaryKey("pk_fediverse_applications", x => x.id); table.PrimaryKey("pk_fediverse_applications", x => x.id);
}); }
);
migrationBuilder.CreateTable( migrationBuilder.CreateTable(
name: "users", name: "users",
@ -43,12 +44,13 @@ namespace Foxnouns.Backend.Database.Migrations
role = table.Column<int>(type: "integer", nullable: false), role = table.Column<int>(type: "integer", nullable: false),
fields = table.Column<string>(type: "jsonb", nullable: false), fields = table.Column<string>(type: "jsonb", nullable: false),
names = table.Column<string>(type: "jsonb", nullable: false), names = table.Column<string>(type: "jsonb", nullable: false),
pronouns = table.Column<string>(type: "jsonb", nullable: false) pronouns = table.Column<string>(type: "jsonb", nullable: false),
}, },
constraints: table => constraints: table =>
{ {
table.PrimaryKey("pk_users", x => x.id); table.PrimaryKey("pk_users", x => x.id);
}); }
);
migrationBuilder.CreateTable( migrationBuilder.CreateTable(
name: "auth_methods", name: "auth_methods",
@ -59,7 +61,7 @@ namespace Foxnouns.Backend.Database.Migrations
remote_id = table.Column<string>(type: "text", nullable: false), remote_id = table.Column<string>(type: "text", nullable: false),
remote_username = table.Column<string>(type: "text", nullable: true), remote_username = table.Column<string>(type: "text", nullable: true),
user_id = table.Column<long>(type: "bigint", nullable: false), user_id = table.Column<long>(type: "bigint", nullable: false),
fediverse_application_id = table.Column<long>(type: "bigint", nullable: true) fediverse_application_id = table.Column<long>(type: "bigint", nullable: true),
}, },
constraints: table => constraints: table =>
{ {
@ -68,14 +70,17 @@ namespace Foxnouns.Backend.Database.Migrations
name: "fk_auth_methods_fediverse_applications_fediverse_application_id", name: "fk_auth_methods_fediverse_applications_fediverse_application_id",
column: x => x.fediverse_application_id, column: x => x.fediverse_application_id,
principalTable: "fediverse_applications", principalTable: "fediverse_applications",
principalColumn: "id"); principalColumn: "id"
);
table.ForeignKey( table.ForeignKey(
name: "fk_auth_methods_users_user_id", name: "fk_auth_methods_users_user_id",
column: x => x.user_id, column: x => x.user_id,
principalTable: "users", principalTable: "users",
principalColumn: "id", principalColumn: "id",
onDelete: ReferentialAction.Cascade); onDelete: ReferentialAction.Cascade
}); );
}
);
migrationBuilder.CreateTable( migrationBuilder.CreateTable(
name: "members", name: "members",
@ -91,7 +96,7 @@ namespace Foxnouns.Backend.Database.Migrations
user_id = table.Column<long>(type: "bigint", nullable: false), user_id = table.Column<long>(type: "bigint", nullable: false),
fields = table.Column<string>(type: "jsonb", nullable: false), fields = table.Column<string>(type: "jsonb", nullable: false),
names = table.Column<string>(type: "jsonb", nullable: false), names = table.Column<string>(type: "jsonb", nullable: false),
pronouns = table.Column<string>(type: "jsonb", nullable: false) pronouns = table.Column<string>(type: "jsonb", nullable: false),
}, },
constraints: table => constraints: table =>
{ {
@ -101,18 +106,23 @@ namespace Foxnouns.Backend.Database.Migrations
column: x => x.user_id, column: x => x.user_id,
principalTable: "users", principalTable: "users",
principalColumn: "id", principalColumn: "id",
onDelete: ReferentialAction.Cascade); onDelete: ReferentialAction.Cascade
}); );
}
);
migrationBuilder.CreateTable( migrationBuilder.CreateTable(
name: "tokens", name: "tokens",
columns: table => new columns: table => new
{ {
id = table.Column<long>(type: "bigint", nullable: false), id = table.Column<long>(type: "bigint", nullable: false),
expires_at = table.Column<Instant>(type: "timestamp with time zone", nullable: false), expires_at = table.Column<Instant>(
type: "timestamp with time zone",
nullable: false
),
scopes = table.Column<string[]>(type: "text[]", nullable: false), scopes = table.Column<string[]>(type: "text[]", nullable: false),
manually_expired = table.Column<bool>(type: "boolean", nullable: false), manually_expired = table.Column<bool>(type: "boolean", nullable: false),
user_id = table.Column<long>(type: "bigint", nullable: false) user_id = table.Column<long>(type: "bigint", nullable: false),
}, },
constraints: table => constraints: table =>
{ {
@ -122,53 +132,56 @@ namespace Foxnouns.Backend.Database.Migrations
column: x => x.user_id, column: x => x.user_id,
principalTable: "users", principalTable: "users",
principalColumn: "id", principalColumn: "id",
onDelete: ReferentialAction.Cascade); onDelete: ReferentialAction.Cascade
}); );
}
);
migrationBuilder.CreateIndex( migrationBuilder.CreateIndex(
name: "ix_auth_methods_fediverse_application_id", name: "ix_auth_methods_fediverse_application_id",
table: "auth_methods", table: "auth_methods",
column: "fediverse_application_id"); column: "fediverse_application_id"
);
migrationBuilder.CreateIndex( migrationBuilder.CreateIndex(
name: "ix_auth_methods_user_id", name: "ix_auth_methods_user_id",
table: "auth_methods", table: "auth_methods",
column: "user_id"); column: "user_id"
);
// EF Core doesn't support creating indexes on arbitrary expressions, so we have to create it manually. // EF Core doesn't support creating indexes on arbitrary expressions, so we have to create it manually.
// Due to historical reasons (I made a mistake while writing the initial migration for the Go version) // Due to historical reasons (I made a mistake while writing the initial migration for the Go version)
// only members have case-insensitive names. // only members have case-insensitive names.
migrationBuilder.Sql("CREATE UNIQUE INDEX ix_members_user_id_name ON members (user_id, lower(name))"); migrationBuilder.Sql(
"CREATE UNIQUE INDEX ix_members_user_id_name ON members (user_id, lower(name))"
);
migrationBuilder.CreateIndex( migrationBuilder.CreateIndex(
name: "ix_tokens_user_id", name: "ix_tokens_user_id",
table: "tokens", table: "tokens",
column: "user_id"); column: "user_id"
);
migrationBuilder.CreateIndex( migrationBuilder.CreateIndex(
name: "ix_users_username", name: "ix_users_username",
table: "users", table: "users",
column: "username", column: "username",
unique: true); unique: true
);
} }
/// <inheritdoc /> /// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder) protected override void Down(MigrationBuilder migrationBuilder)
{ {
migrationBuilder.DropTable( migrationBuilder.DropTable(name: "auth_methods");
name: "auth_methods");
migrationBuilder.DropTable( migrationBuilder.DropTable(name: "members");
name: "members");
migrationBuilder.DropTable( migrationBuilder.DropTable(name: "tokens");
name: "tokens");
migrationBuilder.DropTable( migrationBuilder.DropTable(name: "fediverse_applications");
name: "fediverse_applications");
migrationBuilder.DropTable( migrationBuilder.DropTable(name: "users");
name: "users");
} }
} }
} }

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations; using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Migrations;
#nullable disable #nullable disable
@ -18,14 +18,16 @@ namespace Foxnouns.Backend.Database.Migrations
table: "tokens", table: "tokens",
type: "bigint", type: "bigint",
nullable: false, nullable: false,
defaultValue: 0L); defaultValue: 0L
);
migrationBuilder.AddColumn<byte[]>( migrationBuilder.AddColumn<byte[]>(
name: "hash", name: "hash",
table: "tokens", table: "tokens",
type: "bytea", type: "bytea",
nullable: false, nullable: false,
defaultValue: new byte[0]); defaultValue: new byte[0]
);
migrationBuilder.CreateTable( migrationBuilder.CreateTable(
name: "applications", name: "applications",
@ -36,17 +38,19 @@ namespace Foxnouns.Backend.Database.Migrations
client_secret = table.Column<string>(type: "text", nullable: false), client_secret = table.Column<string>(type: "text", nullable: false),
name = table.Column<string>(type: "text", nullable: false), name = table.Column<string>(type: "text", nullable: false),
scopes = table.Column<string[]>(type: "text[]", nullable: false), scopes = table.Column<string[]>(type: "text[]", nullable: false),
redirect_uris = table.Column<string[]>(type: "text[]", nullable: false) redirect_uris = table.Column<string[]>(type: "text[]", nullable: false),
}, },
constraints: table => constraints: table =>
{ {
table.PrimaryKey("pk_applications", x => x.id); table.PrimaryKey("pk_applications", x => x.id);
}); }
);
migrationBuilder.CreateIndex( migrationBuilder.CreateIndex(
name: "ix_tokens_application_id", name: "ix_tokens_application_id",
table: "tokens", table: "tokens",
column: "application_id"); column: "application_id"
);
migrationBuilder.AddForeignKey( migrationBuilder.AddForeignKey(
name: "fk_tokens_applications_application_id", name: "fk_tokens_applications_application_id",
@ -54,7 +58,8 @@ namespace Foxnouns.Backend.Database.Migrations
column: "application_id", column: "application_id",
principalTable: "applications", principalTable: "applications",
principalColumn: "id", principalColumn: "id",
onDelete: ReferentialAction.Cascade); onDelete: ReferentialAction.Cascade
);
} }
/// <inheritdoc /> /// <inheritdoc />
@ -62,22 +67,16 @@ namespace Foxnouns.Backend.Database.Migrations
{ {
migrationBuilder.DropForeignKey( migrationBuilder.DropForeignKey(
name: "fk_tokens_applications_application_id", name: "fk_tokens_applications_application_id",
table: "tokens"); table: "tokens"
);
migrationBuilder.DropTable( migrationBuilder.DropTable(name: "applications");
name: "applications");
migrationBuilder.DropIndex( migrationBuilder.DropIndex(name: "ix_tokens_application_id", table: "tokens");
name: "ix_tokens_application_id",
table: "tokens");
migrationBuilder.DropColumn( migrationBuilder.DropColumn(name: "application_id", table: "tokens");
name: "application_id",
table: "tokens");
migrationBuilder.DropColumn( migrationBuilder.DropColumn(name: "hash", table: "tokens");
name: "hash",
table: "tokens");
} }
} }
} }

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations; using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Migrations;
#nullable disable #nullable disable
@ -18,15 +18,14 @@ namespace Foxnouns.Backend.Database.Migrations
table: "users", table: "users",
type: "boolean", type: "boolean",
nullable: false, nullable: false,
defaultValue: false); defaultValue: false
);
} }
/// <inheritdoc /> /// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder) protected override void Down(MigrationBuilder migrationBuilder)
{ {
migrationBuilder.DropColumn( migrationBuilder.DropColumn(name: "list_hidden", table: "users");
name: "list_hidden",
table: "users");
} }
} }
} }

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations; using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Migrations;
#nullable disable #nullable disable
@ -17,15 +17,14 @@ namespace Foxnouns.Backend.Database.Migrations
name: "password", name: "password",
table: "users", table: "users",
type: "text", type: "text",
nullable: true); nullable: true
);
} }
/// <inheritdoc /> /// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder) protected override void Down(MigrationBuilder migrationBuilder)
{ {
migrationBuilder.DropColumn( migrationBuilder.DropColumn(name: "password", table: "users");
name: "password",
table: "users");
} }
} }
} }

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations; using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Migrations;
using NodaTime; using NodaTime;
using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata; using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata;
@ -19,29 +19,37 @@ namespace Foxnouns.Backend.Database.Migrations
name: "temporary_keys", name: "temporary_keys",
columns: table => new columns: table => new
{ {
id = table.Column<long>(type: "bigint", nullable: false) id = table
.Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), .Column<long>(type: "bigint", nullable: false)
.Annotation(
"Npgsql:ValueGenerationStrategy",
NpgsqlValueGenerationStrategy.IdentityByDefaultColumn
),
key = table.Column<string>(type: "text", nullable: false), key = table.Column<string>(type: "text", nullable: false),
value = table.Column<string>(type: "text", nullable: false), value = table.Column<string>(type: "text", nullable: false),
expires = table.Column<Instant>(type: "timestamp with time zone", nullable: false) expires = table.Column<Instant>(
type: "timestamp with time zone",
nullable: false
),
}, },
constraints: table => constraints: table =>
{ {
table.PrimaryKey("pk_temporary_keys", x => x.id); table.PrimaryKey("pk_temporary_keys", x => x.id);
}); }
);
migrationBuilder.CreateIndex( migrationBuilder.CreateIndex(
name: "ix_temporary_keys_key", name: "ix_temporary_keys_key",
table: "temporary_keys", table: "temporary_keys",
column: "key", column: "key",
unique: true); unique: true
);
} }
/// <inheritdoc /> /// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder) protected override void Down(MigrationBuilder migrationBuilder)
{ {
migrationBuilder.DropTable( migrationBuilder.DropTable(name: "temporary_keys");
name: "temporary_keys");
} }
} }
} }

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations; using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Migrations;
using NodaTime; using NodaTime;
#nullable disable #nullable disable
@ -19,15 +19,14 @@ namespace Foxnouns.Backend.Database.Migrations
table: "users", table: "users",
type: "timestamp with time zone", type: "timestamp with time zone",
nullable: false, nullable: false,
defaultValueSql: "now()"); defaultValueSql: "now()"
);
} }
/// <inheritdoc /> /// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder) protected override void Down(MigrationBuilder migrationBuilder)
{ {
migrationBuilder.DropColumn( migrationBuilder.DropColumn(name: "last_active", table: "users");
name: "last_active",
table: "users");
} }
} }
} }

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations; using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Migrations;
using NodaTime; using NodaTime;
#nullable disable #nullable disable
@ -19,35 +19,32 @@ namespace Foxnouns.Backend.Database.Migrations
table: "users", table: "users",
type: "boolean", type: "boolean",
nullable: false, nullable: false,
defaultValue: false); defaultValue: false
);
migrationBuilder.AddColumn<Instant>( migrationBuilder.AddColumn<Instant>(
name: "deleted_at", name: "deleted_at",
table: "users", table: "users",
type: "timestamp with time zone", type: "timestamp with time zone",
nullable: true); nullable: true
);
migrationBuilder.AddColumn<long>( migrationBuilder.AddColumn<long>(
name: "deleted_by", name: "deleted_by",
table: "users", table: "users",
type: "bigint", type: "bigint",
nullable: true); nullable: true
);
} }
/// <inheritdoc /> /// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder) protected override void Down(MigrationBuilder migrationBuilder)
{ {
migrationBuilder.DropColumn( migrationBuilder.DropColumn(name: "deleted", table: "users");
name: "deleted",
table: "users");
migrationBuilder.DropColumn( migrationBuilder.DropColumn(name: "deleted_at", table: "users");
name: "deleted_at",
table: "users");
migrationBuilder.DropColumn( migrationBuilder.DropColumn(name: "deleted_by", table: "users");
name: "deleted_by",
table: "users");
} }
} }
} }

View file

@ -1,7 +1,7 @@
using System; using System;
using Microsoft.EntityFrameworkCore.Infrastructure;
using System.Collections.Generic; using System.Collections.Generic;
using Foxnouns.Backend.Database.Models; using Foxnouns.Backend.Database.Models;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Migrations; using Microsoft.EntityFrameworkCore.Migrations;
#nullable disable #nullable disable
@ -21,15 +21,14 @@ namespace Foxnouns.Backend.Database.Migrations
table: "users", table: "users",
type: "jsonb", type: "jsonb",
nullable: false, nullable: false,
defaultValueSql: "'{}'"); defaultValueSql: "'{}'"
);
} }
/// <inheritdoc /> /// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder) protected override void Down(MigrationBuilder migrationBuilder)
{ {
migrationBuilder.DropColumn( migrationBuilder.DropColumn(name: "custom_preferences", table: "users");
name: "custom_preferences",
table: "users");
} }
} }
} }

View file

@ -19,15 +19,14 @@ namespace Foxnouns.Backend.Database.Migrations
table: "users", table: "users",
type: "jsonb", type: "jsonb",
nullable: false, nullable: false,
defaultValueSql: "'{}'"); defaultValueSql: "'{}'"
);
} }
/// <inheritdoc /> /// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder) protected override void Down(MigrationBuilder migrationBuilder)
{ {
migrationBuilder.DropColumn( migrationBuilder.DropColumn(name: "settings", table: "users");
name: "settings",
table: "users");
} }
} }
} }

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations; using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Migrations;
using NodaTime; using NodaTime;
#nullable disable #nullable disable
@ -18,38 +18,46 @@ namespace Foxnouns.Backend.Database.Migrations
name: "sid", name: "sid",
table: "users", table: "users",
type: "text", type: "text",
nullable: true); nullable: true
);
migrationBuilder.AddColumn<Instant>( migrationBuilder.AddColumn<Instant>(
name: "last_sid_reroll", name: "last_sid_reroll",
table: "users", table: "users",
type: "timestamp with time zone", type: "timestamp with time zone",
nullable: false, nullable: false,
defaultValueSql: "now() - '1 hour'::interval"); defaultValueSql: "now() - '1 hour'::interval"
);
migrationBuilder.AddColumn<string>( migrationBuilder.AddColumn<string>(
name: "sid", name: "sid",
table: "members", table: "members",
type: "text", type: "text",
nullable: true); nullable: true
);
migrationBuilder.CreateIndex( migrationBuilder.CreateIndex(
name: "ix_users_sid", name: "ix_users_sid",
table: "users", table: "users",
column: "sid", column: "sid",
unique: true); unique: true
);
migrationBuilder.CreateIndex( migrationBuilder.CreateIndex(
name: "ix_members_sid", name: "ix_members_sid",
table: "members", table: "members",
column: "sid", column: "sid",
unique: true); unique: true
);
migrationBuilder.Sql(@"create function generate_sid(len int) returns text as $$ migrationBuilder.Sql(
@"create function generate_sid(len int) returns text as $$
select string_agg(substr('abcdefghijklmnopqrstuvwxyz', ceil(random() * 26)::integer, 1), '') from generate_series(1, len) select string_agg(substr('abcdefghijklmnopqrstuvwxyz', ceil(random() * 26)::integer, 1), '') from generate_series(1, len)
$$ language sql volatile; $$ language sql volatile;
"); "
migrationBuilder.Sql(@"create function find_free_user_sid() returns text as $$ );
migrationBuilder.Sql(
@"create function find_free_user_sid() returns text as $$
declare new_sid text; declare new_sid text;
begin begin
loop loop
@ -58,8 +66,10 @@ begin
end loop; end loop;
end end
$$ language plpgsql volatile; $$ language plpgsql volatile;
"); "
migrationBuilder.Sql(@"create function find_free_member_sid() returns text as $$ );
migrationBuilder.Sql(
@"create function find_free_member_sid() returns text as $$
declare new_sid text; declare new_sid text;
begin begin
loop loop
@ -67,7 +77,8 @@ begin
if not exists (select 1 from members where sid = new_sid) then return new_sid; end if; if not exists (select 1 from members where sid = new_sid) then return new_sid; end if;
end loop; end loop;
end end
$$ language plpgsql volatile;"); $$ language plpgsql volatile;"
);
} }
/// <inheritdoc /> /// <inheritdoc />
@ -77,25 +88,15 @@ $$ language plpgsql volatile;");
migrationBuilder.Sql("drop function find_free_user_sid;"); migrationBuilder.Sql("drop function find_free_user_sid;");
migrationBuilder.Sql("drop function generate_sid;"); migrationBuilder.Sql("drop function generate_sid;");
migrationBuilder.DropIndex( migrationBuilder.DropIndex(name: "ix_users_sid", table: "users");
name: "ix_users_sid",
table: "users");
migrationBuilder.DropIndex( migrationBuilder.DropIndex(name: "ix_members_sid", table: "members");
name: "ix_members_sid",
table: "members");
migrationBuilder.DropColumn( migrationBuilder.DropColumn(name: "sid", table: "users");
name: "sid",
table: "users");
migrationBuilder.DropColumn( migrationBuilder.DropColumn(name: "last_sid_reroll", table: "users");
name: "last_sid_reroll",
table: "users");
migrationBuilder.DropColumn( migrationBuilder.DropColumn(name: "sid", table: "members");
name: "sid",
table: "members");
} }
} }
} }

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations; using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Migrations;
using NodaTime; using NodaTime;
#nullable disable #nullable disable
@ -22,7 +22,8 @@ namespace Foxnouns.Backend.Database.Migrations
defaultValueSql: "find_free_user_sid()", defaultValueSql: "find_free_user_sid()",
oldClrType: typeof(string), oldClrType: typeof(string),
oldType: "text", oldType: "text",
oldNullable: true); oldNullable: true
);
migrationBuilder.AlterColumn<string>( migrationBuilder.AlterColumn<string>(
name: "sid", name: "sid",
@ -32,7 +33,8 @@ namespace Foxnouns.Backend.Database.Migrations
defaultValueSql: "find_free_member_sid()", defaultValueSql: "find_free_member_sid()",
oldClrType: typeof(string), oldClrType: typeof(string),
oldType: "text", oldType: "text",
oldNullable: true); oldNullable: true
);
} }
/// <inheritdoc /> /// <inheritdoc />
@ -45,7 +47,8 @@ namespace Foxnouns.Backend.Database.Migrations
nullable: true, nullable: true,
oldClrType: typeof(string), oldClrType: typeof(string),
oldType: "text", oldType: "text",
oldDefaultValueSql: "find_free_user_sid()"); oldDefaultValueSql: "find_free_user_sid()"
);
migrationBuilder.AlterColumn<string>( migrationBuilder.AlterColumn<string>(
name: "sid", name: "sid",
@ -54,7 +57,8 @@ namespace Foxnouns.Backend.Database.Migrations
nullable: true, nullable: true,
oldClrType: typeof(string), oldClrType: typeof(string),
oldType: "text", oldType: "text",
oldDefaultValueSql: "find_free_member_sid()"); oldDefaultValueSql: "find_free_member_sid()"
);
} }
} }
} }

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations; using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Migrations;
using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata; using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata;
#nullable disable #nullable disable
@ -22,7 +22,7 @@ namespace Foxnouns.Backend.Database.Migrations
user_id = table.Column<long>(type: "bigint", nullable: false), user_id = table.Column<long>(type: "bigint", nullable: false),
hash = table.Column<string>(type: "text", nullable: false), hash = table.Column<string>(type: "text", nullable: false),
name = table.Column<string>(type: "text", nullable: false), name = table.Column<string>(type: "text", nullable: false),
description = table.Column<string>(type: "text", nullable: true) description = table.Column<string>(type: "text", nullable: true),
}, },
constraints: table => constraints: table =>
{ {
@ -32,17 +32,23 @@ namespace Foxnouns.Backend.Database.Migrations
column: x => x.user_id, column: x => x.user_id,
principalTable: "users", principalTable: "users",
principalColumn: "id", principalColumn: "id",
onDelete: ReferentialAction.Cascade); onDelete: ReferentialAction.Cascade
}); );
}
);
migrationBuilder.CreateTable( migrationBuilder.CreateTable(
name: "member_flags", name: "member_flags",
columns: table => new columns: table => new
{ {
id = table.Column<long>(type: "bigint", nullable: false) id = table
.Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), .Column<long>(type: "bigint", nullable: false)
.Annotation(
"Npgsql:ValueGenerationStrategy",
NpgsqlValueGenerationStrategy.IdentityByDefaultColumn
),
member_id = table.Column<long>(type: "bigint", nullable: false), member_id = table.Column<long>(type: "bigint", nullable: false),
pride_flag_id = table.Column<long>(type: "bigint", nullable: false) pride_flag_id = table.Column<long>(type: "bigint", nullable: false),
}, },
constraints: table => constraints: table =>
{ {
@ -52,23 +58,30 @@ namespace Foxnouns.Backend.Database.Migrations
column: x => x.member_id, column: x => x.member_id,
principalTable: "members", principalTable: "members",
principalColumn: "id", principalColumn: "id",
onDelete: ReferentialAction.Cascade); onDelete: ReferentialAction.Cascade
);
table.ForeignKey( table.ForeignKey(
name: "fk_member_flags_pride_flags_pride_flag_id", name: "fk_member_flags_pride_flags_pride_flag_id",
column: x => x.pride_flag_id, column: x => x.pride_flag_id,
principalTable: "pride_flags", principalTable: "pride_flags",
principalColumn: "id", principalColumn: "id",
onDelete: ReferentialAction.Cascade); onDelete: ReferentialAction.Cascade
}); );
}
);
migrationBuilder.CreateTable( migrationBuilder.CreateTable(
name: "user_flags", name: "user_flags",
columns: table => new columns: table => new
{ {
id = table.Column<long>(type: "bigint", nullable: false) id = table
.Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), .Column<long>(type: "bigint", nullable: false)
.Annotation(
"Npgsql:ValueGenerationStrategy",
NpgsqlValueGenerationStrategy.IdentityByDefaultColumn
),
user_id = table.Column<long>(type: "bigint", nullable: false), user_id = table.Column<long>(type: "bigint", nullable: false),
pride_flag_id = table.Column<long>(type: "bigint", nullable: false) pride_flag_id = table.Column<long>(type: "bigint", nullable: false),
}, },
constraints: table => constraints: table =>
{ {
@ -78,52 +91,57 @@ namespace Foxnouns.Backend.Database.Migrations
column: x => x.pride_flag_id, column: x => x.pride_flag_id,
principalTable: "pride_flags", principalTable: "pride_flags",
principalColumn: "id", principalColumn: "id",
onDelete: ReferentialAction.Cascade); onDelete: ReferentialAction.Cascade
);
table.ForeignKey( table.ForeignKey(
name: "fk_user_flags_users_user_id", name: "fk_user_flags_users_user_id",
column: x => x.user_id, column: x => x.user_id,
principalTable: "users", principalTable: "users",
principalColumn: "id", principalColumn: "id",
onDelete: ReferentialAction.Cascade); onDelete: ReferentialAction.Cascade
}); );
}
);
migrationBuilder.CreateIndex( migrationBuilder.CreateIndex(
name: "ix_member_flags_member_id", name: "ix_member_flags_member_id",
table: "member_flags", table: "member_flags",
column: "member_id"); column: "member_id"
);
migrationBuilder.CreateIndex( migrationBuilder.CreateIndex(
name: "ix_member_flags_pride_flag_id", name: "ix_member_flags_pride_flag_id",
table: "member_flags", table: "member_flags",
column: "pride_flag_id"); column: "pride_flag_id"
);
migrationBuilder.CreateIndex( migrationBuilder.CreateIndex(
name: "ix_pride_flags_user_id", name: "ix_pride_flags_user_id",
table: "pride_flags", table: "pride_flags",
column: "user_id"); column: "user_id"
);
migrationBuilder.CreateIndex( migrationBuilder.CreateIndex(
name: "ix_user_flags_pride_flag_id", name: "ix_user_flags_pride_flag_id",
table: "user_flags", table: "user_flags",
column: "pride_flag_id"); column: "pride_flag_id"
);
migrationBuilder.CreateIndex( migrationBuilder.CreateIndex(
name: "ix_user_flags_user_id", name: "ix_user_flags_user_id",
table: "user_flags", table: "user_flags",
column: "user_id"); column: "user_id"
);
} }
/// <inheritdoc /> /// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder) protected override void Down(MigrationBuilder migrationBuilder)
{ {
migrationBuilder.DropTable( migrationBuilder.DropTable(name: "member_flags");
name: "member_flags");
migrationBuilder.DropTable( migrationBuilder.DropTable(name: "user_flags");
name: "user_flags");
migrationBuilder.DropTable( migrationBuilder.DropTable(name: "pride_flags");
name: "pride_flags");
} }
} }
} }

View file

@ -11,20 +11,30 @@ public class Application : BaseModel
public required string[] Scopes { get; init; } public required string[] Scopes { get; init; }
public required string[] RedirectUris { get; init; } public required string[] RedirectUris { get; init; }
public static Application Create(ISnowflakeGenerator snowflakeGenerator, string name, string[] scopes, public static Application Create(
string[] redirectUrls) ISnowflakeGenerator snowflakeGenerator,
string name,
string[] scopes,
string[] redirectUrls
)
{ {
var clientId = RandomNumberGenerator.GetHexString(32, true); var clientId = RandomNumberGenerator.GetHexString(32, true);
var clientSecret = AuthUtils.RandomToken(); var clientSecret = AuthUtils.RandomToken();
if (scopes.Except(AuthUtils.ApplicationScopes).Any()) if (scopes.Except(AuthUtils.ApplicationScopes).Any())
{ {
throw new ArgumentException("Invalid scopes passed to Application.Create", nameof(scopes)); throw new ArgumentException(
"Invalid scopes passed to Application.Create",
nameof(scopes)
);
} }
if (redirectUrls.Any(s => !AuthUtils.ValidateRedirectUri(s))) if (redirectUrls.Any(s => !AuthUtils.ValidateRedirectUri(s)))
{ {
throw new ArgumentException("Invalid redirect URLs passed to Application.Create", nameof(redirectUrls)); throw new ArgumentException(
"Invalid redirect URLs passed to Application.Create",
nameof(redirectUrls)
);
} }
return new Application return new Application
@ -34,7 +44,7 @@ public class Application : BaseModel
ClientSecret = clientSecret, ClientSecret = clientSecret,
Name = name, Name = name,
Scopes = scopes, Scopes = scopes,
RedirectUris = redirectUrls RedirectUris = redirectUrls,
}; };
} }
} }

View file

@ -11,5 +11,5 @@ public class FediverseApplication : BaseModel
public enum FediverseInstanceType public enum FediverseInstanceType
{ {
MastodonApi, MastodonApi,
MisskeyApi MisskeyApi,
} }

View file

@ -37,7 +37,9 @@ public class User : BaseModel
public bool Deleted { get; set; } public bool Deleted { get; set; }
public Instant? DeletedAt { get; set; } public Instant? DeletedAt { get; set; }
public Snowflake? DeletedBy { get; set; } public Snowflake? DeletedBy { get; set; }
[NotMapped] public bool? SelfDelete => Deleted ? DeletedBy != null : null;
[NotMapped]
public bool? SelfDelete => Deleted ? DeletedBy != null : null;
public class CustomPreference public class CustomPreference
{ {

View file

@ -41,19 +41,26 @@ public readonly struct Snowflake(ulong value) : IEquatable<Snowflake>
public short Increment => (short)(Value & 0xFFF); public short Increment => (short)(Value & 0xFFF);
public static bool operator <(Snowflake arg1, Snowflake arg2) => arg1.Value < arg2.Value; public static bool operator <(Snowflake arg1, Snowflake arg2) => arg1.Value < arg2.Value;
public static bool operator >(Snowflake arg1, Snowflake arg2) => arg1.Value > arg2.Value; public static bool operator >(Snowflake arg1, Snowflake arg2) => arg1.Value > arg2.Value;
public static bool operator ==(Snowflake arg1, Snowflake arg2) => arg1.Value == arg2.Value; public static bool operator ==(Snowflake arg1, Snowflake arg2) => arg1.Value == arg2.Value;
public static bool operator !=(Snowflake arg1, Snowflake arg2) => arg1.Value != arg2.Value; public static bool operator !=(Snowflake arg1, Snowflake arg2) => arg1.Value != arg2.Value;
public static implicit operator ulong(Snowflake s) => s.Value; public static implicit operator ulong(Snowflake s) => s.Value;
public static implicit operator long(Snowflake s) => (long)s.Value; public static implicit operator long(Snowflake s) => (long)s.Value;
public static implicit operator Snowflake(ulong n) => new(n); public static implicit operator Snowflake(ulong n) => new(n);
public static implicit operator Snowflake(long n) => new((ulong)n); public static implicit operator Snowflake(long n) => new((ulong)n);
public static bool TryParse(string input, [NotNullWhen(true)] out Snowflake? snowflake) public static bool TryParse(string input, [NotNullWhen(true)] out Snowflake? snowflake)
{ {
snowflake = null; snowflake = null;
if (!ulong.TryParse(input, out var res)) return false; if (!ulong.TryParse(input, out var res))
return false;
snowflake = new Snowflake(res); snowflake = new Snowflake(res);
return true; return true;
} }
@ -66,27 +73,37 @@ public readonly struct Snowflake(ulong value) : IEquatable<Snowflake>
} }
public override int GetHashCode() => Value.GetHashCode(); public override int GetHashCode() => Value.GetHashCode();
public override string ToString() => Value.ToString(); public override string ToString() => Value.ToString();
/// <summary> /// <summary>
/// An Entity Framework ValueConverter for Snowflakes to longs. /// An Entity Framework ValueConverter for Snowflakes to longs.
/// </summary> /// </summary>
// ReSharper disable once ClassNeverInstantiated.Global // ReSharper disable once ClassNeverInstantiated.Global
public class ValueConverter() : ValueConverter<Snowflake, long>( public class ValueConverter()
: ValueConverter<Snowflake, long>(
convertToProviderExpression: x => x, convertToProviderExpression: x => x,
convertFromProviderExpression: x => x convertFromProviderExpression: x => x
); );
private class JsonConverter : JsonConverter<Snowflake> private class JsonConverter : JsonConverter<Snowflake>
{ {
public override void WriteJson(JsonWriter writer, Snowflake value, JsonSerializer serializer) public override void WriteJson(
JsonWriter writer,
Snowflake value,
JsonSerializer serializer
)
{ {
writer.WriteValue(value.Value.ToString()); writer.WriteValue(value.Value.ToString());
} }
public override Snowflake ReadJson(JsonReader reader, Type objectType, Snowflake existingValue, public override Snowflake ReadJson(
JsonReader reader,
Type objectType,
Snowflake existingValue,
bool hasExistingValue, bool hasExistingValue,
JsonSerializer serializer) JsonSerializer serializer
)
{ {
return ulong.Parse((string)reader.Value!); return ulong.Parse((string)reader.Value!);
} }
@ -97,10 +114,16 @@ public readonly struct Snowflake(ulong value) : IEquatable<Snowflake>
public override bool CanConvertFrom(ITypeDescriptorContext? context, Type sourceType) => public override bool CanConvertFrom(ITypeDescriptorContext? context, Type sourceType) =>
sourceType == typeof(string); sourceType == typeof(string);
public override bool CanConvertTo(ITypeDescriptorContext? context, [NotNullWhen(true)] Type? destinationType) => public override bool CanConvertTo(
destinationType == typeof(Snowflake); ITypeDescriptorContext? context,
[NotNullWhen(true)] Type? destinationType
) => destinationType == typeof(Snowflake);
public override object? ConvertFrom(ITypeDescriptorContext? context, CultureInfo? culture, object value) public override object? ConvertFrom(
ITypeDescriptorContext? context,
CultureInfo? culture,
object value
)
{ {
return TryParse((string)value, out var snowflake) ? snowflake : null; return TryParse((string)value, out var snowflake) ? snowflake : null;
} }

View file

@ -32,13 +32,19 @@ public class SnowflakeGenerator : ISnowflakeGenerator
var threadId = Environment.CurrentManagedThreadId % 32; var threadId = Environment.CurrentManagedThreadId % 32;
var timestamp = time.Value.ToUnixTimeMilliseconds() - Snowflake.Epoch; var timestamp = time.Value.ToUnixTimeMilliseconds() - Snowflake.Epoch;
return (timestamp << 22) | (uint)(_processId << 17) | (uint)(threadId << 12) | (increment % 4096); return (timestamp << 22)
| (uint)(_processId << 17)
| (uint)(threadId << 12)
| (increment % 4096);
} }
} }
public static class SnowflakeGeneratorServiceExtensions public static class SnowflakeGeneratorServiceExtensions
{ {
public static IServiceCollection AddSnowflakeGenerator(this IServiceCollection services, int? processId = null) public static IServiceCollection AddSnowflakeGenerator(
this IServiceCollection services,
int? processId = null
)
{ {
return services.AddSingleton<ISnowflakeGenerator>(new SnowflakeGenerator(processId)); return services.AddSingleton<ISnowflakeGenerator>(new SnowflakeGenerator(processId));
} }

View file

@ -9,39 +9,47 @@ public class FoxnounsError(string message, Exception? inner = null) : Exception(
{ {
public Exception? Inner => inner; public Exception? Inner => inner;
public class DatabaseError(string message, Exception? inner = null) : FoxnounsError(message, inner); public class DatabaseError(string message, Exception? inner = null)
: FoxnounsError(message, inner);
public class UnknownEntityError(Type entityType, Exception? inner = null) public class UnknownEntityError(Type entityType, Exception? inner = null)
: DatabaseError($"Entity of type {entityType.Name} not found", inner); : DatabaseError($"Entity of type {entityType.Name} not found", inner);
} }
public class ApiError(string message, HttpStatusCode? statusCode = null, ErrorCode? errorCode = null) public class ApiError(
: FoxnounsError(message) string message,
HttpStatusCode? statusCode = null,
ErrorCode? errorCode = null
) : FoxnounsError(message)
{ {
public readonly HttpStatusCode StatusCode = statusCode ?? HttpStatusCode.InternalServerError; public readonly HttpStatusCode StatusCode = statusCode ?? HttpStatusCode.InternalServerError;
public readonly ErrorCode ErrorCode = errorCode ?? ErrorCode.InternalServerError; public readonly ErrorCode ErrorCode = errorCode ?? ErrorCode.InternalServerError;
public class Unauthorized(string message, ErrorCode errorCode = ErrorCode.AuthenticationError) : ApiError(message, public class Unauthorized(string message, ErrorCode errorCode = ErrorCode.AuthenticationError)
statusCode: HttpStatusCode.Unauthorized, : ApiError(message, statusCode: HttpStatusCode.Unauthorized, errorCode: errorCode);
errorCode: errorCode);
public class Forbidden( public class Forbidden(
string message, string message,
IEnumerable<string>? scopes = null, IEnumerable<string>? scopes = null,
ErrorCode errorCode = ErrorCode.Forbidden) ErrorCode errorCode = ErrorCode.Forbidden
: ApiError(message, statusCode: HttpStatusCode.Forbidden, errorCode: errorCode) ) : ApiError(message, statusCode: HttpStatusCode.Forbidden, errorCode: errorCode)
{ {
public readonly string[] Scopes = scopes?.ToArray() ?? []; public readonly string[] Scopes = scopes?.ToArray() ?? [];
} }
public class BadRequest(string message, IReadOnlyDictionary<string, IEnumerable<ValidationError>>? errors = null) public class BadRequest(
: ApiError(message, statusCode: HttpStatusCode.BadRequest) string message,
IReadOnlyDictionary<string, IEnumerable<ValidationError>>? errors = null
) : ApiError(message, statusCode: HttpStatusCode.BadRequest)
{ {
public BadRequest(string message, string field, object actualValue) : this("Error validating input", public BadRequest(string message, string field, object actualValue)
: this(
"Error validating input",
new Dictionary<string, IEnumerable<ValidationError>> new Dictionary<string, IEnumerable<ValidationError>>
{ { field, [ValidationError.GenericValidationError(message, actualValue)] } })
{ {
{ field, [ValidationError.GenericValidationError(message, actualValue)] },
} }
) { }
public JObject ToJson() public JObject ToJson()
{ {
@ -49,9 +57,10 @@ public class ApiError(string message, HttpStatusCode? statusCode = null, ErrorCo
{ {
{ "status", (int)HttpStatusCode.BadRequest }, { "status", (int)HttpStatusCode.BadRequest },
{ "message", Message }, { "message", Message },
{ "code", "BAD_REQUEST" } { "code", "BAD_REQUEST" },
}; };
if (errors == null) return o; if (errors == null)
return o;
var a = new JArray(); var a = new JArray();
foreach (var error in errors) foreach (var error in errors)
@ -59,7 +68,7 @@ public class ApiError(string message, HttpStatusCode? statusCode = null, ErrorCo
var errorObj = new JObject var errorObj = new JObject
{ {
{ "key", error.Key }, { "key", error.Key },
{ "errors", JArray.FromObject(error.Value) } { "errors", JArray.FromObject(error.Value) },
}; };
a.Add(errorObj); a.Add(errorObj);
} }
@ -82,9 +91,10 @@ public class ApiError(string message, HttpStatusCode? statusCode = null, ErrorCo
{ {
{ "status", (int)HttpStatusCode.BadRequest }, { "status", (int)HttpStatusCode.BadRequest },
{ "message", Message }, { "message", Message },
{ "code", "BAD_REQUEST" } { "code", "BAD_REQUEST" },
}; };
if (modelState == null) return o; if (modelState == null)
return o;
var a = new JArray(); var a = new JArray();
foreach (var error in modelState.Where(e => e.Value is { Errors.Count: > 0 })) foreach (var error in modelState.Where(e => e.Value is { Errors.Count: > 0 }))
@ -94,8 +104,13 @@ public class ApiError(string message, HttpStatusCode? statusCode = null, ErrorCo
{ "key", error.Key }, { "key", error.Key },
{ {
"errors", "errors",
new JArray(error.Value!.Errors.Select(e => new JObject { { "message", e.ErrorMessage } })) new JArray(
} error.Value!.Errors.Select(e => new JObject
{
{ "message", e.ErrorMessage },
})
)
},
}; };
a.Add(errorObj); a.Add(errorObj);
} }
@ -108,7 +123,8 @@ public class ApiError(string message, HttpStatusCode? statusCode = null, ErrorCo
public class NotFound(string message, ErrorCode? code = null) public class NotFound(string message, ErrorCode? code = null)
: ApiError(message, statusCode: HttpStatusCode.NotFound, errorCode: code); : ApiError(message, statusCode: HttpStatusCode.NotFound, errorCode: code);
public class AuthenticationError(string message) : ApiError(message, statusCode: HttpStatusCode.BadRequest); public class AuthenticationError(string message)
: ApiError(message, statusCode: HttpStatusCode.BadRequest);
} }
public enum ErrorCode public enum ErrorCode
@ -143,34 +159,38 @@ public class ValidationError
[JsonProperty(NullValueHandling = NullValueHandling.Ignore)] [JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
public object? ActualValue { get; init; } public object? ActualValue { get; init; }
public static ValidationError LengthError(string message, int minLength, int maxLength, int actualLength) public static ValidationError LengthError(
string message,
int minLength,
int maxLength,
int actualLength
)
{ {
return new ValidationError return new ValidationError
{ {
Message = message, Message = message,
MinLength = minLength, MinLength = minLength,
MaxLength = maxLength, MaxLength = maxLength,
ActualLength = actualLength ActualLength = actualLength,
}; };
} }
public static ValidationError DisallowedValueError(string message, IEnumerable<object> allowedValues, public static ValidationError DisallowedValueError(
object actualValue) string message,
IEnumerable<object> allowedValues,
object actualValue
)
{ {
return new ValidationError return new ValidationError
{ {
Message = message, Message = message,
AllowedValues = allowedValues, AllowedValues = allowedValues,
ActualValue = actualValue ActualValue = actualValue,
}; };
} }
public static ValidationError GenericValidationError(string message, object? actualValue) public static ValidationError GenericValidationError(string message, object? actualValue)
{ {
return new ValidationError return new ValidationError { Message = message, ActualValue = actualValue };
{
Message = message,
ActualValue = actualValue
};
} }
} }

View file

@ -14,21 +14,35 @@ public static class AvatarObjectExtensions
{ {
private static readonly string[] ValidContentTypes = ["image/png", "image/webp", "image/jpeg"]; private static readonly string[] ValidContentTypes = ["image/png", "image/webp", "image/jpeg"];
public static async Task public static async Task DeleteMemberAvatarAsync(
DeleteMemberAvatarAsync(this ObjectStorageService objectStorageService, Snowflake id, string hash, this ObjectStorageService objectStorageService,
CancellationToken ct = default) => Snowflake id,
await objectStorageService.RemoveObjectAsync(MemberAvatarUpdateInvocable.Path(id, hash), ct); string hash,
CancellationToken ct = default
) =>
await objectStorageService.RemoveObjectAsync(
MemberAvatarUpdateInvocable.Path(id, hash),
ct
);
public static async Task public static async Task DeleteUserAvatarAsync(
DeleteUserAvatarAsync(this ObjectStorageService objectStorageService, Snowflake id, string hash, this ObjectStorageService objectStorageService,
CancellationToken ct = default) => Snowflake id,
await objectStorageService.RemoveObjectAsync(UserAvatarUpdateInvocable.Path(id, hash), ct); string hash,
CancellationToken ct = default
) => await objectStorageService.RemoveObjectAsync(UserAvatarUpdateInvocable.Path(id, hash), ct);
public static async Task DeleteFlagAsync(this ObjectStorageService objectStorageService, string hash, public static async Task DeleteFlagAsync(
CancellationToken ct = default) => this ObjectStorageService objectStorageService,
await objectStorageService.RemoveObjectAsync(CreateFlagInvocable.Path(hash), ct); string hash,
CancellationToken ct = default
) => await objectStorageService.RemoveObjectAsync(CreateFlagInvocable.Path(hash), ct);
public static async Task<(string Hash, Stream Image)> ConvertBase64UriToImage(this string uri, int size, bool crop) public static async Task<(string Hash, Stream Image)> ConvertBase64UriToImage(
this string uri,
int size,
bool crop
)
{ {
if (!uri.StartsWith("data:image/")) if (!uri.StartsWith("data:image/"))
throw new ArgumentException("Not a data URI", nameof(uri)); throw new ArgumentException("Not a data URI", nameof(uri));
@ -49,7 +63,7 @@ public static class AvatarObjectExtensions
{ {
Size = new Size(size), Size = new Size(size),
Mode = crop ? ResizeMode.Crop : ResizeMode.Max, Mode = crop ? ResizeMode.Crop : ResizeMode.Max,
Position = AnchorPositionMode.Center Position = AnchorPositionMode.Center,
}, },
image.Size image.Size
); );

View file

@ -8,37 +8,58 @@ namespace Foxnouns.Backend.Extensions;
public static class KeyCacheExtensions public static class KeyCacheExtensions
{ {
public static async Task<string> GenerateAuthStateAsync(this KeyCacheService keyCacheService, public static async Task<string> GenerateAuthStateAsync(
CancellationToken ct = default) this KeyCacheService keyCacheService,
CancellationToken ct = default
)
{ {
var state = AuthUtils.RandomToken().Replace('+', '-').Replace('/', '_'); var state = AuthUtils.RandomToken().Replace('+', '-').Replace('/', '_');
await keyCacheService.SetKeyAsync($"oauth_state:{state}", "", Duration.FromMinutes(10), ct); await keyCacheService.SetKeyAsync($"oauth_state:{state}", "", Duration.FromMinutes(10), ct);
return state; return state;
} }
public static async Task ValidateAuthStateAsync(this KeyCacheService keyCacheService, string state, public static async Task ValidateAuthStateAsync(
CancellationToken ct = default) this KeyCacheService keyCacheService,
string state,
CancellationToken ct = default
)
{ {
var val = await keyCacheService.GetKeyAsync($"oauth_state:{state}", delete: true, ct); var val = await keyCacheService.GetKeyAsync($"oauth_state:{state}", delete: true, ct);
if (val == null) throw new ApiError.BadRequest("Invalid OAuth state"); if (val == null)
throw new ApiError.BadRequest("Invalid OAuth state");
} }
public static async Task<string> GenerateRegisterEmailStateAsync(this KeyCacheService keyCacheService, string email, public static async Task<string> GenerateRegisterEmailStateAsync(
Snowflake? userId = null, CancellationToken ct = default) this KeyCacheService keyCacheService,
string email,
Snowflake? userId = null,
CancellationToken ct = default
)
{ {
// This state is used in links, not just as JSON values, so make it URL-safe // This state is used in links, not just as JSON values, so make it URL-safe
var state = AuthUtils.RandomToken().Replace('+', '-').Replace('/', '_'); var state = AuthUtils.RandomToken().Replace('+', '-').Replace('/', '_');
await keyCacheService.SetKeyAsync($"email_state:{state}", new RegisterEmailState(email, userId), await keyCacheService.SetKeyAsync(
Duration.FromDays(1), ct); $"email_state:{state}",
new RegisterEmailState(email, userId),
Duration.FromDays(1),
ct
);
return state; return state;
} }
public static async Task<RegisterEmailState?> GetRegisterEmailStateAsync(this KeyCacheService keyCacheService, public static async Task<RegisterEmailState?> GetRegisterEmailStateAsync(
string state, CancellationToken ct = default) => this KeyCacheService keyCacheService,
await keyCacheService.GetKeyAsync<RegisterEmailState>($"email_state:{state}", delete: true, ct); string state,
CancellationToken ct = default
) =>
await keyCacheService.GetKeyAsync<RegisterEmailState>(
$"email_state:{state}",
delete: true,
ct
);
} }
public record RegisterEmailState( public record RegisterEmailState(
string Email, string Email,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] [property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] Snowflake? ExistingUserId
Snowflake? ExistingUserId); );

View file

@ -29,8 +29,10 @@ public static class WebApplicationExtensions
// ASP.NET's built in request logs are extremely verbose, so we use Serilog's instead. // ASP.NET's built in request logs are extremely verbose, so we use Serilog's instead.
// Serilog doesn't disable the built-in logs, so we do it here. // Serilog doesn't disable the built-in logs, so we do it here.
.MinimumLevel.Override("Microsoft", LogEventLevel.Information) .MinimumLevel.Override("Microsoft", LogEventLevel.Information)
.MinimumLevel.Override("Microsoft.EntityFrameworkCore.Database.Command", .MinimumLevel.Override(
config.Logging.LogQueries ? LogEventLevel.Information : LogEventLevel.Fatal) "Microsoft.EntityFrameworkCore.Database.Command",
config.Logging.LogQueries ? LogEventLevel.Information : LogEventLevel.Fatal
)
.MinimumLevel.Override("Microsoft.AspNetCore.Hosting", LogEventLevel.Warning) .MinimumLevel.Override("Microsoft.AspNetCore.Hosting", LogEventLevel.Warning)
.MinimumLevel.Override("Microsoft.AspNetCore.Mvc", LogEventLevel.Warning) .MinimumLevel.Override("Microsoft.AspNetCore.Mvc", LogEventLevel.Warning)
.MinimumLevel.Override("Microsoft.AspNetCore.Routing", LogEventLevel.Warning) .MinimumLevel.Override("Microsoft.AspNetCore.Routing", LogEventLevel.Warning)
@ -38,7 +40,10 @@ public static class WebApplicationExtensions
if (config.Logging.SeqLogUrl != null) if (config.Logging.SeqLogUrl != null)
{ {
logCfg.WriteTo.Seq(config.Logging.SeqLogUrl, restrictedToMinimumLevel: LogEventLevel.Verbose); logCfg.WriteTo.Seq(
config.Logging.SeqLogUrl,
restrictedToMinimumLevel: LogEventLevel.Verbose
);
} }
// AddSerilog doesn't seem to add an ILogger to the service collection, so add that manually. // AddSerilog doesn't seem to add an ILogger to the service collection, so add that manually.
@ -74,7 +79,8 @@ public static class WebApplicationExtensions
/// </summary> /// </summary>
public static IServiceCollection AddServices(this WebApplicationBuilder builder, Config config) public static IServiceCollection AddServices(this WebApplicationBuilder builder, Config config)
{ {
builder.Host.ConfigureServices((ctx, services) => builder.Host.ConfigureServices(
(ctx, services) =>
{ {
services services
.AddQueue() .AddQueue()
@ -84,7 +90,8 @@ public static class WebApplicationExtensions
.AddMinio(c => .AddMinio(c =>
c.WithEndpoint(config.Storage.Endpoint) c.WithEndpoint(config.Storage.Endpoint)
.WithCredentials(config.Storage.AccessKey, config.Storage.SecretKey) .WithCredentials(config.Storage.AccessKey, config.Storage.SecretKey)
.Build()) .Build()
)
.AddSingleton<MetricsCollectionService>() .AddSingleton<MetricsCollectionService>()
.AddSingleton<IClock>(SystemClock.Instance) .AddSingleton<IClock>(SystemClock.Instance)
.AddSnowflakeGenerator() .AddSnowflakeGenerator()
@ -104,18 +111,20 @@ public static class WebApplicationExtensions
if (!config.Logging.EnableMetrics) if (!config.Logging.EnableMetrics)
services.AddHostedService<BackgroundMetricsCollectionService>(); services.AddHostedService<BackgroundMetricsCollectionService>();
}); }
);
return builder.Services; return builder.Services;
} }
public static IServiceCollection AddCustomMiddleware(this IServiceCollection services) => services public static IServiceCollection AddCustomMiddleware(this IServiceCollection services) =>
services
.AddScoped<ErrorHandlerMiddleware>() .AddScoped<ErrorHandlerMiddleware>()
.AddScoped<AuthenticationMiddleware>() .AddScoped<AuthenticationMiddleware>()
.AddScoped<AuthorizationMiddleware>(); .AddScoped<AuthorizationMiddleware>();
public static IApplicationBuilder UseCustomMiddleware(this IApplicationBuilder app) => app public static IApplicationBuilder UseCustomMiddleware(this IApplicationBuilder app) =>
.UseMiddleware<ErrorHandlerMiddleware>() app.UseMiddleware<ErrorHandlerMiddleware>()
.UseMiddleware<AuthenticationMiddleware>() .UseMiddleware<AuthenticationMiddleware>()
.UseMiddleware<AuthorizationMiddleware>(); .UseMiddleware<AuthorizationMiddleware>();
@ -124,13 +133,20 @@ public static class WebApplicationExtensions
// Read version information from .version in the repository root // Read version information from .version in the repository root
await BuildInfo.ReadBuildInfo(); await BuildInfo.ReadBuildInfo();
app.Services.ConfigureQueue().LogQueuedTaskProgress(app.Services.GetRequiredService<ILogger<IQueue>>()); app.Services.ConfigureQueue()
.LogQueuedTaskProgress(app.Services.GetRequiredService<ILogger<IQueue>>());
await using var scope = app.Services.CreateAsyncScope(); await using var scope = app.Services.CreateAsyncScope();
var logger = scope.ServiceProvider.GetRequiredService<ILogger>().ForContext<WebApplication>(); var logger = scope
.ServiceProvider.GetRequiredService<ILogger>()
.ForContext<WebApplication>();
var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>(); var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
logger.Information("Starting Foxnouns.NET {Version} ({Hash})", BuildInfo.Version, BuildInfo.Hash); logger.Information(
"Starting Foxnouns.NET {Version} ({Hash})",
BuildInfo.Version,
BuildInfo.Hash
);
var pendingMigrations = (await db.Database.GetPendingMigrationsAsync()).ToList(); var pendingMigrations = (await db.Database.GetPendingMigrationsAsync()).ToList();
if (args.Contains("--migrate") || args.Contains("--migrate-and-start")) if (args.Contains("--migrate") || args.Contains("--migrate-and-start"))
@ -146,13 +162,15 @@ public static class WebApplicationExtensions
logger.Information("Successfully migrated database"); logger.Information("Successfully migrated database");
} }
if (!args.Contains("--migrate-and-start")) Environment.Exit(0); if (!args.Contains("--migrate-and-start"))
Environment.Exit(0);
} }
else if (pendingMigrations.Count > 0) else if (pendingMigrations.Count > 0)
{ {
logger.Fatal( logger.Fatal(
"There are {Count} pending migrations, run server with --migrate or --migrate-and-start to run migrations.", "There are {Count} pending migrations, run server with --migrate or --migrate-and-start to run migrations.",
pendingMigrations.Count); pendingMigrations.Count
);
Environment.Exit(1); Environment.Exit(1);
} }

View file

@ -4,23 +4,35 @@ namespace Foxnouns.Backend;
public static class FoxnounsMetrics public static class FoxnounsMetrics
{ {
public static readonly Gauge UsersCount = public static readonly Gauge UsersCount = Metrics.CreateGauge(
Metrics.CreateGauge("foxnouns_user_count", "Number of total users"); "foxnouns_user_count",
"Number of total users"
);
public static readonly Gauge UsersActiveMonthCount = public static readonly Gauge UsersActiveMonthCount = Metrics.CreateGauge(
Metrics.CreateGauge("foxnouns_user_count_active_month", "Number of users active in the last month"); "foxnouns_user_count_active_month",
"Number of users active in the last month"
);
public static readonly Gauge UsersActiveWeekCount = public static readonly Gauge UsersActiveWeekCount = Metrics.CreateGauge(
Metrics.CreateGauge("foxnouns_user_count_active_week", "Number of users active in the last week"); "foxnouns_user_count_active_week",
"Number of users active in the last week"
);
public static readonly Gauge UsersActiveDayCount = public static readonly Gauge UsersActiveDayCount = Metrics.CreateGauge(
Metrics.CreateGauge("foxnouns_user_count_active_day", "Number of users active in the last day"); "foxnouns_user_count_active_day",
"Number of users active in the last day"
);
public static readonly Gauge MemberCount = public static readonly Gauge MemberCount = Metrics.CreateGauge(
Metrics.CreateGauge("foxnouns_member_count", "Number of total members"); "foxnouns_member_count",
"Number of total members"
);
public static readonly Summary MetricsCollectionTime = public static readonly Summary MetricsCollectionTime = Metrics.CreateSummary(
Metrics.CreateSummary("foxnouns_time_metrics", "Time it took to collect metrics"); "foxnouns_time_metrics",
"Time it took to collect metrics"
);
public static Gauge ProcessPhysicalMemory => public static Gauge ProcessPhysicalMemory =>
Metrics.CreateGauge("foxnouns_process_physical_memory", "Process physical memory"); Metrics.CreateGauge("foxnouns_process_physical_memory", "Process physical memory");
@ -31,7 +43,9 @@ public static class FoxnounsMetrics
public static Gauge ProcessPrivateMemory => public static Gauge ProcessPrivateMemory =>
Metrics.CreateGauge("foxnouns_process_private_memory", "Process private memory"); Metrics.CreateGauge("foxnouns_process_private_memory", "Process private memory");
public static Gauge ProcessThreads => Metrics.CreateGauge("foxnouns_process_threads", "Process thread count"); public static Gauge ProcessThreads =>
Metrics.CreateGauge("foxnouns_process_threads", "Process thread count");
public static Gauge ProcessHandles => Metrics.CreateGauge("foxnouns_process_handles", "Process handle count"); public static Gauge ProcessHandles =>
Metrics.CreateGauge("foxnouns_process_handles", "Process handle count");
} }

View file

@ -6,20 +6,30 @@ using Foxnouns.Backend.Services;
namespace Foxnouns.Backend.Jobs; namespace Foxnouns.Backend.Jobs;
public class CreateFlagInvocable(DatabaseContext db, ObjectStorageService objectStorageService, ILogger logger) public class CreateFlagInvocable(
: IInvocable, IInvocableWithPayload<CreateFlagPayload> DatabaseContext db,
ObjectStorageService objectStorageService,
ILogger logger
) : IInvocable, IInvocableWithPayload<CreateFlagPayload>
{ {
private readonly ILogger _logger = logger.ForContext<CreateFlagInvocable>(); private readonly ILogger _logger = logger.ForContext<CreateFlagInvocable>();
public required CreateFlagPayload Payload { get; set; } public required CreateFlagPayload Payload { get; set; }
public async Task Invoke() public async Task Invoke()
{ {
_logger.Information("Creating flag {FlagId} for user {UserId} with image data length {DataLength}", Payload.Id, _logger.Information(
Payload.UserId, Payload.ImageData.Length); "Creating flag {FlagId} for user {UserId} with image data length {DataLength}",
Payload.Id,
Payload.UserId,
Payload.ImageData.Length
);
try try
{ {
var (hash, image) = await Payload.ImageData.ConvertBase64UriToImage(size: 256, crop: false); var (hash, image) = await Payload.ImageData.ConvertBase64UriToImage(
size: 256,
crop: false
);
await objectStorageService.PutObjectAsync(Path(hash), image, "image/webp"); await objectStorageService.PutObjectAsync(Path(hash), image, "image/webp");
var flag = new PrideFlag var flag = new PrideFlag
@ -28,7 +38,7 @@ public class CreateFlagInvocable(DatabaseContext db, ObjectStorageService object
UserId = Payload.UserId, UserId = Payload.UserId,
Hash = hash, Hash = hash,
Name = Payload.Name, Name = Payload.Name,
Description = Payload.Description Description = Payload.Description,
}; };
db.Add(flag); db.Add(flag);

View file

@ -6,16 +6,21 @@ using Foxnouns.Backend.Services;
namespace Foxnouns.Backend.Jobs; namespace Foxnouns.Backend.Jobs;
public class MemberAvatarUpdateInvocable(DatabaseContext db, ObjectStorageService objectStorageService, ILogger logger) public class MemberAvatarUpdateInvocable(
: IInvocable, IInvocableWithPayload<AvatarUpdatePayload> DatabaseContext db,
ObjectStorageService objectStorageService,
ILogger logger
) : IInvocable, IInvocableWithPayload<AvatarUpdatePayload>
{ {
private readonly ILogger _logger = logger.ForContext<UserAvatarUpdateInvocable>(); private readonly ILogger _logger = logger.ForContext<UserAvatarUpdateInvocable>();
public required AvatarUpdatePayload Payload { get; set; } public required AvatarUpdatePayload Payload { get; set; }
public async Task Invoke() public async Task Invoke()
{ {
if (Payload.NewAvatar != null) await UpdateMemberAvatarAsync(Payload.Id, Payload.NewAvatar); if (Payload.NewAvatar != null)
else await ClearMemberAvatarAsync(Payload.Id); await UpdateMemberAvatarAsync(Payload.Id, Payload.NewAvatar);
else
await ClearMemberAvatarAsync(Payload.Id);
} }
private async Task UpdateMemberAvatarAsync(Snowflake id, string newAvatar) private async Task UpdateMemberAvatarAsync(Snowflake id, string newAvatar)
@ -25,7 +30,10 @@ public class MemberAvatarUpdateInvocable(DatabaseContext db, ObjectStorageServic
var member = await db.Members.FindAsync(id); var member = await db.Members.FindAsync(id);
if (member == null) if (member == null)
{ {
_logger.Warning("Update avatar job queued for {MemberId} but no member with that ID exists", id); _logger.Warning(
"Update avatar job queued for {MemberId} but no member with that ID exists",
id
);
return; return;
} }
@ -46,7 +54,11 @@ public class MemberAvatarUpdateInvocable(DatabaseContext db, ObjectStorageServic
} }
catch (ArgumentException ae) catch (ArgumentException ae)
{ {
_logger.Warning("Invalid data URI for new avatar for member {MemberId}: {Reason}", id, ae.Message); _logger.Warning(
"Invalid data URI for new avatar for member {MemberId}: {Reason}",
id,
ae.Message
);
} }
} }
@ -57,7 +69,10 @@ public class MemberAvatarUpdateInvocable(DatabaseContext db, ObjectStorageServic
var member = await db.Members.FindAsync(id); var member = await db.Members.FindAsync(id);
if (member == null) if (member == null)
{ {
_logger.Warning("Clear avatar job queued for {MemberId} but no member with that ID exists", id); _logger.Warning(
"Clear avatar job queued for {MemberId} but no member with that ID exists",
id
);
return; return;
} }

View file

@ -4,4 +4,10 @@ namespace Foxnouns.Backend.Jobs;
public record AvatarUpdatePayload(Snowflake Id, string? NewAvatar); public record AvatarUpdatePayload(Snowflake Id, string? NewAvatar);
public record CreateFlagPayload(Snowflake Id, Snowflake UserId, string Name, string ImageData, string? Description); public record CreateFlagPayload(
Snowflake Id,
Snowflake UserId,
string Name,
string ImageData,
string? Description
);

View file

@ -6,16 +6,21 @@ using Foxnouns.Backend.Services;
namespace Foxnouns.Backend.Jobs; namespace Foxnouns.Backend.Jobs;
public class UserAvatarUpdateInvocable(DatabaseContext db, ObjectStorageService objectStorageService, ILogger logger) public class UserAvatarUpdateInvocable(
: IInvocable, IInvocableWithPayload<AvatarUpdatePayload> DatabaseContext db,
ObjectStorageService objectStorageService,
ILogger logger
) : IInvocable, IInvocableWithPayload<AvatarUpdatePayload>
{ {
private readonly ILogger _logger = logger.ForContext<UserAvatarUpdateInvocable>(); private readonly ILogger _logger = logger.ForContext<UserAvatarUpdateInvocable>();
public required AvatarUpdatePayload Payload { get; set; } public required AvatarUpdatePayload Payload { get; set; }
public async Task Invoke() public async Task Invoke()
{ {
if (Payload.NewAvatar != null) await UpdateUserAvatarAsync(Payload.Id, Payload.NewAvatar); if (Payload.NewAvatar != null)
else await ClearUserAvatarAsync(Payload.Id); await UpdateUserAvatarAsync(Payload.Id, Payload.NewAvatar);
else
await ClearUserAvatarAsync(Payload.Id);
} }
private async Task UpdateUserAvatarAsync(Snowflake id, string newAvatar) private async Task UpdateUserAvatarAsync(Snowflake id, string newAvatar)
@ -25,7 +30,10 @@ public class UserAvatarUpdateInvocable(DatabaseContext db, ObjectStorageService
var user = await db.Users.FindAsync(id); var user = await db.Users.FindAsync(id);
if (user == null) if (user == null)
{ {
_logger.Warning("Update avatar job queued for {UserId} but no user with that ID exists", id); _logger.Warning(
"Update avatar job queued for {UserId} but no user with that ID exists",
id
);
return; return;
} }
@ -47,7 +55,11 @@ public class UserAvatarUpdateInvocable(DatabaseContext db, ObjectStorageService
} }
catch (ArgumentException ae) catch (ArgumentException ae)
{ {
_logger.Warning("Invalid data URI for new avatar for user {UserId}: {Reason}", id, ae.Message); _logger.Warning(
"Invalid data URI for new avatar for user {UserId}: {Reason}",
id,
ae.Message
);
} }
} }
@ -58,7 +70,10 @@ public class UserAvatarUpdateInvocable(DatabaseContext db, ObjectStorageService
var user = await db.Users.FindAsync(id); var user = await db.Users.FindAsync(id);
if (user == null) if (user == null)
{ {
_logger.Warning("Clear avatar job queued for {UserId} but no user with that ID exists", id); _logger.Warning(
"Clear avatar job queued for {UserId} but no user with that ID exists",
id
);
return; return;
} }

View file

@ -7,9 +7,7 @@ public class AccountCreationMailable(Config config, AccountCreationMailableView
{ {
public override void Build() public override void Build()
{ {
To(view.To) To(view.To).From(config.EmailAuth.From!).View("~/Views/Mail/AccountCreation.cshtml", view);
.From(config.EmailAuth.From!)
.View("~/Views/Mail/AccountCreation.cshtml", view);
} }
} }

View file

@ -17,7 +17,9 @@ public class AuthenticationMiddleware(DatabaseContext db) : IMiddleware
return; return;
} }
if (!AuthUtils.TryParseToken(ctx.Request.Headers.Authorization.ToString(), out var rawToken)) if (
!AuthUtils.TryParseToken(ctx.Request.Headers.Authorization.ToString(), out var rawToken)
)
{ {
await next(ctx); await next(ctx);
return; return;
@ -40,6 +42,7 @@ public static class HttpContextExtensions
private const string Key = "token"; private const string Key = "token";
public static void SetToken(this HttpContext ctx, Token token) => ctx.Items.Add(Key, token); public static void SetToken(this HttpContext ctx, Token token) => ctx.Items.Add(Key, token);
public static User? GetUser(this HttpContext ctx) => ctx.GetToken()?.User; public static User? GetUser(this HttpContext ctx) => ctx.GetToken()?.User;
public static User GetUserOrThrow(this HttpContext ctx) => public static User GetUserOrThrow(this HttpContext ctx) =>

View file

@ -18,14 +18,26 @@ public class AuthorizationMiddleware : IMiddleware
var token = ctx.GetToken(); var token = ctx.GetToken();
if (token == null) if (token == null)
throw new ApiError.Unauthorized("This endpoint requires an authenticated user.", throw new ApiError.Unauthorized(
ErrorCode.AuthenticationRequired); "This endpoint requires an authenticated user.",
if (attribute.Scopes.Length > 0 && attribute.Scopes.Except(token.Scopes.ExpandScopes()).Any()) ErrorCode.AuthenticationRequired
throw new ApiError.Forbidden("This endpoint requires ungranted scopes.", );
attribute.Scopes.Except(token.Scopes.ExpandScopes()), ErrorCode.MissingScopes); if (
attribute.Scopes.Length > 0
&& attribute.Scopes.Except(token.Scopes.ExpandScopes()).Any()
)
throw new ApiError.Forbidden(
"This endpoint requires ungranted scopes.",
attribute.Scopes.Except(token.Scopes.ExpandScopes()),
ErrorCode.MissingScopes
);
if (attribute.RequireAdmin && token.User.Role != UserRole.Admin) if (attribute.RequireAdmin && token.User.Role != UserRole.Admin)
throw new ApiError.Forbidden("This endpoint can only be used by admins."); throw new ApiError.Forbidden("This endpoint can only be used by admins.");
if (attribute.RequireModerator && token.User.Role != UserRole.Admin && token.User.Role != UserRole.Moderator) if (
attribute.RequireModerator
&& token.User.Role != UserRole.Admin
&& token.User.Role != UserRole.Moderator
)
throw new ApiError.Forbidden("This endpoint can only be used by moderators."); throw new ApiError.Forbidden("This endpoint can only be used by moderators.");
await next(ctx); await next(ctx);

View file

@ -21,19 +21,26 @@ public class ErrorHandlerMiddleware(ILogger baseLogger, IHub sentry) : IMiddlewa
if (ctx.Response.HasStarted) if (ctx.Response.HasStarted)
{ {
logger.Error(e, "Error in {ClassName} ({Path}) after response started being sent", typeName, logger.Error(
ctx.Request.Path); e,
"Error in {ClassName} ({Path}) after response started being sent",
typeName,
ctx.Request.Path
);
sentry.CaptureException(e, scope => sentry.CaptureException(
e,
scope =>
{ {
var user = ctx.GetUser(); var user = ctx.GetUser();
if (user != null) if (user != null)
scope.User = new SentryUser scope.User = new SentryUser
{ {
Id = user.Id.ToString(), Id = user.Id.ToString(),
Username = user.Username Username = user.Username,
}; };
}); }
);
return; return;
} }
@ -45,13 +52,17 @@ public class ErrorHandlerMiddleware(ILogger baseLogger, IHub sentry) : IMiddlewa
ctx.Response.ContentType = "application/json; charset=utf-8"; ctx.Response.ContentType = "application/json; charset=utf-8";
if (ae is ApiError.Forbidden fe) if (ae is ApiError.Forbidden fe)
{ {
await ctx.Response.WriteAsync(JsonConvert.SerializeObject(new HttpApiError await ctx.Response.WriteAsync(
JsonConvert.SerializeObject(
new HttpApiError
{ {
Status = (int)fe.StatusCode, Status = (int)fe.StatusCode,
Code = ErrorCode.Forbidden, Code = ErrorCode.Forbidden,
Message = fe.Message, Message = fe.Message,
Scopes = fe.Scopes.Length > 0 ? fe.Scopes : null Scopes = fe.Scopes.Length > 0 ? fe.Scopes : null,
})); }
)
);
return; return;
} }
@ -61,45 +72,61 @@ public class ErrorHandlerMiddleware(ILogger baseLogger, IHub sentry) : IMiddlewa
return; return;
} }
await ctx.Response.WriteAsync(JsonConvert.SerializeObject(new HttpApiError await ctx.Response.WriteAsync(
JsonConvert.SerializeObject(
new HttpApiError
{ {
Status = (int)ae.StatusCode, Status = (int)ae.StatusCode,
Code = ae.ErrorCode, Code = ae.ErrorCode,
Message = ae.Message, Message = ae.Message,
})); }
)
);
return; return;
} }
if (e is FoxnounsError fce) if (e is FoxnounsError fce)
{ {
logger.Error(fce.Inner ?? fce, "Exception in {ClassName} ({Path})", typeName, ctx.Request.Path); logger.Error(
fce.Inner ?? fce,
"Exception in {ClassName} ({Path})",
typeName,
ctx.Request.Path
);
} }
else else
{ {
logger.Error(e, "Exception in {ClassName} ({Path})", typeName, ctx.Request.Path); logger.Error(e, "Exception in {ClassName} ({Path})", typeName, ctx.Request.Path);
} }
var errorId = sentry.CaptureException(e, scope => var errorId = sentry.CaptureException(
e,
scope =>
{ {
var user = ctx.GetUser(); var user = ctx.GetUser();
if (user != null) if (user != null)
scope.User = new SentryUser scope.User = new SentryUser
{ {
Id = user.Id.ToString(), Id = user.Id.ToString(),
Username = user.Username Username = user.Username,
}; };
}); }
);
ctx.Response.StatusCode = (int)HttpStatusCode.InternalServerError; ctx.Response.StatusCode = (int)HttpStatusCode.InternalServerError;
ctx.Response.Headers.RequestId = ctx.TraceIdentifier; ctx.Response.Headers.RequestId = ctx.TraceIdentifier;
ctx.Response.ContentType = "application/json; charset=utf-8"; ctx.Response.ContentType = "application/json; charset=utf-8";
await ctx.Response.WriteAsync(JsonConvert.SerializeObject(new HttpApiError await ctx.Response.WriteAsync(
JsonConvert.SerializeObject(
new HttpApiError
{ {
Status = (int)HttpStatusCode.InternalServerError, Status = (int)HttpStatusCode.InternalServerError,
Code = ErrorCode.InternalServerError, Code = ErrorCode.InternalServerError,
ErrorId = errorId.ToString(), ErrorId = errorId.ToString(),
Message = "Internal server error", Message = "Internal server error",
})); }
)
);
} }
} }
} }

View file

@ -1,5 +1,4 @@
using Foxnouns.Backend; using Foxnouns.Backend;
using Serilog;
using Foxnouns.Backend.Extensions; using Foxnouns.Backend.Extensions;
using Foxnouns.Backend.Services; using Foxnouns.Backend.Services;
using Foxnouns.Backend.Utils; using Foxnouns.Backend.Utils;
@ -8,6 +7,7 @@ using Newtonsoft.Json;
using Newtonsoft.Json.Serialization; using Newtonsoft.Json.Serialization;
using Prometheus; using Prometheus;
using Sentry.Extensibility; using Sentry.Extensibility;
using Serilog;
var builder = WebApplication.CreateBuilder(args); var builder = WebApplication.CreateBuilder(args);
@ -15,8 +15,8 @@ var config = builder.AddConfiguration();
builder.AddSerilog(); builder.AddSerilog();
builder.WebHost builder
.UseSentry(opts => .WebHost.UseSentry(opts =>
{ {
opts.Dsn = config.Logging.SentryUrl; opts.Dsn = config.Logging.SentryUrl;
opts.TracesSampleRate = config.Logging.SentryTracesSampleRate; opts.TracesSampleRate = config.Logging.SentryTracesSampleRate;
@ -30,13 +30,13 @@ builder.WebHost
opts.Limits.MaxRequestBodySize = 2 * 1024 * 1024; opts.Limits.MaxRequestBodySize = 2 * 1024 * 1024;
}); });
builder.Services builder
.AddControllers() .Services.AddControllers()
.AddNewtonsoftJson(options => .AddNewtonsoftJson(options =>
{ {
options.SerializerSettings.ContractResolver = new PatchRequestContractResolver options.SerializerSettings.ContractResolver = new PatchRequestContractResolver
{ {
NamingStrategy = new SnakeCaseNamingStrategy() NamingStrategy = new SnakeCaseNamingStrategy(),
}; };
}) })
.ConfigureApiBehaviorOptions(options => .ConfigureApiBehaviorOptions(options =>
@ -47,18 +47,16 @@ builder.Services
}); });
// Set the default converter to snake case as we use it in a couple places. // Set the default converter to snake case as we use it in a couple places.
JsonConvert.DefaultSettings = () => new JsonSerializerSettings JsonConvert.DefaultSettings = () =>
new JsonSerializerSettings
{ {
ContractResolver = new DefaultContractResolver ContractResolver = new DefaultContractResolver
{ {
NamingStrategy = new SnakeCaseNamingStrategy() NamingStrategy = new SnakeCaseNamingStrategy(),
} },
}; };
builder.AddServices(config) builder.AddServices(config).AddCustomMiddleware().AddEndpointsApiExplorer().AddSwaggerGen();
.AddCustomMiddleware()
.AddEndpointsApiExplorer()
.AddSwaggerGen();
var app = builder.Build(); var app = builder.Build();
@ -66,9 +64,11 @@ await app.Initialize(args);
app.UseSerilogRequestLogging(); app.UseSerilogRequestLogging();
app.UseRouting(); app.UseRouting();
// Not all environments will want tracing (from experience, it's expensive to use in production, even with a low sample rate), // Not all environments will want tracing (from experience, it's expensive to use in production, even with a low sample rate),
// so it's locked behind a config option. // so it's locked behind a config option.
if (config.Logging.SentryTracing) app.UseSentryTracing(); if (config.Logging.SentryTracing)
app.UseSentryTracing();
app.UseSwagger(); app.UseSwagger();
app.UseSwaggerUI(); app.UseSwaggerUI();
app.UseCors(); app.UseCors();
@ -80,7 +80,8 @@ app.Urls.Add(config.Address);
// Make sure metrics are updated whenever Prometheus scrapes them // Make sure metrics are updated whenever Prometheus scrapes them
Metrics.DefaultRegistry.AddBeforeCollectCallback(async ct => Metrics.DefaultRegistry.AddBeforeCollectCallback(async ct =>
await app.Services.GetRequiredService<MetricsCollectionService>().CollectMetricsAsync(ct)); await app.Services.GetRequiredService<MetricsCollectionService>().CollectMetricsAsync(ct)
);
app.Run(); app.Run();
Log.CloseAndFlush(); Log.CloseAndFlush();

View file

@ -16,8 +16,12 @@ public class AuthService(IClock clock, DatabaseContext db, ISnowflakeGenerator s
/// Creates a new user with the given email address and password. /// Creates a new user with the given email address and password.
/// This method does <i>not</i> save the resulting user, the caller must still call <see cref="M:Microsoft.EntityFrameworkCore.DbContext.SaveChanges" />. /// This method does <i>not</i> save the resulting user, the caller must still call <see cref="M:Microsoft.EntityFrameworkCore.DbContext.SaveChanges" />.
/// </summary> /// </summary>
public async Task<User> CreateUserWithPasswordAsync(string username, string email, string password, public async Task<User> CreateUserWithPasswordAsync(
CancellationToken ct = default) string username,
string email,
string password,
CancellationToken ct = default
)
{ {
var user = new User var user = new User
{ {
@ -26,9 +30,13 @@ public class AuthService(IClock clock, DatabaseContext db, ISnowflakeGenerator s
AuthMethods = AuthMethods =
{ {
new AuthMethod new AuthMethod
{ Id = snowflakeGenerator.GenerateSnowflake(), AuthType = AuthType.Email, RemoteId = email } {
Id = snowflakeGenerator.GenerateSnowflake(),
AuthType = AuthType.Email,
RemoteId = email,
}, },
LastActive = clock.GetCurrentInstant() },
LastActive = clock.GetCurrentInstant(),
}; };
db.Add(user); db.Add(user);
@ -42,8 +50,14 @@ public class AuthService(IClock clock, DatabaseContext db, ISnowflakeGenerator s
/// To create a user with email authentication, use <see cref="CreateUserWithPasswordAsync" /> /// To create a user with email authentication, use <see cref="CreateUserWithPasswordAsync" />
/// This method does <i>not</i> save the resulting user, the caller must still call <see cref="M:Microsoft.EntityFrameworkCore.DbContext.SaveChanges" />. /// This method does <i>not</i> save the resulting user, the caller must still call <see cref="M:Microsoft.EntityFrameworkCore.DbContext.SaveChanges" />.
/// </summary> /// </summary>
public async Task<User> CreateUserWithRemoteAuthAsync(string username, AuthType authType, string remoteId, public async Task<User> CreateUserWithRemoteAuthAsync(
string remoteUsername, FediverseApplication? instance = null, CancellationToken ct = default) string username,
AuthType authType,
string remoteId,
string remoteUsername,
FediverseApplication? instance = null,
CancellationToken ct = default
)
{ {
AssertValidAuthType(authType, instance); AssertValidAuthType(authType, instance);
@ -58,11 +72,14 @@ public class AuthService(IClock clock, DatabaseContext db, ISnowflakeGenerator s
{ {
new AuthMethod new AuthMethod
{ {
Id = snowflakeGenerator.GenerateSnowflake(), AuthType = authType, RemoteId = remoteId, Id = snowflakeGenerator.GenerateSnowflake(),
RemoteUsername = remoteUsername, FediverseApplication = instance AuthType = authType,
} RemoteId = remoteId,
RemoteUsername = remoteUsername,
FediverseApplication = instance,
}, },
LastActive = clock.GetCurrentInstant() },
LastActive = clock.GetCurrentInstant(),
}; };
db.Add(user); db.Add(user);
@ -78,19 +95,31 @@ public class AuthService(IClock clock, DatabaseContext db, ISnowflakeGenerator s
/// <returns>A tuple of the authenticated user and whether multi-factor authentication is required</returns> /// <returns>A tuple of the authenticated user and whether multi-factor authentication is required</returns>
/// <exception cref="ApiError.NotFound">Thrown if the email address is not associated with any user /// <exception cref="ApiError.NotFound">Thrown if the email address is not associated with any user
/// or if the password is incorrect</exception> /// or if the password is incorrect</exception>
public async Task<(User, EmailAuthenticationResult)> AuthenticateUserAsync(string email, string password, public async Task<(User, EmailAuthenticationResult)> AuthenticateUserAsync(
CancellationToken ct = default) string email,
string password,
CancellationToken ct = default
)
{ {
var user = await db.Users.FirstOrDefaultAsync(u => var user = await db.Users.FirstOrDefaultAsync(
u.AuthMethods.Any(a => a.AuthType == AuthType.Email && a.RemoteId == email), ct); u => u.AuthMethods.Any(a => a.AuthType == AuthType.Email && a.RemoteId == email),
ct
);
if (user == null) if (user == null)
throw new ApiError.NotFound("No user with that email address found, or password is incorrect", throw new ApiError.NotFound(
ErrorCode.UserNotFound); "No user with that email address found, or password is incorrect",
ErrorCode.UserNotFound
);
var pwResult = await Task.Run(() => _passwordHasher.VerifyHashedPassword(user, user.Password!, password), ct); var pwResult = await Task.Run(
() => _passwordHasher.VerifyHashedPassword(user, user.Password!, password),
ct
);
if (pwResult == PasswordVerificationResult.Failed) // TODO: this seems to fail on some valid passwords? if (pwResult == PasswordVerificationResult.Failed) // TODO: this seems to fail on some valid passwords?
throw new ApiError.NotFound("No user with that email address found, or password is incorrect", throw new ApiError.NotFound(
ErrorCode.UserNotFound); "No user with that email address found, or password is incorrect",
ErrorCode.UserNotFound
);
if (pwResult == PasswordVerificationResult.SuccessRehashNeeded) if (pwResult == PasswordVerificationResult.SuccessRehashNeeded)
{ {
user.Password = await Task.Run(() => _passwordHasher.HashPassword(user, password), ct); user.Password = await Task.Run(() => _passwordHasher.HashPassword(user, password), ct);
@ -117,19 +146,33 @@ public class AuthService(IClock clock, DatabaseContext db, ISnowflakeGenerator s
/// <returns>A user object, or null if the remote account isn't linked to any user.</returns> /// <returns>A user object, or null if the remote account isn't linked to any user.</returns>
/// <exception cref="FoxnounsError">Thrown if <c>instance</c> is passed when not required, /// <exception cref="FoxnounsError">Thrown if <c>instance</c> is passed when not required,
/// or not passed when required</exception> /// or not passed when required</exception>
public async Task<User?> AuthenticateUserAsync(AuthType authType, string remoteId, public async Task<User?> AuthenticateUserAsync(
FediverseApplication? instance = null, CancellationToken ct = default) AuthType authType,
string remoteId,
FediverseApplication? instance = null,
CancellationToken ct = default
)
{ {
AssertValidAuthType(authType, instance); AssertValidAuthType(authType, instance);
return await db.Users.FirstOrDefaultAsync(u => return await db.Users.FirstOrDefaultAsync(
u =>
u.AuthMethods.Any(a => u.AuthMethods.Any(a =>
a.AuthType == authType && a.RemoteId == remoteId && a.FediverseApplication == instance), ct); a.AuthType == authType
&& a.RemoteId == remoteId
&& a.FediverseApplication == instance
),
ct
);
} }
public async Task<AuthMethod> AddAuthMethodAsync(Snowflake userId, AuthType authType, string remoteId, public async Task<AuthMethod> AddAuthMethodAsync(
Snowflake userId,
AuthType authType,
string remoteId,
string? remoteUsername = null, string? remoteUsername = null,
CancellationToken ct = default) CancellationToken ct = default
)
{ {
AssertValidAuthType(authType, null); AssertValidAuthType(authType, null);
@ -139,7 +182,7 @@ public class AuthService(IClock clock, DatabaseContext db, ISnowflakeGenerator s
AuthType = authType, AuthType = authType,
RemoteId = remoteId, RemoteId = remoteId,
RemoteUsername = remoteUsername, RemoteUsername = remoteUsername,
UserId = userId UserId = userId,
}; };
db.Add(authMethod); db.Add(authMethod);
@ -147,21 +190,33 @@ public class AuthService(IClock clock, DatabaseContext db, ISnowflakeGenerator s
return authMethod; return authMethod;
} }
public (string, Token) GenerateToken(User user, Application application, string[] scopes, Instant expires) public (string, Token) GenerateToken(
User user,
Application application,
string[] scopes,
Instant expires
)
{ {
if (!AuthUtils.ValidateScopes(application, scopes)) if (!AuthUtils.ValidateScopes(application, scopes))
throw new ApiError.BadRequest("Invalid scopes requested for this token", "scopes", scopes); throw new ApiError.BadRequest(
"Invalid scopes requested for this token",
"scopes",
scopes
);
var (token, hash) = GenerateToken(); var (token, hash) = GenerateToken();
return (token, new Token return (
token,
new Token
{ {
Id = snowflakeGenerator.GenerateSnowflake(), Id = snowflakeGenerator.GenerateSnowflake(),
Hash = hash, Hash = hash,
Application = application, Application = application,
User = user, User = user,
ExpiresAt = expires, ExpiresAt = expires,
Scopes = scopes Scopes = scopes,
}); }
);
} }
private static (string, byte[]) GenerateToken() private static (string, byte[]) GenerateToken()

View file

@ -10,26 +10,43 @@ public class KeyCacheService(DatabaseContext db, IClock clock, ILogger logger)
{ {
private readonly ILogger _logger = logger.ForContext<KeyCacheService>(); private readonly ILogger _logger = logger.ForContext<KeyCacheService>();
public Task SetKeyAsync(string key, string value, Duration expireAfter, CancellationToken ct = default) => public Task SetKeyAsync(
SetKeyAsync(key, value, clock.GetCurrentInstant() + expireAfter, ct); string key,
string value,
Duration expireAfter,
CancellationToken ct = default
) => SetKeyAsync(key, value, clock.GetCurrentInstant() + expireAfter, ct);
public async Task SetKeyAsync(string key, string value, Instant expires, CancellationToken ct = default) public async Task SetKeyAsync(
string key,
string value,
Instant expires,
CancellationToken ct = default
)
{ {
db.TemporaryKeys.Add(new TemporaryKey db.TemporaryKeys.Add(
new TemporaryKey
{ {
Expires = expires, Expires = expires,
Key = key, Key = key,
Value = value, Value = value,
}); }
);
await db.SaveChangesAsync(ct); await db.SaveChangesAsync(ct);
} }
public async Task<string?> GetKeyAsync(string key, bool delete = false, CancellationToken ct = default) public async Task<string?> GetKeyAsync(
string key,
bool delete = false,
CancellationToken ct = default
)
{ {
var value = await db.TemporaryKeys.FirstOrDefaultAsync(k => k.Key == key, ct); var value = await db.TemporaryKeys.FirstOrDefaultAsync(k => k.Key == key, ct);
if (value == null) return null; if (value == null)
return null;
if (delete) await db.TemporaryKeys.Where(k => k.Key == key).ExecuteDeleteAsync(ct); if (delete)
await db.TemporaryKeys.Where(k => k.Key == key).ExecuteDeleteAsync(ct);
return value.Value; return value.Value;
} }
@ -39,20 +56,38 @@ public class KeyCacheService(DatabaseContext db, IClock clock, ILogger logger)
public async Task DeleteExpiredKeysAsync(CancellationToken ct) public async Task DeleteExpiredKeysAsync(CancellationToken ct)
{ {
var count = await db.TemporaryKeys.Where(k => k.Expires < clock.GetCurrentInstant()).ExecuteDeleteAsync(ct); var count = await db
if (count != 0) _logger.Information("Removed {Count} expired keys from the database", count); .TemporaryKeys.Where(k => k.Expires < clock.GetCurrentInstant())
.ExecuteDeleteAsync(ct);
if (count != 0)
_logger.Information("Removed {Count} expired keys from the database", count);
} }
public Task SetKeyAsync<T>(string key, T obj, Duration expiresAt, CancellationToken ct = default) where T : class => public Task SetKeyAsync<T>(
SetKeyAsync(key, obj, clock.GetCurrentInstant() + expiresAt, ct); string key,
T obj,
Duration expiresAt,
CancellationToken ct = default
)
where T : class => SetKeyAsync(key, obj, clock.GetCurrentInstant() + expiresAt, ct);
public async Task SetKeyAsync<T>(string key, T obj, Instant expires, CancellationToken ct = default) where T : class public async Task SetKeyAsync<T>(
string key,
T obj,
Instant expires,
CancellationToken ct = default
)
where T : class
{ {
var value = JsonConvert.SerializeObject(obj); var value = JsonConvert.SerializeObject(obj);
await SetKeyAsync(key, value, expires, ct); await SetKeyAsync(key, value, expires, ct);
} }
public async Task<T?> GetKeyAsync<T>(string key, bool delete = false, CancellationToken ct = default) public async Task<T?> GetKeyAsync<T>(
string key,
bool delete = false,
CancellationToken ct = default
)
where T : class where T : class
{ {
var value = await GetKeyAsync(key, delete, ct); var value = await GetKeyAsync(key, delete, ct);

View file

@ -15,12 +15,17 @@ public class MailService(ILogger logger, IMailer mailer, IQueue queue, Config co
_logger.Debug("Sending account creation email to {ToEmail}", to); _logger.Debug("Sending account creation email to {ToEmail}", to);
try try
{ {
await mailer.SendAsync(new AccountCreationMailable(config, new AccountCreationMailableView await mailer.SendAsync(
new AccountCreationMailable(
config,
new AccountCreationMailableView
{ {
BaseUrl = config.BaseUrl, BaseUrl = config.BaseUrl,
To = to, To = to,
Code = code Code = code,
})); }
)
);
} }
catch (Exception exc) catch (Exception exc)
{ {

View file

@ -10,17 +10,17 @@ public class MemberRendererService(DatabaseContext db, Config config)
{ {
public async Task<IEnumerable<PartialMember>> RenderUserMembersAsync(User user, Token? token) public async Task<IEnumerable<PartialMember>> RenderUserMembersAsync(User user, Token? token)
{ {
var canReadHiddenMembers = token != null && token.UserId == user.Id && token.HasScope("member.read"); var canReadHiddenMembers =
var renderUnlisted = token != null && token.UserId == user.Id && token.HasScope("user.read_hidden"); token != null && token.UserId == user.Id && token.HasScope("member.read");
var renderUnlisted =
token != null && token.UserId == user.Id && token.HasScope("user.read_hidden");
var canReadMemberList = !user.ListHidden || canReadHiddenMembers; var canReadMemberList = !user.ListHidden || canReadHiddenMembers;
IEnumerable<Member> members = canReadMemberList IEnumerable<Member> members = canReadMemberList
? await db.Members ? await db.Members.Where(m => m.UserId == user.Id).OrderBy(m => m.Name).ToListAsync()
.Where(m => m.UserId == user.Id)
.OrderBy(m => m.Name)
.ToListAsync()
: []; : [];
if (!canReadHiddenMembers) members = members.Where(m => !m.Unlisted); if (!canReadHiddenMembers)
members = members.Where(m => !m.Unlisted);
return members.Select(m => RenderPartialMember(m, renderUnlisted)); return members.Select(m => RenderPartialMember(m, renderUnlisted));
} }
@ -29,25 +29,54 @@ public class MemberRendererService(DatabaseContext db, Config config)
var renderUnlisted = token?.UserId == member.UserId && token.HasScope("user.read_hidden"); var renderUnlisted = token?.UserId == member.UserId && token.HasScope("user.read_hidden");
return new MemberResponse( return new MemberResponse(
member.Id, member.Sid, member.Name, member.DisplayName, member.Bio, member.Id,
AvatarUrlFor(member), member.Links, member.Names, member.Pronouns, member.Fields, member.Sid,
member.Name,
member.DisplayName,
member.Bio,
AvatarUrlFor(member),
member.Links,
member.Names,
member.Pronouns,
member.Fields,
member.ProfileFlags.Select(f => RenderPrideFlag(f.PrideFlag)), member.ProfileFlags.Select(f => RenderPrideFlag(f.PrideFlag)),
RenderPartialUser(member.User), renderUnlisted ? member.Unlisted : null); RenderPartialUser(member.User),
renderUnlisted ? member.Unlisted : null
);
} }
private UserRendererService.PartialUser RenderPartialUser(User user) => private UserRendererService.PartialUser RenderPartialUser(User user) =>
new(user.Id, user.Sid, user.Username, user.DisplayName, AvatarUrlFor(user), user.CustomPreferences); new(
user.Id,
user.Sid,
user.Username,
user.DisplayName,
AvatarUrlFor(user),
user.CustomPreferences
);
public PartialMember RenderPartialMember(Member member, bool renderUnlisted = false) => new(member.Id, member.Sid, public PartialMember RenderPartialMember(Member member, bool renderUnlisted = false) =>
new(
member.Id,
member.Sid,
member.Name, member.Name,
member.DisplayName, member.Bio, AvatarUrlFor(member), member.Names, member.Pronouns, member.DisplayName,
renderUnlisted ? member.Unlisted : null); member.Bio,
AvatarUrlFor(member),
member.Names,
member.Pronouns,
renderUnlisted ? member.Unlisted : null
);
private string? AvatarUrlFor(Member member) => private string? AvatarUrlFor(Member member) =>
member.Avatar != null ? $"{config.MediaBaseUrl}/members/{member.Id}/avatars/{member.Avatar}.webp" : null; member.Avatar != null
? $"{config.MediaBaseUrl}/members/{member.Id}/avatars/{member.Avatar}.webp"
: null;
private string? AvatarUrlFor(User user) => private string? AvatarUrlFor(User user) =>
user.Avatar != null ? $"{config.MediaBaseUrl}/users/{user.Id}/avatars/{user.Avatar}.webp" : null; user.Avatar != null
? $"{config.MediaBaseUrl}/users/{user.Id}/avatars/{user.Avatar}.webp"
: null;
private string ImageUrlFor(PrideFlag flag) => $"{config.MediaBaseUrl}/flags/{flag.Hash}.webp"; private string ImageUrlFor(PrideFlag flag) => $"{config.MediaBaseUrl}/flags/{flag.Hash}.webp";
@ -63,8 +92,8 @@ public class MemberRendererService(DatabaseContext db, Config config)
string? AvatarUrl, string? AvatarUrl,
IEnumerable<FieldEntry> Names, IEnumerable<FieldEntry> Names,
IEnumerable<Pronoun> Pronouns, IEnumerable<Pronoun> Pronouns,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] [property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] bool? Unlisted
bool? Unlisted); );
public record MemberResponse( public record MemberResponse(
Snowflake Id, Snowflake Id,
@ -79,6 +108,6 @@ public class MemberRendererService(DatabaseContext db, Config config)
IEnumerable<Field> Fields, IEnumerable<Field> Fields,
IEnumerable<UserRendererService.PrideFlagResponse> Flags, IEnumerable<UserRendererService.PrideFlagResponse> Flags,
UserRendererService.PartialUser User, UserRendererService.PartialUser User,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] [property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] bool? Unlisted
bool? Unlisted); );
} }

View file

@ -6,10 +6,7 @@ using Prometheus;
namespace Foxnouns.Backend.Services; namespace Foxnouns.Backend.Services;
public class MetricsCollectionService( public class MetricsCollectionService(ILogger logger, IServiceProvider services, IClock clock)
ILogger logger,
IServiceProvider services,
IClock clock)
{ {
private readonly ILogger _logger = logger.ForContext<MetricsCollectionService>(); private readonly ILogger _logger = logger.ForContext<MetricsCollectionService>();
@ -31,8 +28,10 @@ public class MetricsCollectionService(
FoxnounsMetrics.UsersActiveWeekCount.Set(users.Count(i => i > now - Week)); FoxnounsMetrics.UsersActiveWeekCount.Set(users.Count(i => i > now - Week));
FoxnounsMetrics.UsersActiveDayCount.Set(users.Count(i => i > now - Day)); FoxnounsMetrics.UsersActiveDayCount.Set(users.Count(i => i > now - Day));
var memberCount = await db.Members.Include(m => m.User) var memberCount = await db
.Where(m => !m.Unlisted && !m.User.ListHidden && !m.User.Deleted).CountAsync(ct); .Members.Include(m => m.User)
.Where(m => !m.Unlisted && !m.User.ListHidden && !m.User.Deleted)
.CountAsync(ct);
FoxnounsMetrics.MemberCount.Set(memberCount); FoxnounsMetrics.MemberCount.Set(memberCount);
var process = Process.GetCurrentProcess(); var process = Process.GetCurrentProcess();
@ -42,13 +41,17 @@ public class MetricsCollectionService(
FoxnounsMetrics.ProcessThreads.Set(process.Threads.Count); FoxnounsMetrics.ProcessThreads.Set(process.Threads.Count);
FoxnounsMetrics.ProcessHandles.Set(process.HandleCount); FoxnounsMetrics.ProcessHandles.Set(process.HandleCount);
_logger.Information("Collected metrics in {DurationMilliseconds} ms", _logger.Information(
timer.ObserveDuration().TotalMilliseconds); "Collected metrics in {DurationMilliseconds} ms",
timer.ObserveDuration().TotalMilliseconds
);
} }
} }
public class BackgroundMetricsCollectionService(ILogger logger, MetricsCollectionService metricsCollectionService) public class BackgroundMetricsCollectionService(
: BackgroundService ILogger logger,
MetricsCollectionService metricsCollectionService
) : BackgroundService
{ {
private readonly ILogger _logger = logger.ForContext<BackgroundMetricsCollectionService>(); private readonly ILogger _logger = logger.ForContext<BackgroundMetricsCollectionService>();

View file

@ -15,7 +15,8 @@ public class ObjectStorageService(ILogger logger, Config config, IMinioClient mi
{ {
await minioClient.RemoveObjectAsync( await minioClient.RemoveObjectAsync(
new RemoveObjectArgs().WithBucket(config.Storage.Bucket).WithObject(path), new RemoveObjectArgs().WithBucket(config.Storage.Bucket).WithObject(path),
ct); ct
);
} }
catch (InvalidObjectNameException) catch (InvalidObjectNameException)
{ {
@ -23,17 +24,28 @@ public class ObjectStorageService(ILogger logger, Config config, IMinioClient mi
} }
} }
public async Task PutObjectAsync(string path, Stream data, string contentType, CancellationToken ct = default) public async Task PutObjectAsync(
string path,
Stream data,
string contentType,
CancellationToken ct = default
)
{ {
_logger.Debug("Putting object at path {Path} with length {Length} and content type {ContentType}", path, _logger.Debug(
data.Length, contentType); "Putting object at path {Path} with length {Length} and content type {ContentType}",
path,
data.Length,
contentType
);
await minioClient.PutObjectAsync(new PutObjectArgs() await minioClient.PutObjectAsync(
new PutObjectArgs()
.WithBucket(config.Storage.Bucket) .WithBucket(config.Storage.Bucket)
.WithObject(path) .WithObject(path)
.WithObjectSize(data.Length) .WithObjectSize(data.Length)
.WithStreamData(data) .WithStreamData(data)
.WithContentType(contentType), ct .WithContentType(contentType),
ct
); );
} }
} }

View file

@ -11,30 +11,42 @@ public class RemoteAuthService(Config config, ILogger logger)
private readonly Uri _discordTokenUri = new("https://discord.com/api/oauth2/token"); private readonly Uri _discordTokenUri = new("https://discord.com/api/oauth2/token");
private readonly Uri _discordUserUri = new("https://discord.com/api/v10/users/@me"); private readonly Uri _discordUserUri = new("https://discord.com/api/v10/users/@me");
public async Task<RemoteUser> RequestDiscordTokenAsync(string code, string state, CancellationToken ct = default) public async Task<RemoteUser> RequestDiscordTokenAsync(
string code,
string state,
CancellationToken ct = default
)
{ {
var redirectUri = $"{config.BaseUrl}/auth/callback/discord"; var redirectUri = $"{config.BaseUrl}/auth/callback/discord";
var resp = await _httpClient.PostAsync(_discordTokenUri, new FormUrlEncodedContent( var resp = await _httpClient.PostAsync(
_discordTokenUri,
new FormUrlEncodedContent(
new Dictionary<string, string> new Dictionary<string, string>
{ {
{ "client_id", config.DiscordAuth.ClientId! }, { "client_id", config.DiscordAuth.ClientId! },
{ "client_secret", config.DiscordAuth.ClientSecret! }, { "client_secret", config.DiscordAuth.ClientSecret! },
{ "grant_type", "authorization_code" }, { "grant_type", "authorization_code" },
{ "code", code }, { "code", code },
{ "redirect_uri", redirectUri } { "redirect_uri", redirectUri },
} }
), ct); ),
ct
);
if (!resp.IsSuccessStatusCode) if (!resp.IsSuccessStatusCode)
{ {
var respBody = await resp.Content.ReadAsStringAsync(ct); var respBody = await resp.Content.ReadAsStringAsync(ct);
_logger.Error("Received error status {StatusCode} when exchanging OAuth token: {ErrorBody}", _logger.Error(
(int)resp.StatusCode, respBody); "Received error status {StatusCode} when exchanging OAuth token: {ErrorBody}",
(int)resp.StatusCode,
respBody
);
throw new FoxnounsError("Invalid Discord OAuth response"); throw new FoxnounsError("Invalid Discord OAuth response");
} }
resp.EnsureSuccessStatusCode(); resp.EnsureSuccessStatusCode();
var token = await resp.Content.ReadFromJsonAsync<DiscordTokenResponse>(ct); var token = await resp.Content.ReadFromJsonAsync<DiscordTokenResponse>(ct);
if (token == null) throw new FoxnounsError("Discord token response was null"); if (token == null)
throw new FoxnounsError("Discord token response was null");
var req = new HttpRequestMessage(HttpMethod.Get, _discordUserUri); var req = new HttpRequestMessage(HttpMethod.Get, _discordUserUri);
req.Headers.Add("Authorization", $"{token.token_type} {token.access_token}"); req.Headers.Add("Authorization", $"{token.token_type} {token.access_token}");
@ -42,18 +54,25 @@ public class RemoteAuthService(Config config, ILogger logger)
var resp2 = await _httpClient.SendAsync(req, ct); var resp2 = await _httpClient.SendAsync(req, ct);
resp2.EnsureSuccessStatusCode(); resp2.EnsureSuccessStatusCode();
var user = await resp2.Content.ReadFromJsonAsync<DiscordUserResponse>(ct); var user = await resp2.Content.ReadFromJsonAsync<DiscordUserResponse>(ct);
if (user == null) throw new FoxnounsError("Discord user response was null"); if (user == null)
throw new FoxnounsError("Discord user response was null");
return new RemoteUser(user.id, user.username); return new RemoteUser(user.id, user.username);
} }
[SuppressMessage("ReSharper", "InconsistentNaming", [SuppressMessage(
Justification = "Easier to use snake_case here, rather than passing in JSON converter options")] "ReSharper",
"InconsistentNaming",
Justification = "Easier to use snake_case here, rather than passing in JSON converter options"
)]
[UsedImplicitly] [UsedImplicitly]
private record DiscordTokenResponse(string access_token, string token_type); private record DiscordTokenResponse(string access_token, string token_type);
[SuppressMessage("ReSharper", "InconsistentNaming", [SuppressMessage(
Justification = "Easier to use snake_case here, rather than passing in JSON converter options")] "ReSharper",
"InconsistentNaming",
Justification = "Easier to use snake_case here, rather than passing in JSON converter options"
)]
[UsedImplicitly] [UsedImplicitly]
private record DiscordUserResponse(string id, string username); private record DiscordUserResponse(string id, string username);

View file

@ -7,48 +7,73 @@ using NodaTime;
namespace Foxnouns.Backend.Services; namespace Foxnouns.Backend.Services;
public class UserRendererService(DatabaseContext db, MemberRendererService memberRenderer, Config config) public class UserRendererService(
DatabaseContext db,
MemberRendererService memberRenderer,
Config config
)
{ {
public async Task<UserResponse> RenderUserAsync(User user, User? selfUser = null, public async Task<UserResponse> RenderUserAsync(
User user,
User? selfUser = null,
Token? token = null, Token? token = null,
bool renderMembers = true, bool renderMembers = true,
bool renderAuthMethods = false, bool renderAuthMethods = false,
CancellationToken ct = default) CancellationToken ct = default
)
{ {
var isSelfUser = selfUser?.Id == user.Id; var isSelfUser = selfUser?.Id == user.Id;
var tokenCanReadHiddenMembers = token.HasScope("member.read") && isSelfUser; var tokenCanReadHiddenMembers = token.HasScope("member.read") && isSelfUser;
var tokenHidden = token.HasScope("user.read_hidden") && isSelfUser; var tokenHidden = token.HasScope("user.read_hidden") && isSelfUser;
var tokenPrivileged = token.HasScope("user.read_privileged") && isSelfUser; var tokenPrivileged = token.HasScope("user.read_privileged") && isSelfUser;
renderMembers = renderMembers && renderMembers = renderMembers && (!user.ListHidden || tokenCanReadHiddenMembers);
(!user.ListHidden || tokenCanReadHiddenMembers);
renderAuthMethods = renderAuthMethods && tokenPrivileged; renderAuthMethods = renderAuthMethods && tokenPrivileged;
IEnumerable<Member> members = IEnumerable<Member> members = renderMembers
renderMembers ? await db.Members.Where(m => m.UserId == user.Id).OrderBy(m => m.Name).ToListAsync(ct) : []; ? await db.Members.Where(m => m.UserId == user.Id).OrderBy(m => m.Name).ToListAsync(ct)
: [];
// Unless the user is requesting their own members AND the token can read hidden members, we filter out unlisted members. // Unless the user is requesting their own members AND the token can read hidden members, we filter out unlisted members.
if (!(isSelfUser && tokenCanReadHiddenMembers)) members = members.Where(m => !m.Unlisted); if (!(isSelfUser && tokenCanReadHiddenMembers))
members = members.Where(m => !m.Unlisted);
var flags = await db.UserFlags.Where(f => f.UserId == user.Id).OrderBy(f => f.Id).ToListAsync(ct); var flags = await db
.UserFlags.Where(f => f.UserId == user.Id)
.OrderBy(f => f.Id)
.ToListAsync(ct);
var authMethods = renderAuthMethods var authMethods = renderAuthMethods
? await db.AuthMethods ? await db
.Where(a => a.UserId == user.Id) .AuthMethods.Where(a => a.UserId == user.Id)
.Include(a => a.FediverseApplication) .Include(a => a.FediverseApplication)
.ToListAsync(ct) .ToListAsync(ct)
: []; : [];
return new UserResponse( return new UserResponse(
user.Id, user.Sid, user.Username, user.DisplayName, user.Bio, user.MemberTitle, AvatarUrlFor(user), user.Id,
user.Sid,
user.Username,
user.DisplayName,
user.Bio,
user.MemberTitle,
AvatarUrlFor(user),
user.Links, user.Links,
user.Names, user.Pronouns, user.Fields, user.CustomPreferences, user.Names,
user.Pronouns,
user.Fields,
user.CustomPreferences,
flags.Select(f => RenderPrideFlag(f.PrideFlag)), flags.Select(f => RenderPrideFlag(f.PrideFlag)),
user.Role, user.Role,
renderMembers ? members.Select(m => memberRenderer.RenderPartialMember(m, tokenHidden)) : null, renderMembers
? members.Select(m => memberRenderer.RenderPartialMember(m, tokenHidden))
: null,
renderAuthMethods renderAuthMethods
? authMethods.Select(a => new AuthenticationMethodResponse( ? authMethods.Select(a => new AuthenticationMethodResponse(
a.Id, a.AuthType, a.RemoteId, a.Id,
a.RemoteUsername, a.FediverseApplication?.Domain a.AuthType,
a.RemoteId,
a.RemoteUsername,
a.FediverseApplication?.Domain
)) ))
: null, : null,
tokenHidden ? user.ListHidden : null, tokenHidden ? user.ListHidden : null,
@ -58,10 +83,19 @@ public class UserRendererService(DatabaseContext db, MemberRendererService membe
} }
public PartialUser RenderPartialUser(User user) => public PartialUser RenderPartialUser(User user) =>
new(user.Id, user.Sid, user.Username, user.DisplayName, AvatarUrlFor(user), user.CustomPreferences); new(
user.Id,
user.Sid,
user.Username,
user.DisplayName,
AvatarUrlFor(user),
user.CustomPreferences
);
private string? AvatarUrlFor(User user) => private string? AvatarUrlFor(User user) =>
user.Avatar != null ? $"{config.MediaBaseUrl}/users/{user.Id}/avatars/{user.Avatar}.webp" : null; user.Avatar != null
? $"{config.MediaBaseUrl}/users/{user.Id}/avatars/{user.Avatar}.webp"
: null;
public string ImageUrlFor(PrideFlag flag) => $"{config.MediaBaseUrl}/flags/{flag.Hash}.webp"; public string ImageUrlFor(PrideFlag flag) => $"{config.MediaBaseUrl}/flags/{flag.Hash}.webp";
@ -79,24 +113,21 @@ public class UserRendererService(DatabaseContext db, MemberRendererService membe
IEnumerable<Field> Fields, IEnumerable<Field> Fields,
Dictionary<Snowflake, User.CustomPreference> CustomPreferences, Dictionary<Snowflake, User.CustomPreference> CustomPreferences,
IEnumerable<PrideFlagResponse> Flags, IEnumerable<PrideFlagResponse> Flags,
[property: JsonConverter(typeof(ScreamingSnakeCaseEnumConverter))] [property: JsonConverter(typeof(ScreamingSnakeCaseEnumConverter))] UserRole Role,
UserRole Role,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] [property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
IEnumerable<MemberRendererService.PartialMember>? Members, IEnumerable<MemberRendererService.PartialMember>? Members,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] [property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
IEnumerable<AuthenticationMethodResponse>? AuthMethods, IEnumerable<AuthenticationMethodResponse>? AuthMethods,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] [property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
bool? MemberListHidden, bool? MemberListHidden,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] [property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] Instant? LastActive,
Instant? LastActive,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] [property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
Instant? LastSidReroll Instant? LastSidReroll
); );
public record AuthenticationMethodResponse( public record AuthenticationMethodResponse(
Snowflake Id, Snowflake Id,
[property: JsonConverter(typeof(ScreamingSnakeCaseEnumConverter))] [property: JsonConverter(typeof(ScreamingSnakeCaseEnumConverter))] AuthType Type,
AuthType Type,
string RemoteId, string RemoteId,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] [property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
string? RemoteUsername, string? RemoteUsername,
@ -120,5 +151,6 @@ public class UserRendererService(DatabaseContext db, MemberRendererService membe
Snowflake Id, Snowflake Id,
string ImageUrl, string ImageUrl,
string Name, string Name,
string? Description); string? Description
);
} }

View file

@ -7,12 +7,28 @@ public static class AuthUtils
{ {
public const string ClientCredentials = "client_credentials"; public const string ClientCredentials = "client_credentials";
public const string AuthorizationCode = "authorization_code"; public const string AuthorizationCode = "authorization_code";
private static readonly string[] ForbiddenSchemes = ["javascript", "file", "data", "mailto", "tel"]; private static readonly string[] ForbiddenSchemes =
[
"javascript",
"file",
"data",
"mailto",
"tel",
];
public static readonly string[] UserScopes = public static readonly string[] UserScopes =
["user.read_hidden", "user.read_privileged", "user.update"]; [
"user.read_hidden",
"user.read_privileged",
"user.update",
];
public static readonly string[] MemberScopes = ["member.read", "member.update", "member.create"]; public static readonly string[] MemberScopes =
[
"member.read",
"member.update",
"member.create",
];
/// <summary> /// <summary>
/// All scopes endpoints can be secured by. This does *not* include the catch-all token scopes. /// All scopes endpoints can be secured by. This does *not* include the catch-all token scopes.
@ -27,10 +43,13 @@ public static class AuthUtils
public static string[] ExpandScopes(this string[] scopes) public static string[] ExpandScopes(this string[] scopes)
{ {
if (scopes.Contains("*")) return ["*", ..Scopes]; if (scopes.Contains("*"))
return ["*", .. Scopes];
List<string> expandedScopes = ["identify"]; List<string> expandedScopes = ["identify"];
if (scopes.Contains("user")) expandedScopes.AddRange(UserScopes); if (scopes.Contains("user"))
if (scopes.Contains("member")) expandedScopes.AddRange(MemberScopes); expandedScopes.AddRange(UserScopes);
if (scopes.Contains("member"))
expandedScopes.AddRange(MemberScopes);
return expandedScopes.ToArray(); return expandedScopes.ToArray();
} }
@ -41,8 +60,10 @@ public static class AuthUtils
private static string[] ExpandAppScopes(this string[] scopes) private static string[] ExpandAppScopes(this string[] scopes)
{ {
var expandedScopes = scopes.ExpandScopes().ToList(); var expandedScopes = scopes.ExpandScopes().ToList();
if (scopes.Contains("user")) expandedScopes.Add("user"); if (scopes.Contains("user"))
if (scopes.Contains("member")) expandedScopes.Add("member"); expandedScopes.Add("user");
if (scopes.Contains("member"))
expandedScopes.Add("member");
return expandedScopes.ToArray(); return expandedScopes.ToArray();
} }
@ -84,7 +105,8 @@ public static class AuthUtils
{ {
rawToken = []; rawToken = [];
if (string.IsNullOrWhiteSpace(input)) return false; if (string.IsNullOrWhiteSpace(input))
return false;
if (input.StartsWith("bearer ", StringComparison.InvariantCultureIgnoreCase)) if (input.StartsWith("bearer ", StringComparison.InvariantCultureIgnoreCase))
input = input["bearer ".Length..]; input = input["bearer ".Length..];

View file

@ -13,7 +13,9 @@ namespace Foxnouns.Backend.Utils;
public abstract class PatchRequest public abstract class PatchRequest
{ {
private readonly HashSet<string> _properties = []; private readonly HashSet<string> _properties = [];
public bool HasProperty(string propertyName) => _properties.Contains(propertyName); public bool HasProperty(string propertyName) => _properties.Contains(propertyName);
public void SetHasProperty(string propertyName) => _properties.Add(propertyName); public void SetHasProperty(string propertyName) => _properties.Add(propertyName);
} }
@ -23,13 +25,17 @@ public abstract class PatchRequest
/// </summary> /// </summary>
public class PatchRequestContractResolver : DefaultContractResolver public class PatchRequestContractResolver : DefaultContractResolver
{ {
protected override JsonProperty CreateProperty(MemberInfo member, MemberSerialization memberSerialization) protected override JsonProperty CreateProperty(
MemberInfo member,
MemberSerialization memberSerialization
)
{ {
var prop = base.CreateProperty(member, memberSerialization); var prop = base.CreateProperty(member, memberSerialization);
prop.SetIsSpecified += (o, _) => prop.SetIsSpecified += (o, _) =>
{ {
if (o is not PatchRequest patchRequest) return; if (o is not PatchRequest patchRequest)
return;
patchRequest.SetHasProperty(prop.UnderlyingName!); patchRequest.SetHasProperty(prop.UnderlyingName!);
}; };

View file

@ -8,7 +8,8 @@ namespace Foxnouns.Backend.Utils;
/// A custom StringEnumConverter that converts enum members to SCREAMING_SNAKE_CASE, rather than CamelCase as is the default. /// A custom StringEnumConverter that converts enum members to SCREAMING_SNAKE_CASE, rather than CamelCase as is the default.
/// Newtonsoft.Json doesn't provide a screaming snake case naming strategy, so we just wrap the normal snake case one and convert it to uppercase. /// Newtonsoft.Json doesn't provide a screaming snake case naming strategy, so we just wrap the normal snake case one and convert it to uppercase.
/// </summary> /// </summary>
public class ScreamingSnakeCaseEnumConverter() : StringEnumConverter(new ScreamingSnakeCaseNamingStrategy(), false) public class ScreamingSnakeCaseEnumConverter()
: StringEnumConverter(new ScreamingSnakeCaseNamingStrategy(), false)
{ {
private class ScreamingSnakeCaseNamingStrategy : SnakeCaseNamingStrategy private class ScreamingSnakeCaseNamingStrategy : SnakeCaseNamingStrategy
{ {

View file

@ -21,7 +21,7 @@ public static partial class ValidationUtils
"pronouns", "pronouns",
"settings", "settings",
"pronouns.cc", "pronouns.cc",
"pronounscc" "pronounscc",
]; ];
private static readonly string[] InvalidMemberNames = private static readonly string[] InvalidMemberNames =
@ -30,7 +30,7 @@ public static partial class ValidationUtils
".", ".",
"..", "..",
// the user edit page lives at `/@{username}/edit`, so a member named "edit" would be inaccessible // the user edit page lives at `/@{username}/edit`, so a member named "edit" would be inaccessible
"edit" "edit",
]; ];
public static ValidationError? ValidateUsername(string username) public static ValidationError? ValidateUsername(string username)
@ -42,10 +42,15 @@ public static partial class ValidationUtils
> 40 => ValidationError.LengthError("Username is too long", 2, 40, username.Length), > 40 => ValidationError.LengthError("Username is too long", 2, 40, username.Length),
_ => ValidationError.GenericValidationError( _ => ValidationError.GenericValidationError(
"Username is invalid, can only contain alphanumeric characters, dashes, underscores, and periods", "Username is invalid, can only contain alphanumeric characters, dashes, underscores, and periods",
username) username
),
}; };
if (InvalidUsernames.Any(u => string.Equals(u, username, StringComparison.InvariantCultureIgnoreCase))) if (
InvalidUsernames.Any(u =>
string.Equals(u, username, StringComparison.InvariantCultureIgnoreCase)
)
)
return ValidationError.GenericValidationError("Username is not allowed", username); return ValidationError.GenericValidationError("Username is not allowed", username);
return null; return null;
} }
@ -58,13 +63,18 @@ public static partial class ValidationUtils
< 1 => ValidationError.LengthError("Name is too short", 1, 100, memberName.Length), < 1 => ValidationError.LengthError("Name is too short", 1, 100, memberName.Length),
> 100 => ValidationError.LengthError("Name is too long", 1, 100, memberName.Length), > 100 => ValidationError.LengthError("Name is too long", 1, 100, memberName.Length),
_ => ValidationError.GenericValidationError( _ => ValidationError.GenericValidationError(
"Member name cannot contain any of the following: " + "Member name cannot contain any of the following: "
" @, ?, !, #, /, \\, [, ], \", ', $, %, &, (, ), {, }, +, <, =, >, ^, |, ~, `, , " + + " @, ?, !, #, /, \\, [, ], \", ', $, %, &, (, ), {, }, +, <, =, >, ^, |, ~, `, , "
"and cannot be one or two periods", + "and cannot be one or two periods",
memberName) memberName
),
}; };
if (InvalidMemberNames.Any(u => string.Equals(u, memberName, StringComparison.InvariantCultureIgnoreCase))) if (
InvalidMemberNames.Any(u =>
string.Equals(u, memberName, StringComparison.InvariantCultureIgnoreCase)
)
)
return ValidationError.GenericValidationError("Name is not allowed", memberName); return ValidationError.GenericValidationError("Name is not allowed", memberName);
return null; return null;
} }
@ -72,12 +82,14 @@ public static partial class ValidationUtils
public static void Validate(IEnumerable<(string, ValidationError?)> errors) public static void Validate(IEnumerable<(string, ValidationError?)> errors)
{ {
errors = errors.Where(e => e.Item2 != null).ToList(); errors = errors.Where(e => e.Item2 != null).ToList();
if (!errors.Any()) return; if (!errors.Any())
return;
var errorDict = new Dictionary<string, IEnumerable<ValidationError>>(); var errorDict = new Dictionary<string, IEnumerable<ValidationError>>();
foreach (var error in errors) foreach (var error in errors)
{ {
if (errorDict.TryGetValue(error.Item1, out var value)) errorDict[error.Item1] = value.Append(error.Item2!); if (errorDict.TryGetValue(error.Item1, out var value))
errorDict[error.Item1] = value.Append(error.Item2!);
errorDict.Add(error.Item1, [error.Item2!]); errorDict.Add(error.Item1, [error.Item2!]);
} }
@ -88,9 +100,19 @@ public static partial class ValidationUtils
{ {
return displayName?.Length switch return displayName?.Length switch
{ {
0 => ValidationError.LengthError("Display name is too short", 1, 100, displayName.Length), 0 => ValidationError.LengthError(
> 100 => ValidationError.LengthError("Display name is too long", 1, 100, displayName.Length), "Display name is too short",
_ => null 1,
100,
displayName.Length
),
> 100 => ValidationError.LengthError(
"Display name is too long",
1,
100,
displayName.Length
),
_ => null,
}; };
} }
@ -99,9 +121,13 @@ public static partial class ValidationUtils
public static IEnumerable<(string, ValidationError?)> ValidateLinks(string[]? links) public static IEnumerable<(string, ValidationError?)> ValidateLinks(string[]? links)
{ {
if (links == null) return []; if (links == null)
return [];
if (links.Length > MaxLinks) if (links.Length > MaxLinks)
return [("links", ValidationError.LengthError("Too many links", 0, MaxLinks, links.Length))]; return
[
("links", ValidationError.LengthError("Too many links", 0, MaxLinks, links.Length)),
];
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
foreach (var (link, idx) in links.Select((l, i) => (l, i))) foreach (var (link, idx) in links.Select((l, i) => (l, i)))
@ -109,12 +135,25 @@ public static partial class ValidationUtils
switch (link.Length) switch (link.Length)
{ {
case 0: case 0:
errors.Add(($"links.{idx}", errors.Add(
ValidationError.LengthError("Link cannot be empty", 1, 256, 0))); (
$"links.{idx}",
ValidationError.LengthError("Link cannot be empty", 1, 256, 0)
)
);
break; break;
case > MaxLinkLength: case > MaxLinkLength:
errors.Add(($"links.{idx}", errors.Add(
ValidationError.LengthError("Link is too long", 1, MaxLinkLength, link.Length))); (
$"links.{idx}",
ValidationError.LengthError(
"Link is too long",
1,
MaxLinkLength,
link.Length
)
)
);
break; break;
} }
} }
@ -129,8 +168,13 @@ public static partial class ValidationUtils
return bio?.Length switch return bio?.Length switch
{ {
0 => ValidationError.LengthError("Bio is too short", 1, MaxBioLength, bio.Length), 0 => ValidationError.LengthError("Bio is too short", 1, MaxBioLength, bio.Length),
> MaxBioLength => ValidationError.LengthError("Bio is too long", 1, MaxBioLength, bio.Length), > MaxBioLength => ValidationError.LengthError(
_ => null "Bio is too long",
1,
MaxBioLength,
bio.Length
),
_ => null,
}; };
} }
@ -140,121 +184,222 @@ public static partial class ValidationUtils
{ {
0 => ValidationError.GenericValidationError("Avatar cannot be empty", null), 0 => ValidationError.GenericValidationError("Avatar cannot be empty", null),
> 1_500_000 => ValidationError.GenericValidationError("Avatar is too large", null), > 1_500_000 => ValidationError.GenericValidationError("Avatar is too large", null),
_ => null _ => null,
}; };
} }
private static readonly string[] DefaultStatusOptions = private static readonly string[] DefaultStatusOptions =
[ [
"favourite", "favourite",
"okay", "okay",
"jokingly", "jokingly",
"friends_only", "friends_only",
"avoid" "avoid",
]; ];
public static IEnumerable<(string, ValidationError?)> ValidateFields(List<Field>? fields, public static IEnumerable<(string, ValidationError?)> ValidateFields(
IReadOnlyDictionary<Snowflake, User.CustomPreference> customPreferences) List<Field>? fields,
IReadOnlyDictionary<Snowflake, User.CustomPreference> customPreferences
)
{ {
if (fields == null) return []; if (fields == null)
return [];
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
if (fields.Count > 25) if (fields.Count > 25)
errors.Add(("fields", ValidationError.LengthError("Too many fields", 0, Limits.FieldLimit, fields.Count))); errors.Add(
(
"fields",
ValidationError.LengthError(
"Too many fields",
0,
Limits.FieldLimit,
fields.Count
)
)
);
// No overwhelming this function, thank you // No overwhelming this function, thank you
if (fields.Count > 100) return errors; if (fields.Count > 100)
return errors;
foreach (var (field, index) in fields.Select((field, index) => (field, index))) foreach (var (field, index) in fields.Select((field, index) => (field, index)))
{ {
switch (field.Name.Length) switch (field.Name.Length)
{ {
case > Limits.FieldNameLimit: case > Limits.FieldNameLimit:
errors.Add(($"fields.{index}.name", errors.Add(
ValidationError.LengthError("Field name is too long", 1, Limits.FieldNameLimit, (
field.Name.Length))); $"fields.{index}.name",
ValidationError.LengthError(
"Field name is too long",
1,
Limits.FieldNameLimit,
field.Name.Length
)
)
);
break; break;
case < 1: case < 1:
errors.Add(($"fields.{index}.name", errors.Add(
ValidationError.LengthError("Field name is too short", 1, Limits.FieldNameLimit, (
field.Name.Length))); $"fields.{index}.name",
ValidationError.LengthError(
"Field name is too short",
1,
Limits.FieldNameLimit,
field.Name.Length
)
)
);
break; break;
} }
errors = errors.Concat(ValidateFieldEntries(field.Entries, customPreferences, $"fields.{index}.entries")) errors = errors
.Concat(
ValidateFieldEntries(
field.Entries,
customPreferences,
$"fields.{index}.entries"
)
)
.ToList(); .ToList();
} }
return errors; return errors;
} }
public static IEnumerable<(string, ValidationError?)> ValidateFieldEntries(FieldEntry[]? entries, public static IEnumerable<(string, ValidationError?)> ValidateFieldEntries(
IReadOnlyDictionary<Snowflake, User.CustomPreference> customPreferences, string errorPrefix = "fields") FieldEntry[]? entries,
IReadOnlyDictionary<Snowflake, User.CustomPreference> customPreferences,
string errorPrefix = "fields"
)
{ {
if (entries == null || entries.Length == 0) return []; if (entries == null || entries.Length == 0)
return [];
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
if (entries.Length > Limits.FieldEntriesLimit) if (entries.Length > Limits.FieldEntriesLimit)
errors.Add((errorPrefix, errors.Add(
ValidationError.LengthError("Field has too many entries", 0, Limits.FieldEntriesLimit, (
entries.Length))); errorPrefix,
ValidationError.LengthError(
"Field has too many entries",
0,
Limits.FieldEntriesLimit,
entries.Length
)
)
);
// Same as above, no overwhelming this function with a ridiculous amount of entries // Same as above, no overwhelming this function with a ridiculous amount of entries
if (entries.Length > Limits.FieldEntriesLimit + 50) return errors; if (entries.Length > Limits.FieldEntriesLimit + 50)
return errors;
foreach (var (entry, entryIdx) in entries.Select((entry, entryIdx) => (entry, entryIdx))) foreach (var (entry, entryIdx) in entries.Select((entry, entryIdx) => (entry, entryIdx)))
{ {
switch (entry.Value.Length) switch (entry.Value.Length)
{ {
case > Limits.FieldEntryTextLimit: case > Limits.FieldEntryTextLimit:
errors.Add(($"{errorPrefix}.{entryIdx}.value", errors.Add(
ValidationError.LengthError("Field value is too long", 1, Limits.FieldEntryTextLimit, (
entry.Value.Length))); $"{errorPrefix}.{entryIdx}.value",
ValidationError.LengthError(
"Field value is too long",
1,
Limits.FieldEntryTextLimit,
entry.Value.Length
)
)
);
break; break;
case < 1: case < 1:
errors.Add(($"{errorPrefix}.{entryIdx}.value", errors.Add(
ValidationError.LengthError("Field value is too short", 1, Limits.FieldEntryTextLimit, (
entry.Value.Length))); $"{errorPrefix}.{entryIdx}.value",
ValidationError.LengthError(
"Field value is too short",
1,
Limits.FieldEntryTextLimit,
entry.Value.Length
)
)
);
break; break;
} }
var customPreferenceIds = customPreferences?.Keys.Select(id => id.ToString()) ?? []; var customPreferenceIds = customPreferences?.Keys.Select(id => id.ToString()) ?? [];
if (!DefaultStatusOptions.Contains(entry.Status) && !customPreferenceIds.Contains(entry.Status)) if (
errors.Add(($"{errorPrefix}.{entryIdx}.status", !DefaultStatusOptions.Contains(entry.Status)
ValidationError.GenericValidationError("Invalid status", entry.Status))); && !customPreferenceIds.Contains(entry.Status)
)
errors.Add(
(
$"{errorPrefix}.{entryIdx}.status",
ValidationError.GenericValidationError("Invalid status", entry.Status)
)
);
} }
return errors; return errors;
} }
public static IEnumerable<(string, ValidationError?)> ValidatePronouns(Pronoun[]? entries, public static IEnumerable<(string, ValidationError?)> ValidatePronouns(
IReadOnlyDictionary<Snowflake, User.CustomPreference> customPreferences, string errorPrefix = "pronouns") Pronoun[]? entries,
IReadOnlyDictionary<Snowflake, User.CustomPreference> customPreferences,
string errorPrefix = "pronouns"
)
{ {
if (entries == null || entries.Length == 0) return []; if (entries == null || entries.Length == 0)
return [];
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
if (entries.Length > Limits.FieldEntriesLimit) if (entries.Length > Limits.FieldEntriesLimit)
errors.Add((errorPrefix, errors.Add(
ValidationError.LengthError("Too many pronouns", 0, Limits.FieldEntriesLimit, (
entries.Length))); errorPrefix,
ValidationError.LengthError(
"Too many pronouns",
0,
Limits.FieldEntriesLimit,
entries.Length
)
)
);
// Same as above, no overwhelming this function with a ridiculous amount of entries // Same as above, no overwhelming this function with a ridiculous amount of entries
if (entries.Length > Limits.FieldEntriesLimit + 50) return errors; if (entries.Length > Limits.FieldEntriesLimit + 50)
return errors;
foreach (var (entry, entryIdx) in entries.Select((entry, entryIdx) => (entry, entryIdx))) foreach (var (entry, entryIdx) in entries.Select((entry, entryIdx) => (entry, entryIdx)))
{ {
switch (entry.Value.Length) switch (entry.Value.Length)
{ {
case > Limits.FieldEntryTextLimit: case > Limits.FieldEntryTextLimit:
errors.Add(($"{errorPrefix}.{entryIdx}.value", errors.Add(
ValidationError.LengthError("Pronoun value is too long", 1, Limits.FieldEntryTextLimit, (
entry.Value.Length))); $"{errorPrefix}.{entryIdx}.value",
ValidationError.LengthError(
"Pronoun value is too long",
1,
Limits.FieldEntryTextLimit,
entry.Value.Length
)
)
);
break; break;
case < 1: case < 1:
errors.Add(($"{errorPrefix}.{entryIdx}.value", errors.Add(
ValidationError.LengthError("Pronoun value is too short", 1, Limits.FieldEntryTextLimit, (
entry.Value.Length))); $"{errorPrefix}.{entryIdx}.value",
ValidationError.LengthError(
"Pronoun value is too short",
1,
Limits.FieldEntryTextLimit,
entry.Value.Length
)
)
);
break; break;
} }
@ -263,25 +408,46 @@ public static partial class ValidationUtils
switch (entry.DisplayText.Length) switch (entry.DisplayText.Length)
{ {
case > Limits.FieldEntryTextLimit: case > Limits.FieldEntryTextLimit:
errors.Add(($"{errorPrefix}.{entryIdx}.value", errors.Add(
ValidationError.LengthError("Pronoun display text is too long", 1, (
$"{errorPrefix}.{entryIdx}.value",
ValidationError.LengthError(
"Pronoun display text is too long",
1,
Limits.FieldEntryTextLimit, Limits.FieldEntryTextLimit,
entry.Value.Length))); entry.Value.Length
)
)
);
break; break;
case < 1: case < 1:
errors.Add(($"{errorPrefix}.{entryIdx}.value", errors.Add(
ValidationError.LengthError("Pronoun display text is too short", 1, (
$"{errorPrefix}.{entryIdx}.value",
ValidationError.LengthError(
"Pronoun display text is too short",
1,
Limits.FieldEntryTextLimit, Limits.FieldEntryTextLimit,
entry.Value.Length))); entry.Value.Length
)
)
);
break; break;
} }
} }
var customPreferenceIds = customPreferences?.Keys.Select(id => id.ToString()) ?? []; var customPreferenceIds = customPreferences?.Keys.Select(id => id.ToString()) ?? [];
if (!DefaultStatusOptions.Contains(entry.Status) && !customPreferenceIds.Contains(entry.Status)) if (
errors.Add(($"{errorPrefix}.{entryIdx}.status", !DefaultStatusOptions.Contains(entry.Status)
ValidationError.GenericValidationError("Invalid status", entry.Status))); && !customPreferenceIds.Contains(entry.Status)
)
errors.Add(
(
$"{errorPrefix}.{entryIdx}.status",
ValidationError.GenericValidationError("Invalid status", entry.Status)
)
);
} }
return errors; return errors;
@ -290,6 +456,10 @@ public static partial class ValidationUtils
[GeneratedRegex(@"^[a-zA-Z_0-9\-\.]{2,40}$", RegexOptions.IgnoreCase, "en-NL")] [GeneratedRegex(@"^[a-zA-Z_0-9\-\.]{2,40}$", RegexOptions.IgnoreCase, "en-NL")]
private static partial Regex UsernameRegex(); private static partial Regex UsernameRegex();
[GeneratedRegex("""^[^@'$%&()+<=>^|~`,*!#/\\\[\]""\{\}\?]{1,100}$""", RegexOptions.IgnoreCase, "en-NL")] [GeneratedRegex(
"""^[^@'$%&()+<=>^|~`,*!#/\\\[\]""\{\}\?]{1,100}$""",
RegexOptions.IgnoreCase,
"en-NL"
)]
private static partial Regex MemberRegex(); private static partial Regex MemberRegex();
} }

View file

@ -19,12 +19,19 @@ public static class Users
var stopwatch = new Stopwatch(); var stopwatch = new Stopwatch();
stopwatch.Start(); stopwatch.Start();
var users = NetImporter.ReadFromFile<ImportUser>(filename).Output.Select(ConvertUser).ToList(); var users = NetImporter
.ReadFromFile<ImportUser>(filename)
.Output.Select(ConvertUser)
.ToList();
db.AddRange(users); db.AddRange(users);
await db.SaveChangesAsync(); await db.SaveChangesAsync();
stopwatch.Stop(); stopwatch.Stop();
Log.Information("Imported {Count} users in {Duration}", users.Count, stopwatch.ElapsedDuration()); Log.Information(
"Imported {Count} users in {Duration}",
users.Count,
stopwatch.ElapsedDuration()
);
} }
private static User ConvertUser(ImportUser oldUser) private static User ConvertUser(ImportUser oldUser)
@ -43,40 +50,46 @@ public static class Users
Role = oldUser.ParseRole(), Role = oldUser.ParseRole(),
Deleted = oldUser.Deleted, Deleted = oldUser.Deleted,
DeletedAt = oldUser.DeletedAt?.ToInstant(), DeletedAt = oldUser.DeletedAt?.ToInstant(),
DeletedBy = null DeletedBy = null,
}; };
if (oldUser is { DiscordId: not null, DiscordUsername: not null }) if (oldUser is { DiscordId: not null, DiscordUsername: not null })
{ {
user.AuthMethods.Add(new AuthMethod user.AuthMethods.Add(
new AuthMethod
{ {
Id = SnowflakeGenerator.Instance.GenerateSnowflake(), Id = SnowflakeGenerator.Instance.GenerateSnowflake(),
AuthType = AuthType.Discord, AuthType = AuthType.Discord,
RemoteId = oldUser.DiscordId, RemoteId = oldUser.DiscordId,
RemoteUsername = oldUser.DiscordUsername RemoteUsername = oldUser.DiscordUsername,
}); }
);
} }
if (oldUser is { TumblrId: not null, TumblrUsername: not null }) if (oldUser is { TumblrId: not null, TumblrUsername: not null })
{ {
user.AuthMethods.Add(new AuthMethod user.AuthMethods.Add(
new AuthMethod
{ {
Id = SnowflakeGenerator.Instance.GenerateSnowflake(), Id = SnowflakeGenerator.Instance.GenerateSnowflake(),
AuthType = AuthType.Tumblr, AuthType = AuthType.Tumblr,
RemoteId = oldUser.TumblrId, RemoteId = oldUser.TumblrId,
RemoteUsername = oldUser.TumblrUsername RemoteUsername = oldUser.TumblrUsername,
}); }
);
} }
if (oldUser is { GoogleId: not null, GoogleUsername: not null }) if (oldUser is { GoogleId: not null, GoogleUsername: not null })
{ {
user.AuthMethods.Add(new AuthMethod user.AuthMethods.Add(
new AuthMethod
{ {
Id = SnowflakeGenerator.Instance.GenerateSnowflake(), Id = SnowflakeGenerator.Instance.GenerateSnowflake(),
AuthType = AuthType.Google, AuthType = AuthType.Google,
RemoteId = oldUser.GoogleId, RemoteId = oldUser.GoogleId,
RemoteUsername = oldUser.GoogleUsername RemoteUsername = oldUser.GoogleUsername,
}); }
);
} }
// Convert all custom preference UUIDs to snowflakes // Convert all custom preference UUIDs to snowflakes
@ -90,28 +103,35 @@ public static class Users
foreach (var name in oldUser.Names ?? []) foreach (var name in oldUser.Names ?? [])
{ {
user.Names.Add(new FieldEntry user.Names.Add(
new FieldEntry
{ {
Value = name.Value, Value = name.Value,
Status = prefMapping.TryGetValue(name.Status, out var newStatus) ? newStatus.ToString() : name.Status, Status = prefMapping.TryGetValue(name.Status, out var newStatus)
}); ? newStatus.ToString()
: name.Status,
}
);
} }
foreach (var pronoun in oldUser.Pronouns ?? []) foreach (var pronoun in oldUser.Pronouns ?? [])
{ {
user.Pronouns.Add(new Pronoun user.Pronouns.Add(
new Pronoun
{ {
Value = pronoun.Value, Value = pronoun.Value,
DisplayText = pronoun.DisplayText, DisplayText = pronoun.DisplayText,
Status = prefMapping.TryGetValue(pronoun.Status, out var newStatus) Status = prefMapping.TryGetValue(pronoun.Status, out var newStatus)
? newStatus.ToString() ? newStatus.ToString()
: pronoun.Status, : pronoun.Status,
}); }
);
} }
foreach (var field in oldUser.Fields ?? []) foreach (var field in oldUser.Fields ?? [])
{ {
var entries = field.Entries.Select(entry => new FieldEntry var entries = field
.Entries.Select(entry => new FieldEntry
{ {
Value = entry.Value, Value = entry.Value,
Status = prefMapping.TryGetValue(entry.Status, out var newStatus) Status = prefMapping.TryGetValue(entry.Status, out var newStatus)
@ -120,11 +140,7 @@ public static class Users
}) })
.ToList(); .ToList();
user.Fields.Add(new Field user.Fields.Add(new Field { Name = field.Name, Entries = entries.ToArray() });
{
Name = field.Name,
Entries = entries.ToArray()
});
} }
Log.Debug("Converted user {UserId}", oldUser.Id); Log.Debug("Converted user {UserId}", oldUser.Id);
@ -161,14 +177,16 @@ public static class Users
bool Deleted, bool Deleted,
OffsetDateTime? DeletedAt, OffsetDateTime? DeletedAt,
string? DeleteReason, string? DeleteReason,
Dictionary<string, User.CustomPreference> CustomPreferences) Dictionary<string, User.CustomPreference> CustomPreferences
)
{ {
public UserRole ParseRole() => Role switch public UserRole ParseRole() =>
Role switch
{ {
"USER" => UserRole.User, "USER" => UserRole.User,
"MODERATOR" => UserRole.Moderator, "MODERATOR" => UserRole.Moderator,
"ADMIN" => UserRole.Admin, "ADMIN" => UserRole.Admin,
_ => UserRole.User _ => UserRole.User,
}; };
} }
} }

View file

@ -19,7 +19,10 @@ internal static class NetImporter
.Enrich.FromLogContext() .Enrich.FromLogContext()
.MinimumLevel.Debug() .MinimumLevel.Debug()
.MinimumLevel.Override("Microsoft", LogEventLevel.Information) .MinimumLevel.Override("Microsoft", LogEventLevel.Information)
.MinimumLevel.Override("Microsoft.EntityFrameworkCore.Database.Command", LogEventLevel.Information) .MinimumLevel.Override(
"Microsoft.EntityFrameworkCore.Database.Command",
LogEventLevel.Information
)
.WriteTo.Console() .WriteTo.Console()
.CreateLogger(); .CreateLogger();
@ -47,16 +50,11 @@ internal static class NetImporter
internal static async Task<DatabaseContext> GetContextAsync() internal static async Task<DatabaseContext> GetContextAsync()
{ {
var connString = Environment.GetEnvironmentVariable("DATABASE"); var connString = Environment.GetEnvironmentVariable("DATABASE");
if (connString == null) throw new Exception("$DATABASE not set, must be an ADO.NET connection string"); if (connString == null)
throw new Exception("$DATABASE not set, must be an ADO.NET connection string");
var loggerFactory = new LoggerFactory().AddSerilog(Log.Logger); var loggerFactory = new LoggerFactory().AddSerilog(Log.Logger);
var config = new Config var config = new Config { Database = new Config.DatabaseConfig { Url = connString } };
{
Database = new Config.DatabaseConfig
{
Url = connString
}
};
var db = new DatabaseContext(config, loggerFactory); var db = new DatabaseContext(config, loggerFactory);
@ -70,13 +68,17 @@ internal static class NetImporter
private static readonly JsonSerializerSettings Settings = new JsonSerializerSettings private static readonly JsonSerializerSettings Settings = new JsonSerializerSettings
{ {
ContractResolver = new DefaultContractResolver { NamingStrategy = new SnakeCaseNamingStrategy() } ContractResolver = new DefaultContractResolver
{
NamingStrategy = new SnakeCaseNamingStrategy(),
},
}.ConfigureForNodaTime(DateTimeZoneProviders.Tzdb); }.ConfigureForNodaTime(DateTimeZoneProviders.Tzdb);
internal static Input<T> ReadFromFile<T>(string path) internal static Input<T> ReadFromFile<T>(string path)
{ {
var data = File.ReadAllText(path); var data = File.ReadAllText(path);
return JsonConvert.DeserializeObject<Input<T>>(data, Settings) ?? throw new Exception("Invalid input file"); return JsonConvert.DeserializeObject<Input<T>>(data, Settings)
?? throw new Exception("Invalid input file");
} }
} }