This commit is contained in:
sam 2023-09-03 00:23:48 +02:00
commit 2586161abd
Signed by: sam
GPG key ID: B4EF20DDE721CAA1
49 changed files with 4171 additions and 0 deletions

View file

@ -0,0 +1,83 @@
package sql
import (
"context"
"emperror.dev/errors"
"git.sleepycat.moe/sam/mercury/internal/database"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/keegancsmith/sqlf"
"github.com/oklog/ulid/v2"
"golang.org/x/crypto/bcrypt"
)
// AccountStore is the interface to accounts in the database.
type AccountStore struct {
q Querier
}
// NewAccountStore creates a new AccountStore instance.
func NewAccountStore(q Querier) *AccountStore {
return &AccountStore{q: q}
}
// ByID gets an account by its ID.
func (s *AccountStore) ByID(ctx context.Context, id ulid.ULID) (a database.Account, err error) {
q := sqlf.Sprintf("SELECT * FROM accounts WHERE id = %s", id)
a, err = Get[database.Account](ctx, s.q, q)
if err != nil {
if errors.Cause(err) == pgx.ErrNoRows {
return a, ErrNotFound
}
return a, errors.Wrap(err, "getting account")
}
return a, nil
}
// ByUsername gets an account by its username.
func (s *AccountStore) ByUsername(ctx context.Context, username string, host *string) (a database.Account, err error) {
q := sqlf.Sprintf("SELECT * FROM accounts WHERE username = %s AND host = %v", username, host)
a, err = Get[database.Account](ctx, s.q, q)
if err != nil {
if errors.Cause(err) == pgx.ErrNoRows {
return a, ErrNotFound
}
return a, errors.Wrap(err, "getting account")
}
return a, nil
}
const PasswordCost = 15
func (s *AccountStore) CreateLocal(
ctx context.Context,
username string,
email string,
password []byte,
) (a database.Account, err error) {
hash, err := bcrypt.GenerateFromPassword(password, PasswordCost)
if err != nil {
return a, errors.Wrap(err, "hashing password")
}
q := sqlf.Sprintf(
"INSERT INTO accounts (id, username, host, email, password) VALUES (%s, %v, NULL, %v, %v) RETURNING *",
makeULID(), username, email, hash)
a, err = Get[database.Account](ctx, s.q, q)
if err != nil {
pge := &pgconn.PgError{}
if errors.As(err, &pge) {
// unique constraint violation
if pge.Code == "23505" {
return a, ErrUsernameTaken
}
}
return a, errors.Wrap(err, "creating account")
}
return a, nil
}

View file

@ -0,0 +1,86 @@
package sql
import (
"context"
"emperror.dev/errors"
"git.sleepycat.moe/sam/mercury/internal/database"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/keegancsmith/sqlf"
"github.com/oklog/ulid/v2"
)
// BlogStore is the interface to blogs in the database.
type BlogStore struct {
q Querier
}
// NewBlogStore creates a new BlogStore instance.
func NewBlogStore(q Querier) *BlogStore {
return &BlogStore{q: q}
}
// ByID gets a blog by its ID.
func (s *BlogStore) ByID(ctx context.Context, id ulid.ULID) (b database.Blog, err error) {
q := sqlf.Sprintf("SELECT * FROM blogs WHERE id = %s", id)
b, err = Get[database.Blog](ctx, s.q, q)
if err != nil {
if errors.Cause(err) == pgx.ErrNoRows {
return b, ErrNotFound
}
return b, errors.Wrap(err, "getting blog")
}
return b, nil
}
// ByName gets a blog by name. If domain is empty, it gets a local blog.
func (s *BlogStore) ByName(ctx context.Context, name, domain string) (b database.Blog, err error) {
q := sqlf.Sprintf("SELECT * FROM blogs WHERE name = %s", name)
if domain != "" {
q = sqlf.Sprintf("%v AND domain = %s", q, domain)
} else {
q = sqlf.Sprintf("%v AND domain IS NULL", q)
}
b, err = Get[database.Blog](ctx, s.q, q)
if err != nil {
if errors.Cause(err) == pgx.ErrNoRows {
return b, ErrNotFound
}
return b, errors.Wrap(err, "getting blog")
}
return b, nil
}
// Create creates a blog.
func (s *BlogStore) Create(ctx context.Context, accountID ulid.ULID, name string) (b database.Blog, err error) {
q := sqlf.Sprintf(
"INSERT INTO blogs (id, name, domain, bio, account_id) VALUES (%s, %s, NULL, '', %s) RETURNING *",
makeULID(), name, accountID)
b, err = Get[database.Blog](ctx, s.q, q)
if err != nil {
pge := &pgconn.PgError{}
if errors.As(err, &pge) {
// unique constraint violation
if pge.Code == "23505" {
return b, ErrUsernameTaken
}
}
return b, errors.Wrap(err, "creating blog")
}
return b, nil
}
func (s *BlogStore) LocalCount(ctx context.Context) (count int64, err error) {
q := sqlf.Sprintf("SELECT count(*) FROM blogs WHERE domain IS NULL")
err = s.q.QueryRow(ctx, q.Query(sqlf.PostgresBindVar)).Scan(&count)
if err != nil {
return count, errors.Wrap(err, "getting count")
}
return count, nil
}

View file

@ -0,0 +1,55 @@
// Package sql implements the SQL storage layer.
package sql
import (
"context"
"emperror.dev/errors"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/oklog/ulid/v2"
)
// Base is the base database pool used by storage layers.
type Base struct {
pool *pgxpool.Pool
}
// NewBase creates a new instance of Base with the specified connection string.
func NewBase(ctx context.Context, connString string) (*Base, error) {
pool, err := pgxpool.New(ctx, connString)
if err != nil {
return nil, errors.Wrap(err, "creating pool")
}
base := &Base{
pool: pool,
}
return base, nil
}
// Acquire acquires a connection from the database pool.
// It is the caller's responsibility to call the Release method.
func (base *Base) Acquire(ctx context.Context) (ReleaseableQuerier, error) {
conn, err := base.pool.Acquire(ctx)
if err != nil {
return nil, errors.Wrap(err, "acquiring connection")
}
return conn, nil
}
func (base *Base) BeginTx(ctx context.Context) (Tx, error) {
tx, err := base.pool.Begin(ctx)
if err != nil {
return nil, errors.Wrap(err, "beginning transaction")
}
return tx, nil
}
func (base *Base) PoolQuerier() Querier {
return base.pool
}
func makeULID() ulid.ULID {
return ulid.Make()
}

View file

@ -0,0 +1,10 @@
package sql
import "emperror.dev/errors"
const (
ErrRemoteBlog = errors.Sentinel("blog is remote")
ErrUsernameTaken = errors.Sentinel("username or blog name is taken")
ErrNotFound = errors.Sentinel("not found")
)

View file

@ -0,0 +1,36 @@
package sql
import (
"database/sql"
"emperror.dev/errors"
"git.sleepycat.moe/sam/mercury/internal/database/migrations"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/rs/zerolog/log"
migrate "github.com/rubenv/sql-migrate"
)
func Migrate(dsn string) error {
db, err := sql.Open("pgx", dsn)
if err != nil {
return errors.Wrap(err, "opening database")
}
defer db.Close()
if err := db.Ping(); err != nil {
return errors.Wrap(err, "pinging database")
}
migrations := &migrate.EmbedFileSystemMigrationSource{
FileSystem: migrations.FS,
Root: ".",
}
n, err := migrate.Exec(db, "postgres", migrations, migrate.Up)
if err != nil {
return errors.Wrap(err, "executing migrations")
}
log.Info().Int("count", n).Msg("Performed migrations!")
return nil
}

View file

@ -0,0 +1,57 @@
package sql
import (
"context"
"emperror.dev/errors"
"git.sleepycat.moe/sam/mercury/internal/database"
"github.com/jackc/pgx/v5"
"github.com/keegancsmith/sqlf"
"github.com/oklog/ulid/v2"
)
// PostStore is the interface to posts in the database.
type PostStore struct {
q Querier
}
// NewPostStore creates a new PostStore instance.
func NewPostStore(q Querier) *PostStore {
return &PostStore{q: q}
}
// ByID gets a post by its ID.
func (s *PostStore) ByID(ctx context.Context, id ulid.ULID) (p database.Post, err error) {
q := sqlf.Sprintf("SELECT * FROM posts WHERE id = %s", id)
p, err = Get[database.Post](ctx, s.q, q)
if err != nil {
if errors.Cause(err) == pgx.ErrNoRows {
return p, ErrNotFound
}
return p, errors.Wrap(err, "getting post")
}
return p, nil
}
func (s *PostStore) Create(ctx context.Context, blog database.Blog, content string) (p database.Post, err error) {
q := sqlf.Sprintf(
"INSERT INTO posts (id, blog_id, content, remote) VALUES (%s, %s, %s, %v) RETURNING *",
makeULID(), blog.ID, content, blog.Remote())
p, err = Get[database.Post](ctx, s.q, q)
if err != nil {
return p, errors.Wrap(err, "creating post")
}
return p, nil
}
func (s *PostStore) LocalCount(ctx context.Context) (count int64, err error) {
q := sqlf.Sprintf("SELECT count(*) FROM posts WHERE remote = false")
err = s.q.QueryRow(ctx, q.Query(sqlf.PostgresBindVar)).Scan(&count)
if err != nil {
return count, errors.Wrap(err, "getting count")
}
return count, nil
}

View file

@ -0,0 +1,50 @@
package sql
import (
"context"
"emperror.dev/errors"
"github.com/georgysavva/scany/v2/pgxscan"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/keegancsmith/sqlf"
)
type Querier interface {
Query(ctx context.Context, query string, args ...interface{}) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
}
type ReleaseableQuerier interface {
Querier
Release()
}
type Tx interface {
Querier
Commit(ctx context.Context) error
Rollback(ctx context.Context) error
}
func Select[T any](ctx context.Context, querier Querier, query *sqlf.Query) ([]T, error) {
dst := make([]T, 0)
err := pgxscan.Select(ctx, querier, &dst, query.Query(sqlf.PostgresBindVar), query.Args()...)
if err != nil {
return []T{}, errors.Wrap(err, "executing query")
}
return dst, nil
}
func Get[T any](ctx context.Context, querier Querier, query *sqlf.Query) (T, error) {
var dst T
err := pgxscan.Get(ctx, querier, &dst, query.Query(sqlf.PostgresBindVar), query.Args()...)
if err != nil {
return dst, errors.Wrap(err, "executing query")
}
return dst, nil
}