add a couple post endpoints + /timelines/home
This commit is contained in:
parent
dd72a1f4c1
commit
9f052dc9ef
24 changed files with 462 additions and 32 deletions
|
@ -86,3 +86,13 @@ func (s *AccountStore) CreateLocal(
|
|||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (s *AccountStore) Follows(ctx context.Context, id ulid.ULID) ([]ulid.ULID, error) {
|
||||
q := sqlf.Sprintf("SELECT blog_id FROM account_follows WHERE account_id = %s", id)
|
||||
|
||||
follows, err := Select[ulid.ULID](ctx, s.q, q)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "executing query")
|
||||
}
|
||||
return follows, nil
|
||||
}
|
||||
|
|
|
@ -84,3 +84,13 @@ func (s *BlogStore) LocalCount(ctx context.Context) (count int64, err error) {
|
|||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (s *BlogStore) Followers(ctx context.Context, id ulid.ULID) ([]ulid.ULID, error) {
|
||||
q := sqlf.Sprintf("SELECT account_id FROM account_follows WHERE blog_id = %s", id)
|
||||
|
||||
follows, err := Select[ulid.ULID](ctx, s.q, q)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "executing query")
|
||||
}
|
||||
return follows, nil
|
||||
}
|
||||
|
|
|
@ -45,6 +45,9 @@ func (s *ConfigStore) Set(ctx context.Context, cur, new database.Config) (databa
|
|||
if cur.InternalApplication != new.InternalApplication {
|
||||
q = sqlf.Sprintf("%v internal_application = %v,", q, new.InternalApplication)
|
||||
}
|
||||
if cur.PostCharacterLimit != new.PostCharacterLimit {
|
||||
q = sqlf.Sprintf("%v post_character_limit = %v,", q, new.PostCharacterLimit)
|
||||
}
|
||||
q = sqlf.Sprintf("%v id = %v WHERE id = %v RETURNING *", q, cur.ID, cur.ID)
|
||||
|
||||
return Get[database.Config](ctx, s.q, q)
|
||||
|
|
|
@ -34,10 +34,10 @@ func (s *PostStore) ByID(ctx context.Context, id ulid.ULID) (p database.Post, er
|
|||
return p, nil
|
||||
}
|
||||
|
||||
func (s *PostStore) Create(ctx context.Context, blog database.Blog, content string) (p database.Post, err error) {
|
||||
func (s *PostStore) Create(ctx context.Context, blog database.Blog, content string, visibility database.PostVisibility) (p database.Post, err error) {
|
||||
q := sqlf.Sprintf(
|
||||
"INSERT INTO posts (id, blog_id, content, remote) VALUES (%s, %s, %s, %v) RETURNING *",
|
||||
makeULID(), blog.ID, content, blog.Remote())
|
||||
"INSERT INTO posts (id, blog_id, content, source, visibility) VALUES (%s, %s, %s, %s, %s) RETURNING *",
|
||||
makeULID(), blog.ID, content, content, visibility)
|
||||
|
||||
p, err = Get[database.Post](ctx, s.q, q)
|
||||
if err != nil {
|
||||
|
|
|
@ -8,8 +8,11 @@ import (
|
|||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/keegancsmith/sqlf"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
var LogQueries = true
|
||||
|
||||
type Querier interface {
|
||||
Query(ctx context.Context, query string, args ...interface{}) (pgx.Rows, error)
|
||||
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
|
||||
|
@ -32,6 +35,10 @@ type Tx interface {
|
|||
func Select[T any](ctx context.Context, querier Querier, query *sqlf.Query) ([]T, error) {
|
||||
dst := make([]T, 0)
|
||||
|
||||
if LogQueries {
|
||||
log.Debug().Str("query", query.Query(sqlf.PostgresBindVar)).Msg("executing select query")
|
||||
}
|
||||
|
||||
err := pgxscan.Select(ctx, querier, &dst, query.Query(sqlf.PostgresBindVar), query.Args()...)
|
||||
if err != nil {
|
||||
return []T{}, errors.Wrap(err, "executing query")
|
||||
|
@ -42,6 +49,10 @@ func Select[T any](ctx context.Context, querier Querier, query *sqlf.Query) ([]T
|
|||
func Get[T any](ctx context.Context, querier Querier, query *sqlf.Query) (T, error) {
|
||||
var dst T
|
||||
|
||||
if LogQueries {
|
||||
log.Debug().Str("query", query.Query(sqlf.PostgresBindVar)).Msg("executing get query")
|
||||
}
|
||||
|
||||
err := pgxscan.Get(ctx, querier, &dst, query.Query(sqlf.PostgresBindVar), query.Args()...)
|
||||
if err != nil {
|
||||
return dst, errors.Wrap(err, "executing query")
|
||||
|
@ -50,6 +61,10 @@ func Get[T any](ctx context.Context, querier Querier, query *sqlf.Query) (T, err
|
|||
}
|
||||
|
||||
func Exec(ctx context.Context, querier Querier, query *sqlf.Query) error {
|
||||
if LogQueries {
|
||||
log.Debug().Str("query", query.Query(sqlf.PostgresBindVar)).Msg("executing exec query")
|
||||
}
|
||||
|
||||
_, err := querier.Exec(ctx, query.Query(sqlf.PostgresBindVar), query.Args()...)
|
||||
return err
|
||||
}
|
||||
|
|
49
internal/database/sql/timeline.go
Normal file
49
internal/database/sql/timeline.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.sleepycat.moe/sam/mercury/internal/database"
|
||||
"github.com/keegancsmith/sqlf"
|
||||
"github.com/oklog/ulid/v2"
|
||||
)
|
||||
|
||||
// TimelineStore is the interface to post timelines in the database.
|
||||
type TimelineStore struct {
|
||||
q Querier
|
||||
}
|
||||
|
||||
// NewTimelineStore creates a new TimelineStore instance.
|
||||
func NewTimelineStore(q Querier) *TimelineStore {
|
||||
return &TimelineStore{q: q}
|
||||
}
|
||||
|
||||
type TimelinePost struct {
|
||||
database.Post
|
||||
database.Blog
|
||||
}
|
||||
|
||||
func (s *TimelineStore) Home(ctx context.Context, accountID ulid.ULID, limit int, before, after *ulid.ULID) ([]TimelinePost, error) {
|
||||
q := sqlf.Sprintf("SELECT p.*, b.account_id, b.name, b.domain FROM posts p JOIN blogs b ON b.id = p.blog_id")
|
||||
|
||||
q = sqlf.Sprintf("%v WHERE (blog_id IN (%s) OR blog_id IN (%s))", q,
|
||||
sqlf.Sprintf("SELECT id FROM blogs WHERE account_id = %s", accountID),
|
||||
sqlf.Sprintf("SELECT blog_id FROM account_follows WHERE account_id = %s", accountID))
|
||||
|
||||
if before != nil {
|
||||
q = sqlf.Sprintf("%v AND id < %s", q, *before)
|
||||
}
|
||||
if after != nil {
|
||||
q = sqlf.Sprintf("%v AND id > %s", q, *after)
|
||||
}
|
||||
|
||||
q = sqlf.Sprintf("%v AND (visibility != %s OR (b.account_id = %s OR %s IN (%s)))", q, database.DirectVisibility, accountID, accountID,
|
||||
sqlf.Sprintf("SELECT account_id FROM blogs WHERE id IN (SELECT blog_id FROM post_recipients WHERE post_id = p.id)"))
|
||||
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 100
|
||||
}
|
||||
q = sqlf.Sprintf("%v ORDER BY id DESC LIMIT %d", q, limit)
|
||||
|
||||
return Select[TimelinePost](ctx, s.q, q)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue