diff --git a/go.mod b/go.mod index 2b2506c..e11343a 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/go-chi/chi/v5 v5.0.8 github.com/go-chi/render v1.0.2 github.com/golang-jwt/jwt/v4 v4.5.0 + github.com/gorilla/websocket v1.4.0 github.com/jackc/pgx/v5 v5.0.0 github.com/keegancsmith/sqlf v1.1.1 github.com/oklog/ulid/v2 v2.1.0 diff --git a/go.sum b/go.sum index 8a8509b..e2059d2 100644 --- a/go.sum +++ b/go.sum @@ -207,6 +207,7 @@ github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/websocket v1.4.0 h1:WDFjx/TMzVgy9VdMMQi2K2Emtwi2QcUQsztZ/zLaH/Q= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= diff --git a/internal/database/token.go b/internal/database/token.go index 9df4b03..ca330ec 100644 --- a/internal/database/token.go +++ b/internal/database/token.go @@ -58,6 +58,7 @@ const ( TokenScopePostsRead TokenScope = "posts.read" TokenScopePostsWrite TokenScope = "posts.write" TokenScopeTimeline TokenScope = "timeline" + TokenScopeStreaming TokenScope = "streaming" ) func (s TokenScope) IsValid() bool { diff --git a/web/api/error.go b/web/api/error.go index 9cc2b6e..3642fbe 100644 --- a/web/api/error.go +++ b/web/api/error.go @@ -134,6 +134,9 @@ const ( // Post related ErrPostNotFound = 4001 + + // Streaming related + ErrTooManyStreams = 5001 ) func ErrCodeMessage(code int) string { @@ -157,6 +160,8 @@ var errCodeMessages = map[int]string{ ErrBlogNotFound: "Blog not found", ErrPostNotFound: "Post not found", + + ErrTooManyStreams: "Too many streams open", } func ErrCodeStatus(code int) int { @@ -180,4 +185,6 @@ var errCodeStatuses = map[int]int{ ErrBlogNotFound: http.StatusNotFound, ErrPostNotFound: http.StatusNotFound, + + ErrTooManyStreams: http.StatusBadRequest, } diff --git a/web/api/streaming/messages.go b/web/api/streaming/messages.go new file mode 100644 index 0000000..0ef1697 --- /dev/null +++ b/web/api/streaming/messages.go @@ -0,0 +1,64 @@ +package streaming + +import "encoding/json" + +type EventType int8 + +const ( + EventTypeError EventType = 1 + EventTypePost EventType = 2 + + EventTypeSubscribe EventType = 126 + EventTypeUnsubscribe EventType = 127 +) + +func (et EventType) Valid() bool { + switch et { + case EventTypeError: + return true + case EventTypePost: + return true + case EventTypeSubscribe: + return true + case EventTypeUnsubscribe: + return true + default: + return false + } +} + +// Returns true if this event can be subscribed to/unsubscribed from +func (et EventType) ValidReceive() bool { + if !et.Valid() { + return false + } + + switch et { + case EventTypeError, EventTypeSubscribe, EventTypeUnsubscribe: + return false + default: + return true + } +} + +type Event struct { + Type EventType `json:"t"` + Data any `json:"d"` +} + +func newEvent(mt EventType, data any) *Event { + return &Event{ + Type: mt, + Data: data, + } +} + +type ErrorEvent struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type IncomingEvent struct { + Type EventType `json:"t"` + Data json.RawMessage `json:"d"` // this is a RawMessage so we can easily unmarshal it later +} diff --git a/web/api/streaming/module.go b/web/api/streaming/module.go new file mode 100644 index 0000000..2402598 --- /dev/null +++ b/web/api/streaming/module.go @@ -0,0 +1,110 @@ +package streaming + +import ( + "context" + "encoding/json" + "errors" + "net/http" + + "git.sleepycat.moe/sam/mercury/web/api" + "git.sleepycat.moe/sam/mercury/web/app" + "github.com/gorilla/websocket" + "github.com/rs/zerolog/log" +) + +type App struct { + *app.App +} + +func New(app *app.App) *App { + return &App{ + App: app, + } +} + +var upgrader = websocket.Upgrader{} + +func (app *App) Streaming(w http.ResponseWriter, r *http.Request) error { + token, _ := app.TokenFromContext(r.Context()) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Err(err).Msg("performing websocket handshake") + hse := websocket.HandshakeError{} + if errors.As(err, &hse) { + return nil // the upgrader already sent a response for us + } + + return err + } + + ctx, cancel := context.WithCancel(context.Background()) + socket, ok := SocketHolder.socketsFor(token.UserID).newSocket(ctx, cancel) + if !ok { + err := conn.WriteJSON(newEvent(EventTypeError, ErrorEvent{Code: api.ErrTooManyStreams, Message: "Too many streams open"})) + if err != nil { + log.Err(err).Msg("writing stream rejection message to socket") + } + + return nil + } + + go app.writeStream(conn, socket) + go app.readStream(conn, socket) + + return nil +} + +func (app *App) writeStream(conn *websocket.Conn, socket *socket) { + defer conn.Close() + + for { + select { + case <-socket.ctx.Done(): + return + case ev := <-socket.ch: + // at this point, the type should already have been filtered, so just send the event + err := conn.WriteJSON(ev) + if err != nil { + // write failed, bail and make client reconnect + log.Err(err).Msg("error writing JSON to socket") + socket.cancel() + } + } + } +} + +func (app *App) readStream(conn *websocket.Conn, socket *socket) { + for { + select { + case <-socket.ctx.Done(): + return + default: + var e IncomingEvent + err := conn.ReadJSON(&e) + if err != nil { + // read failed, bail and make client reconnect + log.Err(err).Msg("error reading JSON from socket") + socket.cancel() + return + } + + switch e.Type { + case EventTypeSubscribe, EventTypeUnsubscribe: + var et EventType + err = json.Unmarshal(e.Data, &et) + if err != nil { + // invalid event type, log but don't disconnect + log.Err(err).Msg("reading event type to subscribe to") + continue + } + + if !et.ValidReceive() { + // if it's not a valid event, ignore silently + continue + } + + socket.setEvent(et, e.Type != EventTypeSubscribe) + } + } + } +} diff --git a/web/api/streaming/sockets.go b/web/api/streaming/sockets.go new file mode 100644 index 0000000..ef6fa41 --- /dev/null +++ b/web/api/streaming/sockets.go @@ -0,0 +1,103 @@ +package streaming + +import ( + "context" + "sync" + + "github.com/oklog/ulid/v2" +) + +var SocketHolder socketHolder + +type socketHolder struct { + // map of sockets to + sockets map[ulid.ULID]*userSockets + mu sync.Mutex +} + +func (sh *socketHolder) Send(acctID ulid.ULID, et EventType, data any) { + userSockets := sh.socketsFor(acctID) + + userSockets.mu.Lock() + sockets := make([]*socket, len(userSockets.sockets)) + copy(sockets, userSockets.sockets) + userSockets.mu.Unlock() + + for _, s := range sockets { + if s.willAcceptEvent(et) { + // the socket might block for a bit, so spin this off into a separate goroutine + go func(s *socket) { + s.ch <- Event{Type: et, Data: data} + }(s) + } + } +} + +func (s *socketHolder) socketsFor(acct ulid.ULID) *userSockets { + s.mu.Lock() + defer s.mu.Unlock() + + us, ok := s.sockets[acct] + if !ok { + us = &userSockets{} + s.sockets[acct] = us + } + return us +} + +const sessionCountLimit = 50 // no more than 50 concurrent sessions per user + +type userSockets struct { + mu sync.Mutex + sockets []*socket +} + +func (s *userSockets) newSocket(ctx context.Context, cancel context.CancelFunc) (*socket, bool) { + s.mu.Lock() + if len(s.sockets) >= sessionCountLimit { + return nil, false + } + socket := newSocket(ctx, cancel) + s.sockets = append(s.sockets, socket) + return socket, true +} + +type socket struct { + ctx context.Context + cancel context.CancelFunc + + ch chan Event + types map[EventType]struct{} + + mu sync.RWMutex +} + +func (s *socket) willAcceptEvent(mt EventType) bool { + if mt == EventTypeError { + return true + } + + s.mu.RLock() + _, ok := s.types[mt] + s.mu.RUnlock() + return ok +} + +func (s *socket) setEvent(mt EventType, add bool) { + s.mu.Lock() + if add { + s.types[mt] = struct{}{} + } else { + delete(s.types, mt) + } + s.mu.Unlock() +} + +func newSocket(ctx context.Context, cancel context.CancelFunc) *socket { + return &socket{ + ctx: ctx, + cancel: cancel, + ch: make(chan Event), + types: make(map[EventType]struct{}), + } +} diff --git a/web/routes.go b/web/routes.go index fff5135..0792ceb 100644 --- a/web/routes.go +++ b/web/routes.go @@ -6,6 +6,7 @@ import ( "git.sleepycat.moe/sam/mercury/web/api/accounts" "git.sleepycat.moe/sam/mercury/web/api/blogs" "git.sleepycat.moe/sam/mercury/web/api/posts" + "git.sleepycat.moe/sam/mercury/web/api/streaming" "git.sleepycat.moe/sam/mercury/web/api/timelines" "git.sleepycat.moe/sam/mercury/web/app" "git.sleepycat.moe/sam/mercury/web/auth" @@ -53,6 +54,9 @@ func Routes(app *app.App) { r.With(app.APIAuth(database.TokenScopeTimeline, false)). Get("/timelines/home", api.WrapHandlerT(timelines.Home)) + streaming := streaming.New(app) + r.With(app.APIAuth(database.TokenScopeStreaming, false)). + Get("/streaming", api.WrapHandler(streaming.Streaming)) }) // web app handlers