From 1b78462f50f663e05d9b0c2bb7a5c450ac9eae8f Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 25 May 2023 13:40:15 +0200 Subject: [PATCH] feat: add flags to PATCH /users/@me --- backend/db/db.go | 5 ++ backend/db/flags.go | 74 +++++++++++++++++++++++-- backend/db/member.go | 4 +- backend/db/user.go | 4 +- backend/routes/member/create_member.go | 2 +- backend/routes/member/get_member.go | 18 +++++-- backend/routes/member/patch_member.go | 2 +- backend/routes/user/get_user.go | 75 +++++++++++--------------- backend/routes/user/patch_user.go | 27 +++++++++- 9 files changed, 153 insertions(+), 58 deletions(-) diff --git a/backend/db/db.go b/backend/db/db.go index 620e498..75ededf 100644 --- a/backend/db/db.go +++ b/backend/db/db.go @@ -22,6 +22,11 @@ var sq = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar) const ErrNothingToUpdate = errors.Sentinel("nothing to update") +const ( + uniqueViolation = "23505" + foreignKeyViolation = "23503" +) + type Execer interface { Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) } diff --git a/backend/db/flags.go b/backend/db/flags.go index cc2195a..15e11fb 100644 --- a/backend/db/flags.go +++ b/backend/db/flags.go @@ -9,10 +9,12 @@ import ( "io" "strings" + "codeberg.org/u1f320/pronouns.cc/backend/log" "emperror.dev/errors" "github.com/davidbyttow/govips/v2/vips" "github.com/georgysavva/scany/v2/pgxscan" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/minio/minio-go/v7" "github.com/rs/xid" ) @@ -51,6 +53,10 @@ const ( MaxPrideFlagDescLength = 200 ) +const ( + ErrInvalidFlagID = errors.Sentinel("invalid flag ID") +) + func (db *DB) AccountFlags(ctx context.Context, userID xid.ID) (fs []PrideFlag, err error) { sql, args, err := sq.Select("*").From("pride_flags").Where("user_id = ?", userID).OrderBy("id").ToSql() if err != nil { @@ -67,7 +73,7 @@ func (db *DB) AccountFlags(ctx context.Context, userID xid.ID) (fs []PrideFlag, func (db *DB) UserFlags(ctx context.Context, userID xid.ID) (fs []UserFlag, err error) { sql, args, err := sq.Select("u.id", "u.flag_id", "f.user_id", "f.hash", "f.name", "f.description"). From("user_flags AS u"). - Where("u.user_id = $1"). + Where("u.user_id = $1", userID). Join("pride_flags AS f ON u.flag_id = f.id"). OrderBy("u.id ASC"). ToSql() @@ -82,10 +88,10 @@ func (db *DB) UserFlags(ctx context.Context, userID xid.ID) (fs []UserFlag, err return NotNull(fs), nil } -func (db *DB) MemberFlags(ctx context.Context, userID xid.ID) (fs []MemberFlag, err error) { +func (db *DB) MemberFlags(ctx context.Context, memberID xid.ID) (fs []MemberFlag, err error) { sql, args, err := sq.Select("m.id", "m.flag_id", "m.member_id", "f.hash", "f.name", "f.description"). From("member_flags AS m"). - Where("m.member_id = $1"). + Where("m.member_id = $1", memberID). Join("pride_flags AS f ON m.flag_id = f.id"). OrderBy("m.id ASC"). ToSql() @@ -100,6 +106,68 @@ func (db *DB) MemberFlags(ctx context.Context, userID xid.ID) (fs []MemberFlag, return NotNull(fs), nil } +func (db *DB) SetUserFlags(ctx context.Context, tx pgx.Tx, userID xid.ID, flags []xid.ID) (err error) { + sql, args, err := sq.Delete("user_flags").Where("user_id = ?", userID).ToSql() + if err != nil { + return errors.Wrap(err, "building sql") + } + + _, err = tx.Exec(ctx, sql, args...) + if err != nil { + return errors.Wrap(err, "deleting existing flags") + } + + n, err := tx.CopyFrom(ctx, pgx.Identifier{"user_flags"}, []string{"user_id", "flag_id"}, + pgx.CopyFromSlice(len(flags), func(i int) ([]any, error) { + return []any{userID, flags[i]}, nil + })) + if err != nil { + pge := &pgconn.PgError{} + if errors.As(err, &pge) { + if pge.Code == foreignKeyViolation { + return ErrInvalidFlagID + } + } + + return errors.Wrap(err, "copying new flags") + } + if n > 0 { + log.Debugf("set %v flags for user %v", n, userID) + } + return nil +} + +func (db *DB) SetMemberFlags(ctx context.Context, tx pgx.Tx, memberID xid.ID, flags []xid.ID) (err error) { + sql, args, err := sq.Delete("member_flags").Where("member_id = ?", memberID).ToSql() + if err != nil { + return errors.Wrap(err, "building sql") + } + + _, err = tx.Exec(ctx, sql, args...) + if err != nil { + return errors.Wrap(err, "deleting existing flags") + } + + n, err := tx.CopyFrom(ctx, pgx.Identifier{"member_flags"}, []string{"member_id", "flag_id"}, + pgx.CopyFromSlice(len(flags), func(i int) ([]any, error) { + return []any{memberID, flags[i]}, nil + })) + if err != nil { + pge := &pgconn.PgError{} + if errors.As(err, &pge) { + if pge.Code == foreignKeyViolation { + return ErrInvalidFlagID + } + } + + return errors.Wrap(err, "copying new flags") + } + if n > 0 { + log.Debugf("set %v flags for member %v", n, memberID) + } + return nil +} + func (db *DB) CreateFlag(ctx context.Context, tx pgx.Tx, userID xid.ID, name, desc string) (f PrideFlag, err error) { description := &desc if desc == "" { diff --git a/backend/db/member.go b/backend/db/member.go index 2cb377a..43682c5 100644 --- a/backend/db/member.go +++ b/backend/db/member.go @@ -116,7 +116,7 @@ func (db *DB) CreateMember( pge := &pgconn.PgError{} if errors.As(err, &pge) { // unique constraint violation - if pge.Code == "23505" { + if pge.Code == uniqueViolation { return m, ErrMemberNameInUse } } @@ -223,7 +223,7 @@ func (db *DB) UpdateMember( if err != nil { pge := &pgconn.PgError{} if errors.As(err, &pge) { - if pge.Code == "23505" { + if pge.Code == uniqueViolation { return m, ErrMemberNameInUse } } diff --git a/backend/db/user.go b/backend/db/user.go index 63b8173..0370088 100644 --- a/backend/db/user.go +++ b/backend/db/user.go @@ -171,7 +171,7 @@ func (db *DB) CreateUser(ctx context.Context, tx pgx.Tx, username string) (u Use pge := &pgconn.PgError{} if errors.As(err, &pge) { // unique constraint violation - if pge.Code == "23505" { + if pge.Code == uniqueViolation { return u, ErrUsernameTaken } } @@ -494,7 +494,7 @@ func (db *DB) UpdateUsername(ctx context.Context, tx pgx.Tx, id xid.ID, newName pge := &pgconn.PgError{} if errors.As(err, &pge) { // unique constraint violation - if pge.Code == "23505" { + if pge.Code == uniqueViolation { return ErrUsernameTaken } } diff --git a/backend/routes/member/create_member.go b/backend/routes/member/create_member.go index 1156674..b2d4198 100644 --- a/backend/routes/member/create_member.go +++ b/backend/routes/member/create_member.go @@ -188,7 +188,7 @@ func (s *Server) createMember(w http.ResponseWriter, r *http.Request) (err error return errors.Wrap(err, "committing transaction") } - render.JSON(w, r, dbMemberToMember(u, m, cmr.Fields, true)) + render.JSON(w, r, dbMemberToMember(u, m, cmr.Fields, nil, true)) return nil } diff --git a/backend/routes/member/get_member.go b/backend/routes/member/get_member.go index c109b7f..d92aa63 100644 --- a/backend/routes/member/get_member.go +++ b/backend/routes/member/get_member.go @@ -23,13 +23,14 @@ type GetMemberResponse struct { Names []db.FieldEntry `json:"names"` Pronouns []db.PronounEntry `json:"pronouns"` Fields []db.Field `json:"fields"` + Flags []db.MemberFlag `json:"flags"` User PartialUser `json:"user"` Unlisted *bool `json:"unlisted,omitempty"` } -func dbMemberToMember(u db.User, m db.Member, fields []db.Field, isOwnMember bool) GetMemberResponse { +func dbMemberToMember(u db.User, m db.Member, fields []db.Field, flags []db.MemberFlag, isOwnMember bool) GetMemberResponse { r := GetMemberResponse{ ID: m.ID, Name: m.Name, @@ -41,6 +42,7 @@ func dbMemberToMember(u db.User, m db.Member, fields []db.Field, isOwnMember boo Names: db.NotNull(m.Names), Pronouns: db.NotNull(m.Pronouns), Fields: db.NotNull(fields), + Flags: flags, User: PartialUser{ ID: u.ID, @@ -102,7 +104,12 @@ func (s *Server) getMember(w http.ResponseWriter, r *http.Request) error { return err } - render.JSON(w, r, dbMemberToMember(u, m, fields, isOwnMember)) + flags, err := s.DB.MemberFlags(ctx, m.ID) + if err != nil { + return err + } + + render.JSON(w, r, dbMemberToMember(u, m, fields, flags, isOwnMember)) return nil } @@ -137,7 +144,12 @@ func (s *Server) getUserMember(w http.ResponseWriter, r *http.Request) error { return err } - render.JSON(w, r, dbMemberToMember(u, m, fields, isOwnMember)) + flags, err := s.DB.MemberFlags(ctx, m.ID) + if err != nil { + return err + } + + render.JSON(w, r, dbMemberToMember(u, m, fields, flags, isOwnMember)) return nil } diff --git a/backend/routes/member/patch_member.go b/backend/routes/member/patch_member.go index a6d0881..8e1d7af 100644 --- a/backend/routes/member/patch_member.go +++ b/backend/routes/member/patch_member.go @@ -284,6 +284,6 @@ func (s *Server) patchMember(w http.ResponseWriter, r *http.Request) error { } // echo the updated member back on success - render.JSON(w, r, dbMemberToMember(u, m, fields, true)) + render.JSON(w, r, dbMemberToMember(u, m, fields, nil, true)) return nil } diff --git a/backend/routes/user/get_user.go b/backend/routes/user/get_user.go index 2466e2e..c04f48d 100644 --- a/backend/routes/user/get_user.go +++ b/backend/routes/user/get_user.go @@ -25,6 +25,7 @@ type GetUserResponse struct { Members []PartialMember `json:"members"` Fields []db.Field `json:"fields"` CustomPreferences db.CustomPreferences `json:"custom_preferences"` + Flags []db.UserFlag `json:"flags"` } type GetMeResponse struct { @@ -61,7 +62,7 @@ type PartialMember struct { Pronouns []db.PronounEntry `json:"pronouns"` } -func dbUserToResponse(u db.User, fields []db.Field, members []db.Member) GetUserResponse { +func dbUserToResponse(u db.User, fields []db.Field, members []db.Member, flags []db.UserFlag) GetUserResponse { resp := GetUserResponse{ ID: u.ID, Username: u.Username, @@ -74,6 +75,7 @@ func dbUserToResponse(u db.User, fields []db.Field, members []db.Member) GetUser Pronouns: db.NotNull(u.Pronouns), Fields: db.NotNull(fields), CustomPreferences: u.CustomPreferences, + Flags: flags, } resp.Members = make([]PartialMember, len(members)) @@ -93,56 +95,29 @@ func dbUserToResponse(u db.User, fields []db.Field, members []db.Member) GetUser return resp } -func (s *Server) getUser(w http.ResponseWriter, r *http.Request) error { +func (s *Server) getUser(w http.ResponseWriter, r *http.Request) (err error) { ctx := r.Context() userRef := chi.URLParamFromCtx(ctx, "userRef") + var u db.User if id, err := xid.FromString(userRef); err == nil { - u, err := s.DB.User(ctx, id) - if err == nil { - if u.DeletedAt != nil { - return server.APIError{Code: server.ErrUserNotFound} - } - - isSelf := false - if claims, ok := server.ClaimsFromContext(ctx); ok && claims.UserID == u.ID { - isSelf = true - } - - fields, err := s.DB.UserFields(ctx, u.ID) - if err != nil { - log.Errorf("Error getting user fields: %v", err) - return err - } - - var members []db.Member - if !u.ListPrivate || isSelf { - members, err = s.DB.UserMembers(ctx, u.ID, isSelf) - if err != nil { - log.Errorf("Error getting user members: %v", err) - return err - } - } - - render.JSON(w, r, dbUserToResponse(u, fields, members)) - return nil - } else if err != db.ErrUserNotFound { - log.Errorf("Error getting user by ID: %v", err) - return err + u, err = s.DB.User(ctx, id) + if err != nil { + log.Errorf("getting user by ID: %v", err) } - // otherwise, we fall back to checking usernames } - u, err := s.DB.Username(ctx, userRef) - if err == db.ErrUserNotFound { - return server.APIError{ - Code: server.ErrUserNotFound, + if u.ID.IsNil() { + u, err = s.DB.Username(ctx, userRef) + if err == db.ErrUserNotFound { + return server.APIError{ + Code: server.ErrUserNotFound, + } + } else if err != nil { + log.Errorf("Error getting user by username: %v", err) + return err } - - } else if err != nil { - log.Errorf("Error getting user by username: %v", err) - return err } if u.DeletedAt != nil { @@ -160,6 +135,12 @@ func (s *Server) getUser(w http.ResponseWriter, r *http.Request) error { return err } + flags, err := s.DB.UserFlags(ctx, u.ID) + if err != nil { + log.Errorf("getting user flags: %v", err) + return err + } + var members []db.Member if !u.ListPrivate || isSelf { members, err = s.DB.UserMembers(ctx, u.ID, isSelf) @@ -169,7 +150,7 @@ func (s *Server) getUser(w http.ResponseWriter, r *http.Request) error { } } - render.JSON(w, r, dbUserToResponse(u, fields, members)) + render.JSON(w, r, dbUserToResponse(u, fields, members, flags)) return nil } @@ -195,8 +176,14 @@ func (s *Server) getMeUser(w http.ResponseWriter, r *http.Request) error { return err } + flags, err := s.DB.UserFlags(ctx, u.ID) + if err != nil { + log.Errorf("getting user flags: %v", err) + return err + } + render.JSON(w, r, GetMeResponse{ - GetUserResponse: dbUserToResponse(u, fields, members), + GetUserResponse: dbUserToResponse(u, fields, members, flags), CreatedAt: u.ID.Time(), MaxInvites: u.MaxInvites, IsAdmin: u.IsAdmin, diff --git a/backend/routes/user/patch_user.go b/backend/routes/user/patch_user.go index ab07a58..3b015bb 100644 --- a/backend/routes/user/patch_user.go +++ b/backend/routes/user/patch_user.go @@ -11,6 +11,7 @@ import ( "emperror.dev/errors" "github.com/go-chi/render" "github.com/google/uuid" + "github.com/rs/xid" ) type PatchUserRequest struct { @@ -25,6 +26,7 @@ type PatchUserRequest struct { Avatar *string `json:"avatar"` ListPrivate *bool `json:"list_private"` CustomPreferences *db.CustomPreferences `json:"custom_preferences"` + Flags *[]xid.ID `json:"flags"` } // patchUser parses a PatchUserRequest and updates the user with the given ID. @@ -60,7 +62,8 @@ func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error { req.Names == nil && req.Pronouns == nil && req.Avatar == nil && - req.CustomPreferences == nil { + req.CustomPreferences == nil && + req.Flags == nil { return server.APIError{ Code: server.ErrBadRequest, Details: "Data must not be empty", @@ -252,6 +255,19 @@ func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error { } } + // update flags + if req.Flags != nil { + err = s.DB.SetUserFlags(ctx, tx, claims.UserID, *req.Flags) + if err != nil { + if err == db.ErrInvalidFlagID { + return server.APIError{Code: server.ErrBadRequest, Details: "One or more flag IDs are unknown"} + } + + log.Errorf("updating flags for user %v: %v", claims.UserID, err) + return err + } + } + // update last active time err = s.DB.UpdateActiveTime(ctx, tx, claims.UserID) if err != nil { @@ -274,9 +290,16 @@ func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error { } } + // get flags to return (we need to return full flag objects, not the array of IDs in the request body) + flags, err := s.DB.UserFlags(ctx, u.ID) + if err != nil { + log.Errorf("getting user flags: %v", err) + return err + } + // echo the updated user back on success render.JSON(w, r, GetMeResponse{ - GetUserResponse: dbUserToResponse(u, fields, nil), + GetUserResponse: dbUserToResponse(u, fields, nil, flags), MaxInvites: u.MaxInvites, IsAdmin: u.IsAdmin, ListPrivate: u.ListPrivate,