feat(api): add PATCH /users/@me, remove PATCH /users/@me/fields
This commit is contained in:
		
							parent
							
								
									f4a6296d13
								
							
						
					
					
						commit
						57c7a0f4de
					
				
					 6 changed files with 188 additions and 67 deletions
				
			
		|  | @ -14,6 +14,8 @@ import ( | |||
| 
 | ||||
| var sq = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar) | ||||
| 
 | ||||
| const ErrNothingToUpdate = errors.Sentinel("nothing to update") | ||||
| 
 | ||||
| type DB struct { | ||||
| 	*pgxpool.Pool | ||||
| 
 | ||||
|  |  | |||
|  | @ -91,13 +91,7 @@ func (db *DB) UserFields(ctx context.Context, id xid.ID) (fs []Field, err error) | |||
| } | ||||
| 
 | ||||
| // SetUserFields updates the fields for the given user. | ||||
| func (db *DB) SetUserFields(ctx context.Context, userID xid.ID, fields []Field) (err error) { | ||||
| 	tx, err := db.Begin(ctx) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "building sql") | ||||
| 	} | ||||
| 	defer tx.Rollback(ctx) | ||||
| 
 | ||||
| func (db *DB) SetUserFields(ctx context.Context, tx pgx.Tx, userID xid.ID, fields []Field) (err error) { | ||||
| 	sql, args, err := sq.Delete("user_fields").Where("user_id = ?", userID).ToSql() | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "building sql") | ||||
|  | @ -124,10 +118,5 @@ func (db *DB) SetUserFields(ctx context.Context, userID xid.ID, fields []Field) | |||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "inserting new fields") | ||||
| 	} | ||||
| 
 | ||||
| 	err = tx.Commit(ctx) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "committing transaction") | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  |  | |||
|  | @ -38,6 +38,14 @@ const ( | |||
| 	ErrUsernameTooLong  = errors.Sentinel("username is too long") | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	MaxUsernameLength    = 40 | ||||
| 	MaxDisplayNameLength = 100 | ||||
| 	MaxUserBioLength     = 1000 | ||||
| 	MaxUserLinksLength   = 25 | ||||
| 	MaxLinkLength        = 256 | ||||
| ) | ||||
| 
 | ||||
