chore: add csharpier to husky, format backend with csharpier

This commit is contained in:
sam 2024-10-02 00:28:07 +02:00
parent 5fab66444f
commit 7f971e8549
Signed by: sam
GPG key ID: B4EF20DDE721CAA1
73 changed files with 2098 additions and 1048 deletions

View file

@ -4,7 +4,16 @@
"tools": {
"husky": {
"version": "0.7.1",
"commands": ["husky"],
"commands": [
"husky"
],
"rollForward": false
},
"csharpier": {
"version": "0.29.2",
"commands": [
"dotnet-csharpier"
],
"rollForward": false
}
}

View file

@ -8,9 +8,10 @@
"pathMode": "absolute"
},
{
"name": "dotnet-format",
"name": "run-csharpier",
"command": "dotnet",
"args": ["format"]
"args": [ "csharpier", "${staged}" ],
"include": [ "**/*.cs" ]
}
]
}

View file

@ -8,19 +8,24 @@ public static class BuildInfo
public static async Task ReadBuildInfo()
{
await using var stream = typeof(BuildInfo).Assembly.GetManifestResourceStream("version");
if (stream == null) return;
if (stream == null)
return;
using var reader = new StreamReader(stream);
var data = (await reader.ReadToEndAsync()).Trim().Split("\n");
if (data.Length < 3) return;
if (data.Length < 3)
return;
Hash = data[0];
var dirty = data[2] == "dirty";
var versionData = data[1].Split("-");
if (versionData.Length < 3) return;
if (versionData.Length < 3)
return;
Version = versionData[0];
if (versionData[1] != "0" || dirty) Version += $"+{versionData[2]}";
if (dirty) Version += ".dirty";
if (versionData[1] != "0" || dirty)
Version += $"+{versionData[2]}";
if (dirty)
Version += ".dirty";
}
}

View file

@ -11,8 +11,12 @@ using NodaTime;
namespace Foxnouns.Backend.Controllers.Authentication;
[Route("/api/internal/auth")]
public class AuthController(Config config, DatabaseContext db, KeyCacheService keyCache, ILogger logger)
: ApiControllerBase
public class AuthController(
Config config,
DatabaseContext db,
KeyCacheService keyCache,
ILogger logger
) : ApiControllerBase
{
private readonly ILogger _logger = logger.ForContext<AuthController>();
@ -20,27 +24,25 @@ public class AuthController(Config config, DatabaseContext db, KeyCacheService k
[ProducesResponseType<UrlsResponse>(StatusCodes.Status200OK)]
public async Task<IActionResult> UrlsAsync(CancellationToken ct = default)
{
_logger.Debug("Generating auth URLs for Discord: {Discord}, Google: {Google}, Tumblr: {Tumblr}",
_logger.Debug(
"Generating auth URLs for Discord: {Discord}, Google: {Google}, Tumblr: {Tumblr}",
config.DiscordAuth.Enabled,
config.GoogleAuth.Enabled,
config.TumblrAuth.Enabled);
config.TumblrAuth.Enabled
);
var state = HttpUtility.UrlEncode(await keyCache.GenerateAuthStateAsync(ct));
string? discord = null;
if (config.DiscordAuth is { ClientId: not null, ClientSecret: not null })
discord =
$"https://discord.com/oauth2/authorize?response_type=code" +
$"&client_id={config.DiscordAuth.ClientId}&scope=identify" +
$"&prompt=none&state={state}" +
$"&redirect_uri={HttpUtility.UrlEncode($"{config.BaseUrl}/auth/callback/discord")}";
$"https://discord.com/oauth2/authorize?response_type=code"
+ $"&client_id={config.DiscordAuth.ClientId}&scope=identify"
+ $"&prompt=none&state={state}"
+ $"&redirect_uri={HttpUtility.UrlEncode($"{config.BaseUrl}/auth/callback/discord")}";
return Ok(new UrlsResponse(discord, null, null));
}
private record UrlsResponse(
string? Discord,
string? Google,
string? Tumblr
);
private record UrlsResponse(string? Discord, string? Google, string? Tumblr);
public record AuthResponse(
UserRendererService.UserResponse User,
@ -50,16 +52,13 @@ public class AuthController(Config config, DatabaseContext db, KeyCacheService k
public record CallbackResponse(
bool HasAccount,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] string? Ticket,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
string? Ticket,
string? RemoteUsername,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
string? RemoteUsername,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
UserRendererService.UserResponse? User,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
string? Token,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
Instant? ExpiresAt
UserRendererService.UserResponse? User,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] string? Token,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] Instant? ExpiresAt
);
public record OauthRegisterRequest(string Ticket, string Username);
@ -71,7 +70,8 @@ public class AuthController(Config config, DatabaseContext db, KeyCacheService k
public async Task<IActionResult> ForceLogoutAsync()
{
_logger.Information("Invalidating all tokens for user {UserId}", CurrentUser!.Id);
await db.Tokens.Where(t => t.UserId == CurrentUser.Id)
await db
.Tokens.Where(t => t.UserId == CurrentUser.Id)
.ExecuteUpdateAsync(s => s.SetProperty(t => t.ManuallyExpired, true));
return NoContent();

View file

@ -19,7 +19,8 @@ public class DiscordAuthController(
KeyCacheService keyCacheService,
AuthService authService,
RemoteAuthService remoteAuthService,
UserRendererService userRenderer) : ApiControllerBase
UserRendererService userRenderer
) : ApiControllerBase
{
private readonly ILogger _logger = logger.ForContext<DiscordAuthController>();
@ -27,59 +28,93 @@ public class DiscordAuthController(
// TODO: duplicating attribute doesn't work, find another way to mark both as possible response
// leaving it here for documentation purposes
[ProducesResponseType<AuthController.CallbackResponse>(StatusCodes.Status200OK)]
public async Task<IActionResult> CallbackAsync([FromBody] AuthController.CallbackRequest req,
CancellationToken ct = default)
public async Task<IActionResult> CallbackAsync(
[FromBody] AuthController.CallbackRequest req,
CancellationToken ct = default
)
{
CheckRequirements();
await keyCacheService.ValidateAuthStateAsync(req.State, ct);
var remoteUser = await remoteAuthService.RequestDiscordTokenAsync(req.Code, req.State, ct);
var user = await authService.AuthenticateUserAsync(AuthType.Discord, remoteUser.Id, ct: ct);
if (user != null) return Ok(await GenerateUserTokenAsync(user, ct));
if (user != null)
return Ok(await GenerateUserTokenAsync(user, ct));
_logger.Debug("Discord user {Username} ({Id}) authenticated with no local account", remoteUser.Username,
remoteUser.Id);
_logger.Debug(
"Discord user {Username} ({Id}) authenticated with no local account",
remoteUser.Username,
remoteUser.Id
);
var ticket = AuthUtils.RandomToken();
await keyCacheService.SetKeyAsync($"discord:{ticket}", remoteUser, Duration.FromMinutes(20), ct);
await keyCacheService.SetKeyAsync(
$"discord:{ticket}",
remoteUser,
Duration.FromMinutes(20),
ct
);
return Ok(new AuthController.CallbackResponse(
HasAccount: false,
Ticket: ticket,
RemoteUsername: remoteUser.Username,
User: null,
Token: null,
ExpiresAt: null
));
return Ok(
new AuthController.CallbackResponse(
HasAccount: false,
Ticket: ticket,
RemoteUsername: remoteUser.Username,
User: null,
Token: null,
ExpiresAt: null
)
);
}
[HttpPost("register")]
[ProducesResponseType<AuthController.AuthResponse>(StatusCodes.Status200OK)]
public async Task<IActionResult> RegisterAsync([FromBody] AuthController.OauthRegisterRequest req)
public async Task<IActionResult> RegisterAsync(
[FromBody] AuthController.OauthRegisterRequest req
)
{
var remoteUser = await keyCacheService.GetKeyAsync<RemoteAuthService.RemoteUser>($"discord:{req.Ticket}");
if (remoteUser == null) throw new ApiError.BadRequest("Invalid ticket", "ticket", req.Ticket);
if (await db.AuthMethods.AnyAsync(a => a.AuthType == AuthType.Discord && a.RemoteId == remoteUser.Id))
var remoteUser = await keyCacheService.GetKeyAsync<RemoteAuthService.RemoteUser>(
$"discord:{req.Ticket}"
);
if (remoteUser == null)
throw new ApiError.BadRequest("Invalid ticket", "ticket", req.Ticket);
if (
await db.AuthMethods.AnyAsync(a =>
a.AuthType == AuthType.Discord && a.RemoteId == remoteUser.Id
)
)
{
_logger.Error("Discord user {Id} has valid ticket but is already linked to an existing account",
remoteUser.Id);
_logger.Error(
"Discord user {Id} has valid ticket but is already linked to an existing account",
remoteUser.Id
);
throw new ApiError.BadRequest("Invalid ticket", "ticket", req.Ticket);
}
var user = await authService.CreateUserWithRemoteAuthAsync(req.Username, AuthType.Discord, remoteUser.Id,
remoteUser.Username);
var user = await authService.CreateUserWithRemoteAuthAsync(
req.Username,
AuthType.Discord,
remoteUser.Id,
remoteUser.Username
);
return Ok(await GenerateUserTokenAsync(user));
}
private async Task<AuthController.CallbackResponse> GenerateUserTokenAsync(User user,
CancellationToken ct = default)
private async Task<AuthController.CallbackResponse> GenerateUserTokenAsync(
User user,
CancellationToken ct = default
)
{
var frontendApp = await db.GetFrontendApplicationAsync(ct);
_logger.Debug("Logging user {Id} in with Discord", user.Id);
var (tokenStr, token) =
authService.GenerateToken(user, frontendApp, ["*"], clock.GetCurrentInstant() + Duration.FromDays(365));
var (tokenStr, token) = authService.GenerateToken(
user,
frontendApp,
["*"],
clock.GetCurrentInstant() + Duration.FromDays(365)
);
db.Add(token);
_logger.Debug("Generated token {TokenId} for {UserId}", user.Id, token.Id);
@ -90,7 +125,12 @@ public class DiscordAuthController(
HasAccount: true,
Ticket: null,
RemoteUsername: null,
User: await userRenderer.RenderUserAsync(user, selfUser: user, renderMembers: false, ct: ct),
User: await userRenderer.RenderUserAsync(
user,
selfUser: user,
renderMembers: false,
ct: ct
),
Token: tokenStr,
ExpiresAt: token.ExpiresAt
);
@ -99,6 +139,8 @@ public class DiscordAuthController(
private void CheckRequirements()
{
if (!config.DiscordAuth.Enabled)
throw new ApiError.BadRequest("Discord authentication is not enabled on this instance.");
throw new ApiError.BadRequest(
"Discord authentication is not enabled on this instance."
);
}
}

View file

@ -20,21 +20,35 @@ public class EmailAuthController(
KeyCacheService keyCacheService,
UserRendererService userRenderer,
IClock clock,
ILogger logger) : ApiControllerBase
ILogger logger
) : ApiControllerBase
{
private readonly ILogger _logger = logger.ForContext<EmailAuthController>();
[HttpPost("register")]
public async Task<IActionResult> RegisterAsync([FromBody] RegisterRequest req, CancellationToken ct = default)
public async Task<IActionResult> RegisterAsync(
[FromBody] RegisterRequest req,
CancellationToken ct = default
)
{
CheckRequirements();
if (!req.Email.Contains('@')) throw new ApiError.BadRequest("Email is invalid", "email", req.Email);
if (!req.Email.Contains('@'))
throw new ApiError.BadRequest("Email is invalid", "email", req.Email);
var state = await keyCacheService.GenerateRegisterEmailStateAsync(req.Email, userId: null, ct);
var state = await keyCacheService.GenerateRegisterEmailStateAsync(
req.Email,
userId: null,
ct
);
// If there's already a user with that email address, pretend we sent an email but actually ignore it
if (await db.AuthMethods.AnyAsync(a => a.AuthType == AuthType.Email && a.RemoteId == req.Email, ct))
if (
await db.AuthMethods.AnyAsync(
a => a.AuthType == AuthType.Email && a.RemoteId == req.Email,
ct
)
)
return NoContent();
mailService.QueueAccountCreationEmail(req.Email, state);
@ -47,29 +61,48 @@ public class EmailAuthController(
CheckRequirements();
var state = await keyCacheService.GetRegisterEmailStateAsync(req.State);
if (state == null) throw new ApiError.BadRequest("Invalid state", "state", req.State);
if (state == null)
throw new ApiError.BadRequest("Invalid state", "state", req.State);
// If this callback is for an existing user, add the email address to their auth methods
if (state.ExistingUserId != null)
{
var authMethod =
await authService.AddAuthMethodAsync(state.ExistingUserId.Value, AuthType.Email, state.Email);
_logger.Debug("Added email auth {AuthId} for user {UserId}", authMethod.Id, state.ExistingUserId);
var authMethod = await authService.AddAuthMethodAsync(
state.ExistingUserId.Value,
AuthType.Email,
state.Email
);
_logger.Debug(
"Added email auth {AuthId} for user {UserId}",
authMethod.Id,
state.ExistingUserId
);
return NoContent();
}
var ticket = AuthUtils.RandomToken();
await keyCacheService.SetKeyAsync($"email:{ticket}", state.Email, Duration.FromMinutes(20));
return Ok(new AuthController.CallbackResponse(HasAccount: false, Ticket: ticket, RemoteUsername: state.Email,
User: null, Token: null, ExpiresAt: null));
return Ok(
new AuthController.CallbackResponse(
HasAccount: false,
Ticket: ticket,
RemoteUsername: state.Email,
User: null,
Token: null,
ExpiresAt: null
)
);
}
[HttpPost("complete-registration")]
public async Task<IActionResult> CompleteRegistrationAsync([FromBody] CompleteRegistrationRequest req)
public async Task<IActionResult> CompleteRegistrationAsync(
[FromBody] CompleteRegistrationRequest req
)
{
var email = await keyCacheService.GetKeyAsync($"email:{req.Ticket}");
if (email == null) throw new ApiError.BadRequest("Unknown ticket", "ticket", req.Ticket);
if (email == null)
throw new ApiError.BadRequest("Unknown ticket", "ticket", req.Ticket);
// Check if username is valid at all
ValidationUtils.Validate([("username", ValidationUtils.ValidateUsername(req.Username))]);
@ -80,28 +113,41 @@ public class EmailAuthController(
var user = await authService.CreateUserWithPasswordAsync(req.Username, email, req.Password);
var frontendApp = await db.GetFrontendApplicationAsync();
var (tokenStr, token) =
authService.GenerateToken(user, frontendApp, ["*"], clock.GetCurrentInstant() + Duration.FromDays(365));
var (tokenStr, token) = authService.GenerateToken(
user,
frontendApp,
["*"],
clock.GetCurrentInstant() + Duration.FromDays(365)
);
db.Add(token);
await db.SaveChangesAsync();
await keyCacheService.DeleteKeyAsync($"email:{req.Ticket}");
return Ok(new AuthController.AuthResponse(
await userRenderer.RenderUserAsync(user, selfUser: user, renderMembers: false),
tokenStr,
token.ExpiresAt
));
return Ok(
new AuthController.AuthResponse(
await userRenderer.RenderUserAsync(user, selfUser: user, renderMembers: false),
tokenStr,
token.ExpiresAt
)
);
}
[HttpPost("login")]
[ProducesResponseType<AuthController.AuthResponse>(StatusCodes.Status200OK)]
public async Task<IActionResult> LoginAsync([FromBody] LoginRequest req, CancellationToken ct = default)
public async Task<IActionResult> LoginAsync(
[FromBody] LoginRequest req,
CancellationToken ct = default
)
{
CheckRequirements();
var (user, authenticationResult) = await authService.AuthenticateUserAsync(req.Email, req.Password, ct);
var (user, authenticationResult) = await authService.AuthenticateUserAsync(
req.Email,
req.Password,
ct
);
if (authenticationResult == AuthService.EmailAuthenticationResult.MfaRequired)
throw new NotImplementedException("MFA is not implemented yet");
@ -109,19 +155,30 @@ public class EmailAuthController(
_logger.Debug("Logging user {Id} in with email and password", user.Id);
var (tokenStr, token) =
authService.GenerateToken(user, frontendApp, ["*"], clock.GetCurrentInstant() + Duration.FromDays(365));
var (tokenStr, token) = authService.GenerateToken(
user,
frontendApp,
["*"],
clock.GetCurrentInstant() + Duration.FromDays(365)
);
db.Add(token);
_logger.Debug("Generated token {TokenId} for {UserId}", token.Id, user.Id);
await db.SaveChangesAsync(ct);
return Ok(new AuthController.AuthResponse(
await userRenderer.RenderUserAsync(user, selfUser: user, renderMembers: false, ct: ct),
tokenStr,
token.ExpiresAt
));
return Ok(
new AuthController.AuthResponse(
await userRenderer.RenderUserAsync(
user,
selfUser: user,
renderMembers: false,
ct: ct
),
tokenStr,
token.ExpiresAt
)
);
}
[HttpPost("add")]

View file

@ -18,13 +18,16 @@ public class FlagsController(
UserRendererService userRenderer,
ObjectStorageService objectStorageService,
ISnowflakeGenerator snowflakeGenerator,
IQueue queue) : ApiControllerBase
IQueue queue
) : ApiControllerBase
{
private readonly ILogger _logger = logger.ForContext<FlagsController>();
[HttpGet]
[Authorize("identify")]
[ProducesResponseType<IEnumerable<UserRendererService.PrideFlagResponse>>(statusCode: StatusCodes.Status200OK)]
[ProducesResponseType<IEnumerable<UserRendererService.PrideFlagResponse>>(
statusCode: StatusCodes.Status200OK
)]
public async Task<IActionResult> GetFlagsAsync(CancellationToken ct = default)
{
var flags = await db.PrideFlags.Where(f => f.UserId == CurrentUser!.Id).ToListAsync(ct);
@ -34,7 +37,9 @@ public class FlagsController(
[HttpPost]
[Authorize("user.update")]
[ProducesResponseType<UserRendererService.PrideFlagResponse>(statusCode: StatusCodes.Status202Accepted)]
[ProducesResponseType<UserRendererService.PrideFlagResponse>(
statusCode: StatusCodes.Status202Accepted
)]
public IActionResult CreateFlag([FromBody] CreateFlagRequest req)
{
ValidationUtils.Validate(ValidateFlag(req.Name, req.Description, req.Image));
@ -42,7 +47,8 @@ public class FlagsController(
var id = snowflakeGenerator.GenerateSnowflake();
queue.QueueInvocableWithPayload<CreateFlagInvocable, CreateFlagPayload>(
new CreateFlagPayload(id, CurrentUser!.Id, req.Name, req.Image, req.Description));
new CreateFlagPayload(id, CurrentUser!.Id, req.Name, req.Image, req.Description)
);
return Accepted(new CreateFlagResponse(id, req.Name, req.Description));
}
@ -57,10 +63,14 @@ public class FlagsController(
{
ValidationUtils.Validate(ValidateFlag(req.Name, req.Description, null));
var flag = await db.PrideFlags.FirstOrDefaultAsync(f => f.Id == id && f.UserId == CurrentUser!.Id);
if (flag == null) throw new ApiError.NotFound("Unknown flag ID, or it's not your flag.");
var flag = await db.PrideFlags.FirstOrDefaultAsync(f =>
f.Id == id && f.UserId == CurrentUser!.Id
);
if (flag == null)
throw new ApiError.NotFound("Unknown flag ID, or it's not your flag.");
if (req.Name != null) flag.Name = req.Name;
if (req.Name != null)
flag.Name = req.Name;
if (req.HasProperty(nameof(req.Description)))
flag.Description = req.Description;
@ -83,8 +93,11 @@ public class FlagsController(
{
await using var tx = await db.Database.BeginTransactionAsync();
var flag = await db.PrideFlags.FirstOrDefaultAsync(f => f.Id == id && f.UserId == CurrentUser!.Id);
if (flag == null) throw new ApiError.NotFound("Unknown flag ID, or it's not your flag.");
var flag = await db.PrideFlags.FirstOrDefaultAsync(f =>
f.Id == id && f.UserId == CurrentUser!.Id
);
if (flag == null)
throw new ApiError.NotFound("Unknown flag ID, or it's not your flag.");
var hash = flag.Hash;
@ -96,7 +109,10 @@ public class FlagsController(
{
try
{
_logger.Information("Deleting flag file {Hash} as it is no longer used by any flags", hash);
_logger.Information(
"Deleting flag file {Hash} as it is no longer used by any flags",
hash
);
await objectStorageService.DeleteFlagAsync(hash);
}
catch (Exception e)
@ -104,14 +120,19 @@ public class FlagsController(
_logger.Error(e, "Error deleting flag file {Hash}", hash);
}
}
else _logger.Debug("Flag file {Hash} is used by other flags, not deleting", hash);
else
_logger.Debug("Flag file {Hash} is used by other flags, not deleting", hash);
await tx.CommitAsync();
return NoContent();
}
private static List<(string, ValidationError?)> ValidateFlag(string? name, string? description, string? imageData)
private static List<(string, ValidationError?)> ValidateFlag(
string? name,
string? description,
string? imageData
)
{
var errors = new List<(string, ValidationError?)>();
@ -120,10 +141,20 @@ public class FlagsController(
switch (name.Length)
{
case < 1:
errors.Add(("name", ValidationError.LengthError("Name is too short", 1, 100, name.Length)));
errors.Add(
(
"name",
ValidationError.LengthError("Name is too short", 1, 100, name.Length)
)
);
break;
case > 100:
errors.Add(("name", ValidationError.LengthError("Name is too long", 1, 100, name.Length)));
errors.Add(
(
"name",
ValidationError.LengthError("Name is too long", 1, 100, name.Length)
)
);
break;
}
}
@ -133,12 +164,30 @@ public class FlagsController(
switch (description.Length)
{
case < 1:
errors.Add(("description",
ValidationError.LengthError("Description is too short", 1, 100, description.Length)));
errors.Add(
(
"description",
ValidationError.LengthError(
"Description is too short",
1,
100,
description.Length
)
)
);
break;
case > 500:
errors.Add(("description",
ValidationError.LengthError("Description is too long", 1, 100, description.Length)));
errors.Add(
(
"description",
ValidationError.LengthError(
"Description is too long",
1,
100,
description.Length
)
)
);
break;
}
}
@ -148,10 +197,20 @@ public class FlagsController(
switch (imageData.Length)
{
case 0:
errors.Add(("image", ValidationError.GenericValidationError("Image cannot be empty", null)));
errors.Add(
(
"image",
ValidationError.GenericValidationError("Image cannot be empty", null)
)
);
break;
case > 1_500_000:
errors.Add(("image", ValidationError.GenericValidationError("Image is too large", null)));
errors.Add(
(
"image",
ValidationError.GenericValidationError("Image is too large", null)
)
);
break;
}
}

View file

@ -17,11 +17,13 @@ public partial class InternalController(DatabaseContext db) : ControllerBase
private static string GetCleanedTemplate(string template)
{
if (template.StartsWith("api/v2")) template = template["api/v2".Length..];
if (template.StartsWith("api/v2"))
template = template["api/v2".Length..];
template = PathVarRegex()
.Replace(template, "{id}") // Replace all path variables (almost always IDs) with `{id}`
.Replace("@me", "{id}"); // Also replace hardcoded `@me` with `{id}`
if (template.Contains("{id}")) return template.Split("{id}")[0] + "{id}";
if (template.Contains("{id}"))
return template.Split("{id}")[0] + "{id}";
return template;
}
@ -29,11 +31,13 @@ public partial class InternalController(DatabaseContext db) : ControllerBase
public async Task<IActionResult> GetRequestDataAsync([FromBody] RequestDataRequest req)
{
var endpoint = GetEndpoint(HttpContext, req.Path, req.Method);
if (endpoint == null) throw new ApiError.BadRequest("Path/method combination is invalid");
if (endpoint == null)
throw new ApiError.BadRequest("Path/method combination is invalid");
var actionDescriptor = endpoint.Metadata.GetMetadata<ControllerActionDescriptor>();
var template = actionDescriptor?.AttributeRouteInfo?.Template;
if (template == null) throw new FoxnounsError("Template value was null on valid endpoint");
if (template == null)
throw new FoxnounsError("Template value was null on valid endpoint");
template = GetCleanedTemplate(template);
// If no token was supplied, or it isn't valid base 64, return a null user ID (limiting by IP)
@ -46,26 +50,37 @@ public partial class InternalController(DatabaseContext db) : ControllerBase
public record RequestDataRequest(string? Token, string Method, string Path);
public record RequestDataResponse(
Snowflake? UserId,
string Template);
public record RequestDataResponse(Snowflake? UserId, string Template);
private static RouteEndpoint? GetEndpoint(HttpContext httpContext, string url, string requestMethod)
private static RouteEndpoint? GetEndpoint(
HttpContext httpContext,
string url,
string requestMethod
)
{
var endpointDataSource = httpContext.RequestServices.GetService<EndpointDataSource>();
if (endpointDataSource == null) return null;
if (endpointDataSource == null)
return null;
var endpoints = endpointDataSource.Endpoints.OfType<RouteEndpoint>();
foreach (var endpoint in endpoints)
{
if (endpoint.RoutePattern.RawText == null) continue;
if (endpoint.RoutePattern.RawText == null)
continue;
var templateMatcher = new TemplateMatcher(TemplateParser.Parse(endpoint.RoutePattern.RawText),
new RouteValueDictionary());
if (!templateMatcher.TryMatch(url, new())) continue;
var templateMatcher = new TemplateMatcher(
TemplateParser.Parse(endpoint.RoutePattern.RawText),
new RouteValueDictionary()
);
if (!templateMatcher.TryMatch(url, new()))
continue;
var httpMethodAttribute = endpoint.Metadata.GetMetadata<HttpMethodAttribute>();
if (httpMethodAttribute != null &&
!httpMethodAttribute.HttpMethods.Any(x => x.Equals(requestMethod, StringComparison.OrdinalIgnoreCase)))
if (
httpMethodAttribute != null
&& !httpMethodAttribute.HttpMethods.Any(x =>
x.Equals(requestMethod, StringComparison.OrdinalIgnoreCase)
)
)
continue;
return endpoint;
}

View file

@ -21,12 +21,15 @@ public class MembersController(
ISnowflakeGenerator snowflakeGenerator,
ObjectStorageService objectStorageService,
IQueue queue,
IClock clock) : ApiControllerBase
IClock clock
) : ApiControllerBase
{
private readonly ILogger _logger = logger.ForContext<MembersController>();
[HttpGet]
[ProducesResponseType<IEnumerable<MemberRendererService.PartialMember>>(StatusCodes.Status200OK)]
[ProducesResponseType<IEnumerable<MemberRendererService.PartialMember>>(
StatusCodes.Status200OK
)]
public async Task<IActionResult> GetMembersAsync(string userRef, CancellationToken ct = default)
{
var user = await db.ResolveUserAsync(userRef, CurrentToken, ct);
@ -35,7 +38,11 @@ public class MembersController(
[HttpGet("{memberRef}")]
[ProducesResponseType<MemberRendererService.MemberResponse>(StatusCodes.Status200OK)]
public async Task<IActionResult> GetMemberAsync(string userRef, string memberRef, CancellationToken ct = default)
public async Task<IActionResult> GetMemberAsync(
string userRef,
string memberRef,
CancellationToken ct = default
)
{
var member = await db.ResolveMemberAsync(userRef, memberRef, CurrentToken, ct);
return Ok(memberRenderer.RenderMember(member, CurrentToken));
@ -46,19 +53,30 @@ public class MembersController(
[HttpPost("/api/v2/users/@me/members")]
[ProducesResponseType<MemberRendererService.MemberResponse>(StatusCodes.Status200OK)]
[Authorize("member.create")]
public async Task<IActionResult> CreateMemberAsync([FromBody] CreateMemberRequest req,
CancellationToken ct = default)
public async Task<IActionResult> CreateMemberAsync(
[FromBody] CreateMemberRequest req,
CancellationToken ct = default
)
{
ValidationUtils.Validate([
("name", ValidationUtils.ValidateMemberName(req.Name)),
("display_name", ValidationUtils.ValidateDisplayName(req.DisplayName)),
("bio", ValidationUtils.ValidateBio(req.Bio)),
("avatar", ValidationUtils.ValidateAvatar(req.Avatar)),
.. ValidationUtils.ValidateFields(req.Fields, CurrentUser!.CustomPreferences),
.. ValidationUtils.ValidateFieldEntries(req.Names?.ToArray(), CurrentUser!.CustomPreferences, "names"),
.. ValidationUtils.ValidatePronouns(req.Pronouns?.ToArray(), CurrentUser!.CustomPreferences),
.. ValidationUtils.ValidateLinks(req.Links)
]);
ValidationUtils.Validate(
[
("name", ValidationUtils.ValidateMemberName(req.Name)),
("display_name", ValidationUtils.ValidateDisplayName(req.DisplayName)),
("bio", ValidationUtils.ValidateBio(req.Bio)),
("avatar", ValidationUtils.ValidateAvatar(req.Avatar)),
.. ValidationUtils.ValidateFields(req.Fields, CurrentUser!.CustomPreferences),
.. ValidationUtils.ValidateFieldEntries(
req.Names?.ToArray(),
CurrentUser!.CustomPreferences,
"names"
),
.. ValidationUtils.ValidatePronouns(
req.Pronouns?.ToArray(),
CurrentUser!.CustomPreferences
),
.. ValidationUtils.ValidateLinks(req.Links),
]
);
var memberCount = await db.Members.CountAsync(m => m.UserId == CurrentUser.Id, ct);
if (memberCount >= MaxMemberCount)
@ -75,11 +93,16 @@ public class MembersController(
Fields = req.Fields ?? [],
Names = req.Names ?? [],
Pronouns = req.Pronouns ?? [],
Unlisted = req.Unlisted ?? false
Unlisted = req.Unlisted ?? false,
};
db.Add(member);
_logger.Debug("Creating member {MemberName} ({Id}) for {UserId}", member.Name, member.Id, CurrentUser!.Id);
_logger.Debug(
"Creating member {MemberName} ({Id}) for {UserId}",
member.Name,
member.Id,
CurrentUser!.Id
);
try
{
@ -88,19 +111,27 @@ public class MembersController(
catch (UniqueConstraintException)
{
_logger.Debug("Could not create member {Id} due to name conflict", member.Id);
throw new ApiError.BadRequest("A member with that name already exists", "name", req.Name);
throw new ApiError.BadRequest(
"A member with that name already exists",
"name",
req.Name
);
}
if (req.Avatar != null)
queue.QueueInvocableWithPayload<MemberAvatarUpdateInvocable, AvatarUpdatePayload>(
new AvatarUpdatePayload(member.Id, req.Avatar));
new AvatarUpdatePayload(member.Id, req.Avatar)
);
return Ok(memberRenderer.RenderMember(member, CurrentToken));
}
[HttpPatch("/api/v2/users/@me/members/{memberRef}")]
[Authorize("member.update")]
public async Task<IActionResult> UpdateMemberAsync(string memberRef, [FromBody] UpdateMemberRequest req)
public async Task<IActionResult> UpdateMemberAsync(
string memberRef,
[FromBody] UpdateMemberRequest req
)
{
await using var tx = await db.Database.BeginTransactionAsync();
var member = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef);
@ -134,26 +165,37 @@ public class MembersController(
if (req.Names != null)
{
errors.AddRange(ValidationUtils.ValidateFieldEntries(req.Names, CurrentUser!.CustomPreferences, "names"));
errors.AddRange(
ValidationUtils.ValidateFieldEntries(
req.Names,
CurrentUser!.CustomPreferences,
"names"
)
);
member.Names = req.Names.ToList();
}
if (req.Pronouns != null)
{
errors.AddRange(ValidationUtils.ValidatePronouns(req.Pronouns, CurrentUser!.CustomPreferences));
errors.AddRange(
ValidationUtils.ValidatePronouns(req.Pronouns, CurrentUser!.CustomPreferences)
);
member.Pronouns = req.Pronouns.ToList();
}
if (req.Fields != null)
{
errors.AddRange(ValidationUtils.ValidateFields(req.Fields.ToList(), CurrentUser!.CustomPreferences));
errors.AddRange(
ValidationUtils.ValidateFields(req.Fields.ToList(), CurrentUser!.CustomPreferences)
);
member.Fields = req.Fields.ToList();
}
if (req.Flags != null)
{
var flagError = await db.SetMemberFlagsAsync(CurrentUser!.Id, member.Id, req.Flags);
if (flagError != null) errors.Add(("flags", flagError));
if (flagError != null)
errors.Add(("flags", flagError));
}
if (req.HasProperty(nameof(req.Avatar)))
@ -165,16 +207,25 @@ public class MembersController(
// so it's in a separate block to the validation above.
if (req.HasProperty(nameof(req.Avatar)))
queue.QueueInvocableWithPayload<MemberAvatarUpdateInvocable, AvatarUpdatePayload>(
new AvatarUpdatePayload(member.Id, req.Avatar));
new AvatarUpdatePayload(member.Id, req.Avatar)
);
try
{
await db.SaveChangesAsync();
}
catch (UniqueConstraintException)
{
_logger.Debug("Could not update member {Id} due to name conflict ({CurrentName} / {NewName})", member.Id,
member.Name, req.Name);
throw new ApiError.BadRequest("A member with that name already exists", "name", req.Name!);
_logger.Debug(
"Could not update member {Id} due to name conflict ({CurrentName} / {NewName})",
member.Id,
member.Name,
req.Name
);
throw new ApiError.BadRequest(
"A member with that name already exists",
"name",
req.Name!
);
}
await tx.CommitAsync();
@ -199,15 +250,20 @@ public class MembersController(
public async Task<IActionResult> DeleteMemberAsync(string memberRef)
{
var member = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef);
var deleteCount = await db.Members.Where(m => m.UserId == CurrentUser!.Id && m.Id == member.Id)
var deleteCount = await db
.Members.Where(m => m.UserId == CurrentUser!.Id && m.Id == member.Id)
.ExecuteDeleteAsync();
if (deleteCount == 0)
{
_logger.Warning("Successfully resolved member {Id} but could not delete them", member.Id);
_logger.Warning(
"Successfully resolved member {Id} but could not delete them",
member.Id
);
return NoContent();
}
if (member.Avatar != null) await objectStorageService.DeleteMemberAvatarAsync(member.Id, member.Avatar);
if (member.Avatar != null)
await objectStorageService.DeleteMemberAvatarAsync(member.Id, member.Avatar);
return NoContent();
}
@ -220,7 +276,8 @@ public class MembersController(
string[]? Links,
List<FieldEntry>? Names,
List<Pronoun>? Pronouns,
List<Field>? Fields);
List<Field>? Fields
);
[HttpPost("/api/v2/users/@me/members/{memberRef}/reroll-sid")]
[Authorize("member.update")]
@ -234,14 +291,16 @@ public class MembersController(
throw new ApiError.BadRequest("Cannot reroll short ID yet");
// Using ExecuteUpdateAsync here as the new short ID is generated by the database
await db.Members.Where(m => m.Id == member.Id)
.ExecuteUpdateAsync(s => s
.SetProperty(m => m.Sid, _ => db.FindFreeMemberSid()));
await db
.Members.Where(m => m.Id == member.Id)
.ExecuteUpdateAsync(s => s.SetProperty(m => m.Sid, _ => db.FindFreeMemberSid()));
await db.Users.Where(u => u.Id == CurrentUser.Id)
.ExecuteUpdateAsync(s => s
.SetProperty(u => u.LastSidReroll, clock.GetCurrentInstant())
.SetProperty(u => u.LastActive, clock.GetCurrentInstant()));
await db
.Users.Where(u => u.Id == CurrentUser.Id)
.ExecuteUpdateAsync(s =>
s.SetProperty(u => u.LastSidReroll, clock.GetCurrentInstant())
.SetProperty(u => u.LastActive, clock.GetCurrentInstant())
);
// Re-fetch member to fetch the new sid
var updatedMember = await db.ResolveMemberAsync(CurrentUser!.Id, memberRef);

View file

@ -12,23 +12,30 @@ public class MetaController : ApiControllerBase
[ProducesResponseType<MetaResponse>(StatusCodes.Status200OK)]
public IActionResult GetMeta()
{
return Ok(new MetaResponse(
Repository, BuildInfo.Version, BuildInfo.Hash, (int)FoxnounsMetrics.MemberCount.Value,
new UserInfo(
(int)FoxnounsMetrics.UsersCount.Value,
(int)FoxnounsMetrics.UsersActiveMonthCount.Value,
(int)FoxnounsMetrics.UsersActiveWeekCount.Value,
(int)FoxnounsMetrics.UsersActiveDayCount.Value
),
new Limits(
MemberCount: MembersController.MaxMemberCount,
BioLength: ValidationUtils.MaxBioLength,
CustomPreferences: UsersController.MaxCustomPreferences))
return Ok(
new MetaResponse(
Repository,
BuildInfo.Version,
BuildInfo.Hash,
(int)FoxnounsMetrics.MemberCount.Value,
new UserInfo(
(int)FoxnounsMetrics.UsersCount.Value,
(int)FoxnounsMetrics.UsersActiveMonthCount.Value,
(int)FoxnounsMetrics.UsersActiveWeekCount.Value,
(int)FoxnounsMetrics.UsersActiveDayCount.Value
),
new Limits(
MemberCount: MembersController.MaxMemberCount,
BioLength: ValidationUtils.MaxBioLength,
CustomPreferences: UsersController.MaxCustomPreferences
)
)
);
}
[HttpGet("/api/v2/coffee")]
public IActionResult BrewCoffee() => Problem("Sorry, I'm a teapot!", statusCode: StatusCodes.Status418ImATeapot);
public IActionResult BrewCoffee() =>
Problem("Sorry, I'm a teapot!", statusCode: StatusCodes.Status418ImATeapot);
private record MetaResponse(
string Repository,
@ -36,13 +43,11 @@ public class MetaController : ApiControllerBase
string Hash,
int Members,
UserInfo Users,
Limits Limits);
Limits Limits
);
private record UserInfo(int Total, int ActiveMonth, int ActiveWeek, int ActiveDay);
// All limits that the frontend should know about (for UI purposes)
private record Limits(
int MemberCount,
int BioLength,
int CustomPreferences);
private record Limits(int MemberCount, int BioLength, int CustomPreferences);
}

View file

@ -20,7 +20,8 @@ public class UsersController(
UserRendererService userRenderer,
ISnowflakeGenerator snowflakeGenerator,
IQueue queue,
IClock clock) : ApiControllerBase
IClock clock
) : ApiControllerBase
{
private readonly ILogger _logger = logger.ForContext<UsersController>();
@ -29,20 +30,25 @@ public class UsersController(
public async Task<IActionResult> GetUserAsync(string userRef, CancellationToken ct = default)
{
var user = await db.ResolveUserAsync(userRef, CurrentToken, ct);
return Ok(await userRenderer.RenderUserAsync(
user,
selfUser: CurrentUser,
token: CurrentToken,
renderMembers: true,
renderAuthMethods: true,
ct: ct
));
return Ok(
await userRenderer.RenderUserAsync(
user,
selfUser: CurrentUser,
token: CurrentToken,
renderMembers: true,
renderAuthMethods: true,
ct: ct
)
);
}
[HttpPatch("@me")]
[Authorize("user.update")]
[ProducesResponseType<UserRendererService.UserResponse>(statusCode: StatusCodes.Status200OK)]
public async Task<IActionResult> UpdateUserAsync([FromBody] UpdateUserRequest req, CancellationToken ct = default)
public async Task<IActionResult> UpdateUserAsync(
[FromBody] UpdateUserRequest req,
CancellationToken ct = default
)
{
await using var tx = await db.Database.BeginTransactionAsync(ct);
var user = await db.Users.FirstAsync(u => u.Id == CurrentUser!.Id, ct);
@ -74,26 +80,37 @@ public class UsersController(
if (req.Names != null)
{
errors.AddRange(ValidationUtils.ValidateFieldEntries(req.Names, CurrentUser!.CustomPreferences, "names"));
errors.AddRange(
ValidationUtils.ValidateFieldEntries(
req.Names,
CurrentUser!.CustomPreferences,
"names"
)
);
user.Names = req.Names.ToList();
}
if (req.Pronouns != null)
{
errors.AddRange(ValidationUtils.ValidatePronouns(req.Pronouns, CurrentUser!.CustomPreferences));
errors.AddRange(
ValidationUtils.ValidatePronouns(req.Pronouns, CurrentUser!.CustomPreferences)
);
user.Pronouns = req.Pronouns.ToList();
}
if (req.Fields != null)
{
errors.AddRange(ValidationUtils.ValidateFields(req.Fields.ToList(), CurrentUser!.CustomPreferences));
errors.AddRange(
ValidationUtils.ValidateFields(req.Fields.ToList(), CurrentUser!.CustomPreferences)
);
user.Fields = req.Fields.ToList();
}
if (req.Flags != null)
{
var flagError = await db.SetUserFlagsAsync(CurrentUser!.Id, req.Flags);
if (flagError != null) errors.Add(("flags", flagError));
if (flagError != null)
errors.Add(("flags", flagError));
}
if (req.HasProperty(nameof(req.Avatar)))
@ -105,7 +122,8 @@ public class UsersController(
// so it's in a separate block to the validation above.
if (req.HasProperty(nameof(req.Avatar)))
queue.QueueInvocableWithPayload<UserAvatarUpdateInvocable, AvatarUpdatePayload>(
new AvatarUpdatePayload(CurrentUser!.Id, req.Avatar));
new AvatarUpdatePayload(CurrentUser!.Id, req.Avatar)
);
try
{
@ -113,26 +131,45 @@ public class UsersController(
}
catch (UniqueConstraintException)
{
_logger.Debug("Could not update user {Id} due to name conflict ({CurrentName} / {NewName})", user.Id,
user.Username, req.Username);
throw new ApiError.BadRequest("That username is already taken.", "username", req.Username!);
_logger.Debug(
"Could not update user {Id} due to name conflict ({CurrentName} / {NewName})",
user.Id,
user.Username,
req.Username
);
throw new ApiError.BadRequest(
"That username is already taken.",
"username",
req.Username!
);
}
await tx.CommitAsync(ct);
return Ok(await userRenderer.RenderUserAsync(user, CurrentUser, renderMembers: false,
renderAuthMethods: false, ct: ct));
return Ok(
await userRenderer.RenderUserAsync(
user,
CurrentUser,
renderMembers: false,
renderAuthMethods: false,
ct: ct
)
);
}
[HttpPatch("@me/custom-preferences")]
[Authorize("user.update")]
[ProducesResponseType<Dictionary<Snowflake, User.CustomPreference>>(StatusCodes.Status200OK)]
public async Task<IActionResult> UpdateCustomPreferencesAsync([FromBody] List<CustomPreferencesUpdateRequest> req,
CancellationToken ct = default)
public async Task<IActionResult> UpdateCustomPreferencesAsync(
[FromBody] List<CustomPreferencesUpdateRequest> req,
CancellationToken ct = default
)
{
ValidationUtils.Validate(ValidateCustomPreferences(req));
var user = await db.ResolveUserAsync(CurrentUser!.Id, ct);
var preferences = user.CustomPreferences.Where(x => req.Any(r => r.Id == x.Key)).ToDictionary();
var preferences = user
.CustomPreferences.Where(x => req.Any(r => r.Id == x.Key))
.ToDictionary();
foreach (var r in req)
{
@ -144,7 +181,7 @@ public class UsersController(
Icon = r.Icon,
Muted = r.Muted,
Size = r.Size,
Tooltip = r.Tooltip
Tooltip = r.Tooltip,
};
}
else
@ -155,7 +192,7 @@ public class UsersController(
Icon = r.Icon,
Muted = r.Muted,
Size = r.Size,
Tooltip = r.Tooltip
Tooltip = r.Tooltip,
};
}
}
@ -180,15 +217,25 @@ public class UsersController(
public const int MaxCustomPreferences = 25;
private static List<(string, ValidationError?)> ValidateCustomPreferences(
List<CustomPreferencesUpdateRequest> preferences)
List<CustomPreferencesUpdateRequest> preferences
)
{
var errors = new List<(string, ValidationError?)>();
if (preferences.Count > MaxCustomPreferences)
errors.Add(("custom_preferences",
ValidationError.LengthError("Too many custom preferences", 0, MaxCustomPreferences,
preferences.Count)));
if (preferences.Count > 50) return errors;
errors.Add(
(
"custom_preferences",
ValidationError.LengthError(
"Too many custom preferences",
0,
MaxCustomPreferences,
preferences.Count
)
)
);
if (preferences.Count > 50)
return errors;
// TODO: validate individual preferences
@ -208,7 +255,6 @@ public class UsersController(
public Snowflake[]? Flags { get; init; }
}
[HttpGet("@me/settings")]
[Authorize("user.read_hidden")]
[ProducesResponseType<UserSettings>(statusCode: StatusCodes.Status200OK)]
@ -221,8 +267,10 @@ public class UsersController(
[HttpPatch("@me/settings")]
[Authorize("user.read_hidden", "user.update")]
[ProducesResponseType<UserSettings>(statusCode: StatusCodes.Status200OK)]
public async Task<IActionResult> UpdateUserSettingsAsync([FromBody] UpdateUserSettingsRequest req,
CancellationToken ct = default)
public async Task<IActionResult> UpdateUserSettingsAsync(
[FromBody] UpdateUserSettingsRequest req,
CancellationToken ct = default
)
{
var user = await db.Users.FirstAsync(u => u.Id == CurrentUser!.Id, ct);
@ -250,13 +298,22 @@ public class UsersController(
throw new ApiError.BadRequest("Cannot reroll short ID yet");
// Using ExecuteUpdateAsync here as the new short ID is generated by the database
await db.Users.Where(u => u.Id == CurrentUser.Id)
.ExecuteUpdateAsync(s => s
.SetProperty(u => u.Sid, _ => db.FindFreeUserSid())
.SetProperty(u => u.LastSidReroll, clock.GetCurrentInstant())
.SetProperty(u => u.LastActive, clock.GetCurrentInstant()));
await db
.Users.Where(u => u.Id == CurrentUser.Id)
.ExecuteUpdateAsync(s =>
s.SetProperty(u => u.Sid, _ => db.FindFreeUserSid())
.SetProperty(u => u.LastSidReroll, clock.GetCurrentInstant())
.SetProperty(u => u.LastActive, clock.GetCurrentInstant())
);
var user = await db.ResolveUserAsync(CurrentUser.Id);
return Ok(await userRenderer.RenderUserAsync(user, CurrentUser, CurrentToken, renderMembers: false));
return Ok(
await userRenderer.RenderUserAsync(
user,
CurrentUser,
CurrentToken,
renderMembers: false
)
);
}
}

View file

@ -45,11 +45,12 @@ public class DatabaseContext : DbContext
_loggerFactory = loggerFactory;
}
protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
=> optionsBuilder
protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) =>
optionsBuilder
.ConfigureWarnings(c =>
c.Ignore(CoreEventId.ManyServiceProvidersCreatedWarning)
.Ignore(CoreEventId.SaveChangesFailed))
.Ignore(CoreEventId.SaveChangesFailed)
)
.UseNpgsql(_dataSource, o => o.UseNodaTime())
.UseSnakeCaseNamingConvention()
.UseLoggerFactory(_loggerFactory)
@ -76,7 +77,10 @@ public class DatabaseContext : DbContext
modelBuilder.Entity<User>().Property(u => u.CustomPreferences).HasColumnType("jsonb");
modelBuilder.Entity<User>().Property(u => u.Settings).HasColumnType("jsonb");
modelBuilder.Entity<Member>().Property(m => m.Sid).HasDefaultValueSql("find_free_member_sid()");
modelBuilder
.Entity<Member>()
.Property(m => m.Sid)
.HasDefaultValueSql("find_free_member_sid()");
modelBuilder.Entity<Member>().Property(m => m.Fields).HasColumnType("jsonb");
modelBuilder.Entity<Member>().Property(m => m.Names).HasColumnType("jsonb");
modelBuilder.Entity<Member>().Property(m => m.Pronouns).HasColumnType("jsonb");
@ -84,10 +88,12 @@ public class DatabaseContext : DbContext
modelBuilder.Entity<UserFlag>().Navigation(f => f.PrideFlag).AutoInclude();
modelBuilder.Entity<MemberFlag>().Navigation(f => f.PrideFlag).AutoInclude();
modelBuilder.HasDbFunction(typeof(DatabaseContext).GetMethod(nameof(FindFreeUserSid))!)
modelBuilder
.HasDbFunction(typeof(DatabaseContext).GetMethod(nameof(FindFreeUserSid))!)
.HasName("find_free_user_sid");
modelBuilder.HasDbFunction(typeof(DatabaseContext).GetMethod(nameof(FindFreeMemberSid))!)
modelBuilder
.HasDbFunction(typeof(DatabaseContext).GetMethod(nameof(FindFreeMemberSid))!)
.HasName("find_free_member_sid");
}
@ -102,17 +108,22 @@ public class DatabaseContext : DbContext
public string FindFreeMemberSid() => throw new NotSupportedException();
}
[SuppressMessage("ReSharper", "UnusedType.Global", Justification = "Used by EF Core's migration generator")]
[SuppressMessage(
"ReSharper",
"UnusedType.Global",
Justification = "Used by EF Core's migration generator"
)]
public class DesignTimeDatabaseContextFactory : IDesignTimeDbContextFactory<DatabaseContext>
{
public DatabaseContext CreateDbContext(string[] args)
{
// Read the configuration file
var config = new ConfigurationBuilder()
.AddConfiguration()
.Build()
// Get the configuration as our config class
.Get<Config>() ?? new();
var config =
new ConfigurationBuilder()
.AddConfiguration()
.Build()
// Get the configuration as our config class
.Get<Config>() ?? new();
return new DatabaseContext(config, null);
}

View file

@ -8,89 +8,128 @@ namespace Foxnouns.Backend.Database;
public static class DatabaseQueryExtensions
{
public static async Task<User> ResolveUserAsync(this DatabaseContext context, string userRef, Token? token,
CancellationToken ct = default)
public static async Task<User> ResolveUserAsync(
this DatabaseContext context,
string userRef,
Token? token,
CancellationToken ct = default
)
{
if (userRef == "@me")
{
return token != null
? await context.Users.FirstAsync(u => u.Id == token.UserId, ct)
: throw new ApiError.Unauthorized("This endpoint requires an authenticated user.",
ErrorCode.AuthenticationRequired);
: throw new ApiError.Unauthorized(
"This endpoint requires an authenticated user.",
ErrorCode.AuthenticationRequired
);
}
User? user;
if (Snowflake.TryParse(userRef, out var snowflake))
{
user = await context.Users
.Where(u => !u.Deleted)
user = await context
.Users.Where(u => !u.Deleted)
.FirstOrDefaultAsync(u => u.Id == snowflake, ct);
if (user != null) return user;
if (user != null)
return user;
}
user = await context.Users
.Where(u => !u.Deleted)
user = await context
.Users.Where(u => !u.Deleted)
.FirstOrDefaultAsync(u => u.Username == userRef, ct);
if (user != null) return user;
throw new ApiError.NotFound("No user with that ID or username found.", code: ErrorCode.UserNotFound);
if (user != null)
return user;
throw new ApiError.NotFound(
"No user with that ID or username found.",
code: ErrorCode.UserNotFound
);
}
public static async Task<User> ResolveUserAsync(this DatabaseContext context, Snowflake id,
CancellationToken ct = default)
public static async Task<User> ResolveUserAsync(
this DatabaseContext context,
Snowflake id,
CancellationToken ct = default
)
{
var user = await context.Users
.Where(u => !u.Deleted)
var user = await context
.Users.Where(u => !u.Deleted)
.FirstOrDefaultAsync(u => u.Id == id, ct);
if (user != null) return user;
if (user != null)
return user;
throw new ApiError.NotFound("No user with that ID found.", code: ErrorCode.UserNotFound);
}
public static async Task<Member> ResolveMemberAsync(this DatabaseContext context, Snowflake id,
CancellationToken ct = default)
public static async Task<Member> ResolveMemberAsync(
this DatabaseContext context,
Snowflake id,
CancellationToken ct = default
)
{
var member = await context.Members
.Include(m => m.User)
var member = await context
.Members.Include(m => m.User)
.Where(m => !m.User.Deleted)
.FirstOrDefaultAsync(m => m.Id == id, ct);
if (member != null) return member;
throw new ApiError.NotFound("No member with that ID found.", code: ErrorCode.MemberNotFound);
if (member != null)
return member;
throw new ApiError.NotFound(
"No member with that ID found.",
code: ErrorCode.MemberNotFound
);
}
public static async Task<Member> ResolveMemberAsync(this DatabaseContext context, string userRef, string memberRef,
Token? token, CancellationToken ct = default)
public static async Task<Member> ResolveMemberAsync(
this DatabaseContext context,
string userRef,
string memberRef,
Token? token,
CancellationToken ct = default
)
{
var user = await context.ResolveUserAsync(userRef, token, ct);
return await context.ResolveMemberAsync(user.Id, memberRef, ct);
}
public static async Task<Member> ResolveMemberAsync(this DatabaseContext context, Snowflake userId,
string memberRef, CancellationToken ct = default)
public static async Task<Member> ResolveMemberAsync(
this DatabaseContext context,
Snowflake userId,
string memberRef,
CancellationToken ct = default
)
{
Member? member;
if (Snowflake.TryParse(memberRef, out var snowflake))
{
member = await context.Members
.Include(m => m.User)
member = await context
.Members.Include(m => m.User)
.Include(m => m.ProfileFlags)
.Where(m => !m.User.Deleted)
.FirstOrDefaultAsync(m => m.Id == snowflake && m.UserId == userId, ct);
if (member != null) return member;
if (member != null)
return member;
}
member = await context.Members
.Include(m => m.User)
member = await context
.Members.Include(m => m.User)
.Include(m => m.ProfileFlags)
.Where(m => !m.User.Deleted)
.FirstOrDefaultAsync(m => m.Name == memberRef && m.UserId == userId, ct);
if (member != null) return member;
throw new ApiError.NotFound("No member with that ID or name found.", code: ErrorCode.MemberNotFound);
if (member != null)
return member;
throw new ApiError.NotFound(
"No member with that ID or name found.",
code: ErrorCode.MemberNotFound
);
}
public static async Task<Application> GetFrontendApplicationAsync(this DatabaseContext context,
CancellationToken ct = default)
public static async Task<Application> GetFrontendApplicationAsync(
this DatabaseContext context,
CancellationToken ct = default
)
{
var app = await context.Applications.FirstOrDefaultAsync(a => a.Id == new Snowflake(0), ct);
if (app != null) return app;
if (app != null)
return app;
app = new Application
{
@ -107,27 +146,42 @@ public static class DatabaseQueryExtensions
return app;
}
public static async Task<Token?> GetToken(this DatabaseContext context, byte[] rawToken,
CancellationToken ct = default)
public static async Task<Token?> GetToken(
this DatabaseContext context,
byte[] rawToken,
CancellationToken ct = default
)
{
var hash = SHA512.HashData(rawToken);
var oauthToken = await context.Tokens
.Include(t => t.Application)
var oauthToken = await context
.Tokens.Include(t => t.Application)
.Include(t => t.User)
.FirstOrDefaultAsync(
t => t.Hash == hash && t.ExpiresAt > SystemClock.Instance.GetCurrentInstant() && !t.ManuallyExpired,
ct);
t =>
t.Hash == hash
&& t.ExpiresAt > SystemClock.Instance.GetCurrentInstant()
&& !t.ManuallyExpired,
ct
);
return oauthToken;
}
public static async Task<Snowflake?> GetTokenUserId(this DatabaseContext context, byte[] rawToken,
CancellationToken ct = default)
public static async Task<Snowflake?> GetTokenUserId(
this DatabaseContext context,
byte[] rawToken,
CancellationToken ct = default
)
{
var hash = SHA512.HashData(rawToken);
return await context.Tokens
.Where(t => t.Hash == hash && t.ExpiresAt > SystemClock.Instance.GetCurrentInstant() && !t.ManuallyExpired)
.Select(t => t.UserId).FirstOrDefaultAsync(ct);
return await context
.Tokens.Where(t =>
t.Hash == hash
&& t.ExpiresAt > SystemClock.Instance.GetCurrentInstant()
&& !t.ManuallyExpired
)
.Select(t => t.UserId)
.FirstOrDefaultAsync(ct);
}
}

View file

@ -5,23 +5,30 @@ namespace Foxnouns.Backend.Database;
public static class FlagQueryExtensions
{
private static async Task<List<PrideFlag>> GetFlagsAsync(this DatabaseContext db, Snowflake userId) =>
await db.PrideFlags.Where(f => f.UserId == userId).OrderBy(f => f.Id).ToListAsync();
private static async Task<List<PrideFlag>> GetFlagsAsync(
this DatabaseContext db,
Snowflake userId
) => await db.PrideFlags.Where(f => f.UserId == userId).OrderBy(f => f.Id).ToListAsync();
/// <summary>
/// Sets the user's profile flags to the given IDs. Returns a validation error if any of the flag IDs are unknown
/// or if too many IDs are given. Duplicates are allowed.
/// </summary>
public static async Task<ValidationError?> SetUserFlagsAsync(this DatabaseContext db, Snowflake userId,
Snowflake[] flagIds)
public static async Task<ValidationError?> SetUserFlagsAsync(
this DatabaseContext db,
Snowflake userId,
Snowflake[] flagIds
)
{
var currentFlags = await db.UserFlags.Where(f => f.UserId == userId).ToListAsync();
foreach (var flag in currentFlags)
db.UserFlags.Remove(flag);
// If there's no new flags to set, we're done
if (flagIds.Length == 0) return null;
if (flagIds.Length > 100) return ValidationError.LengthError("Too many profile flags", 0, 100, flagIds.Length);
if (flagIds.Length == 0)
return null;
if (flagIds.Length > 100)
return ValidationError.LengthError("Too many profile flags", 0, 100, flagIds.Length);
var flags = await db.GetFlagsAsync(userId);
var unknownFlagIds = flagIds.Where(id => flags.All(f => f.Id != id)).ToArray();
@ -34,22 +41,32 @@ public static class FlagQueryExtensions
return null;
}
public static async Task<ValidationError?> SetMemberFlagsAsync(this DatabaseContext db, Snowflake userId,
Snowflake memberId, Snowflake[] flagIds)
public static async Task<ValidationError?> SetMemberFlagsAsync(
this DatabaseContext db,
Snowflake userId,
Snowflake memberId,
Snowflake[] flagIds
)
{
var currentFlags = await db.MemberFlags.Where(f => f.MemberId == memberId).ToListAsync();
foreach (var flag in currentFlags)
db.MemberFlags.Remove(flag);
if (flagIds.Length == 0) return null;
if (flagIds.Length > 100) return ValidationError.LengthError("Too many profile flags", 0, 100, flagIds.Length);
if (flagIds.Length == 0)
return null;
if (flagIds.Length > 100)
return ValidationError.LengthError("Too many profile flags", 0, 100, flagIds.Length);
var flags = await db.GetFlagsAsync(userId);
var unknownFlagIds = flagIds.Where(id => flags.All(f => f.Id != id)).ToArray();
if (unknownFlagIds.Length != 0)
return ValidationError.GenericValidationError("Unknown flag IDs", unknownFlagIds);
var memberFlags = flagIds.Select(id => new MemberFlag { PrideFlagId = id, MemberId = memberId });
var memberFlags = flagIds.Select(id => new MemberFlag
{
PrideFlagId = id,
MemberId = memberId,
});
db.MemberFlags.AddRange(memberFlags);
return null;

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Migrations;
using NodaTime;
#nullable disable
@ -22,12 +22,13 @@ namespace Foxnouns.Backend.Database.Migrations
domain = table.Column<string>(type: "text", nullable: false),
client_id = table.Column<string>(type: "text", nullable: false),
client_secret = table.Column<string>(type: "text", nullable: false),
instance_type = table.Column<int>(type: "integer", nullable: false)
instance_type = table.Column<int>(type: "integer", nullable: false),
},
constraints: table =>
{
table.PrimaryKey("pk_fediverse_applications", x => x.id);
});
}
);
migrationBuilder.CreateTable(
name: "users",
@ -43,12 +44,13 @@ namespace Foxnouns.Backend.Database.Migrations
role = table.Column<int>(type: "integer", nullable: false),
fields = table.Column<string>(type: "jsonb", nullable: false),
names = table.Column<string>(type: "jsonb", nullable: false),
pronouns = table.Column<string>(type: "jsonb", nullable: false)
pronouns = table.Column<string>(type: "jsonb", nullable: false),
},
constraints: table =>
{
table.PrimaryKey("pk_users", x => x.id);
});
}
);
migrationBuilder.CreateTable(
name: "auth_methods",
@ -59,7 +61,7 @@ namespace Foxnouns.Backend.Database.Migrations
remote_id = table.Column<string>(type: "text", nullable: false),
remote_username = table.Column<string>(type: "text", nullable: true),
user_id = table.Column<long>(type: "bigint", nullable: false),
fediverse_application_id = table.Column<long>(type: "bigint", nullable: true)
fediverse_application_id = table.Column<long>(type: "bigint", nullable: true),
},
constraints: table =>
{
@ -68,14 +70,17 @@ namespace Foxnouns.Backend.Database.Migrations
name: "fk_auth_methods_fediverse_applications_fediverse_application_id",
column: x => x.fediverse_application_id,
principalTable: "fediverse_applications",
principalColumn: "id");
principalColumn: "id"
);
table.ForeignKey(
name: "fk_auth_methods_users_user_id",
column: x => x.user_id,
principalTable: "users",
principalColumn: "id",
onDelete: ReferentialAction.Cascade);
});
onDelete: ReferentialAction.Cascade
);
}
);
migrationBuilder.CreateTable(
name: "members",
@ -91,7 +96,7 @@ namespace Foxnouns.Backend.Database.Migrations
user_id = table.Column<long>(type: "bigint", nullable: false),
fields = table.Column<string>(type: "jsonb", nullable: false),
names = table.Column<string>(type: "jsonb", nullable: false),
pronouns = table.Column<string>(type: "jsonb", nullable: false)
pronouns = table.Column<string>(type: "jsonb", nullable: false),
},
constraints: table =>
{
@ -101,18 +106,23 @@ namespace Foxnouns.Backend.Database.Migrations
column: x => x.user_id,
principalTable: "users",
principalColumn: "id",
onDelete: ReferentialAction.Cascade);
});
onDelete: ReferentialAction.Cascade
);
}
);
migrationBuilder.CreateTable(
name: "tokens",
columns: table => new
{
id = table.Column<long>(type: "bigint", nullable: false),
expires_at = table.Column<Instant>(type: "timestamp with time zone", nullable: false),
expires_at = table.Column<Instant>(
type: "timestamp with time zone",
nullable: false
),
scopes = table.Column<string[]>(type: "text[]", nullable: false),
manually_expired = table.Column<bool>(type: "boolean", nullable: false),
user_id = table.Column<long>(type: "bigint", nullable: false)
user_id = table.Column<long>(type: "bigint", nullable: false),
},
constraints: table =>
{
@ -122,53 +132,56 @@ namespace Foxnouns.Backend.Database.Migrations
column: x => x.user_id,
principalTable: "users",
principalColumn: "id",
onDelete: ReferentialAction.Cascade);
});
onDelete: ReferentialAction.Cascade
);
}
);
migrationBuilder.CreateIndex(
name: "ix_auth_methods_fediverse_application_id",
table: "auth_methods",
column: "fediverse_application_id");
column: "fediverse_application_id"
);
migrationBuilder.CreateIndex(
name: "ix_auth_methods_user_id",
table: "auth_methods",
column: "user_id");
column: "user_id"
);
// EF Core doesn't support creating indexes on arbitrary expressions, so we have to create it manually.
// Due to historical reasons (I made a mistake while writing the initial migration for the Go version)
// only members have case-insensitive names.
migrationBuilder.Sql("CREATE UNIQUE INDEX ix_members_user_id_name ON members (user_id, lower(name))");
migrationBuilder.Sql(
"CREATE UNIQUE INDEX ix_members_user_id_name ON members (user_id, lower(name))"
);
migrationBuilder.CreateIndex(
name: "ix_tokens_user_id",
table: "tokens",
column: "user_id");
column: "user_id"
);
migrationBuilder.CreateIndex(
name: "ix_users_username",
table: "users",
column: "username",
unique: true);
unique: true
);
}
/// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder)
{
migrationBuilder.DropTable(
name: "auth_methods");
migrationBuilder.DropTable(name: "auth_methods");
migrationBuilder.DropTable(
name: "members");
migrationBuilder.DropTable(name: "members");
migrationBuilder.DropTable(
name: "tokens");
migrationBuilder.DropTable(name: "tokens");
migrationBuilder.DropTable(
name: "fediverse_applications");
migrationBuilder.DropTable(name: "fediverse_applications");
migrationBuilder.DropTable(
name: "users");
migrationBuilder.DropTable(name: "users");
}
}
}

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Migrations;
#nullable disable
@ -18,14 +18,16 @@ namespace Foxnouns.Backend.Database.Migrations
table: "tokens",
type: "bigint",
nullable: false,
defaultValue: 0L);
defaultValue: 0L
);
migrationBuilder.AddColumn<byte[]>(
name: "hash",
table: "tokens",
type: "bytea",
nullable: false,
defaultValue: new byte[0]);
defaultValue: new byte[0]
);
migrationBuilder.CreateTable(
name: "applications",
@ -36,17 +38,19 @@ namespace Foxnouns.Backend.Database.Migrations
client_secret = table.Column<string>(type: "text", nullable: false),
name = table.Column<string>(type: "text", nullable: false),
scopes = table.Column<string[]>(type: "text[]", nullable: false),
redirect_uris = table.Column<string[]>(type: "text[]", nullable: false)
redirect_uris = table.Column<string[]>(type: "text[]", nullable: false),
},
constraints: table =>
{
table.PrimaryKey("pk_applications", x => x.id);
});
}
);
migrationBuilder.CreateIndex(
name: "ix_tokens_application_id",
table: "tokens",
column: "application_id");
column: "application_id"
);
migrationBuilder.AddForeignKey(
name: "fk_tokens_applications_application_id",
@ -54,7 +58,8 @@ namespace Foxnouns.Backend.Database.Migrations
column: "application_id",
principalTable: "applications",
principalColumn: "id",
onDelete: ReferentialAction.Cascade);
onDelete: ReferentialAction.Cascade
);
}
/// <inheritdoc />
@ -62,22 +67,16 @@ namespace Foxnouns.Backend.Database.Migrations
{
migrationBuilder.DropForeignKey(
name: "fk_tokens_applications_application_id",
table: "tokens");
table: "tokens"
);
migrationBuilder.DropTable(
name: "applications");
migrationBuilder.DropTable(name: "applications");
migrationBuilder.DropIndex(
name: "ix_tokens_application_id",
table: "tokens");
migrationBuilder.DropIndex(name: "ix_tokens_application_id", table: "tokens");
migrationBuilder.DropColumn(
name: "application_id",
table: "tokens");
migrationBuilder.DropColumn(name: "application_id", table: "tokens");
migrationBuilder.DropColumn(
name: "hash",
table: "tokens");
migrationBuilder.DropColumn(name: "hash", table: "tokens");
}
}
}

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Migrations;
#nullable disable
@ -18,15 +18,14 @@ namespace Foxnouns.Backend.Database.Migrations
table: "users",
type: "boolean",
nullable: false,
defaultValue: false);
defaultValue: false
);
}
/// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder)
{
migrationBuilder.DropColumn(
name: "list_hidden",
table: "users");
migrationBuilder.DropColumn(name: "list_hidden", table: "users");
}
}
}

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Migrations;
#nullable disable
@ -17,15 +17,14 @@ namespace Foxnouns.Backend.Database.Migrations
name: "password",
table: "users",
type: "text",
nullable: true);
nullable: true
);
}
/// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder)
{
migrationBuilder.DropColumn(
name: "password",
table: "users");
migrationBuilder.DropColumn(name: "password", table: "users");
}
}
}

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Migrations;
using NodaTime;
using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata;
@ -19,29 +19,37 @@ namespace Foxnouns.Backend.Database.Migrations
name: "temporary_keys",
columns: table => new
{
id = table.Column<long>(type: "bigint", nullable: false)
.Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn),
id = table
.Column<long>(type: "bigint", nullable: false)
.Annotation(
"Npgsql:ValueGenerationStrategy",
NpgsqlValueGenerationStrategy.IdentityByDefaultColumn
),
key = table.Column<string>(type: "text", nullable: false),
value = table.Column<string>(type: "text", nullable: false),
expires = table.Column<Instant>(type: "timestamp with time zone", nullable: false)
expires = table.Column<Instant>(
type: "timestamp with time zone",
nullable: false
),
},
constraints: table =>
{
table.PrimaryKey("pk_temporary_keys", x => x.id);
});
}
);
migrationBuilder.CreateIndex(
name: "ix_temporary_keys_key",
table: "temporary_keys",
column: "key",
unique: true);
unique: true
);
}
/// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder)
{
migrationBuilder.DropTable(
name: "temporary_keys");
migrationBuilder.DropTable(name: "temporary_keys");
}
}
}

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Migrations;
using NodaTime;
#nullable disable
@ -19,15 +19,14 @@ namespace Foxnouns.Backend.Database.Migrations
table: "users",
type: "timestamp with time zone",
nullable: false,
defaultValueSql: "now()");
defaultValueSql: "now()"
);
}
/// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder)
{
migrationBuilder.DropColumn(
name: "last_active",
table: "users");
migrationBuilder.DropColumn(name: "last_active", table: "users");
}
}
}

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Migrations;
using NodaTime;
#nullable disable
@ -19,35 +19,32 @@ namespace Foxnouns.Backend.Database.Migrations
table: "users",
type: "boolean",
nullable: false,
defaultValue: false);
defaultValue: false
);
migrationBuilder.AddColumn<Instant>(
name: "deleted_at",
table: "users",
type: "timestamp with time zone",
nullable: true);
nullable: true
);
migrationBuilder.AddColumn<long>(
name: "deleted_by",
table: "users",
type: "bigint",
nullable: true);
nullable: true
);
}
/// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder)
{
migrationBuilder.DropColumn(
name: "deleted",
table: "users");
migrationBuilder.DropColumn(name: "deleted", table: "users");
migrationBuilder.DropColumn(
name: "deleted_at",
table: "users");
migrationBuilder.DropColumn(name: "deleted_at", table: "users");
migrationBuilder.DropColumn(
name: "deleted_by",
table: "users");
migrationBuilder.DropColumn(name: "deleted_by", table: "users");
}
}
}

View file

@ -1,7 +1,7 @@
using System;
using Microsoft.EntityFrameworkCore.Infrastructure;
using System.Collections.Generic;
using Foxnouns.Backend.Database.Models;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Migrations;
#nullable disable
@ -21,15 +21,14 @@ namespace Foxnouns.Backend.Database.Migrations
table: "users",
type: "jsonb",
nullable: false,
defaultValueSql: "'{}'");
defaultValueSql: "'{}'"
);
}
/// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder)
{
migrationBuilder.DropColumn(
name: "custom_preferences",
table: "users");
migrationBuilder.DropColumn(name: "custom_preferences", table: "users");
}
}
}

View file

@ -19,15 +19,14 @@ namespace Foxnouns.Backend.Database.Migrations
table: "users",
type: "jsonb",
nullable: false,
defaultValueSql: "'{}'");
defaultValueSql: "'{}'"
);
}
/// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder)
{
migrationBuilder.DropColumn(
name: "settings",
table: "users");
migrationBuilder.DropColumn(name: "settings", table: "users");
}
}
}

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Migrations;
using NodaTime;
#nullable disable
@ -18,38 +18,46 @@ namespace Foxnouns.Backend.Database.Migrations
name: "sid",
table: "users",
type: "text",
nullable: true);
nullable: true
);
migrationBuilder.AddColumn<Instant>(
name: "last_sid_reroll",
table: "users",
type: "timestamp with time zone",
nullable: false,
defaultValueSql: "now() - '1 hour'::interval");
defaultValueSql: "now() - '1 hour'::interval"
);
migrationBuilder.AddColumn<string>(
name: "sid",
table: "members",
type: "text",
nullable: true);
nullable: true
);
migrationBuilder.CreateIndex(
name: "ix_users_sid",
table: "users",
column: "sid",
unique: true);
unique: true
);
migrationBuilder.CreateIndex(
name: "ix_members_sid",
table: "members",
column: "sid",
unique: true);
unique: true
);
migrationBuilder.Sql(@"create function generate_sid(len int) returns text as $$
migrationBuilder.Sql(
@"create function generate_sid(len int) returns text as $$
select string_agg(substr('abcdefghijklmnopqrstuvwxyz', ceil(random() * 26)::integer, 1), '') from generate_series(1, len)
$$ language sql volatile;
");
migrationBuilder.Sql(@"create function find_free_user_sid() returns text as $$
"
);
migrationBuilder.Sql(
@"create function find_free_user_sid() returns text as $$
declare new_sid text;
begin
loop
@ -58,8 +66,10 @@ begin
end loop;
end
$$ language plpgsql volatile;
");
migrationBuilder.Sql(@"create function find_free_member_sid() returns text as $$
"
);
migrationBuilder.Sql(
@"create function find_free_member_sid() returns text as $$
declare new_sid text;
begin
loop
@ -67,7 +77,8 @@ begin
if not exists (select 1 from members where sid = new_sid) then return new_sid; end if;
end loop;
end
$$ language plpgsql volatile;");
$$ language plpgsql volatile;"
);
}
/// <inheritdoc />
@ -77,25 +88,15 @@ $$ language plpgsql volatile;");
migrationBuilder.Sql("drop function find_free_user_sid;");
migrationBuilder.Sql("drop function generate_sid;");
migrationBuilder.DropIndex(
name: "ix_users_sid",
table: "users");
migrationBuilder.DropIndex(name: "ix_users_sid", table: "users");
migrationBuilder.DropIndex(
name: "ix_members_sid",
table: "members");
migrationBuilder.DropIndex(name: "ix_members_sid", table: "members");
migrationBuilder.DropColumn(
name: "sid",
table: "users");
migrationBuilder.DropColumn(name: "sid", table: "users");
migrationBuilder.DropColumn(
name: "last_sid_reroll",
table: "users");
migrationBuilder.DropColumn(name: "last_sid_reroll", table: "users");
migrationBuilder.DropColumn(
name: "sid",
table: "members");
migrationBuilder.DropColumn(name: "sid", table: "members");
}
}
}

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Migrations;
using NodaTime;
#nullable disable
@ -22,7 +22,8 @@ namespace Foxnouns.Backend.Database.Migrations
defaultValueSql: "find_free_user_sid()",
oldClrType: typeof(string),
oldType: "text",
oldNullable: true);
oldNullable: true
);
migrationBuilder.AlterColumn<string>(
name: "sid",
@ -32,7 +33,8 @@ namespace Foxnouns.Backend.Database.Migrations
defaultValueSql: "find_free_member_sid()",
oldClrType: typeof(string),
oldType: "text",
oldNullable: true);
oldNullable: true
);
}
/// <inheritdoc />
@ -45,7 +47,8 @@ namespace Foxnouns.Backend.Database.Migrations
nullable: true,
oldClrType: typeof(string),
oldType: "text",
oldDefaultValueSql: "find_free_user_sid()");
oldDefaultValueSql: "find_free_user_sid()"
);
migrationBuilder.AlterColumn<string>(
name: "sid",
@ -54,7 +57,8 @@ namespace Foxnouns.Backend.Database.Migrations
nullable: true,
oldClrType: typeof(string),
oldType: "text",
oldDefaultValueSql: "find_free_member_sid()");
oldDefaultValueSql: "find_free_member_sid()"
);
}
}
}

View file

@ -1,5 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Migrations;
using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata;
#nullable disable
@ -22,7 +22,7 @@ namespace Foxnouns.Backend.Database.Migrations
user_id = table.Column<long>(type: "bigint", nullable: false),
hash = table.Column<string>(type: "text", nullable: false),
name = table.Column<string>(type: "text", nullable: false),
description = table.Column<string>(type: "text", nullable: true)
description = table.Column<string>(type: "text", nullable: true),
},
constraints: table =>
{
@ -32,17 +32,23 @@ namespace Foxnouns.Backend.Database.Migrations
column: x => x.user_id,
principalTable: "users",
principalColumn: "id",
onDelete: ReferentialAction.Cascade);
});
onDelete: ReferentialAction.Cascade
);
}
);
migrationBuilder.CreateTable(
name: "member_flags",
columns: table => new
{
id = table.Column<long>(type: "bigint", nullable: false)
.Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn),
id = table
.Column<long>(type: "bigint", nullable: false)
.Annotation(
"Npgsql:ValueGenerationStrategy",
NpgsqlValueGenerationStrategy.IdentityByDefaultColumn
),
member_id = table.Column<long>(type: "bigint", nullable: false),
pride_flag_id = table.Column<long>(type: "bigint", nullable: false)
pride_flag_id = table.Column<long>(type: "bigint", nullable: false),
},
constraints: table =>
{
@ -52,23 +58,30 @@ namespace Foxnouns.Backend.Database.Migrations
column: x => x.member_id,
principalTable: "members",
principalColumn: "id",
onDelete: ReferentialAction.Cascade);
onDelete: ReferentialAction.Cascade
);
table.ForeignKey(
name: "fk_member_flags_pride_flags_pride_flag_id",
column: x => x.pride_flag_id,
principalTable: "pride_flags",
principalColumn: "id",
onDelete: ReferentialAction.Cascade);
});
onDelete: ReferentialAction.Cascade
);
}
);
migrationBuilder.CreateTable(
name: "user_flags",
columns: table => new
{
id = table.Column<long>(type: "bigint", nullable: false)
.Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn),
id = table
.Column<long>(type: "bigint", nullable: false)
.Annotation(
"Npgsql:ValueGenerationStrategy",
NpgsqlValueGenerationStrategy.IdentityByDefaultColumn
),
user_id = table.Column<long>(type: "bigint", nullable: false),
pride_flag_id = table.Column<long>(type: "bigint", nullable: false)
pride_flag_id = table.Column<long>(type: "bigint", nullable: false),
},
constraints: table =>
{
@ -78,52 +91,57 @@ namespace Foxnouns.Backend.Database.Migrations
column: x => x.pride_flag_id,
principalTable: "pride_flags",
principalColumn: "id",
onDelete: ReferentialAction.Cascade);
onDelete: ReferentialAction.Cascade
);
table.ForeignKey(
name: "fk_user_flags_users_user_id",
column: x => x.user_id,
principalTable: "users",
principalColumn: "id",
onDelete: ReferentialAction.Cascade);
});
onDelete: ReferentialAction.Cascade
);
}
);
migrationBuilder.CreateIndex(
name: "ix_member_flags_member_id",
table: "member_flags",
column: "member_id");
column: "member_id"
);
migrationBuilder.CreateIndex(
name: "ix_member_flags_pride_flag_id",
table: "member_flags",
column: "pride_flag_id");
column: "pride_flag_id"
);
migrationBuilder.CreateIndex(
name: "ix_pride_flags_user_id",
table: "pride_flags",
column: "user_id");
column: "user_id"
);
migrationBuilder.CreateIndex(
name: "ix_user_flags_pride_flag_id",
table: "user_flags",
column: "pride_flag_id");
column: "pride_flag_id"
);
migrationBuilder.CreateIndex(
name: "ix_user_flags_user_id",
table: "user_flags",
column: "user_id");
column: "user_id"
);
}
/// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder)
{
migrationBuilder.DropTable(
name: "member_flags");
migrationBuilder.DropTable(name: "member_flags");
migrationBuilder.DropTable(
name: "user_flags");
migrationBuilder.DropTable(name: "user_flags");
migrationBuilder.DropTable(
name: "pride_flags");
migrationBuilder.DropTable(name: "pride_flags");
}
}
}

View file

@ -11,20 +11,30 @@ public class Application : BaseModel
public required string[] Scopes { get; init; }
public required string[] RedirectUris { get; init; }
public static Application Create(ISnowflakeGenerator snowflakeGenerator, string name, string[] scopes,
string[] redirectUrls)
public static Application Create(
ISnowflakeGenerator snowflakeGenerator,
string name,
string[] scopes,
string[] redirectUrls
)
{
var clientId = RandomNumberGenerator.GetHexString(32, true);
var clientSecret = AuthUtils.RandomToken();
if (scopes.Except(AuthUtils.ApplicationScopes).Any())
{
throw new ArgumentException("Invalid scopes passed to Application.Create", nameof(scopes));
throw new ArgumentException(
"Invalid scopes passed to Application.Create",
nameof(scopes)
);
}
if (redirectUrls.Any(s => !AuthUtils.ValidateRedirectUri(s)))
{
throw new ArgumentException("Invalid redirect URLs passed to Application.Create", nameof(redirectUrls));
throw new ArgumentException(
"Invalid redirect URLs passed to Application.Create",
nameof(redirectUrls)
);
}
return new Application
@ -34,7 +44,7 @@ public class Application : BaseModel
ClientSecret = clientSecret,
Name = name,
Scopes = scopes,
RedirectUris = redirectUrls
RedirectUris = redirectUrls,
};
}
}

View file

@ -11,5 +11,5 @@ public class FediverseApplication : BaseModel
public enum FediverseInstanceType
{
MastodonApi,
MisskeyApi
MisskeyApi,
}

View file

@ -37,7 +37,9 @@ public class User : BaseModel
public bool Deleted { get; set; }
public Instant? DeletedAt { get; set; }
public Snowflake? DeletedBy { get; set; }
[NotMapped] public bool? SelfDelete => Deleted ? DeletedBy != null : null;
[NotMapped]
public bool? SelfDelete => Deleted ? DeletedBy != null : null;
public class CustomPreference
{

View file

@ -41,19 +41,26 @@ public readonly struct Snowflake(ulong value) : IEquatable<Snowflake>
public short Increment => (short)(Value & 0xFFF);
public static bool operator <(Snowflake arg1, Snowflake arg2) => arg1.Value < arg2.Value;
public static bool operator >(Snowflake arg1, Snowflake arg2) => arg1.Value > arg2.Value;
public static bool operator ==(Snowflake arg1, Snowflake arg2) => arg1.Value == arg2.Value;
public static bool operator !=(Snowflake arg1, Snowflake arg2) => arg1.Value != arg2.Value;
public static implicit operator ulong(Snowflake s) => s.Value;
public static implicit operator long(Snowflake s) => (long)s.Value;
public static implicit operator Snowflake(ulong n) => new(n);
public static implicit operator Snowflake(long n) => new((ulong)n);
public static bool TryParse(string input, [NotNullWhen(true)] out Snowflake? snowflake)
{
snowflake = null;
if (!ulong.TryParse(input, out var res)) return false;
if (!ulong.TryParse(input, out var res))
return false;
snowflake = new Snowflake(res);
return true;
}
@ -66,27 +73,37 @@ public readonly struct Snowflake(ulong value) : IEquatable<Snowflake>
}
public override int GetHashCode() => Value.GetHashCode();
public override string ToString() => Value.ToString();
/// <summary>
/// An Entity Framework ValueConverter for Snowflakes to longs.
/// </summary>
// ReSharper disable once ClassNeverInstantiated.Global
public class ValueConverter() : ValueConverter<Snowflake, long>(
convertToProviderExpression: x => x,
convertFromProviderExpression: x => x
);
public class ValueConverter()
: ValueConverter<Snowflake, long>(
convertToProviderExpression: x => x,
convertFromProviderExpression: x => x
);
private class JsonConverter : JsonConverter<Snowflake>
{
public override void WriteJson(JsonWriter writer, Snowflake value, JsonSerializer serializer)
public override void WriteJson(
JsonWriter writer,
Snowflake value,
JsonSerializer serializer
)
{
writer.WriteValue(value.Value.ToString());
}
public override Snowflake ReadJson(JsonReader reader, Type objectType, Snowflake existingValue,
public override Snowflake ReadJson(
JsonReader reader,
Type objectType,
Snowflake existingValue,
bool hasExistingValue,
JsonSerializer serializer)
JsonSerializer serializer
)
{
return ulong.Parse((string)reader.Value!);
}
@ -97,10 +114,16 @@ public readonly struct Snowflake(ulong value) : IEquatable<Snowflake>
public override bool CanConvertFrom(ITypeDescriptorContext? context, Type sourceType) =>
sourceType == typeof(string);
public override bool CanConvertTo(ITypeDescriptorContext? context, [NotNullWhen(true)] Type? destinationType) =>
destinationType == typeof(Snowflake);
public override bool CanConvertTo(
ITypeDescriptorContext? context,
[NotNullWhen(true)] Type? destinationType
) => destinationType == typeof(Snowflake);
public override object? ConvertFrom(ITypeDescriptorContext? context, CultureInfo? culture, object value)
public override object? ConvertFrom(
ITypeDescriptorContext? context,
CultureInfo? culture,
object value
)
{
return TryParse((string)value, out var snowflake) ? snowflake : null;
}

View file

@ -32,13 +32,19 @@ public class SnowflakeGenerator : ISnowflakeGenerator
var threadId = Environment.CurrentManagedThreadId % 32;
var timestamp = time.Value.ToUnixTimeMilliseconds() - Snowflake.Epoch;
return (timestamp << 22) | (uint)(_processId << 17) | (uint)(threadId << 12) | (increment % 4096);
return (timestamp << 22)
| (uint)(_processId << 17)
| (uint)(threadId << 12)
| (increment % 4096);
}
}
public static class SnowflakeGeneratorServiceExtensions
{
public static IServiceCollection AddSnowflakeGenerator(this IServiceCollection services, int? processId = null)
public static IServiceCollection AddSnowflakeGenerator(
this IServiceCollection services,
int? processId = null
)
{
return services.AddSingleton<ISnowflakeGenerator>(new SnowflakeGenerator(processId));
}

View file

@ -9,39 +9,47 @@ public class FoxnounsError(string message, Exception? inner = null) : Exception(
{
public Exception? Inner => inner;
public class DatabaseError(string message, Exception? inner = null) : FoxnounsError(message, inner);
public class DatabaseError(string message, Exception? inner = null)
: FoxnounsError(message, inner);
public class UnknownEntityError(Type entityType, Exception? inner = null)
: DatabaseError($"Entity of type {entityType.Name} not found", inner);
}
public class ApiError(string message, HttpStatusCode? statusCode = null, ErrorCode? errorCode = null)
: FoxnounsError(message)
public class ApiError(
string message,
HttpStatusCode? statusCode = null,
ErrorCode? errorCode = null
) : FoxnounsError(message)
{
public readonly HttpStatusCode StatusCode = statusCode ?? HttpStatusCode.InternalServerError;
public readonly ErrorCode ErrorCode = errorCode ?? ErrorCode.InternalServerError;
public class Unauthorized(string message, ErrorCode errorCode = ErrorCode.AuthenticationError) : ApiError(message,
statusCode: HttpStatusCode.Unauthorized,
errorCode: errorCode);
public class Unauthorized(string message, ErrorCode errorCode = ErrorCode.AuthenticationError)
: ApiError(message, statusCode: HttpStatusCode.Unauthorized, errorCode: errorCode);
public class Forbidden(
string message,
IEnumerable<string>? scopes = null,
ErrorCode errorCode = ErrorCode.Forbidden)
: ApiError(message, statusCode: HttpStatusCode.Forbidden, errorCode: errorCode)
ErrorCode errorCode = ErrorCode.Forbidden
) : ApiError(message, statusCode: HttpStatusCode.Forbidden, errorCode: errorCode)
{
public readonly string[] Scopes = scopes?.ToArray() ?? [];
}
public class BadRequest(string message, IReadOnlyDictionary<string, IEnumerable<ValidationError>>? errors = null)
: ApiError(message, statusCode: HttpStatusCode.BadRequest)
public class BadRequest(
string message,
IReadOnlyDictionary<string, IEnumerable<ValidationError>>? errors = null
) : ApiError(message, statusCode: HttpStatusCode.BadRequest)
{
public BadRequest(string message, string field, object actualValue) : this("Error validating input",
new Dictionary<string, IEnumerable<ValidationError>>
{ { field, [ValidationError.GenericValidationError(message, actualValue)] } })
{
}
public BadRequest(string message, string field, object actualValue)
: this(
"Error validating input",
new Dictionary<string, IEnumerable<ValidationError>>
{
{ field, [ValidationError.GenericValidationError(message, actualValue)] },
}
) { }
public JObject ToJson()
{
@ -49,9 +57,10 @@ public class ApiError(string message, HttpStatusCode? statusCode = null, ErrorCo
{
{ "status", (int)HttpStatusCode.BadRequest },
{ "message", Message },
{ "code", "BAD_REQUEST" }
{ "code", "BAD_REQUEST" },
};
if (errors == null) return o;
if (errors == null)
return o;
var a = new JArray();
foreach (var error in errors)
@ -59,7 +68,7 @@ public class ApiError(string message, HttpStatusCode? statusCode = null, ErrorCo
var errorObj = new JObject
{
{ "key", error.Key },
{ "errors", JArray.FromObject(error.Value) }
{ "errors", JArray.FromObject(error.Value) },
};
a.Add(errorObj);
}
@ -82,9 +91,10 @@ public class ApiError(string message, HttpStatusCode? statusCode = null, ErrorCo
{
{ "status", (int)HttpStatusCode.BadRequest },
{ "message", Message },
{ "code", "BAD_REQUEST" }
{ "code", "BAD_REQUEST" },
};
if (modelState == null) return o;
if (modelState == null)
return o;
var a = new JArray();
foreach (var error in modelState.Where(e => e.Value is { Errors.Count: > 0 }))
@ -94,8 +104,13 @@ public class ApiError(string message, HttpStatusCode? statusCode = null, ErrorCo
{ "key", error.Key },
{
"errors",
new JArray(error.Value!.Errors.Select(e => new JObject { { "message", e.ErrorMessage } }))
}
new JArray(
error.Value!.Errors.Select(e => new JObject
{
{ "message", e.ErrorMessage },
})
)
},
};
a.Add(errorObj);
}
@ -108,7 +123,8 @@ public class ApiError(string message, HttpStatusCode? statusCode = null, ErrorCo
public class NotFound(string message, ErrorCode? code = null)
: ApiError(message, statusCode: HttpStatusCode.NotFound, errorCode: code);
public class AuthenticationError(string message) : ApiError(message, statusCode: HttpStatusCode.BadRequest);
public class AuthenticationError(string message)
: ApiError(message, statusCode: HttpStatusCode.BadRequest);
}
public enum ErrorCode
@ -143,34 +159,38 @@ public class ValidationError
[JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
public object? ActualValue { get; init; }
public static ValidationError LengthError(string message, int minLength, int maxLength, int actualLength)
public static ValidationError LengthError(
string message,
int minLength,
int maxLength,
int actualLength
)
{
return new ValidationError
{
Message = message,
MinLength = minLength,
MaxLength = maxLength,
ActualLength = actualLength
ActualLength = actualLength,
};
}
public static ValidationError DisallowedValueError(string message, IEnumerable<object> allowedValues,
object actualValue)
public static ValidationError DisallowedValueError(
string message,
IEnumerable<object> allowedValues,
object actualValue
)
{
return new ValidationError
{
Message = message,
AllowedValues = allowedValues,
ActualValue = actualValue
ActualValue = actualValue,
};
}
public static ValidationError GenericValidationError(string message, object? actualValue)
{
return new ValidationError
{
Message = message,
ActualValue = actualValue
};
return new ValidationError { Message = message, ActualValue = actualValue };
}
}

View file

@ -14,21 +14,35 @@ public static class AvatarObjectExtensions
{
private static readonly string[] ValidContentTypes = ["image/png", "image/webp", "image/jpeg"];
public static async Task
DeleteMemberAvatarAsync(this ObjectStorageService objectStorageService, Snowflake id, string hash,
CancellationToken ct = default) =>
await objectStorageService.RemoveObjectAsync(MemberAvatarUpdateInvocable.Path(id, hash), ct);
public static async Task DeleteMemberAvatarAsync(
this ObjectStorageService objectStorageService,
Snowflake id,
string hash,
CancellationToken ct = default
) =>
await objectStorageService.RemoveObjectAsync(
MemberAvatarUpdateInvocable.Path(id, hash),
ct
);
public static async Task
DeleteUserAvatarAsync(this ObjectStorageService objectStorageService, Snowflake id, string hash,
CancellationToken ct = default) =>
await objectStorageService.RemoveObjectAsync(UserAvatarUpdateInvocable.Path(id, hash), ct);
public static async Task DeleteUserAvatarAsync(
this ObjectStorageService objectStorageService,
Snowflake id,
string hash,
CancellationToken ct = default
) => await objectStorageService.RemoveObjectAsync(UserAvatarUpdateInvocable.Path(id, hash), ct);
public static async Task DeleteFlagAsync(this ObjectStorageService objectStorageService, string hash,
CancellationToken ct = default) =>
await objectStorageService.RemoveObjectAsync(CreateFlagInvocable.Path(hash), ct);
public static async Task DeleteFlagAsync(
this ObjectStorageService objectStorageService,
string hash,
CancellationToken ct = default
) => await objectStorageService.RemoveObjectAsync(CreateFlagInvocable.Path(hash), ct);
public static async Task<(string Hash, Stream Image)> ConvertBase64UriToImage(this string uri, int size, bool crop)
public static async Task<(string Hash, Stream Image)> ConvertBase64UriToImage(
this string uri,
int size,
bool crop
)
{
if (!uri.StartsWith("data:image/"))
throw new ArgumentException("Not a data URI", nameof(uri));
@ -49,7 +63,7 @@ public static class AvatarObjectExtensions
{
Size = new Size(size),
Mode = crop ? ResizeMode.Crop : ResizeMode.Max,
Position = AnchorPositionMode.Center
Position = AnchorPositionMode.Center,
},
image.Size
);

View file

@ -8,37 +8,58 @@ namespace Foxnouns.Backend.Extensions;
public static class KeyCacheExtensions
{
public static async Task<string> GenerateAuthStateAsync(this KeyCacheService keyCacheService,
CancellationToken ct = default)
public static async Task<string> GenerateAuthStateAsync(
this KeyCacheService keyCacheService,
CancellationToken ct = default
)
{
var state = AuthUtils.RandomToken().Replace('+', '-').Replace('/', '_');
await keyCacheService.SetKeyAsync($"oauth_state:{state}", "", Duration.FromMinutes(10), ct);
return state;
}
public static async Task ValidateAuthStateAsync(this KeyCacheService keyCacheService, string state,
CancellationToken ct = default)
public static async Task ValidateAuthStateAsync(
this KeyCacheService keyCacheService,
string state,
CancellationToken ct = default
)
{
var val = await keyCacheService.GetKeyAsync($"oauth_state:{state}", delete: true, ct);
if (val == null) throw new ApiError.BadRequest("Invalid OAuth state");
if (val == null)
throw new ApiError.BadRequest("Invalid OAuth state");
}
public static async Task<string> GenerateRegisterEmailStateAsync(this KeyCacheService keyCacheService, string email,
Snowflake? userId = null, CancellationToken ct = default)
public static async Task<string> GenerateRegisterEmailStateAsync(
this KeyCacheService keyCacheService,
string email,
Snowflake? userId = null,
CancellationToken ct = default
)
{
// This state is used in links, not just as JSON values, so make it URL-safe
var state = AuthUtils.RandomToken().Replace('+', '-').Replace('/', '_');
await keyCacheService.SetKeyAsync($"email_state:{state}", new RegisterEmailState(email, userId),
Duration.FromDays(1), ct);
await keyCacheService.SetKeyAsync(
$"email_state:{state}",
new RegisterEmailState(email, userId),
Duration.FromDays(1),
ct
);
return state;
}
public static async Task<RegisterEmailState?> GetRegisterEmailStateAsync(this KeyCacheService keyCacheService,
string state, CancellationToken ct = default) =>
await keyCacheService.GetKeyAsync<RegisterEmailState>($"email_state:{state}", delete: true, ct);
public static async Task<RegisterEmailState?> GetRegisterEmailStateAsync(
this KeyCacheService keyCacheService,
string state,
CancellationToken ct = default
) =>
await keyCacheService.GetKeyAsync<RegisterEmailState>(
$"email_state:{state}",
delete: true,
ct
);
}
public record RegisterEmailState(
string Email,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
Snowflake? ExistingUserId);
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] Snowflake? ExistingUserId
);

View file

@ -29,8 +29,10 @@ public static class WebApplicationExtensions
// ASP.NET's built in request logs are extremely verbose, so we use Serilog's instead.
// Serilog doesn't disable the built-in logs, so we do it here.
.MinimumLevel.Override("Microsoft", LogEventLevel.Information)
.MinimumLevel.Override("Microsoft.EntityFrameworkCore.Database.Command",
config.Logging.LogQueries ? LogEventLevel.Information : LogEventLevel.Fatal)
.MinimumLevel.Override(
"Microsoft.EntityFrameworkCore.Database.Command",
config.Logging.LogQueries ? LogEventLevel.Information : LogEventLevel.Fatal
)
.MinimumLevel.Override("Microsoft.AspNetCore.Hosting", LogEventLevel.Warning)
.MinimumLevel.Override("Microsoft.AspNetCore.Mvc", LogEventLevel.Warning)
.MinimumLevel.Override("Microsoft.AspNetCore.Routing", LogEventLevel.Warning)
@ -38,7 +40,10 @@ public static class WebApplicationExtensions
if (config.Logging.SeqLogUrl != null)
{
logCfg.WriteTo.Seq(config.Logging.SeqLogUrl, restrictedToMinimumLevel: LogEventLevel.Verbose);
logCfg.WriteTo.Seq(
config.Logging.SeqLogUrl,
restrictedToMinimumLevel: LogEventLevel.Verbose
);
}
// AddSerilog doesn't seem to add an ILogger to the service collection, so add that manually.
@ -74,63 +79,74 @@ public static class WebApplicationExtensions
/// </summary>
public static IServiceCollection AddServices(this WebApplicationBuilder builder, Config config)
{
builder.Host.ConfigureServices((ctx, services) =>
{
services
.AddQueue()
.AddSmtpMailer(ctx.Configuration)
.AddDbContext<DatabaseContext>()
.AddMetricServer(o => o.Port = config.Logging.MetricsPort)
.AddMinio(c =>
c.WithEndpoint(config.Storage.Endpoint)
.WithCredentials(config.Storage.AccessKey, config.Storage.SecretKey)
.Build())
.AddSingleton<MetricsCollectionService>()
.AddSingleton<IClock>(SystemClock.Instance)
.AddSnowflakeGenerator()
.AddSingleton<MailService>()
.AddScoped<UserRendererService>()
.AddScoped<MemberRendererService>()
.AddScoped<AuthService>()
.AddScoped<KeyCacheService>()
.AddScoped<RemoteAuthService>()
.AddScoped<ObjectStorageService>()
// Background services
.AddHostedService<PeriodicTasksService>()
// Transient jobs
.AddTransient<MemberAvatarUpdateInvocable>()
.AddTransient<UserAvatarUpdateInvocable>()
.AddTransient<CreateFlagInvocable>();
builder.Host.ConfigureServices(
(ctx, services) =>
{
services
.AddQueue()
.AddSmtpMailer(ctx.Configuration)
.AddDbContext<DatabaseContext>()
.AddMetricServer(o => o.Port = config.Logging.MetricsPort)
.AddMinio(c =>
c.WithEndpoint(config.Storage.Endpoint)
.WithCredentials(config.Storage.AccessKey, config.Storage.SecretKey)
.Build()
)
.AddSingleton<MetricsCollectionService>()
.AddSingleton<IClock>(SystemClock.Instance)
.AddSnowflakeGenerator()
.AddSingleton<MailService>()
.AddScoped<UserRendererService>()
.AddScoped<MemberRendererService>()
.AddScoped<AuthService>()
.AddScoped<KeyCacheService>()
.AddScoped<RemoteAuthService>()
.AddScoped<ObjectStorageService>()
// Background services
.AddHostedService<PeriodicTasksService>()
// Transient jobs
.AddTransient<MemberAvatarUpdateInvocable>()
.AddTransient<UserAvatarUpdateInvocable>()
.AddTransient<CreateFlagInvocable>();
if (!config.Logging.EnableMetrics)
services.AddHostedService<BackgroundMetricsCollectionService>();
});
if (!config.Logging.EnableMetrics)
services.AddHostedService<BackgroundMetricsCollectionService>();
}
);
return builder.Services;
}
public static IServiceCollection AddCustomMiddleware(this IServiceCollection services) => services
.AddScoped<ErrorHandlerMiddleware>()
.AddScoped<AuthenticationMiddleware>()
.AddScoped<AuthorizationMiddleware>();
public static IServiceCollection AddCustomMiddleware(this IServiceCollection services) =>
services
.AddScoped<ErrorHandlerMiddleware>()
.AddScoped<AuthenticationMiddleware>()
.AddScoped<AuthorizationMiddleware>();
public static IApplicationBuilder UseCustomMiddleware(this IApplicationBuilder app) => app
.UseMiddleware<ErrorHandlerMiddleware>()
.UseMiddleware<AuthenticationMiddleware>()
.UseMiddleware<AuthorizationMiddleware>();
public static IApplicationBuilder UseCustomMiddleware(this IApplicationBuilder app) =>
app.UseMiddleware<ErrorHandlerMiddleware>()
.UseMiddleware<AuthenticationMiddleware>()
.UseMiddleware<AuthorizationMiddleware>();
public static async Task Initialize(this WebApplication app, string[] args)
{
// Read version information from .version in the repository root
await BuildInfo.ReadBuildInfo();
app.Services.ConfigureQueue().LogQueuedTaskProgress(app.Services.GetRequiredService<ILogger<IQueue>>());
app.Services.ConfigureQueue()
.LogQueuedTaskProgress(app.Services.GetRequiredService<ILogger<IQueue>>());
await using var scope = app.Services.CreateAsyncScope();
var logger = scope.ServiceProvider.GetRequiredService<ILogger>().ForContext<WebApplication>();
var logger = scope
.ServiceProvider.GetRequiredService<ILogger>()
.ForContext<WebApplication>();
var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
logger.Information("Starting Foxnouns.NET {Version} ({Hash})", BuildInfo.Version, BuildInfo.Hash);
logger.Information(
"Starting Foxnouns.NET {Version} ({Hash})",
BuildInfo.Version,
BuildInfo.Hash
);
var pendingMigrations = (await db.Database.GetPendingMigrationsAsync()).ToList();
if (args.Contains("--migrate") || args.Contains("--migrate-and-start"))
@ -146,13 +162,15 @@ public static class WebApplicationExtensions
logger.Information("Successfully migrated database");
}
if (!args.Contains("--migrate-and-start")) Environment.Exit(0);
if (!args.Contains("--migrate-and-start"))
Environment.Exit(0);
}
else if (pendingMigrations.Count > 0)
{
logger.Fatal(
"There are {Count} pending migrations, run server with --migrate or --migrate-and-start to run migrations.",
pendingMigrations.Count);
pendingMigrations.Count
);
Environment.Exit(1);
}

View file

@ -4,23 +4,35 @@ namespace Foxnouns.Backend;
public static class FoxnounsMetrics
{
public static readonly Gauge UsersCount =
Metrics.CreateGauge("foxnouns_user_count", "Number of total users");
public static readonly Gauge UsersCount = Metrics.CreateGauge(
"foxnouns_user_count",
"Number of total users"
);
public static readonly Gauge UsersActiveMonthCount =
Metrics.CreateGauge("foxnouns_user_count_active_month", "Number of users active in the last month");
public static readonly Gauge UsersActiveMonthCount = Metrics.CreateGauge(
"foxnouns_user_count_active_month",
"Number of users active in the last month"
);
public static readonly Gauge UsersActiveWeekCount =
Metrics.CreateGauge("foxnouns_user_count_active_week", "Number of users active in the last week");
public static readonly Gauge UsersActiveWeekCount = Metrics.CreateGauge(
"foxnouns_user_count_active_week",
"Number of users active in the last week"
);
public static readonly Gauge UsersActiveDayCount =
Metrics.CreateGauge("foxnouns_user_count_active_day", "Number of users active in the last day");
public static readonly Gauge UsersActiveDayCount = Metrics.CreateGauge(
"foxnouns_user_count_active_day",
"Number of users active in the last day"
);
public static readonly Gauge MemberCount =
Metrics.CreateGauge("foxnouns_member_count", "Number of total members");
public static readonly Gauge MemberCount = Metrics.CreateGauge(
"foxnouns_member_count",
"Number of total members"
);
public static readonly Summary MetricsCollectionTime =
Metrics.CreateSummary("foxnouns_time_metrics", "Time it took to collect metrics");
public static readonly Summary MetricsCollectionTime = Metrics.CreateSummary(
"foxnouns_time_metrics",
"Time it took to collect metrics"
);
public static Gauge ProcessPhysicalMemory =>
Metrics.CreateGauge("foxnouns_process_physical_memory", "Process physical memory");
@ -31,7 +43,9 @@ public static class FoxnounsMetrics
public static Gauge ProcessPrivateMemory =>
Metrics.CreateGauge("foxnouns_process_private_memory", "Process private memory");
public static Gauge ProcessThreads => Metrics.CreateGauge("foxnouns_process_threads", "Process thread count");
public static Gauge ProcessThreads =>
Metrics.CreateGauge("foxnouns_process_threads", "Process thread count");
public static Gauge ProcessHandles => Metrics.CreateGauge("foxnouns_process_handles", "Process handle count");
public static Gauge ProcessHandles =>
Metrics.CreateGauge("foxnouns_process_handles", "Process handle count");
}

View file

@ -6,20 +6,30 @@ using Foxnouns.Backend.Services;
namespace Foxnouns.Backend.Jobs;
public class CreateFlagInvocable(DatabaseContext db, ObjectStorageService objectStorageService, ILogger logger)
: IInvocable, IInvocableWithPayload<CreateFlagPayload>
public class CreateFlagInvocable(
DatabaseContext db,
ObjectStorageService objectStorageService,
ILogger logger
) : IInvocable, IInvocableWithPayload<CreateFlagPayload>
{
private readonly ILogger _logger = logger.ForContext<CreateFlagInvocable>();
public required CreateFlagPayload Payload { get; set; }
public async Task Invoke()
{
_logger.Information("Creating flag {FlagId} for user {UserId} with image data length {DataLength}", Payload.Id,
Payload.UserId, Payload.ImageData.Length);
_logger.Information(
"Creating flag {FlagId} for user {UserId} with image data length {DataLength}",
Payload.Id,
Payload.UserId,
Payload.ImageData.Length
);
try
{
var (hash, image) = await Payload.ImageData.ConvertBase64UriToImage(size: 256, crop: false);
var (hash, image) = await Payload.ImageData.ConvertBase64UriToImage(
size: 256,
crop: false
);
await objectStorageService.PutObjectAsync(Path(hash), image, "image/webp");
var flag = new PrideFlag
@ -28,7 +38,7 @@ public class CreateFlagInvocable(DatabaseContext db, ObjectStorageService object
UserId = Payload.UserId,
Hash = hash,
Name = Payload.Name,
Description = Payload.Description
Description = Payload.Description,
};
db.Add(flag);

View file

@ -6,16 +6,21 @@ using Foxnouns.Backend.Services;
namespace Foxnouns.Backend.Jobs;
public class MemberAvatarUpdateInvocable(DatabaseContext db, ObjectStorageService objectStorageService, ILogger logger)
: IInvocable, IInvocableWithPayload<AvatarUpdatePayload>
public class MemberAvatarUpdateInvocable(
DatabaseContext db,
ObjectStorageService objectStorageService,
ILogger logger
) : IInvocable, IInvocableWithPayload<AvatarUpdatePayload>
{
private readonly ILogger _logger = logger.ForContext<UserAvatarUpdateInvocable>();
public required AvatarUpdatePayload Payload { get; set; }
public async Task Invoke()
{
if (Payload.NewAvatar != null) await UpdateMemberAvatarAsync(Payload.Id, Payload.NewAvatar);
else await ClearMemberAvatarAsync(Payload.Id);
if (Payload.NewAvatar != null)
await UpdateMemberAvatarAsync(Payload.Id, Payload.NewAvatar);
else
await ClearMemberAvatarAsync(Payload.Id);
}
private async Task UpdateMemberAvatarAsync(Snowflake id, string newAvatar)
@ -25,7 +30,10 @@ public class MemberAvatarUpdateInvocable(DatabaseContext db, ObjectStorageServic
var member = await db.Members.FindAsync(id);
if (member == null)
{
_logger.Warning("Update avatar job queued for {MemberId} but no member with that ID exists", id);
_logger.Warning(
"Update avatar job queued for {MemberId} but no member with that ID exists",
id
);
return;
}
@ -46,7 +54,11 @@ public class MemberAvatarUpdateInvocable(DatabaseContext db, ObjectStorageServic
}
catch (ArgumentException ae)
{
_logger.Warning("Invalid data URI for new avatar for member {MemberId}: {Reason}", id, ae.Message);
_logger.Warning(
"Invalid data URI for new avatar for member {MemberId}: {Reason}",
id,
ae.Message
);
}
}
@ -57,7 +69,10 @@ public class MemberAvatarUpdateInvocable(DatabaseContext db, ObjectStorageServic
var member = await db.Members.FindAsync(id);
if (member == null)
{
_logger.Warning("Clear avatar job queued for {MemberId} but no member with that ID exists", id);
_logger.Warning(
"Clear avatar job queued for {MemberId} but no member with that ID exists",
id
);
return;
}

View file

@ -4,4 +4,10 @@ namespace Foxnouns.Backend.Jobs;
public record AvatarUpdatePayload(Snowflake Id, string? NewAvatar);
public record CreateFlagPayload(Snowflake Id, Snowflake UserId, string Name, string ImageData, string? Description);
public record CreateFlagPayload(
Snowflake Id,
Snowflake UserId,
string Name,
string ImageData,
string? Description
);

View file

@ -6,16 +6,21 @@ using Foxnouns.Backend.Services;
namespace Foxnouns.Backend.Jobs;
public class UserAvatarUpdateInvocable(DatabaseContext db, ObjectStorageService objectStorageService, ILogger logger)
: IInvocable, IInvocableWithPayload<AvatarUpdatePayload>
public class UserAvatarUpdateInvocable(
DatabaseContext db,
ObjectStorageService objectStorageService,
ILogger logger
) : IInvocable, IInvocableWithPayload<AvatarUpdatePayload>
{
private readonly ILogger _logger = logger.ForContext<UserAvatarUpdateInvocable>();
public required AvatarUpdatePayload Payload { get; set; }
public async Task Invoke()
{
if (Payload.NewAvatar != null) await UpdateUserAvatarAsync(Payload.Id, Payload.NewAvatar);
else await ClearUserAvatarAsync(Payload.Id);
if (Payload.NewAvatar != null)
await UpdateUserAvatarAsync(Payload.Id, Payload.NewAvatar);
else
await ClearUserAvatarAsync(Payload.Id);
}
private async Task UpdateUserAvatarAsync(Snowflake id, string newAvatar)
@ -25,7 +30,10 @@ public class UserAvatarUpdateInvocable(DatabaseContext db, ObjectStorageService
var user = await db.Users.FindAsync(id);
if (user == null)
{
_logger.Warning("Update avatar job queued for {UserId} but no user with that ID exists", id);
_logger.Warning(
"Update avatar job queued for {UserId} but no user with that ID exists",
id
);
return;
}
@ -47,7 +55,11 @@ public class UserAvatarUpdateInvocable(DatabaseContext db, ObjectStorageService
}
catch (ArgumentException ae)
{
_logger.Warning("Invalid data URI for new avatar for user {UserId}: {Reason}", id, ae.Message);
_logger.Warning(
"Invalid data URI for new avatar for user {UserId}: {Reason}",
id,
ae.Message
);
}
}
@ -58,7 +70,10 @@ public class UserAvatarUpdateInvocable(DatabaseContext db, ObjectStorageService
var user = await db.Users.FindAsync(id);
if (user == null)
{
_logger.Warning("Clear avatar job queued for {UserId} but no user with that ID exists", id);
_logger.Warning(
"Clear avatar job queued for {UserId} but no user with that ID exists",
id
);
return;
}

View file

@ -7,9 +7,7 @@ public class AccountCreationMailable(Config config, AccountCreationMailableView
{
public override void Build()
{
To(view.To)
.From(config.EmailAuth.From!)
.View("~/Views/Mail/AccountCreation.cshtml", view);
To(view.To).From(config.EmailAuth.From!).View("~/Views/Mail/AccountCreation.cshtml", view);
}
}

View file

@ -17,7 +17,9 @@ public class AuthenticationMiddleware(DatabaseContext db) : IMiddleware
return;
}
if (!AuthUtils.TryParseToken(ctx.Request.Headers.Authorization.ToString(), out var rawToken))
if (
!AuthUtils.TryParseToken(ctx.Request.Headers.Authorization.ToString(), out var rawToken)
)
{
await next(ctx);
return;
@ -40,6 +42,7 @@ public static class HttpContextExtensions
private const string Key = "token";
public static void SetToken(this HttpContext ctx, Token token) => ctx.Items.Add(Key, token);
public static User? GetUser(this HttpContext ctx) => ctx.GetToken()?.User;
public static User GetUserOrThrow(this HttpContext ctx) =>

View file

@ -18,14 +18,26 @@ public class AuthorizationMiddleware : IMiddleware
var token = ctx.GetToken();
if (token == null)
throw new ApiError.Unauthorized("This endpoint requires an authenticated user.",
ErrorCode.AuthenticationRequired);
if (attribute.Scopes.Length > 0 && attribute.Scopes.Except(token.Scopes.ExpandScopes()).Any())
throw new ApiError.Forbidden("This endpoint requires ungranted scopes.",
attribute.Scopes.Except(token.Scopes.ExpandScopes()), ErrorCode.MissingScopes);
throw new ApiError.Unauthorized(
"This endpoint requires an authenticated user.",
ErrorCode.AuthenticationRequired
);
if (
attribute.Scopes.Length > 0
&& attribute.Scopes.Except(token.Scopes.ExpandScopes()).Any()
)
throw new ApiError.Forbidden(
"This endpoint requires ungranted scopes.",
attribute.Scopes.Except(token.Scopes.ExpandScopes()),
ErrorCode.MissingScopes
);
if (attribute.RequireAdmin && token.User.Role != UserRole.Admin)
throw new ApiError.Forbidden("This endpoint can only be used by admins.");
if (attribute.RequireModerator && token.User.Role != UserRole.Admin && token.User.Role != UserRole.Moderator)
if (
attribute.RequireModerator
&& token.User.Role != UserRole.Admin
&& token.User.Role != UserRole.Moderator
)
throw new ApiError.Forbidden("This endpoint can only be used by moderators.");
await next(ctx);

View file

@ -21,19 +21,26 @@ public class ErrorHandlerMiddleware(ILogger baseLogger, IHub sentry) : IMiddlewa
if (ctx.Response.HasStarted)
{
logger.Error(e, "Error in {ClassName} ({Path}) after response started being sent", typeName,
ctx.Request.Path);
logger.Error(
e,
"Error in {ClassName} ({Path}) after response started being sent",
typeName,
ctx.Request.Path
);
sentry.CaptureException(e, scope =>
{
var user = ctx.GetUser();
if (user != null)
scope.User = new SentryUser
{
Id = user.Id.ToString(),
Username = user.Username
};
});
sentry.CaptureException(
e,
scope =>
{
var user = ctx.GetUser();
if (user != null)
scope.User = new SentryUser
{
Id = user.Id.ToString(),
Username = user.Username,
};
}
);
return;
}
@ -45,13 +52,17 @@ public class ErrorHandlerMiddleware(ILogger baseLogger, IHub sentry) : IMiddlewa
ctx.Response.ContentType = "application/json; charset=utf-8";
if (ae is ApiError.Forbidden fe)
{
await ctx.Response.WriteAsync(JsonConvert.SerializeObject(new HttpApiError
{
Status = (int)fe.StatusCode,
Code = ErrorCode.Forbidden,
Message = fe.Message,
Scopes = fe.Scopes.Length > 0 ? fe.Scopes : null
}));
await ctx.Response.WriteAsync(
JsonConvert.SerializeObject(
new HttpApiError
{
Status = (int)fe.StatusCode,
Code = ErrorCode.Forbidden,
Message = fe.Message,
Scopes = fe.Scopes.Length > 0 ? fe.Scopes : null,
}
)
);
return;
}
@ -61,45 +72,61 @@ public class ErrorHandlerMiddleware(ILogger baseLogger, IHub sentry) : IMiddlewa
return;
}
await ctx.Response.WriteAsync(JsonConvert.SerializeObject(new HttpApiError
{
Status = (int)ae.StatusCode,
Code = ae.ErrorCode,
Message = ae.Message,
}));
await ctx.Response.WriteAsync(
JsonConvert.SerializeObject(
new HttpApiError
{
Status = (int)ae.StatusCode,
Code = ae.ErrorCode,
Message = ae.Message,
}
)
);
return;
}
if (e is FoxnounsError fce)
{
logger.Error(fce.Inner ?? fce, "Exception in {ClassName} ({Path})", typeName, ctx.Request.Path);
logger.Error(
fce.Inner ?? fce,
"Exception in {ClassName} ({Path})",
typeName,
ctx.Request.Path
);
}
else
{
logger.Error(e, "Exception in {ClassName} ({Path})", typeName, ctx.Request.Path);
}
var errorId = sentry.CaptureException(e, scope =>
{
var user = ctx.GetUser();
if (user != null)
scope.User = new SentryUser
{
Id = user.Id.ToString(),
Username = user.Username
};
});
var errorId = sentry.CaptureException(
e,
scope =>
{
var user = ctx.GetUser();
if (user != null)
scope.User = new SentryUser
{
Id = user.Id.ToString(),
Username = user.Username,
};
}
);
ctx.Response.StatusCode = (int)HttpStatusCode.InternalServerError;
ctx.Response.Headers.RequestId = ctx.TraceIdentifier;
ctx.Response.ContentType = "application/json; charset=utf-8";
await ctx.Response.WriteAsync(JsonConvert.SerializeObject(new HttpApiError
{
Status = (int)HttpStatusCode.InternalServerError,
Code = ErrorCode.InternalServerError,
ErrorId = errorId.ToString(),
Message = "Internal server error",
}));
await ctx.Response.WriteAsync(
JsonConvert.SerializeObject(
new HttpApiError
{
Status = (int)HttpStatusCode.InternalServerError,
Code = ErrorCode.InternalServerError,
ErrorId = errorId.ToString(),
Message = "Internal server error",
}
)
);
}
}
}

View file

@ -1,5 +1,4 @@
using Foxnouns.Backend;
using Serilog;
using Foxnouns.Backend.Extensions;
using Foxnouns.Backend.Services;
using Foxnouns.Backend.Utils;
@ -8,6 +7,7 @@ using Newtonsoft.Json;
using Newtonsoft.Json.Serialization;
using Prometheus;
using Sentry.Extensibility;
using Serilog;
var builder = WebApplication.CreateBuilder(args);
@ -15,8 +15,8 @@ var config = builder.AddConfiguration();
builder.AddSerilog();
builder.WebHost
.UseSentry(opts =>
builder
.WebHost.UseSentry(opts =>
{
opts.Dsn = config.Logging.SentryUrl;
opts.TracesSampleRate = config.Logging.SentryTracesSampleRate;
@ -30,13 +30,13 @@ builder.WebHost
opts.Limits.MaxRequestBodySize = 2 * 1024 * 1024;
});
builder.Services
.AddControllers()
builder
.Services.AddControllers()
.AddNewtonsoftJson(options =>
{
options.SerializerSettings.ContractResolver = new PatchRequestContractResolver
{
NamingStrategy = new SnakeCaseNamingStrategy()
NamingStrategy = new SnakeCaseNamingStrategy(),
};
})
.ConfigureApiBehaviorOptions(options =>
@ -47,18 +47,16 @@ builder.Services
});
// Set the default converter to snake case as we use it in a couple places.
JsonConvert.DefaultSettings = () => new JsonSerializerSettings
{
ContractResolver = new DefaultContractResolver
JsonConvert.DefaultSettings = () =>
new JsonSerializerSettings
{
NamingStrategy = new SnakeCaseNamingStrategy()
}
};
ContractResolver = new DefaultContractResolver
{
NamingStrategy = new SnakeCaseNamingStrategy(),
},
};
builder.AddServices(config)
.AddCustomMiddleware()
.AddEndpointsApiExplorer()
.AddSwaggerGen();
builder.AddServices(config).AddCustomMiddleware().AddEndpointsApiExplorer().AddSwaggerGen();
var app = builder.Build();
@ -66,9 +64,11 @@ await app.Initialize(args);
app.UseSerilogRequestLogging();
app.UseRouting();
// Not all environments will want tracing (from experience, it's expensive to use in production, even with a low sample rate),
// so it's locked behind a config option.
if (config.Logging.SentryTracing) app.UseSentryTracing();
if (config.Logging.SentryTracing)
app.UseSentryTracing();
app.UseSwagger();
app.UseSwaggerUI();
app.UseCors();
@ -80,7 +80,8 @@ app.Urls.Add(config.Address);
// Make sure metrics are updated whenever Prometheus scrapes them
Metrics.DefaultRegistry.AddBeforeCollectCallback(async ct =>
await app.Services.GetRequiredService<MetricsCollectionService>().CollectMetricsAsync(ct));
await app.Services.GetRequiredService<MetricsCollectionService>().CollectMetricsAsync(ct)
);
app.Run();
Log.CloseAndFlush();

View file

@ -16,8 +16,12 @@ public class AuthService(IClock clock, DatabaseContext db, ISnowflakeGenerator s
/// Creates a new user with the given email address and password.
/// This method does <i>not</i> save the resulting user, the caller must still call <see cref="M:Microsoft.EntityFrameworkCore.DbContext.SaveChanges" />.
/// </summary>
public async Task<User> CreateUserWithPasswordAsync(string username, string email, string password,
CancellationToken ct = default)
public async Task<User> CreateUserWithPasswordAsync(
string username,
string email,
string password,
CancellationToken ct = default
)
{
var user = new User
{
@ -26,9 +30,13 @@ public class AuthService(IClock clock, DatabaseContext db, ISnowflakeGenerator s
AuthMethods =
{
new AuthMethod
{ Id = snowflakeGenerator.GenerateSnowflake(), AuthType = AuthType.Email, RemoteId = email }
{
Id = snowflakeGenerator.GenerateSnowflake(),
AuthType = AuthType.Email,
RemoteId = email,
},
},
LastActive = clock.GetCurrentInstant()
LastActive = clock.GetCurrentInstant(),
};
db.Add(user);
@ -42,8 +50,14 @@ public class AuthService(IClock clock, DatabaseContext db, ISnowflakeGenerator s
/// To create a user with email authentication, use <see cref="CreateUserWithPasswordAsync" />
/// This method does <i>not</i> save the resulting user, the caller must still call <see cref="M:Microsoft.EntityFrameworkCore.DbContext.SaveChanges" />.
/// </summary>
public async Task<User> CreateUserWithRemoteAuthAsync(string username, AuthType authType, string remoteId,
string remoteUsername, FediverseApplication? instance = null, CancellationToken ct = default)
public async Task<User> CreateUserWithRemoteAuthAsync(
string username,
AuthType authType,
string remoteId,
string remoteUsername,
FediverseApplication? instance = null,
CancellationToken ct = default
)
{
AssertValidAuthType(authType, instance);
@ -58,11 +72,14 @@ public class AuthService(IClock clock, DatabaseContext db, ISnowflakeGenerator s
{
new AuthMethod
{
Id = snowflakeGenerator.GenerateSnowflake(), AuthType = authType, RemoteId = remoteId,
RemoteUsername = remoteUsername, FediverseApplication = instance
}
Id = snowflakeGenerator.GenerateSnowflake(),
AuthType = authType,
RemoteId = remoteId,
RemoteUsername = remoteUsername,
FediverseApplication = instance,
},
},
LastActive = clock.GetCurrentInstant()
LastActive = clock.GetCurrentInstant(),
};
db.Add(user);
@ -78,19 +95,31 @@ public class AuthService(IClock clock, DatabaseContext db, ISnowflakeGenerator s
/// <returns>A tuple of the authenticated user and whether multi-factor authentication is required</returns>
/// <exception cref="ApiError.NotFound">Thrown if the email address is not associated with any user
/// or if the password is incorrect</exception>
public async Task<(User, EmailAuthenticationResult)> AuthenticateUserAsync(string email, string password,
CancellationToken ct = default)
public async Task<(User, EmailAuthenticationResult)> AuthenticateUserAsync(
string email,
string password,
CancellationToken ct = default
)
{
var user = await db.Users.FirstOrDefaultAsync(u =>
u.AuthMethods.Any(a => a.AuthType == AuthType.Email && a.RemoteId == email), ct);
var user = await db.Users.FirstOrDefaultAsync(
u => u.AuthMethods.Any(a => a.AuthType == AuthType.Email && a.RemoteId == email),
ct
);
if (user == null)
throw new ApiError.NotFound("No user with that email address found, or password is incorrect",
ErrorCode.UserNotFound);
throw new ApiError.NotFound(
"No user with that email address found, or password is incorrect",
ErrorCode.UserNotFound
);
var pwResult = await Task.Run(() => _passwordHasher.VerifyHashedPassword(user, user.Password!, password), ct);
var pwResult = await Task.Run(
() => _passwordHasher.VerifyHashedPassword(user, user.Password!, password),
ct
);
if (pwResult == PasswordVerificationResult.Failed) // TODO: this seems to fail on some valid passwords?
throw new ApiError.NotFound("No user with that email address found, or password is incorrect",
ErrorCode.UserNotFound);
throw new ApiError.NotFound(
"No user with that email address found, or password is incorrect",
ErrorCode.UserNotFound
);
if (pwResult == PasswordVerificationResult.SuccessRehashNeeded)
{
user.Password = await Task.Run(() => _passwordHasher.HashPassword(user, password), ct);
@ -117,19 +146,33 @@ public class AuthService(IClock clock, DatabaseContext db, ISnowflakeGenerator s
/// <returns>A user object, or null if the remote account isn't linked to any user.</returns>
/// <exception cref="FoxnounsError">Thrown if <c>instance</c> is passed when not required,
/// or not passed when required</exception>
public async Task<User?> AuthenticateUserAsync(AuthType authType, string remoteId,
FediverseApplication? instance = null, CancellationToken ct = default)
public async Task<User?> AuthenticateUserAsync(
AuthType authType,
string remoteId,
FediverseApplication? instance = null,
CancellationToken ct = default
)
{
AssertValidAuthType(authType, instance);
return await db.Users.FirstOrDefaultAsync(u =>
u.AuthMethods.Any(a =>
a.AuthType == authType && a.RemoteId == remoteId && a.FediverseApplication == instance), ct);
return await db.Users.FirstOrDefaultAsync(
u =>
u.AuthMethods.Any(a =>
a.AuthType == authType
&& a.RemoteId == remoteId
&& a.FediverseApplication == instance
),
ct
);
}
public async Task<AuthMethod> AddAuthMethodAsync(Snowflake userId, AuthType authType, string remoteId,
public async Task<AuthMethod> AddAuthMethodAsync(
Snowflake userId,
AuthType authType,
string remoteId,
string? remoteUsername = null,
CancellationToken ct = default)
CancellationToken ct = default
)
{
AssertValidAuthType(authType, null);
@ -139,7 +182,7 @@ public class AuthService(IClock clock, DatabaseContext db, ISnowflakeGenerator s
AuthType = authType,
RemoteId = remoteId,
RemoteUsername = remoteUsername,
UserId = userId
UserId = userId,
};
db.Add(authMethod);
@ -147,21 +190,33 @@ public class AuthService(IClock clock, DatabaseContext db, ISnowflakeGenerator s
return authMethod;
}
public (string, Token) GenerateToken(User user, Application application, string[] scopes, Instant expires)
public (string, Token) GenerateToken(
User user,
Application application,
string[] scopes,
Instant expires
)
{
if (!AuthUtils.ValidateScopes(application, scopes))
throw new ApiError.BadRequest("Invalid scopes requested for this token", "scopes", scopes);
throw new ApiError.BadRequest(
"Invalid scopes requested for this token",
"scopes",
scopes
);
var (token, hash) = GenerateToken();
return (token, new Token
{
Id = snowflakeGenerator.GenerateSnowflake(),
Hash = hash,
Application = application,
User = user,
ExpiresAt = expires,
Scopes = scopes
});
return (
token,
new Token
{
Id = snowflakeGenerator.GenerateSnowflake(),
Hash = hash,
Application = application,
User = user,
ExpiresAt = expires,
Scopes = scopes,
}
);
}
private static (string, byte[]) GenerateToken()

View file

@ -10,26 +10,43 @@ public class KeyCacheService(DatabaseContext db, IClock clock, ILogger logger)
{
private readonly ILogger _logger = logger.ForContext<KeyCacheService>();
public Task SetKeyAsync(string key, string value, Duration expireAfter, CancellationToken ct = default) =>
SetKeyAsync(key, value, clock.GetCurrentInstant() + expireAfter, ct);
public Task SetKeyAsync(
string key,
string value,
Duration expireAfter,
CancellationToken ct = default
) => SetKeyAsync(key, value, clock.GetCurrentInstant() + expireAfter, ct);
public async Task SetKeyAsync(string key, string value, Instant expires, CancellationToken ct = default)
public async Task SetKeyAsync(
string key,
string value,
Instant expires,
CancellationToken ct = default
)
{
db.TemporaryKeys.Add(new TemporaryKey
{
Expires = expires,
Key = key,
Value = value,
});
db.TemporaryKeys.Add(
new TemporaryKey
{
Expires = expires,
Key = key,
Value = value,
}
);
await db.SaveChangesAsync(ct);
}
public async Task<string?> GetKeyAsync(string key, bool delete = false, CancellationToken ct = default)
public async Task<string?> GetKeyAsync(
string key,
bool delete = false,
CancellationToken ct = default
)
{
var value = await db.TemporaryKeys.FirstOrDefaultAsync(k => k.Key == key, ct);
if (value == null) return null;
if (value == null)
return null;
if (delete) await db.TemporaryKeys.Where(k => k.Key == key).ExecuteDeleteAsync(ct);
if (delete)
await db.TemporaryKeys.Where(k => k.Key == key).ExecuteDeleteAsync(ct);
return value.Value;
}
@ -39,20 +56,38 @@ public class KeyCacheService(DatabaseContext db, IClock clock, ILogger logger)
public async Task DeleteExpiredKeysAsync(CancellationToken ct)
{
var count = await db.TemporaryKeys.Where(k => k.Expires < clock.GetCurrentInstant()).ExecuteDeleteAsync(ct);
if (count != 0) _logger.Information("Removed {Count} expired keys from the database", count);
var count = await db
.TemporaryKeys.Where(k => k.Expires < clock.GetCurrentInstant())
.ExecuteDeleteAsync(ct);
if (count != 0)
_logger.Information("Removed {Count} expired keys from the database", count);
}
public Task SetKeyAsync<T>(string key, T obj, Duration expiresAt, CancellationToken ct = default) where T : class =>
SetKeyAsync(key, obj, clock.GetCurrentInstant() + expiresAt, ct);
public Task SetKeyAsync<T>(
string key,
T obj,
Duration expiresAt,
CancellationToken ct = default
)
where T : class => SetKeyAsync(key, obj, clock.GetCurrentInstant() + expiresAt, ct);
public async Task SetKeyAsync<T>(string key, T obj, Instant expires, CancellationToken ct = default) where T : class
public async Task SetKeyAsync<T>(
string key,
T obj,
Instant expires,
CancellationToken ct = default
)
where T : class
{
var value = JsonConvert.SerializeObject(obj);
await SetKeyAsync(key, value, expires, ct);
}
public async Task<T?> GetKeyAsync<T>(string key, bool delete = false, CancellationToken ct = default)
public async Task<T?> GetKeyAsync<T>(
string key,
bool delete = false,
CancellationToken ct = default
)
where T : class
{
var value = await GetKeyAsync(key, delete, ct);

View file

@ -15,12 +15,17 @@ public class MailService(ILogger logger, IMailer mailer, IQueue queue, Config co
_logger.Debug("Sending account creation email to {ToEmail}", to);
try
{
await mailer.SendAsync(new AccountCreationMailable(config, new AccountCreationMailableView
{
BaseUrl = config.BaseUrl,
To = to,
Code = code
}));
await mailer.SendAsync(
new AccountCreationMailable(
config,
new AccountCreationMailableView
{
BaseUrl = config.BaseUrl,
To = to,
Code = code,
}
)
);
}
catch (Exception exc)
{

View file

@ -10,17 +10,17 @@ public class MemberRendererService(DatabaseContext db, Config config)
{
public async Task<IEnumerable<PartialMember>> RenderUserMembersAsync(User user, Token? token)
{
var canReadHiddenMembers = token != null && token.UserId == user.Id && token.HasScope("member.read");
var renderUnlisted = token != null && token.UserId == user.Id && token.HasScope("user.read_hidden");
var canReadHiddenMembers =
token != null && token.UserId == user.Id && token.HasScope("member.read");
var renderUnlisted =
token != null && token.UserId == user.Id && token.HasScope("user.read_hidden");
var canReadMemberList = !user.ListHidden || canReadHiddenMembers;
IEnumerable<Member> members = canReadMemberList
? await db.Members
.Where(m => m.UserId == user.Id)
.OrderBy(m => m.Name)
.ToListAsync()
? await db.Members.Where(m => m.UserId == user.Id).OrderBy(m => m.Name).ToListAsync()
: [];
if (!canReadHiddenMembers) members = members.Where(m => !m.Unlisted);
if (!canReadHiddenMembers)
members = members.Where(m => !m.Unlisted);
return members.Select(m => RenderPartialMember(m, renderUnlisted));
}
@ -29,25 +29,54 @@ public class MemberRendererService(DatabaseContext db, Config config)
var renderUnlisted = token?.UserId == member.UserId && token.HasScope("user.read_hidden");
return new MemberResponse(
member.Id, member.Sid, member.Name, member.DisplayName, member.Bio,
AvatarUrlFor(member), member.Links, member.Names, member.Pronouns, member.Fields,
member.Id,
member.Sid,
member.Name,
member.DisplayName,
member.Bio,
AvatarUrlFor(member),
member.Links,
member.Names,
member.Pronouns,
member.Fields,
member.ProfileFlags.Select(f => RenderPrideFlag(f.PrideFlag)),
RenderPartialUser(member.User), renderUnlisted ? member.Unlisted : null);
RenderPartialUser(member.User),
renderUnlisted ? member.Unlisted : null
);
}
private UserRendererService.PartialUser RenderPartialUser(User user) =>
new(user.Id, user.Sid, user.Username, user.DisplayName, AvatarUrlFor(user), user.CustomPreferences);
new(
user.Id,
user.Sid,
user.Username,
user.DisplayName,
AvatarUrlFor(user),
user.CustomPreferences
);
public PartialMember RenderPartialMember(Member member, bool renderUnlisted = false) => new(member.Id, member.Sid,
member.Name,
member.DisplayName, member.Bio, AvatarUrlFor(member), member.Names, member.Pronouns,
renderUnlisted ? member.Unlisted : null);
public PartialMember RenderPartialMember(Member member, bool renderUnlisted = false) =>
new(
member.Id,
member.Sid,
member.Name,
member.DisplayName,
member.Bio,
AvatarUrlFor(member),
member.Names,
member.Pronouns,
renderUnlisted ? member.Unlisted : null
);
private string? AvatarUrlFor(Member member) =>
member.Avatar != null ? $"{config.MediaBaseUrl}/members/{member.Id}/avatars/{member.Avatar}.webp" : null;
member.Avatar != null
? $"{config.MediaBaseUrl}/members/{member.Id}/avatars/{member.Avatar}.webp"
: null;
private string? AvatarUrlFor(User user) =>
user.Avatar != null ? $"{config.MediaBaseUrl}/users/{user.Id}/avatars/{user.Avatar}.webp" : null;
user.Avatar != null
? $"{config.MediaBaseUrl}/users/{user.Id}/avatars/{user.Avatar}.webp"
: null;
private string ImageUrlFor(PrideFlag flag) => $"{config.MediaBaseUrl}/flags/{flag.Hash}.webp";
@ -63,8 +92,8 @@ public class MemberRendererService(DatabaseContext db, Config config)
string? AvatarUrl,
IEnumerable<FieldEntry> Names,
IEnumerable<Pronoun> Pronouns,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
bool? Unlisted);
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] bool? Unlisted
);
public record MemberResponse(
Snowflake Id,
@ -79,6 +108,6 @@ public class MemberRendererService(DatabaseContext db, Config config)
IEnumerable<Field> Fields,
IEnumerable<UserRendererService.PrideFlagResponse> Flags,
UserRendererService.PartialUser User,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
bool? Unlisted);
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] bool? Unlisted
);
}

View file

@ -6,10 +6,7 @@ using Prometheus;
namespace Foxnouns.Backend.Services;
public class MetricsCollectionService(
ILogger logger,
IServiceProvider services,
IClock clock)
public class MetricsCollectionService(ILogger logger, IServiceProvider services, IClock clock)
{
private readonly ILogger _logger = logger.ForContext<MetricsCollectionService>();
@ -31,8 +28,10 @@ public class MetricsCollectionService(
FoxnounsMetrics.UsersActiveWeekCount.Set(users.Count(i => i > now - Week));
FoxnounsMetrics.UsersActiveDayCount.Set(users.Count(i => i > now - Day));
var memberCount = await db.Members.Include(m => m.User)
.Where(m => !m.Unlisted && !m.User.ListHidden && !m.User.Deleted).CountAsync(ct);
var memberCount = await db
.Members.Include(m => m.User)
.Where(m => !m.Unlisted && !m.User.ListHidden && !m.User.Deleted)
.CountAsync(ct);
FoxnounsMetrics.MemberCount.Set(memberCount);
var process = Process.GetCurrentProcess();
@ -42,13 +41,17 @@ public class MetricsCollectionService(
FoxnounsMetrics.ProcessThreads.Set(process.Threads.Count);
FoxnounsMetrics.ProcessHandles.Set(process.HandleCount);
_logger.Information("Collected metrics in {DurationMilliseconds} ms",
timer.ObserveDuration().TotalMilliseconds);
_logger.Information(
"Collected metrics in {DurationMilliseconds} ms",
timer.ObserveDuration().TotalMilliseconds
);
}
}
public class BackgroundMetricsCollectionService(ILogger logger, MetricsCollectionService metricsCollectionService)
: BackgroundService
public class BackgroundMetricsCollectionService(
ILogger logger,
MetricsCollectionService metricsCollectionService
) : BackgroundService
{
private readonly ILogger _logger = logger.ForContext<BackgroundMetricsCollectionService>();

View file

@ -15,7 +15,8 @@ public class ObjectStorageService(ILogger logger, Config config, IMinioClient mi
{
await minioClient.RemoveObjectAsync(
new RemoveObjectArgs().WithBucket(config.Storage.Bucket).WithObject(path),
ct);
ct
);
}
catch (InvalidObjectNameException)
{
@ -23,17 +24,28 @@ public class ObjectStorageService(ILogger logger, Config config, IMinioClient mi
}
}
public async Task PutObjectAsync(string path, Stream data, string contentType, CancellationToken ct = default)
public async Task PutObjectAsync(
string path,
Stream data,
string contentType,
CancellationToken ct = default
)
{
_logger.Debug("Putting object at path {Path} with length {Length} and content type {ContentType}", path,
data.Length, contentType);
_logger.Debug(
"Putting object at path {Path} with length {Length} and content type {ContentType}",
path,
data.Length,
contentType
);
await minioClient.PutObjectAsync(new PutObjectArgs()
await minioClient.PutObjectAsync(
new PutObjectArgs()
.WithBucket(config.Storage.Bucket)
.WithObject(path)
.WithObjectSize(data.Length)
.WithStreamData(data)
.WithContentType(contentType), ct
.WithContentType(contentType),
ct
);
}
}

View file

@ -11,30 +11,42 @@ public class RemoteAuthService(Config config, ILogger logger)
private readonly Uri _discordTokenUri = new("https://discord.com/api/oauth2/token");
private readonly Uri _discordUserUri = new("https://discord.com/api/v10/users/@me");
public async Task<RemoteUser> RequestDiscordTokenAsync(string code, string state, CancellationToken ct = default)
public async Task<RemoteUser> RequestDiscordTokenAsync(
string code,
string state,
CancellationToken ct = default
)
{
var redirectUri = $"{config.BaseUrl}/auth/callback/discord";
var resp = await _httpClient.PostAsync(_discordTokenUri, new FormUrlEncodedContent(
new Dictionary<string, string>
{
{ "client_id", config.DiscordAuth.ClientId! },
{ "client_secret", config.DiscordAuth.ClientSecret! },
{ "grant_type", "authorization_code" },
{ "code", code },
{ "redirect_uri", redirectUri }
}
), ct);
var resp = await _httpClient.PostAsync(
_discordTokenUri,
new FormUrlEncodedContent(
new Dictionary<string, string>
{
{ "client_id", config.DiscordAuth.ClientId! },
{ "client_secret", config.DiscordAuth.ClientSecret! },
{ "grant_type", "authorization_code" },
{ "code", code },
{ "redirect_uri", redirectUri },
}
),
ct
);
if (!resp.IsSuccessStatusCode)
{
var respBody = await resp.Content.ReadAsStringAsync(ct);
_logger.Error("Received error status {StatusCode} when exchanging OAuth token: {ErrorBody}",
(int)resp.StatusCode, respBody);
_logger.Error(
"Received error status {StatusCode} when exchanging OAuth token: {ErrorBody}",
(int)resp.StatusCode,
respBody
);
throw new FoxnounsError("Invalid Discord OAuth response");
}
resp.EnsureSuccessStatusCode();
var token = await resp.Content.ReadFromJsonAsync<DiscordTokenResponse>(ct);
if (token == null) throw new FoxnounsError("Discord token response was null");
if (token == null)
throw new FoxnounsError("Discord token response was null");
var req = new HttpRequestMessage(HttpMethod.Get, _discordUserUri);
req.Headers.Add("Authorization", $"{token.token_type} {token.access_token}");
@ -42,18 +54,25 @@ public class RemoteAuthService(Config config, ILogger logger)
var resp2 = await _httpClient.SendAsync(req, ct);
resp2.EnsureSuccessStatusCode();
var user = await resp2.Content.ReadFromJsonAsync<DiscordUserResponse>(ct);
if (user == null) throw new FoxnounsError("Discord user response was null");
if (user == null)
throw new FoxnounsError("Discord user response was null");
return new RemoteUser(user.id, user.username);
}
[SuppressMessage("ReSharper", "InconsistentNaming",
Justification = "Easier to use snake_case here, rather than passing in JSON converter options")]
[SuppressMessage(
"ReSharper",
"InconsistentNaming",
Justification = "Easier to use snake_case here, rather than passing in JSON converter options"
)]
[UsedImplicitly]
private record DiscordTokenResponse(string access_token, string token_type);
[SuppressMessage("ReSharper", "InconsistentNaming",
Justification = "Easier to use snake_case here, rather than passing in JSON converter options")]
[SuppressMessage(
"ReSharper",
"InconsistentNaming",
Justification = "Easier to use snake_case here, rather than passing in JSON converter options"
)]
[UsedImplicitly]
private record DiscordUserResponse(string id, string username);

View file

@ -7,48 +7,73 @@ using NodaTime;
namespace Foxnouns.Backend.Services;
public class UserRendererService(DatabaseContext db, MemberRendererService memberRenderer, Config config)
public class UserRendererService(
DatabaseContext db,
MemberRendererService memberRenderer,
Config config
)
{
public async Task<UserResponse> RenderUserAsync(User user, User? selfUser = null,
public async Task<UserResponse> RenderUserAsync(
User user,
User? selfUser = null,
Token? token = null,
bool renderMembers = true,
bool renderAuthMethods = false,
CancellationToken ct = default)
CancellationToken ct = default
)
{
var isSelfUser = selfUser?.Id == user.Id;
var tokenCanReadHiddenMembers = token.HasScope("member.read") && isSelfUser;
var tokenHidden = token.HasScope("user.read_hidden") && isSelfUser;
var tokenPrivileged = token.HasScope("user.read_privileged") && isSelfUser;
renderMembers = renderMembers &&
(!user.ListHidden || tokenCanReadHiddenMembers);
renderMembers = renderMembers && (!user.ListHidden || tokenCanReadHiddenMembers);
renderAuthMethods = renderAuthMethods && tokenPrivileged;
IEnumerable<Member> members =
renderMembers ? await db.Members.Where(m => m.UserId == user.Id).OrderBy(m => m.Name).ToListAsync(ct) : [];
IEnumerable<Member> members = renderMembers
? await db.Members.Where(m => m.UserId == user.Id).OrderBy(m => m.Name).ToListAsync(ct)
: [];
// Unless the user is requesting their own members AND the token can read hidden members, we filter out unlisted members.
if (!(isSelfUser && tokenCanReadHiddenMembers)) members = members.Where(m => !m.Unlisted);
if (!(isSelfUser && tokenCanReadHiddenMembers))
members = members.Where(m => !m.Unlisted);
var flags = await db.UserFlags.Where(f => f.UserId == user.Id).OrderBy(f => f.Id).ToListAsync(ct);
var flags = await db
.UserFlags.Where(f => f.UserId == user.Id)
.OrderBy(f => f.Id)
.ToListAsync(ct);
var authMethods = renderAuthMethods
? await db.AuthMethods
.Where(a => a.UserId == user.Id)
? await db
.AuthMethods.Where(a => a.UserId == user.Id)
.Include(a => a.FediverseApplication)
.ToListAsync(ct)
: [];
return new UserResponse(
user.Id, user.Sid, user.Username, user.DisplayName, user.Bio, user.MemberTitle, AvatarUrlFor(user),
user.Id,
user.Sid,
user.Username,
user.DisplayName,
user.Bio,
user.MemberTitle,
AvatarUrlFor(user),
user.Links,
user.Names, user.Pronouns, user.Fields, user.CustomPreferences,
user.Names,
user.Pronouns,
user.Fields,
user.CustomPreferences,
flags.Select(f => RenderPrideFlag(f.PrideFlag)),
user.Role,
renderMembers ? members.Select(m => memberRenderer.RenderPartialMember(m, tokenHidden)) : null,
renderMembers
? members.Select(m => memberRenderer.RenderPartialMember(m, tokenHidden))
: null,
renderAuthMethods
? authMethods.Select(a => new AuthenticationMethodResponse(
a.Id, a.AuthType, a.RemoteId,
a.RemoteUsername, a.FediverseApplication?.Domain
a.Id,
a.AuthType,
a.RemoteId,
a.RemoteUsername,
a.FediverseApplication?.Domain
))
: null,
tokenHidden ? user.ListHidden : null,
@ -58,10 +83,19 @@ public class UserRendererService(DatabaseContext db, MemberRendererService membe
}
public PartialUser RenderPartialUser(User user) =>
new(user.Id, user.Sid, user.Username, user.DisplayName, AvatarUrlFor(user), user.CustomPreferences);
new(
user.Id,
user.Sid,
user.Username,
user.DisplayName,
AvatarUrlFor(user),
user.CustomPreferences
);
private string? AvatarUrlFor(User user) =>
user.Avatar != null ? $"{config.MediaBaseUrl}/users/{user.Id}/avatars/{user.Avatar}.webp" : null;
user.Avatar != null
? $"{config.MediaBaseUrl}/users/{user.Id}/avatars/{user.Avatar}.webp"
: null;
public string ImageUrlFor(PrideFlag flag) => $"{config.MediaBaseUrl}/flags/{flag.Hash}.webp";
@ -79,29 +113,26 @@ public class UserRendererService(DatabaseContext db, MemberRendererService membe
IEnumerable<Field> Fields,
Dictionary<Snowflake, User.CustomPreference> CustomPreferences,
IEnumerable<PrideFlagResponse> Flags,
[property: JsonConverter(typeof(ScreamingSnakeCaseEnumConverter))]
UserRole Role,
[property: JsonConverter(typeof(ScreamingSnakeCaseEnumConverter))] UserRole Role,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
IEnumerable<MemberRendererService.PartialMember>? Members,
IEnumerable<MemberRendererService.PartialMember>? Members,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
IEnumerable<AuthenticationMethodResponse>? AuthMethods,
IEnumerable<AuthenticationMethodResponse>? AuthMethods,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
bool? MemberListHidden,
bool? MemberListHidden,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)] Instant? LastActive,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
Instant? LastActive,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
Instant? LastSidReroll
Instant? LastSidReroll
);
public record AuthenticationMethodResponse(
Snowflake Id,
[property: JsonConverter(typeof(ScreamingSnakeCaseEnumConverter))]
AuthType Type,
[property: JsonConverter(typeof(ScreamingSnakeCaseEnumConverter))] AuthType Type,
string RemoteId,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
string? RemoteUsername,
string? RemoteUsername,
[property: JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
string? FediverseInstance
string? FediverseInstance
);
public record PartialUser(
@ -120,5 +151,6 @@ public class UserRendererService(DatabaseContext db, MemberRendererService membe
Snowflake Id,
string ImageUrl,
string Name,
string? Description);
string? Description
);
}

View file

@ -7,12 +7,28 @@ public static class AuthUtils
{
public const string ClientCredentials = "client_credentials";
public const string AuthorizationCode = "authorization_code";
private static readonly string[] ForbiddenSchemes = ["javascript", "file", "data", "mailto", "tel"];
private static readonly string[] ForbiddenSchemes =
[
"javascript",
"file",
"data",
"mailto",
"tel",
];
public static readonly string[] UserScopes =
["user.read_hidden", "user.read_privileged", "user.update"];
[
"user.read_hidden",
"user.read_privileged",
"user.update",
];
public static readonly string[] MemberScopes = ["member.read", "member.update", "member.create"];
public static readonly string[] MemberScopes =
[
"member.read",
"member.update",
"member.create",
];
/// <summary>
/// All scopes endpoints can be secured by. This does *not* include the catch-all token scopes.
@ -27,10 +43,13 @@ public static class AuthUtils
public static string[] ExpandScopes(this string[] scopes)
{
if (scopes.Contains("*")) return ["*", ..Scopes];
if (scopes.Contains("*"))
return ["*", .. Scopes];
List<string> expandedScopes = ["identify"];
if (scopes.Contains("user")) expandedScopes.AddRange(UserScopes);
if (scopes.Contains("member")) expandedScopes.AddRange(MemberScopes);
if (scopes.Contains("user"))
expandedScopes.AddRange(UserScopes);
if (scopes.Contains("member"))
expandedScopes.AddRange(MemberScopes);
return expandedScopes.ToArray();
}
@ -41,8 +60,10 @@ public static class AuthUtils
private static string[] ExpandAppScopes(this string[] scopes)
{
var expandedScopes = scopes.ExpandScopes().ToList();
if (scopes.Contains("user")) expandedScopes.Add("user");
if (scopes.Contains("member")) expandedScopes.Add("member");
if (scopes.Contains("user"))
expandedScopes.Add("user");
if (scopes.Contains("member"))
expandedScopes.Add("member");
return expandedScopes.ToArray();
}
@ -84,7 +105,8 @@ public static class AuthUtils
{
rawToken = [];
if (string.IsNullOrWhiteSpace(input)) return false;
if (string.IsNullOrWhiteSpace(input))
return false;
if (input.StartsWith("bearer ", StringComparison.InvariantCultureIgnoreCase))
input = input["bearer ".Length..];

View file

@ -13,7 +13,9 @@ namespace Foxnouns.Backend.Utils;
public abstract class PatchRequest
{
private readonly HashSet<string> _properties = [];
public bool HasProperty(string propertyName) => _properties.Contains(propertyName);
public void SetHasProperty(string propertyName) => _properties.Add(propertyName);
}
@ -23,13 +25,17 @@ public abstract class PatchRequest
/// </summary>
public class PatchRequestContractResolver : DefaultContractResolver
{
protected override JsonProperty CreateProperty(MemberInfo member, MemberSerialization memberSerialization)
protected override JsonProperty CreateProperty(
MemberInfo member,
MemberSerialization memberSerialization
)
{
var prop = base.CreateProperty(member, memberSerialization);
prop.SetIsSpecified += (o, _) =>
{
if (o is not PatchRequest patchRequest) return;
if (o is not PatchRequest patchRequest)
return;
patchRequest.SetHasProperty(prop.UnderlyingName!);
};

View file

@ -8,7 +8,8 @@ namespace Foxnouns.Backend.Utils;
/// A custom StringEnumConverter that converts enum members to SCREAMING_SNAKE_CASE, rather than CamelCase as is the default.
/// Newtonsoft.Json doesn't provide a screaming snake case naming strategy, so we just wrap the normal snake case one and convert it to uppercase.
/// </summary>
public class ScreamingSnakeCaseEnumConverter() : StringEnumConverter(new ScreamingSnakeCaseNamingStrategy(), false)
public class ScreamingSnakeCaseEnumConverter()
: StringEnumConverter(new ScreamingSnakeCaseNamingStrategy(), false)
{
private class ScreamingSnakeCaseNamingStrategy : SnakeCaseNamingStrategy
{

View file

@ -21,7 +21,7 @@ public static partial class ValidationUtils
"pronouns",
"settings",
"pronouns.cc",
"pronounscc"
"pronounscc",
];
private static readonly string[] InvalidMemberNames =
@ -30,7 +30,7 @@ public static partial class ValidationUtils
".",
"..",
// the user edit page lives at `/@{username}/edit`, so a member named "edit" would be inaccessible
"edit"
"edit",
];
public static ValidationError? ValidateUsername(string username)
@ -42,10 +42,15 @@ public static partial class ValidationUtils
> 40 => ValidationError.LengthError("Username is too long", 2, 40, username.Length),
_ => ValidationError.GenericValidationError(
"Username is invalid, can only contain alphanumeric characters, dashes, underscores, and periods",
username)
username
),
};
if (InvalidUsernames.Any(u => string.Equals(u, username, StringComparison.InvariantCultureIgnoreCase)))
if (
InvalidUsernames.Any(u =>
string.Equals(u, username, StringComparison.InvariantCultureIgnoreCase)
)
)
return ValidationError.GenericValidationError("Username is not allowed", username);
return null;
}
@ -58,13 +63,18 @@ public static partial class ValidationUtils
< 1 => ValidationError.LengthError("Name is too short", 1, 100, memberName.Length),
> 100 => ValidationError.LengthError("Name is too long", 1, 100, memberName.Length),
_ => ValidationError.GenericValidationError(
"Member name cannot contain any of the following: " +
" @, ?, !, #, /, \\, [, ], \", ', $, %, &, (, ), {, }, +, <, =, >, ^, |, ~, `, , " +
"and cannot be one or two periods",
memberName)
"Member name cannot contain any of the following: "
+ " @, ?, !, #, /, \\, [, ], \", ', $, %, &, (, ), {, }, +, <, =, >, ^, |, ~, `, , "
+ "and cannot be one or two periods",
memberName
),
};
if (InvalidMemberNames.Any(u => string.Equals(u, memberName, StringComparison.InvariantCultureIgnoreCase)))
if (
InvalidMemberNames.Any(u =>
string.Equals(u, memberName, StringComparison.InvariantCultureIgnoreCase)
)
)
return ValidationError.GenericValidationError("Name is not allowed", memberName);
return null;
}
@ -72,12 +82,14 @@ public static partial class ValidationUtils
public static void Validate(IEnumerable<(string, ValidationError?)> errors)
{
errors = errors.Where(e => e.Item2 != null).ToList();
if (!errors.Any()) return;
if (!errors.Any())
return;
var errorDict = new Dictionary<string, IEnumerable<ValidationError>>();
foreach (var error in errors)
{
if (errorDict.TryGetValue(error.Item1, out var value)) errorDict[error.Item1] = value.Append(error.Item2!);
if (errorDict.TryGetValue(error.Item1, out var value))
errorDict[error.Item1] = value.Append(error.Item2!);
errorDict.Add(error.Item1, [error.Item2!]);
}
@ -88,9 +100,19 @@ public static partial class ValidationUtils
{
return displayName?.Length switch
{
0 => ValidationError.LengthError("Display name is too short", 1, 100, displayName.Length),
> 100 => ValidationError.LengthError("Display name is too long", 1, 100, displayName.Length),
_ => null
0 => ValidationError.LengthError(
"Display name is too short",
1,
100,
displayName.Length
),
> 100 => ValidationError.LengthError(
"Display name is too long",
1,
100,
displayName.Length
),
_ => null,
};
}
@ -99,9 +121,13 @@ public static partial class ValidationUtils
public static IEnumerable<(string, ValidationError?)> ValidateLinks(string[]? links)
{
if (links == null) return [];
if (links == null)
return [];
if (links.Length > MaxLinks)
return [("links", ValidationError.LengthError("Too many links", 0, MaxLinks, links.Length))];
return
[
("links", ValidationError.LengthError("Too many links", 0, MaxLinks, links.Length)),
];
var errors = new List<(string, ValidationError?)>();
foreach (var (link, idx) in links.Select((l, i) => (l, i)))
@ -109,12 +135,25 @@ public static partial class ValidationUtils
switch (link.Length)
{
case 0:
errors.Add(($"links.{idx}",
ValidationError.LengthError("Link cannot be empty", 1, 256, 0)));
errors.Add(
(
$"links.{idx}",
ValidationError.LengthError("Link cannot be empty", 1, 256, 0)
)
);
break;
case > MaxLinkLength:
errors.Add(($"links.{idx}",
ValidationError.LengthError("Link is too long", 1, MaxLinkLength, link.Length)));
errors.Add(
(
$"links.{idx}",
ValidationError.LengthError(
"Link is too long",
1,
MaxLinkLength,
link.Length
)
)
);
break;
}
}
@ -129,8 +168,13 @@ public static partial class ValidationUtils
return bio?.Length switch
{
0 => ValidationError.LengthError("Bio is too short", 1, MaxBioLength, bio.Length),
> MaxBioLength => ValidationError.LengthError("Bio is too long", 1, MaxBioLength, bio.Length),
_ => null
> MaxBioLength => ValidationError.LengthError(
"Bio is too long",
1,
MaxBioLength,
bio.Length
),
_ => null,
};
}
@ -140,121 +184,222 @@ public static partial class ValidationUtils
{
0 => ValidationError.GenericValidationError("Avatar cannot be empty", null),
> 1_500_000 => ValidationError.GenericValidationError("Avatar is too large", null),
_ => null
_ => null,
};
}
private static readonly string[] DefaultStatusOptions =
[
"favourite",
"okay",
"jokingly",
"friends_only",
"avoid"
"avoid",
];
public static IEnumerable<(string, ValidationError?)> ValidateFields(List<Field>? fields,
IReadOnlyDictionary<Snowflake, User.CustomPreference> customPreferences)
public static IEnumerable<(string, ValidationError?)> ValidateFields(
List<Field>? fields,
IReadOnlyDictionary<Snowflake, User.CustomPreference> customPreferences
)
{
if (fields == null) return [];
if (fields == null)
return [];
var errors = new List<(string, ValidationError?)>();
if (fields.Count > 25)
errors.Add(("fields", ValidationError.LengthError("Too many fields", 0, Limits.FieldLimit, fields.Count)));
errors.Add(
(
"fields",
ValidationError.LengthError(
"Too many fields",
0,
Limits.FieldLimit,
fields.Count
)
)
);
// No overwhelming this function, thank you
if (fields.Count > 100) return errors;
if (fields.Count > 100)
return errors;
foreach (var (field, index) in fields.Select((field, index) => (field, index)))
{
switch (field.Name.Length)
{
case > Limits.FieldNameLimit:
errors.Add(($"fields.{index}.name",
ValidationError.LengthError("Field name is too long", 1, Limits.FieldNameLimit,
field.Name.Length)));
errors.Add(
(
$"fields.{index}.name",
ValidationError.LengthError(
"Field name is too long",
1,
Limits.FieldNameLimit,
field.Name.Length
)
)
);
break;
case < 1:
errors.Add(($"fields.{index}.name",
ValidationError.LengthError("Field name is too short", 1, Limits.FieldNameLimit,
field.Name.Length)));
errors.Add(
(
$"fields.{index}.name",
ValidationError.LengthError(
"Field name is too short",
1,
Limits.FieldNameLimit,
field.Name.Length
)
)
);
break;
}
errors = errors.Concat(ValidateFieldEntries(field.Entries, customPreferences, $"fields.{index}.entries"))
errors = errors
.Concat(
ValidateFieldEntries(
field.Entries,
customPreferences,
$"fields.{index}.entries"
)
)
.ToList();
}
return errors;
}
public static IEnumerable<(string, ValidationError?)> ValidateFieldEntries(FieldEntry[]? entries,
IReadOnlyDictionary<Snowflake, User.CustomPreference> customPreferences, string errorPrefix = "fields")
public static IEnumerable<(string, ValidationError?)> ValidateFieldEntries(
FieldEntry[]? entries,
IReadOnlyDictionary<Snowflake, User.CustomPreference> customPreferences,
string errorPrefix = "fields"
)
{
if (entries == null || entries.Length == 0) return [];
if (entries == null || entries.Length == 0)
return [];
var errors = new List<(string, ValidationError?)>();
if (entries.Length > Limits.FieldEntriesLimit)
errors.Add((errorPrefix,
ValidationError.LengthError("Field has too many entries", 0, Limits.FieldEntriesLimit,
entries.Length)));
errors.Add(
(
errorPrefix,
ValidationError.LengthError(
"Field has too many entries",
0,
Limits.FieldEntriesLimit,
entries.Length
)
)
);
// Same as above, no overwhelming this function with a ridiculous amount of entries
if (entries.Length > Limits.FieldEntriesLimit + 50) return errors;
if (entries.Length > Limits.FieldEntriesLimit + 50)
return errors;
foreach (var (entry, entryIdx) in entries.Select((entry, entryIdx) => (entry, entryIdx)))
{
switch (entry.Value.Length)
{
case > Limits.FieldEntryTextLimit:
errors.Add(($"{errorPrefix}.{entryIdx}.value",
ValidationError.LengthError("Field value is too long", 1, Limits.FieldEntryTextLimit,
entry.Value.Length)));
errors.Add(
(
$"{errorPrefix}.{entryIdx}.value",
ValidationError.LengthError(
"Field value is too long",
1,
Limits.FieldEntryTextLimit,
entry.Value.Length
)
)
);
break;
case < 1:
errors.Add(($"{errorPrefix}.{entryIdx}.value",
ValidationError.LengthError("Field value is too short", 1, Limits.FieldEntryTextLimit,
entry.Value.Length)));
errors.Add(
(
$"{errorPrefix}.{entryIdx}.value",
ValidationError.LengthError(
"Field value is too short",
1,
Limits.FieldEntryTextLimit,
entry.Value.Length
)
)
);
break;
}
var customPreferenceIds = customPreferences?.Keys.Select(id => id.ToString()) ?? [];
if (!DefaultStatusOptions.Contains(entry.Status) && !customPreferenceIds.Contains(entry.Status))
errors.Add(($"{errorPrefix}.{entryIdx}.status",
ValidationError.GenericValidationError("Invalid status", entry.Status)));
if (
!DefaultStatusOptions.Contains(entry.Status)
&& !customPreferenceIds.Contains(entry.Status)
)
errors.Add(
(
$"{errorPrefix}.{entryIdx}.status",
ValidationError.GenericValidationError("Invalid status", entry.Status)
)
);
}
return errors;
}
public static IEnumerable<(string, ValidationError?)> ValidatePronouns(Pronoun[]? entries,
IReadOnlyDictionary<Snowflake, User.CustomPreference> customPreferences, string errorPrefix = "pronouns")
public static IEnumerable<(string, ValidationError?)> ValidatePronouns(
Pronoun[]? entries,
IReadOnlyDictionary<Snowflake, User.CustomPreference> customPreferences,
string errorPrefix = "pronouns"
)
{
if (entries == null || entries.Length == 0) return [];
if (entries == null || entries.Length == 0)
return [];
var errors = new List<(string, ValidationError?)>();
if (entries.Length > Limits.FieldEntriesLimit)
errors.Add((errorPrefix,
ValidationError.LengthError("Too many pronouns", 0, Limits.FieldEntriesLimit,
entries.Length)));
errors.Add(
(
errorPrefix,
ValidationError.LengthError(
"Too many pronouns",
0,
Limits.FieldEntriesLimit,
entries.Length
)
)
);
// Same as above, no overwhelming this function with a ridiculous amount of entries
if (entries.Length > Limits.FieldEntriesLimit + 50) return errors;
if (entries.Length > Limits.FieldEntriesLimit + 50)
return errors;
foreach (var (entry, entryIdx) in entries.Select((entry, entryIdx) => (entry, entryIdx)))
{
switch (entry.Value.Length)
{
case > Limits.FieldEntryTextLimit:
errors.Add(($"{errorPrefix}.{entryIdx}.value",
ValidationError.LengthError("Pronoun value is too long", 1, Limits.FieldEntryTextLimit,
entry.Value.Length)));
errors.Add(
(
$"{errorPrefix}.{entryIdx}.value",
ValidationError.LengthError(
"Pronoun value is too long",
1,
Limits.FieldEntryTextLimit,
entry.Value.Length
)
)
);
break;
case < 1:
errors.Add(($"{errorPrefix}.{entryIdx}.value",
ValidationError.LengthError("Pronoun value is too short", 1, Limits.FieldEntryTextLimit,
entry.Value.Length)));
errors.Add(
(
$"{errorPrefix}.{entryIdx}.value",
ValidationError.LengthError(
"Pronoun value is too short",
1,
Limits.FieldEntryTextLimit,
entry.Value.Length
)
)
);
break;
}
@ -263,25 +408,46 @@ public static partial class ValidationUtils
switch (entry.DisplayText.Length)
{
case > Limits.FieldEntryTextLimit:
errors.Add(($"{errorPrefix}.{entryIdx}.value",
ValidationError.LengthError("Pronoun display text is too long", 1,
Limits.FieldEntryTextLimit,
entry.Value.Length)));
errors.Add(
(
$"{errorPrefix}.{entryIdx}.value",
ValidationError.LengthError(
"Pronoun display text is too long",
1,
Limits.FieldEntryTextLimit,
entry.Value.Length
)
)
);
break;
case < 1:
errors.Add(($"{errorPrefix}.{entryIdx}.value",
ValidationError.LengthError("Pronoun display text is too short", 1,
Limits.FieldEntryTextLimit,
entry.Value.Length)));
errors.Add(
(
$"{errorPrefix}.{entryIdx}.value",
ValidationError.LengthError(
"Pronoun display text is too short",
1,
Limits.FieldEntryTextLimit,
entry.Value.Length
)
)
);
break;
}
}
var customPreferenceIds = customPreferences?.Keys.Select(id => id.ToString()) ?? [];
if (!DefaultStatusOptions.Contains(entry.Status) && !customPreferenceIds.Contains(entry.Status))
errors.Add(($"{errorPrefix}.{entryIdx}.status",
ValidationError.GenericValidationError("Invalid status", entry.Status)));
if (
!DefaultStatusOptions.Contains(entry.Status)
&& !customPreferenceIds.Contains(entry.Status)
)
errors.Add(
(
$"{errorPrefix}.{entryIdx}.status",
ValidationError.GenericValidationError("Invalid status", entry.Status)
)
);
}
return errors;
@ -290,6 +456,10 @@ public static partial class ValidationUtils
[GeneratedRegex(@"^[a-zA-Z_0-9\-\.]{2,40}$", RegexOptions.IgnoreCase, "en-NL")]
private static partial Regex UsernameRegex();
[GeneratedRegex("""^[^@'$%&()+<=>^|~`,*!#/\\\[\]""\{\}\?]{1,100}$""", RegexOptions.IgnoreCase, "en-NL")]
[GeneratedRegex(
"""^[^@'$%&()+<=>^|~`,*!#/\\\[\]""\{\}\?]{1,100}$""",
RegexOptions.IgnoreCase,
"en-NL"
)]
private static partial Regex MemberRegex();
}

View file

@ -19,12 +19,19 @@ public static class Users
var stopwatch = new Stopwatch();
stopwatch.Start();
var users = NetImporter.ReadFromFile<ImportUser>(filename).Output.Select(ConvertUser).ToList();
var users = NetImporter
.ReadFromFile<ImportUser>(filename)
.Output.Select(ConvertUser)
.ToList();
db.AddRange(users);
await db.SaveChangesAsync();
stopwatch.Stop();
Log.Information("Imported {Count} users in {Duration}", users.Count, stopwatch.ElapsedDuration());
Log.Information(
"Imported {Count} users in {Duration}",
users.Count,
stopwatch.ElapsedDuration()
);
}
private static User ConvertUser(ImportUser oldUser)
@ -43,40 +50,46 @@ public static class Users
Role = oldUser.ParseRole(),
Deleted = oldUser.Deleted,
DeletedAt = oldUser.DeletedAt?.ToInstant(),
DeletedBy = null
DeletedBy = null,
};
if (oldUser is { DiscordId: not null, DiscordUsername: not null })
{
user.AuthMethods.Add(new AuthMethod
{
Id = SnowflakeGenerator.Instance.GenerateSnowflake(),
AuthType = AuthType.Discord,
RemoteId = oldUser.DiscordId,
RemoteUsername = oldUser.DiscordUsername
});
user.AuthMethods.Add(
new AuthMethod
{
Id = SnowflakeGenerator.Instance.GenerateSnowflake(),
AuthType = AuthType.Discord,
RemoteId = oldUser.DiscordId,
RemoteUsername = oldUser.DiscordUsername,
}
);
}
if (oldUser is { TumblrId: not null, TumblrUsername: not null })
{
user.AuthMethods.Add(new AuthMethod
{
Id = SnowflakeGenerator.Instance.GenerateSnowflake(),
AuthType = AuthType.Tumblr,
RemoteId = oldUser.TumblrId,
RemoteUsername = oldUser.TumblrUsername
});
user.AuthMethods.Add(
new AuthMethod
{
Id = SnowflakeGenerator.Instance.GenerateSnowflake(),
AuthType = AuthType.Tumblr,
RemoteId = oldUser.TumblrId,
RemoteUsername = oldUser.TumblrUsername,
}
);
}
if (oldUser is { GoogleId: not null, GoogleUsername: not null })
{
user.AuthMethods.Add(new AuthMethod
{
Id = SnowflakeGenerator.Instance.GenerateSnowflake(),
AuthType = AuthType.Google,
RemoteId = oldUser.GoogleId,
RemoteUsername = oldUser.GoogleUsername
});
user.AuthMethods.Add(
new AuthMethod
{
Id = SnowflakeGenerator.Instance.GenerateSnowflake(),
AuthType = AuthType.Google,
RemoteId = oldUser.GoogleId,
RemoteUsername = oldUser.GoogleUsername,
}
);
}
// Convert all custom preference UUIDs to snowflakes
@ -90,41 +103,44 @@ public static class Users
foreach (var name in oldUser.Names ?? [])
{
user.Names.Add(new FieldEntry
{
Value = name.Value,
Status = prefMapping.TryGetValue(name.Status, out var newStatus) ? newStatus.ToString() : name.Status,
});
user.Names.Add(
new FieldEntry
{
Value = name.Value,
Status = prefMapping.TryGetValue(name.Status, out var newStatus)
? newStatus.ToString()
: name.Status,
}
);
}
foreach (var pronoun in oldUser.Pronouns ?? [])
{
user.Pronouns.Add(new Pronoun
{
Value = pronoun.Value,
DisplayText = pronoun.DisplayText,
Status = prefMapping.TryGetValue(pronoun.Status, out var newStatus)
? newStatus.ToString()
: pronoun.Status,
});
user.Pronouns.Add(
new Pronoun
{
Value = pronoun.Value,
DisplayText = pronoun.DisplayText,
Status = prefMapping.TryGetValue(pronoun.Status, out var newStatus)
? newStatus.ToString()
: pronoun.Status,
}
);
}
foreach (var field in oldUser.Fields ?? [])
{
var entries = field.Entries.Select(entry => new FieldEntry
{
Value = entry.Value,
Status = prefMapping.TryGetValue(entry.Status, out var newStatus)
var entries = field
.Entries.Select(entry => new FieldEntry
{
Value = entry.Value,
Status = prefMapping.TryGetValue(entry.Status, out var newStatus)
? newStatus.ToString()
: entry.Status,
})
})
.ToList();
user.Fields.Add(new Field
{
Name = field.Name,
Entries = entries.ToArray()
});
user.Fields.Add(new Field { Name = field.Name, Entries = entries.ToArray() });
}
Log.Debug("Converted user {UserId}", oldUser.Id);
@ -161,14 +177,16 @@ public static class Users
bool Deleted,
OffsetDateTime? DeletedAt,
string? DeleteReason,
Dictionary<string, User.CustomPreference> CustomPreferences)
Dictionary<string, User.CustomPreference> CustomPreferences
)
{
public UserRole ParseRole() => Role switch
{
"USER" => UserRole.User,
"MODERATOR" => UserRole.Moderator,
"ADMIN" => UserRole.Admin,
_ => UserRole.User
};
public UserRole ParseRole() =>
Role switch
{
"USER" => UserRole.User,
"MODERATOR" => UserRole.Moderator,
"ADMIN" => UserRole.Admin,
_ => UserRole.User,
};
}
}

View file

@ -19,7 +19,10 @@ internal static class NetImporter
.Enrich.FromLogContext()
.MinimumLevel.Debug()
.MinimumLevel.Override("Microsoft", LogEventLevel.Information)
.MinimumLevel.Override("Microsoft.EntityFrameworkCore.Database.Command", LogEventLevel.Information)
.MinimumLevel.Override(
"Microsoft.EntityFrameworkCore.Database.Command",
LogEventLevel.Information
)
.WriteTo.Console()
.CreateLogger();
@ -47,16 +50,11 @@ internal static class NetImporter
internal static async Task<DatabaseContext> GetContextAsync()
{
var connString = Environment.GetEnvironmentVariable("DATABASE");
if (connString == null) throw new Exception("$DATABASE not set, must be an ADO.NET connection string");
if (connString == null)
throw new Exception("$DATABASE not set, must be an ADO.NET connection string");
var loggerFactory = new LoggerFactory().AddSerilog(Log.Logger);
var config = new Config
{
Database = new Config.DatabaseConfig
{
Url = connString
}
};
var config = new Config { Database = new Config.DatabaseConfig { Url = connString } };
var db = new DatabaseContext(config, loggerFactory);
@ -70,13 +68,17 @@ internal static class NetImporter
private static readonly JsonSerializerSettings Settings = new JsonSerializerSettings
{
ContractResolver = new DefaultContractResolver { NamingStrategy = new SnakeCaseNamingStrategy() }
ContractResolver = new DefaultContractResolver
{
NamingStrategy = new SnakeCaseNamingStrategy(),
},
}.ConfigureForNodaTime(DateTimeZoneProviders.Tzdb);
internal static Input<T> ReadFromFile<T>(string path)
{
var data = File.ReadAllText(path);
return JsonConvert.DeserializeObject<Input<T>>(data, Settings) ?? throw new Exception("Invalid input file");
return JsonConvert.DeserializeObject<Input<T>>(data, Settings)
?? throw new Exception("Invalid input file");
}
}