add frontend template + GET /users/{userRef} route
This commit is contained in:
parent
5a75f99720
commit
580449440a
28 changed files with 1393 additions and 12 deletions
|
@ -2,8 +2,11 @@ package db
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"emperror.dev/errors"
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
"github.com/mediocregopher/radix/v4"
|
||||
|
@ -47,3 +50,69 @@ func (db *DB) MultiCmd(ctx context.Context, cmds ...radix.Action) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetJSON sets the given key to v marshaled as JSON.
|
||||
func (db *DB) SetJSON(ctx context.Context, key string, v any, args ...string) error {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "marshaling json")
|
||||
}
|
||||
|
||||
cmdArgs := make([]string, 0, len(args)+2)
|
||||
cmdArgs = append(cmdArgs, key, string(b))
|
||||
cmdArgs = append(cmdArgs, args...)
|
||||
|
||||
err = db.Redis.Do(ctx, radix.Cmd(nil, "SET", cmdArgs...))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "writing to Redis")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetJSON gets the given key as a JSON object.
|
||||
func (db *DB) GetJSON(ctx context.Context, key string, v any) error {
|
||||
var b []byte
|
||||
|
||||
err := db.Redis.Do(ctx, radix.Cmd(&b, "GET", key))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "reading from Redis")
|
||||
}
|
||||
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if v == nil {
|
||||
return fmt.Errorf("nil pointer passed into GetJSON")
|
||||
}
|
||||
|
||||
err = json.Unmarshal(b, v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "unmarshaling json")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDelJSON gets the given key as a JSON object and deletes it.
|
||||
func (db *DB) GetDelJSON(ctx context.Context, key string, v any) error {
|
||||
var b []byte
|
||||
|
||||
err := db.Redis.Do(ctx, radix.Cmd(&b, "GETDEL", key))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "reading from Redis")
|
||||
}
|
||||
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if v == nil {
|
||||
return fmt.Errorf("nil pointer passed into GetDelJSON")
|
||||
}
|
||||
|
||||
err = json.Unmarshal(b, v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "unmarshaling json")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -5,8 +5,10 @@ import (
|
|||
"regexp"
|
||||
|
||||
"emperror.dev/errors"
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/georgysavva/scany/pgxscan"
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
|
@ -27,8 +29,9 @@ type User struct {
|
|||
var usernameRegex = regexp.MustCompile(`[\w-.]{2,40}`)
|
||||
|
||||
const (
|
||||
ErrUsernameTaken = errors.Sentinel("username is already taken")
|
||||
ErrUserNotFound = errors.Sentinel("user not found")
|
||||
|
||||
ErrUsernameTaken = errors.Sentinel("username is already taken")
|
||||
ErrInvalidUsername = errors.Sentinel("username contains invalid characters")
|
||||
ErrUsernameTooShort = errors.Sentinel("username is too short")
|
||||
ErrUsernameTooLong = errors.Sentinel("username is too long")
|
||||
|
@ -66,3 +69,69 @@ func (db *DB) CreateUser(ctx context.Context, username string) (u User, err erro
|
|||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// DiscordUser fetches a user by Discord user ID.
|
||||
func (db *DB) DiscordUser(ctx context.Context, discordID string) (u User, err error) {
|
||||
sql, args, err := sq.Select("*").From("users").Where("discord = ?", discordID).ToSql()
|
||||
if err != nil {
|
||||
return u, errors.Wrap(err, "building sql")
|
||||
}
|
||||
|
||||
err = pgxscan.Get(ctx, db, &u, sql, args...)
|
||||
if err != nil {
|
||||
if errors.Cause(err) == pgx.ErrNoRows {
|
||||
return u, ErrUserNotFound
|
||||
}
|
||||
return u, errors.Cause(err)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (u *User) UpdateFromDiscord(ctx context.Context, db pgxscan.Querier, du *discordgo.User) error {
|
||||
builder := sq.Update("users").
|
||||
Set("discord_username", du.String()).
|
||||
Where("id = ?", u.ID).
|
||||
Suffix("RETURNING *")
|
||||
|
||||
if u.AvatarSource == nil || *u.AvatarSource == "discord" {
|
||||
builder = builder.
|
||||
Set("avatar_source", "discord").
|
||||
Set("avatar_url", du.AvatarURL("1024"))
|
||||
}
|
||||
|
||||
sql, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "building sql")
|
||||
}
|
||||
|
||||
return pgxscan.Get(ctx, db, u, sql, args...)
|
||||
}
|
||||
|
||||
// User gets a user by ID.
|
||||
func (db *DB) User(ctx context.Context, id xid.ID) (u User, err error) {
|
||||
err = pgxscan.Get(ctx, db, &u, "select * from users where id = $1", id)
|
||||
if err != nil {
|
||||
if errors.Cause(err) == pgx.ErrNoRows {
|
||||
return u, ErrUserNotFound
|
||||
}
|
||||
|
||||
return u, errors.Cause(err)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// Username gets a user by username.
|
||||
func (db *DB) Username(ctx context.Context, name string) (u User, err error) {
|
||||
err = pgxscan.Get(ctx, db, &u, "select * from users where username = $1", name)
|
||||
if err != nil {
|
||||
if errors.Cause(err) == pgx.ErrNoRows {
|
||||
return u, ErrUserNotFound
|
||||
}
|
||||
|
||||
return u, errors.Cause(err)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
|
|
@ -3,13 +3,16 @@ package main
|
|||
import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"gitlab.com/1f320/pronouns/backend/routes/auth"
|
||||
"gitlab.com/1f320/pronouns/backend/routes/user"
|
||||
"gitlab.com/1f320/pronouns/backend/server"
|
||||
)
|
||||
|
||||
// mountRoutes mounts all API routes on the server's router.
|
||||
// they are all mounted under /v1/
|
||||
func mountRoutes(s *server.Server) {
|
||||
// future-proofing for API versions
|
||||
s.Router.Route("/v1", func(r chi.Router) {
|
||||
auth.Mount(s, r)
|
||||
user.Mount(s, r)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,12 +1,18 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/go-chi/render"
|
||||
"gitlab.com/1f320/pronouns/backend/db"
|
||||
"gitlab.com/1f320/pronouns/backend/log"
|
||||
"gitlab.com/1f320/pronouns/backend/server"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
var oauthConfig = oauth2.Config{
|
||||
var discordOAuthConfig = oauth2.Config{
|
||||
ClientID: os.Getenv("DISCORD_CLIENT_ID"),
|
||||
ClientSecret: os.Getenv("DISCORD_CLIENT_SECRET"),
|
||||
Endpoint: oauth2.Endpoint{
|
||||
|
@ -16,3 +22,80 @@ var oauthConfig = oauth2.Config{
|
|||
},
|
||||
Scopes: []string{"identify"},
|
||||
}
|
||||
|
||||
type oauthCallbackRequest struct {
|
||||
Code string `json:"code"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
type discordCallbackResponse struct {
|
||||
HasAccount bool `json:"has_account"` // if true, Token and User will be set. if false, Ticket and Discord will be set
|
||||
|
||||
Token string `json:"token,omitempty"`
|
||||
User *db.User `json:"user,omitempty"`
|
||||
|
||||
Discord string `json:"discord,omitempty"` // username, for UI purposes
|
||||
Ticket string `json:"ticket,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Server) discordCallback(w http.ResponseWriter, r *http.Request) error {
|
||||
ctx := r.Context()
|
||||
|
||||
decoded, err := Decode[oauthCallbackRequest](r)
|
||||
if err != nil {
|
||||
return server.APIError{Code: server.ErrBadRequest}
|
||||
}
|
||||
|
||||
// if the state can't be validated, return
|
||||
if valid, err := s.validateCSRFState(ctx, decoded.State); !valid {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return server.APIError{Code: server.ErrInvalidState}
|
||||
}
|
||||
|
||||
token, err := discordOAuthConfig.Exchange(r.Context(), decoded.Code)
|
||||
if err != nil {
|
||||
log.Errorf("exchanging oauth code: %v", err)
|
||||
|
||||
return server.APIError{Code: server.ErrInvalidOAuthCode}
|
||||
}
|
||||
|
||||
dg, _ := discordgo.New(token.Type() + " " + token.AccessToken)
|
||||
du, err := dg.User("@me")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u, err := s.DB.DiscordUser(ctx, du.ID)
|
||||
if err == nil {
|
||||
err = u.UpdateFromDiscord(ctx, s.DB, du)
|
||||
if err != nil {
|
||||
log.Errorf("updating user %v with Discord info: %v", u.ID, err)
|
||||
}
|
||||
|
||||
token, err := s.Auth.CreateToken(u.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
render.JSON(w, r, discordCallbackResponse{
|
||||
HasAccount: true,
|
||||
Token: token,
|
||||
User: &u,
|
||||
})
|
||||
} else if err != db.ErrUserNotFound { // internal error
|
||||
return err
|
||||
}
|
||||
|
||||
// no user found, so save a ticket
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Decode[T any](r *http.Request) (T, error) {
|
||||
decoded := *new(T)
|
||||
|
||||
return decoded, render.Decode(r, &decoded)
|
||||
}
|
||||
|
|
41
backend/routes/auth/oauth.go
Normal file
41
backend/routes/auth/oauth.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/mediocregopher/radix/v4"
|
||||
)
|
||||
|
||||
// numStates is the number of CSRF states stored in Redis at any one time.
|
||||
// This must be an integer.
|
||||
const numStates = "1000"
|
||||
|
||||
// setCSRFState generates a random string to use as state, then stores that in Redis.
|
||||
func (s *Server) setCSRFState(ctx context.Context) (string, error) {
|
||||
b := make([]byte, 32)
|
||||
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
state := base64.URLEncoding.EncodeToString(b)
|
||||
|
||||
err = s.DB.MultiCmd(ctx,
|
||||
radix.Cmd(nil, "LPUSH", "csrf", state),
|
||||
radix.Cmd(nil, "LTRIM", "csrf", "0", numStates),
|
||||
)
|
||||
return state, err
|
||||
}
|
||||
|
||||
// validateCSRFState checks if the given state exists in Redis.
|
||||
func (s *Server) validateCSRFState(ctx context.Context, state string) (matched bool, err error) {
|
||||
var num int
|
||||
err = s.DB.Redis.Do(ctx, radix.Cmd(&num, "LREM", "csrf", "1", state))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return num > 0, nil
|
||||
}
|
|
@ -3,7 +3,9 @@ package auth
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"emperror.dev/errors"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/render"
|
||||
"gitlab.com/1f320/pronouns/backend/server"
|
||||
)
|
||||
|
||||
|
@ -13,15 +15,46 @@ type Server struct {
|
|||
|
||||
func Mount(srv *server.Server, r chi.Router) {
|
||||
s := &Server{srv}
|
||||
_ = s
|
||||
|
||||
r.Route("/auth/discord", func(r chi.Router) {
|
||||
r.Get("/authorize", nil) // generate csrf token, returns URL
|
||||
r.Get("/callback", nil) // takes code + state, validates it, returns token OR discord signup ticket
|
||||
r.Get("/signup", nil) // takes discord signup ticket to register account
|
||||
r.Route("/auth", func(r chi.Router) {
|
||||
// generate csrf token, returns all supported OAuth provider URLs
|
||||
r.Get("/urls", server.WrapHandler(s.oauthURLs))
|
||||
|
||||
r.Get("/test", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("hello world!"))
|
||||
r.Route("/discord", func(r chi.Router) {
|
||||
// takes code + state, validates it, returns token OR discord signup ticket
|
||||
r.Post("/callback", nil)
|
||||
// takes discord signup ticket to register account
|
||||
r.Post("/signup", nil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
type oauthURLsRequest struct {
|
||||
CallbackURL string `json:"callback_url"`
|
||||
}
|
||||
|
||||
type oauthURLsResponse struct {
|
||||
Discord string `json:"discord"`
|
||||
}
|
||||
|
||||
func (s *Server) oauthURLs(w http.ResponseWriter, r *http.Request) error {
|
||||
req, err := Decode[oauthURLsRequest](r)
|
||||
if err != nil {
|
||||
return server.APIError{Code: server.ErrBadRequest}
|
||||
}
|
||||
|
||||
// generate CSRF state
|
||||
state, err := s.setCSRFState(r.Context())
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "setting CSRF state")
|
||||
}
|
||||
|
||||
// copy Discord config and set redirect url
|
||||
discordCfg := discordOAuthConfig
|
||||
discordCfg.RedirectURL = req.CallbackURL
|
||||
|
||||
render.JSON(w, r, oauthURLsResponse{
|
||||
Discord: discordCfg.AuthCodeURL(state),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
|
70
backend/routes/user/get_user.go
Normal file
70
backend/routes/user/get_user.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package user
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/render"
|
||||
"github.com/rs/xid"
|
||||
"gitlab.com/1f320/pronouns/backend/db"
|
||||
"gitlab.com/1f320/pronouns/backend/log"
|
||||
"gitlab.com/1f320/pronouns/backend/server"
|
||||
)
|
||||
|
||||
type GetUserResponse struct {
|
||||
ID xid.ID `json:"id"`
|
||||
Username string `json:"username"`
|
||||
DisplayName *string `json:"display_name"`
|
||||
Bio *string `json:"bio"`
|
||||
AvatarURL *string `json:"avatar_url"`
|
||||
Links []string `json:"links"`
|
||||
}
|
||||
|
||||
type PartialMember struct {
|
||||
ID xid.ID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
AvatarURL *string `json:"avatar_url"`
|
||||
}
|
||||
|
||||
func dbUserToResponse(u db.User) GetUserResponse {
|
||||
return GetUserResponse{
|
||||
ID: u.ID,
|
||||
Username: u.Username,
|
||||
DisplayName: u.DisplayName,
|
||||
Bio: u.Bio,
|
||||
AvatarURL: u.AvatarURL,
|
||||
Links: u.Links,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) getUser(w http.ResponseWriter, r *http.Request) error {
|
||||
ctx := r.Context()
|
||||
|
||||
userRef := chi.URLParamFromCtx(ctx, "userRef")
|
||||
|
||||
if id, err := xid.FromString(userRef); err == nil {
|
||||
u, err := s.DB.User(ctx, id)
|
||||
if err == nil {
|
||||
render.JSON(w, r, dbUserToResponse(u))
|
||||
return nil
|
||||
} else if err != db.ErrUserNotFound {
|
||||
log.Errorf("Error getting user by ID: %v", err)
|
||||
return err
|
||||
}
|
||||
// otherwise, we fall back to checking usernames
|
||||
}
|
||||
|
||||
u, err := s.DB.Username(ctx, userRef)
|
||||
if err == db.ErrUserNotFound {
|
||||
return server.APIError{
|
||||
Code: server.ErrUserNotFound,
|
||||
}
|
||||
|
||||
} else if err != nil {
|
||||
log.Errorf("Error getting user by username: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
render.JSON(w, r, dbUserToResponse(u))
|
||||
return nil
|
||||
}
|
20
backend/routes/user/routes.go
Normal file
20
backend/routes/user/routes.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package user
|
||||
|
||||
import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"gitlab.com/1f320/pronouns/backend/server"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
*server.Server
|
||||
}
|
||||
|
||||
func Mount(srv *server.Server, r chi.Router) {
|
||||
s := &Server{srv}
|
||||
|
||||
r.Route("/users", func(r chi.Router) {
|
||||
r.With(server.MustAuth).Get("/@me", server.WrapHandler(nil))
|
||||
|
||||
r.Get("/{userRef}", server.WrapHandler(s.getUser))
|
||||
})
|
||||
}
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/render"
|
||||
"gitlab.com/1f320/pronouns/backend/server/auth"
|
||||
)
|
||||
|
||||
|
@ -34,7 +35,12 @@ func MustAuth(next http.Handler) http.Handler {
|
|||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
_, ok := ClaimsFromContext(r.Context())
|
||||
if !ok {
|
||||
|
||||
render.Status(r, errCodeStatuses[ErrForbidden])
|
||||
render.JSON(w, r, APIError{
|
||||
Code: ErrForbidden,
|
||||
Message: errCodeMessages[ErrForbidden],
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
|
|
|
@ -62,13 +62,36 @@ func (e *APIError) prepare() {
|
|||
|
||||
// Error code constants
|
||||
const (
|
||||
ErrBadRequest = 400
|
||||
ErrForbidden = 403
|
||||
ErrInternalServerError = 500 // catch-all code for unknown errors
|
||||
|
||||
// Login/authorize error codes
|
||||
ErrInvalidState = 1001
|
||||
ErrInvalidOAuthCode = 1002
|
||||
|
||||
// User-related error codes
|
||||
ErrUserNotFound = 2001
|
||||
)
|
||||
|
||||
var errCodeMessages = map[int]string{
|
||||
ErrBadRequest: "Bad request",
|
||||
ErrForbidden: "Forbidden",
|
||||
ErrInternalServerError: "Internal server error",
|
||||
|
||||
ErrInvalidState: "Invalid OAuth state",
|
||||
ErrInvalidOAuthCode: "Invalid OAuth code",
|
||||
|
||||
ErrUserNotFound: "User not found",
|
||||
}
|
||||
|
||||
var errCodeStatuses = map[int]int{
|
||||
ErrBadRequest: http.StatusBadRequest,
|
||||
ErrForbidden: http.StatusForbidden,
|
||||
ErrInternalServerError: http.StatusInternalServerError,
|
||||
|
||||
ErrInvalidState: http.StatusBadRequest,
|
||||
ErrInvalidOAuthCode: http.StatusForbidden,
|
||||
|
||||
ErrUserNotFound: http.StatusNotFound,
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue