mercury/internal/database/sql/account.go

98 lines
2.5 KiB
Go

package sql
import (
"context"
"emperror.dev/errors"
"git.sleepycat.moe/sam/mercury/internal/database"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/keegancsmith/sqlf"
"github.com/oklog/ulid/v2"
"golang.org/x/crypto/bcrypt"
)
// AccountStore is the interface to accounts in the database.
type AccountStore struct {
q Querier
}
// NewAccountStore creates a new AccountStore instance.
func NewAccountStore(q Querier) *AccountStore {
return &AccountStore{q: q}
}
// ByID gets an account by its ID.
func (s *AccountStore) ByID(ctx context.Context, id ulid.ULID) (a database.Account, err error) {
q := sqlf.Sprintf("SELECT * FROM accounts WHERE id = %s", id)
a, err = Get[database.Account](ctx, s.q, q)
if err != nil {
if errors.Cause(err) == pgx.ErrNoRows {
return a, ErrNotFound
}
return a, errors.Wrap(err, "getting account")
}
return a, nil
}
// ByUsername gets an account by its username.
func (s *AccountStore) ByUsername(ctx context.Context, username, domain string) (a database.Account, err error) {
q := sqlf.Sprintf("SELECT * FROM accounts WHERE username = %s", username)
if domain == "" {
q = sqlf.Sprintf("%v AND domain IS NULL", q)
} else {
q = sqlf.Sprintf("%v AND domain = %s", q, domain)
}
a, err = Get[database.Account](ctx, s.q, q)
if err != nil {
if errors.Cause(err) == pgx.ErrNoRows {
return a, ErrNotFound
}
return a, errors.Wrap(err, "getting account")
}
return a, nil
}
const PasswordCost = 15
func (s *AccountStore) CreateLocal(
ctx context.Context,
username string,
email string,
password []byte,
) (a database.Account, err error) {
hash, err := bcrypt.GenerateFromPassword(password, PasswordCost)
if err != nil {
return a, errors.Wrap(err, "hashing password")
}
q := sqlf.Sprintf(
"INSERT INTO accounts (id, username, domain, email, password) VALUES (%s, %v, NULL, %v, %v) RETURNING *",
makeULID(), username, email, hash)
a, err = Get[database.Account](ctx, s.q, q)
if err != nil {
pge := &pgconn.PgError{}
if errors.As(err, &pge) {
// unique constraint violation
if pge.Code == "23505" {
return a, ErrUsernameTaken
}
}
return a, errors.Wrap(err, "creating account")
}
return a, nil
}
func (s *AccountStore) Follows(ctx context.Context, id ulid.ULID) ([]ulid.ULID, error) {
q := sqlf.Sprintf("SELECT blog_id FROM account_follows WHERE account_id = %s", id)
follows, err := Select[ulid.ULID](ctx, s.q, q)
if err != nil {
return nil, errors.Wrap(err, "executing query")
}
return follows, nil
}