Compare commits

...

3 commits

71 changed files with 1158 additions and 832 deletions

View file

@ -1,8 +1,52 @@
[*.cs] [*]
# We use PostgresSQL which doesn't recommend more specific string types # We use PostgresSQL which doesn't recommend more specific string types
resharper_entity_framework_model_validation_unlimited_string_length_highlighting = none resharper_entity_framework_model_validation_unlimited_string_length_highlighting = none
# This is raised for every single property of records returned by endpoints # This is raised for every single property of records returned by endpoints
resharper_not_accessed_positional_property_local_highlighting = none resharper_not_accessed_positional_property_local_highlighting = none
# Microsoft .NET properties
csharp_new_line_before_members_in_object_initializers = false
csharp_preferred_modifier_order = public, internal, protected, private, file, new, required, abstract, virtual, sealed, static, override, extern, unsafe, volatile, async, readonly:suggestion
# ReSharper properties
resharper_align_multiline_binary_expressions_chain = false
resharper_arguments_skip_single = true
resharper_blank_lines_after_start_comment = 0
resharper_blank_lines_around_single_line_invocable = 0
resharper_blank_lines_before_block_statements = 0
resharper_braces_for_foreach = required_for_multiline
resharper_braces_for_ifelse = required_for_multiline
resharper_braces_redundant = false
resharper_csharp_blank_lines_around_field = 0
resharper_csharp_empty_block_style = together_same_line
resharper_csharp_max_line_length = 166
resharper_csharp_wrap_after_declaration_lpar = true
resharper_csharp_wrap_before_binary_opsign = true
resharper_csharp_wrap_before_declaration_rpar = true
resharper_csharp_wrap_parameters_style = chop_if_long
resharper_indent_preprocessor_other = do_not_change
resharper_instance_members_qualify_declared_in =
resharper_keep_existing_attribute_arrangement = true
resharper_max_attribute_length_for_same_line = 70
resharper_place_accessorholder_attribute_on_same_line = false
resharper_place_expr_method_on_single_line = if_owner_is_single_line
resharper_place_method_attribute_on_same_line = if_owner_is_single_line
resharper_place_record_field_attribute_on_same_line = true
resharper_place_simple_embedded_statement_on_same_line = false
resharper_place_simple_initializer_on_single_line = false
resharper_place_simple_list_pattern_on_single_line = false
resharper_space_within_empty_braces = false
resharper_trailing_comma_in_multiline_lists = true
resharper_wrap_after_invocation_lpar = false
resharper_wrap_before_invocation_rpar = false
resharper_wrap_before_primary_constructor_declaration_rpar = true
resharper_wrap_chained_binary_patterns = chop_if_long
resharper_wrap_list_pattern = chop_always
resharper_wrap_object_and_collection_initializer_style = chop_always
# Roslynator properties
dotnet_diagnostic.RCS1194.severity = none
[*generated.cs] [*generated.cs]
generated_code = true generated_code = true

View file

@ -7,19 +7,21 @@ public static class BuildInfo
public static async Task ReadBuildInfo() public static async Task ReadBuildInfo()
{ {
await using var stream = typeof(BuildInfo).Assembly.GetManifestResourceStream("version"); await using Stream? stream = typeof(BuildInfo).Assembly.GetManifestResourceStream(
"version"
);
if (stream == null) if (stream == null)
return; return;
using var reader = new StreamReader(stream); using var reader = new StreamReader(stream);
var data = (await reader.ReadToEndAsync()).Trim().Split("\n"); string[] data = (await reader.ReadToEndAsync()).Trim().Split("\n");
if (data.Length < 3) if (data.Length < 3)
return; return;
Hash = data[0]; Hash = data[0];
var dirty = data[2] == "dirty"; bool dirty = data[2] == "dirty";
var versionData = data[1].Split("-"); string[] versionData = data[1].Split("-");
if (versionData.Length < 3) if (versionData.Length < 3)
return; return;
Version = versionData[0]; Version = versionData[0];

View file

@ -33,14 +33,16 @@ public class AuthController(
config.GoogleAuth.Enabled, config.GoogleAuth.Enabled,
config.TumblrAuth.Enabled config.TumblrAuth.Enabled
); );
var state = HttpUtility.UrlEncode(await keyCacheService.GenerateAuthStateAsync(ct)); string state = HttpUtility.UrlEncode(await keyCacheService.GenerateAuthStateAsync(ct));
string? discord = null; string? discord = null;
if (config.DiscordAuth is { ClientId: not null, ClientSecret: not null }) if (config.DiscordAuth is { ClientId: not null, ClientSecret: not null })
{
discord = discord =
$"https://discord.com/oauth2/authorize?response_type=code" "https://discord.com/oauth2/authorize?response_type=code"
+ $"&client_id={config.DiscordAuth.ClientId}&scope=identify" + $"&client_id={config.DiscordAuth.ClientId}&scope=identify"
+ $"&prompt=none&state={state}" + $"&prompt=none&state={state}"
+ $"&redirect_uri={HttpUtility.UrlEncode($"{config.BaseUrl}/auth/callback/discord")}"; + $"&redirect_uri={HttpUtility.UrlEncode($"{config.BaseUrl}/auth/callback/discord")}";
}
return Ok(new UrlsResponse(config.EmailAuth.Enabled, discord, null, null)); return Ok(new UrlsResponse(config.EmailAuth.Enabled, discord, null, null));
} }
@ -86,7 +88,7 @@ public class AuthController(
)] )]
public async Task<IActionResult> GetAuthMethodAsync(Snowflake id) public async Task<IActionResult> GetAuthMethodAsync(Snowflake id)
{ {
var authMethod = await db AuthMethod? authMethod = await db
.AuthMethods.Include(a => a.FediverseApplication) .AuthMethods.Include(a => a.FediverseApplication)
.FirstOrDefaultAsync(a => a.UserId == CurrentUser!.Id && a.Id == id); .FirstOrDefaultAsync(a => a.UserId == CurrentUser!.Id && a.Id == id);
if (authMethod == null) if (authMethod == null)
@ -99,17 +101,19 @@ public class AuthController(
[Authorize("*")] [Authorize("*")]
public async Task<IActionResult> DeleteAuthMethodAsync(Snowflake id) public async Task<IActionResult> DeleteAuthMethodAsync(Snowflake id)
{ {
var authMethods = await db List<AuthMethod> authMethods = await db
.AuthMethods.Where(a => a.UserId == CurrentUser!.Id) .AuthMethods.Where(a => a.UserId == CurrentUser!.Id)
.ToListAsync(); .ToListAsync();
if (authMethods.Count < 2) if (authMethods.Count < 2)
{
throw new ApiError( throw new ApiError(
"You cannot remove your last authentication method.", "You cannot remove your last authentication method.",
HttpStatusCode.BadRequest, HttpStatusCode.BadRequest,
ErrorCode.LastAuthMethod ErrorCode.LastAuthMethod
); );
}
var authMethod = authMethods.FirstOrDefault(a => a.Id == id); AuthMethod? authMethod = authMethods.FirstOrDefault(a => a.Id == id);
if (authMethod == null) if (authMethod == null)
throw new ApiError.NotFound("No authentication method with that ID found."); throw new ApiError.NotFound("No authentication method with that ID found.");
@ -119,6 +123,20 @@ public class AuthController(
CurrentUser!.Id CurrentUser!.Id
); );
// If this is the user's last email, we should also clear the user's password.
if (
authMethod.AuthType == AuthType.Email
&& authMethods.Count(a => a.AuthType == AuthType.Email) == 1
)
{
_logger.Debug(
"Deleted last email address for user {UserId}, resetting their password",
CurrentUser.Id
);
CurrentUser.Password = null;
db.Update(CurrentUser);
}
db.Remove(authMethod); db.Remove(authMethod);
await db.SaveChangesAsync(); await db.SaveChangesAsync();

View file

@ -34,8 +34,10 @@ public class DiscordAuthController(
CheckRequirements(); CheckRequirements();
await keyCacheService.ValidateAuthStateAsync(req.State); await keyCacheService.ValidateAuthStateAsync(req.State);
var remoteUser = await remoteAuthService.RequestDiscordTokenAsync(req.Code); RemoteAuthService.RemoteUser remoteUser = await remoteAuthService.RequestDiscordTokenAsync(
var user = await authService.AuthenticateUserAsync(AuthType.Discord, remoteUser.Id); req.Code
);
User? user = await authService.AuthenticateUserAsync(AuthType.Discord, remoteUser.Id);
if (user != null) if (user != null)
return Ok(await authService.GenerateUserTokenAsync(user)); return Ok(await authService.GenerateUserTokenAsync(user));
@ -45,23 +47,14 @@ public class DiscordAuthController(
remoteUser.Id remoteUser.Id
); );
var ticket = AuthUtils.RandomToken(); string ticket = AuthUtils.RandomToken();
await keyCacheService.SetKeyAsync( await keyCacheService.SetKeyAsync(
$"discord:{ticket}", $"discord:{ticket}",
remoteUser, remoteUser,
Duration.FromMinutes(20) Duration.FromMinutes(20)
); );
return Ok( return Ok(new CallbackResponse(false, ticket, remoteUser.Username, null, null, null));
new CallbackResponse(
HasAccount: false,
Ticket: ticket,
RemoteUsername: remoteUser.Username,
User: null,
Token: null,
ExpiresAt: null
)
);
} }
[HttpPost("register")] [HttpPost("register")]
@ -70,7 +63,8 @@ public class DiscordAuthController(
[FromBody] AuthController.OauthRegisterRequest req [FromBody] AuthController.OauthRegisterRequest req
) )
{ {
var remoteUser = await keyCacheService.GetKeyAsync<RemoteAuthService.RemoteUser>( RemoteAuthService.RemoteUser? remoteUser =
await keyCacheService.GetKeyAsync<RemoteAuthService.RemoteUser>(
$"discord:{req.Ticket}" $"discord:{req.Ticket}"
); );
if (remoteUser == null) if (remoteUser == null)
@ -88,7 +82,7 @@ public class DiscordAuthController(
throw new ApiError.BadRequest("Invalid ticket", "ticket", req.Ticket); throw new ApiError.BadRequest("Invalid ticket", "ticket", req.Ticket);
} }
var user = await authService.CreateUserWithRemoteAuthAsync( User user = await authService.CreateUserWithRemoteAuthAsync(
req.Username, req.Username,
AuthType.Discord, AuthType.Discord,
remoteUser.Id, remoteUser.Id,
@ -104,13 +98,13 @@ public class DiscordAuthController(
{ {
CheckRequirements(); CheckRequirements();
var state = await remoteAuthService.ValidateAddAccountRequestAsync( string state = await remoteAuthService.ValidateAddAccountRequestAsync(
CurrentUser!.Id, CurrentUser!.Id,
AuthType.Discord AuthType.Discord
); );
var url = string url =
$"https://discord.com/oauth2/authorize?response_type=code" "https://discord.com/oauth2/authorize?response_type=code"
+ $"&client_id={config.DiscordAuth.ClientId}&scope=identify" + $"&client_id={config.DiscordAuth.ClientId}&scope=identify"
+ $"&prompt=none&state={state}" + $"&prompt=none&state={state}"
+ $"&redirect_uri={HttpUtility.UrlEncode($"{config.BaseUrl}/auth/callback/discord")}"; + $"&redirect_uri={HttpUtility.UrlEncode($"{config.BaseUrl}/auth/callback/discord")}";
@ -132,10 +126,12 @@ public class DiscordAuthController(
AuthType.Discord AuthType.Discord
); );
var remoteUser = await remoteAuthService.RequestDiscordTokenAsync(req.Code); RemoteAuthService.RemoteUser remoteUser = await remoteAuthService.RequestDiscordTokenAsync(
req.Code
);
try try
{ {
var authMethod = await authService.AddAuthMethodAsync( AuthMethod authMethod = await authService.AddAuthMethodAsync(
CurrentUser.Id, CurrentUser.Id,
AuthType.Discord, AuthType.Discord,
remoteUser.Id, remoteUser.Id,
@ -169,8 +165,10 @@ public class DiscordAuthController(
private void CheckRequirements() private void CheckRequirements()
{ {
if (!config.DiscordAuth.Enabled) if (!config.DiscordAuth.Enabled)
{
throw new ApiError.BadRequest( throw new ApiError.BadRequest(
"Discord authentication is not enabled on this instance." "Discord authentication is not enabled on this instance."
); );
} }
} }
}

View file

@ -1,3 +1,5 @@
using System.Net;
using EntityFramework.Exceptions.Common;
using Foxnouns.Backend.Database; using Foxnouns.Backend.Database;
using Foxnouns.Backend.Database.Models; using Foxnouns.Backend.Database.Models;
using Foxnouns.Backend.Extensions; using Foxnouns.Backend.Extensions;
@ -26,8 +28,8 @@ public class EmailAuthController(
{ {
private readonly ILogger _logger = logger.ForContext<EmailAuthController>(); private readonly ILogger _logger = logger.ForContext<EmailAuthController>();
[HttpPost("register")] [HttpPost("register/init")]
public async Task<IActionResult> RegisterAsync( public async Task<IActionResult> RegisterInitAsync(
[FromBody] RegisterRequest req, [FromBody] RegisterRequest req,
CancellationToken ct = default CancellationToken ct = default
) )
@ -37,11 +39,7 @@ public class EmailAuthController(
if (!req.Email.Contains('@')) if (!req.Email.Contains('@'))
throw new ApiError.BadRequest("Email is invalid", "email", req.Email); throw new ApiError.BadRequest("Email is invalid", "email", req.Email);
var state = await keyCacheService.GenerateRegisterEmailStateAsync( string state = await keyCacheService.GenerateRegisterEmailStateAsync(req.Email, null, ct);
req.Email,
userId: null,
ct
);
// If there's already a user with that email address, pretend we sent an email but actually ignore it // If there's already a user with that email address, pretend we sent an email but actually ignore it
if ( if (
@ -50,7 +48,9 @@ public class EmailAuthController(
ct ct
) )
) )
{
return NoContent(); return NoContent();
}
mailService.QueueAccountCreationEmail(req.Email, state); mailService.QueueAccountCreationEmail(req.Email, state);
return NoContent(); return NoContent();
@ -61,62 +61,35 @@ public class EmailAuthController(
{ {
CheckRequirements(); CheckRequirements();
var state = await keyCacheService.GetRegisterEmailStateAsync(req.State); RegisterEmailState? state = await keyCacheService.GetRegisterEmailStateAsync(req.State);
if (state == null) if (state is not { ExistingUserId: null })
throw new ApiError.BadRequest("Invalid state", "state", req.State); throw new ApiError.BadRequest("Invalid state", "state", req.State);
// If this callback is for an existing user, add the email address to their auth methods string ticket = AuthUtils.RandomToken();
if (state.ExistingUserId != null)
{
var authMethod = await authService.AddAuthMethodAsync(
state.ExistingUserId.Value,
AuthType.Email,
state.Email
);
_logger.Debug(
"Added email auth {AuthId} for user {UserId}",
authMethod.Id,
state.ExistingUserId
);
return NoContent();
}
var ticket = AuthUtils.RandomToken();
await keyCacheService.SetKeyAsync($"email:{ticket}", state.Email, Duration.FromMinutes(20)); await keyCacheService.SetKeyAsync($"email:{ticket}", state.Email, Duration.FromMinutes(20));
return Ok( return Ok(new CallbackResponse(false, ticket, state.Email, null, null, null));
new CallbackResponse(
HasAccount: false,
Ticket: ticket,
RemoteUsername: state.Email,
User: null,
Token: null,
ExpiresAt: null
)
);
} }
[HttpPost("complete-registration")] [HttpPost("register")]
public async Task<IActionResult> CompleteRegistrationAsync( public async Task<IActionResult> CompleteRegistrationAsync(
[FromBody] CompleteRegistrationRequest req [FromBody] CompleteRegistrationRequest req
) )
{ {
CheckRequirements(); CheckRequirements();
var email = await keyCacheService.GetKeyAsync($"email:{req.Ticket}"); string? email = await keyCacheService.GetKeyAsync($"email:{req.Ticket}");
if (email == null) if (email == null)
throw new ApiError.BadRequest("Unknown ticket", "ticket", req.Ticket); throw new ApiError.BadRequest("Unknown ticket", "ticket", req.Ticket);
// Check if username is valid at all User user = await authService.CreateUserWithPasswordAsync(
ValidationUtils.Validate([("username", ValidationUtils.ValidateUsername(req.Username))]); req.Username,
// Check if username is already taken email,
if (await db.Users.AnyAsync(u => u.Username == req.Username)) req.Password
throw new ApiError.BadRequest("Username is already taken", "username", req.Username); );
Application frontendApp = await db.GetFrontendApplicationAsync();
var user = await authService.CreateUserWithPasswordAsync(req.Username, email, req.Password); (string? tokenStr, Token? token) = authService.GenerateToken(
var frontendApp = await db.GetFrontendApplicationAsync();
var (tokenStr, token) = authService.GenerateToken(
user, user,
frontendApp, frontendApp,
["*"], ["*"],
@ -130,7 +103,7 @@ public class EmailAuthController(
return Ok( return Ok(
new AuthController.AuthResponse( new AuthController.AuthResponse(
await userRenderer.RenderUserAsync(user, selfUser: user, renderMembers: false), await userRenderer.RenderUserAsync(user, user, renderMembers: false),
tokenStr, tokenStr,
token.ExpiresAt token.ExpiresAt
) )
@ -146,19 +119,16 @@ public class EmailAuthController(
{ {
CheckRequirements(); CheckRequirements();
var (user, authenticationResult) = await authService.AuthenticateUserAsync( (User? user, AuthService.EmailAuthenticationResult authenticationResult) =
req.Email, await authService.AuthenticateUserAsync(req.Email, req.Password, ct);
req.Password,
ct
);
if (authenticationResult == AuthService.EmailAuthenticationResult.MfaRequired) if (authenticationResult == AuthService.EmailAuthenticationResult.MfaRequired)
throw new NotImplementedException("MFA is not implemented yet"); throw new NotImplementedException("MFA is not implemented yet");
var frontendApp = await db.GetFrontendApplicationAsync(ct); Application frontendApp = await db.GetFrontendApplicationAsync(ct);
_logger.Debug("Logging user {Id} in with email and password", user.Id); _logger.Debug("Logging user {Id} in with email and password", user.Id);
var (tokenStr, token) = authService.GenerateToken( (string? tokenStr, Token? token) = authService.GenerateToken(
user, user,
frontendApp, frontendApp,
["*"], ["*"],
@ -172,25 +142,34 @@ public class EmailAuthController(
return Ok( return Ok(
new AuthController.AuthResponse( new AuthController.AuthResponse(
await userRenderer.RenderUserAsync( await userRenderer.RenderUserAsync(user, user, renderMembers: false, ct: ct),
user,
selfUser: user,
renderMembers: false,
ct: ct
),
tokenStr, tokenStr,
token.ExpiresAt token.ExpiresAt
) )
); );
} }
[HttpPost("add")] [HttpPost("change-password")]
[Authorize("*")]
public async Task<IActionResult> UpdatePasswordAsync([FromBody] ChangePasswordRequest req)
{
if (!await authService.ValidatePasswordAsync(CurrentUser!, req.Current))
throw new ApiError.Forbidden("Invalid password");
ValidationUtils.Validate([("new", ValidationUtils.ValidatePassword(req.New))]);
await authService.SetUserPasswordAsync(CurrentUser!, req.New);
await db.SaveChangesAsync();
return NoContent();
}
[HttpPost("add-email")]
[Authorize("*")] [Authorize("*")]
public async Task<IActionResult> AddEmailAddressAsync([FromBody] AddEmailAddressRequest req) public async Task<IActionResult> AddEmailAddressAsync([FromBody] AddEmailAddressRequest req)
{ {
CheckRequirements(); CheckRequirements();
var emails = await db List<AuthMethod> emails = await db
.AuthMethods.Where(m => m.UserId == CurrentUser!.Id && m.AuthType == AuthType.Email) .AuthMethods.Where(m => m.UserId == CurrentUser!.Id && m.AuthType == AuthType.Email)
.ToListAsync(); .ToListAsync();
if (emails.Count > AuthUtils.MaxAuthMethodsPerType) if (emails.Count > AuthUtils.MaxAuthMethodsPerType)
@ -204,24 +183,21 @@ public class EmailAuthController(
if (emails.Count != 0) if (emails.Count != 0)
{ {
var validPassword = await authService.ValidatePasswordAsync(CurrentUser!, req.Password); if (!await authService.ValidatePasswordAsync(CurrentUser!, req.Password))
if (!validPassword)
{
throw new ApiError.Forbidden("Invalid password"); throw new ApiError.Forbidden("Invalid password");
} }
}
else else
{ {
await authService.SetUserPasswordAsync(CurrentUser!, req.Password); await authService.SetUserPasswordAsync(CurrentUser!, req.Password);
await db.SaveChangesAsync(); await db.SaveChangesAsync();
} }
var state = await keyCacheService.GenerateRegisterEmailStateAsync( string state = await keyCacheService.GenerateRegisterEmailStateAsync(
req.Email, req.Email,
userId: CurrentUser!.Id CurrentUser!.Id
); );
var emailExists = await db bool emailExists = await db
.AuthMethods.Where(m => m.AuthType == AuthType.Email && m.RemoteId == req.Email) .AuthMethods.Where(m => m.AuthType == AuthType.Email && m.RemoteId == req.Email)
.AnyAsync(); .AnyAsync();
if (emailExists) if (emailExists)
@ -233,6 +209,48 @@ public class EmailAuthController(
return NoContent(); return NoContent();
} }
[HttpPost("add-email/callback")]
[Authorize("*")]
public async Task<IActionResult> AddEmailCallbackAsync([FromBody] CallbackRequest req)
{
CheckRequirements();
RegisterEmailState? state = await keyCacheService.GetRegisterEmailStateAsync(req.State);
if (state?.ExistingUserId != CurrentUser!.Id)
throw new ApiError.BadRequest("Invalid state", "state", req.State);
try
{
AuthMethod authMethod = await authService.AddAuthMethodAsync(
CurrentUser.Id,
AuthType.Email,
state.Email
);
_logger.Debug(
"Added email auth {AuthId} for user {UserId}",
authMethod.Id,
CurrentUser.Id
);
return Ok(
new AuthController.AddOauthAccountResponse(
authMethod.Id,
AuthType.Email,
authMethod.RemoteId,
null
)
);
}
catch (UniqueConstraintException)
{
throw new ApiError(
"That email address is already linked.",
HttpStatusCode.BadRequest,
ErrorCode.AccountAlreadyLinked
);
}
}
public record AddEmailAddressRequest(string Email, string Password); public record AddEmailAddressRequest(string Email, string Password);
private void CheckRequirements() private void CheckRequirements()
@ -248,4 +266,6 @@ public class EmailAuthController(
public record CompleteRegistrationRequest(string Ticket, string Username, string Password); public record CompleteRegistrationRequest(string Ticket, string Username, string Password);
public record CallbackRequest(string State); public record CallbackRequest(string State);
public record ChangePasswordRequest(string Current, string New);
} }

View file

@ -34,7 +34,7 @@ public class FediverseAuthController(
if (instance.Any(c => c is '@' or ':' or '/') || !instance.Contains('.')) if (instance.Any(c => c is '@' or ':' or '/') || !instance.Contains('.'))
throw new ApiError.BadRequest("Not a valid domain.", "instance", instance); throw new ApiError.BadRequest("Not a valid domain.", "instance", instance);
var url = await fediverseAuthService.GenerateAuthUrlAsync(instance, forceRefresh); string url = await fediverseAuthService.GenerateAuthUrlAsync(instance, forceRefresh);
return Ok(new AuthController.SingleUrlResponse(url)); return Ok(new AuthController.SingleUrlResponse(url));
} }
@ -42,22 +42,19 @@ public class FediverseAuthController(
[ProducesResponseType<CallbackResponse>(statusCode: StatusCodes.Status200OK)] [ProducesResponseType<CallbackResponse>(statusCode: StatusCodes.Status200OK)]
public async Task<IActionResult> FediverseCallbackAsync([FromBody] CallbackRequest req) public async Task<IActionResult> FediverseCallbackAsync([FromBody] CallbackRequest req)
{ {
var app = await fediverseAuthService.GetApplicationAsync(req.Instance); FediverseApplication app = await fediverseAuthService.GetApplicationAsync(req.Instance);
var remoteUser = await fediverseAuthService.GetRemoteFediverseUserAsync( FediverseAuthService.FediverseUser remoteUser =
app, await fediverseAuthService.GetRemoteFediverseUserAsync(app, req.Code, req.State);
req.Code,
req.State
);
var user = await authService.AuthenticateUserAsync( User? user = await authService.AuthenticateUserAsync(
AuthType.Fediverse, AuthType.Fediverse,
remoteUser.Id, remoteUser.Id,
instance: app app
); );
if (user != null) if (user != null)
return Ok(await authService.GenerateUserTokenAsync(user)); return Ok(await authService.GenerateUserTokenAsync(user));
var ticket = AuthUtils.RandomToken(); string ticket = AuthUtils.RandomToken();
await keyCacheService.SetKeyAsync( await keyCacheService.SetKeyAsync(
$"fediverse:{ticket}", $"fediverse:{ticket}",
new FediverseTicketData(app.Id, remoteUser), new FediverseTicketData(app.Id, remoteUser),
@ -66,12 +63,12 @@ public class FediverseAuthController(
return Ok( return Ok(
new CallbackResponse( new CallbackResponse(
HasAccount: false, false,
Ticket: ticket, ticket,
RemoteUsername: $"@{remoteUser.Username}@{app.Domain}", $"@{remoteUser.Username}@{app.Domain}",
User: null, null,
Token: null, null,
ExpiresAt: null null
) )
); );
} }
@ -82,14 +79,16 @@ public class FediverseAuthController(
[FromBody] AuthController.OauthRegisterRequest req [FromBody] AuthController.OauthRegisterRequest req
) )
{ {
var ticketData = await keyCacheService.GetKeyAsync<FediverseTicketData>( FediverseTicketData? ticketData = await keyCacheService.GetKeyAsync<FediverseTicketData>(
$"fediverse:{req.Ticket}", $"fediverse:{req.Ticket}",
delete: true true
); );
if (ticketData == null) if (ticketData == null)
throw new ApiError.BadRequest("Invalid ticket", "ticket", req.Ticket); throw new ApiError.BadRequest("Invalid ticket", "ticket", req.Ticket);
var app = await db.FediverseApplications.FindAsync(ticketData.ApplicationId); FediverseApplication? app = await db.FediverseApplications.FindAsync(
ticketData.ApplicationId
);
if (app == null) if (app == null)
throw new FoxnounsError("Null application found for ticket"); throw new FoxnounsError("Null application found for ticket");
@ -111,12 +110,12 @@ public class FediverseAuthController(
throw new ApiError.BadRequest("Invalid ticket", "ticket", req.Ticket); throw new ApiError.BadRequest("Invalid ticket", "ticket", req.Ticket);
} }
var user = await authService.CreateUserWithRemoteAuthAsync( User user = await authService.CreateUserWithRemoteAuthAsync(
req.Username, req.Username,
AuthType.Fediverse, AuthType.Fediverse,
ticketData.User.Id, ticketData.User.Id,
ticketData.User.Username, ticketData.User.Username,
instance: app app
); );
return Ok(await authService.GenerateUserTokenAsync(user)); return Ok(await authService.GenerateUserTokenAsync(user));
@ -132,13 +131,13 @@ public class FediverseAuthController(
if (instance.Any(c => c is '@' or ':' or '/') || !instance.Contains('.')) if (instance.Any(c => c is '@' or ':' or '/') || !instance.Contains('.'))
throw new ApiError.BadRequest("Not a valid domain.", "instance", instance); throw new ApiError.BadRequest("Not a valid domain.", "instance", instance);
var state = await remoteAuthService.ValidateAddAccountRequestAsync( string state = await remoteAuthService.ValidateAddAccountRequestAsync(
CurrentUser!.Id, CurrentUser!.Id,
AuthType.Fediverse, AuthType.Fediverse,
instance instance
); );
var url = await fediverseAuthService.GenerateAuthUrlAsync(instance, forceRefresh, state); string url = await fediverseAuthService.GenerateAuthUrlAsync(instance, forceRefresh, state);
return Ok(new AuthController.SingleUrlResponse(url)); return Ok(new AuthController.SingleUrlResponse(url));
} }
@ -153,11 +152,12 @@ public class FediverseAuthController(
req.Instance req.Instance
); );
var app = await fediverseAuthService.GetApplicationAsync(req.Instance); FediverseApplication app = await fediverseAuthService.GetApplicationAsync(req.Instance);
var remoteUser = await fediverseAuthService.GetRemoteFediverseUserAsync(app, req.Code); FediverseAuthService.FediverseUser remoteUser =
await fediverseAuthService.GetRemoteFediverseUserAsync(app, req.Code);
try try
{ {
var authMethod = await authService.AddAuthMethodAsync( AuthMethod authMethod = await authService.AddAuthMethodAsync(
CurrentUser.Id, CurrentUser.Id,
AuthType.Fediverse, AuthType.Fediverse,
remoteUser.Id, remoteUser.Id,

View file

@ -25,7 +25,7 @@ public class ExportsController(
[HttpGet] [HttpGet]
public async Task<IActionResult> GetDataExportsAsync() public async Task<IActionResult> GetDataExportsAsync()
{ {
var export = await db DataExport? export = await db
.DataExports.Where(d => d.UserId == CurrentUser!.Id) .DataExports.Where(d => d.UserId == CurrentUser!.Id)
.OrderByDescending(d => d.Id) .OrderByDescending(d => d.Id)
.FirstOrDefaultAsync(); .FirstOrDefaultAsync();

View file

@ -1,5 +1,6 @@
using Coravel.Queuing.Interfaces; using Coravel.Queuing.Interfaces;
using Foxnouns.Backend.Database; using Foxnouns.Backend.Database;
using Foxnouns.Backend.Database.Models;
using Foxnouns.Backend.Extensions; using Foxnouns.Backend.Extensions;
using Foxnouns.Backend.Jobs; using Foxnouns.Backend.Jobs;
using Foxnouns.Backend.Middleware; using Foxnouns.Backend.Middleware;
@ -7,6 +8,7 @@ using Foxnouns.Backend.Services;
using Foxnouns.Backend.Utils; using Foxnouns.Backend.Utils;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Storage;
namespace Foxnouns.Backend.Controllers; namespace Foxnouns.Backend.Controllers;
@ -29,7 +31,9 @@ public class FlagsController(
)] )]
public async Task<IActionResult> GetFlagsAsync(CancellationToken ct = default) public async Task<IActionResult> GetFlagsAsync(CancellationToken ct = default)
{ {
var flags = await db.PrideFlags.Where(f => f.UserId == CurrentUser!.Id).ToListAsync(ct); List<PrideFlag> flags = await db
.PrideFlags.Where(f => f.UserId == CurrentUser!.Id)
.ToListAsync(ct);
return Ok(flags.Select(userRenderer.RenderPrideFlag)); return Ok(flags.Select(userRenderer.RenderPrideFlag));
} }
@ -43,7 +47,7 @@ public class FlagsController(
{ {
ValidationUtils.Validate(ValidateFlag(req.Name, req.Description, req.Image)); ValidationUtils.Validate(ValidateFlag(req.Name, req.Description, req.Image));
var id = snowflakeGenerator.GenerateSnowflake(); Snowflake id = snowflakeGenerator.GenerateSnowflake();
queue.QueueInvocableWithPayload<CreateFlagInvocable, CreateFlagPayload>( queue.QueueInvocableWithPayload<CreateFlagInvocable, CreateFlagPayload>(
new CreateFlagPayload(id, CurrentUser!.Id, req.Name, req.Image, req.Description) new CreateFlagPayload(id, CurrentUser!.Id, req.Name, req.Image, req.Description)
@ -62,7 +66,7 @@ public class FlagsController(
{ {
ValidationUtils.Validate(ValidateFlag(req.Name, req.Description, null)); ValidationUtils.Validate(ValidateFlag(req.Name, req.Description, null));
var flag = await db.PrideFlags.FirstOrDefaultAsync(f => PrideFlag? flag = await db.PrideFlags.FirstOrDefaultAsync(f =>
f.Id == id && f.UserId == CurrentUser!.Id f.Id == id && f.UserId == CurrentUser!.Id
); );
if (flag == null) if (flag == null)
@ -90,20 +94,20 @@ public class FlagsController(
[Authorize("user.update")] [Authorize("user.update")]
public async Task<IActionResult> DeleteFlagAsync(Snowflake id) public async Task<IActionResult> DeleteFlagAsync(Snowflake id)
{ {
await using var tx = await db.Database.BeginTransactionAsync(); await using IDbContextTransaction tx = await db.Database.BeginTransactionAsync();
var flag = await db.PrideFlags.FirstOrDefaultAsync(f => PrideFlag? flag = await db.PrideFlags.FirstOrDefaultAsync(f =>
f.Id == id && f.UserId == CurrentUser!.Id f.Id == id && f.UserId == CurrentUser!.Id
); );
if (flag == null) if (flag == null)
throw new ApiError.NotFound("Unknown flag ID, or it's not your flag."); throw new ApiError.NotFound("Unknown flag ID, or it's not your flag.");
var hash = flag.Hash; string hash = flag.Hash;
db.PrideFlags.Remove(flag); db.PrideFlags.Remove(flag);
await db.SaveChangesAsync(); await db.SaveChangesAsync();
var flagCount = await db.PrideFlags.CountAsync(f => f.Hash == flag.Hash); int flagCount = await db.PrideFlags.CountAsync(f => f.Hash == flag.Hash);
if (flagCount == 0) if (flagCount == 0)
{ {
try try
@ -120,7 +124,9 @@ public class FlagsController(
} }
} }
else else
{
_logger.Debug("Flag file {Hash} is used by other flags, not deleting", hash); _logger.Debug("Flag file {Hash} is used by other flags, not deleting", hash);
}
await tx.CommitAsync(); await tx.CommitAsync();

View file

@ -44,21 +44,22 @@ public partial class InternalController(DatabaseContext db) : ControllerBase
[HttpPost("request-data")] [HttpPost("request-data")]
public async Task<IActionResult> GetRequestDataAsync([FromBody] RequestDataRequest req) public async Task<IActionResult> GetRequestDataAsync([FromBody] RequestDataRequest req)
{ {
var endpoint = GetEndpoint(HttpContext, req.Path, req.Method); RouteEndpoint? endpoint = GetEndpoint(HttpContext, req.Path, req.Method);
if (endpoint == null) if (endpoint == null)
throw new ApiError.BadRequest("Path/method combination is invalid"); throw new ApiError.BadRequest("Path/method combination is invalid");
var actionDescriptor = endpoint.Metadata.GetMetadata<ControllerActionDescriptor>(); ControllerActionDescriptor? actionDescriptor =
var template = actionDescriptor?.AttributeRouteInfo?.Template; endpoint.Metadata.GetMetadata<ControllerActionDescriptor>();
string? template = actionDescriptor?.AttributeRouteInfo?.Template;
if (template == null) if (template == null)
throw new FoxnounsError("Template value was null on valid endpoint"); throw new FoxnounsError("Template value was null on valid endpoint");
template = GetCleanedTemplate(template); template = GetCleanedTemplate(template);
// If no token was supplied, or it isn't valid base 64, return a null user ID (limiting by IP) // If no token was supplied, or it isn't valid base 64, return a null user ID (limiting by IP)
if (!AuthUtils.TryParseToken(req.Token, out var rawToken)) if (!AuthUtils.TryParseToken(req.Token, out byte[]? rawToken))
return Ok(new RequestDataResponse(null, template)); return Ok(new RequestDataResponse(null, template));
var userId = await db.GetTokenUserId(rawToken); Snowflake? userId = await db.GetTokenUserId(rawToken);
return Ok(new RequestDataResponse(userId, template)); return Ok(new RequestDataResponse(userId, template));
} }
@ -72,12 +73,13 @@ public partial class InternalController(DatabaseContext db) : ControllerBase
string requestMethod string requestMethod
) )
{ {
var endpointDataSource = httpContext.RequestServices.GetService<EndpointDataSource>(); EndpointDataSource? endpointDataSource =
httpContext.RequestServices.GetService<EndpointDataSource>();
if (endpointDataSource == null) if (endpointDataSource == null)
return null; return null;
var endpoints = endpointDataSource.Endpoints.OfType<RouteEndpoint>(); IEnumerable<RouteEndpoint> endpoints = endpointDataSource.Endpoints.OfType<RouteEndpoint>();
foreach (var endpoint in endpoints) foreach (RouteEndpoint? endpoint in endpoints)
{ {
if (endpoint.RoutePattern.RawText == null) if (endpoint.RoutePattern.RawText == null)
continue; continue;
@ -86,16 +88,19 @@ public partial class InternalController(DatabaseContext db) : ControllerBase
TemplateParser.Parse(endpoint.RoutePattern.RawText), TemplateParser.Parse(endpoint.RoutePattern.RawText),
new RouteValueDictionary() new RouteValueDictionary()
); );
if (!templateMatcher.TryMatch(url, new())) if (!templateMatcher.TryMatch(url, new RouteValueDictionary()))
continue; continue;
var httpMethodAttribute = endpoint.Metadata.GetMetadata<HttpMethodAttribute>(); HttpMethodAttribute? httpMethodAttribute =
endpoint.Metadata.GetMetadata<HttpMethodAttribute>();
if ( if (
httpMethodAttribute != null httpMethodAttribute?.HttpMethods.Any(x =>
&& !httpMethodAttribute.HttpMethods.Any(x =>
x.Equals(requestMethod, StringComparison.OrdinalIgnoreCase) x.Equals(requestMethod, StringComparison.OrdinalIgnoreCase)
) == false
) )
) {
continue; continue;
}
return endpoint; return endpoint;
} }

View file

@ -9,6 +9,7 @@ using Foxnouns.Backend.Services;
using Foxnouns.Backend.Utils; using Foxnouns.Backend.Utils;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Storage;
using NodaTime; using NodaTime;
namespace Foxnouns.Backend.Controllers; namespace Foxnouns.Backend.Controllers;
@ -32,7 +33,7 @@ public class MembersController(
)] )]
public async Task<IActionResult> GetMembersAsync(string userRef, CancellationToken ct = default) public async Task<IActionResult> GetMembersAsync(string userRef, CancellationToken ct = default)
{ {
var user = await db.ResolveUserAsync(userRef, CurrentToken, ct); User user = await db.ResolveUserAsync(userRef, CurrentToken, ct);
return Ok(await memberRenderer.RenderUserMembersAsync(user, CurrentToken)); return Ok(await memberRenderer.RenderUserMembersAsync(user, CurrentToken));
} }
@ -44,7 +45,7 @@ public class MembersController(
CancellationToken ct = default CancellationToken ct = default
) )
{ {
var member = await db.ResolveMemberAsync(userRef, memberRef, CurrentToken, ct); Member member = await db.ResolveMemberAsync(userRef, memberRef, CurrentToken, ct);
return Ok(memberRenderer.RenderMember(member, CurrentToken)); return Ok(memberRenderer.RenderMember(member, CurrentToken));
} }
@ -78,7 +79,7 @@ public class MembersController(
] ]
); );
var memberCount = await db.Members.CountAsync(m => m.UserId == CurrentUser.Id, ct); int memberCount = await db.Members.CountAsync(m => m.UserId == CurrentUser.Id, ct);
if (memberCount >= MaxMemberCount) if (memberCount >= MaxMemberCount)
throw new ApiError.BadRequest("Maximum number of members reached"); throw new ApiError.BadRequest("Maximum number of members reached");
@ -120,9 +121,11 @@ public class MembersController(
} }
if (req.Avatar != null) if (req.Avatar != null)
{
queue.QueueInvocableWithPayload<MemberAvatarUpdateInvocable, AvatarUpdatePayload>( queue.QueueInvocableWithPayload<MemberAvatarUpdateInvocable, AvatarUpdatePayload>(
new AvatarUpdatePayload(member.Id, req.Avatar) new AvatarUpdatePayload(member.Id, req.Avatar)
); );
}
return Ok(memberRenderer.RenderMember(member, CurrentToken)); return Ok(memberRenderer.RenderMember(member, CurrentToken));
} }
@ -134,8 +137,8 @@ public class MembersController(
[FromBody] UpdateMemberRequest req [FromBody] UpdateMemberRequest req
) )
{ {
await using var tx = await db.Database.BeginTransactionAsync(); await using IDbContextTransaction tx = await db.Database.BeginTransactionAsync();
var member = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef); Member member = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef);
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
// We might add extra validations for names later down the line. // We might add extra validations for names later down the line.
@ -197,7 +200,11 @@ public class MembersController(
if (req.Flags != null) if (req.Flags != null)
{ {
var flagError = await db.SetMemberFlagsAsync(CurrentUser!.Id, member.Id, req.Flags); ValidationError? flagError = await db.SetMemberFlagsAsync(
CurrentUser!.Id,
member.Id,
req.Flags
);
if (flagError != null) if (flagError != null)
errors.Add(("flags", flagError)); errors.Add(("flags", flagError));
} }
@ -210,9 +217,12 @@ public class MembersController(
// (atomic operations are hard when combined with background jobs) // (atomic operations are hard when combined with background jobs)
// so it's in a separate block to the validation above. // so it's in a separate block to the validation above.
if (req.HasProperty(nameof(req.Avatar))) if (req.HasProperty(nameof(req.Avatar)))
{
queue.QueueInvocableWithPayload<MemberAvatarUpdateInvocable, AvatarUpdatePayload>( queue.QueueInvocableWithPayload<MemberAvatarUpdateInvocable, AvatarUpdatePayload>(
new AvatarUpdatePayload(member.Id, req.Avatar) new AvatarUpdatePayload(member.Id, req.Avatar)
); );
}
try try
{ {
await db.SaveChangesAsync(); await db.SaveChangesAsync();
@ -228,7 +238,7 @@ public class MembersController(
throw new ApiError.BadRequest( throw new ApiError.BadRequest(
"A member with that name already exists", "A member with that name already exists",
"name", "name",
req.Name! req.Name
); );
} }
@ -254,8 +264,8 @@ public class MembersController(
[Authorize("member.update")] [Authorize("member.update")]
public async Task<IActionResult> DeleteMemberAsync(string memberRef) public async Task<IActionResult> DeleteMemberAsync(string memberRef)
{ {
var member = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef); Member member = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef);
var deleteCount = await db int deleteCount = await db
.Members.Where(m => m.UserId == CurrentUser!.Id && m.Id == member.Id) .Members.Where(m => m.UserId == CurrentUser!.Id && m.Id == member.Id)
.ExecuteDeleteAsync(); .ExecuteDeleteAsync();
if (deleteCount == 0) if (deleteCount == 0)
@ -289,9 +299,9 @@ public class MembersController(
[ProducesResponseType<UserRendererService.UserResponse>(statusCode: StatusCodes.Status200OK)] [ProducesResponseType<UserRendererService.UserResponse>(statusCode: StatusCodes.Status200OK)]
public async Task<IActionResult> RerollSidAsync(string memberRef) public async Task<IActionResult> RerollSidAsync(string memberRef)
{ {
var member = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef); Member member = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef);
var minTimeAgo = clock.GetCurrentInstant() - Duration.FromHours(1); Instant minTimeAgo = clock.GetCurrentInstant() - Duration.FromHours(1);
if (CurrentUser!.LastSidReroll > minTimeAgo) if (CurrentUser!.LastSidReroll > minTimeAgo)
throw new ApiError.BadRequest("Cannot reroll short ID yet"); throw new ApiError.BadRequest("Cannot reroll short ID yet");
@ -308,7 +318,10 @@ public class MembersController(
); );
// Fetch the new sid then pass that to RenderMember // Fetch the new sid then pass that to RenderMember
var newSid = await db.Members.Where(m => m.Id == member.Id).Select(m => m.Sid).FirstAsync(); string newSid = await db
.Members.Where(m => m.Id == member.Id)
.Select(m => m.Sid)
.FirstAsync();
return Ok(memberRenderer.RenderMember(member, CurrentToken, newSid)); return Ok(memberRenderer.RenderMember(member, CurrentToken, newSid));
} }
} }

View file

@ -10,9 +10,8 @@ public class MetaController : ApiControllerBase
[HttpGet] [HttpGet]
[ProducesResponseType<MetaResponse>(StatusCodes.Status200OK)] [ProducesResponseType<MetaResponse>(StatusCodes.Status200OK)]
public IActionResult GetMeta() public IActionResult GetMeta() =>
{ Ok(
return Ok(
new MetaResponse( new MetaResponse(
Repository, Repository,
BuildInfo.Version, BuildInfo.Version,
@ -25,14 +24,13 @@ public class MetaController : ApiControllerBase
(int)FoxnounsMetrics.UsersActiveDayCount.Value (int)FoxnounsMetrics.UsersActiveDayCount.Value
), ),
new Limits( new Limits(
MemberCount: MembersController.MaxMemberCount, MembersController.MaxMemberCount,
BioLength: ValidationUtils.MaxBioLength, ValidationUtils.MaxBioLength,
CustomPreferences: ValidationUtils.MaxCustomPreferences, ValidationUtils.MaxCustomPreferences,
MaxAuthMethods: AuthUtils.MaxAuthMethodsPerType AuthUtils.MaxAuthMethodsPerType
) )
) )
); );
}
[HttpGet("/api/v2/coffee")] [HttpGet("/api/v2/coffee")]
public IActionResult BrewCoffee() => public IActionResult BrewCoffee() =>

View file

@ -24,7 +24,7 @@ public class SidController(Config config, DatabaseContext db) : ApiControllerBas
private async Task<IActionResult> ResolveUserSidAsync(string id, CancellationToken ct = default) private async Task<IActionResult> ResolveUserSidAsync(string id, CancellationToken ct = default)
{ {
var username = await db string? username = await db
.Users.Where(u => u.Sid == id.ToLowerInvariant() && !u.Deleted) .Users.Where(u => u.Sid == id.ToLowerInvariant() && !u.Deleted)
.Select(u => u.Username) .Select(u => u.Username)
.FirstOrDefaultAsync(ct); .FirstOrDefaultAsync(ct);

View file

@ -9,6 +9,7 @@ using Foxnouns.Backend.Services;
using Foxnouns.Backend.Utils; using Foxnouns.Backend.Utils;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Storage;
using NodaTime; using NodaTime;
namespace Foxnouns.Backend.Controllers; namespace Foxnouns.Backend.Controllers;
@ -29,16 +30,9 @@ public class UsersController(
[ProducesResponseType<UserRendererService.UserResponse>(statusCode: StatusCodes.Status200OK)] [ProducesResponseType<UserRendererService.UserResponse>(statusCode: StatusCodes.Status200OK)]
public async Task<IActionResult> GetUserAsync(string userRef, CancellationToken ct = default) public async Task<IActionResult> GetUserAsync(string userRef, CancellationToken ct = default)
{ {
var user = await db.ResolveUserAsync(userRef, CurrentToken, ct); User user = await db.ResolveUserAsync(userRef, CurrentToken, ct);
return Ok( return Ok(
await userRenderer.RenderUserAsync( await userRenderer.RenderUserAsync(user, CurrentUser, CurrentToken, true, true, ct: ct)
user,
selfUser: CurrentUser,
token: CurrentToken,
renderMembers: true,
renderAuthMethods: true,
ct: ct
)
); );
} }
@ -50,8 +44,8 @@ public class UsersController(
CancellationToken ct = default CancellationToken ct = default
) )
{ {
await using var tx = await db.Database.BeginTransactionAsync(ct); await using IDbContextTransaction tx = await db.Database.BeginTransactionAsync(ct);
var user = await db.Users.FirstAsync(u => u.Id == CurrentUser!.Id, ct); User user = await db.Users.FirstAsync(u => u.Id == CurrentUser!.Id, ct);
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
if (req.Username != null && req.Username != user.Username) if (req.Username != null && req.Username != user.Username)
@ -108,7 +102,7 @@ public class UsersController(
if (req.Flags != null) if (req.Flags != null)
{ {
var flagError = await db.SetUserFlagsAsync(CurrentUser!.Id, req.Flags); ValidationError? flagError = await db.SetUserFlagsAsync(CurrentUser!.Id, req.Flags);
if (flagError != null) if (flagError != null)
errors.Add(("flags", flagError)); errors.Add(("flags", flagError));
} }
@ -141,8 +135,11 @@ public class UsersController(
else else
{ {
if (TimeZoneInfo.TryFindSystemTimeZoneById(req.Timezone, out _)) if (TimeZoneInfo.TryFindSystemTimeZoneById(req.Timezone, out _))
{
user.Timezone = req.Timezone; user.Timezone = req.Timezone;
}
else else
{
errors.Add( errors.Add(
( (
"timezone", "timezone",
@ -151,15 +148,18 @@ public class UsersController(
); );
} }
} }
}
ValidationUtils.Validate(errors); ValidationUtils.Validate(errors);
// This is fired off regardless of whether the transaction is committed // This is fired off regardless of whether the transaction is committed
// (atomic operations are hard when combined with background jobs) // (atomic operations are hard when combined with background jobs)
// so it's in a separate block to the validation above. // so it's in a separate block to the validation above.
if (req.HasProperty(nameof(req.Avatar))) if (req.HasProperty(nameof(req.Avatar)))
{
queue.QueueInvocableWithPayload<UserAvatarUpdateInvocable, AvatarUpdatePayload>( queue.QueueInvocableWithPayload<UserAvatarUpdateInvocable, AvatarUpdatePayload>(
new AvatarUpdatePayload(CurrentUser!.Id, req.Avatar) new AvatarUpdatePayload(CurrentUser!.Id, req.Avatar)
); );
}
try try
{ {
@ -176,7 +176,7 @@ public class UsersController(
throw new ApiError.BadRequest( throw new ApiError.BadRequest(
"That username is already taken.", "That username is already taken.",
"username", "username",
req.Username! req.Username
); );
} }
@ -202,12 +202,12 @@ public class UsersController(
{ {
ValidationUtils.Validate(ValidationUtils.ValidateCustomPreferences(req)); ValidationUtils.Validate(ValidationUtils.ValidateCustomPreferences(req));
var user = await db.ResolveUserAsync(CurrentUser!.Id, ct); User user = await db.ResolveUserAsync(CurrentUser!.Id, ct);
var preferences = user var preferences = user
.CustomPreferences.Where(x => req.Any(r => r.Id == x.Key)) .CustomPreferences.Where(x => req.Any(r => r.Id == x.Key))
.ToDictionary(); .ToDictionary();
foreach (var r in req) foreach (CustomPreferenceUpdate? r in req)
{ {
if (r.Id != null && preferences.ContainsKey(r.Id.Value)) if (r.Id != null && preferences.ContainsKey(r.Id.Value))
{ {
@ -271,7 +271,7 @@ public class UsersController(
[ProducesResponseType<UserSettings>(statusCode: StatusCodes.Status200OK)] [ProducesResponseType<UserSettings>(statusCode: StatusCodes.Status200OK)]
public async Task<IActionResult> GetUserSettingsAsync(CancellationToken ct = default) public async Task<IActionResult> GetUserSettingsAsync(CancellationToken ct = default)
{ {
var user = await db.Users.FirstAsync(u => u.Id == CurrentUser!.Id, ct); User user = await db.Users.FirstAsync(u => u.Id == CurrentUser!.Id, ct);
return Ok(user.Settings); return Ok(user.Settings);
} }
@ -283,7 +283,7 @@ public class UsersController(
CancellationToken ct = default CancellationToken ct = default
) )
{ {
var user = await db.Users.FirstAsync(u => u.Id == CurrentUser!.Id, ct); User user = await db.Users.FirstAsync(u => u.Id == CurrentUser!.Id, ct);
if (req.HasProperty(nameof(req.DarkMode))) if (req.HasProperty(nameof(req.DarkMode)))
user.Settings.DarkMode = req.DarkMode; user.Settings.DarkMode = req.DarkMode;
@ -304,7 +304,7 @@ public class UsersController(
[ProducesResponseType<UserRendererService.UserResponse>(statusCode: StatusCodes.Status200OK)] [ProducesResponseType<UserRendererService.UserResponse>(statusCode: StatusCodes.Status200OK)]
public async Task<IActionResult> RerollSidAsync() public async Task<IActionResult> RerollSidAsync()
{ {
var minTimeAgo = clock.GetCurrentInstant() - Duration.FromHours(1); Instant minTimeAgo = clock.GetCurrentInstant() - Duration.FromHours(1);
if (CurrentUser!.LastSidReroll > minTimeAgo) if (CurrentUser!.LastSidReroll > minTimeAgo)
throw new ApiError.BadRequest("Cannot reroll short ID yet"); throw new ApiError.BadRequest("Cannot reroll short ID yet");
@ -318,18 +318,18 @@ public class UsersController(
); );
// Get the user's new sid // Get the user's new sid
var newSid = await db string newSid = await db
.Users.Where(u => u.Id == CurrentUser.Id) .Users.Where(u => u.Id == CurrentUser.Id)
.Select(u => u.Sid) .Select(u => u.Sid)
.FirstAsync(); .FirstAsync();
var user = await db.ResolveUserAsync(CurrentUser.Id); User user = await db.ResolveUserAsync(CurrentUser.Id);
return Ok( return Ok(
await userRenderer.RenderUserAsync( await userRenderer.RenderUserAsync(
CurrentUser, user,
CurrentUser, CurrentUser,
CurrentToken, CurrentToken,
renderMembers: false, false,
overrideSid: newSid overrideSid: newSid
) )
); );

View file

@ -11,9 +11,8 @@ namespace Foxnouns.Backend.Database;
public class DatabaseContext(DbContextOptions options) : DbContext(options) public class DatabaseContext(DbContextOptions options) : DbContext(options)
{ {
private static string GenerateConnectionString(Config.DatabaseConfig config) private static string GenerateConnectionString(Config.DatabaseConfig config) =>
{ new NpgsqlConnectionStringBuilder(config.Url)
return new NpgsqlConnectionStringBuilder(config.Url)
{ {
Pooling = config.EnablePooling ?? true, Pooling = config.EnablePooling ?? true,
Timeout = config.Timeout ?? 5, Timeout = config.Timeout ?? 5,
@ -22,7 +21,6 @@ public class DatabaseContext(DbContextOptions options) : DbContext(options)
ConnectionPruningInterval = 10, ConnectionPruningInterval = 10,
ConnectionIdleLifetime = 10, ConnectionIdleLifetime = 10,
}.ConnectionString; }.ConnectionString;
}
public static NpgsqlDataSource BuildDataSource(Config config) public static NpgsqlDataSource BuildDataSource(Config config)
{ {
@ -46,18 +44,18 @@ public class DatabaseContext(DbContextOptions options) : DbContext(options)
.UseSnakeCaseNamingConvention() .UseSnakeCaseNamingConvention()
.UseExceptionProcessor(); .UseExceptionProcessor();
public DbSet<User> Users { get; init; } public DbSet<User> Users { get; init; } = null!;
public DbSet<Member> Members { get; init; } public DbSet<Member> Members { get; init; } = null!;
public DbSet<AuthMethod> AuthMethods { get; init; } public DbSet<AuthMethod> AuthMethods { get; init; } = null!;
public DbSet<FediverseApplication> FediverseApplications { get; init; } public DbSet<FediverseApplication> FediverseApplications { get; init; } = null!;
public DbSet<Token> Tokens { get; init; } public DbSet<Token> Tokens { get; init; } = null!;
public DbSet<Application> Applications { get; init; } public DbSet<Application> Applications { get; init; } = null!;
public DbSet<TemporaryKey> TemporaryKeys { get; init; } public DbSet<TemporaryKey> TemporaryKeys { get; init; } = null!;
public DbSet<DataExport> DataExports { get; init; } public DbSet<DataExport> DataExports { get; init; } = null!;
public DbSet<PrideFlag> PrideFlags { get; init; } public DbSet<PrideFlag> PrideFlags { get; init; } = null!;
public DbSet<UserFlag> UserFlags { get; init; } public DbSet<UserFlag> UserFlags { get; init; } = null!;
public DbSet<MemberFlag> MemberFlags { get; init; } public DbSet<MemberFlag> MemberFlags { get; init; } = null!;
protected override void ConfigureConventions(ModelConfigurationBuilder configurationBuilder) protected override void ConfigureConventions(ModelConfigurationBuilder configurationBuilder)
{ {
@ -138,16 +136,16 @@ public class DesignTimeDatabaseContextFactory : IDesignTimeDbContextFactory<Data
public DatabaseContext CreateDbContext(string[] args) public DatabaseContext CreateDbContext(string[] args)
{ {
// Read the configuration file // Read the configuration file
var config = Config config =
new ConfigurationBuilder() new ConfigurationBuilder()
.AddConfiguration() .AddConfiguration()
.Build() .Build()
// Get the configuration as our config class // Get the configuration as our config class
.Get<Config>() ?? new(); .Get<Config>() ?? new Config();
var dataSource = DatabaseContext.BuildDataSource(config); NpgsqlDataSource dataSource = DatabaseContext.BuildDataSource(config);
var options = DatabaseContext DbContextOptions options = DatabaseContext
.BuildOptions(new DbContextOptionsBuilder(), dataSource, null) .BuildOptions(new DbContextOptionsBuilder(), dataSource, null)
.Options; .Options;

View file

@ -26,7 +26,7 @@ public static class DatabaseQueryExtensions
} }
User? user; User? user;
if (Snowflake.TryParse(userRef, out var snowflake)) if (Snowflake.TryParse(userRef, out Snowflake? snowflake))
{ {
user = await context user = await context
.Users.Where(u => !u.Deleted) .Users.Where(u => !u.Deleted)
@ -42,7 +42,7 @@ public static class DatabaseQueryExtensions
return user; return user;
throw new ApiError.NotFound( throw new ApiError.NotFound(
"No user with that ID or username found.", "No user with that ID or username found.",
code: ErrorCode.UserNotFound ErrorCode.UserNotFound
); );
} }
@ -52,12 +52,12 @@ public static class DatabaseQueryExtensions
CancellationToken ct = default CancellationToken ct = default
) )
{ {
var user = await context User? user = await context
.Users.Where(u => !u.Deleted) .Users.Where(u => !u.Deleted)
.FirstOrDefaultAsync(u => u.Id == id, ct); .FirstOrDefaultAsync(u => u.Id == id, ct);
if (user != null) if (user != null)
return user; return user;
throw new ApiError.NotFound("No user with that ID found.", code: ErrorCode.UserNotFound); throw new ApiError.NotFound("No user with that ID found.", ErrorCode.UserNotFound);
} }
public static async Task<Member> ResolveMemberAsync( public static async Task<Member> ResolveMemberAsync(
@ -66,16 +66,13 @@ public static class DatabaseQueryExtensions
CancellationToken ct = default CancellationToken ct = default
) )
{ {
var member = await context Member? member = await context
.Members.Include(m => m.User) .Members.Include(m => m.User)
.Where(m => !m.User.Deleted) .Where(m => !m.User.Deleted)
.FirstOrDefaultAsync(m => m.Id == id, ct); .FirstOrDefaultAsync(m => m.Id == id, ct);
if (member != null) if (member != null)
return member; return member;
throw new ApiError.NotFound( throw new ApiError.NotFound("No member with that ID found.", ErrorCode.MemberNotFound);
"No member with that ID found.",
code: ErrorCode.MemberNotFound
);
} }
public static async Task<Member> ResolveMemberAsync( public static async Task<Member> ResolveMemberAsync(
@ -86,7 +83,7 @@ public static class DatabaseQueryExtensions
CancellationToken ct = default CancellationToken ct = default
) )
{ {
var user = await context.ResolveUserAsync(userRef, token, ct); User user = await context.ResolveUserAsync(userRef, token, ct);
return await context.ResolveMemberAsync(user.Id, memberRef, ct); return await context.ResolveMemberAsync(user.Id, memberRef, ct);
} }
@ -98,7 +95,7 @@ public static class DatabaseQueryExtensions
) )
{ {
Member? member; Member? member;
if (Snowflake.TryParse(memberRef, out var snowflake)) if (Snowflake.TryParse(memberRef, out Snowflake? snowflake))
{ {
member = await context member = await context
.Members.Include(m => m.User) .Members.Include(m => m.User)
@ -118,7 +115,7 @@ public static class DatabaseQueryExtensions
return member; return member;
throw new ApiError.NotFound( throw new ApiError.NotFound(
"No member with that ID or name found.", "No member with that ID or name found.",
code: ErrorCode.MemberNotFound ErrorCode.MemberNotFound
); );
} }
@ -127,7 +124,10 @@ public static class DatabaseQueryExtensions
CancellationToken ct = default CancellationToken ct = default
) )
{ {
var app = await context.Applications.FirstOrDefaultAsync(a => a.Id == new Snowflake(0), ct); Application? app = await context.Applications.FirstOrDefaultAsync(
a => a.Id == new Snowflake(0),
ct
);
if (app != null) if (app != null)
return app; return app;
@ -152,9 +152,9 @@ public static class DatabaseQueryExtensions
CancellationToken ct = default CancellationToken ct = default
) )
{ {
var hash = SHA512.HashData(rawToken); byte[] hash = SHA512.HashData(rawToken);
var oauthToken = await context Token? oauthToken = await context
.Tokens.Include(t => t.Application) .Tokens.Include(t => t.Application)
.Include(t => t.User) .Include(t => t.User)
.FirstOrDefaultAsync( .FirstOrDefaultAsync(
@ -174,7 +174,7 @@ public static class DatabaseQueryExtensions
CancellationToken ct = default CancellationToken ct = default
) )
{ {
var hash = SHA512.HashData(rawToken); byte[] hash = SHA512.HashData(rawToken);
return await context return await context
.Tokens.Where(t => .Tokens.Where(t =>
t.Hash == hash t.Hash == hash

View file

@ -1,3 +1,4 @@
using Npgsql;
using Serilog; using Serilog;
namespace Foxnouns.Backend.Database; namespace Foxnouns.Backend.Database;
@ -9,8 +10,8 @@ public static class DatabaseServiceExtensions
Config config Config config
) )
{ {
var dataSource = DatabaseContext.BuildDataSource(config); NpgsqlDataSource dataSource = DatabaseContext.BuildDataSource(config);
var loggerFactory = new LoggerFactory().AddSerilog(dispose: false); ILoggerFactory loggerFactory = new LoggerFactory().AddSerilog(dispose: false);
serviceCollection.AddDbContext<DatabaseContext>(options => serviceCollection.AddDbContext<DatabaseContext>(options =>
DatabaseContext.BuildOptions(options, dataSource, loggerFactory) DatabaseContext.BuildOptions(options, dataSource, loggerFactory)

View file

@ -20,8 +20,10 @@ public static class FlagQueryExtensions
Snowflake[] flagIds Snowflake[] flagIds
) )
{ {
var currentFlags = await db.UserFlags.Where(f => f.UserId == userId).ToListAsync(); List<UserFlag> currentFlags = await db
foreach (var flag in currentFlags) .UserFlags.Where(f => f.UserId == userId)
.ToListAsync();
foreach (UserFlag flag in currentFlags)
db.UserFlags.Remove(flag); db.UserFlags.Remove(flag);
// If there's no new flags to set, we're done // If there's no new flags to set, we're done
@ -30,12 +32,16 @@ public static class FlagQueryExtensions
if (flagIds.Length > 100) if (flagIds.Length > 100)
return ValidationError.LengthError("Too many profile flags", 0, 100, flagIds.Length); return ValidationError.LengthError("Too many profile flags", 0, 100, flagIds.Length);
var flags = await db.GetFlagsAsync(userId); List<PrideFlag> flags = await db.GetFlagsAsync(userId);
var unknownFlagIds = flagIds.Where(id => flags.All(f => f.Id != id)).ToArray(); Snowflake[] unknownFlagIds = flagIds.Where(id => flags.All(f => f.Id != id)).ToArray();
if (unknownFlagIds.Length != 0) if (unknownFlagIds.Length != 0)
return ValidationError.GenericValidationError("Unknown flag IDs", unknownFlagIds); return ValidationError.GenericValidationError("Unknown flag IDs", unknownFlagIds);
var userFlags = flagIds.Select(id => new UserFlag { PrideFlagId = id, UserId = userId }); IEnumerable<UserFlag> userFlags = flagIds.Select(id => new UserFlag
{
PrideFlagId = id,
UserId = userId,
});
db.UserFlags.AddRange(userFlags); db.UserFlags.AddRange(userFlags);
return null; return null;
@ -48,8 +54,10 @@ public static class FlagQueryExtensions
Snowflake[] flagIds Snowflake[] flagIds
) )
{ {
var currentFlags = await db.MemberFlags.Where(f => f.MemberId == memberId).ToListAsync(); List<MemberFlag> currentFlags = await db
foreach (var flag in currentFlags) .MemberFlags.Where(f => f.MemberId == memberId)
.ToListAsync();
foreach (MemberFlag flag in currentFlags)
db.MemberFlags.Remove(flag); db.MemberFlags.Remove(flag);
if (flagIds.Length == 0) if (flagIds.Length == 0)
@ -57,12 +65,12 @@ public static class FlagQueryExtensions
if (flagIds.Length > 100) if (flagIds.Length > 100)
return ValidationError.LengthError("Too many profile flags", 0, 100, flagIds.Length); return ValidationError.LengthError("Too many profile flags", 0, 100, flagIds.Length);
var flags = await db.GetFlagsAsync(userId); List<PrideFlag> flags = await db.GetFlagsAsync(userId);
var unknownFlagIds = flagIds.Where(id => flags.All(f => f.Id != id)).ToArray(); Snowflake[] unknownFlagIds = flagIds.Where(id => flags.All(f => f.Id != id)).ToArray();
if (unknownFlagIds.Length != 0) if (unknownFlagIds.Length != 0)
return ValidationError.GenericValidationError("Unknown flag IDs", unknownFlagIds); return ValidationError.GenericValidationError("Unknown flag IDs", unknownFlagIds);
var memberFlags = flagIds.Select(id => new MemberFlag IEnumerable<MemberFlag> memberFlags = flagIds.Select(id => new MemberFlag
{ {
PrideFlagId = id, PrideFlagId = id,
MemberId = memberId, MemberId = memberId,

View file

@ -24,10 +24,7 @@ namespace Foxnouns.Backend.Database.Migrations
client_secret = table.Column<string>(type: "text", nullable: false), client_secret = table.Column<string>(type: "text", nullable: false),
instance_type = table.Column<int>(type: "integer", nullable: false), instance_type = table.Column<int>(type: "integer", nullable: false),
}, },
constraints: table => constraints: table => table.PrimaryKey("pk_fediverse_applications", x => x.id)
{
table.PrimaryKey("pk_fediverse_applications", x => x.id);
}
); );
migrationBuilder.CreateTable( migrationBuilder.CreateTable(
@ -46,10 +43,7 @@ namespace Foxnouns.Backend.Database.Migrations
names = table.Column<string>(type: "jsonb", nullable: false), names = table.Column<string>(type: "jsonb", nullable: false),
pronouns = table.Column<string>(type: "jsonb", nullable: false), pronouns = table.Column<string>(type: "jsonb", nullable: false),
}, },
constraints: table => constraints: table => table.PrimaryKey("pk_users", x => x.id)
{
table.PrimaryKey("pk_users", x => x.id);
}
); );
migrationBuilder.CreateTable( migrationBuilder.CreateTable(

View file

@ -26,7 +26,7 @@ namespace Foxnouns.Backend.Database.Migrations
table: "tokens", table: "tokens",
type: "bytea", type: "bytea",
nullable: false, nullable: false,
defaultValue: new byte[0] defaultValue: Array.Empty<byte>()
); );
migrationBuilder.CreateTable( migrationBuilder.CreateTable(
@ -40,10 +40,7 @@ namespace Foxnouns.Backend.Database.Migrations
scopes = table.Column<string[]>(type: "text[]", nullable: false), scopes = table.Column<string[]>(type: "text[]", nullable: false),
redirect_uris = table.Column<string[]>(type: "text[]", nullable: false), redirect_uris = table.Column<string[]>(type: "text[]", nullable: false),
}, },
constraints: table => constraints: table => table.PrimaryKey("pk_applications", x => x.id)
{
table.PrimaryKey("pk_applications", x => x.id);
}
); );
migrationBuilder.CreateIndex( migrationBuilder.CreateIndex(

View file

@ -32,10 +32,7 @@ namespace Foxnouns.Backend.Database.Migrations
nullable: false nullable: false
), ),
}, },
constraints: table => constraints: table => table.PrimaryKey("pk_temporary_keys", x => x.id)
{
table.PrimaryKey("pk_temporary_keys", x => x.id);
}
); );
migrationBuilder.CreateIndex( migrationBuilder.CreateIndex(

View file

@ -18,8 +18,8 @@ public class Application : BaseModel
string[] redirectUrls string[] redirectUrls
) )
{ {
var clientId = RandomNumberGenerator.GetHexString(32, true); string clientId = RandomNumberGenerator.GetHexString(32, true);
var clientSecret = AuthUtils.RandomToken(); string clientSecret = AuthUtils.RandomToken();
if (scopes.Except(AuthUtils.ApplicationScopes).Any()) if (scopes.Except(AuthUtils.ApplicationScopes).Any())
{ {

View file

@ -59,7 +59,7 @@ public readonly struct Snowflake(ulong value) : IEquatable<Snowflake>
public static bool TryParse(string input, [NotNullWhen(true)] out Snowflake? snowflake) public static bool TryParse(string input, [NotNullWhen(true)] out Snowflake? snowflake)
{ {
snowflake = null; snowflake = null;
if (!ulong.TryParse(input, out var res)) if (!ulong.TryParse(input, out ulong res))
return false; return false;
snowflake = new Snowflake(res); snowflake = new Snowflake(res);
return true; return true;
@ -70,10 +70,7 @@ public readonly struct Snowflake(ulong value) : IEquatable<Snowflake>
public override bool Equals(object? obj) => obj is Snowflake other && Value == other.Value; public override bool Equals(object? obj) => obj is Snowflake other && Value == other.Value;
public bool Equals(Snowflake other) public bool Equals(Snowflake other) => Value == other.Value;
{
return Value == other.Value;
}
public override int GetHashCode() => Value.GetHashCode(); public override int GetHashCode() => Value.GetHashCode();
@ -83,11 +80,7 @@ public readonly struct Snowflake(ulong value) : IEquatable<Snowflake>
/// An Entity Framework ValueConverter for Snowflakes to longs. /// An Entity Framework ValueConverter for Snowflakes to longs.
/// </summary> /// </summary>
// ReSharper disable once ClassNeverInstantiated.Global // ReSharper disable once ClassNeverInstantiated.Global
public class ValueConverter() public class ValueConverter() : ValueConverter<Snowflake, long>(x => x, x => x);
: ValueConverter<Snowflake, long>(
convertToProviderExpression: x => x,
convertFromProviderExpression: x => x
);
private class JsonConverter : JsonConverter<Snowflake> private class JsonConverter : JsonConverter<Snowflake>
{ {
@ -106,10 +99,7 @@ public readonly struct Snowflake(ulong value) : IEquatable<Snowflake>
Snowflake existingValue, Snowflake existingValue,
bool hasExistingValue, bool hasExistingValue,
JsonSerializer serializer JsonSerializer serializer
) ) => ulong.Parse((string)reader.Value!);
{
return ulong.Parse((string)reader.Value!);
}
} }
private class TypeConverter : System.ComponentModel.TypeConverter private class TypeConverter : System.ComponentModel.TypeConverter
@ -126,9 +116,6 @@ public readonly struct Snowflake(ulong value) : IEquatable<Snowflake>
ITypeDescriptorContext? context, ITypeDescriptorContext? context,
CultureInfo? culture, CultureInfo? culture,
object value object value
) ) => TryParse((string)value, out Snowflake? snowflake) ? snowflake : null;
{
return TryParse((string)value, out var snowflake) ? snowflake : null;
}
} }
} }

View file

@ -28,9 +28,9 @@ public class SnowflakeGenerator : ISnowflakeGenerator
public Snowflake GenerateSnowflake(Instant? time = null) public Snowflake GenerateSnowflake(Instant? time = null)
{ {
time ??= SystemClock.Instance.GetCurrentInstant(); time ??= SystemClock.Instance.GetCurrentInstant();
var increment = Interlocked.Increment(ref _increment); long increment = Interlocked.Increment(ref _increment);
var threadId = Environment.CurrentManagedThreadId % 32; int threadId = Environment.CurrentManagedThreadId % 32;
var timestamp = time.Value.ToUnixTimeMilliseconds() - Snowflake.Epoch; long timestamp = time.Value.ToUnixTimeMilliseconds() - Snowflake.Epoch;
return (timestamp << 22) return (timestamp << 22)
| (uint)(_processId << 17) | (uint)(_processId << 17)
@ -44,8 +44,5 @@ public static class SnowflakeGeneratorServiceExtensions
public static IServiceCollection AddSnowflakeGenerator( public static IServiceCollection AddSnowflakeGenerator(
this IServiceCollection services, this IServiceCollection services,
int? processId = null int? processId = null
) ) => services.AddSingleton<ISnowflakeGenerator>(new SnowflakeGenerator(processId));
{
return services.AddSingleton<ISnowflakeGenerator>(new SnowflakeGenerator(processId));
}
} }

View file

@ -35,13 +35,13 @@ public class ApiError(
public readonly ErrorCode ErrorCode = errorCode ?? ErrorCode.InternalServerError; public readonly ErrorCode ErrorCode = errorCode ?? ErrorCode.InternalServerError;
public class Unauthorized(string message, ErrorCode errorCode = ErrorCode.AuthenticationError) public class Unauthorized(string message, ErrorCode errorCode = ErrorCode.AuthenticationError)
: ApiError(message, statusCode: HttpStatusCode.Unauthorized, errorCode: errorCode); : ApiError(message, HttpStatusCode.Unauthorized, errorCode);
public class Forbidden( public class Forbidden(
string message, string message,
IEnumerable<string>? scopes = null, IEnumerable<string>? scopes = null,
ErrorCode errorCode = ErrorCode.Forbidden ErrorCode errorCode = ErrorCode.Forbidden
) : ApiError(message, statusCode: HttpStatusCode.Forbidden, errorCode: errorCode) ) : ApiError(message, HttpStatusCode.Forbidden, errorCode)
{ {
public readonly string[] Scopes = scopes?.ToArray() ?? []; public readonly string[] Scopes = scopes?.ToArray() ?? [];
} }
@ -49,7 +49,7 @@ public class ApiError(
public class BadRequest( public class BadRequest(
string message, string message,
IReadOnlyDictionary<string, IEnumerable<ValidationError>>? errors = null IReadOnlyDictionary<string, IEnumerable<ValidationError>>? errors = null
) : ApiError(message, statusCode: HttpStatusCode.BadRequest) ) : ApiError(message, HttpStatusCode.BadRequest)
{ {
public BadRequest(string message, string field, object? actualValue) public BadRequest(string message, string field, object? actualValue)
: this( : this(
@ -72,7 +72,7 @@ public class ApiError(
return o; return o;
var a = new JArray(); var a = new JArray();
foreach (var error in errors) foreach (KeyValuePair<string, IEnumerable<ValidationError>> error in errors)
{ {
var errorObj = new JObject var errorObj = new JObject
{ {
@ -92,7 +92,7 @@ public class ApiError(
/// Any other methods should use <see cref="ApiError.BadRequest" /> instead. /// Any other methods should use <see cref="ApiError.BadRequest" /> instead.
/// </summary> /// </summary>
public class AspBadRequest(string message, ModelStateDictionary? modelState = null) public class AspBadRequest(string message, ModelStateDictionary? modelState = null)
: ApiError(message, statusCode: HttpStatusCode.BadRequest) : ApiError(message, HttpStatusCode.BadRequest)
{ {
public JObject ToJson() public JObject ToJson()
{ {
@ -106,7 +106,11 @@ public class ApiError(
return o; return o;
var a = new JArray(); var a = new JArray();
foreach (var error in modelState.Where(e => e.Value is { Errors.Count: > 0 })) foreach (
KeyValuePair<string, ModelStateEntry?> error in modelState.Where(e =>
e.Value is { Errors.Count: > 0 }
)
)
{ {
var errorObj = new JObject var errorObj = new JObject
{ {
@ -130,10 +134,9 @@ public class ApiError(
} }
public class NotFound(string message, ErrorCode? code = null) public class NotFound(string message, ErrorCode? code = null)
: ApiError(message, statusCode: HttpStatusCode.NotFound, errorCode: code); : ApiError(message, HttpStatusCode.NotFound, code);
public class AuthenticationError(string message) public class AuthenticationError(string message) : ApiError(message, HttpStatusCode.BadRequest);
: ApiError(message, statusCode: HttpStatusCode.BadRequest);
} }
public enum ErrorCode public enum ErrorCode
@ -175,33 +178,27 @@ public class ValidationError
int minLength, int minLength,
int maxLength, int maxLength,
int actualLength int actualLength
) ) =>
{ new()
return new ValidationError
{ {
Message = message, Message = message,
MinLength = minLength, MinLength = minLength,
MaxLength = maxLength, MaxLength = maxLength,
ActualLength = actualLength, ActualLength = actualLength,
}; };
}
public static ValidationError DisallowedValueError( public static ValidationError DisallowedValueError(
string message, string message,
IEnumerable<object> allowedValues, IEnumerable<object> allowedValues,
object actualValue object actualValue
) ) =>
{ new()
return new ValidationError
{ {
Message = message, Message = message,
AllowedValues = allowedValues, AllowedValues = allowedValues,
ActualValue = actualValue, ActualValue = actualValue,
}; };
}
public static ValidationError GenericValidationError(string message, object? actualValue) public static ValidationError GenericValidationError(string message, object? actualValue) =>
{ new() { Message = message, ActualValue = actualValue };
return new ValidationError { Message = message, ActualValue = actualValue };
}
} }

View file

@ -47,13 +47,13 @@ public static class ImageObjectExtensions
if (!uri.StartsWith("data:image/")) if (!uri.StartsWith("data:image/"))
throw new ArgumentException("Not a data URI", nameof(uri)); throw new ArgumentException("Not a data URI", nameof(uri));
var split = uri.Remove(0, "data:".Length).Split(";base64,"); string[] split = uri.Remove(0, "data:".Length).Split(";base64,");
var contentType = split[0]; string contentType = split[0];
var encoded = split[1]; string encoded = split[1];
if (!ValidContentTypes.Contains(contentType)) if (!ValidContentTypes.Contains(contentType))
throw new ArgumentException("Invalid content type for image", nameof(uri)); throw new ArgumentException("Invalid content type for image", nameof(uri));
if (!AuthUtils.TryFromBase64String(encoded, out var rawImage)) if (!AuthUtils.TryFromBase64String(encoded, out byte[]? rawImage))
throw new ArgumentException("Invalid base64 string", nameof(uri)); throw new ArgumentException("Invalid base64 string", nameof(uri));
var image = Image.Load(rawImage); var image = Image.Load(rawImage);
@ -74,7 +74,7 @@ public static class ImageObjectExtensions
await image.SaveAsync(stream, new WebpEncoder { Quality = 95, NearLossless = false }); await image.SaveAsync(stream, new WebpEncoder { Quality = 95, NearLossless = false });
stream.Seek(0, SeekOrigin.Begin); stream.Seek(0, SeekOrigin.Begin);
var hash = Convert.ToHexString(await SHA256.HashDataAsync(stream)).ToLower(); string hash = Convert.ToHexString(await SHA256.HashDataAsync(stream)).ToLower();
stream.Seek(0, SeekOrigin.Begin); stream.Seek(0, SeekOrigin.Begin);
return (hash, stream); return (hash, stream);

View file

@ -14,7 +14,7 @@ public static class KeyCacheExtensions
CancellationToken ct = default CancellationToken ct = default
) )
{ {
var state = AuthUtils.RandomToken().Replace('+', '-').Replace('/', '_'); string state = AuthUtils.RandomToken().Replace('+', '-').Replace('/', '_');
await keyCacheService.SetKeyAsync($"oauth_state:{state}", "", Duration.FromMinutes(10), ct); await keyCacheService.SetKeyAsync($"oauth_state:{state}", "", Duration.FromMinutes(10), ct);
return state; return state;
} }
@ -25,7 +25,7 @@ public static class KeyCacheExtensions
CancellationToken ct = default CancellationToken ct = default
) )
{ {
var val = await keyCacheService.GetKeyAsync($"oauth_state:{state}", delete: true, ct); string? val = await keyCacheService.GetKeyAsync($"oauth_state:{state}", ct: ct);
if (val == null) if (val == null)
throw new ApiError.BadRequest("Invalid OAuth state"); throw new ApiError.BadRequest("Invalid OAuth state");
} }
@ -38,7 +38,7 @@ public static class KeyCacheExtensions
) )
{ {
// This state is used in links, not just as JSON values, so make it URL-safe // This state is used in links, not just as JSON values, so make it URL-safe
var state = AuthUtils.RandomToken().Replace('+', '-').Replace('/', '_'); string state = AuthUtils.RandomToken().Replace('+', '-').Replace('/', '_');
await keyCacheService.SetKeyAsync( await keyCacheService.SetKeyAsync(
$"email_state:{state}", $"email_state:{state}",
new RegisterEmailState(email, userId), new RegisterEmailState(email, userId),
@ -52,12 +52,7 @@ public static class KeyCacheExtensions
this KeyCacheService keyCacheService, this KeyCacheService keyCacheService,
string state, string state,
CancellationToken ct = default CancellationToken ct = default
) => ) => await keyCacheService.GetKeyAsync<RegisterEmailState>($"email_state:{state}", ct: ct);
await keyCacheService.GetKeyAsync<RegisterEmailState>(
$"email_state:{state}",
delete: true,
ct
);
public static async Task<string> GenerateAddExtraAccountStateAsync( public static async Task<string> GenerateAddExtraAccountStateAsync(
this KeyCacheService keyCacheService, this KeyCacheService keyCacheService,
@ -67,7 +62,7 @@ public static class KeyCacheExtensions
CancellationToken ct = default CancellationToken ct = default
) )
{ {
var state = AuthUtils.RandomToken(); string state = AuthUtils.RandomToken();
await keyCacheService.SetKeyAsync( await keyCacheService.SetKeyAsync(
$"add_account:{state}", $"add_account:{state}",
new AddExtraAccountState(authType, userId, instance), new AddExtraAccountState(authType, userId, instance),
@ -81,12 +76,7 @@ public static class KeyCacheExtensions
this KeyCacheService keyCacheService, this KeyCacheService keyCacheService,
string state, string state,
CancellationToken ct = default CancellationToken ct = default
) => ) => await keyCacheService.GetKeyAsync<AddExtraAccountState>($"add_account:{state}", true, ct);
await keyCacheService.GetKeyAsync<AddExtraAccountState>(
$"add_account:{state}",
delete: true,
ct
);
} }
public record RegisterEmailState( public record RegisterEmailState(

View file

@ -24,9 +24,9 @@ public static class WebApplicationExtensions
/// </summary> /// </summary>
public static WebApplicationBuilder AddSerilog(this WebApplicationBuilder builder) public static WebApplicationBuilder AddSerilog(this WebApplicationBuilder builder)
{ {
var config = builder.Configuration.Get<Config>() ?? new(); Config config = builder.Configuration.Get<Config>() ?? new Config();
var logCfg = new LoggerConfiguration() LoggerConfiguration logCfg = new LoggerConfiguration()
.Enrich.FromLogContext() .Enrich.FromLogContext()
.MinimumLevel.Is(config.Logging.LogEventLevel) .MinimumLevel.Is(config.Logging.LogEventLevel)
// ASP.NET's built in request logs are extremely verbose, so we use Serilog's instead. // ASP.NET's built in request logs are extremely verbose, so we use Serilog's instead.
@ -43,10 +43,7 @@ public static class WebApplicationExtensions
if (config.Logging.SeqLogUrl != null) if (config.Logging.SeqLogUrl != null)
{ {
logCfg.WriteTo.Seq( logCfg.WriteTo.Seq(config.Logging.SeqLogUrl, LogEventLevel.Verbose);
config.Logging.SeqLogUrl,
restrictedToMinimumLevel: LogEventLevel.Verbose
);
} }
// AddSerilog doesn't seem to add an ILogger to the service collection, so add that manually. // AddSerilog doesn't seem to add an ILogger to the service collection, so add that manually.
@ -60,19 +57,19 @@ public static class WebApplicationExtensions
builder.Configuration.Sources.Clear(); builder.Configuration.Sources.Clear();
builder.Configuration.AddConfiguration(); builder.Configuration.AddConfiguration();
var config = builder.Configuration.Get<Config>() ?? new(); Config config = builder.Configuration.Get<Config>() ?? new Config();
builder.Services.AddSingleton(config); builder.Services.AddSingleton(config);
return config; return config;
} }
public static IConfigurationBuilder AddConfiguration(this IConfigurationBuilder builder) public static IConfigurationBuilder AddConfiguration(this IConfigurationBuilder builder)
{ {
var file = Environment.GetEnvironmentVariable("FOXNOUNS_CONFIG_FILE") ?? "config.ini"; string file = Environment.GetEnvironmentVariable("FOXNOUNS_CONFIG_FILE") ?? "config.ini";
return builder return builder
.SetBasePath(Directory.GetCurrentDirectory()) .SetBasePath(Directory.GetCurrentDirectory())
.AddJsonFile("appSettings.json", true) .AddJsonFile("appSettings.json", true)
.AddIniFile(file, optional: false, reloadOnChange: true) .AddIniFile(file, false, true)
.AddEnvironmentVariables(); .AddEnvironmentVariables();
} }
@ -142,11 +139,15 @@ public static class WebApplicationExtensions
app.Services.ConfigureQueue() app.Services.ConfigureQueue()
.LogQueuedTaskProgress(app.Services.GetRequiredService<ILogger<IQueue>>()); .LogQueuedTaskProgress(app.Services.GetRequiredService<ILogger<IQueue>>());
await using var scope = app.Services.CreateAsyncScope(); await using AsyncServiceScope scope = app.Services.CreateAsyncScope();
// The types of these variables are obvious from the methods being called to create them
// ReSharper disable SuggestVarOrType_SimpleTypes
var logger = scope var logger = scope
.ServiceProvider.GetRequiredService<ILogger>() .ServiceProvider.GetRequiredService<ILogger>()
.ForContext<WebApplication>(); .ForContext<WebApplication>();
var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>(); var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
// ReSharper restore SuggestVarOrType_SimpleTypes
logger.Information( logger.Information(
"Starting Foxnouns.NET {Version} ({Hash})", "Starting Foxnouns.NET {Version} ({Hash})",

View file

@ -30,6 +30,10 @@
<PackageReference Include="Npgsql.Json.NET" Version="8.0.3"/> <PackageReference Include="Npgsql.Json.NET" Version="8.0.3"/>
<PackageReference Include="prometheus-net" Version="8.2.1"/> <PackageReference Include="prometheus-net" Version="8.2.1"/>
<PackageReference Include="prometheus-net.AspNetCore" Version="8.2.1"/> <PackageReference Include="prometheus-net.AspNetCore" Version="8.2.1"/>
<PackageReference Include="Roslynator.Analyzers" Version="4.12.9">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="Sentry.AspNetCore" Version="4.9.0"/> <PackageReference Include="Sentry.AspNetCore" Version="4.9.0"/>
<PackageReference Include="Serilog" Version="4.0.1"/> <PackageReference Include="Serilog" Version="4.0.1"/>
<PackageReference Include="Serilog.AspNetCore" Version="8.0.1"/> <PackageReference Include="Serilog.AspNetCore" Version="8.0.1"/>

View file

@ -40,7 +40,7 @@ public class CreateDataExportInvocable(
private async Task InvokeAsync() private async Task InvokeAsync()
{ {
var user = await db User? user = await db
.Users.Include(u => u.AuthMethods) .Users.Include(u => u.AuthMethods)
.Include(u => u.Flags) .Include(u => u.Flags)
.Include(u => u.ProfileFlags) .Include(u => u.ProfileFlags)
@ -57,7 +57,7 @@ public class CreateDataExportInvocable(
_logger.Information("Generating data export for user {UserId}", user.Id); _logger.Information("Generating data export for user {UserId}", user.Id);
using var stream = new MemoryStream(); await using var stream = new MemoryStream();
using var zip = new ZipArchive(stream, ZipArchiveMode.Create, true); using var zip = new ZipArchive(stream, ZipArchiveMode.Create, true);
zip.Comment = zip.Comment =
$"This archive for {user.Username} ({user.Id}) was generated at {InstantPattern.General.Format(clock.GetCurrentInstant())}"; $"This archive for {user.Username} ({user.Id}) was generated at {InstantPattern.General.Format(clock.GetCurrentInstant())}";
@ -66,25 +66,19 @@ public class CreateDataExportInvocable(
WriteJson( WriteJson(
zip, zip,
"user.json", "user.json",
await userRenderer.RenderUserInnerAsync( await userRenderer.RenderUserInnerAsync(user, true, ["*"], false, true)
user,
true,
["*"],
renderMembers: false,
renderAuthMethods: true
)
); );
await WriteS3Object(zip, "user-avatar.webp", userRenderer.AvatarUrlFor(user)); await WriteS3Object(zip, "user-avatar.webp", userRenderer.AvatarUrlFor(user));
foreach (var flag in user.Flags) foreach (PrideFlag? flag in user.Flags)
await WritePrideFlag(zip, flag); await WritePrideFlag(zip, flag);
var members = await db List<Member> members = await db
.Members.Include(m => m.User) .Members.Include(m => m.User)
.Include(m => m.ProfileFlags) .Include(m => m.ProfileFlags)
.Where(m => m.UserId == user.Id) .Where(m => m.UserId == user.Id)
.ToListAsync(); .ToListAsync();
foreach (var member in members) foreach (Member? member in members)
await WriteMember(zip, member); await WriteMember(zip, member);
// We want to dispose the ZipArchive on an error, but we need to dispose it manually to upload to object storage. // We want to dispose the ZipArchive on an error, but we need to dispose it manually to upload to object storage.
@ -94,7 +88,7 @@ public class CreateDataExportInvocable(
stream.Seek(0, SeekOrigin.Begin); stream.Seek(0, SeekOrigin.Begin);
// Upload the file! // Upload the file!
var filename = AuthUtils.RandomToken().Replace('+', '-').Replace('/', '_'); string filename = AuthUtils.RandomToken().Replace('+', '-').Replace('/', '_');
await objectStorageService.PutObjectAsync( await objectStorageService.PutObjectAsync(
ExportPath(user.Id, filename), ExportPath(user.Id, filename),
stream, stream,
@ -132,8 +126,8 @@ public class CreateDataExportInvocable(
return; return;
} }
var entry = zip.CreateEntry($"flag-{flag.Id}/flag.txt"); ZipArchiveEntry entry = zip.CreateEntry($"flag-{flag.Id}/flag.txt");
await using var stream = entry.Open(); await using Stream stream = entry.Open();
await using var writer = new StreamWriter(stream); await using var writer = new StreamWriter(stream);
await writer.WriteAsync(flagData); await writer.WriteAsync(flagData);
} }
@ -164,7 +158,7 @@ public class CreateDataExportInvocable(
private void WriteJson(ZipArchive zip, string filename, object data) private void WriteJson(ZipArchive zip, string filename, object data)
{ {
var json = JsonConvert.SerializeObject(data, Formatting.Indented); string json = JsonConvert.SerializeObject(data, Formatting.Indented);
_logger.Debug( _logger.Debug(
"Writing file {Filename} to archive with size {Length}", "Writing file {Filename} to archive with size {Length}",
@ -172,8 +166,8 @@ public class CreateDataExportInvocable(
json.Length json.Length
); );
var entry = zip.CreateEntry(filename); ZipArchiveEntry entry = zip.CreateEntry(filename);
using var stream = entry.Open(); using Stream stream = entry.Open();
using var writer = new StreamWriter(stream); using var writer = new StreamWriter(stream);
writer.Write(json); writer.Write(json);
} }
@ -183,14 +177,14 @@ public class CreateDataExportInvocable(
if (s3Path == null) if (s3Path == null)
return; return;
var resp = await Client.GetAsync(s3Path); HttpResponseMessage resp = await Client.GetAsync(s3Path);
if (resp.StatusCode != HttpStatusCode.OK) if (resp.StatusCode != HttpStatusCode.OK)
{ {
_logger.Warning("S3 path {S3Path} returned a non-200 status, not saving file", s3Path); _logger.Warning("S3 path {S3Path} returned a non-200 status, not saving file", s3Path);
return; return;
} }
await using var respStream = await resp.Content.ReadAsStreamAsync(); await using Stream respStream = await resp.Content.ReadAsStreamAsync();
_logger.Debug( _logger.Debug(
"Writing file {Filename} to archive with size {Length}", "Writing file {Filename} to archive with size {Length}",
@ -198,8 +192,8 @@ public class CreateDataExportInvocable(
respStream.Length respStream.Length
); );
var entry = zip.CreateEntry(filename); ZipArchiveEntry entry = zip.CreateEntry(filename);
await using var entryStream = entry.Open(); await using Stream entryStream = entry.Open();
respStream.Seek(0, SeekOrigin.Begin); respStream.Seek(0, SeekOrigin.Begin);
await respStream.CopyToAsync(entryStream); await respStream.CopyToAsync(entryStream);

View file

@ -26,10 +26,10 @@ public class CreateFlagInvocable(
try try
{ {
var (hash, image) = await ImageObjectExtensions.ConvertBase64UriToImage( (string? hash, Stream? image) = await ImageObjectExtensions.ConvertBase64UriToImage(
Payload.ImageData, Payload.ImageData,
size: 256, 256,
crop: false false
); );
await objectStorageService.PutObjectAsync(Path(hash), image, "image/webp"); await objectStorageService.PutObjectAsync(Path(hash), image, "image/webp");

View file

@ -1,5 +1,6 @@
using Coravel.Invocable; using Coravel.Invocable;
using Foxnouns.Backend.Database; using Foxnouns.Backend.Database;
using Foxnouns.Backend.Database.Models;
using Foxnouns.Backend.Extensions; using Foxnouns.Backend.Extensions;
using Foxnouns.Backend.Services; using Foxnouns.Backend.Services;
@ -26,7 +27,7 @@ public class MemberAvatarUpdateInvocable(
{ {
_logger.Debug("Updating avatar for member {MemberId}", id); _logger.Debug("Updating avatar for member {MemberId}", id);
var member = await db.Members.FindAsync(id); Member? member = await db.Members.FindAsync(id);
if (member == null) if (member == null)
{ {
_logger.Warning( _logger.Warning(
@ -38,12 +39,12 @@ public class MemberAvatarUpdateInvocable(
try try
{ {
var (hash, image) = await ImageObjectExtensions.ConvertBase64UriToImage( (string? hash, Stream? image) = await ImageObjectExtensions.ConvertBase64UriToImage(
newAvatar, newAvatar,
size: 512, 512,
crop: true true
); );
var prevHash = member.Avatar; string? prevHash = member.Avatar;
await objectStorageService.PutObjectAsync(Path(id, hash), image, "image/webp"); await objectStorageService.PutObjectAsync(Path(id, hash), image, "image/webp");
@ -69,7 +70,7 @@ public class MemberAvatarUpdateInvocable(
{ {
_logger.Debug("Clearing avatar for member {MemberId}", id); _logger.Debug("Clearing avatar for member {MemberId}", id);
var member = await db.Members.FindAsync(id); Member? member = await db.Members.FindAsync(id);
if (member == null) if (member == null)
{ {
_logger.Warning( _logger.Warning(

View file

@ -1,5 +1,6 @@
using Coravel.Invocable; using Coravel.Invocable;
using Foxnouns.Backend.Database; using Foxnouns.Backend.Database;
using Foxnouns.Backend.Database.Models;
using Foxnouns.Backend.Extensions; using Foxnouns.Backend.Extensions;
using Foxnouns.Backend.Services; using Foxnouns.Backend.Services;
@ -26,7 +27,7 @@ public class UserAvatarUpdateInvocable(
{ {
_logger.Debug("Updating avatar for user {MemberId}", id); _logger.Debug("Updating avatar for user {MemberId}", id);
var user = await db.Users.FindAsync(id); User? user = await db.Users.FindAsync(id);
if (user == null) if (user == null)
{ {
_logger.Warning( _logger.Warning(
@ -38,13 +39,13 @@ public class UserAvatarUpdateInvocable(
try try
{ {
var (hash, image) = await ImageObjectExtensions.ConvertBase64UriToImage( (string? hash, Stream? image) = await ImageObjectExtensions.ConvertBase64UriToImage(
newAvatar, newAvatar,
size: 512, 512,
crop: true true
); );
image.Seek(0, SeekOrigin.Begin); image.Seek(0, SeekOrigin.Begin);
var prevHash = user.Avatar; string? prevHash = user.Avatar;
await objectStorageService.PutObjectAsync(Path(id, hash), image, "image/webp"); await objectStorageService.PutObjectAsync(Path(id, hash), image, "image/webp");
@ -70,7 +71,7 @@ public class UserAvatarUpdateInvocable(
{ {
_logger.Debug("Clearing avatar for user {MemberId}", id); _logger.Debug("Clearing avatar for user {MemberId}", id);
var user = await db.Users.FindAsync(id); User? user = await db.Users.FindAsync(id);
if (user == null) if (user == null)
{ {
_logger.Warning( _logger.Warning(

View file

@ -8,8 +8,8 @@ public class AuthenticationMiddleware(DatabaseContext db) : IMiddleware
{ {
public async Task InvokeAsync(HttpContext ctx, RequestDelegate next) public async Task InvokeAsync(HttpContext ctx, RequestDelegate next)
{ {
var endpoint = ctx.GetEndpoint(); Endpoint? endpoint = ctx.GetEndpoint();
var metadata = endpoint?.Metadata.GetMetadata<AuthenticateAttribute>(); AuthenticateAttribute? metadata = endpoint?.Metadata.GetMetadata<AuthenticateAttribute>();
if (metadata == null) if (metadata == null)
{ {
@ -18,14 +18,17 @@ public class AuthenticationMiddleware(DatabaseContext db) : IMiddleware
} }
if ( if (
!AuthUtils.TryParseToken(ctx.Request.Headers.Authorization.ToString(), out var rawToken) !AuthUtils.TryParseToken(
ctx.Request.Headers.Authorization.ToString(),
out byte[]? rawToken
)
) )
{ {
await next(ctx); await next(ctx);
return; return;
} }
var oauthToken = await db.GetToken(rawToken); Token? oauthToken = await db.GetToken(rawToken);
if (oauthToken == null) if (oauthToken == null)
{ {
await next(ctx); await next(ctx);
@ -50,7 +53,7 @@ public static class HttpContextExtensions
public static Token? GetToken(this HttpContext ctx) public static Token? GetToken(this HttpContext ctx)
{ {
if (ctx.Items.TryGetValue(Key, out var token)) if (ctx.Items.TryGetValue(Key, out object? token))
return token as Token; return token as Token;
return null; return null;
} }

View file

@ -7,8 +7,8 @@ public class AuthorizationMiddleware : IMiddleware
{ {
public async Task InvokeAsync(HttpContext ctx, RequestDelegate next) public async Task InvokeAsync(HttpContext ctx, RequestDelegate next)
{ {
var endpoint = ctx.GetEndpoint(); Endpoint? endpoint = ctx.GetEndpoint();
var attribute = endpoint?.Metadata.GetMetadata<AuthorizeAttribute>(); AuthorizeAttribute? attribute = endpoint?.Metadata.GetMetadata<AuthorizeAttribute>();
if (attribute == null) if (attribute == null)
{ {
@ -16,21 +16,27 @@ public class AuthorizationMiddleware : IMiddleware
return; return;
} }
var token = ctx.GetToken(); Token? token = ctx.GetToken();
if (token == null) if (token == null)
{
throw new ApiError.Unauthorized( throw new ApiError.Unauthorized(
"This endpoint requires an authenticated user.", "This endpoint requires an authenticated user.",
ErrorCode.AuthenticationRequired ErrorCode.AuthenticationRequired
); );
}
if ( if (
attribute.Scopes.Length > 0 attribute.Scopes.Length > 0
&& attribute.Scopes.Except(token.Scopes.ExpandScopes()).Any() && attribute.Scopes.Except(token.Scopes.ExpandScopes()).Any()
) )
{
throw new ApiError.Forbidden( throw new ApiError.Forbidden(
"This endpoint requires ungranted scopes.", "This endpoint requires ungranted scopes.",
attribute.Scopes.Except(token.Scopes.ExpandScopes()), attribute.Scopes.Except(token.Scopes.ExpandScopes()),
ErrorCode.MissingScopes ErrorCode.MissingScopes
); );
}
if (attribute.RequireAdmin && token.User.Role != UserRole.Admin) if (attribute.RequireAdmin && token.User.Role != UserRole.Admin)
throw new ApiError.Forbidden("This endpoint can only be used by admins."); throw new ApiError.Forbidden("This endpoint can only be used by admins.");
if ( if (
@ -38,7 +44,9 @@ public class AuthorizationMiddleware : IMiddleware
&& token.User.Role != UserRole.Admin && token.User.Role != UserRole.Admin
&& token.User.Role != UserRole.Moderator && token.User.Role != UserRole.Moderator
) )
{
throw new ApiError.Forbidden("This endpoint can only be used by moderators."); throw new ApiError.Forbidden("This endpoint can only be used by moderators.");
}
await next(ctx); await next(ctx);
} }

View file

@ -1,4 +1,5 @@
using System.Net; using System.Net;
using Foxnouns.Backend.Database.Models;
using Foxnouns.Backend.Utils; using Foxnouns.Backend.Utils;
using Newtonsoft.Json; using Newtonsoft.Json;
@ -14,9 +15,9 @@ public class ErrorHandlerMiddleware(ILogger baseLogger, IHub sentry) : IMiddlewa
} }
catch (Exception e) catch (Exception e)
{ {
var type = e.TargetSite?.DeclaringType ?? typeof(ErrorHandlerMiddleware); Type type = e.TargetSite?.DeclaringType ?? typeof(ErrorHandlerMiddleware);
var typeName = e.TargetSite?.DeclaringType?.FullName ?? "<unknown>"; string typeName = e.TargetSite?.DeclaringType?.FullName ?? "<unknown>";
var logger = baseLogger.ForContext(type); ILogger logger = baseLogger.ForContext(type);
if (ctx.Response.HasStarted) if (ctx.Response.HasStarted)
{ {
@ -31,14 +32,16 @@ public class ErrorHandlerMiddleware(ILogger baseLogger, IHub sentry) : IMiddlewa
e, e,
scope => scope =>
{ {
var user = ctx.GetUser(); User? user = ctx.GetUser();
if (user != null) if (user != null)
{
scope.User = new SentryUser scope.User = new SentryUser
{ {
Id = user.Id.ToString(), Id = user.Id.ToString(),
Username = user.Username, Username = user.Username,
}; };
} }
}
); );
return; return;
@ -98,18 +101,20 @@ public class ErrorHandlerMiddleware(ILogger baseLogger, IHub sentry) : IMiddlewa
logger.Error(e, "Exception in {ClassName} ({Path})", typeName, ctx.Request.Path); logger.Error(e, "Exception in {ClassName} ({Path})", typeName, ctx.Request.Path);
} }
var errorId = sentry.CaptureException( SentryId errorId = sentry.CaptureException(
e, e,
scope => scope =>
{ {
var user = ctx.GetUser(); User? user = ctx.GetUser();
if (user != null) if (user != null)
{
scope.User = new SentryUser scope.User = new SentryUser
{ {
Id = user.Id.ToString(), Id = user.Id.ToString(),
Username = user.Username, Username = user.Username,
}; };
} }
}
); );
ctx.Response.StatusCode = (int)HttpStatusCode.InternalServerError; ctx.Response.StatusCode = (int)HttpStatusCode.InternalServerError;

View file

@ -9,9 +9,9 @@ using Prometheus;
using Sentry.Extensibility; using Sentry.Extensibility;
using Serilog; using Serilog;
var builder = WebApplication.CreateBuilder(args); WebApplicationBuilder builder = WebApplication.CreateBuilder(args);
var config = builder.AddConfiguration(); Config config = builder.AddConfiguration();
builder.AddSerilog(); builder.AddSerilog();
@ -58,7 +58,7 @@ JsonConvert.DefaultSettings = () =>
builder.AddServices(config).AddCustomMiddleware().AddEndpointsApiExplorer().AddSwaggerGen(); builder.AddServices(config).AddCustomMiddleware().AddEndpointsApiExplorer().AddSwaggerGen();
var app = builder.Build(); WebApplication app = builder.Build();
await app.Initialize(args); await app.Initialize(args);

View file

@ -31,6 +31,16 @@ public class AuthService(
CancellationToken ct = default CancellationToken ct = default
) )
{ {
// Validate username and whether it's not taken
ValidationUtils.Validate(
[
("username", ValidationUtils.ValidateUsername(username)),
("password", ValidationUtils.ValidatePassword(password)),
]
);
if (await db.Users.AnyAsync(u => u.Username == username, ct))
throw new ApiError.BadRequest("Username is already taken", "username", username);
var user = new User var user = new User
{ {
Id = snowflakeGenerator.GenerateSnowflake(), Id = snowflakeGenerator.GenerateSnowflake(),
@ -49,7 +59,7 @@ public class AuthService(
}; };
db.Add(user); db.Add(user);
user.Password = await Task.Run(() => _passwordHasher.HashPassword(user, password), ct); user.Password = await HashPasswordAsync(user, password, ct);
return user; return user;
} }
@ -70,6 +80,8 @@ public class AuthService(
{ {
AssertValidAuthType(authType, instance); AssertValidAuthType(authType, instance);
// Validate username and whether it's not taken
ValidationUtils.Validate([("username", ValidationUtils.ValidateUsername(username))]);
if (await db.Users.AnyAsync(u => u.Username == username, ct)) if (await db.Users.AnyAsync(u => u.Username == username, ct))
throw new ApiError.BadRequest("Username is already taken", "username", username); throw new ApiError.BadRequest("Username is already taken", "username", username);
@ -111,28 +123,30 @@ public class AuthService(
CancellationToken ct = default CancellationToken ct = default
) )
{ {
var user = await db.Users.FirstOrDefaultAsync( User? user = await db.Users.FirstOrDefaultAsync(
u => u.AuthMethods.Any(a => a.AuthType == AuthType.Email && a.RemoteId == email), u => u.AuthMethods.Any(a => a.AuthType == AuthType.Email && a.RemoteId == email),
ct ct
); );
if (user == null) if (user == null)
{
throw new ApiError.NotFound( throw new ApiError.NotFound(
"No user with that email address found, or password is incorrect", "No user with that email address found, or password is incorrect",
ErrorCode.UserNotFound ErrorCode.UserNotFound
); );
}
var pwResult = await Task.Run( PasswordVerificationResult pwResult = await VerifyHashedPasswordAsync(user, password, ct);
() => _passwordHasher.VerifyHashedPassword(user, user.Password!, password),
ct
);
if (pwResult == PasswordVerificationResult.Failed) // TODO: this seems to fail on some valid passwords? if (pwResult == PasswordVerificationResult.Failed) // TODO: this seems to fail on some valid passwords?
{
throw new ApiError.NotFound( throw new ApiError.NotFound(
"No user with that email address found, or password is incorrect", "No user with that email address found, or password is incorrect",
ErrorCode.UserNotFound ErrorCode.UserNotFound
); );
}
if (pwResult == PasswordVerificationResult.SuccessRehashNeeded) if (pwResult == PasswordVerificationResult.SuccessRehashNeeded)
{ {
user.Password = await Task.Run(() => _passwordHasher.HashPassword(user, password), ct); user.Password = await HashPasswordAsync(user, password, ct);
await db.SaveChangesAsync(ct); await db.SaveChangesAsync(ct);
} }
@ -160,10 +174,7 @@ public class AuthService(
throw new FoxnounsError("Password for user supplied to ValidatePasswordAsync was null"); throw new FoxnounsError("Password for user supplied to ValidatePasswordAsync was null");
} }
var pwResult = await Task.Run( PasswordVerificationResult pwResult = await VerifyHashedPasswordAsync(user, password, ct);
() => _passwordHasher.VerifyHashedPassword(user, user.Password!, password),
ct
);
return pwResult return pwResult
is PasswordVerificationResult.SuccessRehashNeeded is PasswordVerificationResult.SuccessRehashNeeded
or PasswordVerificationResult.Success; or PasswordVerificationResult.Success;
@ -178,7 +189,7 @@ public class AuthService(
CancellationToken ct = default CancellationToken ct = default
) )
{ {
user.Password = await Task.Run(() => _passwordHasher.HashPassword(user, password), ct); user.Password = await HashPasswordAsync(user, password, ct);
db.Update(user); db.Update(user);
} }
@ -225,13 +236,15 @@ public class AuthService(
AssertValidAuthType(authType, app); AssertValidAuthType(authType, app);
// This is already checked when // This is already checked when
var currentCount = await db int currentCount = await db
.AuthMethods.Where(m => m.UserId == userId && m.AuthType == authType) .AuthMethods.Where(m => m.UserId == userId && m.AuthType == authType)
.CountAsync(ct); .CountAsync(ct);
if (currentCount >= AuthUtils.MaxAuthMethodsPerType) if (currentCount >= AuthUtils.MaxAuthMethodsPerType)
{
throw new ApiError.BadRequest( throw new ApiError.BadRequest(
"Too many linked accounts of this type, maximum of 3 per account." "Too many linked accounts of this type, maximum of 3 per account."
); );
}
var authMethod = new AuthMethod var authMethod = new AuthMethod
{ {
@ -256,13 +269,15 @@ public class AuthService(
) )
{ {
if (!AuthUtils.ValidateScopes(application, scopes)) if (!AuthUtils.ValidateScopes(application, scopes))
{
throw new ApiError.BadRequest( throw new ApiError.BadRequest(
"Invalid scopes requested for this token", "Invalid scopes requested for this token",
"scopes", "scopes",
scopes scopes
); );
}
var (token, hash) = GenerateToken(); (string? token, byte[]? hash) = GenerateToken();
return ( return (
token, token,
new Token new Token
@ -287,9 +302,9 @@ public class AuthService(
CancellationToken ct = default CancellationToken ct = default
) )
{ {
var frontendApp = await db.GetFrontendApplicationAsync(ct); Application frontendApp = await db.GetFrontendApplicationAsync(ct);
var (tokenStr, token) = GenerateToken( (string? tokenStr, Token? token) = GenerateToken(
user, user,
frontendApp, frontendApp,
["*"], ["*"],
@ -302,24 +317,35 @@ public class AuthService(
await db.SaveChangesAsync(ct); await db.SaveChangesAsync(ct);
return new CallbackResponse( return new CallbackResponse(
HasAccount: true, true,
Ticket: null, null,
RemoteUsername: null, null,
User: await userRenderer.RenderUserAsync( await userRenderer.RenderUserAsync(user, user, renderMembers: false, ct: ct),
user, tokenStr,
selfUser: user, token.ExpiresAt
renderMembers: false,
ct: ct
),
Token: tokenStr,
ExpiresAt: token.ExpiresAt
); );
} }
private Task<string> HashPasswordAsync(
User user,
string password,
CancellationToken ct = default
) => Task.Run(() => _passwordHasher.HashPassword(user, password), ct);
private Task<PasswordVerificationResult> VerifyHashedPasswordAsync(
User user,
string providedPassword,
CancellationToken ct = default
) =>
Task.Run(
() => _passwordHasher.VerifyHashedPassword(user, user.Password!, providedPassword),
ct
);
private static (string, byte[]) GenerateToken() private static (string, byte[]) GenerateToken()
{ {
var token = AuthUtils.RandomToken(); string token = AuthUtils.RandomToken();
var hash = SHA512.HashData(Convert.FromBase64String(token)); byte[] hash = SHA512.HashData(Convert.FromBase64String(token));
return (token, hash); return (token, hash);
} }

View file

@ -18,22 +18,25 @@ public partial class FediverseAuthService
Snowflake? existingAppId = null Snowflake? existingAppId = null
) )
{ {
var resp = await _client.PostAsJsonAsync( HttpResponseMessage resp = await _client.PostAsJsonAsync(
$"https://{instance}/api/v1/apps", $"https://{instance}/api/v1/apps",
new CreateMastodonApplicationRequest( new CreateMastodonApplicationRequest(
ClientName: $"pronouns.cc (+{_config.BaseUrl})", $"pronouns.cc (+{_config.BaseUrl})",
RedirectUris: MastodonRedirectUri(instance), MastodonRedirectUri(instance),
Scopes: "read read:accounts", "read read:accounts",
Website: _config.BaseUrl _config.BaseUrl
) )
); );
resp.EnsureSuccessStatusCode(); resp.EnsureSuccessStatusCode();
var mastodonApp = await resp.Content.ReadFromJsonAsync<PartialMastodonApplication>(); PartialMastodonApplication? mastodonApp =
await resp.Content.ReadFromJsonAsync<PartialMastodonApplication>();
if (mastodonApp == null) if (mastodonApp == null)
{
throw new FoxnounsError( throw new FoxnounsError(
$"Application created on Mastodon-compatible instance {instance} was null" $"Application created on Mastodon-compatible instance {instance} was null"
); );
}
FediverseApplication app; FediverseApplication app;
@ -75,7 +78,7 @@ public partial class FediverseAuthService
if (state != null) if (state != null)
await _keyCacheService.ValidateAuthStateAsync(state); await _keyCacheService.ValidateAuthStateAsync(state);
var tokenResp = await _client.PostAsync( HttpResponseMessage tokenResp = await _client.PostAsync(
MastodonTokenUri(app.Domain), MastodonTokenUri(app.Domain),
new FormUrlEncodedContent( new FormUrlEncodedContent(
new Dictionary<string, string> new Dictionary<string, string>
@ -95,7 +98,7 @@ public partial class FediverseAuthService
} }
tokenResp.EnsureSuccessStatusCode(); tokenResp.EnsureSuccessStatusCode();
var token = ( string? token = (
await tokenResp.Content.ReadFromJsonAsync<MastodonTokenResponse>() await tokenResp.Content.ReadFromJsonAsync<MastodonTokenResponse>()
)?.AccessToken; )?.AccessToken;
if (token == null) if (token == null)
@ -106,9 +109,9 @@ public partial class FediverseAuthService
var req = new HttpRequestMessage(HttpMethod.Get, MastodonCurrentUserUri(app.Domain)); var req = new HttpRequestMessage(HttpMethod.Get, MastodonCurrentUserUri(app.Domain));
req.Headers.Add("Authorization", $"Bearer {token}"); req.Headers.Add("Authorization", $"Bearer {token}");
var currentUserResp = await _client.SendAsync(req); HttpResponseMessage currentUserResp = await _client.SendAsync(req);
currentUserResp.EnsureSuccessStatusCode(); currentUserResp.EnsureSuccessStatusCode();
var user = await currentUserResp.Content.ReadFromJsonAsync<FediverseUser>(); FediverseUser? user = await currentUserResp.Content.ReadFromJsonAsync<FediverseUser>();
if (user == null) if (user == null)
{ {
throw new FoxnounsError($"User response from instance {app.Domain} was invalid"); throw new FoxnounsError($"User response from instance {app.Domain} was invalid");
@ -131,7 +134,7 @@ public partial class FediverseAuthService
"An app credentials refresh was requested for {ApplicationId}, creating a new application", "An app credentials refresh was requested for {ApplicationId}, creating a new application",
app.Id app.Id
); );
app = await CreateMastodonApplicationAsync(app.Domain, existingAppId: app.Id); app = await CreateMastodonApplicationAsync(app.Domain, app.Id);
} }
state ??= HttpUtility.UrlEncode(await _keyCacheService.GenerateAuthStateAsync()); state ??= HttpUtility.UrlEncode(await _keyCacheService.GenerateAuthStateAsync());

View file

@ -43,7 +43,7 @@ public partial class FediverseAuthService
string? state = null string? state = null
) )
{ {
var app = await GetApplicationAsync(instance); FediverseApplication app = await GetApplicationAsync(instance);
return await GenerateAuthUrlAsync(app, forceRefresh, state); return await GenerateAuthUrlAsync(app, forceRefresh, state);
} }
@ -56,13 +56,15 @@ public partial class FediverseAuthService
public async Task<FediverseApplication> GetApplicationAsync(string instance) public async Task<FediverseApplication> GetApplicationAsync(string instance)
{ {
var app = await _db.FediverseApplications.FirstOrDefaultAsync(a => a.Domain == instance); FediverseApplication? app = await _db.FediverseApplications.FirstOrDefaultAsync(a =>
a.Domain == instance
);
if (app != null) if (app != null)
return app; return app;
_logger.Debug("No application for fediverse instance {Instance}, creating it", instance); _logger.Debug("No application for fediverse instance {Instance}, creating it", instance);
var softwareName = await GetSoftwareNameAsync(instance); string softwareName = await GetSoftwareNameAsync(instance);
if (IsMastodonCompatible(softwareName)) if (IsMastodonCompatible(softwareName))
{ {
@ -76,13 +78,14 @@ public partial class FediverseAuthService
{ {
_logger.Debug("Requesting software name for fediverse instance {Instance}", instance); _logger.Debug("Requesting software name for fediverse instance {Instance}", instance);
var wellKnownResp = await _client.GetAsync( HttpResponseMessage wellKnownResp = await _client.GetAsync(
new Uri($"https://{instance}/.well-known/nodeinfo") new Uri($"https://{instance}/.well-known/nodeinfo")
); );
wellKnownResp.EnsureSuccessStatusCode(); wellKnownResp.EnsureSuccessStatusCode();
var wellKnown = await wellKnownResp.Content.ReadFromJsonAsync<WellKnownResponse>(); WellKnownResponse? wellKnown =
var nodeInfoUrl = wellKnown?.Links.FirstOrDefault(l => l.Rel == NodeInfoRel)?.Href; await wellKnownResp.Content.ReadFromJsonAsync<WellKnownResponse>();
string? nodeInfoUrl = wellKnown?.Links.FirstOrDefault(l => l.Rel == NodeInfoRel)?.Href;
if (nodeInfoUrl == null) if (nodeInfoUrl == null)
{ {
throw new FoxnounsError( throw new FoxnounsError(
@ -90,10 +93,10 @@ public partial class FediverseAuthService
); );
} }
var nodeInfoResp = await _client.GetAsync(nodeInfoUrl); HttpResponseMessage nodeInfoResp = await _client.GetAsync(nodeInfoUrl);
nodeInfoResp.EnsureSuccessStatusCode(); nodeInfoResp.EnsureSuccessStatusCode();
var nodeInfo = await nodeInfoResp.Content.ReadFromJsonAsync<PartialNodeInfo>(); PartialNodeInfo? nodeInfo = await nodeInfoResp.Content.ReadFromJsonAsync<PartialNodeInfo>();
return nodeInfo?.Software.Name return nodeInfo?.Software.Name
?? throw new FoxnounsError( ?? throw new FoxnounsError(
$"Nodeinfo response for instance {instance} was invalid, no software name" $"Nodeinfo response for instance {instance} was invalid, no software name"

View file

@ -29,7 +29,7 @@ public class RemoteAuthService(
) )
{ {
var redirectUri = $"{config.BaseUrl}/auth/callback/discord"; var redirectUri = $"{config.BaseUrl}/auth/callback/discord";
var resp = await _httpClient.PostAsync( HttpResponseMessage resp = await _httpClient.PostAsync(
_discordTokenUri, _discordTokenUri,
new FormUrlEncodedContent( new FormUrlEncodedContent(
new Dictionary<string, string> new Dictionary<string, string>
@ -45,7 +45,7 @@ public class RemoteAuthService(
); );
if (!resp.IsSuccessStatusCode) if (!resp.IsSuccessStatusCode)
{ {
var respBody = await resp.Content.ReadAsStringAsync(ct); string respBody = await resp.Content.ReadAsStringAsync(ct);
_logger.Error( _logger.Error(
"Received error status {StatusCode} when exchanging OAuth token: {ErrorBody}", "Received error status {StatusCode} when exchanging OAuth token: {ErrorBody}",
(int)resp.StatusCode, (int)resp.StatusCode,
@ -55,16 +55,18 @@ public class RemoteAuthService(
} }
resp.EnsureSuccessStatusCode(); resp.EnsureSuccessStatusCode();
var token = await resp.Content.ReadFromJsonAsync<DiscordTokenResponse>(ct); DiscordTokenResponse? token = await resp.Content.ReadFromJsonAsync<DiscordTokenResponse>(
ct
);
if (token == null) if (token == null)
throw new FoxnounsError("Discord token response was null"); throw new FoxnounsError("Discord token response was null");
var req = new HttpRequestMessage(HttpMethod.Get, _discordUserUri); var req = new HttpRequestMessage(HttpMethod.Get, _discordUserUri);
req.Headers.Add("Authorization", $"{token.token_type} {token.access_token}"); req.Headers.Add("Authorization", $"{token.token_type} {token.access_token}");
var resp2 = await _httpClient.SendAsync(req, ct); HttpResponseMessage resp2 = await _httpClient.SendAsync(req, ct);
resp2.EnsureSuccessStatusCode(); resp2.EnsureSuccessStatusCode();
var user = await resp2.Content.ReadFromJsonAsync<DiscordUserResponse>(ct); DiscordUserResponse? user = await resp2.Content.ReadFromJsonAsync<DiscordUserResponse>(ct);
if (user == null) if (user == null)
throw new FoxnounsError("Discord user response was null"); throw new FoxnounsError("Discord user response was null");
@ -104,7 +106,7 @@ public class RemoteAuthService(
string? instance = null string? instance = null
) )
{ {
var existingAccounts = await db int existingAccounts = await db
.AuthMethods.Where(m => m.UserId == userId && m.AuthType == authType) .AuthMethods.Where(m => m.UserId == userId && m.AuthType == authType)
.CountAsync(); .CountAsync();
if (existingAccounts > AuthUtils.MaxAuthMethodsPerType) if (existingAccounts > AuthUtils.MaxAuthMethodsPerType)
@ -131,13 +133,17 @@ public class RemoteAuthService(
string? instance = null string? instance = null
) )
{ {
var accountState = await keyCacheService.GetAddExtraAccountStateAsync(state); AddExtraAccountState? accountState = await keyCacheService.GetAddExtraAccountStateAsync(
state
);
if ( if (
accountState == null accountState == null
|| accountState.AuthType != authType || accountState.AuthType != authType
|| accountState.UserId != userId || accountState.UserId != userId
|| (instance != null && accountState.Instance != instance) || (instance != null && accountState.Instance != instance)
) )
{
throw new ApiError.BadRequest("Invalid state", "state", state); throw new ApiError.BadRequest("Invalid state", "state", state);
} }
} }
}

View file

@ -28,9 +28,9 @@ public class DataCleanupService(
private async Task CleanUsersAsync(CancellationToken ct = default) private async Task CleanUsersAsync(CancellationToken ct = default)
{ {
var selfDeleteExpires = clock.GetCurrentInstant() - User.DeleteAfter; Instant selfDeleteExpires = clock.GetCurrentInstant() - User.DeleteAfter;
var suspendExpires = clock.GetCurrentInstant() - User.DeleteSuspendedAfter; Instant suspendExpires = clock.GetCurrentInstant() - User.DeleteSuspendedAfter;
var users = await db List<User> users = await db
.Users.Include(u => u.Members) .Users.Include(u => u.Members)
.Include(u => u.DataExports) .Include(u => u.DataExports)
.Where(u => .Where(u =>
@ -92,13 +92,15 @@ public class DataCleanupService(
private async Task CleanExportsAsync(CancellationToken ct = default) private async Task CleanExportsAsync(CancellationToken ct = default)
{ {
var minExpiredId = Snowflake.FromInstant(clock.GetCurrentInstant() - DataExport.Expiration); var minExpiredId = Snowflake.FromInstant(clock.GetCurrentInstant() - DataExport.Expiration);
var exports = await db.DataExports.Where(d => d.Id < minExpiredId).ToListAsync(ct); List<DataExport> exports = await db
.DataExports.Where(d => d.Id < minExpiredId)
.ToListAsync(ct);
if (exports.Count == 0) if (exports.Count == 0)
return; return;
_logger.Debug("Deleting {Count} expired exports", exports.Count); _logger.Debug("Deleting {Count} expired exports", exports.Count);
foreach (var export in exports) foreach (DataExport? export in exports)
{ {
_logger.Debug("Deleting export {ExportId}", export.Id); _logger.Debug("Deleting export {ExportId}", export.Id);
await objectStorageService.RemoveObjectAsync( await objectStorageService.RemoveObjectAsync(

View file

@ -41,7 +41,7 @@ public class KeyCacheService(DatabaseContext db, IClock clock, ILogger logger)
CancellationToken ct = default CancellationToken ct = default
) )
{ {
var value = await db.TemporaryKeys.FirstOrDefaultAsync(k => k.Key == key, ct); TemporaryKey? value = await db.TemporaryKeys.FirstOrDefaultAsync(k => k.Key == key, ct);
if (value == null) if (value == null)
return null; return null;
@ -56,7 +56,7 @@ public class KeyCacheService(DatabaseContext db, IClock clock, ILogger logger)
public async Task DeleteExpiredKeysAsync(CancellationToken ct) public async Task DeleteExpiredKeysAsync(CancellationToken ct)
{ {
var count = await db int count = await db
.TemporaryKeys.Where(k => k.Expires < clock.GetCurrentInstant()) .TemporaryKeys.Where(k => k.Expires < clock.GetCurrentInstant())
.ExecuteDeleteAsync(ct); .ExecuteDeleteAsync(ct);
if (count != 0) if (count != 0)
@ -79,7 +79,7 @@ public class KeyCacheService(DatabaseContext db, IClock clock, ILogger logger)
) )
where T : class where T : class
{ {
var value = JsonConvert.SerializeObject(obj); string value = JsonConvert.SerializeObject(obj);
await SetKeyAsync(key, value, expires, ct); await SetKeyAsync(key, value, expires, ct);
} }
@ -90,7 +90,7 @@ public class KeyCacheService(DatabaseContext db, IClock clock, ILogger logger)
) )
where T : class where T : class
{ {
var value = await GetKeyAsync(key, delete, ct); string? value = await GetKeyAsync(key, delete, ct);
return value == null ? default : JsonConvert.DeserializeObject<T>(value); return value == null ? default : JsonConvert.DeserializeObject<T>(value);
} }
} }

View file

@ -10,11 +10,11 @@ public class MemberRendererService(DatabaseContext db, Config config)
{ {
public async Task<IEnumerable<PartialMember>> RenderUserMembersAsync(User user, Token? token) public async Task<IEnumerable<PartialMember>> RenderUserMembersAsync(User user, Token? token)
{ {
var canReadHiddenMembers = bool canReadHiddenMembers =
token != null && token.UserId == user.Id && token.HasScope("member.read"); token != null && token.UserId == user.Id && token.HasScope("member.read");
var renderUnlisted = bool renderUnlisted =
token != null && token.UserId == user.Id && token.HasScope("user.read_hidden"); token != null && token.UserId == user.Id && token.HasScope("user.read_hidden");
var canReadMemberList = !user.ListHidden || canReadHiddenMembers; bool canReadMemberList = !user.ListHidden || canReadHiddenMembers;
IEnumerable<Member> members = canReadMemberList IEnumerable<Member> members = canReadMemberList
? await db.Members.Where(m => m.UserId == user.Id).OrderBy(m => m.Name).ToListAsync() ? await db.Members.Where(m => m.UserId == user.Id).OrderBy(m => m.Name).ToListAsync()
@ -30,7 +30,7 @@ public class MemberRendererService(DatabaseContext db, Config config)
string? overrideSid = null string? overrideSid = null
) )
{ {
var renderUnlisted = token?.UserId == member.UserId && token.HasScope("user.read_hidden"); bool renderUnlisted = token?.UserId == member.UserId && token.HasScope("user.read_hidden");
return new MemberResponse( return new MemberResponse(
member.Id, member.Id,

View file

@ -3,6 +3,7 @@ using Foxnouns.Backend.Database;
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using NodaTime; using NodaTime;
using Prometheus; using Prometheus;
using ITimer = Prometheus.ITimer;
namespace Foxnouns.Backend.Services; namespace Foxnouns.Backend.Services;
@ -16,19 +17,23 @@ public class MetricsCollectionService(ILogger logger, IServiceProvider services,
public async Task CollectMetricsAsync(CancellationToken ct = default) public async Task CollectMetricsAsync(CancellationToken ct = default)
{ {
var timer = FoxnounsMetrics.MetricsCollectionTime.NewTimer(); ITimer timer = FoxnounsMetrics.MetricsCollectionTime.NewTimer();
var now = clock.GetCurrentInstant(); Instant now = clock.GetCurrentInstant();
await using var scope = services.CreateAsyncScope(); await using AsyncServiceScope scope = services.CreateAsyncScope();
// ReSharper disable once SuggestVarOrType_SimpleTypes
await using var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>(); await using var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
var users = await db.Users.Where(u => !u.Deleted).Select(u => u.LastActive).ToListAsync(ct); List<Instant>? users = await db
.Users.Where(u => !u.Deleted)
.Select(u => u.LastActive)
.ToListAsync(ct);
FoxnounsMetrics.UsersCount.Set(users.Count); FoxnounsMetrics.UsersCount.Set(users.Count);
FoxnounsMetrics.UsersActiveMonthCount.Set(users.Count(i => i > now - Month)); FoxnounsMetrics.UsersActiveMonthCount.Set(users.Count(i => i > now - Month));
FoxnounsMetrics.UsersActiveWeekCount.Set(users.Count(i => i > now - Week)); FoxnounsMetrics.UsersActiveWeekCount.Set(users.Count(i => i > now - Week));
FoxnounsMetrics.UsersActiveDayCount.Set(users.Count(i => i > now - Day)); FoxnounsMetrics.UsersActiveDayCount.Set(users.Count(i => i > now - Day));
var memberCount = await db int memberCount = await db
.Members.Include(m => m.User) .Members.Include(m => m.User)
.Where(m => !m.Unlisted && !m.User.ListHidden && !m.User.Deleted) .Where(m => !m.Unlisted && !m.User.ListHidden && !m.User.Deleted)
.CountAsync(ct); .CountAsync(ct);

View file

@ -1,4 +1,5 @@
using Minio; using Minio;
using Minio.DataModel;
using Minio.DataModel.Args; using Minio.DataModel.Args;
using Minio.Exceptions; using Minio.Exceptions;
@ -48,13 +49,4 @@ public class ObjectStorageService(ILogger logger, Config config, IMinioClient mi
ct ct
); );
} }
public async Task GetObjectAsync(string path, CancellationToken ct = default)
{
var stream = new MemoryStream();
var resp = await minioClient.GetObjectAsync(
new GetObjectArgs().WithBucket(config.Storage.Bucket).WithObject(path),
ct
);
}
} }

View file

@ -15,10 +15,13 @@ public class PeriodicTasksService(ILogger logger, IServiceProvider services) : B
{ {
_logger.Debug("Running periodic tasks"); _logger.Debug("Running periodic tasks");
await using var scope = services.CreateAsyncScope(); await using AsyncServiceScope scope = services.CreateAsyncScope();
// The type is literally written on the same line, we can just use `var`
// ReSharper disable SuggestVarOrType_SimpleTypes
var keyCacheService = scope.ServiceProvider.GetRequiredService<KeyCacheService>(); var keyCacheService = scope.ServiceProvider.GetRequiredService<KeyCacheService>();
var dataCleanupService = scope.ServiceProvider.GetRequiredService<DataCleanupService>(); var dataCleanupService = scope.ServiceProvider.GetRequiredService<DataCleanupService>();
// ReSharper restore SuggestVarOrType_SimpleTypes
await keyCacheService.DeleteExpiredKeysAsync(ct); await keyCacheService.DeleteExpiredKeysAsync(ct);
await dataCleanupService.InvokeAsync(ct); await dataCleanupService.InvokeAsync(ct);

View file

@ -43,9 +43,9 @@ public class UserRendererService(
) )
{ {
scopes = scopes.ExpandScopes(); scopes = scopes.ExpandScopes();
var tokenCanReadHiddenMembers = scopes.Contains("member.read") && isSelfUser; bool tokenCanReadHiddenMembers = scopes.Contains("member.read") && isSelfUser;
var tokenHidden = scopes.Contains("user.read_hidden") && isSelfUser; bool tokenHidden = scopes.Contains("user.read_hidden") && isSelfUser;
var tokenPrivileged = scopes.Contains("user.read_privileged") && isSelfUser; bool tokenPrivileged = scopes.Contains("user.read_privileged") && isSelfUser;
renderMembers = renderMembers && (!user.ListHidden || tokenCanReadHiddenMembers); renderMembers = renderMembers && (!user.ListHidden || tokenCanReadHiddenMembers);
renderAuthMethods = renderAuthMethods && tokenPrivileged; renderAuthMethods = renderAuthMethods && tokenPrivileged;
@ -57,12 +57,12 @@ public class UserRendererService(
if (!(isSelfUser && tokenCanReadHiddenMembers)) if (!(isSelfUser && tokenCanReadHiddenMembers))
members = members.Where(m => !m.Unlisted); members = members.Where(m => !m.Unlisted);
var flags = await db List<UserFlag> flags = await db
.UserFlags.Where(f => f.UserId == user.Id) .UserFlags.Where(f => f.UserId == user.Id)
.OrderBy(f => f.Id) .OrderBy(f => f.Id)
.ToListAsync(ct); .ToListAsync(ct);
var authMethods = renderAuthMethods List<AuthMethod> authMethods = renderAuthMethods
? await db ? await db
.AuthMethods.Where(a => a.UserId == user.Id) .AuthMethods.Where(a => a.UserId == user.Id)
.Include(a => a.FediverseApplication) .Include(a => a.FediverseApplication)
@ -72,9 +72,11 @@ public class UserRendererService(
int? utcOffset = null; int? utcOffset = null;
if ( if (
user.Timezone != null user.Timezone != null
&& TimeZoneInfo.TryFindSystemTimeZoneById(user.Timezone, out var tz) && TimeZoneInfo.TryFindSystemTimeZoneById(user.Timezone, out TimeZoneInfo? tz)
) )
{
utcOffset = (int)tz.GetUtcOffset(DateTimeOffset.UtcNow).TotalSeconds; utcOffset = (int)tz.GetUtcOffset(DateTimeOffset.UtcNow).TotalSeconds;
}
return new UserResponse( return new UserResponse(
user.Id, user.Id,

View file

@ -69,8 +69,8 @@ public static class AuthUtils
public static bool ValidateScopes(Application application, string[] scopes) public static bool ValidateScopes(Application application, string[] scopes)
{ {
var expandedScopes = scopes.ExpandScopes(); string[] expandedScopes = scopes.ExpandScopes();
var appScopes = application.Scopes.ExpandAppScopes(); string[] appScopes = application.Scopes.ExpandAppScopes();
return !expandedScopes.Except(appScopes).Any(); return !expandedScopes.Except(appScopes).Any();
} }
@ -78,7 +78,7 @@ public static class AuthUtils
{ {
try try
{ {
var scheme = new Uri(uri).Scheme; string scheme = new Uri(uri).Scheme;
return !ForbiddenSchemes.Contains(scheme); return !ForbiddenSchemes.Contains(scheme);
} }
catch catch

View file

@ -5,10 +5,11 @@ using Newtonsoft.Json.Serialization;
namespace Foxnouns.Backend.Utils; namespace Foxnouns.Backend.Utils;
/// <summary> /// <summary>
/// A base class used for PATCH requests which stores information on whether a key is explicitly set to null or not passed at all. /// <para>A base class used for PATCH requests which stores information on whether a key is explicitly set to null or not passed at all.</para>
/// /// <para>
/// HasProperty() should not be used for properties that cannot be set to null--a null value should be treated /// HasProperty() should not be used for properties that cannot be set to null--a null value should be treated
/// as an unset value in those cases. /// as an unset value in those cases.
/// </para>
/// </summary> /// </summary>
public abstract class PatchRequest public abstract class PatchRequest
{ {
@ -30,7 +31,7 @@ public class PatchRequestContractResolver : DefaultContractResolver
MemberSerialization memberSerialization MemberSerialization memberSerialization
) )
{ {
var prop = base.CreateProperty(member, memberSerialization); JsonProperty prop = base.CreateProperty(member, memberSerialization);
prop.SetIsSpecified += (o, _) => prop.SetIsSpecified += (o, _) =>
{ {

View file

@ -39,6 +39,7 @@ public static partial class ValidationUtils
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
if (fields.Count > 25) if (fields.Count > 25)
{
errors.Add( errors.Add(
( (
"fields", "fields",
@ -50,11 +51,13 @@ public static partial class ValidationUtils
) )
) )
); );
}
// No overwhelming this function, thank you // No overwhelming this function, thank you
if (fields.Count > 100) if (fields.Count > 100)
return errors; return errors;
foreach (var (field, index) in fields.Select((field, index) => (field, index))) foreach ((Field? field, int index) in fields.Select((field, index) => (field, index)))
{ {
switch (field.Name.Length) switch (field.Name.Length)
{ {
@ -111,6 +114,7 @@ public static partial class ValidationUtils
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
if (entries.Length > Limits.FieldEntriesLimit) if (entries.Length > Limits.FieldEntriesLimit)
{
errors.Add( errors.Add(
( (
errorPrefix, errorPrefix,
@ -122,15 +126,19 @@ public static partial class ValidationUtils
) )
) )
); );
}
// Same as above, no overwhelming this function with a ridiculous amount of entries // Same as above, no overwhelming this function with a ridiculous amount of entries
if (entries.Length > Limits.FieldEntriesLimit + 50) if (entries.Length > Limits.FieldEntriesLimit + 50)
return errors; return errors;
var customPreferenceIds = string[] customPreferenceIds = customPreferences.Keys.Select(id => id.ToString()).ToArray();
customPreferences?.Keys.Select(id => id.ToString()).ToArray() ?? [];
foreach (var (entry, entryIdx) in entries.Select((entry, entryIdx) => (entry, entryIdx))) foreach (
(FieldEntry? entry, int entryIdx) in entries.Select(
(entry, entryIdx) => (entry, entryIdx)
)
)
{ {
switch (entry.Value.Length) switch (entry.Value.Length)
{ {
@ -166,6 +174,7 @@ public static partial class ValidationUtils
!DefaultStatusOptions.Contains(entry.Status) !DefaultStatusOptions.Contains(entry.Status)
&& !customPreferenceIds.Contains(entry.Status) && !customPreferenceIds.Contains(entry.Status)
) )
{
errors.Add( errors.Add(
( (
$"{errorPrefix}.{entryIdx}.status", $"{errorPrefix}.{entryIdx}.status",
@ -173,6 +182,7 @@ public static partial class ValidationUtils
) )
); );
} }
}
return errors; return errors;
} }
@ -188,6 +198,7 @@ public static partial class ValidationUtils
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
if (entries.Length > Limits.FieldEntriesLimit) if (entries.Length > Limits.FieldEntriesLimit)
{
errors.Add( errors.Add(
( (
errorPrefix, errorPrefix,
@ -199,15 +210,17 @@ public static partial class ValidationUtils
) )
) )
); );
}
// Same as above, no overwhelming this function with a ridiculous amount of entries // Same as above, no overwhelming this function with a ridiculous amount of entries
if (entries.Length > Limits.FieldEntriesLimit + 50) if (entries.Length > Limits.FieldEntriesLimit + 50)
return errors; return errors;
var customPreferenceIds = string[] customPreferenceIds = customPreferences.Keys.Select(id => id.ToString()).ToArray();
customPreferences?.Keys.Select(id => id.ToString()).ToList() ?? [];
foreach (var (entry, entryIdx) in entries.Select((entry, entryIdx) => (entry, entryIdx))) foreach (
(Pronoun? entry, int entryIdx) in entries.Select((entry, entryIdx) => (entry, entryIdx))
)
{ {
switch (entry.Value.Length) switch (entry.Value.Length)
{ {
@ -276,6 +289,7 @@ public static partial class ValidationUtils
!DefaultStatusOptions.Contains(entry.Status) !DefaultStatusOptions.Contains(entry.Status)
&& !customPreferenceIds.Contains(entry.Status) && !customPreferenceIds.Contains(entry.Status)
) )
{
errors.Add( errors.Add(
( (
$"{errorPrefix}.{entryIdx}.status", $"{errorPrefix}.{entryIdx}.status",
@ -283,6 +297,7 @@ public static partial class ValidationUtils
) )
); );
} }
}
return errors; return errors;
} }

View file

@ -29,6 +29,7 @@ public static partial class ValidationUtils
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
if (preferences.Count > MaxCustomPreferences) if (preferences.Count > MaxCustomPreferences)
{
errors.Add( errors.Add(
( (
"custom_preferences", "custom_preferences",
@ -40,20 +41,29 @@ public static partial class ValidationUtils
) )
) )
); );
}
if (preferences.Count > 50) if (preferences.Count > 50)
return errors; return errors;
foreach (var (p, i) in preferences.Select((p, i) => (p, i))) foreach (
(UsersController.CustomPreferenceUpdate? p, int i) in preferences.Select(
(p, i) => (p, i)
)
)
{ {
if (!BootstrapIcons.IsValid(p.Icon)) if (!BootstrapIcons.IsValid(p.Icon))
{
errors.Add( errors.Add(
( (
$"custom_preferences.{i}.icon", $"custom_preferences.{i}.icon",
ValidationError.DisallowedValueError("Invalid icon name", [], p.Icon) ValidationError.DisallowedValueError("Invalid icon name", [], p.Icon)
) )
); );
}
if (p.Tooltip.Length is 1 or > MaxPreferenceTooltipLength) if (p.Tooltip.Length is 1 or > MaxPreferenceTooltipLength)
{
errors.Add( errors.Add(
( (
$"custom_preferences.{i}.tooltip", $"custom_preferences.{i}.tooltip",
@ -66,6 +76,7 @@ public static partial class ValidationUtils
) )
); );
} }
}
return errors; return errors;
} }

View file

@ -46,6 +46,7 @@ public static partial class ValidationUtils
public static ValidationError? ValidateUsername(string username) public static ValidationError? ValidateUsername(string username)
{ {
if (!UsernameRegex().IsMatch(username)) if (!UsernameRegex().IsMatch(username))
{
return username.Length switch return username.Length switch
{ {
< 2 => ValidationError.LengthError("Username is too short", 2, 40, username.Length), < 2 => ValidationError.LengthError("Username is too short", 2, 40, username.Length),
@ -55,19 +56,24 @@ public static partial class ValidationUtils
username username
), ),
}; };
}
if ( if (
InvalidUsernames.Any(u => InvalidUsernames.Any(u =>
string.Equals(u, username, StringComparison.InvariantCultureIgnoreCase) string.Equals(u, username, StringComparison.InvariantCultureIgnoreCase)
) )
) )
{
return ValidationError.GenericValidationError("Username is not allowed", username); return ValidationError.GenericValidationError("Username is not allowed", username);
}
return null; return null;
} }
public static ValidationError? ValidateMemberName(string memberName) public static ValidationError? ValidateMemberName(string memberName)
{ {
if (!MemberRegex().IsMatch(memberName)) if (!MemberRegex().IsMatch(memberName))
{
return memberName.Length switch return memberName.Length switch
{ {
< 1 => ValidationError.LengthError("Name is too short", 1, 100, memberName.Length), < 1 => ValidationError.LengthError("Name is too short", 1, 100, memberName.Length),
@ -79,13 +85,17 @@ public static partial class ValidationUtils
memberName memberName
), ),
}; };
}
if ( if (
InvalidMemberNames.Any(u => InvalidMemberNames.Any(u =>
string.Equals(u, memberName, StringComparison.InvariantCultureIgnoreCase) string.Equals(u, memberName, StringComparison.InvariantCultureIgnoreCase)
) )
) )
{
return ValidationError.GenericValidationError("Name is not allowed", memberName); return ValidationError.GenericValidationError("Name is not allowed", memberName);
}
return null; return null;
} }
@ -117,13 +127,15 @@ public static partial class ValidationUtils
if (links == null) if (links == null)
return []; return [];
if (links.Length > MaxLinks) if (links.Length > MaxLinks)
{
return return
[ [
("links", ValidationError.LengthError("Too many links", 0, MaxLinks, links.Length)), ("links", ValidationError.LengthError("Too many links", 0, MaxLinks, links.Length)),
]; ];
}
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
foreach (var (link, idx) in links.Select((l, i) => (l, i))) foreach ((string link, int idx) in links.Select((l, i) => (l, i)))
{ {
switch (link.Length) switch (link.Length)
{ {
@ -185,6 +197,27 @@ public static partial class ValidationUtils
}; };
} }
public const int MinimumPasswordLength = 12;
public const int MaximumPasswordLength = 1024;
public static ValidationError? ValidatePassword(string password) =>
password.Length switch
{
< MinimumPasswordLength => ValidationError.LengthError(
"Password is too short",
MinimumPasswordLength,
MaximumPasswordLength,
password.Length
),
> MaximumPasswordLength => ValidationError.LengthError(
"Password is too long",
MinimumPasswordLength,
MaximumPasswordLength,
password.Length
),
_ => null,
};
[GeneratedRegex(@"^[a-zA-Z_0-9\-\.]{2,40}$", RegexOptions.IgnoreCase, "en-NL")] [GeneratedRegex(@"^[a-zA-Z_0-9\-\.]{2,40}$", RegexOptions.IgnoreCase, "en-NL")]
private static partial Regex UsernameRegex(); private static partial Regex UsernameRegex();

View file

@ -12,9 +12,9 @@ public static partial class ValidationUtils
return; return;
var errorDict = new Dictionary<string, IEnumerable<ValidationError>>(); var errorDict = new Dictionary<string, IEnumerable<ValidationError>>();
foreach (var error in errors) foreach ((string, ValidationError?) error in errors)
{ {
if (errorDict.TryGetValue(error.Item1, out var value)) if (errorDict.TryGetValue(error.Item1, out IEnumerable<ValidationError>? value))
errorDict[error.Item1] = value.Append(error.Item2!); errorDict[error.Item1] = value.Append(error.Item2!);
errorDict.Add(error.Item1, [error.Item2!]); errorDict.Add(error.Item1, [error.Item2!]);
} }

View file

@ -3,7 +3,7 @@
<p> <p>
Please continue creating a new pronouns.cc account by using the following link: Please continue creating a new pronouns.cc account by using the following link:
<br /> <br />
<a href="@Model.BaseUrl/auth/signup/confirm/@Model.Code">Confirm your email address</a> <a href="@Model.BaseUrl/auth/callback/email/@Model.Code">Confirm your email address</a>
<br /> <br />
Note that this link will expire in one hour. Note that this link will expire in one hour.
</p> </p>

View file

@ -3,7 +3,7 @@
<p> <p>
Hello @@@Model.Username, please confirm adding this email address to your account by using the following link: Hello @@@Model.Username, please confirm adding this email address to your account by using the following link:
<br /> <br />
<a href="@Model.BaseUrl/settings/auth/confirm-email/@Model.Code">Confirm your email address</a> <a href="@Model.BaseUrl/auth/callback/email/@Model.Code">Confirm your email address</a>
<br /> <br />
Note that this link will expire in one hour. Note that this link will expire in one hour.
</p> </p>

View file

@ -194,6 +194,12 @@
"prometheus-net": "8.2.1" "prometheus-net": "8.2.1"
} }
}, },
"Roslynator.Analyzers": {
"type": "Direct",
"requested": "[4.12.9, )",
"resolved": "4.12.9",
"contentHash": "X6lDpN/D5wuinq37KIx+l3GSUe9No+8bCjGBTI5sEEtxapLztkHg6gzNVhMXpXw8P+/5gFYxTXJ5Pf8O4iNz/w=="
},
"Sentry.AspNetCore": { "Sentry.AspNetCore": {
"type": "Direct", "type": "Direct",
"requested": "[4.9.0, )", "requested": "[4.9.0, )",

View file

@ -0,0 +1,69 @@
import { apiRequest } from "$api";
import ApiError, { ErrorCode } from "$api/error";
import type { AddAccountResponse, CallbackResponse } from "$api/models";
import { setToken } from "$lib";
import log from "$lib/log";
import { isRedirect, redirect, type ServerLoadEvent } from "@sveltejs/kit";
export default function createCallbackLoader(
callbackType: string,
bodyFn?: (event: ServerLoadEvent) => Promise<unknown>,
) {
return async (event: ServerLoadEvent) => {
const { url, parent, fetch, cookies } = event;
bodyFn ??= async ({ url }) => {
const code = url.searchParams.get("code") as string | null;
const state = url.searchParams.get("state") as string | null;
if (!code || !state) throw new ApiError(undefined, ErrorCode.BadRequest).obj;
return { code, state };
};
const { meUser } = await parent();
if (meUser) {
try {
const resp = await apiRequest<AddAccountResponse>(
"POST",
`/auth/${callbackType}/add-account/callback`,
{
isInternal: true,
body: await bodyFn(event),
fetch,
cookies,
},
);
return { hasAccount: true, isLinkRequest: true, newAuthMethod: resp };
} catch (e) {
if (e instanceof ApiError) return { isLinkRequest: true, error: e.obj };
log.error("error linking new %s account to user %s:", callbackType, meUser.id, e);
throw e;
}
}
try {
const resp = await apiRequest<CallbackResponse>("POST", `/auth/${callbackType}/callback`, {
body: await bodyFn(event),
isInternal: true,
fetch,
});
if (resp.has_account) {
setToken(cookies, resp.token!);
redirect(303, `/@${resp.user!.username}`);
}
return {
hasAccount: false,
isLinkRequest: false,
ticket: resp.ticket!,
remoteUser: resp.remote_username!,
};
} catch (e) {
if (isRedirect(e)) throw e;
if (e instanceof ApiError) return { isLinkRequest: false, error: e.obj };
log.error("error while requesting %s callback:", callbackType, e);
throw e;
}
};
}

View file

@ -4,8 +4,8 @@
import type { RawApiError } from "$api/error"; import type { RawApiError } from "$api/error";
import ErrorAlert from "$components/ErrorAlert.svelte"; import ErrorAlert from "$components/ErrorAlert.svelte";
type Props = { form: { error: RawApiError | null; ok: boolean } | null }; type Props = { form: { error: RawApiError | null; ok: boolean } | null; successMessage?: string };
let { form }: Props = $props(); let { form, successMessage }: Props = $props();
</script> </script>
{#if form?.error} {#if form?.error}
@ -13,6 +13,6 @@
{:else if form?.ok} {:else if form?.ok}
<p class="text-success-emphasis"> <p class="text-success-emphasis">
<Icon name="check-circle-fill" /> <Icon name="check-circle-fill" />
{$t("edit-profile.saved-changes")} {successMessage ?? $t("edit-profile.saved-changes")}
</p> </p>
{/if} {/if}

View file

@ -1,6 +1,5 @@
<script lang="ts"> <script lang="ts">
import type { RawApiError } from "$api/error"; import type { RawApiError } from "$api/error";
import { enhance } from "$app/forms";
import ErrorAlert from "$components/ErrorAlert.svelte"; import ErrorAlert from "$components/ErrorAlert.svelte";
import { t } from "$lib/i18n"; import { t } from "$lib/i18n";
import { Button, Input, Label } from "@sveltestrap/sveltestrap"; import { Button, Input, Label } from "@sveltestrap/sveltestrap";
@ -21,7 +20,7 @@
<ErrorAlert {error} /> <ErrorAlert {error} />
{/if} {/if}
<form method="POST" use:enhance> <form method="POST">
<div class="mb-3"> <div class="mb-3">
<Label>{remoteLabel}</Label> <Label>{remoteLabel}</Label>
<Input type="text" readonly value={remoteUser} /> <Input type="text" readonly value={remoteUser} />

View file

@ -48,7 +48,11 @@
"successful-link-profile-hint": "You now can close this page, or go back to your profile:", "successful-link-profile-hint": "You now can close this page, or go back to your profile:",
"successful-link-profile-link": "Go to your profile", "successful-link-profile-link": "Go to your profile",
"remote-discord-account-label": "Your Discord account", "remote-discord-account-label": "Your Discord account",
"log-in-with-fediverse-instance-placeholder": "Your instance (i.e. mastodon.social)" "log-in-with-fediverse-instance-placeholder": "Your instance (i.e. mastodon.social)",
"register-with-email": "Register with an email address",
"email-label": "Your email address",
"confirm-password-label": "Confirm password",
"register-with-email-init-success": "Success! An email has been sent to your inbox, please press the link there to continue."
}, },
"error": { "error": {
"bad-request-header": "Something was wrong with your input", "bad-request-header": "Something was wrong with your input",

View file

@ -1,63 +1,7 @@
import { apiRequest } from "$api"; import createCallbackLoader from "$lib/actions/callback";
import ApiError, { ErrorCode } from "$api/error"; import createRegisterAction from "$lib/actions/register";
import type { AddAccountResponse, CallbackResponse } from "$api/models/auth";
import { setToken } from "$lib";
import createRegisterAction from "$lib/actions/register.js";
import log from "$lib/log.js";
import { isRedirect, redirect } from "@sveltejs/kit";
export const load = async ({ url, parent, fetch, cookies }) => { export const load = createCallbackLoader("discord");
const code = url.searchParams.get("code") as string | null;
const state = url.searchParams.get("state") as string | null;
if (!code || !state) throw new ApiError(undefined, ErrorCode.BadRequest).obj;
const { meUser } = await parent();
if (meUser) {
try {
const resp = await apiRequest<AddAccountResponse>(
"POST",
"/auth/discord/add-account/callback",
{
isInternal: true,
body: { code, state },
fetch,
cookies,
},
);
return { hasAccount: true, isLinkRequest: true, newAuthMethod: resp };
} catch (e) {
if (e instanceof ApiError) return { isLinkRequest: true, error: e.obj };
log.error("error linking new discord account to user %s:", meUser.id, e);
throw e;
}
}
try {
const resp = await apiRequest<CallbackResponse>("POST", "/auth/discord/callback", {
body: { code, state },
isInternal: true,
fetch,
});
if (resp.has_account) {
setToken(cookies, resp.token!);
redirect(303, `/@${resp.user!.username}`);
}
return {
hasAccount: false,
isLinkRequest: false,
ticket: resp.ticket!,
remoteUser: resp.remote_username!,
};
} catch (e) {
if (isRedirect(e)) throw e;
if (e instanceof ApiError) return { isLinkRequest: false, error: e.obj };
log.error("error while requesting discord callback:", e);
throw e;
}
};
export const actions = { export const actions = {
default: createRegisterAction("/auth/discord/register"), default: createRegisterAction("/auth/discord/register"),

View file

@ -0,0 +1,53 @@
import { apiRequest } from "$api";
import ApiError, { ErrorCode, type RawApiError } from "$api/error";
import type { AuthResponse } from "$api/models/auth";
import { setToken } from "$lib";
import createCallbackLoader from "$lib/actions/callback";
import log from "$lib/log";
import { redirect, isRedirect } from "@sveltejs/kit";
export const load = createCallbackLoader("email", async ({ params }) => {
log.info("params:", params, "code:", params.code);
return { state: params.code! };
});
export const actions = {
default: async ({ request, fetch, cookies }) => {
const data = await request.formData();
const username = data.get("username") as string | null;
const ticket = data.get("ticket") as string | null;
const password = data.get("password") as string | null;
const password2 = data.get("confirm-password") as string | null;
if (!username || !ticket || !password || !password2)
return {
error: { message: "Bad request", code: ErrorCode.BadRequest, status: 400 } as RawApiError,
};
if (password !== password2)
return {
error: {
message: "Passwords do not match",
code: ErrorCode.BadRequest,
status: 400,
} as RawApiError,
};
try {
const resp = await apiRequest<AuthResponse>("POST", "/auth/email/register", {
body: { username, ticket, password },
isInternal: true,
fetch,
});
setToken(cookies, resp.token);
redirect(303, "/auth/welcome");
} catch (e) {
if (isRedirect(e)) throw e;
log.error("Could not sign up user with username %s:", username, e);
if (e instanceof ApiError) return { error: e.obj };
throw e;
}
},
};

View file

@ -0,0 +1,51 @@
<script lang="ts">
import Error from "$components/Error.svelte";
import ErrorAlert from "$components/ErrorAlert.svelte";
import NewAuthMethod from "$components/settings/NewAuthMethod.svelte";
import { t } from "$lib/i18n";
import { Label, Input, Button } from "@sveltestrap/sveltestrap";
import type { ActionData, PageData } from "./$types";
type Props = { data: PageData; form: ActionData };
let { data, form }: Props = $props();
</script>
<svelte:head>
<title>{$t("auth.register-with-email")} • pronouns.cc</title>
</svelte:head>
<div class="container">
{#if data.error}
<h1>{$t("auth.register-with-email")}</h1>
<Error error={data.error} />
{:else if data.isLinkRequest}
<NewAuthMethod method={data.newAuthMethod!} user={data.meUser!} />
{:else}
<h1>{$t("auth.register-with-email")}</h1>
{#if form?.error}
<ErrorAlert error={form.error} />
{/if}
<form method="POST">
<div class="mb-3">
<Label>{$t("auth.email-label")}</Label>
<Input type="text" readonly value={data.remoteUser} />
</div>
<div class="mb-3">
<Label>{$t("auth.register-username-label")}</Label>
<Input type="text" name="username" required />
</div>
<div class="mb-3">
<Label>{$t("auth.log-in-form-password-label")}</Label>
<Input type="password" name="password" required />
</div>
<div class="mb-3">
<Label>{$t("auth.confirm-password-label")}</Label>
<Input type="password" name="confirm-password" required />
</div>
<input type="hidden" name="ticket" value={data.ticket!} />
<Button color="primary" type="submit">{$t("auth.register-button")}</Button>
</form>
{/if}
</div>

View file

@ -1,64 +1,15 @@
import { apiRequest } from "$api";
import ApiError, { ErrorCode } from "$api/error"; import ApiError, { ErrorCode } from "$api/error";
import type { AddAccountResponse, CallbackResponse } from "$api/models/auth.js"; import createCallbackLoader from "$lib/actions/callback";
import { setToken } from "$lib"; import createRegisterAction from "$lib/actions/register";
import createRegisterAction from "$lib/actions/register.js";
import log from "$lib/log";
import { isRedirect, redirect } from "@sveltejs/kit";
export const load = async ({ parent, params, url, fetch, cookies }) => { export const load = createCallbackLoader("fediverse", async ({ params, url }) => {
const code = url.searchParams.get("code") as string | null; const code = url.searchParams.get("code") as string | null;
const state = url.searchParams.get("state") as string | null; const state = url.searchParams.get("state") as string | null;
if (!code || !state) throw new ApiError(undefined, ErrorCode.BadRequest).obj; if (!code || !state) throw new ApiError(undefined, ErrorCode.BadRequest).obj;
const { meUser } = await parent(); return { code, state, instance: params.instance! };
if (meUser) {
try {
const resp = await apiRequest<AddAccountResponse>(
"POST",
"/auth/fediverse/add-account/callback",
{
isInternal: true,
body: { code, state, instance: params.instance },
fetch,
cookies,
},
);
return { hasAccount: true, isLinkRequest: true, newAuthMethod: resp };
} catch (e) {
if (e instanceof ApiError) return { isLinkRequest: true, error: e.obj };
log.error("error linking new fediverse account to user %s:", meUser.id, e);
throw e;
}
}
try {
const resp = await apiRequest<CallbackResponse>("POST", "/auth/fediverse/callback", {
body: { code, state, instance: params.instance },
isInternal: true,
fetch,
}); });
if (resp.has_account) {
setToken(cookies, resp.token!);
redirect(303, `/@${resp.user!.username}`);
}
return {
hasAccount: false,
isLinkRequest: false,
ticket: resp.ticket!,
remoteUser: resp.remote_username!,
};
} catch (e) {
if (isRedirect(e)) throw e;
if (e instanceof ApiError) return { isLinkRequest: false, error: e.obj };
log.error("error while requesting fediverse callback:", e);
throw e;
}
};
export const actions = { export const actions = {
default: createRegisterAction("/auth/fediverse/register"), default: createRegisterAction("/auth/fediverse/register"),
}; };

View file

@ -0,0 +1,35 @@
import { apiRequest, fastRequest } from "$api";
import ApiError from "$api/error.js";
import type { AuthUrls } from "$api/models/auth";
import log from "$lib/log.js";
import { redirect } from "@sveltejs/kit";
export const load = async ({ fetch, parent }) => {
const parentData = await parent();
if (parentData.meUser) redirect(303, `/@${parentData.meUser.username}`);
const urls = await apiRequest<AuthUrls>("POST", "/auth/urls", { fetch, isInternal: true });
if (!urls.email_enabled) redirect(303, "/");
};
export const actions = {
default: async ({ request, fetch, cookies }) => {
const body = await request.formData();
const email = body.get("email") as string;
try {
await fastRequest("POST", `/auth/email/register/init`, {
body: { email },
isInternal: true,
fetch,
cookies,
});
return { ok: true, error: null };
} catch (e) {
if (e instanceof ApiError) return { ok: false, error: e.obj };
log.error("error initiating registration for email %s:", email, e);
throw e;
}
},
};

View file

@ -0,0 +1,29 @@
<script lang="ts">
import type { ActionData, PageData } from "./$types";
import { t } from "$lib/i18n";
import { enhance } from "$app/forms";
import { Button, Input, InputGroup } from "@sveltestrap/sveltestrap";
import FormStatusMarker from "$components/editor/FormStatusMarker.svelte";
type Props = { data: PageData; form: ActionData };
let { data, form }: Props = $props();
</script>
<svelte:head>
<title>{$t("auth.register-with-email")} • pronouns.cc</title>
</svelte:head>
<div class="container">
<div class="mx-auto w-lg-75">
<h2>{$t("auth.register-with-email")}</h2>
<FormStatusMarker {form} successMessage={$t("auth.register-with-email-init-success")} />
<form method="POST" use:enhance>
<InputGroup>
<Input name="email" type="email" placeholder={$t("auth.email-label")} />
<Button type="submit" color="secondary">{$t("auth.register-with-email-button")}</Button>
</InputGroup>
</form>
</div>
</div>

View file

@ -3,6 +3,7 @@
import { Button, Input, InputGroup } from "@sveltestrap/sveltestrap"; import { Button, Input, InputGroup } from "@sveltestrap/sveltestrap";
</script> </script>
<div class="mx-auto w-lg-75">
<h3>Link a new Fediverse account</h3> <h3>Link a new Fediverse account</h3>
<form method="POST" action="?/add"> <form method="POST" action="?/add">
@ -21,3 +22,4 @@
</Button> </Button>
</p> </p>
</form> </form>
</div>

View file

@ -5,6 +5,7 @@
import { Icon } from "@sveltestrap/sveltestrap"; import { Icon } from "@sveltestrap/sveltestrap";
import { t } from "$lib/i18n"; import { t } from "$lib/i18n";
import { enhance } from "$app/forms"; import { enhance } from "$app/forms";
import FormStatusMarker from "$components/editor/FormStatusMarker.svelte";
type Props = { data: PageData; form: ActionData }; type Props = { data: PageData; form: ActionData };
let { data, form }: Props = $props(); let { data, form }: Props = $props();
@ -18,14 +19,7 @@
<div class="mx-auto w-lg-75"> <div class="mx-auto w-lg-75">
<h3>{$t("settings.export-title")}</h3> <h3>{$t("settings.export-title")}</h3>
{#if form?.ok} <FormStatusMarker {form} successMessage={$t("settings.export-request-success")} />
<p class="text-success-emphasis">
<Icon name="check-circle-fill" />
{$t("settings.export-request-success")}
</p>
{:else if form?.error}
<ErrorAlert error={form.error} />
{/if}
<p> <p>
{$t("settings.export-info")} {$t("settings.export-info")}

View file

@ -1,4 +1,5 @@
<wpf:ResourceDictionary xml:space="preserve" xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:s="clr-namespace:System;assembly=mscorlib" xmlns:ss="urn:shemas-jetbrains-com:settings-storage-xaml" xmlns:wpf="http://schemas.microsoft.com/winfx/2006/xaml/presentation"> <wpf:ResourceDictionary xml:space="preserve" xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:s="clr-namespace:System;assembly=mscorlib" xmlns:ss="urn:shemas-jetbrains-com:settings-storage-xaml" xmlns:wpf="http://schemas.microsoft.com/winfx/2006/xaml/presentation">
<s:String x:Key="/Default/CodeStyle/CSharpVarKeywordUsage/ForBuiltInTypes/@EntryValue">UseVarWhenEvident</s:String>
<s:String x:Key="/Default/CodeStyle/FileHeader/FileHeaderText/@EntryValue">Copyright (C) 2023-present sam/u1f320 (vulpine.solutions) <s:String x:Key="/Default/CodeStyle/FileHeader/FileHeaderText/@EntryValue">Copyright (C) 2023-present sam/u1f320 (vulpine.solutions)
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify

View file

@ -3,9 +3,8 @@
## C# code style ## C# code style
- Code should be formatted with `dotnet format` or Rider's built-in formatter. - Code should be formatted with `dotnet format` or Rider's built-in formatter.
- Variables should *always* be declared using `var`, - Variables should always be declared with their type name, unless the type is obvious from the declaration.
unless the correct type can't be inferred from the declaration (i.e. if the variable needs to be an `IEnumerable<T>` (For example, `var stream = new Stream()` or `var db = services.GetRequiredService<DatabaseContext>()`)
instead of a `List<T>`, or if a variable is initialized as `null`).
### Naming ### Naming