feat(api): add PATCH /users/@me/fields, finish POST /auth/discord/callback
This commit is contained in:
		
							parent
							
								
									020ac15a00
								
							
						
					
					
						commit
						52a03b4aa6
					
				
					 9 changed files with 261 additions and 17 deletions
				
			
		|  | @ -2,12 +2,21 @@ package db | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 
 | ||||
| 	"emperror.dev/errors" | ||||
| 	"github.com/georgysavva/scany/pgxscan" | ||||
| 	"github.com/jackc/pgx/v4" | ||||
| 	"github.com/rs/xid" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	MaxFields           = 25 | ||||
| 	FieldNameMaxLength  = 100 | ||||
| 	FieldEntriesLimit   = 100 | ||||
| 	FieldEntryMaxLength = 50 | ||||
| ) | ||||
| 
 | ||||
| type Field struct { | ||||
| 	ID          int64    `json:"-"` | ||||
| 	Name        string   `json:"name"` | ||||
|  | @ -18,6 +27,53 @@ type Field struct { | |||
| 	Avoid       []string `json:"avoid"` | ||||
| } | ||||
| 
 | ||||
| // Validate validates this field. If it is invalid, a non-empty string is returned as error message. | ||||
| func (f Field) Validate() string { | ||||
| 	if f.Name == "" { | ||||
| 		return "name cannot be empty" | ||||
| 	} | ||||
| 
 | ||||
| 	if length := len([]rune(f.Name)); length > FieldNameMaxLength { | ||||
| 		return fmt.Sprintf("name max length is %d characters, length is %d", FieldNameMaxLength, length) | ||||
| 	} | ||||
| 
 | ||||
| 	if length := len(f.Favourite) + len(f.Okay) + len(f.Jokingly) + len(f.FriendsOnly) + len(f.Avoid); length > FieldEntriesLimit { | ||||
| 		return fmt.Sprintf("max number of entries is %d, current number is %d", FieldEntriesLimit, length) | ||||
| 	} | ||||
| 
 | ||||
| 	for i, entry := range f.Favourite { | ||||
| 		if length := len([]rune(entry)); length > FieldEntryMaxLength { | ||||
| 			return fmt.Sprintf("favourite.%d: name max length is %d characters, length is %d", i, FieldEntryMaxLength, length) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for i, entry := range f.Okay { | ||||
| 		if length := len([]rune(entry)); length > FieldEntryMaxLength { | ||||
| 			return fmt.Sprintf("okay.%d: name max length is %d characters, length is %d", i, FieldEntryMaxLength, length) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for i, entry := range f.Jokingly { | ||||
| 		if length := len([]rune(entry)); length > FieldEntryMaxLength { | ||||
| 			return fmt.Sprintf("jokingly.%d: name max length is %d characters, length is %d", i, FieldEntryMaxLength, length) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for i, entry := range f.FriendsOnly { | ||||
| 		if length := len([]rune(entry)); length > FieldEntryMaxLength { | ||||
| 			return fmt.Sprintf("friends_only.%d: name max length is %d characters, length is %d", i, FieldEntryMaxLength, length) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for i, entry := range f.Avoid { | ||||
| 		if length := len([]rune(entry)); length > FieldEntryMaxLength { | ||||
| 			return fmt.Sprintf("avoid.%d: name max length is %d characters, length is %d", i, FieldEntryMaxLength, length) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| // UserFields returns the fields associated with the given user ID. | ||||
| func (db *DB) UserFields(ctx context.Context, id xid.ID) (fs []Field, err error) { | ||||
| 	sql, args, err := sq. | ||||
|  | @ -33,3 +89,45 @@ func (db *DB) UserFields(ctx context.Context, id xid.ID) (fs []Field, err error) | |||
| 	} | ||||
| 	return fs, nil | ||||
| } | ||||
| 
 | ||||
| // SetUserFields updates the fields for the given user. | ||||
| func (db *DB) SetUserFields(ctx context.Context, userID xid.ID, fields []Field) (err error) { | ||||
| 	tx, err := db.Begin(ctx) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "building sql") | ||||
| 	} | ||||
| 	defer tx.Rollback(ctx) | ||||
| 
 | ||||
| 	sql, args, err := sq.Delete("user_fields").Where("user_id = ?", userID).ToSql() | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "building sql") | ||||
| 	} | ||||
| 
 | ||||
| 	_, err = tx.Exec(ctx, sql, args...) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "deleting existing fields") | ||||
| 	} | ||||
| 
 | ||||
| 	_, err = tx.CopyFrom(ctx, | ||||
| 		pgx.Identifier{"user_fields"}, | ||||
| 		[]string{"name", "favourite", "okay", "jokingly", "friends_only", "avoid"}, | ||||
| 		pgx.CopyFromSlice(len(fields), func(i int) ([]any, error) { | ||||
| 			return []any{ | ||||
| 				fields[i].Name, | ||||
| 				fields[i].Favourite, | ||||
| 				fields[i].Okay, | ||||
| 				fields[i].Jokingly, | ||||
| 				fields[i].FriendsOnly, | ||||
| 				fields[i].Avoid, | ||||
| 			}, nil | ||||
| 		})) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "inserting new fields") | ||||
| 	} | ||||
| 
 | ||||
| 	err = tx.Commit(ctx) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "committing transaction") | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  |  | |||
|  | @ -27,7 +27,7 @@ type User struct { | |||
| } | ||||
| 
 | ||||
| // usernames must match this regex | ||||
| var usernameRegex = regexp.MustCompile(`[\w-.]{2,40}`) | ||||
| var usernameRegex = regexp.MustCompile(`^[\w-.]{2,40}$`) | ||||
| 
 | ||||
| const ( | ||||
| 	ErrUserNotFound = errors.Sentinel("user not found") | ||||
|  | @ -136,3 +136,13 @@ func (db *DB) Username(ctx context.Context, name string) (u User, err error) { | |||
| 
 | ||||
| 	return u, nil | ||||
| } | ||||
| 
 | ||||
| // UsernameTaken checks if the given username is already taken. | ||||
| func (db *DB) UsernameTaken(ctx context.Context, username string) (valid, taken bool, err error) { | ||||
| 	if !usernameRegex.MatchString(username) { | ||||
| 		return false, false, nil | ||||
| 	} | ||||
| 
 | ||||
| 	err = db.QueryRow(ctx, "select exists (select id from users where username = $1)", username).Scan(&taken) | ||||
| 	return true, taken, err | ||||
| } | ||||
|  |  | |||
|  | @ -36,8 +36,9 @@ type discordCallbackResponse struct { | |||
| 	Token string   `json:"token,omitempty"` | ||||
| 	User  *db.User `json:"user,omitempty"` | ||||
| 
 | ||||
| 	Discord string `json:"discord,omitempty"` // username, for UI purposes | ||||
| 	Ticket  string `json:"ticket,omitempty"` | ||||
| 	Discord       string `json:"discord,omitempty"` // username, for UI purposes | ||||
| 	Ticket        string `json:"ticket,omitempty"` | ||||
| 	RequireInvite bool   `json:"require_invite,omitempty"` // require an invite for signing up | ||||
| } | ||||
| 
 | ||||
| func (s *Server) discordCallback(w http.ResponseWriter, r *http.Request) error { | ||||
|  | @ -95,7 +96,20 @@ func (s *Server) discordCallback(w http.ResponseWriter, r *http.Request) error { | |||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	// no user found, so save a ticket | ||||
| 	// no user found, so save a ticket + save their Discord info in Redis | ||||
| 	ticket := RandBase64(32) | ||||
| 	err = s.DB.SetJSON(ctx, "discord:"+ticket, du, "EX", "600") | ||||
| 	if err != nil { | ||||
| 		log.Errorf("setting Discord user for ticket %q: %v", ticket, err) | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	render.JSON(w, r, discordCallbackResponse{ | ||||
| 		HasAccount:    false, | ||||
| 		Discord:       du.String(), | ||||
| 		Ticket:        ticket, | ||||
| 		RequireInvite: s.RequireInvite, | ||||
| 	}) | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
|  |  | |||
|  | @ -14,16 +14,9 @@ const numStates = "1000" | |||
| 
 | ||||
| // setCSRFState generates a random string to use as state, then stores that in Redis. | ||||
| func (s *Server) setCSRFState(ctx context.Context) (string, error) { | ||||
| 	b := make([]byte, 32) | ||||
| 	state := RandBase64(32) | ||||
| 
 | ||||
| 	_, err := rand.Read(b) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 
 | ||||
| 	state := base64.URLEncoding.EncodeToString(b) | ||||
| 
 | ||||
| 	err = s.DB.MultiCmd(ctx, | ||||
| 	err := s.DB.MultiCmd(ctx, | ||||
| 		radix.Cmd(nil, "LPUSH", "csrf", state), | ||||
| 		radix.Cmd(nil, "LTRIM", "csrf", "0", numStates), | ||||
| 	) | ||||
|  | @ -39,3 +32,14 @@ func (s *Server) validateCSRFState(ctx context.Context, state string) (matched b | |||
| 	} | ||||
| 	return num > 0, nil | ||||
| } | ||||
| 
 | ||||
| // RandBase64 returns a string of random bytes encoded in raw base 64. | ||||
| func RandBase64(size int) string { | ||||
| 	b := make([]byte, size) | ||||
| 	_, err := rand.Read(b) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 
 | ||||
| 	return base64.RawURLEncoding.EncodeToString(b) | ||||
| } | ||||
|  |  | |||
|  | @ -2,6 +2,7 @@ package auth | |||
| 
 | ||||
| import ( | ||||
| 	"net/http" | ||||
| 	"os" | ||||
| 
 | ||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/log" | ||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/server" | ||||
|  | @ -12,12 +13,20 @@ import ( | |||
| 
 | ||||
| type Server struct { | ||||
| 	*server.Server | ||||
| 
 | ||||
| 	RequireInvite bool | ||||
| } | ||||
| 
 | ||||
| func Mount(srv *server.Server, r chi.Router) { | ||||
| 	s := &Server{srv} | ||||
| 	s := &Server{ | ||||
| 		Server:        srv, | ||||
| 		RequireInvite: os.Getenv("REQUIRE_INVITE") == "true", | ||||
| 	} | ||||
| 
 | ||||
| 	r.Route("/auth", func(r chi.Router) { | ||||
| 		// check if username is taken | ||||
| 		r.Get("/username", server.WrapHandler(s.usernameTaken)) | ||||
| 
 | ||||
| 		// generate csrf token, returns all supported OAuth provider URLs | ||||
| 		r.Post("/urls", server.WrapHandler(s.oauthURLs)) | ||||
| 
 | ||||
|  | @ -61,3 +70,29 @@ func (s *Server) oauthURLs(w http.ResponseWriter, r *http.Request) error { | |||
| 	}) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) usernameTaken(w http.ResponseWriter, r *http.Request) error { | ||||
| 	type Response struct { | ||||
| 		Valid bool `json:"valid"` | ||||
| 		Taken bool `json:"taken"` | ||||
| 	} | ||||
| 
 | ||||
| 	name := r.FormValue("username") | ||||
| 	if name == "" { | ||||
| 		render.JSON(w, r, Response{ | ||||
| 			Valid: false, | ||||
| 		}) | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	valid, taken, err := s.DB.UsernameTaken(r.Context(), name) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	render.JSON(w, r, Response{ | ||||
| 		Valid: valid, | ||||
| 		Taken: taken, | ||||
| 	}) | ||||
| 	return nil | ||||
| } | ||||
|  |  | |||
							
								
								
									
										54
									
								
								backend/routes/user/fields.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								backend/routes/user/fields.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,54 @@ | |||
| package user | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/db" | ||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/log" | ||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/server" | ||||
| 	"github.com/go-chi/render" | ||||
| ) | ||||
| 
 | ||||
| type PatchFieldsRequest struct { | ||||
| 	Fields []db.Field `json:"fields"` | ||||
| } | ||||
| 
 | ||||
| func (s *Server) patchUserFields(w http.ResponseWriter, r *http.Request) error { | ||||
| 	ctx := r.Context() | ||||
| 	claims, _ := server.ClaimsFromContext(ctx) | ||||
| 
 | ||||
| 	var req PatchFieldsRequest | ||||
| 	err := render.Decode(r, &req) | ||||
| 	if err != nil { | ||||
| 		return server.APIError{Code: server.ErrBadRequest} | ||||
| 	} | ||||
| 
 | ||||
| 	// max 25 fields | ||||
| 	if len(req.Fields) > db.MaxFields { | ||||
| 		return server.APIError{ | ||||
| 			Code:    server.ErrBadRequest, | ||||
| 			Details: fmt.Sprintf("Too many fields (max %d, current %d)", db.MaxFields, len(req.Fields)), | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// validate all fields | ||||
| 	for i, field := range req.Fields { | ||||
| 		if s := field.Validate(); s != "" { | ||||
| 			return server.APIError{ | ||||
| 				Code:    server.ErrBadRequest, | ||||
| 				Details: fmt.Sprintf("field %d: %s", i, s), | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	err = s.DB.SetUserFields(ctx, claims.UserID, req.Fields) | ||||
| 	if err != nil { | ||||
| 		log.Errorf("setting fields for user %v: %v", claims.UserID, err) | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	// echo the fields back on success | ||||
| 	render.JSON(w, r, req) | ||||
| 	return nil | ||||
| } | ||||
|  | @ -13,8 +13,11 @@ func Mount(srv *server.Server, r chi.Router) { | |||
| 	s := &Server{srv} | ||||
| 
 | ||||
| 	r.Route("/users", func(r chi.Router) { | ||||
| 		r.With(server.MustAuth).Get("/@me", server.WrapHandler(s.getMeUser)) | ||||
| 
 | ||||
| 		r.Get("/{userRef}", server.WrapHandler(s.getUser)) | ||||
| 
 | ||||
| 		r.With(server.MustAuth).Group(func(r chi.Router) { | ||||
| 			r.Get("/@me", server.WrapHandler(s.getMeUser)) | ||||
| 			r.Patch("/@me/fields", server.WrapHandler(s.patchUserFields)) | ||||
| 		}) | ||||
| 	}) | ||||
| } | ||||
|  |  | |||
|  | @ -41,8 +41,9 @@ func WrapHandler(hn func(w http.ResponseWriter, r *http.Request) error) http.Han | |||
| type APIError struct { | ||||
| 	Code    int    `json:"code"` | ||||
| 	Message string `json:"message,omitempty"` | ||||
| 	Details string `json:"details,omitempty"` | ||||
| 
 | ||||
| 	// Status is | ||||
| 	// Status is set as the HTTP status code. | ||||
| 	Status int `json:"-"` | ||||
| } | ||||
| 
 | ||||
|  | @ -64,6 +65,8 @@ func (e *APIError) prepare() { | |||
| const ( | ||||
| 	ErrBadRequest          = 400 | ||||
| 	ErrForbidden           = 403 | ||||
| 	ErrNotFound            = 404 | ||||
| 	ErrMethodNotAllowed    = 405 | ||||
| 	ErrInternalServerError = 500 // catch-all code for unknown errors | ||||
| 
 | ||||
| 	// Login/authorize error codes | ||||
|  | @ -79,6 +82,8 @@ var errCodeMessages = map[int]string{ | |||
| 	ErrBadRequest:          "Bad request", | ||||
| 	ErrForbidden:           "Forbidden", | ||||
| 	ErrInternalServerError: "Internal server error", | ||||
| 	ErrNotFound:            "Not found", | ||||
| 	ErrMethodNotAllowed:    "Method not allowed", | ||||
| 
 | ||||
| 	ErrInvalidState:     "Invalid OAuth state", | ||||
| 	ErrInvalidOAuthCode: "Invalid OAuth code", | ||||
|  | @ -91,6 +96,8 @@ var errCodeStatuses = map[int]int{ | |||
| 	ErrBadRequest:          http.StatusBadRequest, | ||||
| 	ErrForbidden:           http.StatusForbidden, | ||||
| 	ErrInternalServerError: http.StatusInternalServerError, | ||||
| 	ErrNotFound:            http.StatusNotFound, | ||||
| 	ErrMethodNotAllowed:    http.StatusMethodNotAllowed, | ||||
| 
 | ||||
| 	ErrInvalidState:     http.StatusBadRequest, | ||||
| 	ErrInvalidOAuthCode: http.StatusForbidden, | ||||
|  |  | |||
|  | @ -1,12 +1,14 @@ | |||
| package server | ||||
| 
 | ||||
| import ( | ||||
| 	"net/http" | ||||
| 	"os" | ||||
| 
 | ||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/db" | ||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/server/auth" | ||||
| 	"github.com/go-chi/chi/v5" | ||||
| 	"github.com/go-chi/chi/v5/middleware" | ||||
| 	"github.com/go-chi/render" | ||||
| ) | ||||
| 
 | ||||
| // Revision is the git commit, filled at build time | ||||
|  | @ -39,6 +41,23 @@ func New() (*Server, error) { | |||
| 	// enable authentication for all routes (but don't require it) | ||||
| 	s.Router.Use(s.maybeAuth) | ||||
| 
 | ||||
| 	// return an API error for not found + method not allowed | ||||
| 	s.Router.NotFound(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		render.Status(r, errCodeStatuses[ErrNotFound]) | ||||
| 		render.JSON(w, r, APIError{ | ||||
| 			Code:    ErrNotFound, | ||||
| 			Message: errCodeMessages[ErrNotFound], | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	s.Router.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		render.Status(r, errCodeStatuses[ErrMethodNotAllowed]) | ||||
| 		render.JSON(w, r, APIError{ | ||||
| 			Code:    ErrMethodNotAllowed, | ||||
| 			Message: errCodeMessages[ErrMethodNotAllowed], | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	return s, nil | ||||
| } | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue