add (unfinished, untested!) /api/v1/streaming endpoint
This commit is contained in:
		
							parent
							
								
									0a94e2bf93
								
							
						
					
					
						commit
						6f17b59a47
					
				
					 8 changed files with 291 additions and 0 deletions
				
			
		|  | @ -134,6 +134,9 @@ const ( | |||
| 
 | ||||
| 	// Post related | ||||
| 	ErrPostNotFound = 4001 | ||||
| 
 | ||||
| 	// Streaming related | ||||
| 	ErrTooManyStreams = 5001 | ||||
| ) | ||||
| 
 | ||||
| func ErrCodeMessage(code int) string { | ||||
|  | @ -157,6 +160,8 @@ var errCodeMessages = map[int]string{ | |||
| 	ErrBlogNotFound: "Blog not found", | ||||
| 
 | ||||
| 	ErrPostNotFound: "Post not found", | ||||
| 
 | ||||
| 	ErrTooManyStreams: "Too many streams open", | ||||
| } | ||||
| 
 | ||||
| func ErrCodeStatus(code int) int { | ||||
|  | @ -180,4 +185,6 @@ var errCodeStatuses = map[int]int{ | |||
| 	ErrBlogNotFound: http.StatusNotFound, | ||||
| 
 | ||||
| 	ErrPostNotFound: http.StatusNotFound, | ||||
| 
 | ||||
| 	ErrTooManyStreams: http.StatusBadRequest, | ||||
| } | ||||
|  |  | |||
							
								
								
									
										64
									
								
								web/api/streaming/messages.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								web/api/streaming/messages.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,64 @@ | |||
| package streaming | ||||
| 
 | ||||
| import "encoding/json" | ||||
| 
 | ||||
| type EventType int8 | ||||
| 
 | ||||
| const ( | ||||
| 	EventTypeError EventType = 1 | ||||
| 	EventTypePost  EventType = 2 | ||||
| 
 | ||||
| 	EventTypeSubscribe   EventType = 126 | ||||
| 	EventTypeUnsubscribe EventType = 127 | ||||
| ) | ||||
| 
 | ||||
| func (et EventType) Valid() bool { | ||||
| 	switch et { | ||||
| 	case EventTypeError: | ||||
| 		return true | ||||
| 	case EventTypePost: | ||||
| 		return true | ||||
| 	case EventTypeSubscribe: | ||||
| 		return true | ||||
| 	case EventTypeUnsubscribe: | ||||
| 		return true | ||||
| 	default: | ||||
| 		return false | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Returns true if this event can be subscribed to/unsubscribed from | ||||
| func (et EventType) ValidReceive() bool { | ||||
| 	if !et.Valid() { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	switch et { | ||||
| 	case EventTypeError, EventTypeSubscribe, EventTypeUnsubscribe: | ||||
| 		return false | ||||
| 	default: | ||||
| 		return true | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type Event struct { | ||||
| 	Type EventType `json:"t"` | ||||
| 	Data any       `json:"d"` | ||||
| } | ||||
| 
 | ||||
| func newEvent(mt EventType, data any) *Event { | ||||
| 	return &Event{ | ||||
| 		Type: mt, | ||||
| 		Data: data, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type ErrorEvent struct { | ||||
| 	Code    int    `json:"code"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
| 
 | ||||
| type IncomingEvent struct { | ||||
| 	Type EventType       `json:"t"` | ||||
| 	Data json.RawMessage `json:"d"` // this is a RawMessage so we can easily unmarshal it later | ||||
| } | ||||
							
								
								
									
										110
									
								
								web/api/streaming/module.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								web/api/streaming/module.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,110 @@ | |||
| package streaming | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"git.sleepycat.moe/sam/mercury/web/api" | ||||
| 	"git.sleepycat.moe/sam/mercury/web/app" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| ) | ||||
| 
 | ||||
| type App struct { | ||||
| 	*app.App | ||||
| } | ||||
| 
 | ||||
| func New(app *app.App) *App { | ||||
| 	return &App{ | ||||
| 		App: app, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| var upgrader = websocket.Upgrader{} | ||||
| 
 | ||||
| func (app *App) Streaming(w http.ResponseWriter, r *http.Request) error { | ||||
| 	token, _ := app.TokenFromContext(r.Context()) | ||||
| 	conn, err := upgrader.Upgrade(w, r, nil) | ||||
| 	if err != nil { | ||||
| 		log.Err(err).Msg("performing websocket handshake") | ||||
| 		hse := websocket.HandshakeError{} | ||||
| 		if errors.As(err, &hse) { | ||||
| 			return nil // the upgrader already sent a response for us | ||||
| 		} | ||||
| 
 | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
| 	socket, ok := SocketHolder.socketsFor(token.UserID).newSocket(ctx, cancel) | ||||
| 	if !ok { | ||||
| 		err := conn.WriteJSON(newEvent(EventTypeError, ErrorEvent{Code: api.ErrTooManyStreams, Message: "Too many streams open"})) | ||||
| 		if err != nil { | ||||
| 			log.Err(err).Msg("writing stream rejection message to socket") | ||||
| 		} | ||||
| 
 | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	go app.writeStream(conn, socket) | ||||
| 	go app.readStream(conn, socket) | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (app *App) writeStream(conn *websocket.Conn, socket *socket) { | ||||
| 	defer conn.Close() | ||||
| 
 | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-socket.ctx.Done(): | ||||
| 			return | ||||
| 		case ev := <-socket.ch: | ||||
| 			// at this point, the type should already have been filtered, so just send the event | ||||
| 			err := conn.WriteJSON(ev) | ||||
| 			if err != nil { | ||||
| 				// write failed, bail and make client reconnect | ||||
| 				log.Err(err).Msg("error writing JSON to socket") | ||||
| 				socket.cancel() | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (app *App) readStream(conn *websocket.Conn, socket *socket) { | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-socket.ctx.Done(): | ||||
| 			return | ||||
| 		default: | ||||
| 			var e IncomingEvent | ||||
| 			err := conn.ReadJSON(&e) | ||||
| 			if err != nil { | ||||
| 				// read failed, bail and make client reconnect | ||||
| 				log.Err(err).Msg("error reading JSON from socket") | ||||
| 				socket.cancel() | ||||
| 				return | ||||
| 			} | ||||
| 
 | ||||
| 			switch e.Type { | ||||
| 			case EventTypeSubscribe, EventTypeUnsubscribe: | ||||
| 				var et EventType | ||||
| 				err = json.Unmarshal(e.Data, &et) | ||||
| 				if err != nil { | ||||
| 					// invalid event type, log but don't disconnect | ||||
| 					log.Err(err).Msg("reading event type to subscribe to") | ||||
| 					continue | ||||
| 				} | ||||
| 
 | ||||
| 				if !et.ValidReceive() { | ||||
| 					// if it's not a valid event, ignore silently | ||||
| 					continue | ||||
| 				} | ||||
| 
 | ||||
| 				socket.setEvent(et, e.Type != EventTypeSubscribe) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										103
									
								
								web/api/streaming/sockets.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								web/api/streaming/sockets.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,103 @@ | |||
| package streaming | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"sync" | ||||
| 
 | ||||
| 	"github.com/oklog/ulid/v2" | ||||
| ) | ||||
| 
 | ||||
| var SocketHolder socketHolder | ||||
| 
 | ||||
| type socketHolder struct { | ||||
| 	// map of sockets to | ||||
| 	sockets map[ulid.ULID]*userSockets | ||||
| 	mu      sync.Mutex | ||||
| } | ||||
| 
 | ||||
| func (sh *socketHolder) Send(acctID ulid.ULID, et EventType, data any) { | ||||
| 	userSockets := sh.socketsFor(acctID) | ||||
| 
 | ||||
| 	userSockets.mu.Lock() | ||||
| 	sockets := make([]*socket, len(userSockets.sockets)) | ||||
| 	copy(sockets, userSockets.sockets) | ||||
| 	userSockets.mu.Unlock() | ||||
| 
 | ||||
| 	for _, s := range sockets { | ||||
| 		if s.willAcceptEvent(et) { | ||||
| 			// the socket might block for a bit, so spin this off into a separate goroutine | ||||
| 			go func(s *socket) { | ||||
| 				s.ch <- Event{Type: et, Data: data} | ||||
| 			}(s) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s *socketHolder) socketsFor(acct ulid.ULID) *userSockets { | ||||
| 	s.mu.Lock() | ||||
| 	defer s.mu.Unlock() | ||||
| 
 | ||||
| 	us, ok := s.sockets[acct] | ||||
| 	if !ok { | ||||
| 		us = &userSockets{} | ||||
| 		s.sockets[acct] = us | ||||
| 	} | ||||
| 	return us | ||||
| } | ||||
| 
 | ||||
| const sessionCountLimit = 50 // no more than 50 concurrent sessions per user | ||||
| 
 | ||||
| type userSockets struct { | ||||
| 	mu      sync.Mutex | ||||
| 	sockets []*socket | ||||
| } | ||||
| 
 | ||||
| func (s *userSockets) newSocket(ctx context.Context, cancel context.CancelFunc) (*socket, bool) { | ||||
| 	s.mu.Lock() | ||||
| 	if len(s.sockets) >= sessionCountLimit { | ||||
| 		return nil, false | ||||
| 	} | ||||
| 	socket := newSocket(ctx, cancel) | ||||
| 	s.sockets = append(s.sockets, socket) | ||||
| 	return socket, true | ||||
| } | ||||
| 
 | ||||
| type socket struct { | ||||
| 	ctx    context.Context | ||||
| 	cancel context.CancelFunc | ||||
| 
 | ||||
| 	ch    chan Event | ||||
| 	types map[EventType]struct{} | ||||
| 
 | ||||
| 	mu sync.RWMutex | ||||
| } | ||||
| 
 | ||||
| func (s *socket) willAcceptEvent(mt EventType) bool { | ||||
| 	if mt == EventTypeError { | ||||
| 		return true | ||||
| 	} | ||||
| 
 | ||||
| 	s.mu.RLock() | ||||
| 	_, ok := s.types[mt] | ||||
| 	s.mu.RUnlock() | ||||
| 	return ok | ||||
| } | ||||
| 
 | ||||
| func (s *socket) setEvent(mt EventType, add bool) { | ||||
| 	s.mu.Lock() | ||||
| 	if add { | ||||
| 		s.types[mt] = struct{}{} | ||||
| 	} else { | ||||
| 		delete(s.types, mt) | ||||
| 	} | ||||
| 	s.mu.Unlock() | ||||
| } | ||||
| 
 | ||||
| func newSocket(ctx context.Context, cancel context.CancelFunc) *socket { | ||||
| 	return &socket{ | ||||
| 		ctx:    ctx, | ||||
| 		cancel: cancel, | ||||
| 		ch:     make(chan Event), | ||||
| 		types:  make(map[EventType]struct{}), | ||||
| 	} | ||||
| } | ||||
|  | @ -6,6 +6,7 @@ import ( | |||
| 	"git.sleepycat.moe/sam/mercury/web/api/accounts" | ||||
| 	"git.sleepycat.moe/sam/mercury/web/api/blogs" | ||||
| 	"git.sleepycat.moe/sam/mercury/web/api/posts" | ||||
| 	"git.sleepycat.moe/sam/mercury/web/api/streaming" | ||||
| 	"git.sleepycat.moe/sam/mercury/web/api/timelines" | ||||
| 	"git.sleepycat.moe/sam/mercury/web/app" | ||||
| 	"git.sleepycat.moe/sam/mercury/web/auth" | ||||
|  | @ -53,6 +54,9 @@ func Routes(app *app.App) { | |||
| 		r.With(app.APIAuth(database.TokenScopeTimeline, false)). | ||||
| 			Get("/timelines/home", api.WrapHandlerT(timelines.Home)) | ||||
| 
 | ||||
| 		streaming := streaming.New(app) | ||||
| 		r.With(app.APIAuth(database.TokenScopeStreaming, false)). | ||||
| 			Get("/streaming", api.WrapHandler(streaming.Streaming)) | ||||
| 	}) | ||||
| 
 | ||||
| 	// web app handlers | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue