From 74614ed0e3ea9a82003b057a05dadae8c5ce6f45 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 15 Oct 2023 21:34:13 +0200 Subject: [PATCH] feat: distribute new local posts to followers via websocket --- internal/database/sql/blog.go | 12 +++++-- internal/processor/processor.go | 56 ++++++++++++++++++++++++++++- internal/streaming/event.go | 16 +++++++-- internal/streaming/sockets.go | 14 ++++++-- web/api/streaming/messages.go | 64 --------------------------------- web/api/streaming/module.go | 20 +++++++++-- 6 files changed, 107 insertions(+), 75 deletions(-) delete mode 100644 web/api/streaming/messages.go diff --git a/internal/database/sql/blog.go b/internal/database/sql/blog.go index 23200fb..44a667d 100644 --- a/internal/database/sql/blog.go +++ b/internal/database/sql/blog.go @@ -85,10 +85,16 @@ func (s *BlogStore) LocalCount(ctx context.Context) (count int64, err error) { return count, nil } -func (s *BlogStore) Followers(ctx context.Context, id ulid.ULID) ([]ulid.ULID, error) { - q := sqlf.Sprintf("SELECT account_id FROM account_follows WHERE blog_id = %s", id) +type BlogFollower struct { + AccountID ulid.ULID + IsLocal bool +} - follows, err := Select[ulid.ULID](ctx, s.q, q) +func (s *BlogStore) Followers(ctx context.Context, id ulid.ULID) ([]BlogFollower, error) { + q := sqlf.Sprintf(`SELECT f.account_id, a.domain IS NOT NULL AS is_local + FROM account_follows f JOIN accounts a ON a.id = f.account_id WHERE f.blog_id = %s`, id) + + follows, err := Select[BlogFollower](ctx, s.q, q) if err != nil { return nil, errors.Wrap(err, "executing query") } diff --git a/internal/processor/processor.go b/internal/processor/processor.go index eedf4ed..5993ee2 100644 --- a/internal/processor/processor.go +++ b/internal/processor/processor.go @@ -1,9 +1,14 @@ package processor import ( + "context" + "time" + "git.sleepycat.moe/sam/mercury/internal/database" "git.sleepycat.moe/sam/mercury/internal/database/sql" "git.sleepycat.moe/sam/mercury/internal/streaming" + "git.sleepycat.moe/sam/mercury/web/api" + "github.com/rs/zerolog/log" ) type Processor struct { @@ -14,7 +19,7 @@ type Processor struct { func New(db *sql.Base) *Processor { p := &Processor{ - SocketHolder: &streaming.SocketHolder{}, + SocketHolder: streaming.NewSocketHolder(), db: db, } @@ -23,5 +28,54 @@ func New(db *sql.Base) *Processor { } func (p *Processor) HandlePost(post database.Post) { + // this function is spun off in a separate goroutine, so we shouldn't use the same context as the parent request + // TODO: make timeout configurable? + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + conn, err := p.db.Acquire(ctx) + if err != nil { + log.Err(err).Msg("acquiring database") + return + } + defer conn.Release() + + // get the blog and their followers + blog, err := sql.NewBlogStore(conn).ByID(ctx, post.BlogID) + if err != nil { + log.Err(err).Msg("getting blog") + return + } + acct, err := sql.NewAccountStore(conn).ByID(ctx, blog.AccountID) + if err != nil { + log.Err(err).Msg("getting account") + return + } + followers, err := sql.NewBlogStore(conn).Followers(ctx, blog.ID) + if err != nil { + log.Err(err).Msg("getting followers") + return + } + + p.handlePostLocal(ctx, post, blog, acct, followers) +} + +// handlePostLocal handles a post on the local side--mostly sending notifications to websockets. +// All posts go through this function. +func (p *Processor) handlePostLocal( + ctx context.Context, post database.Post, + blog database.Blog, acct database.Account, + followers []sql.BlogFollower, +) { + // send to self + apiPost := api.DBPostToPost(post, blog, acct) + p.SocketHolder.Send(acct.ID, streaming.EventTypePost, apiPost) + + for _, follower := range followers { + if !follower.IsLocal { + continue + } + + p.SocketHolder.Send(follower.AccountID, streaming.EventTypePost, apiPost) + } } diff --git a/internal/streaming/event.go b/internal/streaming/event.go index 97b73dd..075ccd6 100644 --- a/internal/streaming/event.go +++ b/internal/streaming/event.go @@ -5,10 +5,20 @@ import "encoding/json" type EventType int8 const ( + // Sent when an error occurs + // Payload: {code:int, message:string} EventTypeError EventType = 1 - EventTypePost EventType = 2 + // Sent when echoing back received messages + // Payload: + EventTypeEcho EventType = 2 + // Sent on a new post being created + // Payload: + EventTypePost EventType = 3 - EventTypeSubscribe EventType = 126 + // Receive events + // Subscribe to a new event + EventTypeSubscribe EventType = 126 + // Unsubscribe from an event EventTypeUnsubscribe EventType = 127 ) @@ -16,6 +26,8 @@ func (et EventType) Valid() bool { switch et { case EventTypeError: return true + case EventTypeEcho: + return true case EventTypePost: return true case EventTypeSubscribe: diff --git a/internal/streaming/sockets.go b/internal/streaming/sockets.go index faf6360..fc214c5 100644 --- a/internal/streaming/sockets.go +++ b/internal/streaming/sockets.go @@ -13,6 +13,12 @@ type SocketHolder struct { mu sync.Mutex } +func NewSocketHolder() *SocketHolder { + return &SocketHolder{ + sockets: make(map[ulid.ULID]*userSockets), + } +} + func (sh *SocketHolder) Send(acctID ulid.ULID, et EventType, data any) { userSockets := sh.SocketsFor(acctID) @@ -22,7 +28,7 @@ func (sh *SocketHolder) Send(acctID ulid.ULID, et EventType, data any) { userSockets.mu.Unlock() for _, s := range sockets { - if s.willAcceptEvent(et) { + if s.WillAcceptEvent(et) { // the socket might block for a bit, so spin this off into a separate goroutine go func(s *Socket) { s.ch <- Event{Type: et, Data: data} @@ -71,7 +77,7 @@ type Socket struct { mu sync.RWMutex } -func (s *Socket) willAcceptEvent(mt EventType) bool { +func (s *Socket) WillAcceptEvent(mt EventType) bool { if mt == EventTypeError { return true } @@ -109,6 +115,8 @@ func NewSocket(ctx context.Context, cancel context.CancelFunc) *Socket { ctx: ctx, cancel: cancel, ch: make(chan Event), - types: make(map[EventType]struct{}), + types: map[EventType]struct{}{ + EventTypeEcho: {}, + }, } } diff --git a/web/api/streaming/messages.go b/web/api/streaming/messages.go deleted file mode 100644 index 0ef1697..0000000 --- a/web/api/streaming/messages.go +++ /dev/null @@ -1,64 +0,0 @@ -package streaming - -import "encoding/json" - -type EventType int8 - -const ( - EventTypeError EventType = 1 - EventTypePost EventType = 2 - - EventTypeSubscribe EventType = 126 - EventTypeUnsubscribe EventType = 127 -) - -func (et EventType) Valid() bool { - switch et { - case EventTypeError: - return true - case EventTypePost: - return true - case EventTypeSubscribe: - return true - case EventTypeUnsubscribe: - return true - default: - return false - } -} - -// Returns true if this event can be subscribed to/unsubscribed from -func (et EventType) ValidReceive() bool { - if !et.Valid() { - return false - } - - switch et { - case EventTypeError, EventTypeSubscribe, EventTypeUnsubscribe: - return false - default: - return true - } -} - -type Event struct { - Type EventType `json:"t"` - Data any `json:"d"` -} - -func newEvent(mt EventType, data any) *Event { - return &Event{ - Type: mt, - Data: data, - } -} - -type ErrorEvent struct { - Code int `json:"code"` - Message string `json:"message"` -} - -type IncomingEvent struct { - Type EventType `json:"t"` - Data json.RawMessage `json:"d"` // this is a RawMessage so we can easily unmarshal it later -} diff --git a/web/api/streaming/module.go b/web/api/streaming/module.go index cc0d81d..8024574 100644 --- a/web/api/streaming/module.go +++ b/web/api/streaming/module.go @@ -41,7 +41,10 @@ func (app *App) Streaming(w http.ResponseWriter, r *http.Request) error { ctx, cancel := context.WithCancel(context.Background()) socket, ok := app.Processor.SocketHolder.SocketsFor(token.UserID).NewSocket(ctx, cancel) if !ok { - err := conn.WriteJSON(newEvent(EventTypeError, ErrorEvent{Code: api.ErrTooManyStreams, Message: "Too many streams open"})) + err := conn.WriteJSON(streaming.Event{ + Type: streaming.EventTypeError, + Data: streaming.ErrorEvent{Code: api.ErrTooManyStreams, Message: "Too many streams open"}, + }) if err != nil { log.Err(err).Msg("writing stream rejection message to socket") } @@ -100,11 +103,24 @@ func (app *App) readStream(conn *websocket.Conn, socket *streaming.Socket) { } if !et.ValidReceive() { + log.Debug().Int("event", int(et)).Msg("invalid event for subscription") + // if it's not a valid event, ignore silently continue } - socket.SetEvent(et, e.Type != streaming.EventTypeUnsubscribe) + socket.SetEvent(et, e.Type == streaming.EventTypeSubscribe) + log.Debug(). + Int("event", int(et)). + Bool("subscribed", e.Type == streaming.EventTypeSubscribe). + Msg("toggled subscription status for event") + } + + if socket.WillAcceptEvent(streaming.EventTypeEcho) { + conn.WriteJSON(streaming.Event{ + Type: streaming.EventTypeEcho, + Data: e, + }) } } }