initial commit
This commit is contained in:
commit
5a75f99720
20 changed files with 2239 additions and 0 deletions
49
backend/db/db.go
Normal file
49
backend/db/db.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
"github.com/mediocregopher/radix/v4"
|
||||
)
|
||||
|
||||
var sq = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
type DB struct {
|
||||
*pgxpool.Pool
|
||||
|
||||
Redis radix.Client
|
||||
}
|
||||
|
||||
func New(dsn string) (*DB, error) {
|
||||
pool, err := pgxpool.Connect(context.Background(), dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
redis, err := (&radix.PoolConfig{}).New(context.Background(), "tcp", os.Getenv("REDIS"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db := &DB{
|
||||
Pool: pool,
|
||||
Redis: redis,
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// MultiCmd executes the given Redis commands in order.
|
||||
// If any return an error, the function is aborted.
|
||||
func (db *DB) MultiCmd(ctx context.Context, cmds ...radix.Action) error {
|
||||
for _, cmd := range cmds {
|
||||
err := db.Redis.Do(ctx, cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
68
backend/db/user.go
Normal file
68
backend/db/user.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"regexp"
|
||||
|
||||
"emperror.dev/errors"
|
||||
"github.com/georgysavva/scany/pgxscan"
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID xid.ID
|
||||
Username string
|
||||
DisplayName *string
|
||||
Bio *string
|
||||
|
||||
AvatarSource *string
|
||||
AvatarURL *string
|
||||
Links []string
|
||||
|
||||
Discord *string
|
||||
}
|
||||
|
||||
// usernames must match this regex
|
||||
var usernameRegex = regexp.MustCompile(`[\w-.]{2,40}`)
|
||||
|
||||
const (
|
||||
ErrUsernameTaken = errors.Sentinel("username is already taken")
|
||||
|
||||
ErrInvalidUsername = errors.Sentinel("username contains invalid characters")
|
||||
ErrUsernameTooShort = errors.Sentinel("username is too short")
|
||||
ErrUsernameTooLong = errors.Sentinel("username is too long")
|
||||
)
|
||||
|
||||
// CreateUser creates a user with the given username.
|
||||
func (db *DB) CreateUser(ctx context.Context, username string) (u User, err error) {
|
||||
// check if the username is valid
|
||||
// if not, return an error depending on what failed
|
||||
if !usernameRegex.MatchString(username) {
|
||||
if len(username) < 2 {
|
||||
return u, ErrUsernameTooShort
|
||||
} else if len(username) > 40 {
|
||||
return u, ErrUsernameTooLong
|
||||
}
|
||||
|
||||
return u, ErrInvalidUsername
|
||||
}
|
||||
|
||||
sql, args, err := sq.Insert("users").Columns("id", "username").Values(xid.New(), username).Suffix("RETURNING *").ToSql()
|
||||
if err != nil {
|
||||
return u, errors.Wrap(err, "building sql")
|
||||
}
|
||||
|
||||
err = pgxscan.Get(ctx, db, &u, sql, args...)
|
||||
if err != nil {
|
||||
if v, ok := errors.Cause(err).(*pgconn.PgError); ok {
|
||||
if v.Code == "23505" { // unique constraint violation
|
||||
return u, ErrUsernameTaken
|
||||
}
|
||||
}
|
||||
|
||||
return u, errors.Cause(err)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
69
backend/log/log.go
Normal file
69
backend/log/log.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
// Package log contains a global Zap logger.
|
||||
package log
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
var Logger *zap.Logger
|
||||
var SugaredLogger *zap.SugaredLogger
|
||||
|
||||
func init() {
|
||||
zcfg := zap.NewProductionConfig()
|
||||
|
||||
zcfg.Level.SetLevel(zap.InfoLevel)
|
||||
zcfg.Encoding = "console"
|
||||
zcfg.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
|
||||
zcfg.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
|
||||
|
||||
logger, err := zcfg.Build()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
zap.RedirectStdLog(logger)
|
||||
|
||||
Logger = logger
|
||||
SugaredLogger = Logger.WithOptions(zap.AddCallerSkip(1)).Sugar()
|
||||
}
|
||||
|
||||
func Debug(v ...any) {
|
||||
SugaredLogger.Debug(v...)
|
||||
}
|
||||
|
||||
func Info(v ...any) {
|
||||
SugaredLogger.Info(v...)
|
||||
}
|
||||
|
||||
func Warn(v ...any) {
|
||||
SugaredLogger.Warn(v...)
|
||||
}
|
||||
|
||||
func Error(v ...any) {
|
||||
SugaredLogger.Error(v...)
|
||||
}
|
||||
|
||||
func Fatal(v ...any) {
|
||||
SugaredLogger.Fatal(v...)
|
||||
}
|
||||
|
||||
func Debugf(tmpl string, v ...any) {
|
||||
SugaredLogger.Debugf(tmpl, v...)
|
||||
}
|
||||
|
||||
func Infof(tmpl string, v ...any) {
|
||||
SugaredLogger.Infof(tmpl, v...)
|
||||
}
|
||||
|
||||
func Warnf(tmpl string, v ...any) {
|
||||
SugaredLogger.Warnf(tmpl, v...)
|
||||
}
|
||||
|
||||
func Errorf(tmpl string, v ...any) {
|
||||
SugaredLogger.Errorf(tmpl, v...)
|
||||
}
|
||||
|
||||
func Fatalf(tmpl string, v ...any) {
|
||||
SugaredLogger.Fatalf(tmpl, v...)
|
||||
}
|
51
backend/main.go
Normal file
51
backend/main.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"gitlab.com/1f320/pronouns/backend/log"
|
||||
"gitlab.com/1f320/pronouns/backend/server"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
func main() {
|
||||
err := godotenv.Load()
|
||||
if err != nil {
|
||||
log.Fatalf("Error loading .env file: %v", err)
|
||||
}
|
||||
|
||||
port := ":" + os.Getenv("PORT")
|
||||
|
||||
s, err := server.New()
|
||||
if err != nil {
|
||||
log.Fatalf("Error creating server: %v", err)
|
||||
}
|
||||
|
||||
// mount api routes
|
||||
mountRoutes(s)
|
||||
|
||||
e := make(chan error)
|
||||
|
||||
// run server in another goroutine (for gracefully shutting down, see below)
|
||||
go func() {
|
||||
e <- http.ListenAndServe(port, s.Router)
|
||||
}()
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
defer stop()
|
||||
|
||||
log.Infof("API server running at %v!", port)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("Interrupt signal received, shutting down...")
|
||||
s.DB.Close()
|
||||
return
|
||||
case err := <-e:
|
||||
log.Fatalf("Error running server: %v", err)
|
||||
}
|
||||
}
|
15
backend/routes.go
Normal file
15
backend/routes.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"gitlab.com/1f320/pronouns/backend/routes/auth"
|
||||
"gitlab.com/1f320/pronouns/backend/server"
|
||||
)
|
||||
|
||||
// mountRoutes mounts all API routes on the server's router.
|
||||
// they are all mounted under /v1/
|
||||
func mountRoutes(s *server.Server) {
|
||||
s.Router.Route("/v1", func(r chi.Router) {
|
||||
auth.Mount(s, r)
|
||||
})
|
||||
}
|
18
backend/routes/auth/discord.go
Normal file
18
backend/routes/auth/discord.go
Normal file
|
@ -0,0 +1,18 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
var oauthConfig = oauth2.Config{
|
||||
ClientID: os.Getenv("DISCORD_CLIENT_ID"),
|
||||
ClientSecret: os.Getenv("DISCORD_CLIENT_SECRET"),
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://discord.com/api/oauth2/authorize",
|
||||
TokenURL: "https://discord.com/api/oauth2/token",
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
},
|
||||
Scopes: []string{"identify"},
|
||||
}
|
27
backend/routes/auth/routes.go
Normal file
27
backend/routes/auth/routes.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"gitlab.com/1f320/pronouns/backend/server"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
*server.Server
|
||||
}
|
||||
|
||||
func Mount(srv *server.Server, r chi.Router) {
|
||||
s := &Server{srv}
|
||||
_ = s
|
||||
|
||||
r.Route("/auth/discord", func(r chi.Router) {
|
||||
r.Get("/authorize", nil) // generate csrf token, returns URL
|
||||
r.Get("/callback", nil) // takes code + state, validates it, returns token OR discord signup ticket
|
||||
r.Get("/signup", nil) // takes discord signup ticket to register account
|
||||
|
||||
r.Get("/test", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("hello world!"))
|
||||
})
|
||||
})
|
||||
}
|
55
backend/server/auth.go
Normal file
55
backend/server/auth.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"gitlab.com/1f320/pronouns/backend/server/auth"
|
||||
)
|
||||
|
||||
// maybeAuth is a globally-used middleware.
|
||||
func (s *Server) maybeAuth(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.Header.Get("Authorization")
|
||||
if token == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := s.Auth.Claims(token)
|
||||
if err != nil {
|
||||
// if we get here, a token was supplied but it's invalid--return an error
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), ctxKeyClaims, claims)
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
|
||||
// MustAuth makes a valid token required
|
||||
func MustAuth(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
_, ok := ClaimsFromContext(r.Context())
|
||||
if !ok {
|
||||
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
|
||||
// ClaimsFromContext returns the auth.Claims in the context, if any.
|
||||
func ClaimsFromContext(ctx context.Context) (auth.Claims, bool) {
|
||||
v := ctx.Value(ctxKeyClaims)
|
||||
if v == nil {
|
||||
return auth.Claims{}, false
|
||||
}
|
||||
|
||||
claims, ok := v.(auth.Claims)
|
||||
return claims, ok
|
||||
}
|
81
backend/server/auth/auth.go
Normal file
81
backend/server/auth/auth.go
Normal file
|
@ -0,0 +1,81 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"emperror.dev/errors"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/rs/xid"
|
||||
"gitlab.com/1f320/pronouns/backend/log"
|
||||
)
|
||||
|
||||
// Claims are the claims used in a token.
|
||||
type Claims struct {
|
||||
UserID xid.ID `json:"sub"`
|
||||
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
type Verifier struct {
|
||||
key []byte
|
||||
}
|
||||
|
||||
func New() *Verifier {
|
||||
raw := os.Getenv("HMAC_KEY")
|
||||
if raw == "" {
|
||||
log.Fatal("$HMAC_KEY is not set")
|
||||
}
|
||||
|
||||
key, err := base64.URLEncoding.DecodeString(raw)
|
||||
if err != nil {
|
||||
log.Fatal("$HMAC_KEY is not a valid base 64 string")
|
||||
}
|
||||
|
||||
return &Verifier{key: key}
|
||||
}
|
||||
|
||||
const expireDays = 30
|
||||
|
||||
// CreateToken creates a token for the given user ID.
|
||||
// It expires after 30 days.
|
||||
func (v *Verifier) CreateToken(userID xid.ID) (string, error) {
|
||||
now := time.Now()
|
||||
expires := now.Add(expireDays * 24 * time.Hour)
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, Claims{
|
||||
UserID: userID,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "pronouns",
|
||||
ExpiresAt: jwt.NewNumericDate(expires),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
})
|
||||
|
||||
return token.SignedString(v.key)
|
||||
}
|
||||
|
||||
// Claims parses the given token and returns its Claims.
|
||||
// If the token is invalid, returns an error.
|
||||
func (v *Verifier) Claims(token string) (c Claims, err error) {
|
||||
parsed, err := jwt.ParseWithClaims(token, &Claims{}, func(t *jwt.Token) (interface{}, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf(`unexpected signing method "%v"`, t.Header["alg"])
|
||||
}
|
||||
|
||||
return v.key, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return c, errors.Wrap(err, "parsing token")
|
||||
}
|
||||
|
||||
if c, ok := parsed.Claims.(*Claims); ok && parsed.Valid {
|
||||
return *c, nil
|
||||
}
|
||||
|
||||
return c, fmt.Errorf("unknown claims type %T", parsed.Claims)
|
||||
}
|
74
backend/server/errors.go
Normal file
74
backend/server/errors.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/render"
|
||||
"gitlab.com/1f320/pronouns/backend/log"
|
||||
)
|
||||
|
||||
// WrapHandler wraps a modified http.HandlerFunc into a stdlib-compatible one.
|
||||
// The inner HandlerFunc additionally returns an error.
|
||||
func WrapHandler(hn func(w http.ResponseWriter, r *http.Request) error) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
err := hn(w, r)
|
||||
if err != nil {
|
||||
// if the function returned an API error, just render that verbatim
|
||||
// we can assume that it also logged the error (if that was needed)
|
||||
if apiErr, ok := err.(APIError); ok {
|
||||
apiErr.prepare()
|
||||
|
||||
render.Status(r, apiErr.Status)
|
||||
render.JSON(w, r, apiErr)
|
||||
return
|
||||
}
|
||||
|
||||
// otherwise, we log the error and return an internal server error message
|
||||
log.Errorf("error in http handler: %v", err)
|
||||
|
||||
apiErr := APIError{Code: ErrInternalServerError}
|
||||
apiErr.prepare()
|
||||
|
||||
render.Status(r, apiErr.Status)
|
||||
render.JSON(w, r, apiErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// APIError is an object returned by the API when an error occurs.
|
||||
// It implements the error interface and can be returned by handlers.
|
||||
type APIError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message,omitempty"`
|
||||
|
||||
// Status is
|
||||
Status int `json:"-"`
|
||||
}
|
||||
|
||||
func (e APIError) Error() string {
|
||||
return fmt.Sprintf("%s (code: %d)", e.Message, e.Code)
|
||||
}
|
||||
|
||||
func (e *APIError) prepare() {
|
||||
if e.Status == 0 {
|
||||
e.Status = errCodeStatuses[e.Code]
|
||||
}
|
||||
|
||||
if e.Message == "" {
|
||||
e.Message = errCodeMessages[e.Code]
|
||||
}
|
||||
}
|
||||
|
||||
// Error code constants
|
||||
const (
|
||||
ErrInternalServerError = 500 // catch-all code for unknown errors
|
||||
)
|
||||
|
||||
var errCodeMessages = map[int]string{
|
||||
ErrInternalServerError: "Internal server error",
|
||||
}
|
||||
|
||||
var errCodeStatuses = map[int]int{
|
||||
ErrInternalServerError: http.StatusInternalServerError,
|
||||
}
|
49
backend/server/server.go
Normal file
49
backend/server/server.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"gitlab.com/1f320/pronouns/backend/db"
|
||||
"gitlab.com/1f320/pronouns/backend/server/auth"
|
||||
)
|
||||
|
||||
// Revision is the git commit, filled at build time
|
||||
var Revision = "[unknown]"
|
||||
|
||||
type Server struct {
|
||||
Router *chi.Mux
|
||||
|
||||
DB *db.DB
|
||||
Auth *auth.Verifier
|
||||
}
|
||||
|
||||
func New() (*Server, error) {
|
||||
db, err := db.New(os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
Router: chi.NewMux(),
|
||||
|
||||
DB: db,
|
||||
Auth: auth.New(),
|
||||
}
|
||||
|
||||
if os.Getenv("DEBUG") == "true" {
|
||||
s.Router.Use(middleware.Logger)
|
||||
}
|
||||
s.Router.Use(middleware.Recoverer)
|
||||
// enable authentication for all routes (but don't require it)
|
||||
s.Router.Use(s.maybeAuth)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
type ctxKey int
|
||||
|
||||
const (
|
||||
ctxKeyClaims ctxKey = 1
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue