feat: add invites to backend
This commit is contained in:
		
							parent
							
								
									47ed36d24c
								
							
						
					
					
						commit
						6237ea940f
					
				
					 7 changed files with 234 additions and 30 deletions
				
			
		
							
								
								
									
										111
									
								
								backend/db/invites.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								backend/db/invites.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,111 @@ | |||
| package db | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/rand" | ||||
| 	"encoding/base64" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"emperror.dev/errors" | ||||
| 	"github.com/georgysavva/scany/pgxscan" | ||||
| 	"github.com/jackc/pgx/v4" | ||||
| 	"github.com/rs/xid" | ||||
| ) | ||||
| 
 | ||||
| type Invite struct { | ||||
| 	UserID  xid.ID | ||||
| 	Code    string | ||||
| 	Created time.Time | ||||
| 	Used    bool | ||||
| } | ||||
| 
 | ||||
| func (db *DB) UserInvites(ctx context.Context, userID xid.ID) (is []Invite, err error) { | ||||
| 	sql, args, err := sq.Select("*").From("invites").Where("user_id = ?", userID).OrderBy("created").ToSql() | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "building sql") | ||||
| 	} | ||||
| 
 | ||||
| 	err = pgxscan.Select(ctx, db, &is, sql, args...) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "querying database") | ||||
| 	} | ||||
| 	if len(is) == 0 { | ||||
| 		is = []Invite{} | ||||
| 	} | ||||
| 
 | ||||
| 	return is, nil | ||||
| } | ||||
| 
 | ||||
| const ErrTooManyInvites = errors.Sentinel("user invite limit reached") | ||||
| 
 | ||||
| func (db *DB) CreateInvite(ctx context.Context, userID xid.ID) (i Invite, err error) { | ||||
| 	tx, err := db.Begin(ctx) | ||||
| 	if err != nil { | ||||
| 		return i, errors.Wrap(err, "beginning transaction") | ||||
| 	} | ||||
| 	defer tx.Rollback(ctx) | ||||
| 
 | ||||
| 	var maxInvites, inviteCount int | ||||
| 	err = tx.QueryRow(ctx, "SELECT max_invites FROM users WHERE id = $1", userID).Scan(&maxInvites) | ||||
| 	if err != nil { | ||||
| 		return i, errors.Wrap(err, "querying invite limit") | ||||
| 	} | ||||
| 	err = tx.QueryRow(ctx, "SELECT count(*) FROM invites WHERE user_id = $1", userID).Scan(&inviteCount) | ||||
| 	if err != nil { | ||||
| 		return i, errors.Wrap(err, "querying current invite count") | ||||
| 	} | ||||
| 
 | ||||
| 	if inviteCount >= maxInvites { | ||||
| 		return i, ErrTooManyInvites | ||||
| 	} | ||||
| 
 | ||||
| 	b := make([]byte, 32) | ||||
| 
 | ||||
| 	_, err = rand.Read(b) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 
 | ||||
| 	code := base64.RawURLEncoding.EncodeToString(b) | ||||
| 
 | ||||
| 	sql, args, err := sq.Insert("invites").Columns("user_id", "code").Values(userID, code).Suffix("RETURNING *").ToSql() | ||||
| 	if err != nil { | ||||
| 		return i, errors.Wrap(err, "building insert invite sql") | ||||
| 	} | ||||
| 
 | ||||
| 	err = pgxscan.Get(ctx, db, &i, sql, args...) | ||||
| 	if err != nil { | ||||
| 		return i, errors.Wrap(err, "inserting invite") | ||||
| 	} | ||||
| 
 | ||||
| 	err = tx.Commit(ctx) | ||||
| 	if err != nil { | ||||
| 		return i, errors.Wrap(err, "committing transaction") | ||||
| 	} | ||||
| 	return i, nil | ||||
| } | ||||
| 
 | ||||
| func (db *DB) InvalidateInvite(ctx context.Context, tx pgx.Tx, code string) (valid, alreadyUsed bool, err error) { | ||||
| 	err = tx.QueryRow(ctx, "SELECT used FROM invites WHERE code = $1", code).Scan(&alreadyUsed) | ||||
| 	if err != nil { | ||||
| 		if errors.Cause(err) == pgx.ErrNoRows { | ||||
| 			return false, false, nil | ||||
| 		} | ||||
| 
 | ||||
| 		return false, false, errors.Wrap(err, "checking if invite exists and is used") | ||||
| 	} | ||||
| 
 | ||||
| 	// valid: true, already used: true | ||||
| 	if alreadyUsed { | ||||
| 		return true, true, nil | ||||
| 	} | ||||
| 
 | ||||
| 	// invite is valid, not already used | ||||
| 	_, err = tx.Exec(ctx, "UPDATE invites SET used = true WHERE code = $1", code) | ||||
| 	if err != nil { | ||||
| 		return false, false, errors.Wrap(err, "updating invite usage") | ||||
| 	} | ||||
| 
 | ||||
| 	// valid: true, already used: false | ||||
| 	return true, false, nil | ||||
| } | ||||
|  | @ -24,6 +24,8 @@ type User struct { | |||
| 
 | ||||
| 	Discord         *string | ||||
| 	DiscordUsername *string | ||||
| 
 | ||||
| 	MaxInvites int | ||||
| } | ||||
| 
 | ||||
| // usernames must match this regex | ||||
|  |  | |||
|  | @ -182,17 +182,18 @@ func (s *Server) discordSignup(w http.ResponseWriter, r *http.Request) error { | |||
| 	} | ||||
| 
 | ||||
| 	if s.RequireInvite { | ||||
| 		// TODO: check invites, invalidate invite when done | ||||
| 		inviteValid := true | ||||
| 
 | ||||
| 		if !inviteValid { | ||||
| 			err = tx.Rollback(ctx) | ||||
| 		valid, used, err := s.DB.InvalidateInvite(ctx, tx, req.InviteCode) | ||||
| 		if err != nil { | ||||
| 				return errors.Wrap(err, "rolling back transaction") | ||||
| 			return errors.Wrap(err, "checking and invalidating invite") | ||||
| 		} | ||||
| 
 | ||||
| 		if !valid { | ||||
| 			return server.APIError{Code: server.ErrInviteRequired} | ||||
| 		} | ||||
| 
 | ||||
| 		if used { | ||||
| 			return server.APIError{Code: server.ErrInviteAlreadyUsed} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// delete sign up ticket | ||||
|  |  | |||
							
								
								
									
										68
									
								
								backend/routes/auth/invite.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								backend/routes/auth/invite.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,68 @@ | |||
| package auth | ||||
| 
 | ||||
| import ( | ||||
| 	"net/http" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/db" | ||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/server" | ||||
| 	"emperror.dev/errors" | ||||
| 	"github.com/go-chi/render" | ||||
| ) | ||||
| 
 | ||||
| type inviteResponse struct { | ||||
| 	Code    string    `json:"string"` | ||||
| 	Created time.Time `json:"created"` | ||||
| 	Used    bool      `json:"used"` | ||||
| } | ||||
| 
 | ||||
| func dbInviteToResponse(i db.Invite) inviteResponse { | ||||
| 	return inviteResponse{ | ||||
| 		Code:    i.Code, | ||||
| 		Created: i.Created, | ||||
| 		Used:    i.Used, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s *Server) getInvites(w http.ResponseWriter, r *http.Request) error { | ||||
| 	if !s.RequireInvite { | ||||
| 		return server.APIError{Code: server.ErrInvitesDisabled} | ||||
| 	} | ||||
| 
 | ||||
| 	ctx := r.Context() | ||||
| 	claims, _ := server.ClaimsFromContext(ctx) | ||||
| 
 | ||||
| 	is, err := s.DB.UserInvites(ctx, claims.UserID) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "getting user invites") | ||||
| 	} | ||||
| 
 | ||||
| 	resps := make([]inviteResponse, len(is)) | ||||
| 	for i := range is { | ||||
| 		resps[i] = dbInviteToResponse(is[i]) | ||||
| 	} | ||||
| 
 | ||||
| 	render.JSON(w, r, resps) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) createInvite(w http.ResponseWriter, r *http.Request) error { | ||||
| 	if !s.RequireInvite { | ||||
| 		return server.APIError{Code: server.ErrInvitesDisabled} | ||||
| 	} | ||||
| 
 | ||||
| 	ctx := r.Context() | ||||
| 	claims, _ := server.ClaimsFromContext(ctx) | ||||
| 
 | ||||
| 	inv, err := s.DB.CreateInvite(ctx, claims.UserID) | ||||
| 	if err != nil { | ||||
| 		if err == db.ErrTooManyInvites { | ||||
| 			return server.APIError{Code: server.ErrInviteLimitReached} | ||||
| 		} | ||||
| 
 | ||||
| 		return errors.Wrap(err, "creating invite") | ||||
| 	} | ||||
| 
 | ||||
| 	render.JSON(w, r, dbInviteToResponse(inv)) | ||||
| 	return nil | ||||
| } | ||||
|  | @ -63,6 +63,10 @@ func Mount(srv *server.Server, r chi.Router) { | |||
| 			// takes discord signup ticket to register account | ||||
| 			r.Post("/signup", server.WrapHandler(s.discordSignup)) | ||||
| 		}) | ||||
| 
 | ||||
| 		// invite routes | ||||
| 		r.With(server.MustAuth).Get("/invites", server.WrapHandler(s.getInvites)) | ||||
| 		r.With(server.MustAuth).Post("/invites", server.WrapHandler(s.createInvite)) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -80,6 +80,9 @@ const ( | |||
| 	ErrInvalidTicket      = 1005 // invalid signup ticket | ||||
| 	ErrInvalidUsername    = 1006 // invalid username (when signing up) | ||||
| 	ErrUsernameTaken      = 1007 // username taken (when signing up) | ||||
| 	ErrInvitesDisabled    = 1008 // invites are disabled (unneeded) | ||||
| 	ErrInviteLimitReached = 1009 // invite limit reached (when creating invites) | ||||
| 	ErrInviteAlreadyUsed  = 1010 // invite already used (when signing up) | ||||
| 
 | ||||
| 	// User-related error codes | ||||
| 	ErrUserNotFound = 2001 | ||||
|  | @ -107,6 +110,9 @@ var errCodeMessages = map[int]string{ | |||
| 	ErrInvalidTicket:      "Invalid signup ticket", | ||||
| 	ErrInvalidUsername:    "Invalid username", | ||||
| 	ErrUsernameTaken:      "Username is already taken", | ||||
| 	ErrInvitesDisabled:    "Invites are disabled", | ||||
| 	ErrInviteLimitReached: "Your account has reached the invite limit", | ||||
| 	ErrInviteAlreadyUsed:  "That invite code has already been used", | ||||
| 
 | ||||
| 	ErrUserNotFound: "User not found", | ||||
| 
 | ||||
|  | @ -131,6 +137,9 @@ var errCodeStatuses = map[int]int{ | |||
| 	ErrInvalidTicket:      http.StatusBadRequest, | ||||
| 	ErrInvalidUsername:    http.StatusBadRequest, | ||||
| 	ErrUsernameTaken:      http.StatusBadRequest, | ||||
| 	ErrInvitesDisabled:    http.StatusForbidden, | ||||
| 	ErrInviteLimitReached: http.StatusForbidden, | ||||
| 	ErrInviteAlreadyUsed:  http.StatusBadRequest, | ||||
| 
 | ||||
| 	ErrUserNotFound: http.StatusNotFound, | ||||
| 
 | ||||
|  |  | |||
|  | @ -12,7 +12,9 @@ create table users ( | |||
|     links       text[], | ||||
| 
 | ||||
|     discord          text unique, -- for Discord oauth | ||||
|     discord_username text | ||||
|     discord_username text, | ||||
| 
 | ||||
|     max_invites int default 10 | ||||
| ); | ||||
| 
 | ||||
| create table user_names ( | ||||
|  | @ -80,3 +82,10 @@ create table member_fields ( | |||
|     friends_only text[], | ||||
|     avoid        text[] | ||||
| ); | ||||
| 
 | ||||
| create table invites ( | ||||
|     user_id text      not null references users (id) on delete cascade, | ||||
|     code    text      primary key, | ||||
|     created timestamp not null default (current_timestamp at time zone 'utc'), | ||||
|     used    boolean not null default false | ||||
| ); | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue