2024-12-09 21:11:46 +01:00
|
|
|
// Copyright (C) 2023-present sam/u1f320 (vulpine.solutions)
|
|
|
|
//
|
|
|
|
// This program is free software: you can redistribute it and/or modify
|
|
|
|
// it under the terms of the GNU Affero General Public License as published
|
|
|
|
// by the Free Software Foundation, either version 3 of the License, or
|
|
|
|
// (at your option) any later version.
|
|
|
|
//
|
|
|
|
// This program is distributed in the hope that it will be useful,
|
|
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
|
|
// GNU Affero General Public License for more details.
|
|
|
|
//
|
|
|
|
// You should have received a copy of the GNU Affero General Public License
|
|
|
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
2024-09-27 14:48:09 +02:00
|
|
|
using Foxnouns.Backend.Database.Models;
|
|
|
|
using Microsoft.EntityFrameworkCore;
|
|
|
|
|
|
|
|
namespace Foxnouns.Backend.Database;
|
|
|
|
|
|
|
|
public static class FlagQueryExtensions
|
|
|
|
{
|
2024-10-02 00:28:07 +02:00
|
|
|
private static async Task<List<PrideFlag>> GetFlagsAsync(
|
|
|
|
this DatabaseContext db,
|
|
|
|
Snowflake userId
|
|
|
|
) => await db.PrideFlags.Where(f => f.UserId == userId).OrderBy(f => f.Id).ToListAsync();
|
2024-09-27 14:48:09 +02:00
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Sets the user's profile flags to the given IDs. Returns a validation error if any of the flag IDs are unknown
|
|
|
|
/// or if too many IDs are given. Duplicates are allowed.
|
|
|
|
/// </summary>
|
2024-10-02 00:28:07 +02:00
|
|
|
public static async Task<ValidationError?> SetUserFlagsAsync(
|
|
|
|
this DatabaseContext db,
|
|
|
|
Snowflake userId,
|
|
|
|
Snowflake[] flagIds
|
|
|
|
)
|
2024-09-27 14:48:09 +02:00
|
|
|
{
|
2024-12-08 15:07:25 +01:00
|
|
|
List<UserFlag> currentFlags = await db
|
|
|
|
.UserFlags.Where(f => f.UserId == userId)
|
|
|
|
.ToListAsync();
|
|
|
|
foreach (UserFlag flag in currentFlags)
|
2024-09-27 14:48:09 +02:00
|
|
|
db.UserFlags.Remove(flag);
|
|
|
|
|
|
|
|
// If there's no new flags to set, we're done
|
2024-10-02 00:28:07 +02:00
|
|
|
if (flagIds.Length == 0)
|
|
|
|
return null;
|
|
|
|
if (flagIds.Length > 100)
|
|
|
|
return ValidationError.LengthError("Too many profile flags", 0, 100, flagIds.Length);
|
2024-09-27 14:48:09 +02:00
|
|
|
|
2024-12-08 15:07:25 +01:00
|
|
|
List<PrideFlag> flags = await db.GetFlagsAsync(userId);
|
|
|
|
Snowflake[] unknownFlagIds = flagIds.Where(id => flags.All(f => f.Id != id)).ToArray();
|
2024-09-27 14:48:09 +02:00
|
|
|
if (unknownFlagIds.Length != 0)
|
|
|
|
return ValidationError.GenericValidationError("Unknown flag IDs", unknownFlagIds);
|
|
|
|
|
2024-12-08 15:07:25 +01:00
|
|
|
IEnumerable<UserFlag> userFlags = flagIds.Select(id => new UserFlag
|
|
|
|
{
|
|
|
|
PrideFlagId = id,
|
|
|
|
UserId = userId,
|
|
|
|
});
|
2024-09-27 14:48:09 +02:00
|
|
|
db.UserFlags.AddRange(userFlags);
|
|
|
|
|
|
|
|
return null;
|
|
|
|
}
|
2024-09-28 22:28:59 +02:00
|
|
|
|
2024-10-02 00:28:07 +02:00
|
|
|
public static async Task<ValidationError?> SetMemberFlagsAsync(
|
|
|
|
this DatabaseContext db,
|
|
|
|
Snowflake userId,
|
|
|
|
Snowflake memberId,
|
|
|
|
Snowflake[] flagIds
|
|
|
|
)
|
2024-09-28 22:28:59 +02:00
|
|
|
{
|
2024-12-08 15:07:25 +01:00
|
|
|
List<MemberFlag> currentFlags = await db
|
|
|
|
.MemberFlags.Where(f => f.MemberId == memberId)
|
|
|
|
.ToListAsync();
|
|
|
|
foreach (MemberFlag flag in currentFlags)
|
2024-09-28 22:28:59 +02:00
|
|
|
db.MemberFlags.Remove(flag);
|
|
|
|
|
2024-10-02 00:28:07 +02:00
|
|
|
if (flagIds.Length == 0)
|
|
|
|
return null;
|
|
|
|
if (flagIds.Length > 100)
|
|
|
|
return ValidationError.LengthError("Too many profile flags", 0, 100, flagIds.Length);
|
2024-09-28 22:28:59 +02:00
|
|
|
|
2024-12-08 15:07:25 +01:00
|
|
|
List<PrideFlag> flags = await db.GetFlagsAsync(userId);
|
|
|
|
Snowflake[] unknownFlagIds = flagIds.Where(id => flags.All(f => f.Id != id)).ToArray();
|
2024-09-28 22:28:59 +02:00
|
|
|
if (unknownFlagIds.Length != 0)
|
|
|
|
return ValidationError.GenericValidationError("Unknown flag IDs", unknownFlagIds);
|
|
|
|
|
2024-12-08 15:07:25 +01:00
|
|
|
IEnumerable<MemberFlag> memberFlags = flagIds.Select(id => new MemberFlag
|
2024-10-02 00:28:07 +02:00
|
|
|
{
|
|
|
|
PrideFlagId = id,
|
|
|
|
MemberId = memberId,
|
|
|
|
});
|
2024-09-28 22:28:59 +02:00
|
|
|
db.MemberFlags.AddRange(memberFlags);
|
|
|
|
|
|
|
|
return null;
|
|
|
|
}
|
2024-10-02 00:28:07 +02:00
|
|
|
}
|