| // CreateUser creates a user with the given username. | ||||
| func (db *DB) CreateUser(ctx context.Context, username string) (u User, err error) { | ||||
| 	// check if the username is valid | ||||
|  | @ -146,3 +154,49 @@ func (db *DB) UsernameTaken(ctx context.Context, username string) (valid, taken | |||
| 	err = db.QueryRow(ctx, "select exists (select id from users where username = $1)", username).Scan(&taken) | ||||
| 	return true, taken, err | ||||
| } | ||||
| 
 | ||||
| func (db *DB) UpdateUser( | ||||
| 	ctx context.Context, | ||||
| 	tx pgx.Tx, id xid.ID, | ||||
| 	displayName, bio *string, | ||||
| 	links *[]string, | ||||
| ) (u User, err error) { | ||||
| 	if displayName == nil && bio == nil && links == nil { | ||||
| 		return u, ErrNothingToUpdate | ||||
| 	} | ||||
| 
 | ||||
| 	builder := sq.Update("users").Where("id = ?", id) | ||||
| 	if displayName != nil { | ||||
| 		if *displayName == "" { | ||||
| 			builder = builder.Set("display_name", nil) | ||||
| 		} else { | ||||
| 			builder = builder.Set("display_name", *displayName) | ||||
| 		} | ||||
| 	} | ||||
| 	if bio != nil { | ||||
| 		if *bio == "" { | ||||
| 			builder = builder.Set("bio", nil) | ||||
| 		} else { | ||||
| 			builder = builder.Set("bio", *bio) | ||||
| 		} | ||||
| 	} | ||||
| 	if links != nil { | ||||
| 		if len(*links) == 0 { | ||||
| 			builder = builder.Set("links", nil) | ||||
| 		} else { | ||||
| 			builder = builder.Set("links", *links) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	sql, args, err := builder.Suffix("RETURNING *").ToSql() | ||||
| 	if err != nil { | ||||
| 		return u, errors.Wrap(err, "building sql") | ||||
| 	} | ||||
| 
 | ||||
| 	err = pgxscan.Get(ctx, tx, &u, sql, args...) | ||||
| 	if err != nil { | ||||
| 		return u, errors.Wrap(err, "executing sql") | ||||
| 	} | ||||
| 
 | ||||
| 	return u, nil | ||||
| } | ||||
|  |  | |||
|  | @ -1,54 +0,0 @@ | |||
| package user | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/db" | ||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/log" | ||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/server" | ||||
| 	"github.com/go-chi/render" | ||||
| ) | ||||
| 
 | ||||
| type PatchFieldsRequest struct { | ||||
| 	Fields []db.Field `json:"fields"` | ||||
| } | ||||
| 
 | ||||
| func (s *Server) patchUserFields(w http.ResponseWriter, r *http.Request) error { | ||||
| 	ctx := r.Context() | ||||
| 	claims, _ := server.ClaimsFromContext(ctx) | ||||
| 
 | ||||
| 	var req PatchFieldsRequest | ||||
| 	err := render.Decode(r, &req) | ||||
| 	if err != nil { | ||||
| 		return server.APIError{Code: server.ErrBadRequest} | ||||
| 	} | ||||
| 
 | ||||
| 	// max 25 fields | ||||
| 	if len(req.Fields) > db.MaxFields { | ||||
| 		return server.APIError{ | ||||
| 			Code:    server.ErrBadRequest, | ||||
| 			Details: fmt.Sprintf("Too many fields (max %d, current %d)", db.MaxFields, len(req.Fields)), | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// validate all fields | ||||
| 	for i, field := range req.Fields { | ||||
| 		if s := field.Validate(); s != "" { | ||||
| 			return server.APIError{ | ||||
| 				Code:    server.ErrBadRequest, | ||||
| 				Details: fmt.Sprintf("field %d: %s", i, s), | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	err = s.DB.SetUserFields(ctx, claims.UserID, req.Fields) | ||||
| 	if err != nil { | ||||
| 		log.Errorf("setting fields for user %v: %v", claims.UserID, err) | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	// echo the fields back on success | ||||
| 	render.JSON(w, r, req) | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										130
									
								
								backend/routes/user/patch_user.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										130
									
								
								backend/routes/user/patch_user.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,130 @@ | |||
| package user | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/db" | ||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/log" | ||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/server" | ||||
| 	"emperror.dev/errors" | ||||
| 	"github.com/go-chi/render" | ||||
| ) | ||||
| 
 | ||||
| type PatchUserRequest struct { | ||||
| 	DisplayName *string     `json:"display_name"` | ||||
| 	Bio         *string     `json:"bio"` | ||||
| 	Links       *[]string   `json:"links"` | ||||
| 	Fields      *[]db.Field `json:"fields"` | ||||
| } | ||||
| 
 | ||||
| // patchUser parses a PatchUserRequest and updates the user with the given ID. | ||||
| func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error { | ||||
| 	ctx := r.Context() | ||||
| 
 | ||||
| 	claims, _ := server.ClaimsFromContext(ctx) | ||||
| 
 | ||||
| 	var req PatchUserRequest | ||||
| 	err := render.Decode(r, &req) | ||||
| 	if err != nil { | ||||
| 		return server.APIError{Code: server.ErrBadRequest} | ||||
| 	} | ||||
| 
 | ||||
| 	// validate that *something* is set | ||||
| 	if req.DisplayName == nil && req.Bio == nil && req.Links == nil && req.Fields == nil { | ||||
| 		return server.APIError{ | ||||
| 			Code:    server.ErrBadRequest, | ||||
| 			Details: "Data must not be empty", | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// validate display name/bio | ||||
| 	if req.DisplayName != nil && len(*req.DisplayName) > db.MaxDisplayNameLength { | ||||
| 		return server.APIError{ | ||||
| 			Code:    server.ErrBadRequest, | ||||
| 			Details: fmt.Sprintf("Display name too long (max %d, current %d)", db.MaxDisplayNameLength, len(*req.DisplayName)), | ||||
| 		} | ||||
| 	} | ||||
| 	if req.Bio != nil && len(*req.Bio) > db.MaxUserBioLength { | ||||
| 		return server.APIError{ | ||||
| 			Code:    server.ErrBadRequest, | ||||
| 			Details: fmt.Sprintf("Bio too long (max %d, current %d)", db.MaxUserBioLength, len(*req.Bio)), | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// validate links | ||||
| 	if req.Links != nil { | ||||
| 		if len(*req.Links) > db.MaxUserLinksLength { | ||||
| 			return server.APIError{ | ||||
| 				Code:    server.ErrBadRequest, | ||||
| 				Details: fmt.Sprintf("Too many links (max %d, current %d)", db.MaxUserLinksLength, len(*req.Links)), | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		for i, link := range *req.Links { | ||||
| 			if len(link) > db.MaxLinkLength { | ||||
| 				return server.APIError{ | ||||
| 					Code:    server.ErrBadRequest, | ||||
| 					Details: fmt.Sprintf("Link %d too long (max %d, current %d)", i, db.MaxLinkLength, len(link)), | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// max 25 fields | ||||
| 	if len(*req.Fields) > db.MaxFields { | ||||
| 		return server.APIError{ | ||||
| 			Code:    server.ErrBadRequest, | ||||
| 			Details: fmt.Sprintf("Too many fields (max %d, current %d)", db.MaxFields, len(*req.Fields)), | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// validate all fields | ||||
| 	for i, field := range *req.Fields { | ||||
| 		if s := field.Validate(); s != "" { | ||||
| 			return server.APIError{ | ||||
| 				Code:    server.ErrBadRequest, | ||||
| 				Details: fmt.Sprintf("field %d: %s", i, s), | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// start transaction | ||||
| 	tx, err := s.DB.Begin(ctx) | ||||
| 	if err != nil { | ||||
| 		log.Errorf("creating transaction: %v", err) | ||||
| 		return err | ||||
| 	} | ||||
| 	defer tx.Rollback(ctx) | ||||
| 
 | ||||
| 	u, err := s.DB.UpdateUser(ctx, tx, claims.UserID, req.DisplayName, req.Bio, req.Links) | ||||
| 	if err != nil && errors.Cause(err) != db.ErrNothingToUpdate { | ||||
| 		log.Errorf("updating user: %v", err) | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	var fields []db.Field | ||||
| 	if req.Fields != nil { | ||||
| 		err = s.DB.SetUserFields(ctx, tx, claims.UserID, *req.Fields) | ||||
| 		if err != nil { | ||||
| 			log.Errorf("setting fields for user %v: %v", claims.UserID, err) | ||||
| 			return err | ||||
| 		} | ||||
| 	} else { | ||||
| 		fields, err = s.DB.UserFields(ctx, claims.UserID) | ||||
| 		if err != nil { | ||||
| 			log.Errorf("getting fields for user %v: %v", claims.UserID, err) | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	err = tx.Commit(ctx) | ||||
| 	if err != nil { | ||||
| 		log.Errorf("committing transaction: %v", err) | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	// echo the updated user back on success | ||||
| 	render.JSON(w, r, dbUserToResponse(u, fields)) | ||||
| 	return nil | ||||
| } | ||||
|  | @ -17,7 +17,7 @@ func Mount(srv *server.Server, r chi.Router) { | |||
| 
 | ||||
| 		r.With(server.MustAuth).Group(func(r chi.Router) { | ||||
| 			r.Get("/@me", server.WrapHandler(s.getMeUser)) | ||||
| 			r.Patch("/@me/fields", server.WrapHandler(s.patchUserFields)) | ||||
| 			r.Patch("/@me", server.WrapHandler(s.patchUser)) | ||||
| 		}) | ||||
| 	}) | ||||
| } | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue