feat: add token IDs, store tokens in db for early invalidation
This commit is contained in:
		
							parent
							
								
									58c1c1794e
								
							
						
					
					
						commit
						e5723360a7
					
				
					 7 changed files with 248 additions and 9 deletions
				
			
		
							
								
								
									
										98
									
								
								backend/db/tokens.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								backend/db/tokens.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,98 @@ | ||||||
|  | package db | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"emperror.dev/errors" | ||||||
|  | 	"github.com/georgysavva/scany/pgxscan" | ||||||
|  | 	"github.com/jackc/pgx/v4" | ||||||
|  | 	"github.com/rs/xid" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type Token struct { | ||||||
|  | 	UserID      xid.ID | ||||||
|  | 	TokenID     xid.ID | ||||||
|  | 	Invalidated bool | ||||||
|  | 	Created     time.Time | ||||||
|  | 	Expires     time.Time | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (db *DB) TokenValid(ctx context.Context, userID, tokenID xid.ID) (valid bool, err error) { | ||||||
|  | 	sql, args, err := sq.Select("*").From("tokens"). | ||||||
|  | 		Where("user_id = ?", userID). | ||||||
|  | 		Where("token_id = ?", tokenID). | ||||||
|  | 		ToSql() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false, errors.Wrap(err, "building sql") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var t Token | ||||||
|  | 	err = pgxscan.Get(ctx, db, &t, sql, args...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		if errors.Cause(err) == pgx.ErrNoRows { | ||||||
|  | 			return false, nil | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return false, errors.Wrap(err, "getting from database") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	now := time.Now() | ||||||
|  | 	return !t.Invalidated && t.Created.Before(now) && t.Expires.After(now), nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (db *DB) Tokens(ctx context.Context, userID xid.ID) (ts []Token, err error) { | ||||||
|  | 	sql, args, err := sq.Select("*").From("tokens"). | ||||||
|  | 		Where("user_id = ?", userID). | ||||||
|  | 		Where("expires > ?", time.Now()). | ||||||
|  | 		OrderBy("created"). | ||||||
|  | 		ToSql() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, errors.Wrap(err, "building sql") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	err = pgxscan.Select(ctx, db, &ts, sql, args...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, errors.Wrap(err, "getting from database") | ||||||
|  | 	} | ||||||
|  | 	return ts, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // 3 months, might be customizable later | ||||||
|  | const ExpiryTime = 3 * 30 * 24 * time.Hour | ||||||
|  | 
 | ||||||
|  | // SaveToken saves a token to the database. | ||||||
|  | func (db *DB) SaveToken(ctx context.Context, userID xid.ID, tokenID xid.ID) (t Token, err error) { | ||||||
|  | 	sql, args, err := sq.Insert("tokens"). | ||||||
|  | 		Columns("user_id", "token_id", "expires"). | ||||||
|  | 		Values(userID, tokenID, time.Now().Add(ExpiryTime)). | ||||||
|  | 		Suffix("RETURNING *"). | ||||||
|  | 		ToSql() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return t, errors.Wrap(err, "building sql") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	err = pgxscan.Get(ctx, db, &t, sql, args...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return t, errors.Wrap(err, "inserting token") | ||||||
|  | 	} | ||||||
|  | 	return t, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (db *DB) InvalidateToken(ctx context.Context, userID xid.ID, tokenID xid.ID) (t Token, err error) { | ||||||
|  | 	sql, args, err := sq.Update("tokens"). | ||||||
|  | 		Where("user_id = ?"). | ||||||
|  | 		Where("token_id = ?"). | ||||||
|  | 		Set("invalidated", true). | ||||||
|  | 		Suffix("RETURNING *"). | ||||||
|  | 		ToSql() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return t, errors.Wrap(err, "building sql") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	err = pgxscan.Get(ctx, db, &t, sql, args...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return t, errors.Wrap(err, "invalidating token") | ||||||
|  | 	} | ||||||
|  | 	return t, nil | ||||||
|  | } | ||||||
|  | @ -11,6 +11,7 @@ import ( | ||||||
| 	"github.com/bwmarrin/discordgo" | 	"github.com/bwmarrin/discordgo" | ||||||
| 	"github.com/go-chi/render" | 	"github.com/go-chi/render" | ||||||
| 	"github.com/mediocregopher/radix/v4" | 	"github.com/mediocregopher/radix/v4" | ||||||
|  | 	"github.com/rs/xid" | ||||||
| 	"golang.org/x/oauth2" | 	"golang.org/x/oauth2" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -81,11 +82,19 @@ func (s *Server) discordCallback(w http.ResponseWriter, r *http.Request) error { | ||||||
| 			log.Errorf("updating user %v with Discord info: %v", u.ID, err) | 			log.Errorf("updating user %v with Discord info: %v", u.ID, err) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		token, err := s.Auth.CreateToken(u.ID) | 		// TODO: implement user + token permissions | ||||||
|  | 		tokenID := xid.New() | ||||||
|  | 		token, err := s.Auth.CreateToken(u.ID, tokenID, false, true) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		// save token to database | ||||||
|  | 		_, err = s.DB.SaveToken(ctx, u.ID, tokenID) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errors.Wrap(err, "saving token to database") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		render.JSON(w, r, discordCallbackResponse{ | 		render.JSON(w, r, discordCallbackResponse{ | ||||||
| 			HasAccount: true, | 			HasAccount: true, | ||||||
| 			Token:      token, | 			Token:      token, | ||||||
|  | @ -206,11 +215,19 @@ func (s *Server) discordSignup(w http.ResponseWriter, r *http.Request) error { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// create token | 	// create token | ||||||
| 	token, err := s.Auth.CreateToken(u.ID) | 	// TODO: implement user + token permissions | ||||||
|  | 	tokenID := xid.New() | ||||||
|  | 	token, err := s.Auth.CreateToken(u.ID, tokenID, false, true) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errors.Wrap(err, "creating token") | 		return errors.Wrap(err, "creating token") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// save token to database | ||||||
|  | 	_, err = s.DB.SaveToken(ctx, u.ID, tokenID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errors.Wrap(err, "saving token to database") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// return user | 	// return user | ||||||
| 	render.JSON(w, r, signupResponse{ | 	render.JSON(w, r, signupResponse{ | ||||||
| 		User:  *dbUserToUserResponse(u), | 		User:  *dbUserToUserResponse(u), | ||||||
|  |  | ||||||
|  | @ -67,6 +67,11 @@ func Mount(srv *server.Server, r chi.Router) { | ||||||
| 		// invite routes | 		// invite routes | ||||||
| 		r.With(server.MustAuth).Get("/invites", server.WrapHandler(s.getInvites)) | 		r.With(server.MustAuth).Get("/invites", server.WrapHandler(s.getInvites)) | ||||||
| 		r.With(server.MustAuth).Post("/invites", server.WrapHandler(s.createInvite)) | 		r.With(server.MustAuth).Post("/invites", server.WrapHandler(s.createInvite)) | ||||||
|  | 
 | ||||||
|  | 		// tokens | ||||||
|  | 		r.With(server.MustAuth).Get("/tokens", server.WrapHandler(s.getTokens)) | ||||||
|  | 		r.With(server.MustAuth).Post("/tokens", server.WrapHandler(s.createToken)) | ||||||
|  | 		r.With(server.MustAuth).Delete("/tokens", server.WrapHandler(s.deleteToken)) | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										72
									
								
								backend/routes/auth/tokens.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										72
									
								
								backend/routes/auth/tokens.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,72 @@ | ||||||
|  | package auth | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"net/http" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"codeberg.org/u1f320/pronouns.cc/backend/db" | ||||||
|  | 	"codeberg.org/u1f320/pronouns.cc/backend/server" | ||||||
|  | 	"emperror.dev/errors" | ||||||
|  | 	"github.com/go-chi/render" | ||||||
|  | 	"github.com/rs/xid" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type getTokenResponse struct { | ||||||
|  | 	TokenID xid.ID    `json:"id"` | ||||||
|  | 	Created time.Time `json:"created"` | ||||||
|  | 	Expires time.Time `json:"expires"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func dbTokenToGetResponse(t db.Token) getTokenResponse { | ||||||
|  | 	return getTokenResponse{ | ||||||
|  | 		TokenID: t.TokenID, | ||||||
|  | 		Created: t.Created, | ||||||
|  | 		Expires: t.Expires, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Server) getTokens(w http.ResponseWriter, r *http.Request) error { | ||||||
|  | 	ctx := r.Context() | ||||||
|  | 	claims, _ := server.ClaimsFromContext(ctx) | ||||||
|  | 
 | ||||||
|  | 	tokens, err := s.DB.Tokens(ctx, claims.UserID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errors.Wrap(err, "getting tokens") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	resps := make([]getTokenResponse, len(tokens)) | ||||||
|  | 	for i := range tokens { | ||||||
|  | 		resps[i] = dbTokenToGetResponse(tokens[i]) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	render.JSON(w, r, resps) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type deleteTokenResponse struct { | ||||||
|  | 	TokenID     xid.ID    `json:"id"` | ||||||
|  | 	Invalidated bool      `json:"invalidated"` | ||||||
|  | 	Created     time.Time `json:"time"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Server) deleteToken(w http.ResponseWriter, r *http.Request) error { | ||||||
|  | 	ctx := r.Context() | ||||||
|  | 	claims, _ := server.ClaimsFromContext(ctx) | ||||||
|  | 
 | ||||||
|  | 	t, err := s.DB.InvalidateToken(ctx, claims.UserID, claims.TokenID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errors.Wrap(err, "invalidating token") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	render.JSON(w, r, deleteTokenResponse{ | ||||||
|  | 		TokenID:     t.TokenID, | ||||||
|  | 		Invalidated: t.Invalidated, | ||||||
|  | 		Created:     t.Created, | ||||||
|  | 	}) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Server) createToken(w http.ResponseWriter, r *http.Request) error { | ||||||
|  | 	// unimplemented right now | ||||||
|  | 	return server.APIError{Code: server.ErrForbidden} | ||||||
|  | } | ||||||
|  | @ -5,6 +5,7 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
|  | 	"codeberg.org/u1f320/pronouns.cc/backend/log" | ||||||
| 	"codeberg.org/u1f320/pronouns.cc/backend/server/auth" | 	"codeberg.org/u1f320/pronouns.cc/backend/server/auth" | ||||||
| 	"github.com/go-chi/render" | 	"github.com/go-chi/render" | ||||||
| ) | ) | ||||||
|  | @ -28,6 +29,27 @@ func (s *Server) maybeAuth(next http.Handler) http.Handler { | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		// "valid" here refers to existence and expiry date, not whether the token is known | ||||||
|  | 		valid, err := s.DB.TokenValid(r.Context(), claims.UserID, claims.TokenID) | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Errorf("validating token for user %v: %v", claims.UserID, err) | ||||||
|  | 			render.Status(r, errCodeStatuses[ErrInternalServerError]) | ||||||
|  | 			render.JSON(w, r, APIError{ | ||||||
|  | 				Code:    ErrInternalServerError, | ||||||
|  | 				Message: errCodeMessages[ErrInternalServerError], | ||||||
|  | 			}) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if !valid { | ||||||
|  | 			render.Status(r, errCodeStatuses[ErrInvalidToken]) | ||||||
|  | 			render.JSON(w, r, APIError{ | ||||||
|  | 				Code:    ErrInvalidToken, | ||||||
|  | 				Message: errCodeMessages[ErrInvalidToken], | ||||||
|  | 			}) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		ctx := context.WithValue(r.Context(), ctxKeyClaims, claims) | 		ctx := context.WithValue(r.Context(), ctxKeyClaims, claims) | ||||||
| 
 | 
 | ||||||
| 		next.ServeHTTP(w, r.WithContext(ctx)) | 		next.ServeHTTP(w, r.WithContext(ctx)) | ||||||
|  |  | ||||||
|  | @ -14,7 +14,13 @@ import ( | ||||||
| 
 | 
 | ||||||
| // Claims are the claims used in a token. | // Claims are the claims used in a token. | ||||||
| type Claims struct { | type Claims struct { | ||||||
| 	UserID xid.ID `json:"sub"` | 	UserID      xid.ID `json:"sub"` | ||||||
|  | 	TokenID     xid.ID `json:"jti"` | ||||||
|  | 	UserIsAdmin bool   `json:"adm"` | ||||||
|  | 
 | ||||||
|  | 	// TokenWrite specifies whether this token can be used for write actions. | ||||||
|  | 	// If set to false, this token can only be used for read actions. | ||||||
|  | 	TokenWrite bool `json:"twr"` | ||||||
| 
 | 
 | ||||||
| 	jwt.RegisteredClaims | 	jwt.RegisteredClaims | ||||||
| } | } | ||||||
|  | @ -37,16 +43,20 @@ func New() *Verifier { | ||||||
| 	return &Verifier{key: key} | 	return &Verifier{key: key} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| const expireDays = 30 | // ExpireDays is after how many days the token will expire. | ||||||
|  | const ExpireDays = 30 | ||||||
| 
 | 
 | ||||||
| // CreateToken creates a token for the given user ID. | // CreateToken creates a token for the given user ID. | ||||||
| // It expires after 30 days. | // It expires after 30 days. | ||||||
| func (v *Verifier) CreateToken(userID xid.ID) (string, error) { | func (v *Verifier) CreateToken(userID, tokenID xid.ID, isAdmin bool, isWriteToken bool) (token string, err error) { | ||||||
| 	now := time.Now() | 	now := time.Now() | ||||||
| 	expires := now.Add(expireDays * 24 * time.Hour) | 	expires := now.Add(ExpireDays * 24 * time.Hour) | ||||||
| 
 | 
 | ||||||
| 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, Claims{ | 	t := jwt.NewWithClaims(jwt.SigningMethodHS256, Claims{ | ||||||
| 		UserID: userID, | 		UserID:      userID, | ||||||
|  | 		TokenID:     tokenID, | ||||||
|  | 		UserIsAdmin: isAdmin, | ||||||
|  | 		TokenWrite:  isWriteToken, | ||||||
| 		RegisteredClaims: jwt.RegisteredClaims{ | 		RegisteredClaims: jwt.RegisteredClaims{ | ||||||
| 			Issuer:    "pronouns", | 			Issuer:    "pronouns", | ||||||
| 			ExpiresAt: jwt.NewNumericDate(expires), | 			ExpiresAt: jwt.NewNumericDate(expires), | ||||||
|  | @ -55,7 +65,7 @@ func (v *Verifier) CreateToken(userID xid.ID) (string, error) { | ||||||
| 		}, | 		}, | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	return token.SignedString(v.key) | 	return t.SignedString(v.key) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Claims parses the given token and returns its Claims. | // Claims parses the given token and returns its Claims. | ||||||
|  |  | ||||||
							
								
								
									
										15
									
								
								scripts/migrate/003_add_tokens.sql
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								scripts/migrate/003_add_tokens.sql
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,15 @@ | ||||||
|  | -- +migrate Up | ||||||
|  | 
 | ||||||
|  | -- 2022-12-23: Add database-backed tokens | ||||||
|  | create table tokens ( | ||||||
|  |     user_id     text not null references users (id) on delete cascade, | ||||||
|  |     token_id    text primary key, | ||||||
|  |     invalidated boolean not null default false, | ||||||
|  |     created     timestamptz not null default now(), | ||||||
|  |     expires     timestamptz not null | ||||||
|  | ); | ||||||
|  | 
 | ||||||
|  | -- Unrelatedly, this migration also changes the column type for invites.created to timestamptz (from plain timestamp) | ||||||
|  | -- This does not change anything code-wise, but it's recommended over plain timestamp because plain timestamp does not handle timezones correctly | ||||||
|  | alter table invites alter column created type timestamptz; | ||||||
|  | alter table invites alter column created set default now(); | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue