feat: add user names/pronouns to GET /users/{userRef} and PATCH /users/@me
This commit is contained in:
		
							parent
							
								
									639b3373e5
								
							
						
					
					
						commit
						9e98b61472
					
				
					 3 changed files with 271 additions and 27 deletions
				
			
		
							
								
								
									
										172
									
								
								backend/db/names_pronouns.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										172
									
								
								backend/db/names_pronouns.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,172 @@ | |||
| package db | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"emperror.dev/errors" | ||||
| 	"github.com/georgysavva/scany/pgxscan" | ||||
| 	"github.com/jackc/pgx/v4" | ||||
| 	"github.com/rs/xid" | ||||
| ) | ||||
| 
 | ||||
| type WordStatus int | ||||
| 
 | ||||
| const ( | ||||
| 	StatusUnknown     WordStatus = 0 | ||||
| 	StatusFavourite   WordStatus = 1 | ||||
| 	StatusOkay        WordStatus = 2 | ||||
| 	StatusJokingly    WordStatus = 3 | ||||
| 	StatusFriendsOnly WordStatus = 4 | ||||
| 	StatusAvoid       WordStatus = 5 | ||||
| 	wordStatusMax     WordStatus = 6 | ||||
| ) | ||||
| 
 | ||||
| type Name struct { | ||||
| 	ID     int64      `json:"-"` | ||||
| 	Name   string     `json:"name"` | ||||
| 	Status WordStatus `json:"status"` | ||||
| } | ||||
| 
 | ||||
| func (n Name) Validate() string { | ||||
| 	if n.Name == "" { | ||||
| 		return "name cannot be empty" | ||||
| 	} | ||||
| 
 | ||||
| 	if len([]rune(n.Name)) > FieldEntryMaxLength { | ||||
| 		return fmt.Sprintf("name must be %d characters or less, is %d", FieldEntryMaxLength, len([]rune(n.Name))) | ||||
| 	} | ||||
| 
 | ||||
| 	if n.Status == StatusUnknown || n.Status >= wordStatusMax { | ||||
| 		return fmt.Sprintf("status is invalid, must be between 1 and %d, is %d", wordStatusMax-1, n.Status) | ||||
| 	} | ||||
| 
 | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| type Pronoun struct { | ||||
| 	ID          int64      `json:"-"` | ||||
| 	DisplayText *string    `json:"display_text"` | ||||
| 	Pronouns    string     `json:"pronouns"` | ||||
| 	Status      WordStatus `json:"status"` | ||||
| } | ||||
| 
 | ||||
| func (p Pronoun) Validate() string { | ||||
| 	if p.Pronouns == "" { | ||||
| 		return "pronouns cannot be empty" | ||||
| 	} | ||||
| 
 | ||||
| 	if p.DisplayText != nil { | ||||
| 		if len([]rune(*p.DisplayText)) > FieldEntryMaxLength { | ||||
| 			return fmt.Sprintf("display_text must be %d characters or less, is %d", FieldEntryMaxLength, len([]rune(*p.DisplayText))) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if len([]rune(p.Pronouns)) > FieldEntryMaxLength { | ||||
| 		return fmt.Sprintf("pronouns must be %d characters or less, is %d", FieldEntryMaxLength, len([]rune(p.Pronouns))) | ||||
| 	} | ||||
| 
 | ||||
| 	if p.Status == StatusUnknown || p.Status >= wordStatusMax { | ||||
| 		return fmt.Sprintf("status is invalid, must be between 1 and %d, is %d", wordStatusMax-1, p.Status) | ||||
| 	} | ||||
| 
 | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| func (p Pronoun) String() string { | ||||
| 	if p.DisplayText != nil { | ||||
| 		return *p.DisplayText | ||||
| 	} | ||||
| 
 | ||||
| 	split := strings.Split(p.Pronouns, "/") | ||||
| 	if len(split) <= 2 { | ||||
| 		return strings.Join(split, "/") | ||||
| 	} | ||||
| 
 | ||||
| 	return strings.Join(split[:1], "/") | ||||
| } | ||||
| 
 | ||||
| func (db *DB) UserNames(ctx context.Context, userID xid.ID) (ns []Name, err error) { | ||||
| 	sql, args, err := sq.Select("id", "name", "status").From("user_names").Where("user_id = ?", userID).OrderBy("id").ToSql() | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "building sql") | ||||
| 	} | ||||
| 
 | ||||
| 	err = pgxscan.Select(ctx, db, &ns, sql, args...) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "executing query") | ||||
| 	} | ||||
| 	return ns, nil | ||||
| } | ||||
| 
 | ||||
| func (db *DB) UserPronouns(ctx context.Context, userID xid.ID) (ps []Pronoun, err error) { | ||||
| 	sql, args, err := sq. | ||||
| 		Select("id", "display_text", "pronouns", "status"). | ||||
| 		From("user_pronouns").Where("user_id = ?", userID). | ||||
| 		OrderBy("id").ToSql() | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "building sql") | ||||
| 	} | ||||
| 
 | ||||
| 	err = pgxscan.Select(ctx, db, &ps, sql, args...) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "executing query") | ||||
| 	} | ||||
| 	return ps, nil | ||||
| } | ||||
| 
 | ||||
| func (db *DB) SetUserNames(ctx context.Context, tx pgx.Tx, userID xid.ID, names []Name) (err error) { | ||||
| 	sql, args, err := sq.Delete("user_names").Where("user_id = ?", userID).ToSql() | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "building sql") | ||||
| 	} | ||||
| 
 | ||||
| 	_, err = tx.Exec(ctx, sql, args...) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "deleting existing names") | ||||
| 	} | ||||
| 
 | ||||
| 	_, err = tx.CopyFrom(ctx, | ||||
| 		pgx.Identifier{"user_names"}, | ||||
| 		[]string{"user_id", "name", "status"}, | ||||
| 		pgx.CopyFromSlice(len(names), func(i int) ([]any, error) { | ||||
| 			return []any{ | ||||
| 				userID, | ||||
| 				names[i].Name, | ||||
| 				names[i].Status, | ||||
| 			}, nil | ||||
| 		})) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "inserting new names") | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (db *DB) SetUserPronouns(ctx context.Context, tx pgx.Tx, userID xid.ID, names []Pronoun) (err error) { | ||||
| 	sql, args, err := sq.Delete("user_pronouns").Where("user_id = ?", userID).ToSql() | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "building sql") | ||||
| 	} | ||||
| 
 | ||||
| 	_, err = tx.Exec(ctx, sql, args...) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "deleting existing pronouns") | ||||
| 	} | ||||
| 
 | ||||
| 	_, err = tx.CopyFrom(ctx, | ||||
| 		pgx.Identifier{"user_pronouns"}, | ||||
| 		[]string{"user_id", "pronouns", "display_text", "status"}, | ||||
| 		pgx.CopyFromSlice(len(names), func(i int) ([]any, error) { | ||||
| 			return []any{ | ||||
| 				userID, | ||||
| 				names[i].Pronouns, | ||||
| 				names[i].DisplayText, | ||||
| 				names[i].Status, | ||||
| 			}, nil | ||||
| 		})) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "inserting new pronouns") | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | @ -18,6 +18,8 @@ type GetUserResponse struct { | |||
| 	Bio         *string         `json:"bio"` | ||||
| 	AvatarURL   *string         `json:"avatar_url"` | ||||
| 	Links       []string        `json:"links"` | ||||
| 	Names       []db.Name       `json:"names"` | ||||
| 	Pronouns    []db.Pronoun    `json:"pronouns"` | ||||
| 	Members     []PartialMember `json:"members"` | ||||
| 	Fields      []db.Field      `json:"fields"` | ||||
| } | ||||
|  | @ -35,7 +37,7 @@ type PartialMember struct { | |||
| 	AvatarURL *string `json:"avatar_url"` | ||||
| } | ||||
| 
 | ||||
