98 lines
2.5 KiB
Go
98 lines
2.5 KiB
Go
package sql
|
|
|
|
import (
|
|
"context"
|
|
|
|
"emperror.dev/errors"
|
|
"git.sleepycat.moe/sam/mercury/internal/database"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/keegancsmith/sqlf"
|
|
"github.com/oklog/ulid/v2"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
// AccountStore is the interface to accounts in the database.
|
|
type AccountStore struct {
|
|
q Querier
|
|
}
|
|
|
|
// NewAccountStore creates a new AccountStore instance.
|
|
func NewAccountStore(q Querier) *AccountStore {
|
|
return &AccountStore{q: q}
|
|
}
|
|
|
|
// ByID gets an account by its ID.
|
|
func (s *AccountStore) ByID(ctx context.Context, id ulid.ULID) (a database.Account, err error) {
|
|
q := sqlf.Sprintf("SELECT * FROM accounts WHERE id = %s", id)
|
|
|
|
a, err = Get[database.Account](ctx, s.q, q)
|
|
if err != nil {
|
|
if errors.Cause(err) == pgx.ErrNoRows {
|
|
return a, ErrNotFound
|
|
}
|
|
return a, errors.Wrap(err, "getting account")
|
|
}
|
|
return a, nil
|
|
}
|
|
|
|
// ByUsername gets an account by its username.
|
|
func (s *AccountStore) ByUsername(ctx context.Context, username, domain string) (a database.Account, err error) {
|
|
q := sqlf.Sprintf("SELECT * FROM accounts WHERE username = %s", username)
|
|
if domain == "" {
|
|
q = sqlf.Sprintf("%v AND domain IS NULL", q)
|
|
} else {
|
|
q = sqlf.Sprintf("%v AND domain = %s", q, domain)
|
|
}
|
|
|
|
a, err = Get[database.Account](ctx, s.q, q)
|
|
if err != nil {
|
|
if errors.Cause(err) == pgx.ErrNoRows {
|
|
return a, ErrNotFound
|
|
}
|
|
return a, errors.Wrap(err, "getting account")
|
|
}
|
|
return a, nil
|
|
}
|
|
|
|
const PasswordCost = 15
|
|
|
|
func (s *AccountStore) CreateLocal(
|
|
ctx context.Context,
|
|
username string,
|
|
email string,
|
|
password []byte,
|
|
) (a database.Account, err error) {
|
|
hash, err := bcrypt.GenerateFromPassword(password, PasswordCost)
|
|
if err != nil {
|
|
return a, errors.Wrap(err, "hashing password")
|
|
}
|
|
|
|
q := sqlf.Sprintf(
|
|
"INSERT INTO accounts (id, username, domain, email, password) VALUES (%s, %v, NULL, %v, %v) RETURNING *",
|
|
makeULID(), username, email, hash)
|
|
|
|
a, err = Get[database.Account](ctx, s.q, q)
|
|
if err != nil {
|
|
pge := &pgconn.PgError{}
|
|
if errors.As(err, &pge) {
|
|
// unique constraint violation
|
|
if pge.Code == "23505" {
|
|
return a, ErrUsernameTaken
|
|
}
|
|
}
|
|
|
|
return a, errors.Wrap(err, "creating account")
|
|
}
|
|
return a, nil
|
|
}
|
|
|
|
func (s *AccountStore) Follows(ctx context.Context, id ulid.ULID) ([]ulid.ULID, error) {
|
|
q := sqlf.Sprintf("SELECT blog_id FROM account_follows WHERE account_id = %s", id)
|
|
|
|
follows, err := Select[ulid.ULID](ctx, s.q, q)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "executing query")
|
|
}
|
|
return follows, nil
|
|
}
|