add a couple post endpoints + /timelines/home

This commit is contained in:
sam 2023-09-06 16:32:33 +02:00
parent dd72a1f4c1
commit 9f052dc9ef
Signed by: sam
GPG key ID: B4EF20DDE721CAA1
24 changed files with 462 additions and 32 deletions

View file

@ -5,6 +5,7 @@ import (
"emperror.dev/errors"
"git.sleepycat.moe/sam/mercury/config"
"git.sleepycat.moe/sam/mercury/internal/database"
"git.sleepycat.moe/sam/mercury/internal/database/sql"
"git.sleepycat.moe/sam/mercury/web/app"
"github.com/rs/zerolog/log"
@ -52,7 +53,7 @@ func run(c *cli.Context) error {
log.Debug().Msg("Created blog")
log.Debug().Msg("Creating post")
post, err := a.Post().Create(c.Context, blog, "Hello world!")
post, err := a.Post().Create(c.Context, blog, "Hello world!", database.PublicVisibility)
if err != nil {
log.Err(err).Msg("Creating post")
return nil

View file

@ -6,3 +6,9 @@ secret_key = "" # generate with `openssl rand -base64 48`
[web]
domain = "http://mercury.local"
port = 8000
[security]
# whether the API should be entirely unavailable to unauthenticated users (except for OAuth endpoints)
restrict_api = false
# whether the local and global timelines should be available to unauthenticated users
public_timelines = false

View file

@ -3,6 +3,7 @@ package config
type Config struct {
Core CoreConfig `toml:"core"`
Web WebConfig `toml:"web"`
Security SecurityConfig `toml:"security"`
}
type WebConfig struct {
@ -18,3 +19,8 @@ type CoreConfig struct {
Dev bool `toml:"dev"`
SecretKey string `toml:"secret_key"`
}
type SecurityConfig struct {
RestrictAPI bool `toml:"restrict_api"`
PublicTimelines bool `toml:"public_timelines"`
}

View file

@ -6,6 +6,7 @@ type Config struct {
Name string
AdminID *ulid.ULID
InternalApplication *ulid.ULID
PostCharacterLimit int
// ID is always 1
ID int

View file

@ -0,0 +1,12 @@
-- 2023-09-06: Account->blog follows
-- +migrate Up
create table account_follows (
account_id text not null references accounts (id) on delete cascade,
blog_id text not null references blogs (id) on delete cascade
);
-- +migrate Down
drop table account_follows;

View file

@ -0,0 +1,19 @@
-- 2023-09-06: Some improvements to posts
-- +migrate Up
alter table config add column post_character_limit integer not null default 5000;
create index posts_blog_id_idx on posts (blog_id);
create table post_recipients (
post_id TEXT NOT NULL REFERENCES posts (id) ON DELETE CASCADE,
blog_id TEXT NOT NULL REFERENCES blogs (id) ON DELETE CASCADE,
PRIMARY KEY (post_id, blog_id)
);
-- +migrate Down
drop table post_recipients;
drop index posts_blog_id_idx;
alter table config drop column post_character_limit;

View file

@ -2,10 +2,20 @@ package database
import "github.com/oklog/ulid/v2"
type PostVisibility string
const (
PublicVisibility PostVisibility = "public"
UnlistedVisibility PostVisibility = "unlisted"
FollowersVisibility PostVisibility = "followers"
DirectVisibility PostVisibility = "direct"
)
// Post is a single post.
type Post struct {
ID ulid.ULID
BlogID ulid.ULID
Content *string
Source *string
Visibility PostVisibility
}

View file

@ -86,3 +86,13 @@ func (s *AccountStore) CreateLocal(
}
return a, nil
}
func (s *AccountStore) Follows(ctx context.Context, id ulid.ULID) ([]ulid.ULID, error) {
q := sqlf.Sprintf("SELECT blog_id FROM account_follows WHERE account_id = %s", id)
follows, err := Select[ulid.ULID](ctx, s.q, q)
if err != nil {
return nil, errors.Wrap(err, "executing query")
}
return follows, nil
}

View file

@ -84,3 +84,13 @@ func (s *BlogStore) LocalCount(ctx context.Context) (count int64, err error) {
}
return count, nil
}
func (s *BlogStore) Followers(ctx context.Context, id ulid.ULID) ([]ulid.ULID, error) {
q := sqlf.Sprintf("SELECT account_id FROM account_follows WHERE blog_id = %s", id)
follows, err := Select[ulid.ULID](ctx, s.q, q)
if err != nil {
return nil, errors.Wrap(err, "executing query")
}
return follows, nil
}

View file

@ -45,6 +45,9 @@ func (s *ConfigStore) Set(ctx context.Context, cur, new database.Config) (databa
if cur.InternalApplication != new.InternalApplication {
q = sqlf.Sprintf("%v internal_application = %v,", q, new.InternalApplication)
}
if cur.PostCharacterLimit != new.PostCharacterLimit {
q = sqlf.Sprintf("%v post_character_limit = %v,", q, new.PostCharacterLimit)
}
q = sqlf.Sprintf("%v id = %v WHERE id = %v RETURNING *", q, cur.ID, cur.ID)
return Get[database.Config](ctx, s.q, q)

View file

@ -34,10 +34,10 @@ func (s *PostStore) ByID(ctx context.Context, id ulid.ULID) (p database.Post, er
return p, nil
}
func (s *PostStore) Create(ctx context.Context, blog database.Blog, content string) (p database.Post, err error) {
func (s *PostStore) Create(ctx context.Context, blog database.Blog, content string, visibility database.PostVisibility) (p database.Post, err error) {
q := sqlf.Sprintf(
"INSERT INTO posts (id, blog_id, content, remote) VALUES (%s, %s, %s, %v) RETURNING *",
makeULID(), blog.ID, content, blog.Remote())
"INSERT INTO posts (id, blog_id, content, source, visibility) VALUES (%s, %s, %s, %s, %s) RETURNING *",
makeULID(), blog.ID, content, content, visibility)
p, err = Get[database.Post](ctx, s.q, q)
if err != nil {

View file

@ -8,8 +8,11 @@ import (
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/keegancsmith/sqlf"
"github.com/rs/zerolog/log"
)
var LogQueries = true
type Querier interface {
Query(ctx context.Context, query string, args ...interface{}) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
@ -32,6 +35,10 @@ type Tx interface {
func Select[T any](ctx context.Context, querier Querier, query *sqlf.Query) ([]T, error) {
dst := make([]T, 0)
if LogQueries {
log.Debug().Str("query", query.Query(sqlf.PostgresBindVar)).Msg("executing select query")
}
err := pgxscan.Select(ctx, querier, &dst, query.Query(sqlf.PostgresBindVar), query.Args()...)
if err != nil {
return []T{}, errors.Wrap(err, "executing query")
@ -42,6 +49,10 @@ func Select[T any](ctx context.Context, querier Querier, query *sqlf.Query) ([]T
func Get[T any](ctx context.Context, querier Querier, query *sqlf.Query) (T, error) {
var dst T
if LogQueries {
log.Debug().Str("query", query.Query(sqlf.PostgresBindVar)).Msg("executing get query")
}
err := pgxscan.Get(ctx, querier, &dst, query.Query(sqlf.PostgresBindVar), query.Args()...)
if err != nil {
return dst, errors.Wrap(err, "executing query")
@ -50,6 +61,10 @@ func Get[T any](ctx context.Context, querier Querier, query *sqlf.Query) (T, err
}
func Exec(ctx context.Context, querier Querier, query *sqlf.Query) error {
if LogQueries {
log.Debug().Str("query", query.Query(sqlf.PostgresBindVar)).Msg("executing exec query")
}
_, err := querier.Exec(ctx, query.Query(sqlf.PostgresBindVar), query.Args()...)
return err
}

View file

@ -0,0 +1,49 @@
package sql
import (
"context"
"git.sleepycat.moe/sam/mercury/internal/database"
"github.com/keegancsmith/sqlf"
"github.com/oklog/ulid/v2"
)
// TimelineStore is the interface to post timelines in the database.
type TimelineStore struct {
q Querier
}
// NewTimelineStore creates a new TimelineStore instance.
func NewTimelineStore(q Querier) *TimelineStore {
return &TimelineStore{q: q}
}
type TimelinePost struct {
database.Post
database.Blog
}
func (s *TimelineStore) Home(ctx context.Context, accountID ulid.ULID, limit int, before, after *ulid.ULID) ([]TimelinePost, error) {
q := sqlf.Sprintf("SELECT p.*, b.account_id, b.name, b.domain FROM posts p JOIN blogs b ON b.id = p.blog_id")
q = sqlf.Sprintf("%v WHERE (blog_id IN (%s) OR blog_id IN (%s))", q,
sqlf.Sprintf("SELECT id FROM blogs WHERE account_id = %s", accountID),
sqlf.Sprintf("SELECT blog_id FROM account_follows WHERE account_id = %s", accountID))
if before != nil {
q = sqlf.Sprintf("%v AND id < %s", q, *before)
}
if after != nil {
q = sqlf.Sprintf("%v AND id > %s", q, *after)
}
q = sqlf.Sprintf("%v AND (visibility != %s OR (b.account_id = %s OR %s IN (%s)))", q, database.DirectVisibility, accountID, accountID,
sqlf.Sprintf("SELECT account_id FROM blogs WHERE id IN (SELECT blog_id FROM post_recipients WHERE post_id = p.id)"))
if limit <= 0 || limit > 100 {
limit = 100
}
q = sqlf.Sprintf("%v ORDER BY id DESC LIMIT %d", q, limit)
return Select[TimelinePost](ctx, s.q, q)
}

View file

@ -55,6 +55,9 @@ const (
TokenScopeAccountsWrite TokenScope = "accounts.write"
TokenScopeBlogsRead TokenScope = "blogs.read"
TokenScopeBlogsWrite TokenScope = "blogs.write"
TokenScopePostsRead TokenScope = "posts.read"
TokenScopePostsWrite TokenScope = "posts.write"
TokenScopeTimeline TokenScope = "timeline"
)
func (s TokenScope) IsValid() bool {

View file

@ -0,0 +1,8 @@
package utils
import "unicode/utf8"
// StringLength returns the length of a string by counting runes.
func StringLength(s string) int {
return utf8.RuneCountInString(s)
}

View file

@ -124,12 +124,16 @@ const (
// Auth related
ErrInvalidToken = 1001
ErrMissingScope = 1002
ErrNotYourObject = 1003
// Account related
ErrAccountNotFound = 2001
// Blog related
ErrBlogNotFound = 3001
// Post related
ErrPostNotFound = 4001
)
func ErrCodeMessage(code int) string {
@ -146,10 +150,13 @@ var errCodeMessages = map[int]string{
ErrInvalidToken: "No token supplied, or token is invalid",
ErrMissingScope: "Token is missing required scope for this endpoint",
ErrNotYourObject: "Object you are trying to perform action on is not owned by you",
ErrAccountNotFound: "Account not found",
ErrBlogNotFound: "Blog not found",
ErrPostNotFound: "Post not found",
}
func ErrCodeStatus(code int) int {
@ -166,8 +173,11 @@ var errCodeStatuses = map[int]int{
ErrInvalidToken: http.StatusUnauthorized,
ErrMissingScope: http.StatusForbidden,
ErrNotYourObject: http.StatusForbidden,
ErrAccountNotFound: http.StatusNotFound,
ErrBlogNotFound: http.StatusNotFound,
ErrPostNotFound: http.StatusNotFound,
}

35
web/api/post.go Normal file
View file

@ -0,0 +1,35 @@
package api
import (
"git.sleepycat.moe/sam/mercury/internal/database"
"github.com/oklog/ulid/v2"
)
type Post struct {
ID ulid.ULID `json:"id"`
Content *string `json:"content"`
Source *string `json:"source"`
Visibility database.PostVisibility `json:"visibility"`
Blog postPartialBlog `json:"blog"`
}
type postPartialBlog struct {
ID ulid.ULID `json:"id"`
Name string `json:"name"`
Domain *string `json:"domain"`
}
func DBPostToPost(p database.Post, b database.Blog) Post {
return Post{
ID: p.ID,
Content: p.Content,
Source: p.Source,
Visibility: p.Visibility,
Blog: postPartialBlog{
ID: p.BlogID,
Name: b.Name,
Domain: b.Domain,
},
}
}

View file

@ -0,0 +1,91 @@
package posts
import (
"net/http"
"git.sleepycat.moe/sam/mercury/internal/database"
"git.sleepycat.moe/sam/mercury/internal/database/sql"
"git.sleepycat.moe/sam/mercury/internal/utils"
"git.sleepycat.moe/sam/mercury/web/api"
"github.com/go-chi/chi/v5"
"github.com/go-chi/render"
"github.com/oklog/ulid/v2"
"github.com/rs/zerolog/log"
)
type createPostParams struct {
Content string `json:"content"`
Visibility database.PostVisibility `json:"visibility"`
}
func (p createPostParams) Validate(cfg database.Config) bool {
if p.Content == "" { // TODO: allow empty content if the post has attachments
return false
}
if p.Visibility == "" {
return false
}
if utils.StringLength(p.Content) > cfg.PostCharacterLimit {
return false
}
switch p.Visibility {
case database.PublicVisibility, database.UnlistedVisibility, database.FollowersVisibility, database.DirectVisibility:
default:
return false
}
// everything checks out
return true
}
func (app *App) Create(w http.ResponseWriter, r *http.Request) (api.Post, error) {
ctx := r.Context()
token, _ := app.TokenFromContext(ctx)
var req createPostParams
err := render.Decode(r, &req)
if err != nil {
return api.Post{}, api.Error{Code: api.ErrBadRequest}
}
if !req.Validate(app.DBConfig.Get()) {
return api.Post{}, api.Error{Code: api.ErrBadRequest}
}
blogID, err := ulid.Parse(chi.URLParam(r, "blogID"))
if err != nil {
return api.Post{}, api.Error{Code: api.ErrBlogNotFound}
}
conn, err := app.Database.Acquire(ctx)
if err != nil {
log.Err(err).Msg("acquiring connection")
return api.Post{}, err
}
defer conn.Release()
blog, err := app.Blog(conn).ByID(ctx, blogID)
if err != nil {
if err == sql.ErrNotFound {
return api.Post{}, api.Error{Code: api.ErrBlogNotFound}
}
log.Err(err).Msg("fetching blog")
return api.Post{}, err
}
if blog.AccountID != token.UserID {
return api.Post{}, api.Error{Code: api.ErrNotYourObject}
}
// create post
post, err := app.Post(conn).Create(ctx, blog, req.Content, req.Visibility)
if err != nil {
log.Err(err).Msg("creating post")
return api.Post{}, err
}
// TODO: federate post + push to websockets
return api.DBPostToPost(post, blog), nil
}

44
web/api/posts/get_post.go Normal file
View file

@ -0,0 +1,44 @@
package posts
import (
"net/http"
"git.sleepycat.moe/sam/mercury/internal/database/sql"
"git.sleepycat.moe/sam/mercury/web/api"
"github.com/go-chi/chi/v5"
"github.com/oklog/ulid/v2"
"github.com/rs/zerolog/log"
)
func (app *App) GetID(w http.ResponseWriter, r *http.Request) (api.Post, error) {
ctx := r.Context()
id, err := ulid.Parse(chi.URLParamFromCtx(ctx, "postID"))
if err != nil {
return api.Post{}, api.Error{Code: api.ErrBlogNotFound}
}
conn, err := app.Database.Acquire(ctx)
if err != nil {
log.Err(err).Msg("acquiring connection")
return api.Post{}, err
}
defer conn.Release()
post, err := app.Post(conn).ByID(ctx, id)
if err != nil {
if err == sql.ErrNotFound {
return api.Post{}, api.Error{Code: api.ErrBlogNotFound}
}
log.Err(err).Str("id", id.String()).Msg("fetching post from database")
return api.Post{}, err
}
blog, err := app.Blog(conn).ByID(ctx, post.BlogID)
if err != nil {
log.Err(err).Str("id", post.BlogID.String()).Msg("fetching blog from database")
return api.Post{}, err
}
return api.DBPostToPost(post, blog), nil
}

13
web/api/posts/module.go Normal file
View file

@ -0,0 +1,13 @@
package posts
import "git.sleepycat.moe/sam/mercury/web/app"
type App struct {
*app.App
}
func New(app *app.App) *App {
return &App{
App: app,
}
}

View file

@ -0,0 +1,47 @@
package timelines
import (
"net/http"
"strconv"
"git.sleepycat.moe/sam/mercury/web/api"
"github.com/oklog/ulid/v2"
"github.com/rs/zerolog/log"
)
type timelineResponse struct {
Posts []api.Post `json:"posts"`
}
func (app *App) Home(w http.ResponseWriter, r *http.Request) (timelineResponse, error) {
ctx := r.Context()
token, _ := app.TokenFromContext(ctx)
var before, after *ulid.ULID
if id, err := ulid.Parse(r.FormValue("before")); err == nil {
before = &id
}
if id, err := ulid.Parse(r.FormValue("after")); err == nil {
after = &id
}
if before != nil && after != nil {
return timelineResponse{}, api.Error{Code: api.ErrBadRequest, Details: "`before` and `after` are mutually exclusive"}
}
limit, err := strconv.Atoi(r.FormValue("limit"))
if err != nil || limit <= 0 || limit > 100 {
limit = 100
}
posts, err := app.Timeline().Home(ctx, token.UserID, limit, before, after)
if err != nil {
log.Err(err).Msg("getting posts from database")
}
resp := timelineResponse{Posts: make([]api.Post, len(posts))}
for i := range posts {
resp.Posts[i] = api.DBPostToPost(posts[i].Post, posts[i].Blog)
}
return resp, nil
}

View file

@ -0,0 +1,13 @@
package timelines
import "git.sleepycat.moe/sam/mercury/web/app"
type App struct {
*app.App
}
func New(app *app.App) *App {
return &App{
App: app,
}
}

View file

@ -94,6 +94,13 @@ func (a *App) Post(q ...sql.Querier) *sql.PostStore {
return sql.NewPostStore(q[0])
}
func (a *App) Timeline(q ...sql.Querier) *sql.TimelineStore {
if len(q) == 0 || q[0] == nil {
return sql.NewTimelineStore(a.Database.PoolQuerier())
}
return sql.NewTimelineStore(q[0])
}
func (a *App) Token(q ...sql.Querier) *sql.TokenStore {
if len(q) == 0 || q[0] == nil {
return sql.NewTokenStore(a.Database.PoolQuerier())

View file

@ -5,6 +5,8 @@ import (
"git.sleepycat.moe/sam/mercury/web/api"
"git.sleepycat.moe/sam/mercury/web/api/accounts"
"git.sleepycat.moe/sam/mercury/web/api/blogs"
"git.sleepycat.moe/sam/mercury/web/api/posts"
"git.sleepycat.moe/sam/mercury/web/api/timelines"
"git.sleepycat.moe/sam/mercury/web/app"
"git.sleepycat.moe/sam/mercury/web/auth"
"git.sleepycat.moe/sam/mercury/web/frontend"
@ -22,6 +24,37 @@ func Routes(app *app.App) {
r.Post("/sign_up", auth.PostSignup)
})
// APIv1 handlers
app.Router.Route("/api/v1", func(r chi.Router) {
unauthedAccess := !app.AppConfig.Security.RestrictAPI
unauthedTimelineAccess := app.AppConfig.Security.PublicTimelines && !app.AppConfig.Security.RestrictAPI
_ = unauthedTimelineAccess
// account handlers
accounts := accounts.New(app)
r.With(app.APIAuth(database.TokenScopeAccountsRead, unauthedAccess)).
Get("/accounts/{accountID}", api.WrapHandlerT(accounts.GetID))
r.With(app.APIAuth(database.TokenScopeAccountsMe, false)).
Get("/accounts/@me", api.WrapHandlerT(accounts.GetMe))
blogs := blogs.New(app)
r.With(app.APIAuth(database.TokenScopeBlogsRead, unauthedAccess)).
Get("/blogs/{blogID}", api.WrapHandlerT(blogs.GetID))
r.With(app.APIAuth(database.TokenScopeBlogsRead, unauthedAccess)).
Get("/blogs/lookup/{blogName}", api.WrapHandlerT(blogs.LookupName))
posts := posts.New(app)
r.With(app.APIAuth(database.TokenScopePostsRead, unauthedAccess)).
Get("/posts/{postID}", api.WrapHandlerT(posts.GetID))
r.With(app.APIAuth(database.TokenScopePostsWrite, false)).
Post("/blogs/{blogID}/posts", api.WrapHandlerT(posts.Create))
timelines := timelines.New(app)
r.With(app.APIAuth(database.TokenScopeTimeline, false)).
Get("/timelines/home", api.WrapHandlerT(timelines.Home))
})
// web app handlers
// also assets
app.Router.Group(func(r chi.Router) {
@ -35,20 +68,4 @@ func Routes(app *app.App) {
r.HandleFunc("/web/@{username}", frontend.ServeUser)
r.HandleFunc("/web/@{username}/posts/{postID}", frontend.ServeStatus)
})
// APIv1 handlers
app.Router.Route("/api/v1", func(r chi.Router) {
// account handlers
accounts := accounts.New(app)
r.With(app.APIAuth(database.TokenScopeAccountsRead, true)).
Get("/accounts/{accountID}", api.WrapHandlerT(accounts.GetID))
r.With(app.APIAuth(database.TokenScopeAccountsMe, false)).
Get("/accounts/@me", api.WrapHandlerT(accounts.GetMe))
blogs := blogs.New(app)
r.With(app.APIAuth(database.TokenScopeBlogsRead, true)).
Get("/blogs/{blogID}", api.WrapHandlerT(blogs.GetID))
r.With(app.APIAuth(database.TokenScopeBlogsRead, true)).
Get("/blogs/lookup/{blogName}", api.WrapHandlerT(blogs.LookupName))
})
}