feat: add token IDs, store tokens in db for early invalidation
This commit is contained in:
		
							parent
							
								
									58c1c1794e
								
							
						
					
					
						commit
						e5723360a7
					
				
					 7 changed files with 248 additions and 9 deletions
				
			
		
							
								
								
									
										98
									
								
								backend/db/tokens.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								backend/db/tokens.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,98 @@ | |||
| package db | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"emperror.dev/errors" | ||||
| 	"github.com/georgysavva/scany/pgxscan" | ||||
| 	"github.com/jackc/pgx/v4" | ||||
| 	"github.com/rs/xid" | ||||
| ) | ||||
| 
 | ||||
| type Token struct { | ||||
| 	UserID      xid.ID | ||||
| 	TokenID     xid.ID | ||||
| 	Invalidated bool | ||||
| 	Created     time.Time | ||||
| 	Expires     time.Time | ||||
| } | ||||
| 
 | ||||
| func (db *DB) TokenValid(ctx context.Context, userID, tokenID xid.ID) (valid bool, err error) { | ||||
| 	sql, args, err := sq.Select("*").From("tokens"). | ||||
| 		Where("user_id = ?", userID). | ||||
| 		Where("token_id = ?", tokenID). | ||||
| 		ToSql() | ||||
| 	if err != nil { | ||||
| 		return false, errors.Wrap(err, "building sql") | ||||
| 	} | ||||
| 
 | ||||
| 	var t Token | ||||
| 	err = pgxscan.Get(ctx, db, &t, sql, args...) | ||||
| 	if err != nil { | ||||
| 		if errors.Cause(err) == pgx.ErrNoRows { | ||||
| 			return false, nil | ||||
| 		} | ||||
| 
 | ||||
| 		return false, errors.Wrap(err, "getting from database") | ||||
| 	} | ||||
| 
 | ||||
| 	now := time.Now() | ||||
| 	return !t.Invalidated && t.Created.Before(now) && t.Expires.After(now), nil | ||||
| } | ||||
| 
 | ||||
| func (db *DB) Tokens(ctx context.Context, userID xid.ID) (ts []Token, err error) { | ||||
| 	sql, args, err := sq.Select("*").From("tokens"). | ||||
| 		Where("user_id = ?", userID). | ||||
| 		Where("expires > ?", time.Now()). | ||||
| 		OrderBy("created"). | ||||
| 		ToSql() | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "building sql") | ||||
| 	} | ||||
| 
 | ||||
| 	err = pgxscan.Select(ctx, db, &ts, sql, args...) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "getting from database") | ||||
| 	} | ||||
| 	return ts, nil | ||||
| } | ||||
| 
 | ||||
| // 3 months, might be customizable later | ||||
| const ExpiryTime = 3 * 30 * 24 * time.Hour | ||||
| 
 | ||||
| // SaveToken saves a token to the database. | ||||
| func (db *DB) SaveToken(ctx context.Context, userID xid.ID, tokenID xid.ID) (t Token, err error) { | ||||
| 	sql, args, err := sq.Insert("tokens"). | ||||
| 		Columns("user_id", "token_id", "expires"). | ||||
| 		Values(userID, tokenID, time.Now().Add(ExpiryTime)). | ||||
| 		Suffix("RETURNING *"). | ||||
| 		ToSql() | ||||
| 	if err != nil { | ||||
| 		return t, errors.Wrap(err, "building sql") | ||||
| 	} | ||||
| 
 | ||||
| 	err = pgxscan.Get(ctx, db, &t, sql, args...) | ||||
| 	if err != nil { | ||||
| 		return t, errors.Wrap(err, "inserting token") | ||||
| 	} | ||||
| 	return t, nil | ||||
| } | ||||
| 
 | ||||
| func (db *DB) InvalidateToken(ctx context.Context, userID xid.ID, tokenID xid.ID) (t Token, err error) { | ||||
| 	sql, args, err := sq.Update("tokens"). | ||||
| 		Where("user_id = ?"). | ||||
| 		Where("token_id = ?"). | ||||
| 		Set("invalidated", true). | ||||
| 		Suffix("RETURNING *"). | ||||
| 		ToSql() | ||||
| 	if err != nil { | ||||
| 		return t, errors.Wrap(err, "building sql") | ||||
| 	} | ||||
| 
 | ||||
| 	err = pgxscan.Get(ctx, db, &t, sql, args...) | ||||
| 	if err != nil { | ||||
| 		return t, errors.Wrap(err, "invalidating token") | ||||
| 	} | ||||
| 	return t, nil | ||||
| } | ||||
|  | @ -11,6 +11,7 @@ import ( | |||
| 	"github.com/bwmarrin/discordgo" | ||||
| 	"github.com/go-chi/render" | ||||
| 	"github.com/mediocregopher/radix/v4" | ||||
| 	"github.com/rs/xid" | ||||
| 	"golang.org/x/oauth2" | ||||
| ) | ||||
| 
 | ||||
|  | @ -81,11 +82,19 @@ func (s *Server) discordCallback(w http.ResponseWriter, r *http.Request) error { | |||
| 			log.Errorf("updating user %v with Discord info: %v", u.ID, err) | ||||
| 		} | ||||
| 
 | ||||
| 		token, err := s.Auth.CreateToken(u.ID) | ||||
| 		// TODO: implement user + token permissions | ||||
| 		tokenID := xid.New() | ||||
| 		token, err := s.Auth.CreateToken(u.ID, tokenID, false, true) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 		// save token to database | ||||
| 		_, err = s.DB.SaveToken(ctx, u.ID, tokenID) | ||||
| 		if err != nil { | ||||
| 			return errors.Wrap(err, "saving token to database") | ||||
| 		} | ||||
| 
 | ||||
| 		render.JSON(w, r, discordCallbackResponse{ | ||||
| 			HasAccount: true, | ||||
| 			Token:      token, | ||||
|  | @ -206,11 +215,19 @@ func (s *Server) discordSignup(w http.ResponseWriter, r *http.Request) error { | |||
| 	} | ||||
| 
 | ||||
| 	// create token | ||||
| 	token, err := s.Auth.CreateToken(u.ID) | ||||
| 	// TODO: implement user + token permissions | ||||
| 	tokenID := xid.New() | ||||
| 	token, err := s.Auth.CreateToken(u.ID, tokenID, false, true) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "creating token") | ||||
| 	} | ||||
| 
 | ||||
| 	// save token to database | ||||
| 	_, err = s.DB.SaveToken(ctx, u.ID, tokenID) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "saving token to database") | ||||
| 	} | ||||
| 
 | ||||
| 	// return user | ||||
| 	render.JSON(w, r, signupResponse{ | ||||
| 		User:  *dbUserToUserResponse(u), | ||||
|  |  | |||
|  | @ -67,6 +67,11 @@ func Mount(srv *server.Server, r chi.Router) { | |||
| 		// invite routes | ||||
| 		r.With(server.MustAuth).Get("/invites", server.WrapHandler(s.getInvites)) | ||||
| 		r.With(server.MustAuth).Post("/invites", server.WrapHandler(s.createInvite)) | ||||
| 
 | ||||
| 		// tokens | ||||
| 		r.With(server.MustAuth).Get("/tokens", server.WrapHandler(s.getTokens)) | ||||
| 		r.With(server.MustAuth).Post("/tokens", server.WrapHandler(s.createToken)) | ||||
| 		r.With(server.MustAuth).Delete("/tokens", server.WrapHandler(s.deleteToken)) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										72
									
								
								backend/routes/auth/tokens.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										72
									
								
								backend/routes/auth/tokens.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,72 @@ | |||
| package auth | ||||
| 
 | ||||
| import ( | ||||
| 	"net/http" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/db" | ||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/server" | ||||
| 	"emperror.dev/errors" | ||||
| 	"github.com/go-chi/render" | ||||
| 	"github.com/rs/xid" | ||||
| ) | ||||
| 
 | ||||
| type getTokenResponse struct { | ||||
| 	TokenID xid.ID    `json:"id"` | ||||
| 	Created time.Time `json:"created"` | ||||
| 	Expires time.Time `json:"expires"` | ||||
| } | ||||
| 
 | ||||
| func dbTokenToGetResponse(t db.Token) getTokenResponse { | ||||
| 	return getTokenResponse{ | ||||
| 		TokenID: t.TokenID, | ||||
| 		Created: t.Created, | ||||
| 		Expires: t.Expires, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s *Server) getTokens(w http.ResponseWriter, r *http.Request) error { | ||||
| 	ctx := r.Context() | ||||
| 	claims, _ := server.ClaimsFromContext(ctx) | ||||
| 
 | ||||
| 	tokens, err := s.DB.Tokens(ctx, claims.UserID) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "getting tokens") | ||||
| 	} | ||||
| 
 | ||||
| 	resps := make([]getTokenResponse, len(tokens)) | ||||
| 	for i := range tokens { | ||||
| 		resps[i] = dbTokenToGetResponse(tokens[i]) | ||||
| 	} | ||||
| 
 | ||||
| 	render.JSON(w, r, resps) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| type deleteTokenResponse struct { | ||||
| 	TokenID     xid.ID    `json:"id"` | ||||
| 	Invalidated bool      `json:"invalidated"` | ||||
| 	Created     time.Time `json:"time"` | ||||
| } | ||||
| 
 | ||||
| func (s *Server) deleteToken(w http.ResponseWriter, r *http.Request) error { | ||||
| 	ctx := r.Context() | ||||
| 	claims, _ := server.ClaimsFromContext(ctx) | ||||
| 
 | ||||
| 	t, err := s.DB.InvalidateToken(ctx, claims.UserID, claims.TokenID) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "invalidating token") | ||||
| 	} | ||||
| 
 | ||||
| 	render.JSON(w, r, deleteTokenResponse{ | ||||
| 		TokenID:     t.TokenID, | ||||
| 		Invalidated: t.Invalidated, | ||||
| 		Created:     t.Created, | ||||
| 	}) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) createToken(w http.ResponseWriter, r *http.Request) error { | ||||
| 	// unimplemented right now | ||||
| 	return server.APIError{Code: server.ErrForbidden} | ||||
| } | ||||
|  | @ -5,6 +5,7 @@ import ( | |||
| 	"net/http" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/log" | ||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/server/auth" | ||||
| 	"github.com/go-chi/render" | ||||
| ) | ||||
|  | @ -28,6 +29,27 @@ func (s *Server) maybeAuth(next http.Handler) http.Handler { | |||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		// "valid" here refers to existence and expiry date, not whether the token is known | ||||
| 		valid, err := s.DB.TokenValid(r.Context(), claims.UserID, claims.TokenID) | ||||
| 		if err != nil { | ||||
| 			log.Errorf("validating token for user %v: %v", claims.UserID, err) | ||||
| 			render.Status(r, errCodeStatuses[ErrInternalServerError]) | ||||
| 			render.JSON(w, r, APIError{ | ||||
| 				Code:    ErrInternalServerError, | ||||
| 				Message: errCodeMessages[ErrInternalServerError], | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		if !valid { | ||||
| 			render.Status(r, errCodeStatuses[ErrInvalidToken]) | ||||
| 			render.JSON(w, r, APIError{ | ||||
| 				Code:    ErrInvalidToken, | ||||
| 				Message: errCodeMessages[ErrInvalidToken], | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		ctx := context.WithValue(r.Context(), ctxKeyClaims, claims) | ||||
| 
 | ||||
| 		next.ServeHTTP(w, r.WithContext(ctx)) | ||||
|  |  | |||
|  | @ -14,7 +14,13 @@ import ( | |||
| 
 | ||||
| // Claims are the claims used in a token. | ||||
| type Claims struct { | ||||
| 	UserID xid.ID `json:"sub"` | ||||
| 	UserID      xid.ID `json:"sub"` | ||||
| 	TokenID     xid.ID `json:"jti"` | ||||
| 	UserIsAdmin bool   `json:"adm"` | ||||
| 
 | ||||
| 	// TokenWrite specifies whether this token can be used for write actions. | ||||
| 	// If set to false, this token can only be used for read actions. | ||||
| 	TokenWrite bool `json:"twr"` | ||||
| 
 | ||||
| 	jwt.RegisteredClaims | ||||
| } | ||||
|  | @ -37,16 +43,20 @@ func New() *Verifier { | |||
| 	return &Verifier{key: key} | ||||
| } | ||||
| 
 | ||||
| const expireDays = 30 | ||||
| // ExpireDays is after how many days the token will expire. | ||||
| const ExpireDays = 30 | ||||
| 
 | ||||
| // CreateToken creates a token for the given user ID. | ||||
| // It expires after 30 days. | ||||
| func (v *Verifier) CreateToken(userID xid.ID) (string, error) { | ||||
| func (v *Verifier) CreateToken(userID, tokenID xid.ID, isAdmin bool, isWriteToken bool) (token string, err error) { | ||||
| 	now := time.Now() | ||||
| 	expires := now.Add(expireDays * 24 * time.Hour) | ||||
| 	expires := now.Add(ExpireDays * 24 * time.Hour) | ||||
| 
 | ||||
| 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, Claims{ | ||||
| 		UserID: userID, | ||||
| 	t := jwt.NewWithClaims(jwt.SigningMethodHS256, Claims{ | ||||
| 		UserID:      userID, | ||||
| 		TokenID:     tokenID, | ||||
| 		UserIsAdmin: isAdmin, | ||||
| 		TokenWrite:  isWriteToken, | ||||
| 		RegisteredClaims: jwt.RegisteredClaims{ | ||||
| 			Issuer:    "pronouns", | ||||
| 			ExpiresAt: jwt.NewNumericDate(expires), | ||||
|  | @ -55,7 +65,7 @@ func (v *Verifier) CreateToken(userID xid.ID) (string, error) { | |||
| 		}, | ||||
| 	}) | ||||
| 
 | ||||
| 	return token.SignedString(v.key) | ||||
| 	return t.SignedString(v.key) | ||||
| } | ||||
| 
 | ||||
| // Claims parses the given token and returns its Claims. | ||||
|  |  | |||
							
								
								
									
										15
									
								
								scripts/migrate/003_add_tokens.sql
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								scripts/migrate/003_add_tokens.sql
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,15 @@ | |||
| -- +migrate Up | ||||
| 
 | ||||
| -- 2022-12-23: Add database-backed tokens | ||||
| create table tokens ( | ||||
|     user_id     text not null references users (id) on delete cascade, | ||||
|     token_id    text primary key, | ||||
|     invalidated boolean not null default false, | ||||
|     created     timestamptz not null default now(), | ||||
|     expires     timestamptz not null | ||||
| ); | ||||
| 
 | ||||
| -- Unrelatedly, this migration also changes the column type for invites.created to timestamptz (from plain timestamp) | ||||
| -- This does not change anything code-wise, but it's recommended over plain timestamp because plain timestamp does not handle timezones correctly | ||||
| alter table invites alter column created type timestamptz; | ||||
| alter table invites alter column created set default now(); | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue