add a couple post endpoints + /timelines/home
This commit is contained in:
		
							parent
							
								
									dd72a1f4c1
								
							
						
					
					
						commit
						9f052dc9ef
					
				
					 24 changed files with 462 additions and 32 deletions
				
			
		|  | @ -5,6 +5,7 @@ import ( | ||||||
| 
 | 
 | ||||||
| 	"emperror.dev/errors" | 	"emperror.dev/errors" | ||||||
| 	"git.sleepycat.moe/sam/mercury/config" | 	"git.sleepycat.moe/sam/mercury/config" | ||||||
|  | 	"git.sleepycat.moe/sam/mercury/internal/database" | ||||||
| 	"git.sleepycat.moe/sam/mercury/internal/database/sql" | 	"git.sleepycat.moe/sam/mercury/internal/database/sql" | ||||||
| 	"git.sleepycat.moe/sam/mercury/web/app" | 	"git.sleepycat.moe/sam/mercury/web/app" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
|  | @ -52,7 +53,7 @@ func run(c *cli.Context) error { | ||||||
| 	log.Debug().Msg("Created blog") | 	log.Debug().Msg("Created blog") | ||||||
| 
 | 
 | ||||||
| 	log.Debug().Msg("Creating post") | 	log.Debug().Msg("Creating post") | ||||||
| 	post, err := a.Post().Create(c.Context, blog, "Hello world!") | 	post, err := a.Post().Create(c.Context, blog, "Hello world!", database.PublicVisibility) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Err(err).Msg("Creating post") | 		log.Err(err).Msg("Creating post") | ||||||
| 		return nil | 		return nil | ||||||
|  |  | ||||||
|  | @ -6,3 +6,9 @@ secret_key = "" # generate with `openssl rand -base64 48` | ||||||
| [web] | [web] | ||||||
| domain = "http://mercury.local" | domain = "http://mercury.local" | ||||||
| port = 8000 | port = 8000 | ||||||
|  | 
 | ||||||
|  | [security] | ||||||
|  | # whether the API should be entirely unavailable to unauthenticated users (except for OAuth endpoints) | ||||||
|  | restrict_api = false | ||||||
|  | # whether the local and global timelines should be available to unauthenticated users | ||||||
|  | public_timelines = false | ||||||
|  |  | ||||||
|  | @ -1,8 +1,9 @@ | ||||||
| package config | package config | ||||||
| 
 | 
 | ||||||
| type Config struct { | type Config struct { | ||||||
| 	Core CoreConfig `toml:"core"` | 	Core     CoreConfig     `toml:"core"` | ||||||
| 	Web  WebConfig  `toml:"web"` | 	Web      WebConfig      `toml:"web"` | ||||||
|  | 	Security SecurityConfig `toml:"security"` | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type WebConfig struct { | type WebConfig struct { | ||||||
|  | @ -18,3 +19,8 @@ type CoreConfig struct { | ||||||
| 	Dev       bool   `toml:"dev"` | 	Dev       bool   `toml:"dev"` | ||||||
| 	SecretKey string `toml:"secret_key"` | 	SecretKey string `toml:"secret_key"` | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | type SecurityConfig struct { | ||||||
|  | 	RestrictAPI     bool `toml:"restrict_api"` | ||||||
|  | 	PublicTimelines bool `toml:"public_timelines"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -6,6 +6,7 @@ type Config struct { | ||||||
| 	Name                string | 	Name                string | ||||||
| 	AdminID             *ulid.ULID | 	AdminID             *ulid.ULID | ||||||
| 	InternalApplication *ulid.ULID | 	InternalApplication *ulid.ULID | ||||||
|  | 	PostCharacterLimit  int | ||||||
| 
 | 
 | ||||||
| 	// ID is always 1 | 	// ID is always 1 | ||||||
| 	ID int | 	ID int | ||||||
|  |  | ||||||
							
								
								
									
										12
									
								
								internal/database/migrations/1693961486_follows.sql
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								internal/database/migrations/1693961486_follows.sql
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,12 @@ | ||||||
|  | -- 2023-09-06: Account->blog follows | ||||||
|  | 
 | ||||||
|  | -- +migrate Up | ||||||
|  | 
 | ||||||
|  | create table account_follows ( | ||||||
|  |     account_id text not null references accounts (id) on delete cascade, | ||||||
|  |     blog_id text not null references blogs (id) on delete cascade | ||||||
|  | ); | ||||||
|  | 
 | ||||||
|  | -- +migrate Down | ||||||
|  | 
 | ||||||
|  | drop table account_follows; | ||||||
|  | @ -0,0 +1,19 @@ | ||||||
|  | -- 2023-09-06: Some improvements to posts | ||||||
|  | 
 | ||||||
|  | -- +migrate Up | ||||||
|  | 
 | ||||||
|  | alter table config add column post_character_limit integer not null default 5000; | ||||||
|  | create index posts_blog_id_idx on posts (blog_id); | ||||||
|  | 
 | ||||||
|  | create table post_recipients ( | ||||||
|  |     post_id    TEXT NOT NULL REFERENCES posts (id) ON DELETE CASCADE, | ||||||
|  |     blog_id    TEXT NOT NULL REFERENCES blogs (id) ON DELETE CASCADE, | ||||||
|  | 
 | ||||||
|  |     PRIMARY KEY (post_id, blog_id) | ||||||
|  | ); | ||||||
|  | 
 | ||||||
|  | -- +migrate Down | ||||||
|  | 
 | ||||||
|  | drop table post_recipients; | ||||||
|  | drop index posts_blog_id_idx; | ||||||
|  | alter table config drop column post_character_limit; | ||||||
|  | @ -2,10 +2,20 @@ package database | ||||||
| 
 | 
 | ||||||
| import "github.com/oklog/ulid/v2" | import "github.com/oklog/ulid/v2" | ||||||
| 
 | 
 | ||||||
|  | type PostVisibility string | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	PublicVisibility    PostVisibility = "public" | ||||||
|  | 	UnlistedVisibility  PostVisibility = "unlisted" | ||||||
|  | 	FollowersVisibility PostVisibility = "followers" | ||||||
|  | 	DirectVisibility    PostVisibility = "direct" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| // Post is a single post. | // Post is a single post. | ||||||
| type Post struct { | type Post struct { | ||||||
| 	ID      ulid.ULID | 	ID         ulid.ULID | ||||||
| 	BlogID  ulid.ULID | 	BlogID     ulid.ULID | ||||||
| 	Content *string | 	Content    *string | ||||||
| 	Source  *string | 	Source     *string | ||||||
|  | 	Visibility PostVisibility | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -86,3 +86,13 @@ func (s *AccountStore) CreateLocal( | ||||||
| 	} | 	} | ||||||
| 	return a, nil | 	return a, nil | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func (s *AccountStore) Follows(ctx context.Context, id ulid.ULID) ([]ulid.ULID, error) { | ||||||
|  | 	q := sqlf.Sprintf("SELECT blog_id FROM account_follows WHERE account_id = %s", id) | ||||||
|  | 
 | ||||||
|  | 	follows, err := Select[ulid.ULID](ctx, s.q, q) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, errors.Wrap(err, "executing query") | ||||||
|  | 	} | ||||||
|  | 	return follows, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -84,3 +84,13 @@ func (s *BlogStore) LocalCount(ctx context.Context) (count int64, err error) { | ||||||
| 	} | 	} | ||||||
| 	return count, nil | 	return count, nil | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func (s *BlogStore) Followers(ctx context.Context, id ulid.ULID) ([]ulid.ULID, error) { | ||||||
|  | 	q := sqlf.Sprintf("SELECT account_id FROM account_follows WHERE blog_id = %s", id) | ||||||
|  | 
 | ||||||
|  | 	follows, err := Select[ulid.ULID](ctx, s.q, q) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, errors.Wrap(err, "executing query") | ||||||
|  | 	} | ||||||
|  | 	return follows, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -45,6 +45,9 @@ func (s *ConfigStore) Set(ctx context.Context, cur, new database.Config) (databa | ||||||
| 	if cur.InternalApplication != new.InternalApplication { | 	if cur.InternalApplication != new.InternalApplication { | ||||||
| 		q = sqlf.Sprintf("%v internal_application = %v,", q, new.InternalApplication) | 		q = sqlf.Sprintf("%v internal_application = %v,", q, new.InternalApplication) | ||||||
| 	} | 	} | ||||||
|  | 	if cur.PostCharacterLimit != new.PostCharacterLimit { | ||||||
|  | 		q = sqlf.Sprintf("%v post_character_limit = %v,", q, new.PostCharacterLimit) | ||||||
|  | 	} | ||||||
| 	q = sqlf.Sprintf("%v id = %v WHERE id = %v RETURNING *", q, cur.ID, cur.ID) | 	q = sqlf.Sprintf("%v id = %v WHERE id = %v RETURNING *", q, cur.ID, cur.ID) | ||||||
| 
 | 
 | ||||||
| 	return Get[database.Config](ctx, s.q, q) | 	return Get[database.Config](ctx, s.q, q) | ||||||
|  |  | ||||||
|  | @ -34,10 +34,10 @@ func (s *PostStore) ByID(ctx context.Context, id ulid.ULID) (p database.Post, er | ||||||
| 	return p, nil | 	return p, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *PostStore) Create(ctx context.Context, blog database.Blog, content string) (p database.Post, err error) { | func (s *PostStore) Create(ctx context.Context, blog database.Blog, content string, visibility database.PostVisibility) (p database.Post, err error) { | ||||||
| 	q := sqlf.Sprintf( | 	q := sqlf.Sprintf( | ||||||
| 		"INSERT INTO posts (id, blog_id, content, remote) VALUES (%s, %s, %s, %v) RETURNING *", | 		"INSERT INTO posts (id, blog_id, content, source, visibility) VALUES (%s, %s, %s, %s, %s) RETURNING *", | ||||||
| 		makeULID(), blog.ID, content, blog.Remote()) | 		makeULID(), blog.ID, content, content, visibility) | ||||||
| 
 | 
 | ||||||
| 	p, err = Get[database.Post](ctx, s.q, q) | 	p, err = Get[database.Post](ctx, s.q, q) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  |  | ||||||
|  | @ -8,8 +8,11 @@ import ( | ||||||
| 	"github.com/jackc/pgx/v5" | 	"github.com/jackc/pgx/v5" | ||||||
| 	"github.com/jackc/pgx/v5/pgconn" | 	"github.com/jackc/pgx/v5/pgconn" | ||||||
| 	"github.com/keegancsmith/sqlf" | 	"github.com/keegancsmith/sqlf" | ||||||
|  | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | var LogQueries = true | ||||||
|  | 
 | ||||||
| type Querier interface { | type Querier interface { | ||||||
| 	Query(ctx context.Context, query string, args ...interface{}) (pgx.Rows, error) | 	Query(ctx context.Context, query string, args ...interface{}) (pgx.Rows, error) | ||||||
| 	QueryRow(ctx context.Context, sql string, args ...any) pgx.Row | 	QueryRow(ctx context.Context, sql string, args ...any) pgx.Row | ||||||
|  | @ -32,6 +35,10 @@ type Tx interface { | ||||||
| func Select[T any](ctx context.Context, querier Querier, query *sqlf.Query) ([]T, error) { | func Select[T any](ctx context.Context, querier Querier, query *sqlf.Query) ([]T, error) { | ||||||
| 	dst := make([]T, 0) | 	dst := make([]T, 0) | ||||||
| 
 | 
 | ||||||
|  | 	if LogQueries { | ||||||
|  | 		log.Debug().Str("query", query.Query(sqlf.PostgresBindVar)).Msg("executing select query") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	err := pgxscan.Select(ctx, querier, &dst, query.Query(sqlf.PostgresBindVar), query.Args()...) | 	err := pgxscan.Select(ctx, querier, &dst, query.Query(sqlf.PostgresBindVar), query.Args()...) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return []T{}, errors.Wrap(err, "executing query") | 		return []T{}, errors.Wrap(err, "executing query") | ||||||
|  | @ -42,6 +49,10 @@ func Select[T any](ctx context.Context, querier Querier, query *sqlf.Query) ([]T | ||||||
| func Get[T any](ctx context.Context, querier Querier, query *sqlf.Query) (T, error) { | func Get[T any](ctx context.Context, querier Querier, query *sqlf.Query) (T, error) { | ||||||
| 	var dst T | 	var dst T | ||||||
| 
 | 
 | ||||||
|  | 	if LogQueries { | ||||||
|  | 		log.Debug().Str("query", query.Query(sqlf.PostgresBindVar)).Msg("executing get query") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	err := pgxscan.Get(ctx, querier, &dst, query.Query(sqlf.PostgresBindVar), query.Args()...) | 	err := pgxscan.Get(ctx, querier, &dst, query.Query(sqlf.PostgresBindVar), query.Args()...) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return dst, errors.Wrap(err, "executing query") | 		return dst, errors.Wrap(err, "executing query") | ||||||
|  | @ -50,6 +61,10 @@ func Get[T any](ctx context.Context, querier Querier, query *sqlf.Query) (T, err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func Exec(ctx context.Context, querier Querier, query *sqlf.Query) error { | func Exec(ctx context.Context, querier Querier, query *sqlf.Query) error { | ||||||
|  | 	if LogQueries { | ||||||
|  | 		log.Debug().Str("query", query.Query(sqlf.PostgresBindVar)).Msg("executing exec query") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	_, err := querier.Exec(ctx, query.Query(sqlf.PostgresBindVar), query.Args()...) | 	_, err := querier.Exec(ctx, query.Query(sqlf.PostgresBindVar), query.Args()...) | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
							
								
								
									
										49
									
								
								internal/database/sql/timeline.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								internal/database/sql/timeline.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,49 @@ | ||||||
|  | package sql | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 
 | ||||||
|  | 	"git.sleepycat.moe/sam/mercury/internal/database" | ||||||
|  | 	"github.com/keegancsmith/sqlf" | ||||||
|  | 	"github.com/oklog/ulid/v2" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // TimelineStore is the interface to post timelines in the database. | ||||||
|  | type TimelineStore struct { | ||||||
|  | 	q Querier | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // NewTimelineStore creates a new TimelineStore instance. | ||||||
|  | func NewTimelineStore(q Querier) *TimelineStore { | ||||||
|  | 	return &TimelineStore{q: q} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type TimelinePost struct { | ||||||
|  | 	database.Post | ||||||
|  | 	database.Blog | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *TimelineStore) Home(ctx context.Context, accountID ulid.ULID, limit int, before, after *ulid.ULID) ([]TimelinePost, error) { | ||||||
|  | 	q := sqlf.Sprintf("SELECT p.*, b.account_id, b.name, b.domain FROM posts p JOIN blogs b ON b.id = p.blog_id") | ||||||
|  | 
 | ||||||
|  | 	q = sqlf.Sprintf("%v WHERE (blog_id IN (%s) OR blog_id IN (%s))", q, | ||||||
|  | 		sqlf.Sprintf("SELECT id FROM blogs WHERE account_id = %s", accountID), | ||||||
|  | 		sqlf.Sprintf("SELECT blog_id FROM account_follows WHERE account_id = %s", accountID)) | ||||||
|  | 
 | ||||||
|  | 	if before != nil { | ||||||
|  | 		q = sqlf.Sprintf("%v AND id < %s", q, *before) | ||||||
|  | 	} | ||||||
|  | 	if after != nil { | ||||||
|  | 		q = sqlf.Sprintf("%v AND id > %s", q, *after) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	q = sqlf.Sprintf("%v AND (visibility != %s OR (b.account_id = %s OR %s IN (%s)))", q, database.DirectVisibility, accountID, accountID, | ||||||
|  | 		sqlf.Sprintf("SELECT account_id FROM blogs WHERE id IN (SELECT blog_id FROM post_recipients WHERE post_id = p.id)")) | ||||||
|  | 
 | ||||||
|  | 	if limit <= 0 || limit > 100 { | ||||||
|  | 		limit = 100 | ||||||
|  | 	} | ||||||
|  | 	q = sqlf.Sprintf("%v ORDER BY id DESC LIMIT %d", q, limit) | ||||||
|  | 
 | ||||||
|  | 	return Select[TimelinePost](ctx, s.q, q) | ||||||
|  | } | ||||||
|  | @ -55,6 +55,9 @@ const ( | ||||||
| 	TokenScopeAccountsWrite TokenScope = "accounts.write" | 	TokenScopeAccountsWrite TokenScope = "accounts.write" | ||||||
| 	TokenScopeBlogsRead     TokenScope = "blogs.read" | 	TokenScopeBlogsRead     TokenScope = "blogs.read" | ||||||
| 	TokenScopeBlogsWrite    TokenScope = "blogs.write" | 	TokenScopeBlogsWrite    TokenScope = "blogs.write" | ||||||
|  | 	TokenScopePostsRead     TokenScope = "posts.read" | ||||||
|  | 	TokenScopePostsWrite    TokenScope = "posts.write" | ||||||
|  | 	TokenScopeTimeline      TokenScope = "timeline" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func (s TokenScope) IsValid() bool { | func (s TokenScope) IsValid() bool { | ||||||
|  |  | ||||||
							
								
								
									
										8
									
								
								internal/utils/strings.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								internal/utils/strings.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,8 @@ | ||||||
|  | package utils | ||||||
|  | 
 | ||||||
|  | import "unicode/utf8" | ||||||
|  | 
 | ||||||
|  | // StringLength returns the length of a string by counting runes. | ||||||
|  | func StringLength(s string) int { | ||||||
|  | 	return utf8.RuneCountInString(s) | ||||||
|  | } | ||||||
|  | @ -122,14 +122,18 @@ const ( | ||||||
| 	ErrInternalServerError = 500 // catch-all code for unknown errors | 	ErrInternalServerError = 500 // catch-all code for unknown errors | ||||||
| 
 | 
 | ||||||
| 	// Auth related | 	// Auth related | ||||||
| 	ErrInvalidToken = 1001 | 	ErrInvalidToken  = 1001 | ||||||
| 	ErrMissingScope = 1002 | 	ErrMissingScope  = 1002 | ||||||
|  | 	ErrNotYourObject = 1003 | ||||||
| 
 | 
 | ||||||
| 	// Account related | 	// Account related | ||||||
| 	ErrAccountNotFound = 2001 | 	ErrAccountNotFound = 2001 | ||||||
| 
 | 
 | ||||||
| 	// Blog related | 	// Blog related | ||||||
| 	ErrBlogNotFound = 3001 | 	ErrBlogNotFound = 3001 | ||||||
|  | 
 | ||||||
|  | 	// Post related | ||||||
|  | 	ErrPostNotFound = 4001 | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func ErrCodeMessage(code int) string { | func ErrCodeMessage(code int) string { | ||||||
|  | @ -144,12 +148,15 @@ var errCodeMessages = map[int]string{ | ||||||
| 	ErrTooManyRequests:     "Rate limit reached", | 	ErrTooManyRequests:     "Rate limit reached", | ||||||
| 	ErrMethodNotAllowed:    "Method not allowed", | 	ErrMethodNotAllowed:    "Method not allowed", | ||||||
| 
 | 
 | ||||||
| 	ErrInvalidToken: "No token supplied, or token is invalid", | 	ErrInvalidToken:  "No token supplied, or token is invalid", | ||||||
| 	ErrMissingScope: "Token is missing required scope for this endpoint", | 	ErrMissingScope:  "Token is missing required scope for this endpoint", | ||||||
|  | 	ErrNotYourObject: "Object you are trying to perform action on is not owned by you", | ||||||
| 
 | 
 | ||||||
| 	ErrAccountNotFound: "Account not found", | 	ErrAccountNotFound: "Account not found", | ||||||
| 
 | 
 | ||||||
| 	ErrBlogNotFound: "Blog not found", | 	ErrBlogNotFound: "Blog not found", | ||||||
|  | 
 | ||||||
|  | 	ErrPostNotFound: "Post not found", | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func ErrCodeStatus(code int) int { | func ErrCodeStatus(code int) int { | ||||||
|  | @ -164,10 +171,13 @@ var errCodeStatuses = map[int]int{ | ||||||
| 	ErrTooManyRequests:     http.StatusTooManyRequests, | 	ErrTooManyRequests:     http.StatusTooManyRequests, | ||||||
| 	ErrMethodNotAllowed:    http.StatusMethodNotAllowed, | 	ErrMethodNotAllowed:    http.StatusMethodNotAllowed, | ||||||
| 
 | 
 | ||||||
| 	ErrInvalidToken: http.StatusUnauthorized, | 	ErrInvalidToken:  http.StatusUnauthorized, | ||||||
| 	ErrMissingScope: http.StatusForbidden, | 	ErrMissingScope:  http.StatusForbidden, | ||||||
|  | 	ErrNotYourObject: http.StatusForbidden, | ||||||
| 
 | 
 | ||||||
| 	ErrAccountNotFound: http.StatusNotFound, | 	ErrAccountNotFound: http.StatusNotFound, | ||||||
| 
 | 
 | ||||||
| 	ErrBlogNotFound: http.StatusNotFound, | 	ErrBlogNotFound: http.StatusNotFound, | ||||||
|  | 
 | ||||||
|  | 	ErrPostNotFound: http.StatusNotFound, | ||||||
| } | } | ||||||
|  |  | ||||||
							
								
								
									
										35
									
								
								web/api/post.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								web/api/post.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,35 @@ | ||||||
|  | package api | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"git.sleepycat.moe/sam/mercury/internal/database" | ||||||
|  | 	"github.com/oklog/ulid/v2" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type Post struct { | ||||||
|  | 	ID         ulid.ULID               `json:"id"` | ||||||
|  | 	Content    *string                 `json:"content"` | ||||||
|  | 	Source     *string                 `json:"source"` | ||||||
|  | 	Visibility database.PostVisibility `json:"visibility"` | ||||||
|  | 
 | ||||||
|  | 	Blog postPartialBlog `json:"blog"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type postPartialBlog struct { | ||||||
|  | 	ID     ulid.ULID `json:"id"` | ||||||
|  | 	Name   string    `json:"name"` | ||||||
|  | 	Domain *string   `json:"domain"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func DBPostToPost(p database.Post, b database.Blog) Post { | ||||||
|  | 	return Post{ | ||||||
|  | 		ID:         p.ID, | ||||||
|  | 		Content:    p.Content, | ||||||
|  | 		Source:     p.Source, | ||||||
|  | 		Visibility: p.Visibility, | ||||||
|  | 		Blog: postPartialBlog{ | ||||||
|  | 			ID:     p.BlogID, | ||||||
|  | 			Name:   b.Name, | ||||||
|  | 			Domain: b.Domain, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										91
									
								
								web/api/posts/create_post.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								web/api/posts/create_post.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,91 @@ | ||||||
|  | package posts | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"net/http" | ||||||
|  | 
 | ||||||
|  | 	"git.sleepycat.moe/sam/mercury/internal/database" | ||||||
|  | 	"git.sleepycat.moe/sam/mercury/internal/database/sql" | ||||||
|  | 	"git.sleepycat.moe/sam/mercury/internal/utils" | ||||||
|  | 	"git.sleepycat.moe/sam/mercury/web/api" | ||||||
|  | 	"github.com/go-chi/chi/v5" | ||||||
|  | 	"github.com/go-chi/render" | ||||||
|  | 	"github.com/oklog/ulid/v2" | ||||||
|  | 	"github.com/rs/zerolog/log" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type createPostParams struct { | ||||||
|  | 	Content    string                  `json:"content"` | ||||||
|  | 	Visibility database.PostVisibility `json:"visibility"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p createPostParams) Validate(cfg database.Config) bool { | ||||||
|  | 	if p.Content == "" { // TODO: allow empty content if the post has attachments | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	if p.Visibility == "" { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if utils.StringLength(p.Content) > cfg.PostCharacterLimit { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	switch p.Visibility { | ||||||
|  | 	case database.PublicVisibility, database.UnlistedVisibility, database.FollowersVisibility, database.DirectVisibility: | ||||||
|  | 	default: | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// everything checks out | ||||||
|  | 	return true | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (app *App) Create(w http.ResponseWriter, r *http.Request) (api.Post, error) { | ||||||
|  | 	ctx := r.Context() | ||||||
|  | 	token, _ := app.TokenFromContext(ctx) | ||||||
|  | 
 | ||||||
|  | 	var req createPostParams | ||||||
|  | 	err := render.Decode(r, &req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return api.Post{}, api.Error{Code: api.ErrBadRequest} | ||||||
|  | 	} | ||||||
|  | 	if !req.Validate(app.DBConfig.Get()) { | ||||||
|  | 		return api.Post{}, api.Error{Code: api.ErrBadRequest} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	blogID, err := ulid.Parse(chi.URLParam(r, "blogID")) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return api.Post{}, api.Error{Code: api.ErrBlogNotFound} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	conn, err := app.Database.Acquire(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Err(err).Msg("acquiring connection") | ||||||
|  | 		return api.Post{}, err | ||||||
|  | 	} | ||||||
|  | 	defer conn.Release() | ||||||
|  | 
 | ||||||
|  | 	blog, err := app.Blog(conn).ByID(ctx, blogID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		if err == sql.ErrNotFound { | ||||||
|  | 			return api.Post{}, api.Error{Code: api.ErrBlogNotFound} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		log.Err(err).Msg("fetching blog") | ||||||
|  | 		return api.Post{}, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if blog.AccountID != token.UserID { | ||||||
|  | 		return api.Post{}, api.Error{Code: api.ErrNotYourObject} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// create post | ||||||
|  | 	post, err := app.Post(conn).Create(ctx, blog, req.Content, req.Visibility) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Err(err).Msg("creating post") | ||||||
|  | 		return api.Post{}, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// TODO: federate post + push to websockets | ||||||
|  | 	return api.DBPostToPost(post, blog), nil | ||||||
|  | } | ||||||
							
								
								
									
										44
									
								
								web/api/posts/get_post.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								web/api/posts/get_post.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,44 @@ | ||||||
|  | package posts | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"net/http" | ||||||
|  | 
 | ||||||
|  | 	"git.sleepycat.moe/sam/mercury/internal/database/sql" | ||||||
|  | 	"git.sleepycat.moe/sam/mercury/web/api" | ||||||
|  | 	"github.com/go-chi/chi/v5" | ||||||
|  | 	"github.com/oklog/ulid/v2" | ||||||
|  | 	"github.com/rs/zerolog/log" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func (app *App) GetID(w http.ResponseWriter, r *http.Request) (api.Post, error) { | ||||||
|  | 	ctx := r.Context() | ||||||
|  | 	id, err := ulid.Parse(chi.URLParamFromCtx(ctx, "postID")) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return api.Post{}, api.Error{Code: api.ErrBlogNotFound} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	conn, err := app.Database.Acquire(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Err(err).Msg("acquiring connection") | ||||||
|  | 		return api.Post{}, err | ||||||
|  | 	} | ||||||
|  | 	defer conn.Release() | ||||||
|  | 
 | ||||||
|  | 	post, err := app.Post(conn).ByID(ctx, id) | ||||||
|  | 	if err != nil { | ||||||
|  | 		if err == sql.ErrNotFound { | ||||||
|  | 			return api.Post{}, api.Error{Code: api.ErrBlogNotFound} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		log.Err(err).Str("id", id.String()).Msg("fetching post from database") | ||||||
|  | 		return api.Post{}, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	blog, err := app.Blog(conn).ByID(ctx, post.BlogID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Err(err).Str("id", post.BlogID.String()).Msg("fetching blog from database") | ||||||
|  | 		return api.Post{}, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return api.DBPostToPost(post, blog), nil | ||||||
|  | } | ||||||
							
								
								
									
										13
									
								
								web/api/posts/module.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								web/api/posts/module.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,13 @@ | ||||||
|  | package posts | ||||||
|  | 
 | ||||||
|  | import "git.sleepycat.moe/sam/mercury/web/app" | ||||||
|  | 
 | ||||||
|  | type App struct { | ||||||
|  | 	*app.App | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func New(app *app.App) *App { | ||||||
|  | 	return &App{ | ||||||
|  | 		App: app, | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										47
									
								
								web/api/timelines/home_timeline.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								web/api/timelines/home_timeline.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,47 @@ | ||||||
|  | package timelines | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"net/http" | ||||||
|  | 	"strconv" | ||||||
|  | 
 | ||||||
|  | 	"git.sleepycat.moe/sam/mercury/web/api" | ||||||
|  | 	"github.com/oklog/ulid/v2" | ||||||
|  | 	"github.com/rs/zerolog/log" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type timelineResponse struct { | ||||||
|  | 	Posts []api.Post `json:"posts"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (app *App) Home(w http.ResponseWriter, r *http.Request) (timelineResponse, error) { | ||||||
|  | 	ctx := r.Context() | ||||||
|  | 	token, _ := app.TokenFromContext(ctx) | ||||||
|  | 
 | ||||||
|  | 	var before, after *ulid.ULID | ||||||
|  | 	if id, err := ulid.Parse(r.FormValue("before")); err == nil { | ||||||
|  | 		before = &id | ||||||
|  | 	} | ||||||
|  | 	if id, err := ulid.Parse(r.FormValue("after")); err == nil { | ||||||
|  | 		after = &id | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if before != nil && after != nil { | ||||||
|  | 		return timelineResponse{}, api.Error{Code: api.ErrBadRequest, Details: "`before` and `after` are mutually exclusive"} | ||||||
|  | 	} | ||||||
|  | 	limit, err := strconv.Atoi(r.FormValue("limit")) | ||||||
|  | 	if err != nil || limit <= 0 || limit > 100 { | ||||||
|  | 		limit = 100 | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	posts, err := app.Timeline().Home(ctx, token.UserID, limit, before, after) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Err(err).Msg("getting posts from database") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	resp := timelineResponse{Posts: make([]api.Post, len(posts))} | ||||||
|  | 	for i := range posts { | ||||||
|  | 		resp.Posts[i] = api.DBPostToPost(posts[i].Post, posts[i].Blog) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return resp, nil | ||||||
|  | } | ||||||
							
								
								
									
										13
									
								
								web/api/timelines/module.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								web/api/timelines/module.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,13 @@ | ||||||
|  | package timelines | ||||||
|  | 
 | ||||||
|  | import "git.sleepycat.moe/sam/mercury/web/app" | ||||||
|  | 
 | ||||||
|  | type App struct { | ||||||
|  | 	*app.App | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func New(app *app.App) *App { | ||||||
|  | 	return &App{ | ||||||
|  | 		App: app, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -94,6 +94,13 @@ func (a *App) Post(q ...sql.Querier) *sql.PostStore { | ||||||
| 	return sql.NewPostStore(q[0]) | 	return sql.NewPostStore(q[0]) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (a *App) Timeline(q ...sql.Querier) *sql.TimelineStore { | ||||||
|  | 	if len(q) == 0 || q[0] == nil { | ||||||
|  | 		return sql.NewTimelineStore(a.Database.PoolQuerier()) | ||||||
|  | 	} | ||||||
|  | 	return sql.NewTimelineStore(q[0]) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (a *App) Token(q ...sql.Querier) *sql.TokenStore { | func (a *App) Token(q ...sql.Querier) *sql.TokenStore { | ||||||
| 	if len(q) == 0 || q[0] == nil { | 	if len(q) == 0 || q[0] == nil { | ||||||
| 		return sql.NewTokenStore(a.Database.PoolQuerier()) | 		return sql.NewTokenStore(a.Database.PoolQuerier()) | ||||||
|  |  | ||||||
|  | @ -5,6 +5,8 @@ import ( | ||||||
| 	"git.sleepycat.moe/sam/mercury/web/api" | 	"git.sleepycat.moe/sam/mercury/web/api" | ||||||
| 	"git.sleepycat.moe/sam/mercury/web/api/accounts" | 	"git.sleepycat.moe/sam/mercury/web/api/accounts" | ||||||
| 	"git.sleepycat.moe/sam/mercury/web/api/blogs" | 	"git.sleepycat.moe/sam/mercury/web/api/blogs" | ||||||
|  | 	"git.sleepycat.moe/sam/mercury/web/api/posts" | ||||||
|  | 	"git.sleepycat.moe/sam/mercury/web/api/timelines" | ||||||
| 	"git.sleepycat.moe/sam/mercury/web/app" | 	"git.sleepycat.moe/sam/mercury/web/app" | ||||||
| 	"git.sleepycat.moe/sam/mercury/web/auth" | 	"git.sleepycat.moe/sam/mercury/web/auth" | ||||||
| 	"git.sleepycat.moe/sam/mercury/web/frontend" | 	"git.sleepycat.moe/sam/mercury/web/frontend" | ||||||
|  | @ -22,6 +24,37 @@ func Routes(app *app.App) { | ||||||
| 		r.Post("/sign_up", auth.PostSignup) | 		r.Post("/sign_up", auth.PostSignup) | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
|  | 	// APIv1 handlers | ||||||
|  | 	app.Router.Route("/api/v1", func(r chi.Router) { | ||||||
|  | 		unauthedAccess := !app.AppConfig.Security.RestrictAPI | ||||||
|  | 		unauthedTimelineAccess := app.AppConfig.Security.PublicTimelines && !app.AppConfig.Security.RestrictAPI | ||||||
|  | 		_ = unauthedTimelineAccess | ||||||
|  | 
 | ||||||
|  | 		// account handlers | ||||||
|  | 		accounts := accounts.New(app) | ||||||
|  | 		r.With(app.APIAuth(database.TokenScopeAccountsRead, unauthedAccess)). | ||||||
|  | 			Get("/accounts/{accountID}", api.WrapHandlerT(accounts.GetID)) | ||||||
|  | 		r.With(app.APIAuth(database.TokenScopeAccountsMe, false)). | ||||||
|  | 			Get("/accounts/@me", api.WrapHandlerT(accounts.GetMe)) | ||||||
|  | 
 | ||||||
|  | 		blogs := blogs.New(app) | ||||||
|  | 		r.With(app.APIAuth(database.TokenScopeBlogsRead, unauthedAccess)). | ||||||
|  | 			Get("/blogs/{blogID}", api.WrapHandlerT(blogs.GetID)) | ||||||
|  | 		r.With(app.APIAuth(database.TokenScopeBlogsRead, unauthedAccess)). | ||||||
|  | 			Get("/blogs/lookup/{blogName}", api.WrapHandlerT(blogs.LookupName)) | ||||||
|  | 
 | ||||||
|  | 		posts := posts.New(app) | ||||||
|  | 		r.With(app.APIAuth(database.TokenScopePostsRead, unauthedAccess)). | ||||||
|  | 			Get("/posts/{postID}", api.WrapHandlerT(posts.GetID)) | ||||||
|  | 		r.With(app.APIAuth(database.TokenScopePostsWrite, false)). | ||||||
|  | 			Post("/blogs/{blogID}/posts", api.WrapHandlerT(posts.Create)) | ||||||
|  | 
 | ||||||
|  | 		timelines := timelines.New(app) | ||||||
|  | 		r.With(app.APIAuth(database.TokenScopeTimeline, false)). | ||||||
|  | 			Get("/timelines/home", api.WrapHandlerT(timelines.Home)) | ||||||
|  | 
 | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
| 	// web app handlers | 	// web app handlers | ||||||
| 	// also assets | 	// also assets | ||||||
| 	app.Router.Group(func(r chi.Router) { | 	app.Router.Group(func(r chi.Router) { | ||||||
|  | @ -35,20 +68,4 @@ func Routes(app *app.App) { | ||||||
| 		r.HandleFunc("/web/@{username}", frontend.ServeUser) | 		r.HandleFunc("/web/@{username}", frontend.ServeUser) | ||||||
| 		r.HandleFunc("/web/@{username}/posts/{postID}", frontend.ServeStatus) | 		r.HandleFunc("/web/@{username}/posts/{postID}", frontend.ServeStatus) | ||||||
| 	}) | 	}) | ||||||
| 
 |  | ||||||
| 	// APIv1 handlers |  | ||||||
| 	app.Router.Route("/api/v1", func(r chi.Router) { |  | ||||||
| 		// account handlers |  | ||||||
| 		accounts := accounts.New(app) |  | ||||||
| 		r.With(app.APIAuth(database.TokenScopeAccountsRead, true)). |  | ||||||
| 			Get("/accounts/{accountID}", api.WrapHandlerT(accounts.GetID)) |  | ||||||
| 		r.With(app.APIAuth(database.TokenScopeAccountsMe, false)). |  | ||||||
| 			Get("/accounts/@me", api.WrapHandlerT(accounts.GetMe)) |  | ||||||
| 
 |  | ||||||
| 		blogs := blogs.New(app) |  | ||||||
| 		r.With(app.APIAuth(database.TokenScopeBlogsRead, true)). |  | ||||||
| 			Get("/blogs/{blogID}", api.WrapHandlerT(blogs.GetID)) |  | ||||||
| 		r.With(app.APIAuth(database.TokenScopeBlogsRead, true)). |  | ||||||
| 			Get("/blogs/lookup/{blogName}", api.WrapHandlerT(blogs.LookupName)) |  | ||||||
| 	}) |  | ||||||
| } | } | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue