feat(api): add PATCH /users/@me, remove PATCH /users/@me/fields
This commit is contained in:
		
							parent
							
								
									f4a6296d13
								
							
						
					
					
						commit
						57c7a0f4de
					
				
					 6 changed files with 188 additions and 67 deletions
				
			
		| 
						 | 
					@ -14,6 +14,8 @@ import (
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var sq = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
 | 
					var sq = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const ErrNothingToUpdate = errors.Sentinel("nothing to update")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type DB struct {
 | 
					type DB struct {
 | 
				
			||||||
	*pgxpool.Pool
 | 
						*pgxpool.Pool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -91,13 +91,7 @@ func (db *DB) UserFields(ctx context.Context, id xid.ID) (fs []Field, err error)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// SetUserFields updates the fields for the given user.
 | 
					// SetUserFields updates the fields for the given user.
 | 
				
			||||||
func (db *DB) SetUserFields(ctx context.Context, userID xid.ID, fields []Field) (err error) {
 | 
					func (db *DB) SetUserFields(ctx context.Context, tx pgx.Tx, userID xid.ID, fields []Field) (err error) {
 | 
				
			||||||
	tx, err := db.Begin(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return errors.Wrap(err, "building sql")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	defer tx.Rollback(ctx)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	sql, args, err := sq.Delete("user_fields").Where("user_id = ?", userID).ToSql()
 | 
						sql, args, err := sq.Delete("user_fields").Where("user_id = ?", userID).ToSql()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errors.Wrap(err, "building sql")
 | 
							return errors.Wrap(err, "building sql")
 | 
				
			||||||
| 
						 | 
					@ -124,10 +118,5 @@ func (db *DB) SetUserFields(ctx context.Context, userID xid.ID, fields []Field)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errors.Wrap(err, "inserting new fields")
 | 
							return errors.Wrap(err, "inserting new fields")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					 | 
				
			||||||
	err = tx.Commit(ctx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return errors.Wrap(err, "committing transaction")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -38,6 +38,14 @@ const (
 | 
				
			||||||
	ErrUsernameTooLong  = errors.Sentinel("username is too long")
 | 
						ErrUsernameTooLong  = errors.Sentinel("username is too long")
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						MaxUsernameLength    = 40
 | 
				
			||||||
 | 
						MaxDisplayNameLength = 100
 | 
				
			||||||
 | 
						MaxUserBioLength     = 1000
 | 
				
			||||||
 | 
						MaxUserLinksLength   = 25
 | 
				
			||||||
 | 
						MaxLinkLength        = 256
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CreateUser creates a user with the given username.
 | 
					// CreateUser creates a user with the given username.
 | 
				
			||||||
func (db *DB) CreateUser(ctx context.Context, username string) (u User, err error) {
 | 
					func (db *DB) CreateUser(ctx context.Context, username string) (u User, err error) {
 | 
				
			||||||
	// check if the username is valid
 | 
						// check if the username is valid
 | 
				
			||||||
| 
						 | 
					@ -146,3 +154,49 @@ func (db *DB) UsernameTaken(ctx context.Context, username string) (valid, taken
 | 
				
			||||||
	err = db.QueryRow(ctx, "select exists (select id from users where username = $1)", username).Scan(&taken)
 | 
						err = db.QueryRow(ctx, "select exists (select id from users where username = $1)", username).Scan(&taken)
 | 
				
			||||||
	return true, taken, err
 | 
						return true, taken, err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (db *DB) UpdateUser(
 | 
				
			||||||
 | 
						ctx context.Context,
 | 
				
			||||||
 | 
						tx pgx.Tx, id xid.ID,
 | 
				
			||||||
 | 
						displayName, bio *string,
 | 
				
			||||||
 | 
						links *[]string,
 | 
				
			||||||
 | 
					) (u User, err error) {
 | 
				
			||||||
 | 
						if displayName == nil && bio == nil && links == nil {
 | 
				
			||||||
 | 
							return u, ErrNothingToUpdate
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						builder := sq.Update("users").Where("id = ?", id)
 | 
				
			||||||
 | 
						if displayName != nil {
 | 
				
			||||||
 | 
							if *displayName == "" {
 | 
				
			||||||
 | 
								builder = builder.Set("display_name", nil)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								builder = builder.Set("display_name", *displayName)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if bio != nil {
 | 
				
			||||||
 | 
							if *bio == "" {
 | 
				
			||||||
 | 
								builder = builder.Set("bio", nil)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								builder = builder.Set("bio", *bio)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if links != nil {
 | 
				
			||||||
 | 
							if len(*links) == 0 {
 | 
				
			||||||
 | 
								builder = builder.Set("links", nil)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								builder = builder.Set("links", *links)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						sql, args, err := builder.Suffix("RETURNING *").ToSql()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return u, errors.Wrap(err, "building sql")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = pgxscan.Get(ctx, tx, &u, sql, args...)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return u, errors.Wrap(err, "executing sql")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return u, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,54 +0,0 @@
 | 
				
			||||||
package user
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"net/http"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"codeberg.org/u1f320/pronouns.cc/backend/db"
 | 
					 | 
				
			||||||
	"codeberg.org/u1f320/pronouns.cc/backend/log"
 | 
					 | 
				
			||||||
	"codeberg.org/u1f320/pronouns.cc/backend/server"
 | 
					 | 
				
			||||||
	"github.com/go-chi/render"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type PatchFieldsRequest struct {
 | 
					 | 
				
			||||||
	Fields []db.Field `json:"fields"`
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s *Server) patchUserFields(w http.ResponseWriter, r *http.Request) error {
 | 
					 | 
				
			||||||
	ctx := r.Context()
 | 
					 | 
				
			||||||
	claims, _ := server.ClaimsFromContext(ctx)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var req PatchFieldsRequest
 | 
					 | 
				
			||||||
	err := render.Decode(r, &req)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return server.APIError{Code: server.ErrBadRequest}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// max 25 fields
 | 
					 | 
				
			||||||
	if len(req.Fields) > db.MaxFields {
 | 
					 | 
				
			||||||
		return server.APIError{
 | 
					 | 
				
			||||||
			Code:    server.ErrBadRequest,
 | 
					 | 
				
			||||||
			Details: fmt.Sprintf("Too many fields (max %d, current %d)", db.MaxFields, len(req.Fields)),
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// validate all fields
 | 
					 | 
				
			||||||
	for i, field := range req.Fields {
 | 
					 | 
				
			||||||
		if s := field.Validate(); s != "" {
 | 
					 | 
				
			||||||
			return server.APIError{
 | 
					 | 
				
			||||||
				Code:    server.ErrBadRequest,
 | 
					 | 
				
			||||||
				Details: fmt.Sprintf("field %d: %s", i, s),
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	err = s.DB.SetUserFields(ctx, claims.UserID, req.Fields)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		log.Errorf("setting fields for user %v: %v", claims.UserID, err)
 | 
					 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// echo the fields back on success
 | 
					 | 
				
			||||||
	render.JSON(w, r, req)
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
							
								
								
									
										130
									
								
								backend/routes/user/patch_user.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										130
									
								
								backend/routes/user/patch_user.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,130 @@
 | 
				
			||||||
 | 
					package user
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"codeberg.org/u1f320/pronouns.cc/backend/db"
 | 
				
			||||||
 | 
						"codeberg.org/u1f320/pronouns.cc/backend/log"
 | 
				
			||||||
 | 
						"codeberg.org/u1f320/pronouns.cc/backend/server"
 | 
				
			||||||
 | 
						"emperror.dev/errors"
 | 
				
			||||||
 | 
						"github.com/go-chi/render"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type PatchUserRequest struct {
 | 
				
			||||||
 | 
						DisplayName *string     `json:"display_name"`
 | 
				
			||||||
 | 
						Bio         *string     `json:"bio"`
 | 
				
			||||||
 | 
						Links       *[]string   `json:"links"`
 | 
				
			||||||
 | 
						Fields      *[]db.Field `json:"fields"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// patchUser parses a PatchUserRequest and updates the user with the given ID.
 | 
				
			||||||
 | 
					func (s *Server) patchUser(w http.ResponseWriter, r *http.Request) error {
 | 
				
			||||||
 | 
						ctx := r.Context()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						claims, _ := server.ClaimsFromContext(ctx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var req PatchUserRequest
 | 
				
			||||||
 | 
						err := render.Decode(r, &req)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return server.APIError{Code: server.ErrBadRequest}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// validate that *something* is set
 | 
				
			||||||
 | 
						if req.DisplayName == nil && req.Bio == nil && req.Links == nil && req.Fields == nil {
 | 
				
			||||||
 | 
							return server.APIError{
 | 
				
			||||||
 | 
								Code:    server.ErrBadRequest,
 | 
				
			||||||
 | 
								Details: "Data must not be empty",
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// validate display name/bio
 | 
				
			||||||
 | 
						if req.DisplayName != nil && len(*req.DisplayName) > db.MaxDisplayNameLength {
 | 
				
			||||||
 | 
							return server.APIError{
 | 
				
			||||||
 | 
								Code:    server.ErrBadRequest,
 | 
				
			||||||
 | 
								Details: fmt.Sprintf("Display name too long (max %d, current %d)", db.MaxDisplayNameLength, len(*req.DisplayName)),
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if req.Bio != nil && len(*req.Bio) > db.MaxUserBioLength {
 | 
				
			||||||
 | 
							return server.APIError{
 | 
				
			||||||
 | 
								Code:    server.ErrBadRequest,
 | 
				
			||||||
 | 
								Details: fmt.Sprintf("Bio too long (max %d, current %d)", db.MaxUserBioLength, len(*req.Bio)),
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// validate links
 | 
				
			||||||
 | 
						if req.Links != nil {
 | 
				
			||||||
 | 
							if len(*req.Links) > db.MaxUserLinksLength {
 | 
				
			||||||
 | 
								return server.APIError{
 | 
				
			||||||
 | 
									Code:    server.ErrBadRequest,
 | 
				
			||||||
 | 
									Details: fmt.Sprintf("Too many links (max %d, current %d)", db.MaxUserLinksLength, len(*req.Links)),
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							for i, link := range *req.Links {
 | 
				
			||||||
 | 
								if len(link) > db.MaxLinkLength {
 | 
				
			||||||
 | 
									return server.APIError{
 | 
				
			||||||
 | 
										Code:    server.ErrBadRequest,
 | 
				
			||||||
 | 
										Details: fmt.Sprintf("Link %d too long (max %d, current %d)", i, db.MaxLinkLength, len(link)),
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// max 25 fields
 | 
				
			||||||
 | 
						if len(*req.Fields) > db.MaxFields {
 | 
				
			||||||
 | 
							return server.APIError{
 | 
				
			||||||
 | 
								Code:    server.ErrBadRequest,
 | 
				
			||||||
 | 
								Details: fmt.Sprintf("Too many fields (max %d, current %d)", db.MaxFields, len(*req.Fields)),
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// validate all fields
 | 
				
			||||||
 | 
						for i, field := range *req.Fields {
 | 
				
			||||||
 | 
							if s := field.Validate(); s != "" {
 | 
				
			||||||
 | 
								return server.APIError{
 | 
				
			||||||
 | 
									Code:    server.ErrBadRequest,
 | 
				
			||||||
 | 
									Details: fmt.Sprintf("field %d: %s", i, s),
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// start transaction
 | 
				
			||||||
 | 
						tx, err := s.DB.Begin(ctx)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Errorf("creating transaction: %v", err)
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer tx.Rollback(ctx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						u, err := s.DB.UpdateUser(ctx, tx, claims.UserID, req.DisplayName, req.Bio, req.Links)
 | 
				
			||||||
 | 
						if err != nil && errors.Cause(err) != db.ErrNothingToUpdate {
 | 
				
			||||||
 | 
							log.Errorf("updating user: %v", err)
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var fields []db.Field
 | 
				
			||||||
 | 
						if req.Fields != nil {
 | 
				
			||||||
 | 
							err = s.DB.SetUserFields(ctx, tx, claims.UserID, *req.Fields)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Errorf("setting fields for user %v: %v", claims.UserID, err)
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							fields, err = s.DB.UserFields(ctx, claims.UserID)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Errorf("getting fields for user %v: %v", claims.UserID, err)
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = tx.Commit(ctx)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Errorf("committing transaction: %v", err)
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// echo the updated user back on success
 | 
				
			||||||
 | 
						render.JSON(w, r, dbUserToResponse(u, fields))
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -17,7 +17,7 @@ func Mount(srv *server.Server, r chi.Router) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		r.With(server.MustAuth).Group(func(r chi.Router) {
 | 
							r.With(server.MustAuth).Group(func(r chi.Router) {
 | 
				
			||||||
			r.Get("/@me", server.WrapHandler(s.getMeUser))
 | 
								r.Get("/@me", server.WrapHandler(s.getMeUser))
 | 
				
			||||||
			r.Patch("/@me/fields", server.WrapHandler(s.patchUserFields))
 | 
								r.Patch("/@me", server.WrapHandler(s.patchUser))
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue