From e5723360a7269940ada8d17e90d353dc3ffed1f4 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 1 Jan 2023 00:34:38 +0100 Subject: [PATCH] feat: add token IDs, store tokens in db for early invalidation --- backend/db/tokens.go | 98 ++++++++++++++++++++++++++++++ backend/routes/auth/discord.go | 21 ++++++- backend/routes/auth/routes.go | 5 ++ backend/routes/auth/tokens.go | 72 ++++++++++++++++++++++ backend/server/auth.go | 22 +++++++ backend/server/auth/auth.go | 24 +++++--- scripts/migrate/003_add_tokens.sql | 15 +++++ 7 files changed, 248 insertions(+), 9 deletions(-) create mode 100644 backend/db/tokens.go create mode 100644 backend/routes/auth/tokens.go create mode 100644 scripts/migrate/003_add_tokens.sql diff --git a/backend/db/tokens.go b/backend/db/tokens.go new file mode 100644 index 0000000..6e07367 --- /dev/null +++ b/backend/db/tokens.go @@ -0,0 +1,98 @@ +package db + +import ( + "context" + "time" + + "emperror.dev/errors" + "github.com/georgysavva/scany/pgxscan" + "github.com/jackc/pgx/v4" + "github.com/rs/xid" +) + +type Token struct { + UserID xid.ID + TokenID xid.ID + Invalidated bool + Created time.Time + Expires time.Time +} + +func (db *DB) TokenValid(ctx context.Context, userID, tokenID xid.ID) (valid bool, err error) { + sql, args, err := sq.Select("*").From("tokens"). + Where("user_id = ?", userID). + Where("token_id = ?", tokenID). + ToSql() + if err != nil { + return false, errors.Wrap(err, "building sql") + } + + var t Token + err = pgxscan.Get(ctx, db, &t, sql, args...) + if err != nil { + if errors.Cause(err) == pgx.ErrNoRows { + return false, nil + } + + return false, errors.Wrap(err, "getting from database") + } + + now := time.Now() + return !t.Invalidated && t.Created.Before(now) && t.Expires.After(now), nil +} + +func (db *DB) Tokens(ctx context.Context, userID xid.ID) (ts []Token, err error) { + sql, args, err := sq.Select("*").From("tokens"). + Where("user_id = ?", userID). + Where("expires > ?", time.Now()). + OrderBy("created"). + ToSql() + if err != nil { + return nil, errors.Wrap(err, "building sql") + } + + err = pgxscan.Select(ctx, db, &ts, sql, args...) + if err != nil { + return nil, errors.Wrap(err, "getting from database") + } + return ts, nil +} + +// 3 months, might be customizable later +const ExpiryTime = 3 * 30 * 24 * time.Hour + +// SaveToken saves a token to the database. +func (db *DB) SaveToken(ctx context.Context, userID xid.ID, tokenID xid.ID) (t Token, err error) { + sql, args, err := sq.Insert("tokens"). + Columns("user_id", "token_id", "expires"). + Values(userID, tokenID, time.Now().Add(ExpiryTime)). + Suffix("RETURNING *"). + ToSql() + if err != nil { + return t, errors.Wrap(err, "building sql") + } + + err = pgxscan.Get(ctx, db, &t, sql, args...) + if err != nil { + return t, errors.Wrap(err, "inserting token") + } + return t, nil +} + +func (db *DB) InvalidateToken(ctx context.Context, userID xid.ID, tokenID xid.ID) (t Token, err error) { + sql, args, err := sq.Update("tokens"). + Where("user_id = ?"). + Where("token_id = ?"). + Set("invalidated", true). + Suffix("RETURNING *"). + ToSql() + if err != nil { + return t, errors.Wrap(err, "building sql") + } + + err = pgxscan.Get(ctx, db, &t, sql, args...) + if err != nil { + return t, errors.Wrap(err, "invalidating token") + } + return t, nil +} diff --git a/backend/routes/auth/discord.go b/backend/routes/auth/discord.go index 819f6d1..8b1bed5 100644 --- a/backend/routes/auth/discord.go +++ b/backend/routes/auth/discord.go @@ -11,6 +11,7 @@ import ( "github.com/bwmarrin/discordgo" "github.com/go-chi/render" "github.com/mediocregopher/radix/v4" + "github.com/rs/xid" "golang.org/x/oauth2" ) @@ -81,11 +82,19 @@ func (s *Server) discordCallback(w http.ResponseWriter, r *http.Request) error { log.Errorf("updating user %v with Discord info: %v", u.ID, err) } - token, err := s.Auth.CreateToken(u.ID) + // TODO: implement user + token permissions + tokenID := xid.New() + token, err := s.Auth.CreateToken(u.ID, tokenID, false, true) if err != nil { return err } + // save token to database + _, err = s.DB.SaveToken(ctx, u.ID, tokenID) + if err != nil { + return errors.Wrap(err, "saving token to database") + } + render.JSON(w, r, discordCallbackResponse{ HasAccount: true, Token: token, @@ -206,11 +215,19 @@ func (s *Server) discordSignup(w http.ResponseWriter, r *http.Request) error { } // create token - token, err := s.Auth.CreateToken(u.ID) + // TODO: implement user + token permissions + tokenID := xid.New() + token, err := s.Auth.CreateToken(u.ID, tokenID, false, true) if err != nil { return errors.Wrap(err, "creating token") } + // save token to database + _, err = s.DB.SaveToken(ctx, u.ID, tokenID) + if err != nil { + return errors.Wrap(err, "saving token to database") + } + // return user render.JSON(w, r, signupResponse{ User: *dbUserToUserResponse(u), diff --git a/backend/routes/auth/routes.go b/backend/routes/auth/routes.go index b5324fc..67f148c 100644 --- a/backend/routes/auth/routes.go +++ b/backend/routes/auth/routes.go @@ -67,6 +67,11 @@ func Mount(srv *server.Server, r chi.Router) { // invite routes r.With(server.MustAuth).Get("/invites", server.WrapHandler(s.getInvites)) r.With(server.MustAuth).Post("/invites", server.WrapHandler(s.createInvite)) + + // tokens + r.With(server.MustAuth).Get("/tokens", server.WrapHandler(s.getTokens)) + r.With(server.MustAuth).Post("/tokens", server.WrapHandler(s.createToken)) + r.With(server.MustAuth).Delete("/tokens", server.WrapHandler(s.deleteToken)) }) } diff --git a/backend/routes/auth/tokens.go b/backend/routes/auth/tokens.go new file mode 100644 index 0000000..447d265 --- /dev/null +++ b/backend/routes/auth/tokens.go @@ -0,0 +1,72 @@ +package auth + +import ( + "net/http" + "time" + + "codeberg.org/u1f320/pronouns.cc/backend/db" + "codeberg.org/u1f320/pronouns.cc/backend/server" + "emperror.dev/errors" + "github.com/go-chi/render" + "github.com/rs/xid" +) + +type getTokenResponse struct { + TokenID xid.ID `json:"id"` + Created time.Time `json:"created"` + Expires time.Time `json:"expires"` +} + +func dbTokenToGetResponse(t db.Token) getTokenResponse { + return getTokenResponse{ + TokenID: t.TokenID, + Created: t.Created, + Expires: t.Expires, + } +} + +func (s *Server) getTokens(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + claims, _ := server.ClaimsFromContext(ctx) + + tokens, err := s.DB.Tokens(ctx, claims.UserID) + if err != nil { + return errors.Wrap(err, "getting tokens") + } + + resps := make([]getTokenResponse, len(tokens)) + for i := range tokens { + resps[i] = dbTokenToGetResponse(tokens[i]) + } + + render.JSON(w, r, resps) + return nil +} + +type deleteTokenResponse struct { + TokenID xid.ID `json:"id"` + Invalidated bool `json:"invalidated"` + Created time.Time `json:"time"` +} + +func (s *Server) deleteToken(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + claims, _ := server.ClaimsFromContext(ctx) + + t, err := s.DB.InvalidateToken(ctx, claims.UserID, claims.TokenID) + if err != nil { + return errors.Wrap(err, "invalidating token") + } + + render.JSON(w, r, deleteTokenResponse{ + TokenID: t.TokenID, + Invalidated: t.Invalidated, + Created: t.Created, + }) + return nil +} + +func (s *Server) createToken(w http.ResponseWriter, r *http.Request) error { + // unimplemented right now + return server.APIError{Code: server.ErrForbidden} +} diff --git a/backend/server/auth.go b/backend/server/auth.go index 6113ba3..915dc77 100644 --- a/backend/server/auth.go +++ b/backend/server/auth.go @@ -5,6 +5,7 @@ import ( "net/http" "strings" + "codeberg.org/u1f320/pronouns.cc/backend/log" "codeberg.org/u1f320/pronouns.cc/backend/server/auth" "github.com/go-chi/render" ) @@ -28,6 +29,27 @@ func (s *Server) maybeAuth(next http.Handler) http.Handler { return } + // "valid" here refers to existence and expiry date, not whether the token is known + valid, err := s.DB.TokenValid(r.Context(), claims.UserID, claims.TokenID) + if err != nil { + log.Errorf("validating token for user %v: %v", claims.UserID, err) + render.Status(r, errCodeStatuses[ErrInternalServerError]) + render.JSON(w, r, APIError{ + Code: ErrInternalServerError, + Message: errCodeMessages[ErrInternalServerError], + }) + return + } + + if !valid { + render.Status(r, errCodeStatuses[ErrInvalidToken]) + render.JSON(w, r, APIError{ + Code: ErrInvalidToken, + Message: errCodeMessages[ErrInvalidToken], + }) + return + } + ctx := context.WithValue(r.Context(), ctxKeyClaims, claims) next.ServeHTTP(w, r.WithContext(ctx)) diff --git a/backend/server/auth/auth.go b/backend/server/auth/auth.go index 8d756ce..07f1072 100644 --- a/backend/server/auth/auth.go +++ b/backend/server/auth/auth.go @@ -14,7 +14,13 @@ import ( // Claims are the claims used in a token. type Claims struct { - UserID xid.ID `json:"sub"` + UserID xid.ID `json:"sub"` + TokenID xid.ID `json:"jti"` + UserIsAdmin bool `json:"adm"` + + // TokenWrite specifies whether this token can be used for write actions. + // If set to false, this token can only be used for read actions. + TokenWrite bool `json:"twr"` jwt.RegisteredClaims } @@ -37,16 +43,20 @@ func New() *Verifier { return &Verifier{key: key} } -const expireDays = 30 +// ExpireDays is after how many days the token will expire. +const ExpireDays = 30 // CreateToken creates a token for the given user ID. // It expires after 30 days. -func (v *Verifier) CreateToken(userID xid.ID) (string, error) { +func (v *Verifier) CreateToken(userID, tokenID xid.ID, isAdmin bool, isWriteToken bool) (token string, err error) { now := time.Now() - expires := now.Add(expireDays * 24 * time.Hour) + expires := now.Add(ExpireDays * 24 * time.Hour) - token := jwt.NewWithClaims(jwt.SigningMethodHS256, Claims{ - UserID: userID, + t := jwt.NewWithClaims(jwt.SigningMethodHS256, Claims{ + UserID: userID, + TokenID: tokenID, + UserIsAdmin: isAdmin, + TokenWrite: isWriteToken, RegisteredClaims: jwt.RegisteredClaims{ Issuer: "pronouns", ExpiresAt: jwt.NewNumericDate(expires), @@ -55,7 +65,7 @@ func (v *Verifier) CreateToken(userID xid.ID) (string, error) { }, }) - return token.SignedString(v.key) + return t.SignedString(v.key) } // Claims parses the given token and returns its Claims. diff --git a/scripts/migrate/003_add_tokens.sql b/scripts/migrate/003_add_tokens.sql new file mode 100644 index 0000000..a6d81ef --- /dev/null +++ b/scripts/migrate/003_add_tokens.sql @@ -0,0 +1,15 @@ +-- +migrate Up + +-- 2022-12-23: Add database-backed tokens +create table tokens ( + user_id text not null references users (id) on delete cascade, + token_id text primary key, + invalidated boolean not null default false, + created timestamptz not null default now(), + expires timestamptz not null +); + +-- Unrelatedly, this migration also changes the column type for invites.created to timestamptz (from plain timestamp) +-- This does not change anything code-wise, but it's recommended over plain timestamp because plain timestamp does not handle timezones correctly +alter table invites alter column created type timestamptz; +alter table invites alter column created set default now();