feat: add token IDs, store tokens in db for early invalidation
This commit is contained in:
parent
58c1c1794e
commit
e5723360a7
7 changed files with 248 additions and 9 deletions
98
backend/db/tokens.go
Normal file
98
backend/db/tokens.go
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"emperror.dev/errors"
|
||||||
|
"github.com/georgysavva/scany/pgxscan"
|
||||||
|
"github.com/jackc/pgx/v4"
|
||||||
|
"github.com/rs/xid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Token struct {
|
||||||
|
UserID xid.ID
|
||||||
|
TokenID xid.ID
|
||||||
|
Invalidated bool
|
||||||
|
Created time.Time
|
||||||
|
Expires time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) TokenValid(ctx context.Context, userID, tokenID xid.ID) (valid bool, err error) {
|
||||||
|
sql, args, err := sq.Select("*").From("tokens").
|
||||||
|
Where("user_id = ?", userID).
|
||||||
|
Where("token_id = ?", tokenID).
|
||||||
|
ToSql()
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "building sql")
|
||||||
|
}
|
||||||
|
|
||||||
|
var t Token
|
||||||
|
err = pgxscan.Get(ctx, db, &t, sql, args...)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Cause(err) == pgx.ErrNoRows {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, errors.Wrap(err, "getting from database")
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
return !t.Invalidated && t.Created.Before(now) && t.Expires.After(now), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Tokens(ctx context.Context, userID xid.ID) (ts []Token, err error) {
|
||||||
|
sql, args, err := sq.Select("*").From("tokens").
|
||||||
|
Where("user_id = ?", userID).
|
||||||
|
Where("expires > ?", time.Now()).
|
||||||
|
OrderBy("created").
|
||||||
|
ToSql()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "building sql")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pgxscan.Select(ctx, db, &ts, sql, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getting from database")
|
||||||
|
}
|
||||||
|
return ts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3 months, might be customizable later
|
||||||
|
const ExpiryTime = 3 * 30 * 24 * time.Hour
|
||||||
|
|
||||||
|
// SaveToken saves a token to the database.
|
||||||
|
func (db *DB) SaveToken(ctx context.Context, userID xid.ID, tokenID xid.ID) (t Token, err error) {
|
||||||
|
sql, args, err := sq.Insert("tokens").
|
||||||
|
Columns("user_id", "token_id", "expires").
|
||||||
|
Values(userID, tokenID, time.Now().Add(ExpiryTime)).
|
||||||
|
Suffix("RETURNING *").
|
||||||
|
ToSql()
|
||||||
|
if err != nil {
|
||||||
|
return t, errors.Wrap(err, "building sql")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pgxscan.Get(ctx, db, &t, sql, args...)
|
||||||
|
if err != nil {
|
||||||
|
return t, errors.Wrap(err, "inserting token")
|
||||||
|
}
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) InvalidateToken(ctx context.Context, userID xid.ID, tokenID xid.ID) (t Token, err error) {
|
||||||
|
sql, args, err := sq.Update("tokens").
|
||||||
|
Where("user_id = ?").
|
||||||
|
Where("token_id = ?").
|
||||||
|
Set("invalidated", true).
|
||||||
|
Suffix("RETURNING *").
|
||||||
|
ToSql()
|
||||||
|
if err != nil {
|
||||||
|
return t, errors.Wrap(err, "building sql")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pgxscan.Get(ctx, db, &t, sql, args...)
|
||||||
|
if err != nil {
|
||||||
|
return t, errors.Wrap(err, "invalidating token")
|
||||||
|
}
|
||||||
|
return t, nil
|
||||||
|
}
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"github.com/bwmarrin/discordgo"
|
"github.com/bwmarrin/discordgo"
|
||||||
"github.com/go-chi/render"
|
"github.com/go-chi/render"
|
||||||
"github.com/mediocregopher/radix/v4"
|
"github.com/mediocregopher/radix/v4"
|
||||||
|
"github.com/rs/xid"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -81,11 +82,19 @@ func (s *Server) discordCallback(w http.ResponseWriter, r *http.Request) error {
|
||||||
log.Errorf("updating user %v with Discord info: %v", u.ID, err)
|
log.Errorf("updating user %v with Discord info: %v", u.ID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := s.Auth.CreateToken(u.ID)
|
// TODO: implement user + token permissions
|
||||||
|
tokenID := xid.New()
|
||||||
|
token, err := s.Auth.CreateToken(u.ID, tokenID, false, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// save token to database
|
||||||
|
_, err = s.DB.SaveToken(ctx, u.ID, tokenID)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "saving token to database")
|
||||||
|
}
|
||||||
|
|
||||||
render.JSON(w, r, discordCallbackResponse{
|
render.JSON(w, r, discordCallbackResponse{
|
||||||
HasAccount: true,
|
HasAccount: true,
|
||||||
Token: token,
|
Token: token,
|
||||||
|
@ -206,11 +215,19 @@ func (s *Server) discordSignup(w http.ResponseWriter, r *http.Request) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// create token
|
// create token
|
||||||
token, err := s.Auth.CreateToken(u.ID)
|
// TODO: implement user + token permissions
|
||||||
|
tokenID := xid.New()
|
||||||
|
token, err := s.Auth.CreateToken(u.ID, tokenID, false, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "creating token")
|
return errors.Wrap(err, "creating token")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// save token to database
|
||||||
|
_, err = s.DB.SaveToken(ctx, u.ID, tokenID)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "saving token to database")
|
||||||
|
}
|
||||||
|
|
||||||
// return user
|
// return user
|
||||||
render.JSON(w, r, signupResponse{
|
render.JSON(w, r, signupResponse{
|
||||||
User: *dbUserToUserResponse(u),
|
User: *dbUserToUserResponse(u),
|
||||||
|
|
|
@ -67,6 +67,11 @@ func Mount(srv *server.Server, r chi.Router) {
|
||||||
// invite routes
|
// invite routes
|
||||||
r.With(server.MustAuth).Get("/invites", server.WrapHandler(s.getInvites))
|
r.With(server.MustAuth).Get("/invites", server.WrapHandler(s.getInvites))
|
||||||
r.With(server.MustAuth).Post("/invites", server.WrapHandler(s.createInvite))
|
r.With(server.MustAuth).Post("/invites", server.WrapHandler(s.createInvite))
|
||||||
|
|
||||||
|
// tokens
|
||||||
|
r.With(server.MustAuth).Get("/tokens", server.WrapHandler(s.getTokens))
|
||||||
|
r.With(server.MustAuth).Post("/tokens", server.WrapHandler(s.createToken))
|
||||||
|
r.With(server.MustAuth).Delete("/tokens", server.WrapHandler(s.deleteToken))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
72
backend/routes/auth/tokens.go
Normal file
72
backend/routes/auth/tokens.go
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"codeberg.org/u1f320/pronouns.cc/backend/db"
|
||||||
|
"codeberg.org/u1f320/pronouns.cc/backend/server"
|
||||||
|
"emperror.dev/errors"
|
||||||
|
"github.com/go-chi/render"
|
||||||
|
"github.com/rs/xid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type getTokenResponse struct {
|
||||||
|
TokenID xid.ID `json:"id"`
|
||||||
|
Created time.Time `json:"created"`
|
||||||
|
Expires time.Time `json:"expires"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func dbTokenToGetResponse(t db.Token) getTokenResponse {
|
||||||
|
return getTokenResponse{
|
||||||
|
TokenID: t.TokenID,
|
||||||
|
Created: t.Created,
|
||||||
|
Expires: t.Expires,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getTokens(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
ctx := r.Context()
|
||||||
|
claims, _ := server.ClaimsFromContext(ctx)
|
||||||
|
|
||||||
|
tokens, err := s.DB.Tokens(ctx, claims.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "getting tokens")
|
||||||
|
}
|
||||||
|
|
||||||
|
resps := make([]getTokenResponse, len(tokens))
|
||||||
|
for i := range tokens {
|
||||||
|
resps[i] = dbTokenToGetResponse(tokens[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
render.JSON(w, r, resps)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type deleteTokenResponse struct {
|
||||||
|
TokenID xid.ID `json:"id"`
|
||||||
|
Invalidated bool `json:"invalidated"`
|
||||||
|
Created time.Time `json:"time"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) deleteToken(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
ctx := r.Context()
|
||||||
|
claims, _ := server.ClaimsFromContext(ctx)
|
||||||
|
|
||||||
|
t, err := s.DB.InvalidateToken(ctx, claims.UserID, claims.TokenID)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "invalidating token")
|
||||||
|
}
|
||||||
|
|
||||||
|
render.JSON(w, r, deleteTokenResponse{
|
||||||
|
TokenID: t.TokenID,
|
||||||
|
Invalidated: t.Invalidated,
|
||||||
|
Created: t.Created,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) createToken(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
// unimplemented right now
|
||||||
|
return server.APIError{Code: server.ErrForbidden}
|
||||||
|
}
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"codeberg.org/u1f320/pronouns.cc/backend/log"
|
||||||
"codeberg.org/u1f320/pronouns.cc/backend/server/auth"
|
"codeberg.org/u1f320/pronouns.cc/backend/server/auth"
|
||||||
"github.com/go-chi/render"
|
"github.com/go-chi/render"
|
||||||
)
|
)
|
||||||
|
@ -28,6 +29,27 @@ func (s *Server) maybeAuth(next http.Handler) http.Handler {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// "valid" here refers to existence and expiry date, not whether the token is known
|
||||||
|
valid, err := s.DB.TokenValid(r.Context(), claims.UserID, claims.TokenID)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("validating token for user %v: %v", claims.UserID, err)
|
||||||
|
render.Status(r, errCodeStatuses[ErrInternalServerError])
|
||||||
|
render.JSON(w, r, APIError{
|
||||||
|
Code: ErrInternalServerError,
|
||||||
|
Message: errCodeMessages[ErrInternalServerError],
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
render.Status(r, errCodeStatuses[ErrInvalidToken])
|
||||||
|
render.JSON(w, r, APIError{
|
||||||
|
Code: ErrInvalidToken,
|
||||||
|
Message: errCodeMessages[ErrInvalidToken],
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
ctx := context.WithValue(r.Context(), ctxKeyClaims, claims)
|
ctx := context.WithValue(r.Context(), ctxKeyClaims, claims)
|
||||||
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
|
|
@ -14,7 +14,13 @@ import (
|
||||||
|
|
||||||
// Claims are the claims used in a token.
|
// Claims are the claims used in a token.
|
||||||
type Claims struct {
|
type Claims struct {
|
||||||
UserID xid.ID `json:"sub"`
|
UserID xid.ID `json:"sub"`
|
||||||
|
TokenID xid.ID `json:"jti"`
|
||||||
|
UserIsAdmin bool `json:"adm"`
|
||||||
|
|
||||||
|
// TokenWrite specifies whether this token can be used for write actions.
|
||||||
|
// If set to false, this token can only be used for read actions.
|
||||||
|
TokenWrite bool `json:"twr"`
|
||||||
|
|
||||||
jwt.RegisteredClaims
|
jwt.RegisteredClaims
|
||||||
}
|
}
|
||||||
|
@ -37,16 +43,20 @@ func New() *Verifier {
|
||||||
return &Verifier{key: key}
|
return &Verifier{key: key}
|
||||||
}
|
}
|
||||||
|
|
||||||
const expireDays = 30
|
// ExpireDays is after how many days the token will expire.
|
||||||
|
const ExpireDays = 30
|
||||||
|
|
||||||
// CreateToken creates a token for the given user ID.
|
// CreateToken creates a token for the given user ID.
|
||||||
// It expires after 30 days.
|
// It expires after 30 days.
|
||||||
func (v *Verifier) CreateToken(userID xid.ID) (string, error) {
|
func (v *Verifier) CreateToken(userID, tokenID xid.ID, isAdmin bool, isWriteToken bool) (token string, err error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
expires := now.Add(expireDays * 24 * time.Hour)
|
expires := now.Add(ExpireDays * 24 * time.Hour)
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, Claims{
|
t := jwt.NewWithClaims(jwt.SigningMethodHS256, Claims{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
|
TokenID: tokenID,
|
||||||
|
UserIsAdmin: isAdmin,
|
||||||
|
TokenWrite: isWriteToken,
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
Issuer: "pronouns",
|
Issuer: "pronouns",
|
||||||
ExpiresAt: jwt.NewNumericDate(expires),
|
ExpiresAt: jwt.NewNumericDate(expires),
|
||||||
|
@ -55,7 +65,7 @@ func (v *Verifier) CreateToken(userID xid.ID) (string, error) {
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
return token.SignedString(v.key)
|
return t.SignedString(v.key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Claims parses the given token and returns its Claims.
|
// Claims parses the given token and returns its Claims.
|
||||||
|
|
15
scripts/migrate/003_add_tokens.sql
Normal file
15
scripts/migrate/003_add_tokens.sql
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
-- +migrate Up
|
||||||
|
|
||||||
|
-- 2022-12-23: Add database-backed tokens
|
||||||
|
create table tokens (
|
||||||
|
user_id text not null references users (id) on delete cascade,
|
||||||
|
token_id text primary key,
|
||||||
|
invalidated boolean not null default false,
|
||||||
|
created timestamptz not null default now(),
|
||||||
|
expires timestamptz not null
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Unrelatedly, this migration also changes the column type for invites.created to timestamptz (from plain timestamp)
|
||||||
|
-- This does not change anything code-wise, but it's recommended over plain timestamp because plain timestamp does not handle timezones correctly
|
||||||
|
alter table invites alter column created type timestamptz;
|
||||||
|
alter table invites alter column created set default now();
|
Loading…
Reference in a new issue