| func dbUserToResponse(u db.User, fields []db.Field) GetUserResponse { | ||||
| func dbUserToResponse(u db.User, fields []db.Field, names []db.Name, pronouns []db.Pronoun) GetUserResponse { | ||||
| 	return GetUserResponse{ | ||||
| 		ID:          u.ID, | ||||
| 		Username:    u.Username, | ||||
|  | @ -43,6 +45,8 @@ func dbUserToResponse(u db.User, fields []db.Field) GetUserResponse { | |||
| 		Bio:         u.Bio, | ||||
| 		AvatarURL:   u.AvatarURL, | ||||
| 		Links:       u.Links, | ||||
| 		Names:       names, | ||||
| 		Pronouns:    pronouns, | ||||
| 		Fields:      fields, | ||||
| 	} | ||||
| } | ||||
|  | @ -61,7 +65,19 @@ func (s *Server) getUser(w http.ResponseWriter, r *http.Request) error { | |||
| 				return err | ||||
| 			} | ||||
| 
 | ||||
| 			render.JSON(w, r, dbUserToResponse(u, fields)) | ||||
| 			names, err := s.DB.UserNames(ctx, u.ID) | ||||
| 			if err != nil { | ||||
| 				log.Errorf("getting user names: %v", err) | ||||
| 				return err | ||||
| 			} | ||||
| 
 | ||||
| 			pronouns, err := s.DB.UserPronouns(ctx, u.ID) | ||||
| 			if err != nil { | ||||
| 				log.Errorf("getting user pronouns: %v", err) | ||||
| 				return err | ||||
| 			} | ||||
| 
 | ||||
| 			render.JSON(w, r, dbUserToResponse(u, fields, names, pronouns)) | ||||
| 			return nil | ||||
| 		} else if err != db.ErrUserNotFound { | ||||
| 			log.Errorf("Error getting user by ID: %v", err) | ||||
|  | @ -81,13 +97,25 @@ func (s *Server) getUser(w http.ResponseWriter, r *http.Request) error { | |||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	names, err := s.DB.UserNames(ctx, u.ID) | ||||
| 	if err != nil { | ||||
| 		log.Errorf("getting user names: %v", err) | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	pronouns, err := s.DB.UserPronouns(ctx, u.ID) | ||||
| 	if err != nil { | ||||
| 		log.Errorf("getting user pronouns: %v", err) | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	fields, err := s.DB.UserFields(ctx, u.ID) | ||||
| 	if err != nil { | ||||
| 		log.Errorf("Error getting user fields: %v", err) | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	render.JSON(w, r, dbUserToResponse(u, fields)) | ||||
| 	render.JSON(w, r, dbUserToResponse(u, fields, names, pronouns)) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
|  | @ -101,6 +129,18 @@ func (s *Server) getMeUser(w http.ResponseWriter, r *http.Request) error { | |||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	names, err := s.DB.UserNames(ctx, u.ID) | ||||
| 	if err != nil { | ||||
| 		log.Errorf("getting user names: %v", err) | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	pronouns, err := s.DB.UserPronouns(ctx, u.ID) | ||||
| 	if err != nil { | ||||
| 		log.Errorf("getting user pronouns: %v", err) | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	fields, err := s.DB.UserFields(ctx, u.ID) | ||||
| 	if err != nil { | ||||
| 		log.Errorf("Error getting user fields: %v", err) | ||||
|  | @ -108,7 +148,7 @@ func (s *Server) getMeUser(w http.ResponseWriter, r *http.Request) error { | |||
| 	} | ||||
| 
 | ||||
| 	render.JSON(w, r, GetMeResponse{ | ||||
| 		GetUserResponse: dbUserToResponse(u, fields), | ||||
| 		GetUserResponse: dbUserToResponse(u, fields, names, pronouns), | ||||
| 		Discord:         u.Discord, | ||||
| 		DiscordUsername: u.DiscordUsername, | ||||
| 	}) | ||||
|  |  | |||
|  | @ -12,13 +12,16 @@ import ( | |||
| ) | ||||
| 
 | ||||
| type PatchUserRequest struct { | ||||
| 	DisplayName *string     `json:"display_name"` | ||||
| 	Bio         *string     `json:"bio"` | ||||
| 	Links       *[]string   `json:"links"` | ||||
| 	Fields      *[]db.Field `json:"fields"` | ||||
| 	DisplayName *string       `json:"display_name"` | ||||
| 	Bio         *string       `json:"bio"` | ||||
| 	Links       *[]string     `json:"links"` | ||||
| 	Names       *[]db.Name    `json:"names"` | ||||
| 	Pronouns    *[]db.Pronoun `json:"pronouns"` | ||||
| 	Fields      *[]db.Field   `json:"fields"` | ||||
| } | ||||
| 
 | ||||
| // patchUser parses a PatchUserRequest and updates the user with the given ID. | ||||
| // TODO: could this be refactored to be less repetitive? names, pronouns, and fields are all validated in the same way | ||||
| func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error { | ||||
| 	ctx := r.Context() | ||||
| 
 | ||||
|  | @ -71,24 +74,16 @@ func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error { | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if (req.Fields) != nil { | ||||
| 		// max 25 fields | ||||
| 		if len(*req.Fields) > db.MaxFields { | ||||
| 			return server.APIError{ | ||||
| 				Code:    server.ErrBadRequest, | ||||
| 				Details: fmt.Sprintf("Too many fields (max %d, current %d)", db.MaxFields, len(*req.Fields)), | ||||
| 			} | ||||
| 		} | ||||
| 	if err := validateSlicePtr("name", req.Names); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 		// validate all fields | ||||
| 		for i, field := range *req.Fields { | ||||
| 			if s := field.Validate(); s != "" { | ||||
| 				return server.APIError{ | ||||
| 					Code:    server.ErrBadRequest, | ||||
| 					Details: fmt.Sprintf("field %d: %s", i, s), | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	if err := validateSlicePtr("pronoun", req.Pronouns); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	if err := validateSlicePtr("field", req.Fields); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	// start transaction | ||||
|  | @ -105,7 +100,12 @@ func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error { | |||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	var fields []db.Field | ||||
| 	var ( | ||||
| 		names    []db.Name | ||||
| 		pronouns []db.Pronoun | ||||
| 		fields   []db.Field | ||||
| 	) | ||||
| 
 | ||||
| 	if req.Fields != nil { | ||||
| 		err = s.DB.SetUserFields(ctx, tx, claims.UserID, *req.Fields) | ||||
| 		if err != nil { | ||||
|  | @ -127,6 +127,38 @@ func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error { | |||
| 	} | ||||
| 
 | ||||
| 	// echo the updated user back on success | ||||
| 	render.JSON(w, r, dbUserToResponse(u, fields)) | ||||
| 	render.JSON(w, r, dbUserToResponse(u, fields, names, pronouns)) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| type validator interface { | ||||
| 	Validate() string | ||||
| } | ||||
| 
 | ||||
| // validateSlicePtr validates a slice of validators. | ||||
| // If the slice is nil, a nil error is returned (assuming that the field is not required) | ||||
| func validateSlicePtr[T validator](typ string, slice *[]T) error { | ||||
| 	if slice == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	// max 25 fields | ||||
| 	if len(*slice) > db.MaxFields { | ||||
| 		return server.APIError{ | ||||
| 			Code:    server.ErrBadRequest, | ||||
| 			Details: fmt.Sprintf("Too many %ss (max %d, current %d)", typ, db.MaxFields, len(*slice)), | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// validate all fields | ||||
| 	for i, pronouns := range *slice { | ||||
| 		if s := pronouns.Validate(); s != "" { | ||||
| 			return server.APIError{ | ||||
| 				Code:    server.ErrBadRequest, | ||||
| 				Details: fmt.Sprintf("%s %d: %s", typ, i, s), | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue