Compare commits

..

No commits in common. "f8e60324491e4983a8fe9d7bb4a49025cc7b3719" and "57e1ec09c0668bfcd90209a659a710f6ec45f21e" have entirely different histories.

71 changed files with 832 additions and 1158 deletions

View file

@ -1,52 +1,8 @@
[*] [*.cs]
# We use PostgresSQL which doesn't recommend more specific string types # We use PostgresSQL which doesn't recommend more specific string types
resharper_entity_framework_model_validation_unlimited_string_length_highlighting = none resharper_entity_framework_model_validation_unlimited_string_length_highlighting = none
# This is raised for every single property of records returned by endpoints # This is raised for every single property of records returned by endpoints
resharper_not_accessed_positional_property_local_highlighting = none resharper_not_accessed_positional_property_local_highlighting = none
# Microsoft .NET properties
csharp_new_line_before_members_in_object_initializers = false
csharp_preferred_modifier_order = public, internal, protected, private, file, new, required, abstract, virtual, sealed, static, override, extern, unsafe, volatile, async, readonly:suggestion
# ReSharper properties
resharper_align_multiline_binary_expressions_chain = false
resharper_arguments_skip_single = true
resharper_blank_lines_after_start_comment = 0
resharper_blank_lines_around_single_line_invocable = 0
resharper_blank_lines_before_block_statements = 0
resharper_braces_for_foreach = required_for_multiline
resharper_braces_for_ifelse = required_for_multiline
resharper_braces_redundant = false
resharper_csharp_blank_lines_around_field = 0
resharper_csharp_empty_block_style = together_same_line
resharper_csharp_max_line_length = 166
resharper_csharp_wrap_after_declaration_lpar = true
resharper_csharp_wrap_before_binary_opsign = true
resharper_csharp_wrap_before_declaration_rpar = true
resharper_csharp_wrap_parameters_style = chop_if_long
resharper_indent_preprocessor_other = do_not_change
resharper_instance_members_qualify_declared_in =
resharper_keep_existing_attribute_arrangement = true
resharper_max_attribute_length_for_same_line = 70
resharper_place_accessorholder_attribute_on_same_line = false
resharper_place_expr_method_on_single_line = if_owner_is_single_line
resharper_place_method_attribute_on_same_line = if_owner_is_single_line
resharper_place_record_field_attribute_on_same_line = true
resharper_place_simple_embedded_statement_on_same_line = false
resharper_place_simple_initializer_on_single_line = false
resharper_place_simple_list_pattern_on_single_line = false
resharper_space_within_empty_braces = false
resharper_trailing_comma_in_multiline_lists = true
resharper_wrap_after_invocation_lpar = false
resharper_wrap_before_invocation_rpar = false
resharper_wrap_before_primary_constructor_declaration_rpar = true
resharper_wrap_chained_binary_patterns = chop_if_long
resharper_wrap_list_pattern = chop_always
resharper_wrap_object_and_collection_initializer_style = chop_always
# Roslynator properties
dotnet_diagnostic.RCS1194.severity = none
[*generated.cs] [*generated.cs]
generated_code = true generated_code = true

View file

@ -7,21 +7,19 @@ public static class BuildInfo
public static async Task ReadBuildInfo() public static async Task ReadBuildInfo()
{ {
await using Stream? stream = typeof(BuildInfo).Assembly.GetManifestResourceStream( await using var stream = typeof(BuildInfo).Assembly.GetManifestResourceStream("version");
"version"
);
if (stream == null) if (stream == null)
return; return;
using var reader = new StreamReader(stream); using var reader = new StreamReader(stream);
string[] data = (await reader.ReadToEndAsync()).Trim().Split("\n"); var data = (await reader.ReadToEndAsync()).Trim().Split("\n");
if (data.Length < 3) if (data.Length < 3)
return; return;
Hash = data[0]; Hash = data[0];
bool dirty = data[2] == "dirty"; var dirty = data[2] == "dirty";
string[] versionData = data[1].Split("-"); var versionData = data[1].Split("-");
if (versionData.Length < 3) if (versionData.Length < 3)
return; return;
Version = versionData[0]; Version = versionData[0];

View file

@ -33,16 +33,14 @@ public class AuthController(
config.GoogleAuth.Enabled, config.GoogleAuth.Enabled,
config.TumblrAuth.Enabled config.TumblrAuth.Enabled
); );
string state = HttpUtility.UrlEncode(await keyCacheService.GenerateAuthStateAsync(ct)); var state = HttpUtility.UrlEncode(await keyCacheService.GenerateAuthStateAsync(ct));
string? discord = null; string? discord = null;
if (config.DiscordAuth is { ClientId: not null, ClientSecret: not null }) if (config.DiscordAuth is { ClientId: not null, ClientSecret: not null })
{
discord = discord =
"https://discord.com/oauth2/authorize?response_type=code" $"https://discord.com/oauth2/authorize?response_type=code"
+ $"&client_id={config.DiscordAuth.ClientId}&scope=identify" + $"&client_id={config.DiscordAuth.ClientId}&scope=identify"
+ $"&prompt=none&state={state}" + $"&prompt=none&state={state}"
+ $"&redirect_uri={HttpUtility.UrlEncode($"{config.BaseUrl}/auth/callback/discord")}"; + $"&redirect_uri={HttpUtility.UrlEncode($"{config.BaseUrl}/auth/callback/discord")}";
}
return Ok(new UrlsResponse(config.EmailAuth.Enabled, discord, null, null)); return Ok(new UrlsResponse(config.EmailAuth.Enabled, discord, null, null));
} }
@ -88,7 +86,7 @@ public class AuthController(
)] )]
public async Task<IActionResult> GetAuthMethodAsync(Snowflake id) public async Task<IActionResult> GetAuthMethodAsync(Snowflake id)
{ {
AuthMethod? authMethod = await db var authMethod = await db
.AuthMethods.Include(a => a.FediverseApplication) .AuthMethods.Include(a => a.FediverseApplication)
.FirstOrDefaultAsync(a => a.UserId == CurrentUser!.Id && a.Id == id); .FirstOrDefaultAsync(a => a.UserId == CurrentUser!.Id && a.Id == id);
if (authMethod == null) if (authMethod == null)
@ -101,19 +99,17 @@ public class AuthController(
[Authorize("*")] [Authorize("*")]
public async Task<IActionResult> DeleteAuthMethodAsync(Snowflake id) public async Task<IActionResult> DeleteAuthMethodAsync(Snowflake id)
{ {
List<AuthMethod> authMethods = await db var authMethods = await db
.AuthMethods.Where(a => a.UserId == CurrentUser!.Id) .AuthMethods.Where(a => a.UserId == CurrentUser!.Id)
.ToListAsync(); .ToListAsync();
if (authMethods.Count < 2) if (authMethods.Count < 2)
{
throw new ApiError( throw new ApiError(
"You cannot remove your last authentication method.", "You cannot remove your last authentication method.",
HttpStatusCode.BadRequest, HttpStatusCode.BadRequest,
ErrorCode.LastAuthMethod ErrorCode.LastAuthMethod
); );
}
AuthMethod? authMethod = authMethods.FirstOrDefault(a => a.Id == id); var authMethod = authMethods.FirstOrDefault(a => a.Id == id);
if (authMethod == null) if (authMethod == null)
throw new ApiError.NotFound("No authentication method with that ID found."); throw new ApiError.NotFound("No authentication method with that ID found.");
@ -123,20 +119,6 @@ public class AuthController(
CurrentUser!.Id CurrentUser!.Id
); );
// If this is the user's last email, we should also clear the user's password.
if (
authMethod.AuthType == AuthType.Email
&& authMethods.Count(a => a.AuthType == AuthType.Email) == 1
)
{
_logger.Debug(
"Deleted last email address for user {UserId}, resetting their password",
CurrentUser.Id
);
CurrentUser.Password = null;
db.Update(CurrentUser);
}
db.Remove(authMethod); db.Remove(authMethod);
await db.SaveChangesAsync(); await db.SaveChangesAsync();

View file

@ -34,10 +34,8 @@ public class DiscordAuthController(
CheckRequirements(); CheckRequirements();
await keyCacheService.ValidateAuthStateAsync(req.State); await keyCacheService.ValidateAuthStateAsync(req.State);
RemoteAuthService.RemoteUser remoteUser = await remoteAuthService.RequestDiscordTokenAsync( var remoteUser = await remoteAuthService.RequestDiscordTokenAsync(req.Code);
req.Code var user = await authService.AuthenticateUserAsync(AuthType.Discord, remoteUser.Id);
);
User? user = await authService.AuthenticateUserAsync(AuthType.Discord, remoteUser.Id);
if (user != null) if (user != null)
return Ok(await authService.GenerateUserTokenAsync(user)); return Ok(await authService.GenerateUserTokenAsync(user));
@ -47,14 +45,23 @@ public class DiscordAuthController(
remoteUser.Id remoteUser.Id
); );
string ticket = AuthUtils.RandomToken(); var ticket = AuthUtils.RandomToken();
await keyCacheService.SetKeyAsync( await keyCacheService.SetKeyAsync(
$"discord:{ticket}", $"discord:{ticket}",
remoteUser, remoteUser,
Duration.FromMinutes(20) Duration.FromMinutes(20)
); );
return Ok(new CallbackResponse(false, ticket, remoteUser.Username, null, null, null)); return Ok(
new CallbackResponse(
HasAccount: false,
Ticket: ticket,
RemoteUsername: remoteUser.Username,
User: null,
Token: null,
ExpiresAt: null
)
);
} }
[HttpPost("register")] [HttpPost("register")]
@ -63,8 +70,7 @@ public class DiscordAuthController(
[FromBody] AuthController.OauthRegisterRequest req [FromBody] AuthController.OauthRegisterRequest req
) )
{ {
RemoteAuthService.RemoteUser? remoteUser = var remoteUser = await keyCacheService.GetKeyAsync<RemoteAuthService.RemoteUser>(
await keyCacheService.GetKeyAsync<RemoteAuthService.RemoteUser>(
$"discord:{req.Ticket}" $"discord:{req.Ticket}"
); );
if (remoteUser == null) if (remoteUser == null)
@ -82,7 +88,7 @@ public class DiscordAuthController(
throw new ApiError.BadRequest("Invalid ticket", "ticket", req.Ticket); throw new ApiError.BadRequest("Invalid ticket", "ticket", req.Ticket);
} }
User user = await authService.CreateUserWithRemoteAuthAsync( var user = await authService.CreateUserWithRemoteAuthAsync(
req.Username, req.Username,
AuthType.Discord, AuthType.Discord,
remoteUser.Id, remoteUser.Id,
@ -98,13 +104,13 @@ public class DiscordAuthController(
{ {
CheckRequirements(); CheckRequirements();
string state = await remoteAuthService.ValidateAddAccountRequestAsync( var state = await remoteAuthService.ValidateAddAccountRequestAsync(
CurrentUser!.Id, CurrentUser!.Id,
AuthType.Discord AuthType.Discord
); );
string url = var url =
"https://discord.com/oauth2/authorize?response_type=code" $"https://discord.com/oauth2/authorize?response_type=code"
+ $"&client_id={config.DiscordAuth.ClientId}&scope=identify" + $"&client_id={config.DiscordAuth.ClientId}&scope=identify"
+ $"&prompt=none&state={state}" + $"&prompt=none&state={state}"
+ $"&redirect_uri={HttpUtility.UrlEncode($"{config.BaseUrl}/auth/callback/discord")}"; + $"&redirect_uri={HttpUtility.UrlEncode($"{config.BaseUrl}/auth/callback/discord")}";
@ -126,12 +132,10 @@ public class DiscordAuthController(
AuthType.Discord AuthType.Discord
); );
RemoteAuthService.RemoteUser remoteUser = await remoteAuthService.RequestDiscordTokenAsync( var remoteUser = await remoteAuthService.RequestDiscordTokenAsync(req.Code);
req.Code
);
try try
{ {
AuthMethod authMethod = await authService.AddAuthMethodAsync( var authMethod = await authService.AddAuthMethodAsync(
CurrentUser.Id, CurrentUser.Id,
AuthType.Discord, AuthType.Discord,
remoteUser.Id, remoteUser.Id,
@ -165,10 +169,8 @@ public class DiscordAuthController(
private void CheckRequirements() private void CheckRequirements()
{ {
if (!config.DiscordAuth.Enabled) if (!config.DiscordAuth.Enabled)
{
throw new ApiError.BadRequest( throw new ApiError.BadRequest(
"Discord authentication is not enabled on this instance." "Discord authentication is not enabled on this instance."
); );
} }
}
} }

View file

@ -1,5 +1,3 @@
using System.Net;
using EntityFramework.Exceptions.Common;
using Foxnouns.Backend.Database; using Foxnouns.Backend.Database;
using Foxnouns.Backend.Database.Models; using Foxnouns.Backend.Database.Models;
using Foxnouns.Backend.Extensions; using Foxnouns.Backend.Extensions;
@ -28,8 +26,8 @@ public class EmailAuthController(
{ {
private readonly ILogger _logger = logger.ForContext<EmailAuthController>(); private readonly ILogger _logger = logger.ForContext<EmailAuthController>();
[HttpPost("register/init")] [HttpPost("register")]
public async Task<IActionResult> RegisterInitAsync( public async Task<IActionResult> RegisterAsync(
[FromBody] RegisterRequest req, [FromBody] RegisterRequest req,
CancellationToken ct = default CancellationToken ct = default
) )
@ -39,7 +37,11 @@ public class EmailAuthController(
if (!req.Email.Contains('@')) if (!req.Email.Contains('@'))
throw new ApiError.BadRequest("Email is invalid", "email", req.Email); throw new ApiError.BadRequest("Email is invalid", "email", req.Email);
string state = await keyCacheService.GenerateRegisterEmailStateAsync(req.Email, null, ct); var state = await keyCacheService.GenerateRegisterEmailStateAsync(
req.Email,
userId: null,
ct
);
// If there's already a user with that email address, pretend we sent an email but actually ignore it // If there's already a user with that email address, pretend we sent an email but actually ignore it
if ( if (
@ -48,9 +50,7 @@ public class EmailAuthController(
ct ct
) )
) )
{
return NoContent(); return NoContent();
}
mailService.QueueAccountCreationEmail(req.Email, state); mailService.QueueAccountCreationEmail(req.Email, state);
return NoContent(); return NoContent();
@ -61,35 +61,62 @@ public class EmailAuthController(
{ {
CheckRequirements(); CheckRequirements();
RegisterEmailState? state = await keyCacheService.GetRegisterEmailStateAsync(req.State); var state = await keyCacheService.GetRegisterEmailStateAsync(req.State);
if (state is not { ExistingUserId: null }) if (state == null)
throw new ApiError.BadRequest("Invalid state", "state", req.State); throw new ApiError.BadRequest("Invalid state", "state", req.State);
string ticket = AuthUtils.RandomToken(); // If this callback is for an existing user, add the email address to their auth methods
await keyCacheService.SetKeyAsync($"email:{ticket}", state.Email, Duration.FromMinutes(20)); if (state.ExistingUserId != null)
{
return Ok(new CallbackResponse(false, ticket, state.Email, null, null, null)); var authMethod = await authService.AddAuthMethodAsync(
state.ExistingUserId.Value,
AuthType.Email,
state.Email
);
_logger.Debug(
"Added email auth {AuthId} for user {UserId}",
authMethod.Id,
state.ExistingUserId
);
return NoContent();
} }
[HttpPost("register")] var ticket = AuthUtils.RandomToken();
await keyCacheService.SetKeyAsync($"email:{ticket}", state.Email, Duration.FromMinutes(20));
return Ok(
new CallbackResponse(
HasAccount: false,
Ticket: ticket,
RemoteUsername: state.Email,
User: null,
Token: null,
ExpiresAt: null
)
);
}
[HttpPost("complete-registration")]
public async Task<IActionResult> CompleteRegistrationAsync( public async Task<IActionResult> CompleteRegistrationAsync(
[FromBody] CompleteRegistrationRequest req [FromBody] CompleteRegistrationRequest req
) )
{ {
CheckRequirements(); CheckRequirements();
string? email = await keyCacheService.GetKeyAsync($"email:{req.Ticket}"); var email = await keyCacheService.GetKeyAsync($"email:{req.Ticket}");
if (email == null) if (email == null)
throw new ApiError.BadRequest("Unknown ticket", "ticket", req.Ticket); throw new ApiError.BadRequest("Unknown ticket", "ticket", req.Ticket);
User user = await authService.CreateUserWithPasswordAsync( // Check if username is valid at all
req.Username, ValidationUtils.Validate([("username", ValidationUtils.ValidateUsername(req.Username))]);
email, // Check if username is already taken
req.Password if (await db.Users.AnyAsync(u => u.Username == req.Username))
); throw new ApiError.BadRequest("Username is already taken", "username", req.Username);
Application frontendApp = await db.GetFrontendApplicationAsync();
(string? tokenStr, Token? token) = authService.GenerateToken( var user = await authService.CreateUserWithPasswordAsync(req.Username, email, req.Password);
var frontendApp = await db.GetFrontendApplicationAsync();
var (tokenStr, token) = authService.GenerateToken(
user, user,
frontendApp, frontendApp,
["*"], ["*"],
@ -103,7 +130,7 @@ public class EmailAuthController(
return Ok( return Ok(
new AuthController.AuthResponse( new AuthController.AuthResponse(
await userRenderer.RenderUserAsync(user, user, renderMembers: false), await userRenderer.RenderUserAsync(user, selfUser: user, renderMembers: false),
tokenStr, tokenStr,
token.ExpiresAt token.ExpiresAt
) )
@ -119,16 +146,19 @@ public class EmailAuthController(
{ {
CheckRequirements(); CheckRequirements();
(User? user, AuthService.EmailAuthenticationResult authenticationResult) = var (user, authenticationResult) = await authService.AuthenticateUserAsync(
await authService.AuthenticateUserAsync(req.Email, req.Password, ct); req.Email,
req.Password,
ct
);
if (authenticationResult == AuthService.EmailAuthenticationResult.MfaRequired) if (authenticationResult == AuthService.EmailAuthenticationResult.MfaRequired)
throw new NotImplementedException("MFA is not implemented yet"); throw new NotImplementedException("MFA is not implemented yet");
Application frontendApp = await db.GetFrontendApplicationAsync(ct); var frontendApp = await db.GetFrontendApplicationAsync(ct);
_logger.Debug("Logging user {Id} in with email and password", user.Id); _logger.Debug("Logging user {Id} in with email and password", user.Id);
(string? tokenStr, Token? token) = authService.GenerateToken( var (tokenStr, token) = authService.GenerateToken(
user, user,
frontendApp, frontendApp,
["*"], ["*"],
@ -142,34 +172,25 @@ public class EmailAuthController(
return Ok( return Ok(
new AuthController.AuthResponse( new AuthController.AuthResponse(
await userRenderer.RenderUserAsync(user, user, renderMembers: false, ct: ct), await userRenderer.RenderUserAsync(
user,
selfUser: user,
renderMembers: false,
ct: ct
),
tokenStr, tokenStr,
token.ExpiresAt token.ExpiresAt
) )
); );
} }
[HttpPost("change-password")] [HttpPost("add")]
[Authorize("*")]
public async Task<IActionResult> UpdatePasswordAsync([FromBody] ChangePasswordRequest req)
{
if (!await authService.ValidatePasswordAsync(CurrentUser!, req.Current))
throw new ApiError.Forbidden("Invalid password");
ValidationUtils.Validate([("new", ValidationUtils.ValidatePassword(req.New))]);
await authService.SetUserPasswordAsync(CurrentUser!, req.New);
await db.SaveChangesAsync();
return NoContent();
}
[HttpPost("add-email")]
[Authorize("*")] [Authorize("*")]
public async Task<IActionResult> AddEmailAddressAsync([FromBody] AddEmailAddressRequest req) public async Task<IActionResult> AddEmailAddressAsync([FromBody] AddEmailAddressRequest req)
{ {
CheckRequirements(); CheckRequirements();
List<AuthMethod> emails = await db var emails = await db
.AuthMethods.Where(m => m.UserId == CurrentUser!.Id && m.AuthType == AuthType.Email) .AuthMethods.Where(m => m.UserId == CurrentUser!.Id && m.AuthType == AuthType.Email)
.ToListAsync(); .ToListAsync();
if (emails.Count > AuthUtils.MaxAuthMethodsPerType) if (emails.Count > AuthUtils.MaxAuthMethodsPerType)
@ -183,21 +204,24 @@ public class EmailAuthController(
if (emails.Count != 0) if (emails.Count != 0)
{ {
if (!await authService.ValidatePasswordAsync(CurrentUser!, req.Password)) var validPassword = await authService.ValidatePasswordAsync(CurrentUser!, req.Password);
if (!validPassword)
{
throw new ApiError.Forbidden("Invalid password"); throw new ApiError.Forbidden("Invalid password");
} }
}
else else
{ {
await authService.SetUserPasswordAsync(CurrentUser!, req.Password); await authService.SetUserPasswordAsync(CurrentUser!, req.Password);
await db.SaveChangesAsync(); await db.SaveChangesAsync();
} }
string state = await keyCacheService.GenerateRegisterEmailStateAsync( var state = await keyCacheService.GenerateRegisterEmailStateAsync(
req.Email, req.Email,
CurrentUser!.Id userId: CurrentUser!.Id
); );
bool emailExists = await db var emailExists = await db
.AuthMethods.Where(m => m.AuthType == AuthType.Email && m.RemoteId == req.Email) .AuthMethods.Where(m => m.AuthType == AuthType.Email && m.RemoteId == req.Email)
.AnyAsync(); .AnyAsync();
if (emailExists) if (emailExists)
@ -209,48 +233,6 @@ public class EmailAuthController(
return NoContent(); return NoContent();
} }
[HttpPost("add-email/callback")]
[Authorize("*")]
public async Task<IActionResult> AddEmailCallbackAsync([FromBody] CallbackRequest req)
{
CheckRequirements();
RegisterEmailState? state = await keyCacheService.GetRegisterEmailStateAsync(req.State);
if (state?.ExistingUserId != CurrentUser!.Id)
throw new ApiError.BadRequest("Invalid state", "state", req.State);
try
{
AuthMethod authMethod = await authService.AddAuthMethodAsync(
CurrentUser.Id,
AuthType.Email,
state.Email
);
_logger.Debug(
"Added email auth {AuthId} for user {UserId}",
authMethod.Id,
CurrentUser.Id
);
return Ok(
new AuthController.AddOauthAccountResponse(
authMethod.Id,
AuthType.Email,
authMethod.RemoteId,
null
)
);
}
catch (UniqueConstraintException)
{
throw new ApiError(
"That email address is already linked.",
HttpStatusCode.BadRequest,
ErrorCode.AccountAlreadyLinked
);
}
}
public record AddEmailAddressRequest(string Email, string Password); public record AddEmailAddressRequest(string Email, string Password);
private void CheckRequirements() private void CheckRequirements()
@ -266,6 +248,4 @@ public class EmailAuthController(
public record CompleteRegistrationRequest(string Ticket, string Username, string Password); public record CompleteRegistrationRequest(string Ticket, string Username, string Password);
public record CallbackRequest(string State); public record CallbackRequest(string State);
public record ChangePasswordRequest(string Current, string New);
} }

View file

@ -34,7 +34,7 @@ public class FediverseAuthController(
if (instance.Any(c => c is '@' or ':' or '/') || !instance.Contains('.')) if (instance.Any(c => c is '@' or ':' or '/') || !instance.Contains('.'))
throw new ApiError.BadRequest("Not a valid domain.", "instance", instance); throw new ApiError.BadRequest("Not a valid domain.", "instance", instance);
string url = await fediverseAuthService.GenerateAuthUrlAsync(instance, forceRefresh); var url = await fediverseAuthService.GenerateAuthUrlAsync(instance, forceRefresh);
return Ok(new AuthController.SingleUrlResponse(url)); return Ok(new AuthController.SingleUrlResponse(url));
} }
@ -42,19 +42,22 @@ public class FediverseAuthController(
[ProducesResponseType<CallbackResponse>(statusCode: StatusCodes.Status200OK)] [ProducesResponseType<CallbackResponse>(statusCode: StatusCodes.Status200OK)]
public async Task<IActionResult> FediverseCallbackAsync([FromBody] CallbackRequest req) public async Task<IActionResult> FediverseCallbackAsync([FromBody] CallbackRequest req)
{ {
FediverseApplication app = await fediverseAuthService.GetApplicationAsync(req.Instance); var app = await fediverseAuthService.GetApplicationAsync(req.Instance);
FediverseAuthService.FediverseUser remoteUser = var remoteUser = await fediverseAuthService.GetRemoteFediverseUserAsync(
await fediverseAuthService.GetRemoteFediverseUserAsync(app, req.Code, req.State); app,
req.Code,
req.State
);
User? user = await authService.AuthenticateUserAsync( var user = await authService.AuthenticateUserAsync(
AuthType.Fediverse, AuthType.Fediverse,
remoteUser.Id, remoteUser.Id,
app instance: app
); );
if (user != null) if (user != null)
return Ok(await authService.GenerateUserTokenAsync(user)); return Ok(await authService.GenerateUserTokenAsync(user));
string ticket = AuthUtils.RandomToken(); var ticket = AuthUtils.RandomToken();
await keyCacheService.SetKeyAsync( await keyCacheService.SetKeyAsync(
$"fediverse:{ticket}", $"fediverse:{ticket}",
new FediverseTicketData(app.Id, remoteUser), new FediverseTicketData(app.Id, remoteUser),
@ -63,12 +66,12 @@ public class FediverseAuthController(
return Ok( return Ok(
new CallbackResponse( new CallbackResponse(
false, HasAccount: false,
ticket, Ticket: ticket,
$"@{remoteUser.Username}@{app.Domain}", RemoteUsername: $"@{remoteUser.Username}@{app.Domain}",
null, User: null,
null, Token: null,
null ExpiresAt: null
) )
); );
} }
@ -79,16 +82,14 @@ public class FediverseAuthController(
[FromBody] AuthController.OauthRegisterRequest req [FromBody] AuthController.OauthRegisterRequest req
) )
{ {
FediverseTicketData? ticketData = await keyCacheService.GetKeyAsync<FediverseTicketData>( var ticketData = await keyCacheService.GetKeyAsync<FediverseTicketData>(
$"fediverse:{req.Ticket}", $"fediverse:{req.Ticket}",
true delete: true
); );
if (ticketData == null) if (ticketData == null)
throw new ApiError.BadRequest("Invalid ticket", "ticket", req.Ticket); throw new ApiError.BadRequest("Invalid ticket", "ticket", req.Ticket);
FediverseApplication? app = await db.FediverseApplications.FindAsync( var app = await db.FediverseApplications.FindAsync(ticketData.ApplicationId);
ticketData.ApplicationId
);
if (app == null) if (app == null)
throw new FoxnounsError("Null application found for ticket"); throw new FoxnounsError("Null application found for ticket");
@ -110,12 +111,12 @@ public class FediverseAuthController(
throw new ApiError.BadRequest("Invalid ticket", "ticket", req.Ticket); throw new ApiError.BadRequest("Invalid ticket", "ticket", req.Ticket);
} }
User user = await authService.CreateUserWithRemoteAuthAsync( var user = await authService.CreateUserWithRemoteAuthAsync(
req.Username, req.Username,
AuthType.Fediverse, AuthType.Fediverse,
ticketData.User.Id, ticketData.User.Id,
ticketData.User.Username, ticketData.User.Username,
app instance: app
); );
return Ok(await authService.GenerateUserTokenAsync(user)); return Ok(await authService.GenerateUserTokenAsync(user));
@ -131,13 +132,13 @@ public class FediverseAuthController(
if (instance.Any(c => c is '@' or ':' or '/') || !instance.Contains('.')) if (instance.Any(c => c is '@' or ':' or '/') || !instance.Contains('.'))
throw new ApiError.BadRequest("Not a valid domain.", "instance", instance); throw new ApiError.BadRequest("Not a valid domain.", "instance", instance);
string state = await remoteAuthService.ValidateAddAccountRequestAsync( var state = await remoteAuthService.ValidateAddAccountRequestAsync(
CurrentUser!.Id, CurrentUser!.Id,
AuthType.Fediverse, AuthType.Fediverse,
instance instance
); );
string url = await fediverseAuthService.GenerateAuthUrlAsync(instance, forceRefresh, state); var url = await fediverseAuthService.GenerateAuthUrlAsync(instance, forceRefresh, state);
return Ok(new AuthController.SingleUrlResponse(url)); return Ok(new AuthController.SingleUrlResponse(url));
} }
@ -152,12 +153,11 @@ public class FediverseAuthController(
req.Instance req.Instance
); );
FediverseApplication app = await fediverseAuthService.GetApplicationAsync(req.Instance); var app = await fediverseAuthService.GetApplicationAsync(req.Instance);
FediverseAuthService.FediverseUser remoteUser = var remoteUser = await fediverseAuthService.GetRemoteFediverseUserAsync(app, req.Code);
await fediverseAuthService.GetRemoteFediverseUserAsync(app, req.Code);
try try
{ {
AuthMethod authMethod = await authService.AddAuthMethodAsync( var authMethod = await authService.AddAuthMethodAsync(
CurrentUser.Id, CurrentUser.Id,
AuthType.Fediverse, AuthType.Fediverse,
remoteUser.Id, remoteUser.Id,

View file

@ -25,7 +25,7 @@ public class ExportsController(
[HttpGet] [HttpGet]
public async Task<IActionResult> GetDataExportsAsync() public async Task<IActionResult> GetDataExportsAsync()
{ {
DataExport? export = await db var export = await db
.DataExports.Where(d => d.UserId == CurrentUser!.Id) .DataExports.Where(d => d.UserId == CurrentUser!.Id)
.OrderByDescending(d => d.Id) .OrderByDescending(d => d.Id)
.FirstOrDefaultAsync(); .FirstOrDefaultAsync();

View file

@ -1,6 +1,5 @@
using Coravel.Queuing.Interfaces; using Coravel.Queuing.Interfaces;
using Foxnouns.Backend.Database; using Foxnouns.Backend.Database;
using Foxnouns.Backend.Database.Models;
using Foxnouns.Backend.Extensions; using Foxnouns.Backend.Extensions;
using Foxnouns.Backend.Jobs; using Foxnouns.Backend.Jobs;
using Foxnouns.Backend.Middleware; using Foxnouns.Backend.Middleware;
@ -8,7 +7,6 @@ using Foxnouns.Backend.Services;
using Foxnouns.Backend.Utils; using Foxnouns.Backend.Utils;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Storage;
namespace Foxnouns.Backend.Controllers; namespace Foxnouns.Backend.Controllers;
@ -31,9 +29,7 @@ public class FlagsController(
)] )]
public async Task<IActionResult> GetFlagsAsync(CancellationToken ct = default) public async Task<IActionResult> GetFlagsAsync(CancellationToken ct = default)
{ {
List<PrideFlag> flags = await db var flags = await db.PrideFlags.Where(f => f.UserId == CurrentUser!.Id).ToListAsync(ct);
.PrideFlags.Where(f => f.UserId == CurrentUser!.Id)
.ToListAsync(ct);
return Ok(flags.Select(userRenderer.RenderPrideFlag)); return Ok(flags.Select(userRenderer.RenderPrideFlag));
} }
@ -47,7 +43,7 @@ public class FlagsController(
{ {
ValidationUtils.Validate(ValidateFlag(req.Name, req.Description, req.Image)); ValidationUtils.Validate(ValidateFlag(req.Name, req.Description, req.Image));
Snowflake id = snowflakeGenerator.GenerateSnowflake(); var id = snowflakeGenerator.GenerateSnowflake();
queue.QueueInvocableWithPayload<CreateFlagInvocable, CreateFlagPayload>( queue.QueueInvocableWithPayload<CreateFlagInvocable, CreateFlagPayload>(
new CreateFlagPayload(id, CurrentUser!.Id, req.Name, req.Image, req.Description) new CreateFlagPayload(id, CurrentUser!.Id, req.Name, req.Image, req.Description)
@ -66,7 +62,7 @@ public class FlagsController(
{ {
ValidationUtils.Validate(ValidateFlag(req.Name, req.Description, null)); ValidationUtils.Validate(ValidateFlag(req.Name, req.Description, null));
PrideFlag? flag = await db.PrideFlags.FirstOrDefaultAsync(f => var flag = await db.PrideFlags.FirstOrDefaultAsync(f =>
f.Id == id && f.UserId == CurrentUser!.Id f.Id == id && f.UserId == CurrentUser!.Id
); );
if (flag == null) if (flag == null)
@ -94,20 +90,20 @@ public class FlagsController(
[Authorize("user.update")] [Authorize("user.update")]
public async Task<IActionResult> DeleteFlagAsync(Snowflake id) public async Task<IActionResult> DeleteFlagAsync(Snowflake id)
{ {
await using IDbContextTransaction tx = await db.Database.BeginTransactionAsync(); await using var tx = await db.Database.BeginTransactionAsync();
PrideFlag? flag = await db.PrideFlags.FirstOrDefaultAsync(f => var flag = await db.PrideFlags.FirstOrDefaultAsync(f =>
f.Id == id && f.UserId == CurrentUser!.Id f.Id == id && f.UserId == CurrentUser!.Id
); );
if (flag == null) if (flag == null)
throw new ApiError.NotFound("Unknown flag ID, or it's not your flag."); throw new ApiError.NotFound("Unknown flag ID, or it's not your flag.");
string hash = flag.Hash; var hash = flag.Hash;
db.PrideFlags.Remove(flag); db.PrideFlags.Remove(flag);
await db.SaveChangesAsync(); await db.SaveChangesAsync();
int flagCount = await db.PrideFlags.CountAsync(f => f.Hash == flag.Hash); var flagCount = await db.PrideFlags.CountAsync(f => f.Hash == flag.Hash);
if (flagCount == 0) if (flagCount == 0)
{ {
try try
@ -124,9 +120,7 @@ public class FlagsController(
} }
} }
else else
{
_logger.Debug("Flag file {Hash} is used by other flags, not deleting", hash); _logger.Debug("Flag file {Hash} is used by other flags, not deleting", hash);
}
await tx.CommitAsync(); await tx.CommitAsync();

View file

@ -44,22 +44,21 @@ public partial class InternalController(DatabaseContext db) : ControllerBase
[HttpPost("request-data")] [HttpPost("request-data")]
public async Task<IActionResult> GetRequestDataAsync([FromBody] RequestDataRequest req) public async Task<IActionResult> GetRequestDataAsync([FromBody] RequestDataRequest req)
{ {
RouteEndpoint? endpoint = GetEndpoint(HttpContext, req.Path, req.Method); var endpoint = GetEndpoint(HttpContext, req.Path, req.Method);
if (endpoint == null) if (endpoint == null)
throw new ApiError.BadRequest("Path/method combination is invalid"); throw new ApiError.BadRequest("Path/method combination is invalid");
ControllerActionDescriptor? actionDescriptor = var actionDescriptor = endpoint.Metadata.GetMetadata<ControllerActionDescriptor>();
endpoint.Metadata.GetMetadata<ControllerActionDescriptor>(); var template = actionDescriptor?.AttributeRouteInfo?.Template;
string? template = actionDescriptor?.AttributeRouteInfo?.Template;
if (template == null) if (template == null)
throw new FoxnounsError("Template value was null on valid endpoint"); throw new FoxnounsError("Template value was null on valid endpoint");
template = GetCleanedTemplate(template); template = GetCleanedTemplate(template);
// If no token was supplied, or it isn't valid base 64, return a null user ID (limiting by IP) // If no token was supplied, or it isn't valid base 64, return a null user ID (limiting by IP)
if (!AuthUtils.TryParseToken(req.Token, out byte[]? rawToken)) if (!AuthUtils.TryParseToken(req.Token, out var rawToken))
return Ok(new RequestDataResponse(null, template)); return Ok(new RequestDataResponse(null, template));
Snowflake? userId = await db.GetTokenUserId(rawToken); var userId = await db.GetTokenUserId(rawToken);
return Ok(new RequestDataResponse(userId, template)); return Ok(new RequestDataResponse(userId, template));
} }
@ -73,13 +72,12 @@ public partial class InternalController(DatabaseContext db) : ControllerBase
string requestMethod string requestMethod
) )
{ {
EndpointDataSource? endpointDataSource = var endpointDataSource = httpContext.RequestServices.GetService<EndpointDataSource>();
httpContext.RequestServices.GetService<EndpointDataSource>();
if (endpointDataSource == null) if (endpointDataSource == null)
return null; return null;
IEnumerable<RouteEndpoint> endpoints = endpointDataSource.Endpoints.OfType<RouteEndpoint>(); var endpoints = endpointDataSource.Endpoints.OfType<RouteEndpoint>();
foreach (RouteEndpoint? endpoint in endpoints) foreach (var endpoint in endpoints)
{ {
if (endpoint.RoutePattern.RawText == null) if (endpoint.RoutePattern.RawText == null)
continue; continue;
@ -88,19 +86,16 @@ public partial class InternalController(DatabaseContext db) : ControllerBase
TemplateParser.Parse(endpoint.RoutePattern.RawText), TemplateParser.Parse(endpoint.RoutePattern.RawText),
new RouteValueDictionary() new RouteValueDictionary()
); );
if (!templateMatcher.TryMatch(url, new RouteValueDictionary())) if (!templateMatcher.TryMatch(url, new()))
continue; continue;
HttpMethodAttribute? httpMethodAttribute = var httpMethodAttribute = endpoint.Metadata.GetMetadata<HttpMethodAttribute>();
endpoint.Metadata.GetMetadata<HttpMethodAttribute>();
if ( if (
httpMethodAttribute?.HttpMethods.Any(x => httpMethodAttribute != null
&& !httpMethodAttribute.HttpMethods.Any(x =>
x.Equals(requestMethod, StringComparison.OrdinalIgnoreCase) x.Equals(requestMethod, StringComparison.OrdinalIgnoreCase)
) == false
) )
{ )
continue; continue;
}
return endpoint; return endpoint;
} }

View file

@ -9,7 +9,6 @@ using Foxnouns.Backend.Services;
using Foxnouns.Backend.Utils; using Foxnouns.Backend.Utils;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Storage;
using NodaTime; using NodaTime;
namespace Foxnouns.Backend.Controllers; namespace Foxnouns.Backend.Controllers;
@ -33,7 +32,7 @@ public class MembersController(
)] )]
public async Task<IActionResult> GetMembersAsync(string userRef, CancellationToken ct = default) public async Task<IActionResult> GetMembersAsync(string userRef, CancellationToken ct = default)
{ {
User user = await db.ResolveUserAsync(userRef, CurrentToken, ct); var user = await db.ResolveUserAsync(userRef, CurrentToken, ct);
return Ok(await memberRenderer.RenderUserMembersAsync(user, CurrentToken)); return Ok(await memberRenderer.RenderUserMembersAsync(user, CurrentToken));
} }
@ -45,7 +44,7 @@ public class MembersController(
CancellationToken ct = default CancellationToken ct = default
) )
{ {
Member member = await db.ResolveMemberAsync(userRef, memberRef, CurrentToken, ct); var member = await db.ResolveMemberAsync(userRef, memberRef, CurrentToken, ct);
return Ok(memberRenderer.RenderMember(member, CurrentToken)); return Ok(memberRenderer.RenderMember(member, CurrentToken));
} }
@ -79,7 +78,7 @@ public class MembersController(
] ]
); );
int memberCount = await db.Members.CountAsync(m => m.UserId == CurrentUser.Id, ct); var memberCount = await db.Members.CountAsync(m => m.UserId == CurrentUser.Id, ct);
if (memberCount >= MaxMemberCount) if (memberCount >= MaxMemberCount)
throw new ApiError.BadRequest("Maximum number of members reached"); throw new ApiError.BadRequest("Maximum number of members reached");
@ -121,11 +120,9 @@ public class MembersController(
} }
if (req.Avatar != null) if (req.Avatar != null)
{
queue.QueueInvocableWithPayload<MemberAvatarUpdateInvocable, AvatarUpdatePayload>( queue.QueueInvocableWithPayload<MemberAvatarUpdateInvocable, AvatarUpdatePayload>(
new AvatarUpdatePayload(member.Id, req.Avatar) new AvatarUpdatePayload(member.Id, req.Avatar)
); );
}
return Ok(memberRenderer.RenderMember(member, CurrentToken)); return Ok(memberRenderer.RenderMember(member, CurrentToken));
} }
@ -137,8 +134,8 @@ public class MembersController(
[FromBody] UpdateMemberRequest req [FromBody] UpdateMemberRequest req
) )
{ {
await using IDbContextTransaction tx = await db.Database.BeginTransactionAsync(); await using var tx = await db.Database.BeginTransactionAsync();
Member member = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef); var member = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef);
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
// We might add extra validations for names later down the line. // We might add extra validations for names later down the line.
@ -200,11 +197,7 @@ public class MembersController(
if (req.Flags != null) if (req.Flags != null)
{ {
ValidationError? flagError = await db.SetMemberFlagsAsync( var flagError = await db.SetMemberFlagsAsync(CurrentUser!.Id, member.Id, req.Flags);
CurrentUser!.Id,
member.Id,
req.Flags
);
if (flagError != null) if (flagError != null)
errors.Add(("flags", flagError)); errors.Add(("flags", flagError));
} }
@ -217,12 +210,9 @@ public class MembersController(
// (atomic operations are hard when combined with background jobs) // (atomic operations are hard when combined with background jobs)
// so it's in a separate block to the validation above. // so it's in a separate block to the validation above.
if (req.HasProperty(nameof(req.Avatar))) if (req.HasProperty(nameof(req.Avatar)))
{
queue.QueueInvocableWithPayload<MemberAvatarUpdateInvocable, AvatarUpdatePayload>( queue.QueueInvocableWithPayload<MemberAvatarUpdateInvocable, AvatarUpdatePayload>(
new AvatarUpdatePayload(member.Id, req.Avatar) new AvatarUpdatePayload(member.Id, req.Avatar)
); );
}
try try
{ {
await db.SaveChangesAsync(); await db.SaveChangesAsync();
@ -238,7 +228,7 @@ public class MembersController(
throw new ApiError.BadRequest( throw new ApiError.BadRequest(
"A member with that name already exists", "A member with that name already exists",
"name", "name",
req.Name req.Name!
); );
} }
@ -264,8 +254,8 @@ public class MembersController(
[Authorize("member.update")] [Authorize("member.update")]
public async Task<IActionResult> DeleteMemberAsync(string memberRef) public async Task<IActionResult> DeleteMemberAsync(string memberRef)
{ {
Member member = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef); var member = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef);
int deleteCount = await db var deleteCount = await db
.Members.Where(m => m.UserId == CurrentUser!.Id && m.Id == member.Id) .Members.Where(m => m.UserId == CurrentUser!.Id && m.Id == member.Id)
.ExecuteDeleteAsync(); .ExecuteDeleteAsync();
if (deleteCount == 0) if (deleteCount == 0)
@ -299,9 +289,9 @@ public class MembersController(
[ProducesResponseType<UserRendererService.UserResponse>(statusCode: StatusCodes.Status200OK)] [ProducesResponseType<UserRendererService.UserResponse>(statusCode: StatusCodes.Status200OK)]
public async Task<IActionResult> RerollSidAsync(string memberRef) public async Task<IActionResult> RerollSidAsync(string memberRef)
{ {
Member member = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef); var member = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef);
Instant minTimeAgo = clock.GetCurrentInstant() - Duration.FromHours(1); var minTimeAgo = clock.GetCurrentInstant() - Duration.FromHours(1);
if (CurrentUser!.LastSidReroll > minTimeAgo) if (CurrentUser!.LastSidReroll > minTimeAgo)
throw new ApiError.BadRequest("Cannot reroll short ID yet"); throw new ApiError.BadRequest("Cannot reroll short ID yet");
@ -318,10 +308,7 @@ public class MembersController(
); );
// Fetch the new sid then pass that to RenderMember // Fetch the new sid then pass that to RenderMember
string newSid = await db var newSid = await db.Members.Where(m => m.Id == member.Id).Select(m => m.Sid).FirstAsync();
.Members.Where(m => m.Id == member.Id)
.Select(m => m.Sid)
.FirstAsync();
return Ok(memberRenderer.RenderMember(member, CurrentToken, newSid)); return Ok(memberRenderer.RenderMember(member, CurrentToken, newSid));
} }
} }

View file

@ -10,8 +10,9 @@ public class MetaController : ApiControllerBase
[HttpGet] [HttpGet]
[ProducesResponseType<MetaResponse>(StatusCodes.Status200OK)] [ProducesResponseType<MetaResponse>(StatusCodes.Status200OK)]
public IActionResult GetMeta() => public IActionResult GetMeta()
Ok( {
return Ok(
new MetaResponse( new MetaResponse(
Repository, Repository,
BuildInfo.Version, BuildInfo.Version,
@ -24,13 +25,14 @@ public class MetaController : ApiControllerBase
(int)FoxnounsMetrics.UsersActiveDayCount.Value (int)FoxnounsMetrics.UsersActiveDayCount.Value
), ),
new Limits( new Limits(
MembersController.MaxMemberCount, MemberCount: MembersController.MaxMemberCount,
ValidationUtils.MaxBioLength, BioLength: ValidationUtils.MaxBioLength,
ValidationUtils.MaxCustomPreferences, CustomPreferences: ValidationUtils.MaxCustomPreferences,
AuthUtils.MaxAuthMethodsPerType MaxAuthMethods: AuthUtils.MaxAuthMethodsPerType
) )
) )
); );
}
[HttpGet("/api/v2/coffee")] [HttpGet("/api/v2/coffee")]
public IActionResult BrewCoffee() => public IActionResult BrewCoffee() =>

View file

@ -24,7 +24,7 @@ public class SidController(Config config, DatabaseContext db) : ApiControllerBas
private async Task<IActionResult> ResolveUserSidAsync(string id, CancellationToken ct = default) private async Task<IActionResult> ResolveUserSidAsync(string id, CancellationToken ct = default)
{ {
string? username = await db var username = await db
.Users.Where(u => u.Sid == id.ToLowerInvariant() && !u.Deleted) .Users.Where(u => u.Sid == id.ToLowerInvariant() && !u.Deleted)
.Select(u => u.Username) .Select(u => u.Username)
.FirstOrDefaultAsync(ct); .FirstOrDefaultAsync(ct);

View file

@ -9,7 +9,6 @@ using Foxnouns.Backend.Services;
using Foxnouns.Backend.Utils; using Foxnouns.Backend.Utils;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Storage;
using NodaTime; using NodaTime;
namespace Foxnouns.Backend.Controllers; namespace Foxnouns.Backend.Controllers;
@ -30,9 +29,16 @@ public class UsersController(
[ProducesResponseType<UserRendererService.UserResponse>(statusCode: StatusCodes.Status200OK)] [ProducesResponseType<UserRendererService.UserResponse>(statusCode: StatusCodes.Status200OK)]
public async Task<IActionResult> GetUserAsync(string userRef, CancellationToken ct = default) public async Task<IActionResult> GetUserAsync(string userRef, CancellationToken ct = default)
{ {
User user = await db.ResolveUserAsync(userRef, CurrentToken, ct); var user = await db.ResolveUserAsync(userRef, CurrentToken, ct);
return Ok( return Ok(
await userRenderer.RenderUserAsync(user, CurrentUser, CurrentToken, true, true, ct: ct) await userRenderer.RenderUserAsync(
user,
selfUser: CurrentUser,
token: CurrentToken,
renderMembers: true,
renderAuthMethods: true,
ct: ct
)
); );
} }
@ -44,8 +50,8 @@ public class UsersController(
CancellationToken ct = default CancellationToken ct = default
) )
{ {
await using IDbContextTransaction tx = await db.Database.BeginTransactionAsync(ct); await using var tx = await db.Database.BeginTransactionAsync(ct);
User user = await db.Users.FirstAsync(u => u.Id == CurrentUser!.Id, ct); var user = await db.Users.FirstAsync(u => u.Id == CurrentUser!.Id, ct);
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
if (req.Username != null && req.Username != user.Username) if (req.Username != null && req.Username != user.Username)
@ -102,7 +108,7 @@ public class UsersController(
if (req.Flags != null) if (req.Flags != null)
{ {
ValidationError? flagError = await db.SetUserFlagsAsync(CurrentUser!.Id, req.Flags); var flagError = await db.SetUserFlagsAsync(CurrentUser!.Id, req.Flags);
if (flagError != null) if (flagError != null)
errors.Add(("flags", flagError)); errors.Add(("flags", flagError));
} }
@ -135,11 +141,8 @@ public class UsersController(
else else
{ {
if (TimeZoneInfo.TryFindSystemTimeZoneById(req.Timezone, out _)) if (TimeZoneInfo.TryFindSystemTimeZoneById(req.Timezone, out _))
{
user.Timezone = req.Timezone; user.Timezone = req.Timezone;
}
else else
{
errors.Add( errors.Add(
( (
"timezone", "timezone",
@ -148,18 +151,15 @@ public class UsersController(
); );
} }
} }
}
ValidationUtils.Validate(errors); ValidationUtils.Validate(errors);
// This is fired off regardless of whether the transaction is committed // This is fired off regardless of whether the transaction is committed
// (atomic operations are hard when combined with background jobs) // (atomic operations are hard when combined with background jobs)
// so it's in a separate block to the validation above. // so it's in a separate block to the validation above.
if (req.HasProperty(nameof(req.Avatar))) if (req.HasProperty(nameof(req.Avatar)))
{
queue.QueueInvocableWithPayload<UserAvatarUpdateInvocable, AvatarUpdatePayload>( queue.QueueInvocableWithPayload<UserAvatarUpdateInvocable, AvatarUpdatePayload>(
new AvatarUpdatePayload(CurrentUser!.Id, req.Avatar) new AvatarUpdatePayload(CurrentUser!.Id, req.Avatar)
); );
}
try try
{ {
@ -176,7 +176,7 @@ public class UsersController(
throw new ApiError.BadRequest( throw new ApiError.BadRequest(
"That username is already taken.", "That username is already taken.",
"username", "username",
req.Username req.Username!
); );
} }
@ -202,12 +202,12 @@ public class UsersController(
{ {
ValidationUtils.Validate(ValidationUtils.ValidateCustomPreferences(req)); ValidationUtils.Validate(ValidationUtils.ValidateCustomPreferences(req));
User user = await db.ResolveUserAsync(CurrentUser!.Id, ct); var user = await db.ResolveUserAsync(CurrentUser!.Id, ct);
var preferences = user var preferences = user
.CustomPreferences.Where(x => req.Any(r => r.Id == x.Key)) .CustomPreferences.Where(x => req.Any(r => r.Id == x.Key))
.ToDictionary(); .ToDictionary();
foreach (CustomPreferenceUpdate? r in req) foreach (var r in req)
{ {
if (r.Id != null && preferences.ContainsKey(r.Id.Value)) if (r.Id != null && preferences.ContainsKey(r.Id.Value))
{ {
@ -271,7 +271,7 @@ public class UsersController(
[ProducesResponseType<UserSettings>(statusCode: StatusCodes.Status200OK)] [ProducesResponseType<UserSettings>(statusCode: StatusCodes.Status200OK)]
public async Task<IActionResult> GetUserSettingsAsync(CancellationToken ct = default) public async Task<IActionResult> GetUserSettingsAsync(CancellationToken ct = default)
{ {
User user = await db.Users.FirstAsync(u => u.Id == CurrentUser!.Id, ct); var user = await db.Users.FirstAsync(u => u.Id == CurrentUser!.Id, ct);
return Ok(user.Settings); return Ok(user.Settings);
} }
@ -283,7 +283,7 @@ public class UsersController(
CancellationToken ct = default CancellationToken ct = default
) )
{ {
User user = await db.Users.FirstAsync(u => u.Id == CurrentUser!.Id, ct); var user = await db.Users.FirstAsync(u => u.Id == CurrentUser!.Id, ct);
if (req.HasProperty(nameof(req.DarkMode))) if (req.HasProperty(nameof(req.DarkMode)))
user.Settings.DarkMode = req.DarkMode; user.Settings.DarkMode = req.DarkMode;
@ -304,7 +304,7 @@ public class UsersController(
[ProducesResponseType<UserRendererService.UserResponse>(statusCode: StatusCodes.Status200OK)] [ProducesResponseType<UserRendererService.UserResponse>(statusCode: StatusCodes.Status200OK)]
public async Task<IActionResult> RerollSidAsync() public async Task<IActionResult> RerollSidAsync()
{ {
Instant minTimeAgo = clock.GetCurrentInstant() - Duration.FromHours(1); var minTimeAgo = clock.GetCurrentInstant() - Duration.FromHours(1);
if (CurrentUser!.LastSidReroll > minTimeAgo) if (CurrentUser!.LastSidReroll > minTimeAgo)
throw new ApiError.BadRequest("Cannot reroll short ID yet"); throw new ApiError.BadRequest("Cannot reroll short ID yet");
@ -318,18 +318,18 @@ public class UsersController(
); );
// Get the user's new sid // Get the user's new sid
string newSid = await db var newSid = await db
.Users.Where(u => u.Id == CurrentUser.Id) .Users.Where(u => u.Id == CurrentUser.Id)
.Select(u => u.Sid) .Select(u => u.Sid)
.FirstAsync(); .FirstAsync();
User user = await db.ResolveUserAsync(CurrentUser.Id); var user = await db.ResolveUserAsync(CurrentUser.Id);
return Ok( return Ok(
await userRenderer.RenderUserAsync( await userRenderer.RenderUserAsync(
user, CurrentUser,
CurrentUser, CurrentUser,
CurrentToken, CurrentToken,
false, renderMembers: false,
overrideSid: newSid overrideSid: newSid
) )
); );

View file

@ -11,8 +11,9 @@ namespace Foxnouns.Backend.Database;
public class DatabaseContext(DbContextOptions options) : DbContext(options) public class DatabaseContext(DbContextOptions options) : DbContext(options)
{ {
private static string GenerateConnectionString(Config.DatabaseConfig config) => private static string GenerateConnectionString(Config.DatabaseConfig config)
new NpgsqlConnectionStringBuilder(config.Url) {
return new NpgsqlConnectionStringBuilder(config.Url)
{ {
Pooling = config.EnablePooling ?? true, Pooling = config.EnablePooling ?? true,
Timeout = config.Timeout ?? 5, Timeout = config.Timeout ?? 5,
@ -21,6 +22,7 @@ public class DatabaseContext(DbContextOptions options) : DbContext(options)
ConnectionPruningInterval = 10, ConnectionPruningInterval = 10,
ConnectionIdleLifetime = 10, ConnectionIdleLifetime = 10,
}.ConnectionString; }.ConnectionString;
}
public static NpgsqlDataSource BuildDataSource(Config config) public static NpgsqlDataSource BuildDataSource(Config config)
{ {
@ -44,18 +46,18 @@ public class DatabaseContext(DbContextOptions options) : DbContext(options)
.UseSnakeCaseNamingConvention() .UseSnakeCaseNamingConvention()
.UseExceptionProcessor(); .UseExceptionProcessor();
public DbSet<User> Users { get; init; } = null!; public DbSet<User> Users { get; init; }
public DbSet<Member> Members { get; init; } = null!; public DbSet<Member> Members { get; init; }
public DbSet<AuthMethod> AuthMethods { get; init; } = null!; public DbSet<AuthMethod> AuthMethods { get; init; }
public DbSet<FediverseApplication> FediverseApplications { get; init; } = null!; public DbSet<FediverseApplication> FediverseApplications { get; init; }
public DbSet<Token> Tokens { get; init; } = null!; public DbSet<Token> Tokens { get; init; }
public DbSet<Application> Applications { get; init; } = null!; public DbSet<Application> Applications { get; init; }
public DbSet<TemporaryKey> TemporaryKeys { get; init; } = null!; public DbSet<TemporaryKey> TemporaryKeys { get; init; }
public DbSet<DataExport> DataExports { get; init; } = null!; public DbSet<DataExport> DataExports { get; init; }
public DbSet<PrideFlag> PrideFlags { get; init; } = null!; public DbSet<PrideFlag> PrideFlags { get; init; }
public DbSet<UserFlag> UserFlags { get; init; } = null!; public DbSet<UserFlag> UserFlags { get; init; }
public DbSet<MemberFlag> MemberFlags { get; init; } = null!; public DbSet<MemberFlag> MemberFlags { get; init; }
protected override void ConfigureConventions(ModelConfigurationBuilder configurationBuilder) protected override void ConfigureConventions(ModelConfigurationBuilder configurationBuilder)
{ {
@ -136,16 +138,16 @@ public class DesignTimeDatabaseContextFactory : IDesignTimeDbContextFactory<Data
public DatabaseContext CreateDbContext(string[] args) public DatabaseContext CreateDbContext(string[] args)
{ {
// Read the configuration file // Read the configuration file
Config config = var config =
new ConfigurationBuilder() new ConfigurationBuilder()
.AddConfiguration() .AddConfiguration()
.Build() .Build()
// Get the configuration as our config class // Get the configuration as our config class
.Get<Config>() ?? new Config(); .Get<Config>() ?? new();
NpgsqlDataSource dataSource = DatabaseContext.BuildDataSource(config); var dataSource = DatabaseContext.BuildDataSource(config);
DbContextOptions options = DatabaseContext var options = DatabaseContext
.BuildOptions(new DbContextOptionsBuilder(), dataSource, null) .BuildOptions(new DbContextOptionsBuilder(), dataSource, null)
.Options; .Options;

View file

@ -26,7 +26,7 @@ public static class DatabaseQueryExtensions
} }
User? user; User? user;
if (Snowflake.TryParse(userRef, out Snowflake? snowflake)) if (Snowflake.TryParse(userRef, out var snowflake))
{ {
user = await context user = await context
.Users.Where(u => !u.Deleted) .Users.Where(u => !u.Deleted)
@ -42,7 +42,7 @@ public static class DatabaseQueryExtensions
return user; return user;
throw new ApiError.NotFound( throw new ApiError.NotFound(
"No user with that ID or username found.", "No user with that ID or username found.",
ErrorCode.UserNotFound code: ErrorCode.UserNotFound
); );
} }
@ -52,12 +52,12 @@ public static class DatabaseQueryExtensions
CancellationToken ct = default CancellationToken ct = default
) )
{ {
User? user = await context var user = await context
.Users.Where(u => !u.Deleted) .Users.Where(u => !u.Deleted)
.FirstOrDefaultAsync(u => u.Id == id, ct); .FirstOrDefaultAsync(u => u.Id == id, ct);
if (user != null) if (user != null)
return user; return user;
throw new ApiError.NotFound("No user with that ID found.", ErrorCode.UserNotFound); throw new ApiError.NotFound("No user with that ID found.", code: ErrorCode.UserNotFound);
} }
public static async Task<Member> ResolveMemberAsync( public static async Task<Member> ResolveMemberAsync(
@ -66,13 +66,16 @@ public static class DatabaseQueryExtensions
CancellationToken ct = default CancellationToken ct = default
) )
{ {
Member? member = await context var member = await context
.Members.Include(m => m.User) .Members.Include(m => m.User)
.Where(m => !m.User.Deleted) .Where(m => !m.User.Deleted)
.FirstOrDefaultAsync(m => m.Id == id, ct); .FirstOrDefaultAsync(m => m.Id == id, ct);
if (member != null) if (member != null)
return member; return member;
throw new ApiError.NotFound("No member with that ID found.", ErrorCode.MemberNotFound); throw new ApiError.NotFound(
"No member with that ID found.",
code: ErrorCode.MemberNotFound
);
} }
public static async Task<Member> ResolveMemberAsync( public static async Task<Member> ResolveMemberAsync(
@ -83,7 +86,7 @@ public static class DatabaseQueryExtensions
CancellationToken ct = default CancellationToken ct = default
) )
{ {
User user = await context.ResolveUserAsync(userRef, token, ct); var user = await context.ResolveUserAsync(userRef, token, ct);
return await context.ResolveMemberAsync(user.Id, memberRef, ct); return await context.ResolveMemberAsync(user.Id, memberRef, ct);
} }
@ -95,7 +98,7 @@ public static class DatabaseQueryExtensions
) )
{ {
Member? member; Member? member;
if (Snowflake.TryParse(memberRef, out Snowflake? snowflake)) if (Snowflake.TryParse(memberRef, out var snowflake))
{ {
member = await context member = await context
.Members.Include(m => m.User) .Members.Include(m => m.User)
@ -115,7 +118,7 @@ public static class DatabaseQueryExtensions
return member; return member;
throw new ApiError.NotFound( throw new ApiError.NotFound(
"No member with that ID or name found.", "No member with that ID or name found.",
ErrorCode.MemberNotFound code: ErrorCode.MemberNotFound
); );
} }
@ -124,10 +127,7 @@ public static class DatabaseQueryExtensions
CancellationToken ct = default CancellationToken ct = default
) )
{ {
Application? app = await context.Applications.FirstOrDefaultAsync( var app = await context.Applications.FirstOrDefaultAsync(a => a.Id == new Snowflake(0), ct);
a => a.Id == new Snowflake(0),
ct
);
if (app != null) if (app != null)
return app; return app;
@ -152,9 +152,9 @@ public static class DatabaseQueryExtensions
CancellationToken ct = default CancellationToken ct = default
) )
{ {
byte[] hash = SHA512.HashData(rawToken); var hash = SHA512.HashData(rawToken);
Token? oauthToken = await context var oauthToken = await context
.Tokens.Include(t => t.Application) .Tokens.Include(t => t.Application)
.Include(t => t.User) .Include(t => t.User)
.FirstOrDefaultAsync( .FirstOrDefaultAsync(
@ -174,7 +174,7 @@ public static class DatabaseQueryExtensions
CancellationToken ct = default CancellationToken ct = default
) )
{ {
byte[] hash = SHA512.HashData(rawToken); var hash = SHA512.HashData(rawToken);
return await context return await context
.Tokens.Where(t => .Tokens.Where(t =>
t.Hash == hash t.Hash == hash

View file

@ -1,4 +1,3 @@
using Npgsql;
using Serilog; using Serilog;
namespace Foxnouns.Backend.Database; namespace Foxnouns.Backend.Database;
@ -10,8 +9,8 @@ public static class DatabaseServiceExtensions
Config config Config config
) )
{ {
NpgsqlDataSource dataSource = DatabaseContext.BuildDataSource(config); var dataSource = DatabaseContext.BuildDataSource(config);
ILoggerFactory loggerFactory = new LoggerFactory().AddSerilog(dispose: false); var loggerFactory = new LoggerFactory().AddSerilog(dispose: false);
serviceCollection.AddDbContext<DatabaseContext>(options => serviceCollection.AddDbContext<DatabaseContext>(options =>
DatabaseContext.BuildOptions(options, dataSource, loggerFactory) DatabaseContext.BuildOptions(options, dataSource, loggerFactory)

View file

@ -20,10 +20,8 @@ public static class FlagQueryExtensions
Snowflake[] flagIds Snowflake[] flagIds
) )
{ {
List<UserFlag> currentFlags = await db var currentFlags = await db.UserFlags.Where(f => f.UserId == userId).ToListAsync();
.UserFlags.Where(f => f.UserId == userId) foreach (var flag in currentFlags)
.ToListAsync();
foreach (UserFlag flag in currentFlags)
db.UserFlags.Remove(flag); db.UserFlags.Remove(flag);
// If there's no new flags to set, we're done // If there's no new flags to set, we're done
@ -32,16 +30,12 @@ public static class FlagQueryExtensions
if (flagIds.Length > 100) if (flagIds.Length > 100)
return ValidationError.LengthError("Too many profile flags", 0, 100, flagIds.Length); return ValidationError.LengthError("Too many profile flags", 0, 100, flagIds.Length);
List<PrideFlag> flags = await db.GetFlagsAsync(userId); var flags = await db.GetFlagsAsync(userId);
Snowflake[] unknownFlagIds = flagIds.Where(id => flags.All(f => f.Id != id)).ToArray(); var unknownFlagIds = flagIds.Where(id => flags.All(f => f.Id != id)).ToArray();
if (unknownFlagIds.Length != 0) if (unknownFlagIds.Length != 0)
return ValidationError.GenericValidationError("Unknown flag IDs", unknownFlagIds); return ValidationError.GenericValidationError("Unknown flag IDs", unknownFlagIds);
IEnumerable<UserFlag> userFlags = flagIds.Select(id => new UserFlag var userFlags = flagIds.Select(id => new UserFlag { PrideFlagId = id, UserId = userId });
{
PrideFlagId = id,
UserId = userId,
});
db.UserFlags.AddRange(userFlags); db.UserFlags.AddRange(userFlags);
return null; return null;
@ -54,10 +48,8 @@ public static class FlagQueryExtensions
Snowflake[] flagIds Snowflake[] flagIds
) )
{ {
List<MemberFlag> currentFlags = await db var currentFlags = await db.MemberFlags.Where(f => f.MemberId == memberId).ToListAsync();
.MemberFlags.Where(f => f.MemberId == memberId) foreach (var flag in currentFlags)
.ToListAsync();
foreach (MemberFlag flag in currentFlags)
db.MemberFlags.Remove(flag); db.MemberFlags.Remove(flag);
if (flagIds.Length == 0) if (flagIds.Length == 0)
@ -65,12 +57,12 @@ public static class FlagQueryExtensions
if (flagIds.Length > 100) if (flagIds.Length > 100)
return ValidationError.LengthError("Too many profile flags", 0, 100, flagIds.Length); return ValidationError.LengthError("Too many profile flags", 0, 100, flagIds.Length);
List<PrideFlag> flags = await db.GetFlagsAsync(userId); var flags = await db.GetFlagsAsync(userId);
Snowflake[] unknownFlagIds = flagIds.Where(id => flags.All(f => f.Id != id)).ToArray(); var unknownFlagIds = flagIds.Where(id => flags.All(f => f.Id != id)).ToArray();
if (unknownFlagIds.Length != 0) if (unknownFlagIds.Length != 0)
return ValidationError.GenericValidationError("Unknown flag IDs", unknownFlagIds); return ValidationError.GenericValidationError("Unknown flag IDs", unknownFlagIds);
IEnumerable<MemberFlag> memberFlags = flagIds.Select(id => new MemberFlag var memberFlags = flagIds.Select(id => new MemberFlag
{ {
PrideFlagId = id, PrideFlagId = id,
MemberId = memberId, MemberId = memberId,

View file

@ -24,7 +24,10 @@ namespace Foxnouns.Backend.Database.Migrations
client_secret = table.Column<string>(type: "text", nullable: false), client_secret = table.Column<string>(type: "text", nullable: false),
instance_type = table.Column<int>(type: "integer", nullable: false), instance_type = table.Column<int>(type: "integer", nullable: false),
}, },
constraints: table => table.PrimaryKey("pk_fediverse_applications", x => x.id) constraints: table =>
{
table.PrimaryKey("pk_fediverse_applications", x => x.id);
}
); );
migrationBuilder.CreateTable( migrationBuilder.CreateTable(
@ -43,7 +46,10 @@ namespace Foxnouns.Backend.Database.Migrations
names = table.Column<string>(type: "jsonb", nullable: false), names = table.Column<string>(type: "jsonb", nullable: false),
pronouns = table.Column<string>(type: "jsonb", nullable: false), pronouns = table.Column<string>(type: "jsonb", nullable: false),
}, },
constraints: table => table.PrimaryKey("pk_users", x => x.id) constraints: table =>
{
table.PrimaryKey("pk_users", x => x.id);
}
); );
migrationBuilder.CreateTable( migrationBuilder.CreateTable(

View file

@ -26,7 +26,7 @@ namespace Foxnouns.Backend.Database.Migrations
table: "tokens", table: "tokens",
type: "bytea", type: "bytea",
nullable: false, nullable: false,
defaultValue: Array.Empty<byte>() defaultValue: new byte[0]
); );
migrationBuilder.CreateTable( migrationBuilder.CreateTable(
@ -40,7 +40,10 @@ namespace Foxnouns.Backend.Database.Migrations
scopes = table.Column<string[]>(type: "text[]", nullable: false), scopes = table.Column<string[]>(type: "text[]", nullable: false),
redirect_uris = table.Column<string[]>(type: "text[]", nullable: false), redirect_uris = table.Column<string[]>(type: "text[]", nullable: false),
}, },
constraints: table => table.PrimaryKey("pk_applications", x => x.id) constraints: table =>
{
table.PrimaryKey("pk_applications", x => x.id);
}
); );
migrationBuilder.CreateIndex( migrationBuilder.CreateIndex(

View file

@ -32,7 +32,10 @@ namespace Foxnouns.Backend.Database.Migrations
nullable: false nullable: false
), ),
}, },
constraints: table => table.PrimaryKey("pk_temporary_keys", x => x.id) constraints: table =>
{
table.PrimaryKey("pk_temporary_keys", x => x.id);
}
); );
migrationBuilder.CreateIndex( migrationBuilder.CreateIndex(

View file

@ -18,8 +18,8 @@ public class Application : BaseModel
string[] redirectUrls string[] redirectUrls
) )
{ {
string clientId = RandomNumberGenerator.GetHexString(32, true); var clientId = RandomNumberGenerator.GetHexString(32, true);
string clientSecret = AuthUtils.RandomToken(); var clientSecret = AuthUtils.RandomToken();
if (scopes.Except(AuthUtils.ApplicationScopes).Any()) if (scopes.Except(AuthUtils.ApplicationScopes).Any())
{ {

View file

@ -59,7 +59,7 @@ public readonly struct Snowflake(ulong value) : IEquatable<Snowflake>
public static bool TryParse(string input, [NotNullWhen(true)] out Snowflake? snowflake) public static bool TryParse(string input, [NotNullWhen(true)] out Snowflake? snowflake)
{ {
snowflake = null; snowflake = null;
if (!ulong.TryParse(input, out ulong res)) if (!ulong.TryParse(input, out var res))
return false; return false;
snowflake = new Snowflake(res); snowflake = new Snowflake(res);
return true; return true;
@ -70,7 +70,10 @@ public readonly struct Snowflake(ulong value) : IEquatable<Snowflake>
public override bool Equals(object? obj) => obj is Snowflake other && Value == other.Value; public override bool Equals(object? obj) => obj is Snowflake other && Value == other.Value;
public bool Equals(Snowflake other) => Value == other.Value; public bool Equals(Snowflake other)
{
return Value == other.Value;
}
public override int GetHashCode() => Value.GetHashCode(); public override int GetHashCode() => Value.GetHashCode();
@ -80,7 +83,11 @@ public readonly struct Snowflake(ulong value) : IEquatable<Snowflake>
/// An Entity Framework ValueConverter for Snowflakes to longs. /// An Entity Framework ValueConverter for Snowflakes to longs.
/// </summary> /// </summary>
// ReSharper disable once ClassNeverInstantiated.Global // ReSharper disable once ClassNeverInstantiated.Global
public class ValueConverter() : ValueConverter<Snowflake, long>(x => x, x => x); public class ValueConverter()
: ValueConverter<Snowflake, long>(
convertToProviderExpression: x => x,
convertFromProviderExpression: x => x
);
private class JsonConverter : JsonConverter<Snowflake> private class JsonConverter : JsonConverter<Snowflake>
{ {
@ -99,7 +106,10 @@ public readonly struct Snowflake(ulong value) : IEquatable<Snowflake>
Snowflake existingValue, Snowflake existingValue,
bool hasExistingValue, bool hasExistingValue,
JsonSerializer serializer JsonSerializer serializer
) => ulong.Parse((string)reader.Value!); )
{
return ulong.Parse((string)reader.Value!);
}
} }
private class TypeConverter : System.ComponentModel.TypeConverter private class TypeConverter : System.ComponentModel.TypeConverter
@ -116,6 +126,9 @@ public readonly struct Snowflake(ulong value) : IEquatable<Snowflake>
ITypeDescriptorContext? context, ITypeDescriptorContext? context,
CultureInfo? culture, CultureInfo? culture,
object value object value
) => TryParse((string)value, out Snowflake? snowflake) ? snowflake : null; )
{
return TryParse((string)value, out var snowflake) ? snowflake : null;
}
} }
} }

View file

@ -28,9 +28,9 @@ public class SnowflakeGenerator : ISnowflakeGenerator
public Snowflake GenerateSnowflake(Instant? time = null) public Snowflake GenerateSnowflake(Instant? time = null)
{ {
time ??= SystemClock.Instance.GetCurrentInstant(); time ??= SystemClock.Instance.GetCurrentInstant();
long increment = Interlocked.Increment(ref _increment); var increment = Interlocked.Increment(ref _increment);
int threadId = Environment.CurrentManagedThreadId % 32; var threadId = Environment.CurrentManagedThreadId % 32;
long timestamp = time.Value.ToUnixTimeMilliseconds() - Snowflake.Epoch; var timestamp = time.Value.ToUnixTimeMilliseconds() - Snowflake.Epoch;
return (timestamp << 22) return (timestamp << 22)
| (uint)(_processId << 17) | (uint)(_processId << 17)
@ -44,5 +44,8 @@ public static class SnowflakeGeneratorServiceExtensions
public static IServiceCollection AddSnowflakeGenerator( public static IServiceCollection AddSnowflakeGenerator(
this IServiceCollection services, this IServiceCollection services,
int? processId = null int? processId = null
) => services.AddSingleton<ISnowflakeGenerator>(new SnowflakeGenerator(processId)); )
{
return services.AddSingleton<ISnowflakeGenerator>(new SnowflakeGenerator(processId));
}
} }

View file

@ -35,13 +35,13 @@ public class ApiError(
public readonly ErrorCode ErrorCode = errorCode ?? ErrorCode.InternalServerError; public readonly ErrorCode ErrorCode = errorCode ?? ErrorCode.InternalServerError;
public class Unauthorized(string message, ErrorCode errorCode = ErrorCode.AuthenticationError) public class Unauthorized(string message, ErrorCode errorCode = ErrorCode.AuthenticationError)
: ApiError(message, HttpStatusCode.Unauthorized, errorCode); : ApiError(message, statusCode: HttpStatusCode.Unauthorized, errorCode: errorCode);
public class Forbidden( public class Forbidden(
string message, string message,
IEnumerable<string>? scopes = null, IEnumerable<string>? scopes = null,
ErrorCode errorCode = ErrorCode.Forbidden ErrorCode errorCode = ErrorCode.Forbidden
) : ApiError(message, HttpStatusCode.Forbidden, errorCode) ) : ApiError(message, statusCode: HttpStatusCode.Forbidden, errorCode: errorCode)
{ {
public readonly string[] Scopes = scopes?.ToArray() ?? []; public readonly string[] Scopes = scopes?.ToArray() ?? [];
} }
@ -49,7 +49,7 @@ public class ApiError(
public class BadRequest( public class BadRequest(
string message, string message,
IReadOnlyDictionary<string, IEnumerable<ValidationError>>? errors = null IReadOnlyDictionary<string, IEnumerable<ValidationError>>? errors = null
) : ApiError(message, HttpStatusCode.BadRequest) ) : ApiError(message, statusCode: HttpStatusCode.BadRequest)
{ {
public BadRequest(string message, string field, object? actualValue) public BadRequest(string message, string field, object? actualValue)
: this( : this(
@ -72,7 +72,7 @@ public class ApiError(
return o; return o;
var a = new JArray(); var a = new JArray();
foreach (KeyValuePair<string, IEnumerable<ValidationError>> error in errors) foreach (var error in errors)
{ {
var errorObj = new JObject var errorObj = new JObject
{ {
@ -92,7 +92,7 @@ public class ApiError(
/// Any other methods should use <see cref="ApiError.BadRequest" /> instead. /// Any other methods should use <see cref="ApiError.BadRequest" /> instead.
/// </summary> /// </summary>
public class AspBadRequest(string message, ModelStateDictionary? modelState = null) public class AspBadRequest(string message, ModelStateDictionary? modelState = null)
: ApiError(message, HttpStatusCode.BadRequest) : ApiError(message, statusCode: HttpStatusCode.BadRequest)
{ {
public JObject ToJson() public JObject ToJson()
{ {
@ -106,11 +106,7 @@ public class ApiError(
return o; return o;
var a = new JArray(); var a = new JArray();
foreach ( foreach (var error in modelState.Where(e => e.Value is { Errors.Count: > 0 }))
KeyValuePair<string, ModelStateEntry?> error in modelState.Where(e =>
e.Value is { Errors.Count: > 0 }
)
)
{ {
var errorObj = new JObject var errorObj = new JObject
{ {
@ -134,9 +130,10 @@ public class ApiError(
} }
public class NotFound(string message, ErrorCode? code = null) public class NotFound(string message, ErrorCode? code = null)
: ApiError(message, HttpStatusCode.NotFound, code); : ApiError(message, statusCode: HttpStatusCode.NotFound, errorCode: code);
public class AuthenticationError(string message) : ApiError(message, HttpStatusCode.BadRequest); public class AuthenticationError(string message)
: ApiError(message, statusCode: HttpStatusCode.BadRequest);
} }
public enum ErrorCode public enum ErrorCode
@ -178,27 +175,33 @@ public class ValidationError
int minLength, int minLength,
int maxLength, int maxLength,
int actualLength int actualLength
) => )
new() {
return new ValidationError
{ {
Message = message, Message = message,
MinLength = minLength, MinLength = minLength,
MaxLength = maxLength, MaxLength = maxLength,
ActualLength = actualLength, ActualLength = actualLength,
}; };
}
public static ValidationError DisallowedValueError( public static ValidationError DisallowedValueError(
string message, string message,
IEnumerable<object> allowedValues, IEnumerable<object> allowedValues,
object actualValue object actualValue
) => )
new() {
return new ValidationError
{ {
Message = message, Message = message,
AllowedValues = allowedValues, AllowedValues = allowedValues,
ActualValue = actualValue, ActualValue = actualValue,
}; };
}
public static ValidationError GenericValidationError(string message, object? actualValue) => public static ValidationError GenericValidationError(string message, object? actualValue)
new() { Message = message, ActualValue = actualValue }; {
return new ValidationError { Message = message, ActualValue = actualValue };
}
} }

View file

@ -47,13 +47,13 @@ public static class ImageObjectExtensions
if (!uri.StartsWith("data:image/")) if (!uri.StartsWith("data:image/"))
throw new ArgumentException("Not a data URI", nameof(uri)); throw new ArgumentException("Not a data URI", nameof(uri));
string[] split = uri.Remove(0, "data:".Length).Split(";base64,"); var split = uri.Remove(0, "data:".Length).Split(";base64,");
string contentType = split[0]; var contentType = split[0];
string encoded = split[1]; var encoded = split[1];
if (!ValidContentTypes.Contains(contentType)) if (!ValidContentTypes.Contains(contentType))
throw new ArgumentException("Invalid content type for image", nameof(uri)); throw new ArgumentException("Invalid content type for image", nameof(uri));
if (!AuthUtils.TryFromBase64String(encoded, out byte[]? rawImage)) if (!AuthUtils.TryFromBase64String(encoded, out var rawImage))
throw new ArgumentException("Invalid base64 string", nameof(uri)); throw new ArgumentException("Invalid base64 string", nameof(uri));
var image = Image.Load(rawImage); var image = Image.Load(rawImage);
@ -74,7 +74,7 @@ public static class ImageObjectExtensions
await image.SaveAsync(stream, new WebpEncoder { Quality = 95, NearLossless = false }); await image.SaveAsync(stream, new WebpEncoder { Quality = 95, NearLossless = false });
stream.Seek(0, SeekOrigin.Begin); stream.Seek(0, SeekOrigin.Begin);
string hash = Convert.ToHexString(await SHA256.HashDataAsync(stream)).ToLower(); var hash = Convert.ToHexString(await SHA256.HashDataAsync(stream)).ToLower();
stream.Seek(0, SeekOrigin.Begin); stream.Seek(0, SeekOrigin.Begin);
return (hash, stream); return (hash, stream);

View file

@ -14,7 +14,7 @@ public static class KeyCacheExtensions
CancellationToken ct = default CancellationToken ct = default
) )
{ {
string state = AuthUtils.RandomToken().Replace('+', '-').Replace('/', '_'); var state = AuthUtils.RandomToken().Replace('+', '-').Replace('/', '_');
await keyCacheService.SetKeyAsync($"oauth_state:{state}", "", Duration.FromMinutes(10), ct); await keyCacheService.SetKeyAsync($"oauth_state:{state}", "", Duration.FromMinutes(10), ct);
return state; return state;
} }
@ -25,7 +25,7 @@ public static class KeyCacheExtensions
CancellationToken ct = default CancellationToken ct = default
) )
{ {
string? val = await keyCacheService.GetKeyAsync($"oauth_state:{state}", ct: ct); var val = await keyCacheService.GetKeyAsync($"oauth_state:{state}", delete: true, ct);
if (val == null) if (val == null)
throw new ApiError.BadRequest("Invalid OAuth state"); throw new ApiError.BadRequest("Invalid OAuth state");
} }
@ -38,7 +38,7 @@ public static class KeyCacheExtensions
) )
{ {
// This state is used in links, not just as JSON values, so make it URL-safe // This state is used in links, not just as JSON values, so make it URL-safe
string state = AuthUtils.RandomToken().Replace('+', '-').Replace('/', '_'); var state = AuthUtils.RandomToken().Replace('+', '-').Replace('/', '_');
await keyCacheService.SetKeyAsync( await keyCacheService.SetKeyAsync(
$"email_state:{state}", $"email_state:{state}",
new RegisterEmailState(email, userId), new RegisterEmailState(email, userId),
@ -52,7 +52,12 @@ public static class KeyCacheExtensions
this KeyCacheService keyCacheService, this KeyCacheService keyCacheService,
string state, string state,
CancellationToken ct = default CancellationToken ct = default
) => await keyCacheService.GetKeyAsync<RegisterEmailState>($"email_state:{state}", ct: ct); ) =>
await keyCacheService.GetKeyAsync<RegisterEmailState>(
$"email_state:{state}",
delete: true,
ct
);
public static async Task<string> GenerateAddExtraAccountStateAsync( public static async Task<string> GenerateAddExtraAccountStateAsync(
this KeyCacheService keyCacheService, this KeyCacheService keyCacheService,
@ -62,7 +67,7 @@ public static class KeyCacheExtensions
CancellationToken ct = default CancellationToken ct = default
) )
{ {
string state = AuthUtils.RandomToken(); var state = AuthUtils.RandomToken();
await keyCacheService.SetKeyAsync( await keyCacheService.SetKeyAsync(
$"add_account:{state}", $"add_account:{state}",
new AddExtraAccountState(authType, userId, instance), new AddExtraAccountState(authType, userId, instance),
@ -76,7 +81,12 @@ public static class KeyCacheExtensions
this KeyCacheService keyCacheService, this KeyCacheService keyCacheService,
string state, string state,
CancellationToken ct = default CancellationToken ct = default
) => await keyCacheService.GetKeyAsync<AddExtraAccountState>($"add_account:{state}", true, ct); ) =>
await keyCacheService.GetKeyAsync<AddExtraAccountState>(
$"add_account:{state}",
delete: true,
ct
);
} }
public record RegisterEmailState( public record RegisterEmailState(

View file

@ -24,9 +24,9 @@ public static class WebApplicationExtensions
/// </summary> /// </summary>
public static WebApplicationBuilder AddSerilog(this WebApplicationBuilder builder) public static WebApplicationBuilder AddSerilog(this WebApplicationBuilder builder)
{ {
Config config = builder.Configuration.Get<Config>() ?? new Config(); var config = builder.Configuration.Get<Config>() ?? new();
LoggerConfiguration logCfg = new LoggerConfiguration() var logCfg = new LoggerConfiguration()
.Enrich.FromLogContext() .Enrich.FromLogContext()
.MinimumLevel.Is(config.Logging.LogEventLevel) .MinimumLevel.Is(config.Logging.LogEventLevel)
// ASP.NET's built in request logs are extremely verbose, so we use Serilog's instead. // ASP.NET's built in request logs are extremely verbose, so we use Serilog's instead.
@ -43,7 +43,10 @@ public static class WebApplicationExtensions
if (config.Logging.SeqLogUrl != null) if (config.Logging.SeqLogUrl != null)
{ {
logCfg.WriteTo.Seq(config.Logging.SeqLogUrl, LogEventLevel.Verbose); logCfg.WriteTo.Seq(
config.Logging.SeqLogUrl,
restrictedToMinimumLevel: LogEventLevel.Verbose
);
} }
// AddSerilog doesn't seem to add an ILogger to the service collection, so add that manually. // AddSerilog doesn't seem to add an ILogger to the service collection, so add that manually.
@ -57,19 +60,19 @@ public static class WebApplicationExtensions
builder.Configuration.Sources.Clear(); builder.Configuration.Sources.Clear();
builder.Configuration.AddConfiguration(); builder.Configuration.AddConfiguration();
Config config = builder.Configuration.Get<Config>() ?? new Config(); var config = builder.Configuration.Get<Config>() ?? new();
builder.Services.AddSingleton(config); builder.Services.AddSingleton(config);
return config; return config;
} }
public static IConfigurationBuilder AddConfiguration(this IConfigurationBuilder builder) public static IConfigurationBuilder AddConfiguration(this IConfigurationBuilder builder)
{ {
string file = Environment.GetEnvironmentVariable("FOXNOUNS_CONFIG_FILE") ?? "config.ini"; var file = Environment.GetEnvironmentVariable("FOXNOUNS_CONFIG_FILE") ?? "config.ini";
return builder return builder
.SetBasePath(Directory.GetCurrentDirectory()) .SetBasePath(Directory.GetCurrentDirectory())
.AddJsonFile("appSettings.json", true) .AddJsonFile("appSettings.json", true)
.AddIniFile(file, false, true) .AddIniFile(file, optional: false, reloadOnChange: true)
.AddEnvironmentVariables(); .AddEnvironmentVariables();
} }
@ -139,15 +142,11 @@ public static class WebApplicationExtensions
app.Services.ConfigureQueue() app.Services.ConfigureQueue()
.LogQueuedTaskProgress(app.Services.GetRequiredService<ILogger<IQueue>>()); .LogQueuedTaskProgress(app.Services.GetRequiredService<ILogger<IQueue>>());
await using AsyncServiceScope scope = app.Services.CreateAsyncScope(); await using var scope = app.Services.CreateAsyncScope();
// The types of these variables are obvious from the methods being called to create them
// ReSharper disable SuggestVarOrType_SimpleTypes
var logger = scope var logger = scope
.ServiceProvider.GetRequiredService<ILogger>() .ServiceProvider.GetRequiredService<ILogger>()
.ForContext<WebApplication>(); .ForContext<WebApplication>();
var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>(); var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
// ReSharper restore SuggestVarOrType_SimpleTypes
logger.Information( logger.Information(
"Starting Foxnouns.NET {Version} ({Hash})", "Starting Foxnouns.NET {Version} ({Hash})",

View file

@ -30,10 +30,6 @@
<PackageReference Include="Npgsql.Json.NET" Version="8.0.3"/> <PackageReference Include="Npgsql.Json.NET" Version="8.0.3"/>
<PackageReference Include="prometheus-net" Version="8.2.1"/> <PackageReference Include="prometheus-net" Version="8.2.1"/>
<PackageReference Include="prometheus-net.AspNetCore" Version="8.2.1"/> <PackageReference Include="prometheus-net.AspNetCore" Version="8.2.1"/>
<PackageReference Include="Roslynator.Analyzers" Version="4.12.9">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="Sentry.AspNetCore" Version="4.9.0"/> <PackageReference Include="Sentry.AspNetCore" Version="4.9.0"/>
<PackageReference Include="Serilog" Version="4.0.1"/> <PackageReference Include="Serilog" Version="4.0.1"/>
<PackageReference Include="Serilog.AspNetCore" Version="8.0.1"/> <PackageReference Include="Serilog.AspNetCore" Version="8.0.1"/>

View file

@ -40,7 +40,7 @@ public class CreateDataExportInvocable(
private async Task InvokeAsync() private async Task InvokeAsync()
{ {
User? user = await db var user = await db
.Users.Include(u => u.AuthMethods) .Users.Include(u => u.AuthMethods)
.Include(u => u.Flags) .Include(u => u.Flags)
.Include(u => u.ProfileFlags) .Include(u => u.ProfileFlags)
@ -57,7 +57,7 @@ public class CreateDataExportInvocable(
_logger.Information("Generating data export for user {UserId}", user.Id); _logger.Information("Generating data export for user {UserId}", user.Id);
await using var stream = new MemoryStream(); using var stream = new MemoryStream();
using var zip = new ZipArchive(stream, ZipArchiveMode.Create, true); using var zip = new ZipArchive(stream, ZipArchiveMode.Create, true);
zip.Comment = zip.Comment =
$"This archive for {user.Username} ({user.Id}) was generated at {InstantPattern.General.Format(clock.GetCurrentInstant())}"; $"This archive for {user.Username} ({user.Id}) was generated at {InstantPattern.General.Format(clock.GetCurrentInstant())}";
@ -66,19 +66,25 @@ public class CreateDataExportInvocable(
WriteJson( WriteJson(
zip, zip,
"user.json", "user.json",
await userRenderer.RenderUserInnerAsync(user, true, ["*"], false, true) await userRenderer.RenderUserInnerAsync(
user,
true,
["*"],
renderMembers: false,
renderAuthMethods: true
)
); );
await WriteS3Object(zip, "user-avatar.webp", userRenderer.AvatarUrlFor(user)); await WriteS3Object(zip, "user-avatar.webp", userRenderer.AvatarUrlFor(user));
foreach (PrideFlag? flag in user.Flags) foreach (var flag in user.Flags)
await WritePrideFlag(zip, flag); await WritePrideFlag(zip, flag);
List<Member> members = await db var members = await db
.Members.Include(m => m.User) .Members.Include(m => m.User)
.Include(m => m.ProfileFlags) .Include(m => m.ProfileFlags)
.Where(m => m.UserId == user.Id) .Where(m => m.UserId == user.Id)
.ToListAsync(); .ToListAsync();
foreach (Member? member in members) foreach (var member in members)
await WriteMember(zip, member); await WriteMember(zip, member);
// We want to dispose the ZipArchive on an error, but we need to dispose it manually to upload to object storage. // We want to dispose the ZipArchive on an error, but we need to dispose it manually to upload to object storage.
@ -88,7 +94,7 @@ public class CreateDataExportInvocable(
stream.Seek(0, SeekOrigin.Begin); stream.Seek(0, SeekOrigin.Begin);
// Upload the file! // Upload the file!
string filename = AuthUtils.RandomToken().Replace('+', '-').Replace('/', '_'); var filename = AuthUtils.RandomToken().Replace('+', '-').Replace('/', '_');
await objectStorageService.PutObjectAsync( await objectStorageService.PutObjectAsync(
ExportPath(user.Id, filename), ExportPath(user.Id, filename),
stream, stream,
@ -126,8 +132,8 @@ public class CreateDataExportInvocable(
return; return;
} }
ZipArchiveEntry entry = zip.CreateEntry($"flag-{flag.Id}/flag.txt"); var entry = zip.CreateEntry($"flag-{flag.Id}/flag.txt");
await using Stream stream = entry.Open(); await using var stream = entry.Open();
await using var writer = new StreamWriter(stream); await using var writer = new StreamWriter(stream);
await writer.WriteAsync(flagData); await writer.WriteAsync(flagData);
} }
@ -158,7 +164,7 @@ public class CreateDataExportInvocable(
private void WriteJson(ZipArchive zip, string filename, object data) private void WriteJson(ZipArchive zip, string filename, object data)
{ {
string json = JsonConvert.SerializeObject(data, Formatting.Indented); var json = JsonConvert.SerializeObject(data, Formatting.Indented);
_logger.Debug( _logger.Debug(
"Writing file {Filename} to archive with size {Length}", "Writing file {Filename} to archive with size {Length}",
@ -166,8 +172,8 @@ public class CreateDataExportInvocable(
json.Length json.Length
); );
ZipArchiveEntry entry = zip.CreateEntry(filename); var entry = zip.CreateEntry(filename);
using Stream stream = entry.Open(); using var stream = entry.Open();
using var writer = new StreamWriter(stream); using var writer = new StreamWriter(stream);
writer.Write(json); writer.Write(json);
} }
@ -177,14 +183,14 @@ public class CreateDataExportInvocable(
if (s3Path == null) if (s3Path == null)
return; return;
HttpResponseMessage resp = await Client.GetAsync(s3Path); var resp = await Client.GetAsync(s3Path);
if (resp.StatusCode != HttpStatusCode.OK) if (resp.StatusCode != HttpStatusCode.OK)
{ {
_logger.Warning("S3 path {S3Path} returned a non-200 status, not saving file", s3Path); _logger.Warning("S3 path {S3Path} returned a non-200 status, not saving file", s3Path);
return; return;
} }
await using Stream respStream = await resp.Content.ReadAsStreamAsync(); await using var respStream = await resp.Content.ReadAsStreamAsync();
_logger.Debug( _logger.Debug(
"Writing file {Filename} to archive with size {Length}", "Writing file {Filename} to archive with size {Length}",
@ -192,8 +198,8 @@ public class CreateDataExportInvocable(
respStream.Length respStream.Length
); );
ZipArchiveEntry entry = zip.CreateEntry(filename); var entry = zip.CreateEntry(filename);
await using Stream entryStream = entry.Open(); await using var entryStream = entry.Open();
respStream.Seek(0, SeekOrigin.Begin); respStream.Seek(0, SeekOrigin.Begin);
await respStream.CopyToAsync(entryStream); await respStream.CopyToAsync(entryStream);

View file

@ -26,10 +26,10 @@ public class CreateFlagInvocable(
try try
{ {
(string? hash, Stream? image) = await ImageObjectExtensions.ConvertBase64UriToImage( var (hash, image) = await ImageObjectExtensions.ConvertBase64UriToImage(
Payload.ImageData, Payload.ImageData,
256, size: 256,
false crop: false
); );
await objectStorageService.PutObjectAsync(Path(hash), image, "image/webp"); await objectStorageService.PutObjectAsync(Path(hash), image, "image/webp");

View file

@ -1,6 +1,5 @@
using Coravel.Invocable; using Coravel.Invocable;
using Foxnouns.Backend.Database; using Foxnouns.Backend.Database;
using Foxnouns.Backend.Database.Models;
using Foxnouns.Backend.Extensions; using Foxnouns.Backend.Extensions;
using Foxnouns.Backend.Services; using Foxnouns.Backend.Services;
@ -27,7 +26,7 @@ public class MemberAvatarUpdateInvocable(
{ {
_logger.Debug("Updating avatar for member {MemberId}", id); _logger.Debug("Updating avatar for member {MemberId}", id);
Member? member = await db.Members.FindAsync(id); var member = await db.Members.FindAsync(id);
if (member == null) if (member == null)
{ {
_logger.Warning( _logger.Warning(
@ -39,12 +38,12 @@ public class MemberAvatarUpdateInvocable(
try try
{ {
(string? hash, Stream? image) = await ImageObjectExtensions.ConvertBase64UriToImage( var (hash, image) = await ImageObjectExtensions.ConvertBase64UriToImage(
newAvatar, newAvatar,
512, size: 512,
true crop: true
); );
string? prevHash = member.Avatar; var prevHash = member.Avatar;
await objectStorageService.PutObjectAsync(Path(id, hash), image, "image/webp"); await objectStorageService.PutObjectAsync(Path(id, hash), image, "image/webp");
@ -70,7 +69,7 @@ public class MemberAvatarUpdateInvocable(
{ {
_logger.Debug("Clearing avatar for member {MemberId}", id); _logger.Debug("Clearing avatar for member {MemberId}", id);
Member? member = await db.Members.FindAsync(id); var member = await db.Members.FindAsync(id);
if (member == null) if (member == null)
{ {
_logger.Warning( _logger.Warning(

View file

@ -1,6 +1,5 @@
using Coravel.Invocable; using Coravel.Invocable;
using Foxnouns.Backend.Database; using Foxnouns.Backend.Database;
using Foxnouns.Backend.Database.Models;
using Foxnouns.Backend.Extensions; using Foxnouns.Backend.Extensions;
using Foxnouns.Backend.Services; using Foxnouns.Backend.Services;
@ -27,7 +26,7 @@ public class UserAvatarUpdateInvocable(
{ {
_logger.Debug("Updating avatar for user {MemberId}", id); _logger.Debug("Updating avatar for user {MemberId}", id);
User? user = await db.Users.FindAsync(id); var user = await db.Users.FindAsync(id);
if (user == null) if (user == null)
{ {
_logger.Warning( _logger.Warning(
@ -39,13 +38,13 @@ public class UserAvatarUpdateInvocable(
try try
{ {
(string? hash, Stream? image) = await ImageObjectExtensions.ConvertBase64UriToImage( var (hash, image) = await ImageObjectExtensions.ConvertBase64UriToImage(
newAvatar, newAvatar,
512, size: 512,
true crop: true
); );
image.Seek(0, SeekOrigin.Begin); image.Seek(0, SeekOrigin.Begin);
string? prevHash = user.Avatar; var prevHash = user.Avatar;
await objectStorageService.PutObjectAsync(Path(id, hash), image, "image/webp"); await objectStorageService.PutObjectAsync(Path(id, hash), image, "image/webp");
@ -71,7 +70,7 @@ public class UserAvatarUpdateInvocable(
{ {
_logger.Debug("Clearing avatar for user {MemberId}", id); _logger.Debug("Clearing avatar for user {MemberId}", id);
User? user = await db.Users.FindAsync(id); var user = await db.Users.FindAsync(id);
if (user == null) if (user == null)
{ {
_logger.Warning( _logger.Warning(

View file

@ -8,8 +8,8 @@ public class AuthenticationMiddleware(DatabaseContext db) : IMiddleware
{ {
public async Task InvokeAsync(HttpContext ctx, RequestDelegate next) public async Task InvokeAsync(HttpContext ctx, RequestDelegate next)
{ {
Endpoint? endpoint = ctx.GetEndpoint(); var endpoint = ctx.GetEndpoint();
AuthenticateAttribute? metadata = endpoint?.Metadata.GetMetadata<AuthenticateAttribute>(); var metadata = endpoint?.Metadata.GetMetadata<AuthenticateAttribute>();
if (metadata == null) if (metadata == null)
{ {
@ -18,17 +18,14 @@ public class AuthenticationMiddleware(DatabaseContext db) : IMiddleware
} }
if ( if (
!AuthUtils.TryParseToken( !AuthUtils.TryParseToken(ctx.Request.Headers.Authorization.ToString(), out var rawToken)
ctx.Request.Headers.Authorization.ToString(),
out byte[]? rawToken
)
) )
{ {
await next(ctx); await next(ctx);
return; return;
} }
Token? oauthToken = await db.GetToken(rawToken); var oauthToken = await db.GetToken(rawToken);
if (oauthToken == null) if (oauthToken == null)
{ {
await next(ctx); await next(ctx);
@ -53,7 +50,7 @@ public static class HttpContextExtensions
public static Token? GetToken(this HttpContext ctx) public static Token? GetToken(this HttpContext ctx)
{ {
if (ctx.Items.TryGetValue(Key, out object? token)) if (ctx.Items.TryGetValue(Key, out var token))
return token as Token; return token as Token;
return null; return null;
} }

View file

@ -7,8 +7,8 @@ public class AuthorizationMiddleware : IMiddleware
{ {
public async Task InvokeAsync(HttpContext ctx, RequestDelegate next) public async Task InvokeAsync(HttpContext ctx, RequestDelegate next)
{ {
Endpoint? endpoint = ctx.GetEndpoint(); var endpoint = ctx.GetEndpoint();
AuthorizeAttribute? attribute = endpoint?.Metadata.GetMetadata<AuthorizeAttribute>(); var attribute = endpoint?.Metadata.GetMetadata<AuthorizeAttribute>();
if (attribute == null) if (attribute == null)
{ {
@ -16,27 +16,21 @@ public class AuthorizationMiddleware : IMiddleware
return; return;
} }
Token? token = ctx.GetToken(); var token = ctx.GetToken();
if (token == null) if (token == null)
{
throw new ApiError.Unauthorized( throw new ApiError.Unauthorized(
"This endpoint requires an authenticated user.", "This endpoint requires an authenticated user.",
ErrorCode.AuthenticationRequired ErrorCode.AuthenticationRequired
); );
}
if ( if (
attribute.Scopes.Length > 0 attribute.Scopes.Length > 0
&& attribute.Scopes.Except(token.Scopes.ExpandScopes()).Any() && attribute.Scopes.Except(token.Scopes.ExpandScopes()).Any()
) )
{
throw new ApiError.Forbidden( throw new ApiError.Forbidden(
"This endpoint requires ungranted scopes.", "This endpoint requires ungranted scopes.",
attribute.Scopes.Except(token.Scopes.ExpandScopes()), attribute.Scopes.Except(token.Scopes.ExpandScopes()),
ErrorCode.MissingScopes ErrorCode.MissingScopes
); );
}
if (attribute.RequireAdmin && token.User.Role != UserRole.Admin) if (attribute.RequireAdmin && token.User.Role != UserRole.Admin)
throw new ApiError.Forbidden("This endpoint can only be used by admins."); throw new ApiError.Forbidden("This endpoint can only be used by admins.");
if ( if (
@ -44,9 +38,7 @@ public class AuthorizationMiddleware : IMiddleware
&& token.User.Role != UserRole.Admin && token.User.Role != UserRole.Admin
&& token.User.Role != UserRole.Moderator && token.User.Role != UserRole.Moderator
) )
{
throw new ApiError.Forbidden("This endpoint can only be used by moderators."); throw new ApiError.Forbidden("This endpoint can only be used by moderators.");
}
await next(ctx); await next(ctx);
} }

View file

@ -1,5 +1,4 @@
using System.Net; using System.Net;
using Foxnouns.Backend.Database.Models;
using Foxnouns.Backend.Utils; using Foxnouns.Backend.Utils;
using Newtonsoft.Json; using Newtonsoft.Json;
@ -15,9 +14,9 @@ public class ErrorHandlerMiddleware(ILogger baseLogger, IHub sentry) : IMiddlewa
} }
catch (Exception e) catch (Exception e)
{ {
Type type = e.TargetSite?.DeclaringType ?? typeof(ErrorHandlerMiddleware); var type = e.TargetSite?.DeclaringType ?? typeof(ErrorHandlerMiddleware);
string typeName = e.TargetSite?.DeclaringType?.FullName ?? "<unknown>"; var typeName = e.TargetSite?.DeclaringType?.FullName ?? "<unknown>";
ILogger logger = baseLogger.ForContext(type); var logger = baseLogger.ForContext(type);
if (ctx.Response.HasStarted) if (ctx.Response.HasStarted)
{ {
@ -32,16 +31,14 @@ public class ErrorHandlerMiddleware(ILogger baseLogger, IHub sentry) : IMiddlewa
e, e,
scope => scope =>
{ {
User? user = ctx.GetUser(); var user = ctx.GetUser();
if (user != null) if (user != null)
{
scope.User = new SentryUser scope.User = new SentryUser
{ {
Id = user.Id.ToString(), Id = user.Id.ToString(),
Username = user.Username, Username = user.Username,
}; };
} }
}
); );
return; return;
@ -101,20 +98,18 @@ public class ErrorHandlerMiddleware(ILogger baseLogger, IHub sentry) : IMiddlewa
logger.Error(e, "Exception in {ClassName} ({Path})", typeName, ctx.Request.Path); logger.Error(e, "Exception in {ClassName} ({Path})", typeName, ctx.Request.Path);
} }
SentryId errorId = sentry.CaptureException( var errorId = sentry.CaptureException(
e, e,
scope => scope =>
{ {
User? user = ctx.GetUser(); var user = ctx.GetUser();
if (user != null) if (user != null)
{
scope.User = new SentryUser scope.User = new SentryUser
{ {
Id = user.Id.ToString(), Id = user.Id.ToString(),
Username = user.Username, Username = user.Username,
}; };
} }
}
); );
ctx.Response.StatusCode = (int)HttpStatusCode.InternalServerError; ctx.Response.StatusCode = (int)HttpStatusCode.InternalServerError;

View file

@ -9,9 +9,9 @@ using Prometheus;
using Sentry.Extensibility; using Sentry.Extensibility;
using Serilog; using Serilog;
WebApplicationBuilder builder = WebApplication.CreateBuilder(args); var builder = WebApplication.CreateBuilder(args);
Config config = builder.AddConfiguration(); var config = builder.AddConfiguration();
builder.AddSerilog(); builder.AddSerilog();
@ -58,7 +58,7 @@ JsonConvert.DefaultSettings = () =>
builder.AddServices(config).AddCustomMiddleware().AddEndpointsApiExplorer().AddSwaggerGen(); builder.AddServices(config).AddCustomMiddleware().AddEndpointsApiExplorer().AddSwaggerGen();
WebApplication app = builder.Build(); var app = builder.Build();
await app.Initialize(args); await app.Initialize(args);

View file

@ -31,16 +31,6 @@ public class AuthService(
CancellationToken ct = default CancellationToken ct = default
) )
{ {
// Validate username and whether it's not taken
ValidationUtils.Validate(
[
("username", ValidationUtils.ValidateUsername(username)),
("password", ValidationUtils.ValidatePassword(password)),
]
);
if (await db.Users.AnyAsync(u => u.Username == username, ct))
throw new ApiError.BadRequest("Username is already taken", "username", username);
var user = new User var user = new User
{ {
Id = snowflakeGenerator.GenerateSnowflake(), Id = snowflakeGenerator.GenerateSnowflake(),
@ -59,7 +49,7 @@ public class AuthService(
}; };
db.Add(user); db.Add(user);
user.Password = await HashPasswordAsync(user, password, ct); user.Password = await Task.Run(() => _passwordHasher.HashPassword(user, password), ct);
return user; return user;
} }
@ -80,8 +70,6 @@ public class AuthService(
{ {
AssertValidAuthType(authType, instance); AssertValidAuthType(authType, instance);
// Validate username and whether it's not taken
ValidationUtils.Validate([("username", ValidationUtils.ValidateUsername(username))]);
if (await db.Users.AnyAsync(u => u.Username == username, ct)) if (await db.Users.AnyAsync(u => u.Username == username, ct))
throw new ApiError.BadRequest("Username is already taken", "username", username); throw new ApiError.BadRequest("Username is already taken", "username", username);
@ -123,30 +111,28 @@ public class AuthService(
CancellationToken ct = default CancellationToken ct = default
) )
{ {
User? user = await db.Users.FirstOrDefaultAsync( var user = await db.Users.FirstOrDefaultAsync(
u => u.AuthMethods.Any(a => a.AuthType == AuthType.Email && a.RemoteId == email), u => u.AuthMethods.Any(a => a.AuthType == AuthType.Email && a.RemoteId == email),
ct ct
); );
if (user == null) if (user == null)
{
throw new ApiError.NotFound( throw new ApiError.NotFound(
"No user with that email address found, or password is incorrect", "No user with that email address found, or password is incorrect",
ErrorCode.UserNotFound ErrorCode.UserNotFound
); );
}
PasswordVerificationResult pwResult = await VerifyHashedPasswordAsync(user, password, ct); var pwResult = await Task.Run(
() => _passwordHasher.VerifyHashedPassword(user, user.Password!, password),
ct
);
if (pwResult == PasswordVerificationResult.Failed) // TODO: this seems to fail on some valid passwords? if (pwResult == PasswordVerificationResult.Failed) // TODO: this seems to fail on some valid passwords?
{
throw new ApiError.NotFound( throw new ApiError.NotFound(
"No user with that email address found, or password is incorrect", "No user with that email address found, or password is incorrect",
ErrorCode.UserNotFound ErrorCode.UserNotFound
); );
}
if (pwResult == PasswordVerificationResult.SuccessRehashNeeded) if (pwResult == PasswordVerificationResult.SuccessRehashNeeded)
{ {
user.Password = await HashPasswordAsync(user, password, ct); user.Password = await Task.Run(() => _passwordHasher.HashPassword(user, password), ct);
await db.SaveChangesAsync(ct); await db.SaveChangesAsync(ct);
} }
@ -174,7 +160,10 @@ public class AuthService(
throw new FoxnounsError("Password for user supplied to ValidatePasswordAsync was null"); throw new FoxnounsError("Password for user supplied to ValidatePasswordAsync was null");
} }
PasswordVerificationResult pwResult = await VerifyHashedPasswordAsync(user, password, ct); var pwResult = await Task.Run(
() => _passwordHasher.VerifyHashedPassword(user, user.Password!, password),
ct
);
return pwResult return pwResult
is PasswordVerificationResult.SuccessRehashNeeded is PasswordVerificationResult.SuccessRehashNeeded
or PasswordVerificationResult.Success; or PasswordVerificationResult.Success;
@ -189,7 +178,7 @@ public class AuthService(
CancellationToken ct = default CancellationToken ct = default
) )
{ {
user.Password = await HashPasswordAsync(user, password, ct); user.Password = await Task.Run(() => _passwordHasher.HashPassword(user, password), ct);
db.Update(user); db.Update(user);
} }
@ -236,15 +225,13 @@ public class AuthService(
AssertValidAuthType(authType, app); AssertValidAuthType(authType, app);
// This is already checked when // This is already checked when
int currentCount = await db var currentCount = await db
.AuthMethods.Where(m => m.UserId == userId && m.AuthType == authType) .AuthMethods.Where(m => m.UserId == userId && m.AuthType == authType)
.CountAsync(ct); .CountAsync(ct);
if (currentCount >= AuthUtils.MaxAuthMethodsPerType) if (currentCount >= AuthUtils.MaxAuthMethodsPerType)
{
throw new ApiError.BadRequest( throw new ApiError.BadRequest(
"Too many linked accounts of this type, maximum of 3 per account." "Too many linked accounts of this type, maximum of 3 per account."
); );
}
var authMethod = new AuthMethod var authMethod = new AuthMethod
{ {
@ -269,15 +256,13 @@ public class AuthService(
) )
{ {
if (!AuthUtils.ValidateScopes(application, scopes)) if (!AuthUtils.ValidateScopes(application, scopes))
{
throw new ApiError.BadRequest( throw new ApiError.BadRequest(
"Invalid scopes requested for this token", "Invalid scopes requested for this token",
"scopes", "scopes",
scopes scopes
); );
}
(string? token, byte[]? hash) = GenerateToken(); var (token, hash) = GenerateToken();
return ( return (
token, token,
new Token new Token
@ -302,9 +287,9 @@ public class AuthService(
CancellationToken ct = default CancellationToken ct = default
) )
{ {
Application frontendApp = await db.GetFrontendApplicationAsync(ct); var frontendApp = await db.GetFrontendApplicationAsync(ct);
(string? tokenStr, Token? token) = GenerateToken( var (tokenStr, token) = GenerateToken(
user, user,
frontendApp, frontendApp,
["*"], ["*"],
@ -317,35 +302,24 @@ public class AuthService(
await db.SaveChangesAsync(ct); await db.SaveChangesAsync(ct);
return new CallbackResponse( return new CallbackResponse(
true, HasAccount: true,
null, Ticket: null,
null, RemoteUsername: null,
await userRenderer.RenderUserAsync(user, user, renderMembers: false, ct: ct), User: await userRenderer.RenderUserAsync(
tokenStr, user,
token.ExpiresAt selfUser: user,
renderMembers: false,
ct: ct
),
Token: tokenStr,
ExpiresAt: token.ExpiresAt
); );
} }
private Task<string> HashPasswordAsync(
User user,
string password,
CancellationToken ct = default
) => Task.Run(() => _passwordHasher.HashPassword(user, password), ct);
private Task<PasswordVerificationResult> VerifyHashedPasswordAsync(
User user,
string providedPassword,
CancellationToken ct = default
) =>
Task.Run(
() => _passwordHasher.VerifyHashedPassword(user, user.Password!, providedPassword),
ct
);
private static (string, byte[]) GenerateToken() private static (string, byte[]) GenerateToken()
{ {
string token = AuthUtils.RandomToken(); var token = AuthUtils.RandomToken();
byte[] hash = SHA512.HashData(Convert.FromBase64String(token)); var hash = SHA512.HashData(Convert.FromBase64String(token));
return (token, hash); return (token, hash);
} }

View file

@ -18,25 +18,22 @@ public partial class FediverseAuthService
Snowflake? existingAppId = null Snowflake? existingAppId = null
) )
{ {
HttpResponseMessage resp = await _client.PostAsJsonAsync( var resp = await _client.PostAsJsonAsync(
$"https://{instance}/api/v1/apps", $"https://{instance}/api/v1/apps",
new CreateMastodonApplicationRequest( new CreateMastodonApplicationRequest(
$"pronouns.cc (+{_config.BaseUrl})", ClientName: $"pronouns.cc (+{_config.BaseUrl})",
MastodonRedirectUri(instance), RedirectUris: MastodonRedirectUri(instance),
"read read:accounts", Scopes: "read read:accounts",
_config.BaseUrl Website: _config.BaseUrl
) )
); );
resp.EnsureSuccessStatusCode(); resp.EnsureSuccessStatusCode();
PartialMastodonApplication? mastodonApp = var mastodonApp = await resp.Content.ReadFromJsonAsync<PartialMastodonApplication>();
await resp.Content.ReadFromJsonAsync<PartialMastodonApplication>();
if (mastodonApp == null) if (mastodonApp == null)
{
throw new FoxnounsError( throw new FoxnounsError(
$"Application created on Mastodon-compatible instance {instance} was null" $"Application created on Mastodon-compatible instance {instance} was null"
); );
}
FediverseApplication app; FediverseApplication app;
@ -78,7 +75,7 @@ public partial class FediverseAuthService
if (state != null) if (state != null)
await _keyCacheService.ValidateAuthStateAsync(state); await _keyCacheService.ValidateAuthStateAsync(state);
HttpResponseMessage tokenResp = await _client.PostAsync( var tokenResp = await _client.PostAsync(
MastodonTokenUri(app.Domain), MastodonTokenUri(app.Domain),
new FormUrlEncodedContent( new FormUrlEncodedContent(
new Dictionary<string, string> new Dictionary<string, string>
@ -98,7 +95,7 @@ public partial class FediverseAuthService
} }
tokenResp.EnsureSuccessStatusCode(); tokenResp.EnsureSuccessStatusCode();
string? token = ( var token = (
await tokenResp.Content.ReadFromJsonAsync<MastodonTokenResponse>() await tokenResp.Content.ReadFromJsonAsync<MastodonTokenResponse>()
)?.AccessToken; )?.AccessToken;
if (token == null) if (token == null)
@ -109,9 +106,9 @@ public partial class FediverseAuthService
var req = new HttpRequestMessage(HttpMethod.Get, MastodonCurrentUserUri(app.Domain)); var req = new HttpRequestMessage(HttpMethod.Get, MastodonCurrentUserUri(app.Domain));
req.Headers.Add("Authorization", $"Bearer {token}"); req.Headers.Add("Authorization", $"Bearer {token}");
HttpResponseMessage currentUserResp = await _client.SendAsync(req); var currentUserResp = await _client.SendAsync(req);
currentUserResp.EnsureSuccessStatusCode(); currentUserResp.EnsureSuccessStatusCode();
FediverseUser? user = await currentUserResp.Content.ReadFromJsonAsync<FediverseUser>(); var user = await currentUserResp.Content.ReadFromJsonAsync<FediverseUser>();
if (user == null) if (user == null)
{ {
throw new FoxnounsError($"User response from instance {app.Domain} was invalid"); throw new FoxnounsError($"User response from instance {app.Domain} was invalid");
@ -134,7 +131,7 @@ public partial class FediverseAuthService
"An app credentials refresh was requested for {ApplicationId}, creating a new application", "An app credentials refresh was requested for {ApplicationId}, creating a new application",
app.Id app.Id
); );
app = await CreateMastodonApplicationAsync(app.Domain, app.Id); app = await CreateMastodonApplicationAsync(app.Domain, existingAppId: app.Id);
} }
state ??= HttpUtility.UrlEncode(await _keyCacheService.GenerateAuthStateAsync()); state ??= HttpUtility.UrlEncode(await _keyCacheService.GenerateAuthStateAsync());

View file

@ -43,7 +43,7 @@ public partial class FediverseAuthService
string? state = null string? state = null
) )
{ {
FediverseApplication app = await GetApplicationAsync(instance); var app = await GetApplicationAsync(instance);
return await GenerateAuthUrlAsync(app, forceRefresh, state); return await GenerateAuthUrlAsync(app, forceRefresh, state);
} }
@ -56,15 +56,13 @@ public partial class FediverseAuthService
public async Task<FediverseApplication> GetApplicationAsync(string instance) public async Task<FediverseApplication> GetApplicationAsync(string instance)
{ {
FediverseApplication? app = await _db.FediverseApplications.FirstOrDefaultAsync(a => var app = await _db.FediverseApplications.FirstOrDefaultAsync(a => a.Domain == instance);
a.Domain == instance
);
if (app != null) if (app != null)
return app; return app;
_logger.Debug("No application for fediverse instance {Instance}, creating it", instance); _logger.Debug("No application for fediverse instance {Instance}, creating it", instance);
string softwareName = await GetSoftwareNameAsync(instance); var softwareName = await GetSoftwareNameAsync(instance);
if (IsMastodonCompatible(softwareName)) if (IsMastodonCompatible(softwareName))
{ {
@ -78,14 +76,13 @@ public partial class FediverseAuthService
{ {
_logger.Debug("Requesting software name for fediverse instance {Instance}", instance); _logger.Debug("Requesting software name for fediverse instance {Instance}", instance);
HttpResponseMessage wellKnownResp = await _client.GetAsync( var wellKnownResp = await _client.GetAsync(
new Uri($"https://{instance}/.well-known/nodeinfo") new Uri($"https://{instance}/.well-known/nodeinfo")
); );
wellKnownResp.EnsureSuccessStatusCode(); wellKnownResp.EnsureSuccessStatusCode();
WellKnownResponse? wellKnown = var wellKnown = await wellKnownResp.Content.ReadFromJsonAsync<WellKnownResponse>();
await wellKnownResp.Content.ReadFromJsonAsync<WellKnownResponse>(); var nodeInfoUrl = wellKnown?.Links.FirstOrDefault(l => l.Rel == NodeInfoRel)?.Href;
string? nodeInfoUrl = wellKnown?.Links.FirstOrDefault(l => l.Rel == NodeInfoRel)?.Href;
if (nodeInfoUrl == null) if (nodeInfoUrl == null)
{ {
throw new FoxnounsError( throw new FoxnounsError(
@ -93,10 +90,10 @@ public partial class FediverseAuthService
); );
} }
HttpResponseMessage nodeInfoResp = await _client.GetAsync(nodeInfoUrl); var nodeInfoResp = await _client.GetAsync(nodeInfoUrl);
nodeInfoResp.EnsureSuccessStatusCode(); nodeInfoResp.EnsureSuccessStatusCode();
PartialNodeInfo? nodeInfo = await nodeInfoResp.Content.ReadFromJsonAsync<PartialNodeInfo>(); var nodeInfo = await nodeInfoResp.Content.ReadFromJsonAsync<PartialNodeInfo>();
return nodeInfo?.Software.Name return nodeInfo?.Software.Name
?? throw new FoxnounsError( ?? throw new FoxnounsError(
$"Nodeinfo response for instance {instance} was invalid, no software name" $"Nodeinfo response for instance {instance} was invalid, no software name"

View file

@ -29,7 +29,7 @@ public class RemoteAuthService(
) )
{ {
var redirectUri = $"{config.BaseUrl}/auth/callback/discord"; var redirectUri = $"{config.BaseUrl}/auth/callback/discord";
HttpResponseMessage resp = await _httpClient.PostAsync( var resp = await _httpClient.PostAsync(
_discordTokenUri, _discordTokenUri,
new FormUrlEncodedContent( new FormUrlEncodedContent(
new Dictionary<string, string> new Dictionary<string, string>
@ -45,7 +45,7 @@ public class RemoteAuthService(
); );
if (!resp.IsSuccessStatusCode) if (!resp.IsSuccessStatusCode)
{ {
string respBody = await resp.Content.ReadAsStringAsync(ct); var respBody = await resp.Content.ReadAsStringAsync(ct);
_logger.Error( _logger.Error(
"Received error status {StatusCode} when exchanging OAuth token: {ErrorBody}", "Received error status {StatusCode} when exchanging OAuth token: {ErrorBody}",
(int)resp.StatusCode, (int)resp.StatusCode,
@ -55,18 +55,16 @@ public class RemoteAuthService(
} }
resp.EnsureSuccessStatusCode(); resp.EnsureSuccessStatusCode();
DiscordTokenResponse? token = await resp.Content.ReadFromJsonAsync<DiscordTokenResponse>( var token = await resp.Content.ReadFromJsonAsync<DiscordTokenResponse>(ct);
ct
);
if (token == null) if (token == null)
throw new FoxnounsError("Discord token response was null"); throw new FoxnounsError("Discord token response was null");
var req = new HttpRequestMessage(HttpMethod.Get, _discordUserUri); var req = new HttpRequestMessage(HttpMethod.Get, _discordUserUri);
req.Headers.Add("Authorization", $"{token.token_type} {token.access_token}"); req.Headers.Add("Authorization", $"{token.token_type} {token.access_token}");
HttpResponseMessage resp2 = await _httpClient.SendAsync(req, ct); var resp2 = await _httpClient.SendAsync(req, ct);
resp2.EnsureSuccessStatusCode(); resp2.EnsureSuccessStatusCode();
DiscordUserResponse? user = await resp2.Content.ReadFromJsonAsync<DiscordUserResponse>(ct); var user = await resp2.Content.ReadFromJsonAsync<DiscordUserResponse>(ct);
if (user == null) if (user == null)
throw new FoxnounsError("Discord user response was null"); throw new FoxnounsError("Discord user response was null");
@ -106,7 +104,7 @@ public class RemoteAuthService(
string? instance = null string? instance = null
) )
{ {
int existingAccounts = await db var existingAccounts = await db
.AuthMethods.Where(m => m.UserId == userId && m.AuthType == authType) .AuthMethods.Where(m => m.UserId == userId && m.AuthType == authType)
.CountAsync(); .CountAsync();
if (existingAccounts > AuthUtils.MaxAuthMethodsPerType) if (existingAccounts > AuthUtils.MaxAuthMethodsPerType)
@ -133,17 +131,13 @@ public class RemoteAuthService(
string? instance = null string? instance = null
) )
{ {
AddExtraAccountState? accountState = await keyCacheService.GetAddExtraAccountStateAsync( var accountState = await keyCacheService.GetAddExtraAccountStateAsync(state);
state
);
if ( if (
accountState == null accountState == null
|| accountState.AuthType != authType || accountState.AuthType != authType
|| accountState.UserId != userId || accountState.UserId != userId
|| (instance != null && accountState.Instance != instance) || (instance != null && accountState.Instance != instance)
) )
{
throw new ApiError.BadRequest("Invalid state", "state", state); throw new ApiError.BadRequest("Invalid state", "state", state);
} }
}
} }

View file

@ -28,9 +28,9 @@ public class DataCleanupService(
private async Task CleanUsersAsync(CancellationToken ct = default) private async Task CleanUsersAsync(CancellationToken ct = default)
{ {
Instant selfDeleteExpires = clock.GetCurrentInstant() - User.DeleteAfter; var selfDeleteExpires = clock.GetCurrentInstant() - User.DeleteAfter;
Instant suspendExpires = clock.GetCurrentInstant() - User.DeleteSuspendedAfter; var suspendExpires = clock.GetCurrentInstant() - User.DeleteSuspendedAfter;
List<User> users = await db var users = await db
.Users.Include(u => u.Members) .Users.Include(u => u.Members)
.Include(u => u.DataExports) .Include(u => u.DataExports)
.Where(u => .Where(u =>
@ -92,15 +92,13 @@ public class DataCleanupService(
private async Task CleanExportsAsync(CancellationToken ct = default) private async Task CleanExportsAsync(CancellationToken ct = default)
{ {
var minExpiredId = Snowflake.FromInstant(clock.GetCurrentInstant() - DataExport.Expiration); var minExpiredId = Snowflake.FromInstant(clock.GetCurrentInstant() - DataExport.Expiration);
List<DataExport> exports = await db var exports = await db.DataExports.Where(d => d.Id < minExpiredId).ToListAsync(ct);
.DataExports.Where(d => d.Id < minExpiredId)
.ToListAsync(ct);
if (exports.Count == 0) if (exports.Count == 0)
return; return;
_logger.Debug("Deleting {Count} expired exports", exports.Count); _logger.Debug("Deleting {Count} expired exports", exports.Count);
foreach (DataExport? export in exports) foreach (var export in exports)
{ {
_logger.Debug("Deleting export {ExportId}", export.Id); _logger.Debug("Deleting export {ExportId}", export.Id);
await objectStorageService.RemoveObjectAsync( await objectStorageService.RemoveObjectAsync(

View file

@ -41,7 +41,7 @@ public class KeyCacheService(DatabaseContext db, IClock clock, ILogger logger)
CancellationToken ct = default CancellationToken ct = default
) )
{ {
TemporaryKey? value = await db.TemporaryKeys.FirstOrDefaultAsync(k => k.Key == key, ct); var value = await db.TemporaryKeys.FirstOrDefaultAsync(k => k.Key == key, ct);
if (value == null) if (value == null)
return null; return null;
@ -56,7 +56,7 @@ public class KeyCacheService(DatabaseContext db, IClock clock, ILogger logger)
public async Task DeleteExpiredKeysAsync(CancellationToken ct) public async Task DeleteExpiredKeysAsync(CancellationToken ct)
{ {
int count = await db var count = await db
.TemporaryKeys.Where(k => k.Expires < clock.GetCurrentInstant()) .TemporaryKeys.Where(k => k.Expires < clock.GetCurrentInstant())
.ExecuteDeleteAsync(ct); .ExecuteDeleteAsync(ct);
if (count != 0) if (count != 0)
@ -79,7 +79,7 @@ public class KeyCacheService(DatabaseContext db, IClock clock, ILogger logger)
) )
where T : class where T : class
{ {
string value = JsonConvert.SerializeObject(obj); var value = JsonConvert.SerializeObject(obj);
await SetKeyAsync(key, value, expires, ct); await SetKeyAsync(key, value, expires, ct);
} }
@ -90,7 +90,7 @@ public class KeyCacheService(DatabaseContext db, IClock clock, ILogger logger)
) )
where T : class where T : class
{ {
string? value = await GetKeyAsync(key, delete, ct); var value = await GetKeyAsync(key, delete, ct);
return value == null ? default : JsonConvert.DeserializeObject<T>(value); return value == null ? default : JsonConvert.DeserializeObject<T>(value);
} }
} }

View file

@ -10,11 +10,11 @@ public class MemberRendererService(DatabaseContext db, Config config)
{ {
public async Task<IEnumerable<PartialMember>> RenderUserMembersAsync(User user, Token? token) public async Task<IEnumerable<PartialMember>> RenderUserMembersAsync(User user, Token? token)
{ {
bool canReadHiddenMembers = var canReadHiddenMembers =
token != null && token.UserId == user.Id && token.HasScope("member.read"); token != null && token.UserId == user.Id && token.HasScope("member.read");
bool renderUnlisted = var renderUnlisted =
token != null && token.UserId == user.Id && token.HasScope("user.read_hidden"); token != null && token.UserId == user.Id && token.HasScope("user.read_hidden");
bool canReadMemberList = !user.ListHidden || canReadHiddenMembers; var canReadMemberList = !user.ListHidden || canReadHiddenMembers;
IEnumerable<Member> members = canReadMemberList IEnumerable<Member> members = canReadMemberList
? await db.Members.Where(m => m.UserId == user.Id).OrderBy(m => m.Name).ToListAsync() ? await db.Members.Where(m => m.UserId == user.Id).OrderBy(m => m.Name).ToListAsync()
@ -30,7 +30,7 @@ public class MemberRendererService(DatabaseContext db, Config config)
string? overrideSid = null string? overrideSid = null
) )
{ {
bool renderUnlisted = token?.UserId == member.UserId && token.HasScope("user.read_hidden"); var renderUnlisted = token?.UserId == member.UserId && token.HasScope("user.read_hidden");
return new MemberResponse( return new MemberResponse(
member.Id, member.Id,

View file

@ -3,7 +3,6 @@ using Foxnouns.Backend.Database;
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using NodaTime; using NodaTime;
using Prometheus; using Prometheus;
using ITimer = Prometheus.ITimer;
namespace Foxnouns.Backend.Services; namespace Foxnouns.Backend.Services;
@ -17,23 +16,19 @@ public class MetricsCollectionService(ILogger logger, IServiceProvider services,
public async Task CollectMetricsAsync(CancellationToken ct = default) public async Task CollectMetricsAsync(CancellationToken ct = default)
{ {
ITimer timer = FoxnounsMetrics.MetricsCollectionTime.NewTimer(); var timer = FoxnounsMetrics.MetricsCollectionTime.NewTimer();
Instant now = clock.GetCurrentInstant(); var now = clock.GetCurrentInstant();
await using AsyncServiceScope scope = services.CreateAsyncScope(); await using var scope = services.CreateAsyncScope();
// ReSharper disable once SuggestVarOrType_SimpleTypes
await using var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>(); await using var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
List<Instant>? users = await db var users = await db.Users.Where(u => !u.Deleted).Select(u => u.LastActive).ToListAsync(ct);
.Users.Where(u => !u.Deleted)
.Select(u => u.LastActive)
.ToListAsync(ct);
FoxnounsMetrics.UsersCount.Set(users.Count); FoxnounsMetrics.UsersCount.Set(users.Count);
FoxnounsMetrics.UsersActiveMonthCount.Set(users.Count(i => i > now - Month)); FoxnounsMetrics.UsersActiveMonthCount.Set(users.Count(i => i > now - Month));
FoxnounsMetrics.UsersActiveWeekCount.Set(users.Count(i => i > now - Week)); FoxnounsMetrics.UsersActiveWeekCount.Set(users.Count(i => i > now - Week));
FoxnounsMetrics.UsersActiveDayCount.Set(users.Count(i => i > now - Day)); FoxnounsMetrics.UsersActiveDayCount.Set(users.Count(i => i > now - Day));
int memberCount = await db var memberCount = await db
.Members.Include(m => m.User) .Members.Include(m => m.User)
.Where(m => !m.Unlisted && !m.User.ListHidden && !m.User.Deleted) .Where(m => !m.Unlisted && !m.User.ListHidden && !m.User.Deleted)
.CountAsync(ct); .CountAsync(ct);

View file

@ -1,5 +1,4 @@
using Minio; using Minio;
using Minio.DataModel;
using Minio.DataModel.Args; using Minio.DataModel.Args;
using Minio.Exceptions; using Minio.Exceptions;
@ -49,4 +48,13 @@ public class ObjectStorageService(ILogger logger, Config config, IMinioClient mi
ct ct
); );
} }
public async Task GetObjectAsync(string path, CancellationToken ct = default)
{
var stream = new MemoryStream();
var resp = await minioClient.GetObjectAsync(
new GetObjectArgs().WithBucket(config.Storage.Bucket).WithObject(path),
ct
);
}
} }

View file

@ -15,13 +15,10 @@ public class PeriodicTasksService(ILogger logger, IServiceProvider services) : B
{ {
_logger.Debug("Running periodic tasks"); _logger.Debug("Running periodic tasks");
await using AsyncServiceScope scope = services.CreateAsyncScope(); await using var scope = services.CreateAsyncScope();
// The type is literally written on the same line, we can just use `var`
// ReSharper disable SuggestVarOrType_SimpleTypes
var keyCacheService = scope.ServiceProvider.GetRequiredService<KeyCacheService>(); var keyCacheService = scope.ServiceProvider.GetRequiredService<KeyCacheService>();
var dataCleanupService = scope.ServiceProvider.GetRequiredService<DataCleanupService>(); var dataCleanupService = scope.ServiceProvider.GetRequiredService<DataCleanupService>();
// ReSharper restore SuggestVarOrType_SimpleTypes
await keyCacheService.DeleteExpiredKeysAsync(ct); await keyCacheService.DeleteExpiredKeysAsync(ct);
await dataCleanupService.InvokeAsync(ct); await dataCleanupService.InvokeAsync(ct);

View file

@ -43,9 +43,9 @@ public class UserRendererService(
) )
{ {
scopes = scopes.ExpandScopes(); scopes = scopes.ExpandScopes();
bool tokenCanReadHiddenMembers = scopes.Contains("member.read") && isSelfUser; var tokenCanReadHiddenMembers = scopes.Contains("member.read") && isSelfUser;
bool tokenHidden = scopes.Contains("user.read_hidden") && isSelfUser; var tokenHidden = scopes.Contains("user.read_hidden") && isSelfUser;
bool tokenPrivileged = scopes.Contains("user.read_privileged") && isSelfUser; var tokenPrivileged = scopes.Contains("user.read_privileged") && isSelfUser;
renderMembers = renderMembers && (!user.ListHidden || tokenCanReadHiddenMembers); renderMembers = renderMembers && (!user.ListHidden || tokenCanReadHiddenMembers);
renderAuthMethods = renderAuthMethods && tokenPrivileged; renderAuthMethods = renderAuthMethods && tokenPrivileged;
@ -57,12 +57,12 @@ public class UserRendererService(
if (!(isSelfUser && tokenCanReadHiddenMembers)) if (!(isSelfUser && tokenCanReadHiddenMembers))
members = members.Where(m => !m.Unlisted); members = members.Where(m => !m.Unlisted);
List<UserFlag> flags = await db var flags = await db
.UserFlags.Where(f => f.UserId == user.Id) .UserFlags.Where(f => f.UserId == user.Id)
.OrderBy(f => f.Id) .OrderBy(f => f.Id)
.ToListAsync(ct); .ToListAsync(ct);
List<AuthMethod> authMethods = renderAuthMethods var authMethods = renderAuthMethods
? await db ? await db
.AuthMethods.Where(a => a.UserId == user.Id) .AuthMethods.Where(a => a.UserId == user.Id)
.Include(a => a.FediverseApplication) .Include(a => a.FediverseApplication)
@ -72,11 +72,9 @@ public class UserRendererService(
int? utcOffset = null; int? utcOffset = null;
if ( if (
user.Timezone != null user.Timezone != null
&& TimeZoneInfo.TryFindSystemTimeZoneById(user.Timezone, out TimeZoneInfo? tz) && TimeZoneInfo.TryFindSystemTimeZoneById(user.Timezone, out var tz)
) )
{
utcOffset = (int)tz.GetUtcOffset(DateTimeOffset.UtcNow).TotalSeconds; utcOffset = (int)tz.GetUtcOffset(DateTimeOffset.UtcNow).TotalSeconds;
}
return new UserResponse( return new UserResponse(
user.Id, user.Id,

View file

@ -69,8 +69,8 @@ public static class AuthUtils
public static bool ValidateScopes(Application application, string[] scopes) public static bool ValidateScopes(Application application, string[] scopes)
{ {
string[] expandedScopes = scopes.ExpandScopes(); var expandedScopes = scopes.ExpandScopes();
string[] appScopes = application.Scopes.ExpandAppScopes(); var appScopes = application.Scopes.ExpandAppScopes();
return !expandedScopes.Except(appScopes).Any(); return !expandedScopes.Except(appScopes).Any();
} }
@ -78,7 +78,7 @@ public static class AuthUtils
{ {
try try
{ {
string scheme = new Uri(uri).Scheme; var scheme = new Uri(uri).Scheme;
return !ForbiddenSchemes.Contains(scheme); return !ForbiddenSchemes.Contains(scheme);
} }
catch catch

View file

@ -5,11 +5,10 @@ using Newtonsoft.Json.Serialization;
namespace Foxnouns.Backend.Utils; namespace Foxnouns.Backend.Utils;
/// <summary> /// <summary>
/// <para>A base class used for PATCH requests which stores information on whether a key is explicitly set to null or not passed at all.</para> /// A base class used for PATCH requests which stores information on whether a key is explicitly set to null or not passed at all.
/// <para> ///
/// HasProperty() should not be used for properties that cannot be set to null--a null value should be treated /// HasProperty() should not be used for properties that cannot be set to null--a null value should be treated
/// as an unset value in those cases. /// as an unset value in those cases.
/// </para>
/// </summary> /// </summary>
public abstract class PatchRequest public abstract class PatchRequest
{ {
@ -31,7 +30,7 @@ public class PatchRequestContractResolver : DefaultContractResolver
MemberSerialization memberSerialization MemberSerialization memberSerialization
) )
{ {
JsonProperty prop = base.CreateProperty(member, memberSerialization); var prop = base.CreateProperty(member, memberSerialization);
prop.SetIsSpecified += (o, _) => prop.SetIsSpecified += (o, _) =>
{ {

View file

@ -39,7 +39,6 @@ public static partial class ValidationUtils
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
if (fields.Count > 25) if (fields.Count > 25)
{
errors.Add( errors.Add(
( (
"fields", "fields",
@ -51,13 +50,11 @@ public static partial class ValidationUtils
) )
) )
); );
}
// No overwhelming this function, thank you // No overwhelming this function, thank you
if (fields.Count > 100) if (fields.Count > 100)
return errors; return errors;
foreach ((Field? field, int index) in fields.Select((field, index) => (field, index))) foreach (var (field, index) in fields.Select((field, index) => (field, index)))
{ {
switch (field.Name.Length) switch (field.Name.Length)
{ {
@ -114,7 +111,6 @@ public static partial class ValidationUtils
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
if (entries.Length > Limits.FieldEntriesLimit) if (entries.Length > Limits.FieldEntriesLimit)
{
errors.Add( errors.Add(
( (
errorPrefix, errorPrefix,
@ -126,19 +122,15 @@ public static partial class ValidationUtils
) )
) )
); );
}
// Same as above, no overwhelming this function with a ridiculous amount of entries // Same as above, no overwhelming this function with a ridiculous amount of entries
if (entries.Length > Limits.FieldEntriesLimit + 50) if (entries.Length > Limits.FieldEntriesLimit + 50)
return errors; return errors;
string[] customPreferenceIds = customPreferences.Keys.Select(id => id.ToString()).ToArray(); var customPreferenceIds =
customPreferences?.Keys.Select(id => id.ToString()).ToArray() ?? [];
foreach ( foreach (var (entry, entryIdx) in entries.Select((entry, entryIdx) => (entry, entryIdx)))
(FieldEntry? entry, int entryIdx) in entries.Select(
(entry, entryIdx) => (entry, entryIdx)
)
)
{ {
switch (entry.Value.Length) switch (entry.Value.Length)
{ {
@ -174,7 +166,6 @@ public static partial class ValidationUtils
!DefaultStatusOptions.Contains(entry.Status) !DefaultStatusOptions.Contains(entry.Status)
&& !customPreferenceIds.Contains(entry.Status) && !customPreferenceIds.Contains(entry.Status)
) )
{
errors.Add( errors.Add(
( (
$"{errorPrefix}.{entryIdx}.status", $"{errorPrefix}.{entryIdx}.status",
@ -182,7 +173,6 @@ public static partial class ValidationUtils
) )
); );
} }
}
return errors; return errors;
} }
@ -198,7 +188,6 @@ public static partial class ValidationUtils
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
if (entries.Length > Limits.FieldEntriesLimit) if (entries.Length > Limits.FieldEntriesLimit)
{
errors.Add( errors.Add(
( (
errorPrefix, errorPrefix,
@ -210,17 +199,15 @@ public static partial class ValidationUtils
) )
) )
); );
}
// Same as above, no overwhelming this function with a ridiculous amount of entries // Same as above, no overwhelming this function with a ridiculous amount of entries
if (entries.Length > Limits.FieldEntriesLimit + 50) if (entries.Length > Limits.FieldEntriesLimit + 50)
return errors; return errors;
string[] customPreferenceIds = customPreferences.Keys.Select(id => id.ToString()).ToArray(); var customPreferenceIds =
customPreferences?.Keys.Select(id => id.ToString()).ToList() ?? [];
foreach ( foreach (var (entry, entryIdx) in entries.Select((entry, entryIdx) => (entry, entryIdx)))
(Pronoun? entry, int entryIdx) in entries.Select((entry, entryIdx) => (entry, entryIdx))
)
{ {
switch (entry.Value.Length) switch (entry.Value.Length)
{ {
@ -289,7 +276,6 @@ public static partial class ValidationUtils
!DefaultStatusOptions.Contains(entry.Status) !DefaultStatusOptions.Contains(entry.Status)
&& !customPreferenceIds.Contains(entry.Status) && !customPreferenceIds.Contains(entry.Status)
) )
{
errors.Add( errors.Add(
( (
$"{errorPrefix}.{entryIdx}.status", $"{errorPrefix}.{entryIdx}.status",
@ -297,7 +283,6 @@ public static partial class ValidationUtils
) )
); );
} }
}
return errors; return errors;
} }

View file

@ -29,7 +29,6 @@ public static partial class ValidationUtils
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
if (preferences.Count > MaxCustomPreferences) if (preferences.Count > MaxCustomPreferences)
{
errors.Add( errors.Add(
( (
"custom_preferences", "custom_preferences",
@ -41,29 +40,20 @@ public static partial class ValidationUtils
) )
) )
); );
}
if (preferences.Count > 50) if (preferences.Count > 50)
return errors; return errors;
foreach ( foreach (var (p, i) in preferences.Select((p, i) => (p, i)))
(UsersController.CustomPreferenceUpdate? p, int i) in preferences.Select(
(p, i) => (p, i)
)
)
{ {
if (!BootstrapIcons.IsValid(p.Icon)) if (!BootstrapIcons.IsValid(p.Icon))
{
errors.Add( errors.Add(
( (
$"custom_preferences.{i}.icon", $"custom_preferences.{i}.icon",
ValidationError.DisallowedValueError("Invalid icon name", [], p.Icon) ValidationError.DisallowedValueError("Invalid icon name", [], p.Icon)
) )
); );
}
if (p.Tooltip.Length is 1 or > MaxPreferenceTooltipLength) if (p.Tooltip.Length is 1 or > MaxPreferenceTooltipLength)
{
errors.Add( errors.Add(
( (
$"custom_preferences.{i}.tooltip", $"custom_preferences.{i}.tooltip",
@ -76,7 +66,6 @@ public static partial class ValidationUtils
) )
); );
} }
}
return errors; return errors;
} }

View file

@ -46,7 +46,6 @@ public static partial class ValidationUtils
public static ValidationError? ValidateUsername(string username) public static ValidationError? ValidateUsername(string username)
{ {
if (!UsernameRegex().IsMatch(username)) if (!UsernameRegex().IsMatch(username))
{
return username.Length switch return username.Length switch
{ {
< 2 => ValidationError.LengthError("Username is too short", 2, 40, username.Length), < 2 => ValidationError.LengthError("Username is too short", 2, 40, username.Length),
@ -56,24 +55,19 @@ public static partial class ValidationUtils
username username
), ),
}; };
}
if ( if (
InvalidUsernames.Any(u => InvalidUsernames.Any(u =>
string.Equals(u, username, StringComparison.InvariantCultureIgnoreCase) string.Equals(u, username, StringComparison.InvariantCultureIgnoreCase)
) )
) )
{
return ValidationError.GenericValidationError("Username is not allowed", username); return ValidationError.GenericValidationError("Username is not allowed", username);
}
return null; return null;
} }
public static ValidationError? ValidateMemberName(string memberName) public static ValidationError? ValidateMemberName(string memberName)
{ {
if (!MemberRegex().IsMatch(memberName)) if (!MemberRegex().IsMatch(memberName))
{
return memberName.Length switch return memberName.Length switch
{ {
< 1 => ValidationError.LengthError("Name is too short", 1, 100, memberName.Length), < 1 => ValidationError.LengthError("Name is too short", 1, 100, memberName.Length),
@ -85,17 +79,13 @@ public static partial class ValidationUtils
memberName memberName
), ),
}; };
}
if ( if (
InvalidMemberNames.Any(u => InvalidMemberNames.Any(u =>
string.Equals(u, memberName, StringComparison.InvariantCultureIgnoreCase) string.Equals(u, memberName, StringComparison.InvariantCultureIgnoreCase)
) )
) )
{
return ValidationError.GenericValidationError("Name is not allowed", memberName); return ValidationError.GenericValidationError("Name is not allowed", memberName);
}
return null; return null;
} }
@ -127,15 +117,13 @@ public static partial class ValidationUtils
if (links == null) if (links == null)
return []; return [];
if (links.Length > MaxLinks) if (links.Length > MaxLinks)
{
return return
[ [
("links", ValidationError.LengthError("Too many links", 0, MaxLinks, links.Length)), ("links", ValidationError.LengthError("Too many links", 0, MaxLinks, links.Length)),
]; ];
}
var errors = new List<(string, ValidationError?)>(); var errors = new List<(string, ValidationError?)>();
foreach ((string link, int idx) in links.Select((l, i) => (l, i))) foreach (var (link, idx) in links.Select((l, i) => (l, i)))
{ {
switch (link.Length) switch (link.Length)
{ {
@ -197,27 +185,6 @@ public static partial class ValidationUtils
}; };
} }
public const int MinimumPasswordLength = 12;
public const int MaximumPasswordLength = 1024;
public static ValidationError? ValidatePassword(string password) =>
password.Length switch
{
< MinimumPasswordLength => ValidationError.LengthError(
"Password is too short",
MinimumPasswordLength,
MaximumPasswordLength,
password.Length
),
> MaximumPasswordLength => ValidationError.LengthError(
"Password is too long",
MinimumPasswordLength,
MaximumPasswordLength,
password.Length
),
_ => null,
};
[GeneratedRegex(@"^[a-zA-Z_0-9\-\.]{2,40}$", RegexOptions.IgnoreCase, "en-NL")] [GeneratedRegex(@"^[a-zA-Z_0-9\-\.]{2,40}$", RegexOptions.IgnoreCase, "en-NL")]
private static partial Regex UsernameRegex(); private static partial Regex UsernameRegex();

View file

@ -12,9 +12,9 @@ public static partial class ValidationUtils
return; return;
var errorDict = new Dictionary<string, IEnumerable<ValidationError>>(); var errorDict = new Dictionary<string, IEnumerable<ValidationError>>();
foreach ((string, ValidationError?) error in errors) foreach (var error in errors)
{ {
if (errorDict.TryGetValue(error.Item1, out IEnumerable<ValidationError>? value)) if (errorDict.TryGetValue(error.Item1, out var value))
errorDict[error.Item1] = value.Append(error.Item2!); errorDict[error.Item1] = value.Append(error.Item2!);
errorDict.Add(error.Item1, [error.Item2!]); errorDict.Add(error.Item1, [error.Item2!]);
} }

View file

@ -2,9 +2,9 @@
<p> <p>
Please continue creating a new pronouns.cc account by using the following link: Please continue creating a new pronouns.cc account by using the following link:
<br /> <br/>
<a href="@Model.BaseUrl/auth/callback/email/@Model.Code">Confirm your email address</a> <a href="@Model.BaseUrl/auth/signup/confirm/@Model.Code">Confirm your email address</a>
<br /> <br/>
Note that this link will expire in one hour. Note that this link will expire in one hour.
</p> </p>
<p> <p>

View file

@ -2,9 +2,9 @@
<p> <p>
Hello @@@Model.Username, please confirm adding this email address to your account by using the following link: Hello @@@Model.Username, please confirm adding this email address to your account by using the following link:
<br /> <br/>
<a href="@Model.BaseUrl/auth/callback/email/@Model.Code">Confirm your email address</a> <a href="@Model.BaseUrl/settings/auth/confirm-email/@Model.Code">Confirm your email address</a>
<br /> <br/>
Note that this link will expire in one hour. Note that this link will expire in one hour.
</p> </p>
<p> <p>

View file

@ -2,9 +2,9 @@
<html lang="en"> <html lang="en">
<head> <head>
<title></title> <title></title>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8" /> <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="viewport" content="width=device-width, initial-scale=1">
<meta http-equiv="X-UA-Compatible" content="IE=edge" /> <meta http-equiv="X-UA-Compatible" content="IE=edge"/>
<style> <style>
body { body {
font-family: sans-serif; font-family: sans-serif;

View file

@ -194,12 +194,6 @@
"prometheus-net": "8.2.1" "prometheus-net": "8.2.1"
} }
}, },
"Roslynator.Analyzers": {
"type": "Direct",
"requested": "[4.12.9, )",
"resolved": "4.12.9",
"contentHash": "X6lDpN/D5wuinq37KIx+l3GSUe9No+8bCjGBTI5sEEtxapLztkHg6gzNVhMXpXw8P+/5gFYxTXJ5Pf8O4iNz/w=="
},
"Sentry.AspNetCore": { "Sentry.AspNetCore": {
"type": "Direct", "type": "Direct",
"requested": "[4.9.0, )", "requested": "[4.9.0, )",

View file

@ -1,69 +0,0 @@
import { apiRequest } from "$api";
import ApiError, { ErrorCode } from "$api/error";
import type { AddAccountResponse, CallbackResponse } from "$api/models";
import { setToken } from "$lib";
import log from "$lib/log";
import { isRedirect, redirect, type ServerLoadEvent } from "@sveltejs/kit";
export default function createCallbackLoader(
callbackType: string,
bodyFn?: (event: ServerLoadEvent) => Promise<unknown>,
) {
return async (event: ServerLoadEvent) => {
const { url, parent, fetch, cookies } = event;
bodyFn ??= async ({ url }) => {
const code = url.searchParams.get("code") as string | null;
const state = url.searchParams.get("state") as string | null;
if (!code || !state) throw new ApiError(undefined, ErrorCode.BadRequest).obj;
return { code, state };
};
const { meUser } = await parent();
if (meUser) {
try {
const resp = await apiRequest<AddAccountResponse>(
"POST",
`/auth/${callbackType}/add-account/callback`,
{
isInternal: true,
body: await bodyFn(event),
fetch,
cookies,
},
);
return { hasAccount: true, isLinkRequest: true, newAuthMethod: resp };
} catch (e) {
if (e instanceof ApiError) return { isLinkRequest: true, error: e.obj };
log.error("error linking new %s account to user %s:", callbackType, meUser.id, e);
throw e;
}
}
try {
const resp = await apiRequest<CallbackResponse>("POST", `/auth/${callbackType}/callback`, {
body: await bodyFn(event),
isInternal: true,
fetch,
});
if (resp.has_account) {
setToken(cookies, resp.token!);
redirect(303, `/@${resp.user!.username}`);
}
return {
hasAccount: false,
isLinkRequest: false,
ticket: resp.ticket!,
remoteUser: resp.remote_username!,
};
} catch (e) {
if (isRedirect(e)) throw e;
if (e instanceof ApiError) return { isLinkRequest: false, error: e.obj };
log.error("error while requesting %s callback:", callbackType, e);
throw e;
}
};
}

View file

@ -4,8 +4,8 @@
import type { RawApiError } from "$api/error"; import type { RawApiError } from "$api/error";
import ErrorAlert from "$components/ErrorAlert.svelte"; import ErrorAlert from "$components/ErrorAlert.svelte";
type Props = { form: { error: RawApiError | null; ok: boolean } | null; successMessage?: string }; type Props = { form: { error: RawApiError | null; ok: boolean } | null };
let { form, successMessage }: Props = $props(); let { form }: Props = $props();
</script> </script>
{#if form?.error} {#if form?.error}
@ -13,6 +13,6 @@
{:else if form?.ok} {:else if form?.ok}
<p class="text-success-emphasis"> <p class="text-success-emphasis">
<Icon name="check-circle-fill" /> <Icon name="check-circle-fill" />
{successMessage ?? $t("edit-profile.saved-changes")} {$t("edit-profile.saved-changes")}
</p> </p>
{/if} {/if}

View file

@ -1,5 +1,6 @@
<script lang="ts"> <script lang="ts">
import type { RawApiError } from "$api/error"; import type { RawApiError } from "$api/error";
import { enhance } from "$app/forms";
import ErrorAlert from "$components/ErrorAlert.svelte"; import ErrorAlert from "$components/ErrorAlert.svelte";
import { t } from "$lib/i18n"; import { t } from "$lib/i18n";
import { Button, Input, Label } from "@sveltestrap/sveltestrap"; import { Button, Input, Label } from "@sveltestrap/sveltestrap";
@ -20,7 +21,7 @@
<ErrorAlert {error} /> <ErrorAlert {error} />
{/if} {/if}
<form method="POST"> <form method="POST" use:enhance>
<div class="mb-3"> <div class="mb-3">
<Label>{remoteLabel}</Label> <Label>{remoteLabel}</Label>
<Input type="text" readonly value={remoteUser} /> <Input type="text" readonly value={remoteUser} />

View file

@ -48,11 +48,7 @@
"successful-link-profile-hint": "You now can close this page, or go back to your profile:", "successful-link-profile-hint": "You now can close this page, or go back to your profile:",
"successful-link-profile-link": "Go to your profile", "successful-link-profile-link": "Go to your profile",
"remote-discord-account-label": "Your Discord account", "remote-discord-account-label": "Your Discord account",
"log-in-with-fediverse-instance-placeholder": "Your instance (i.e. mastodon.social)", "log-in-with-fediverse-instance-placeholder": "Your instance (i.e. mastodon.social)"
"register-with-email": "Register with an email address",
"email-label": "Your email address",
"confirm-password-label": "Confirm password",
"register-with-email-init-success": "Success! An email has been sent to your inbox, please press the link there to continue."
}, },
"error": { "error": {
"bad-request-header": "Something was wrong with your input", "bad-request-header": "Something was wrong with your input",

View file

@ -1,7 +1,63 @@
import createCallbackLoader from "$lib/actions/callback"; import { apiRequest } from "$api";
import createRegisterAction from "$lib/actions/register"; import ApiError, { ErrorCode } from "$api/error";
import type { AddAccountResponse, CallbackResponse } from "$api/models/auth";
import { setToken } from "$lib";
import createRegisterAction from "$lib/actions/register.js";
import log from "$lib/log.js";
import { isRedirect, redirect } from "@sveltejs/kit";
export const load = createCallbackLoader("discord"); export const load = async ({ url, parent, fetch, cookies }) => {
const code = url.searchParams.get("code") as string | null;
const state = url.searchParams.get("state") as string | null;
if (!code || !state) throw new ApiError(undefined, ErrorCode.BadRequest).obj;
const { meUser } = await parent();
if (meUser) {
try {
const resp = await apiRequest<AddAccountResponse>(
"POST",
"/auth/discord/add-account/callback",
{
isInternal: true,
body: { code, state },
fetch,
cookies,
},
);
return { hasAccount: true, isLinkRequest: true, newAuthMethod: resp };
} catch (e) {
if (e instanceof ApiError) return { isLinkRequest: true, error: e.obj };
log.error("error linking new discord account to user %s:", meUser.id, e);
throw e;
}
}
try {
const resp = await apiRequest<CallbackResponse>("POST", "/auth/discord/callback", {
body: { code, state },
isInternal: true,
fetch,
});
if (resp.has_account) {
setToken(cookies, resp.token!);
redirect(303, `/@${resp.user!.username}`);
}
return {
hasAccount: false,
isLinkRequest: false,
ticket: resp.ticket!,
remoteUser: resp.remote_username!,
};
} catch (e) {
if (isRedirect(e)) throw e;
if (e instanceof ApiError) return { isLinkRequest: false, error: e.obj };
log.error("error while requesting discord callback:", e);
throw e;
}
};
export const actions = { export const actions = {
default: createRegisterAction("/auth/discord/register"), default: createRegisterAction("/auth/discord/register"),

View file

@ -1,53 +0,0 @@
import { apiRequest } from "$api";
import ApiError, { ErrorCode, type RawApiError } from "$api/error";
import type { AuthResponse } from "$api/models/auth";
import { setToken } from "$lib";
import createCallbackLoader from "$lib/actions/callback";
import log from "$lib/log";
import { redirect, isRedirect } from "@sveltejs/kit";
export const load = createCallbackLoader("email", async ({ params }) => {
log.info("params:", params, "code:", params.code);
return { state: params.code! };
});
export const actions = {
default: async ({ request, fetch, cookies }) => {
const data = await request.formData();
const username = data.get("username") as string | null;
const ticket = data.get("ticket") as string | null;
const password = data.get("password") as string | null;
const password2 = data.get("confirm-password") as string | null;
if (!username || !ticket || !password || !password2)
return {
error: { message: "Bad request", code: ErrorCode.BadRequest, status: 400 } as RawApiError,
};
if (password !== password2)
return {
error: {
message: "Passwords do not match",
code: ErrorCode.BadRequest,
status: 400,
} as RawApiError,
};
try {
const resp = await apiRequest<AuthResponse>("POST", "/auth/email/register", {
body: { username, ticket, password },
isInternal: true,
fetch,
});
setToken(cookies, resp.token);
redirect(303, "/auth/welcome");
} catch (e) {
if (isRedirect(e)) throw e;
log.error("Could not sign up user with username %s:", username, e);
if (e instanceof ApiError) return { error: e.obj };
throw e;
}
},
};

View file

@ -1,51 +0,0 @@
<script lang="ts">
import Error from "$components/Error.svelte";
import ErrorAlert from "$components/ErrorAlert.svelte";
import NewAuthMethod from "$components/settings/NewAuthMethod.svelte";
import { t } from "$lib/i18n";
import { Label, Input, Button } from "@sveltestrap/sveltestrap";
import type { ActionData, PageData } from "./$types";
type Props = { data: PageData; form: ActionData };
let { data, form }: Props = $props();
</script>
<svelte:head>
<title>{$t("auth.register-with-email")} • pronouns.cc</title>
</svelte:head>
<div class="container">
{#if data.error}
<h1>{$t("auth.register-with-email")}</h1>
<Error error={data.error} />
{:else if data.isLinkRequest}
<NewAuthMethod method={data.newAuthMethod!} user={data.meUser!} />
{:else}
<h1>{$t("auth.register-with-email")}</h1>
{#if form?.error}
<ErrorAlert error={form.error} />
{/if}
<form method="POST">
<div class="mb-3">
<Label>{$t("auth.email-label")}</Label>
<Input type="text" readonly value={data.remoteUser} />
</div>
<div class="mb-3">
<Label>{$t("auth.register-username-label")}</Label>
<Input type="text" name="username" required />
</div>
<div class="mb-3">
<Label>{$t("auth.log-in-form-password-label")}</Label>
<Input type="password" name="password" required />
</div>
<div class="mb-3">
<Label>{$t("auth.confirm-password-label")}</Label>
<Input type="password" name="confirm-password" required />
</div>
<input type="hidden" name="ticket" value={data.ticket!} />
<Button color="primary" type="submit">{$t("auth.register-button")}</Button>
</form>
{/if}
</div>

View file

@ -1,14 +1,63 @@
import { apiRequest } from "$api";
import ApiError, { ErrorCode } from "$api/error"; import ApiError, { ErrorCode } from "$api/error";
import createCallbackLoader from "$lib/actions/callback"; import type { AddAccountResponse, CallbackResponse } from "$api/models/auth.js";
import createRegisterAction from "$lib/actions/register"; import { setToken } from "$lib";
import createRegisterAction from "$lib/actions/register.js";
import log from "$lib/log";
import { isRedirect, redirect } from "@sveltejs/kit";
export const load = createCallbackLoader("fediverse", async ({ params, url }) => { export const load = async ({ parent, params, url, fetch, cookies }) => {
const code = url.searchParams.get("code") as string | null; const code = url.searchParams.get("code") as string | null;
const state = url.searchParams.get("state") as string | null; const state = url.searchParams.get("state") as string | null;
if (!code || !state) throw new ApiError(undefined, ErrorCode.BadRequest).obj; if (!code || !state) throw new ApiError(undefined, ErrorCode.BadRequest).obj;
return { code, state, instance: params.instance! }; const { meUser } = await parent();
}); if (meUser) {
try {
const resp = await apiRequest<AddAccountResponse>(
"POST",
"/auth/fediverse/add-account/callback",
{
isInternal: true,
body: { code, state, instance: params.instance },
fetch,
cookies,
},
);
return { hasAccount: true, isLinkRequest: true, newAuthMethod: resp };
} catch (e) {
if (e instanceof ApiError) return { isLinkRequest: true, error: e.obj };
log.error("error linking new fediverse account to user %s:", meUser.id, e);
throw e;
}
}
try {
const resp = await apiRequest<CallbackResponse>("POST", "/auth/fediverse/callback", {
body: { code, state, instance: params.instance },
isInternal: true,
fetch,
});
if (resp.has_account) {
setToken(cookies, resp.token!);
redirect(303, `/@${resp.user!.username}`);
}
return {
hasAccount: false,
isLinkRequest: false,
ticket: resp.ticket!,
remoteUser: resp.remote_username!,
};
} catch (e) {
if (isRedirect(e)) throw e;
if (e instanceof ApiError) return { isLinkRequest: false, error: e.obj };
log.error("error while requesting fediverse callback:", e);
throw e;
}
};
export const actions = { export const actions = {
default: createRegisterAction("/auth/fediverse/register"), default: createRegisterAction("/auth/fediverse/register"),

View file

@ -1,35 +0,0 @@
import { apiRequest, fastRequest } from "$api";
import ApiError from "$api/error.js";
import type { AuthUrls } from "$api/models/auth";
import log from "$lib/log.js";
import { redirect } from "@sveltejs/kit";
export const load = async ({ fetch, parent }) => {
const parentData = await parent();
if (parentData.meUser) redirect(303, `/@${parentData.meUser.username}`);
const urls = await apiRequest<AuthUrls>("POST", "/auth/urls", { fetch, isInternal: true });
if (!urls.email_enabled) redirect(303, "/");
};
export const actions = {
default: async ({ request, fetch, cookies }) => {
const body = await request.formData();
const email = body.get("email") as string;
try {
await fastRequest("POST", `/auth/email/register/init`, {
body: { email },
isInternal: true,
fetch,
cookies,
});
return { ok: true, error: null };
} catch (e) {
if (e instanceof ApiError) return { ok: false, error: e.obj };
log.error("error initiating registration for email %s:", email, e);
throw e;
}
},
};

View file

@ -1,29 +0,0 @@
<script lang="ts">
import type { ActionData, PageData } from "./$types";
import { t } from "$lib/i18n";
import { enhance } from "$app/forms";
import { Button, Input, InputGroup } from "@sveltestrap/sveltestrap";
import FormStatusMarker from "$components/editor/FormStatusMarker.svelte";
type Props = { data: PageData; form: ActionData };
let { data, form }: Props = $props();
</script>
<svelte:head>
<title>{$t("auth.register-with-email")} • pronouns.cc</title>
</svelte:head>
<div class="container">
<div class="mx-auto w-lg-75">
<h2>{$t("auth.register-with-email")}</h2>
<FormStatusMarker {form} successMessage={$t("auth.register-with-email-init-success")} />
<form method="POST" use:enhance>
<InputGroup>
<Input name="email" type="email" placeholder={$t("auth.email-label")} />
<Button type="submit" color="secondary">{$t("auth.register-with-email-button")}</Button>
</InputGroup>
</form>
</div>
</div>

View file

@ -3,10 +3,9 @@
import { Button, Input, InputGroup } from "@sveltestrap/sveltestrap"; import { Button, Input, InputGroup } from "@sveltestrap/sveltestrap";
</script> </script>
<div class="mx-auto w-lg-75"> <h3>Link a new Fediverse account</h3>
<h3>Link a new Fediverse account</h3>
<form method="POST" action="?/add"> <form method="POST" action="?/add">
<InputGroup> <InputGroup>
<Input <Input
name="instance" name="instance"
@ -21,5 +20,4 @@
{$t("auth.log-in-with-fediverse-force-refresh-button")} {$t("auth.log-in-with-fediverse-force-refresh-button")}
</Button> </Button>
</p> </p>
</form> </form>
</div>

View file

@ -5,7 +5,6 @@
import { Icon } from "@sveltestrap/sveltestrap"; import { Icon } from "@sveltestrap/sveltestrap";
import { t } from "$lib/i18n"; import { t } from "$lib/i18n";
import { enhance } from "$app/forms"; import { enhance } from "$app/forms";
import FormStatusMarker from "$components/editor/FormStatusMarker.svelte";
type Props = { data: PageData; form: ActionData }; type Props = { data: PageData; form: ActionData };
let { data, form }: Props = $props(); let { data, form }: Props = $props();
@ -19,7 +18,14 @@
<div class="mx-auto w-lg-75"> <div class="mx-auto w-lg-75">
<h3>{$t("settings.export-title")}</h3> <h3>{$t("settings.export-title")}</h3>
<FormStatusMarker {form} successMessage={$t("settings.export-request-success")} /> {#if form?.ok}
<p class="text-success-emphasis">
<Icon name="check-circle-fill" />
{$t("settings.export-request-success")}
</p>
{:else if form?.error}
<ErrorAlert error={form.error} />
{/if}
<p> <p>
{$t("settings.export-info")} {$t("settings.export-info")}

View file

@ -1,5 +1,4 @@
<wpf:ResourceDictionary xml:space="preserve" xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:s="clr-namespace:System;assembly=mscorlib" xmlns:ss="urn:shemas-jetbrains-com:settings-storage-xaml" xmlns:wpf="http://schemas.microsoft.com/winfx/2006/xaml/presentation"> <wpf:ResourceDictionary xml:space="preserve" xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:s="clr-namespace:System;assembly=mscorlib" xmlns:ss="urn:shemas-jetbrains-com:settings-storage-xaml" xmlns:wpf="http://schemas.microsoft.com/winfx/2006/xaml/presentation">
<s:String x:Key="/Default/CodeStyle/CSharpVarKeywordUsage/ForBuiltInTypes/@EntryValue">UseVarWhenEvident</s:String>
<s:String x:Key="/Default/CodeStyle/FileHeader/FileHeaderText/@EntryValue">Copyright (C) 2023-present sam/u1f320 (vulpine.solutions) <s:String x:Key="/Default/CodeStyle/FileHeader/FileHeaderText/@EntryValue">Copyright (C) 2023-present sam/u1f320 (vulpine.solutions)
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify

View file

@ -3,8 +3,9 @@
## C# code style ## C# code style
- Code should be formatted with `dotnet format` or Rider's built-in formatter. - Code should be formatted with `dotnet format` or Rider's built-in formatter.
- Variables should always be declared with their type name, unless the type is obvious from the declaration. - Variables should *always* be declared using `var`,
(For example, `var stream = new Stream()` or `var db = services.GetRequiredService<DatabaseContext>()`) unless the correct type can't be inferred from the declaration (i.e. if the variable needs to be an `IEnumerable<T>`
instead of a `List<T>`, or if a variable is initialized as `null`).
### Naming ### Naming