init
This commit is contained in:
		
						commit
						2586161abd
					
				
					 49 changed files with 4171 additions and 0 deletions
				
			
		
							
								
								
									
										27
									
								
								internal/database/account.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								internal/database/account.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,27 @@
 | 
			
		|||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"emperror.dev/errors"
 | 
			
		||||
	"github.com/oklog/ulid/v2"
 | 
			
		||||
	"golang.org/x/crypto/bcrypt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Account is a single account. It can own many Blogs.
 | 
			
		||||
type Account struct {
 | 
			
		||||
	ID       ulid.ULID
 | 
			
		||||
	Username string
 | 
			
		||||
	Domain   *string
 | 
			
		||||
	Email    *string
 | 
			
		||||
	Password *[]byte
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const ErrNoPassword = errors.Sentinel("user is a remote user and has no password")
 | 
			
		||||
 | 
			
		||||
// PasswordValid returns true if the given password is valid.
 | 
			
		||||
func (a Account) PasswordValid(input string) (bool, error) {
 | 
			
		||||
	if a.Password == nil {
 | 
			
		||||
		return false, ErrNoPassword
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return bcrypt.CompareHashAndPassword(*a.Password, []byte(input)) == nil, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										16
									
								
								internal/database/blog.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								internal/database/blog.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,16 @@
 | 
			
		|||
package database
 | 
			
		||||
 | 
			
		||||
import "github.com/oklog/ulid/v2"
 | 
			
		||||
 | 
			
		||||
// Blog is a blog, or Person in ActivityPub terms.
 | 
			
		||||
type Blog struct {
 | 
			
		||||
	ID     ulid.ULID
 | 
			
		||||
	Name   string
 | 
			
		||||
	Domain *string
 | 
			
		||||
 | 
			
		||||
	Bio string
 | 
			
		||||
 | 
			
		||||
	AccountID ulid.ULID
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b Blog) Remote() bool { return b.Domain != nil }
 | 
			
		||||
							
								
								
									
										13
									
								
								internal/database/config.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								internal/database/config.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,13 @@
 | 
			
		|||
package database
 | 
			
		||||
 | 
			
		||||
import "github.com/oklog/ulid/v2"
 | 
			
		||||
 | 
			
		||||
type Config struct {
 | 
			
		||||
	Name    string
 | 
			
		||||
	AdminID *ulid.ULID
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var DefaultConfig = Config{
 | 
			
		||||
	Name:    "Mercury",
 | 
			
		||||
	AdminID: nil,
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,44 @@
 | 
			
		|||
-- 2023-04-04: Create initial tables
 | 
			
		||||
 | 
			
		||||
-- +migrate Up
 | 
			
		||||
 | 
			
		||||
-- User accounts
 | 
			
		||||
-- Many fields in this table are only relevant for local users, but all blogs are associated with an account
 | 
			
		||||
-- to make the database schema simpler.
 | 
			
		||||
create table accounts (
 | 
			
		||||
    id       text primary key,
 | 
			
		||||
    username text not null,
 | 
			
		||||
    domain   text, -- if null, is a local account
 | 
			
		||||
    email    text,
 | 
			
		||||
    password bytea,
 | 
			
		||||
 | 
			
		||||
    unique (username, domain)
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
-- Blogs
 | 
			
		||||
-- These are the AP actors.
 | 
			
		||||
create table blogs (
 | 
			
		||||
    id     text primary key,
 | 
			
		||||
    name   text not null,
 | 
			
		||||
    domain text, -- if null, is a local blog
 | 
			
		||||
    bio    text not null,
 | 
			
		||||
 | 
			
		||||
    account_id text not null references accounts (id) on delete cascade,
 | 
			
		||||
 | 
			
		||||
    unique (name, domain)
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
create type post_visibility as enum ('public', 'unlisted', 'followers', 'direct');
 | 
			
		||||
 | 
			
		||||
create table posts (
 | 
			
		||||
    id         text primary key,
 | 
			
		||||
    blog_id    text not null references blogs (id) on delete cascade,
 | 
			
		||||
    content    text,
 | 
			
		||||
    source     text,
 | 
			
		||||
    visibility post_visibility not null
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
-- +migrate Down
 | 
			
		||||
drop table accounts;
 | 
			
		||||
drop table blogs;
 | 
			
		||||
drop table posts;
 | 
			
		||||
							
								
								
									
										6
									
								
								internal/database/migrations/migrations.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								internal/database/migrations/migrations.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,6 @@
 | 
			
		|||
package migrations
 | 
			
		||||
 | 
			
		||||
import "embed"
 | 
			
		||||
 | 
			
		||||
//go:embed *.sql
 | 
			
		||||
var FS embed.FS
 | 
			
		||||
							
								
								
									
										11
									
								
								internal/database/post.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								internal/database/post.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,11 @@
 | 
			
		|||
package database
 | 
			
		||||
 | 
			
		||||
import "github.com/oklog/ulid/v2"
 | 
			
		||||
 | 
			
		||||
// Post is a single post.
 | 
			
		||||
type Post struct {
 | 
			
		||||
	ID      ulid.ULID
 | 
			
		||||
	BlogID  ulid.ULID
 | 
			
		||||
	Content *string
 | 
			
		||||
	Source  *string
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										83
									
								
								internal/database/sql/account.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								internal/database/sql/account.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,83 @@
 | 
			
		|||
package sql
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
 | 
			
		||||
	"emperror.dev/errors"
 | 
			
		||||
	"git.sleepycat.moe/sam/mercury/internal/database"
 | 
			
		||||
	"github.com/jackc/pgx/v5"
 | 
			
		||||
	"github.com/jackc/pgx/v5/pgconn"
 | 
			
		||||
	"github.com/keegancsmith/sqlf"
 | 
			
		||||
	"github.com/oklog/ulid/v2"
 | 
			
		||||
	"golang.org/x/crypto/bcrypt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// AccountStore is the interface to accounts in the database.
 | 
			
		||||
type AccountStore struct {
 | 
			
		||||
	q Querier
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewAccountStore creates a new AccountStore instance.
 | 
			
		||||
func NewAccountStore(q Querier) *AccountStore {
 | 
			
		||||
	return &AccountStore{q: q}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ByID gets an account by its ID.
 | 
			
		||||
func (s *AccountStore) ByID(ctx context.Context, id ulid.ULID) (a database.Account, err error) {
 | 
			
		||||
	q := sqlf.Sprintf("SELECT * FROM accounts WHERE id = %s", id)
 | 
			
		||||
 | 
			
		||||
	a, err = Get[database.Account](ctx, s.q, q)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if errors.Cause(err) == pgx.ErrNoRows {
 | 
			
		||||
			return a, ErrNotFound
 | 
			
		||||
		}
 | 
			
		||||
		return a, errors.Wrap(err, "getting account")
 | 
			
		||||
	}
 | 
			
		||||
	return a, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ByUsername gets an account by its username.
 | 
			
		||||
func (s *AccountStore) ByUsername(ctx context.Context, username string, host *string) (a database.Account, err error) {
 | 
			
		||||
	q := sqlf.Sprintf("SELECT * FROM accounts WHERE username = %s AND host = %v", username, host)
 | 
			
		||||
 | 
			
		||||
	a, err = Get[database.Account](ctx, s.q, q)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if errors.Cause(err) == pgx.ErrNoRows {
 | 
			
		||||
			return a, ErrNotFound
 | 
			
		||||
		}
 | 
			
		||||
		return a, errors.Wrap(err, "getting account")
 | 
			
		||||
	}
 | 
			
		||||
	return a, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const PasswordCost = 15
 | 
			
		||||
 | 
			
		||||
func (s *AccountStore) CreateLocal(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	username string,
 | 
			
		||||
	email string,
 | 
			
		||||
	password []byte,
 | 
			
		||||
) (a database.Account, err error) {
 | 
			
		||||
	hash, err := bcrypt.GenerateFromPassword(password, PasswordCost)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return a, errors.Wrap(err, "hashing password")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	q := sqlf.Sprintf(
 | 
			
		||||
		"INSERT INTO accounts (id, username, host, email, password) VALUES (%s, %v, NULL, %v, %v) RETURNING *",
 | 
			
		||||
		makeULID(), username, email, hash)
 | 
			
		||||
 | 
			
		||||
	a, err = Get[database.Account](ctx, s.q, q)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		pge := &pgconn.PgError{}
 | 
			
		||||
		if errors.As(err, &pge) {
 | 
			
		||||
			// unique constraint violation
 | 
			
		||||
			if pge.Code == "23505" {
 | 
			
		||||
				return a, ErrUsernameTaken
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return a, errors.Wrap(err, "creating account")
 | 
			
		||||
	}
 | 
			
		||||
	return a, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										86
									
								
								internal/database/sql/blog.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								internal/database/sql/blog.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,86 @@
 | 
			
		|||
package sql
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
 | 
			
		||||
	"emperror.dev/errors"
 | 
			
		||||
	"git.sleepycat.moe/sam/mercury/internal/database"
 | 
			
		||||
	"github.com/jackc/pgx/v5"
 | 
			
		||||
	"github.com/jackc/pgx/v5/pgconn"
 | 
			
		||||
	"github.com/keegancsmith/sqlf"
 | 
			
		||||
	"github.com/oklog/ulid/v2"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// BlogStore is the interface to blogs in the database.
 | 
			
		||||
type BlogStore struct {
 | 
			
		||||
	q Querier
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewBlogStore creates a new BlogStore instance.
 | 
			
		||||
func NewBlogStore(q Querier) *BlogStore {
 | 
			
		||||
	return &BlogStore{q: q}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ByID gets a blog by its ID.
 | 
			
		||||
func (s *BlogStore) ByID(ctx context.Context, id ulid.ULID) (b database.Blog, err error) {
 | 
			
		||||
	q := sqlf.Sprintf("SELECT * FROM blogs WHERE id = %s", id)
 | 
			
		||||
 | 
			
		||||
	b, err = Get[database.Blog](ctx, s.q, q)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if errors.Cause(err) == pgx.ErrNoRows {
 | 
			
		||||
			return b, ErrNotFound
 | 
			
		||||
		}
 | 
			
		||||
		return b, errors.Wrap(err, "getting blog")
 | 
			
		||||
	}
 | 
			
		||||
	return b, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ByName gets a blog by name. If domain is empty, it gets a local blog.
 | 
			
		||||
func (s *BlogStore) ByName(ctx context.Context, name, domain string) (b database.Blog, err error) {
 | 
			
		||||
	q := sqlf.Sprintf("SELECT * FROM blogs WHERE name = %s", name)
 | 
			
		||||
	if domain != "" {
 | 
			
		||||
		q = sqlf.Sprintf("%v AND domain = %s", q, domain)
 | 
			
		||||
	} else {
 | 
			
		||||
		q = sqlf.Sprintf("%v AND domain IS NULL", q)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	b, err = Get[database.Blog](ctx, s.q, q)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if errors.Cause(err) == pgx.ErrNoRows {
 | 
			
		||||
			return b, ErrNotFound
 | 
			
		||||
		}
 | 
			
		||||
		return b, errors.Wrap(err, "getting blog")
 | 
			
		||||
	}
 | 
			
		||||
	return b, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Create creates a blog.
 | 
			
		||||
func (s *BlogStore) Create(ctx context.Context, accountID ulid.ULID, name string) (b database.Blog, err error) {
 | 
			
		||||
	q := sqlf.Sprintf(
 | 
			
		||||
		"INSERT INTO blogs (id, name, domain, bio, account_id) VALUES (%s, %s, NULL, '', %s) RETURNING *",
 | 
			
		||||
		makeULID(), name, accountID)
 | 
			
		||||
 | 
			
		||||
	b, err = Get[database.Blog](ctx, s.q, q)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		pge := &pgconn.PgError{}
 | 
			
		||||
		if errors.As(err, &pge) {
 | 
			
		||||
			// unique constraint violation
 | 
			
		||||
			if pge.Code == "23505" {
 | 
			
		||||
				return b, ErrUsernameTaken
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return b, errors.Wrap(err, "creating blog")
 | 
			
		||||
	}
 | 
			
		||||
	return b, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *BlogStore) LocalCount(ctx context.Context) (count int64, err error) {
 | 
			
		||||
	q := sqlf.Sprintf("SELECT count(*) FROM blogs WHERE domain IS NULL")
 | 
			
		||||
 | 
			
		||||
	err = s.q.QueryRow(ctx, q.Query(sqlf.PostgresBindVar)).Scan(&count)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return count, errors.Wrap(err, "getting count")
 | 
			
		||||
	}
 | 
			
		||||
	return count, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										55
									
								
								internal/database/sql/database.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								internal/database/sql/database.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,55 @@
 | 
			
		|||
// Package sql implements the SQL storage layer.
 | 
			
		||||
package sql
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
 | 
			
		||||
	"emperror.dev/errors"
 | 
			
		||||
	"github.com/jackc/pgx/v5/pgxpool"
 | 
			
		||||
	"github.com/oklog/ulid/v2"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Base is the base database pool used by storage layers.
 | 
			
		||||
type Base struct {
 | 
			
		||||
	pool *pgxpool.Pool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewBase creates a new instance of Base with the specified connection string.
 | 
			
		||||
func NewBase(ctx context.Context, connString string) (*Base, error) {
 | 
			
		||||
	pool, err := pgxpool.New(ctx, connString)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "creating pool")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	base := &Base{
 | 
			
		||||
		pool: pool,
 | 
			
		||||
	}
 | 
			
		||||
	return base, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Acquire acquires a connection from the database pool.
 | 
			
		||||
// It is the caller's responsibility to call the Release method.
 | 
			
		||||
func (base *Base) Acquire(ctx context.Context) (ReleaseableQuerier, error) {
 | 
			
		||||
	conn, err := base.pool.Acquire(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "acquiring connection")
 | 
			
		||||
	}
 | 
			
		||||
	return conn, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (base *Base) BeginTx(ctx context.Context) (Tx, error) {
 | 
			
		||||
	tx, err := base.pool.Begin(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "beginning transaction")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return tx, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (base *Base) PoolQuerier() Querier {
 | 
			
		||||
	return base.pool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func makeULID() ulid.ULID {
 | 
			
		||||
	return ulid.Make()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										10
									
								
								internal/database/sql/errors.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								internal/database/sql/errors.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,10 @@
 | 
			
		|||
package sql
 | 
			
		||||
 | 
			
		||||
import "emperror.dev/errors"
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	ErrRemoteBlog    = errors.Sentinel("blog is remote")
 | 
			
		||||
	ErrUsernameTaken = errors.Sentinel("username or blog name is taken")
 | 
			
		||||
 | 
			
		||||
	ErrNotFound = errors.Sentinel("not found")
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										36
									
								
								internal/database/sql/migrate.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								internal/database/sql/migrate.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,36 @@
 | 
			
		|||
package sql
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
 | 
			
		||||
	"emperror.dev/errors"
 | 
			
		||||
	"git.sleepycat.moe/sam/mercury/internal/database/migrations"
 | 
			
		||||
	_ "github.com/jackc/pgx/v5/stdlib"
 | 
			
		||||
	"github.com/rs/zerolog/log"
 | 
			
		||||
	migrate "github.com/rubenv/sql-migrate"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func Migrate(dsn string) error {
 | 
			
		||||
	db, err := sql.Open("pgx", dsn)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.Wrap(err, "opening database")
 | 
			
		||||
	}
 | 
			
		||||
	defer db.Close()
 | 
			
		||||
 | 
			
		||||
	if err := db.Ping(); err != nil {
 | 
			
		||||
		return errors.Wrap(err, "pinging database")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	migrations := &migrate.EmbedFileSystemMigrationSource{
 | 
			
		||||
		FileSystem: migrations.FS,
 | 
			
		||||
		Root:       ".",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	n, err := migrate.Exec(db, "postgres", migrations, migrate.Up)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.Wrap(err, "executing migrations")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Info().Int("count", n).Msg("Performed migrations!")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										57
									
								
								internal/database/sql/post.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								internal/database/sql/post.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,57 @@
 | 
			
		|||
package sql
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
 | 
			
		||||
	"emperror.dev/errors"
 | 
			
		||||
	"git.sleepycat.moe/sam/mercury/internal/database"
 | 
			
		||||
	"github.com/jackc/pgx/v5"
 | 
			
		||||
	"github.com/keegancsmith/sqlf"
 | 
			
		||||
	"github.com/oklog/ulid/v2"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// PostStore is the interface to posts in the database.
 | 
			
		||||
type PostStore struct {
 | 
			
		||||
	q Querier
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewPostStore creates a new PostStore instance.
 | 
			
		||||
func NewPostStore(q Querier) *PostStore {
 | 
			
		||||
	return &PostStore{q: q}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ByID gets a post by its ID.
 | 
			
		||||
func (s *PostStore) ByID(ctx context.Context, id ulid.ULID) (p database.Post, err error) {
 | 
			
		||||
	q := sqlf.Sprintf("SELECT * FROM posts WHERE id = %s", id)
 | 
			
		||||
 | 
			
		||||
	p, err = Get[database.Post](ctx, s.q, q)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if errors.Cause(err) == pgx.ErrNoRows {
 | 
			
		||||
			return p, ErrNotFound
 | 
			
		||||
		}
 | 
			
		||||
		return p, errors.Wrap(err, "getting post")
 | 
			
		||||
	}
 | 
			
		||||
	return p, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *PostStore) Create(ctx context.Context, blog database.Blog, content string) (p database.Post, err error) {
 | 
			
		||||
	q := sqlf.Sprintf(
 | 
			
		||||
		"INSERT INTO posts (id, blog_id, content, remote) VALUES (%s, %s, %s, %v) RETURNING *",
 | 
			
		||||
		makeULID(), blog.ID, content, blog.Remote())
 | 
			
		||||
 | 
			
		||||
	p, err = Get[database.Post](ctx, s.q, q)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return p, errors.Wrap(err, "creating post")
 | 
			
		||||
	}
 | 
			
		||||
	return p, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *PostStore) LocalCount(ctx context.Context) (count int64, err error) {
 | 
			
		||||
	q := sqlf.Sprintf("SELECT count(*) FROM posts WHERE remote = false")
 | 
			
		||||
 | 
			
		||||
	err = s.q.QueryRow(ctx, q.Query(sqlf.PostgresBindVar)).Scan(&count)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return count, errors.Wrap(err, "getting count")
 | 
			
		||||
	}
 | 
			
		||||
	return count, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										50
									
								
								internal/database/sql/queries.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								internal/database/sql/queries.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,50 @@
 | 
			
		|||
package sql
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
 | 
			
		||||
	"emperror.dev/errors"
 | 
			
		||||
	"github.com/georgysavva/scany/v2/pgxscan"
 | 
			
		||||
	"github.com/jackc/pgx/v5"
 | 
			
		||||
	"github.com/jackc/pgx/v5/pgconn"
 | 
			
		||||
	"github.com/keegancsmith/sqlf"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Querier interface {
 | 
			
		||||
	Query(ctx context.Context, query string, args ...interface{}) (pgx.Rows, error)
 | 
			
		||||
	QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
 | 
			
		||||
	Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ReleaseableQuerier interface {
 | 
			
		||||
	Querier
 | 
			
		||||
 | 
			
		||||
	Release()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Tx interface {
 | 
			
		||||
	Querier
 | 
			
		||||
 | 
			
		||||
	Commit(ctx context.Context) error
 | 
			
		||||
	Rollback(ctx context.Context) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Select[T any](ctx context.Context, querier Querier, query *sqlf.Query) ([]T, error) {
 | 
			
		||||
	dst := make([]T, 0)
 | 
			
		||||
 | 
			
		||||
	err := pgxscan.Select(ctx, querier, &dst, query.Query(sqlf.PostgresBindVar), query.Args()...)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return []T{}, errors.Wrap(err, "executing query")
 | 
			
		||||
	}
 | 
			
		||||
	return dst, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Get[T any](ctx context.Context, querier Querier, query *sqlf.Query) (T, error) {
 | 
			
		||||
	var dst T
 | 
			
		||||
 | 
			
		||||
	err := pgxscan.Get(ctx, querier, &dst, query.Query(sqlf.PostgresBindVar), query.Args()...)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return dst, errors.Wrap(err, "executing query")
 | 
			
		||||
	}
 | 
			
		||||
	return dst, nil
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue