package sql import ( "context" "emperror.dev/errors" "git.sleepycat.moe/sam/mercury/internal/database" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/keegancsmith/sqlf" "github.com/oklog/ulid/v2" ) // BlogStore is the interface to blogs in the database. type BlogStore struct { q Querier } // NewBlogStore creates a new BlogStore instance. func NewBlogStore(q Querier) *BlogStore { return &BlogStore{q: q} } // ByID gets a blog by its ID. func (s *BlogStore) ByID(ctx context.Context, id ulid.ULID) (b database.Blog, err error) { q := sqlf.Sprintf("SELECT * FROM blogs WHERE id = %s", id) b, err = Get[database.Blog](ctx, s.q, q) if err != nil { if errors.Cause(err) == pgx.ErrNoRows { return b, ErrNotFound } return b, errors.Wrap(err, "getting blog") } return b, nil } // ByName gets a blog by name. If domain is empty, it gets a local blog. func (s *BlogStore) ByName(ctx context.Context, name, domain string) (b database.Blog, err error) { q := sqlf.Sprintf("SELECT * FROM blogs WHERE name = %s", name) if domain != "" { q = sqlf.Sprintf("%v AND domain = %s", q, domain) } else { q = sqlf.Sprintf("%v AND domain IS NULL", q) } b, err = Get[database.Blog](ctx, s.q, q) if err != nil { if errors.Cause(err) == pgx.ErrNoRows { return b, ErrNotFound } return b, errors.Wrap(err, "getting blog") } return b, nil } // Create creates a blog. func (s *BlogStore) Create(ctx context.Context, accountID ulid.ULID, name string) (b database.Blog, err error) { q := sqlf.Sprintf( "INSERT INTO blogs (id, name, domain, bio, account_id) VALUES (%s, %s, NULL, '', %s) RETURNING *", makeULID(), name, accountID) b, err = Get[database.Blog](ctx, s.q, q) if err != nil { pge := &pgconn.PgError{} if errors.As(err, &pge) { // unique constraint violation if pge.Code == "23505" { return b, ErrUsernameTaken } } return b, errors.Wrap(err, "creating blog") } return b, nil } func (s *BlogStore) LocalCount(ctx context.Context) (count int64, err error) { q := sqlf.Sprintf("SELECT count(*) FROM blogs WHERE domain IS NULL") err = s.q.QueryRow(ctx, q.Query(sqlf.PostgresBindVar)).Scan(&count) if err != nil { return count, errors.Wrap(err, "getting count") } return count, nil }