package sql import ( "context" "emperror.dev/errors" "git.sleepycat.moe/sam/mercury/internal/database" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/keegancsmith/sqlf" "github.com/oklog/ulid/v2" "golang.org/x/crypto/bcrypt" ) // AccountStore is the interface to accounts in the database. type AccountStore struct { q Querier } // NewAccountStore creates a new AccountStore instance. func NewAccountStore(q Querier) *AccountStore { return &AccountStore{q: q} } // ByID gets an account by its ID. func (s *AccountStore) ByID(ctx context.Context, id ulid.ULID) (a database.Account, err error) { q := sqlf.Sprintf("SELECT * FROM accounts WHERE id = %s", id) a, err = Get[database.Account](ctx, s.q, q) if err != nil { if errors.Cause(err) == pgx.ErrNoRows { return a, ErrNotFound } return a, errors.Wrap(err, "getting account") } return a, nil } // ByUsername gets an account by its username. func (s *AccountStore) ByUsername(ctx context.Context, username string, host *string) (a database.Account, err error) { q := sqlf.Sprintf("SELECT * FROM accounts WHERE username = %s AND host = %v", username, host) a, err = Get[database.Account](ctx, s.q, q) if err != nil { if errors.Cause(err) == pgx.ErrNoRows { return a, ErrNotFound } return a, errors.Wrap(err, "getting account") } return a, nil } const PasswordCost = 15 func (s *AccountStore) CreateLocal( ctx context.Context, username string, email string, password []byte, ) (a database.Account, err error) { hash, err := bcrypt.GenerateFromPassword(password, PasswordCost) if err != nil { return a, errors.Wrap(err, "hashing password") } q := sqlf.Sprintf( "INSERT INTO accounts (id, username, host, email, password) VALUES (%s, %v, NULL, %v, %v) RETURNING *", makeULID(), username, email, hash) a, err = Get[database.Account](ctx, s.q, q) if err != nil { pge := &pgconn.PgError{} if errors.As(err, &pge) { // unique constraint violation if pge.Code == "23505" { return a, ErrUsernameTaken } } return a, errors.Wrap(err, "creating account") } return a, nil }