feat: distribute new local posts to followers via websocket
This commit is contained in:
		
							parent
							
								
									5c6da51234
								
							
						
					
					
						commit
						97b05d787f
					
				
					 6 changed files with 107 additions and 75 deletions
				
			
		|  | @ -85,10 +85,16 @@ func (s *BlogStore) LocalCount(ctx context.Context) (count int64, err error) { | ||||||
| 	return count, nil | 	return count, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *BlogStore) Followers(ctx context.Context, id ulid.ULID) ([]ulid.ULID, error) { | type BlogFollower struct { | ||||||
| 	q := sqlf.Sprintf("SELECT account_id FROM account_follows WHERE blog_id = %s", id) | 	AccountID ulid.ULID | ||||||
|  | 	IsLocal   bool | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| 	follows, err := Select[ulid.ULID](ctx, s.q, q) | func (s *BlogStore) Followers(ctx context.Context, id ulid.ULID) ([]BlogFollower, error) { | ||||||
|  | 	q := sqlf.Sprintf(`SELECT f.account_id, a.domain IS NOT NULL AS is_local | ||||||
|  | 	FROM account_follows f JOIN accounts a ON a.id = f.account_id WHERE f.blog_id = %s`, id) | ||||||
|  | 
 | ||||||
|  | 	follows, err := Select[BlogFollower](ctx, s.q, q) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, errors.Wrap(err, "executing query") | 		return nil, errors.Wrap(err, "executing query") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -1,9 +1,14 @@ | ||||||
| package processor | package processor | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
| 	"git.sleepycat.moe/sam/mercury/internal/database" | 	"git.sleepycat.moe/sam/mercury/internal/database" | ||||||
| 	"git.sleepycat.moe/sam/mercury/internal/database/sql" | 	"git.sleepycat.moe/sam/mercury/internal/database/sql" | ||||||
| 	"git.sleepycat.moe/sam/mercury/internal/streaming" | 	"git.sleepycat.moe/sam/mercury/internal/streaming" | ||||||
|  | 	"git.sleepycat.moe/sam/mercury/web/api" | ||||||
|  | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type Processor struct { | type Processor struct { | ||||||
|  | @ -14,7 +19,7 @@ type Processor struct { | ||||||
| 
 | 
 | ||||||
| func New(db *sql.Base) *Processor { | func New(db *sql.Base) *Processor { | ||||||
| 	p := &Processor{ | 	p := &Processor{ | ||||||
| 		SocketHolder: &streaming.SocketHolder{}, | 		SocketHolder: streaming.NewSocketHolder(), | ||||||
| 
 | 
 | ||||||
| 		db: db, | 		db: db, | ||||||
| 	} | 	} | ||||||
|  | @ -23,5 +28,54 @@ func New(db *sql.Base) *Processor { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *Processor) HandlePost(post database.Post) { | func (p *Processor) HandlePost(post database.Post) { | ||||||
|  | 	// this function is spun off in a separate goroutine, so we shouldn't use the same context as the parent request | ||||||
|  | 	// TODO: make timeout configurable? | ||||||
|  | 	ctx, cancel := context.WithTimeout(context.Background(), time.Minute) | ||||||
|  | 	defer cancel() | ||||||
| 
 | 
 | ||||||
|  | 	conn, err := p.db.Acquire(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Err(err).Msg("acquiring database") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	defer conn.Release() | ||||||
|  | 
 | ||||||
|  | 	// get the blog and their followers | ||||||
|  | 	blog, err := sql.NewBlogStore(conn).ByID(ctx, post.BlogID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Err(err).Msg("getting blog") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	acct, err := sql.NewAccountStore(conn).ByID(ctx, blog.AccountID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Err(err).Msg("getting account") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	followers, err := sql.NewBlogStore(conn).Followers(ctx, blog.ID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Err(err).Msg("getting followers") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	p.handlePostLocal(ctx, post, blog, acct, followers) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // handlePostLocal handles a post on the local side--mostly sending notifications to websockets. | ||||||
|  | // All posts go through this function. | ||||||
|  | func (p *Processor) handlePostLocal( | ||||||
|  | 	ctx context.Context, post database.Post, | ||||||
|  | 	blog database.Blog, acct database.Account, | ||||||
|  | 	followers []sql.BlogFollower, | ||||||
|  | ) { | ||||||
|  | 	// send to self | ||||||
|  | 	apiPost := api.DBPostToPost(post, blog, acct) | ||||||
|  | 	p.SocketHolder.Send(acct.ID, streaming.EventTypePost, apiPost) | ||||||
|  | 
 | ||||||
|  | 	for _, follower := range followers { | ||||||
|  | 		if !follower.IsLocal { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		p.SocketHolder.Send(follower.AccountID, streaming.EventTypePost, apiPost) | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -5,10 +5,20 @@ import "encoding/json" | ||||||
| type EventType int8 | type EventType int8 | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
|  | 	// Sent when an error occurs | ||||||
|  | 	// Payload: {code:int, message:string} | ||||||
| 	EventTypeError EventType = 1 | 	EventTypeError EventType = 1 | ||||||
| 	EventTypePost  EventType = 2 | 	// Sent when echoing back received messages | ||||||
|  | 	// Payload: <received message> | ||||||
|  | 	EventTypeEcho EventType = 2 | ||||||
|  | 	// Sent on a new post being created | ||||||
|  | 	// Payload: <post object> | ||||||
|  | 	EventTypePost EventType = 3 | ||||||
| 
 | 
 | ||||||
| 	EventTypeSubscribe   EventType = 126 | 	// Receive events | ||||||
|  | 	// Subscribe to a new event | ||||||
|  | 	EventTypeSubscribe EventType = 126 | ||||||
|  | 	// Unsubscribe from an event | ||||||
| 	EventTypeUnsubscribe EventType = 127 | 	EventTypeUnsubscribe EventType = 127 | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -16,6 +26,8 @@ func (et EventType) Valid() bool { | ||||||
| 	switch et { | 	switch et { | ||||||
| 	case EventTypeError: | 	case EventTypeError: | ||||||
| 		return true | 		return true | ||||||
|  | 	case EventTypeEcho: | ||||||
|  | 		return true | ||||||
| 	case EventTypePost: | 	case EventTypePost: | ||||||
| 		return true | 		return true | ||||||
| 	case EventTypeSubscribe: | 	case EventTypeSubscribe: | ||||||
|  |  | ||||||
|  | @ -13,6 +13,12 @@ type SocketHolder struct { | ||||||
| 	mu      sync.Mutex | 	mu      sync.Mutex | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func NewSocketHolder() *SocketHolder { | ||||||
|  | 	return &SocketHolder{ | ||||||
|  | 		sockets: make(map[ulid.ULID]*userSockets), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (sh *SocketHolder) Send(acctID ulid.ULID, et EventType, data any) { | func (sh *SocketHolder) Send(acctID ulid.ULID, et EventType, data any) { | ||||||
| 	userSockets := sh.SocketsFor(acctID) | 	userSockets := sh.SocketsFor(acctID) | ||||||
| 
 | 
 | ||||||
|  | @ -22,7 +28,7 @@ func (sh *SocketHolder) Send(acctID ulid.ULID, et EventType, data any) { | ||||||
| 	userSockets.mu.Unlock() | 	userSockets.mu.Unlock() | ||||||
| 
 | 
 | ||||||
| 	for _, s := range sockets { | 	for _, s := range sockets { | ||||||
| 		if s.willAcceptEvent(et) { | 		if s.WillAcceptEvent(et) { | ||||||
| 			// the socket might block for a bit, so spin this off into a separate goroutine | 			// the socket might block for a bit, so spin this off into a separate goroutine | ||||||
| 			go func(s *Socket) { | 			go func(s *Socket) { | ||||||
| 				s.ch <- Event{Type: et, Data: data} | 				s.ch <- Event{Type: et, Data: data} | ||||||
|  | @ -71,7 +77,7 @@ type Socket struct { | ||||||
| 	mu sync.RWMutex | 	mu sync.RWMutex | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *Socket) willAcceptEvent(mt EventType) bool { | func (s *Socket) WillAcceptEvent(mt EventType) bool { | ||||||
| 	if mt == EventTypeError { | 	if mt == EventTypeError { | ||||||
| 		return true | 		return true | ||||||
| 	} | 	} | ||||||
|  | @ -109,6 +115,8 @@ func NewSocket(ctx context.Context, cancel context.CancelFunc) *Socket { | ||||||
| 		ctx:    ctx, | 		ctx:    ctx, | ||||||
| 		cancel: cancel, | 		cancel: cancel, | ||||||
| 		ch:     make(chan Event), | 		ch:     make(chan Event), | ||||||
| 		types:  make(map[EventType]struct{}), | 		types: map[EventType]struct{}{ | ||||||
|  | 			EventTypeEcho: {}, | ||||||
|  | 		}, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -1,64 +0,0 @@ | ||||||
| package streaming |  | ||||||
| 
 |  | ||||||
| import "encoding/json" |  | ||||||
| 
 |  | ||||||
| type EventType int8 |  | ||||||
| 
 |  | ||||||
| const ( |  | ||||||
| 	EventTypeError EventType = 1 |  | ||||||
| 	EventTypePost  EventType = 2 |  | ||||||
| 
 |  | ||||||
| 	EventTypeSubscribe   EventType = 126 |  | ||||||
| 	EventTypeUnsubscribe EventType = 127 |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| func (et EventType) Valid() bool { |  | ||||||
| 	switch et { |  | ||||||
| 	case EventTypeError: |  | ||||||
| 		return true |  | ||||||
| 	case EventTypePost: |  | ||||||
| 		return true |  | ||||||
| 	case EventTypeSubscribe: |  | ||||||
| 		return true |  | ||||||
| 	case EventTypeUnsubscribe: |  | ||||||
| 		return true |  | ||||||
| 	default: |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Returns true if this event can be subscribed to/unsubscribed from |  | ||||||
| func (et EventType) ValidReceive() bool { |  | ||||||
| 	if !et.Valid() { |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	switch et { |  | ||||||
| 	case EventTypeError, EventTypeSubscribe, EventTypeUnsubscribe: |  | ||||||
| 		return false |  | ||||||
| 	default: |  | ||||||
| 		return true |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type Event struct { |  | ||||||
| 	Type EventType `json:"t"` |  | ||||||
| 	Data any       `json:"d"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func newEvent(mt EventType, data any) *Event { |  | ||||||
| 	return &Event{ |  | ||||||
| 		Type: mt, |  | ||||||
| 		Data: data, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type ErrorEvent struct { |  | ||||||
| 	Code    int    `json:"code"` |  | ||||||
| 	Message string `json:"message"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type IncomingEvent struct { |  | ||||||
| 	Type EventType       `json:"t"` |  | ||||||
| 	Data json.RawMessage `json:"d"` // this is a RawMessage so we can easily unmarshal it later |  | ||||||
| } |  | ||||||
|  | @ -41,7 +41,10 @@ func (app *App) Streaming(w http.ResponseWriter, r *http.Request) error { | ||||||
| 	ctx, cancel := context.WithCancel(context.Background()) | 	ctx, cancel := context.WithCancel(context.Background()) | ||||||
| 	socket, ok := app.Processor.SocketHolder.SocketsFor(token.UserID).NewSocket(ctx, cancel) | 	socket, ok := app.Processor.SocketHolder.SocketsFor(token.UserID).NewSocket(ctx, cancel) | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		err := conn.WriteJSON(newEvent(EventTypeError, ErrorEvent{Code: api.ErrTooManyStreams, Message: "Too many streams open"})) | 		err := conn.WriteJSON(streaming.Event{ | ||||||
|  | 			Type: streaming.EventTypeError, | ||||||
|  | 			Data: streaming.ErrorEvent{Code: api.ErrTooManyStreams, Message: "Too many streams open"}, | ||||||
|  | 		}) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Err(err).Msg("writing stream rejection message to socket") | 			log.Err(err).Msg("writing stream rejection message to socket") | ||||||
| 		} | 		} | ||||||
|  | @ -100,11 +103,24 @@ func (app *App) readStream(conn *websocket.Conn, socket *streaming.Socket) { | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				if !et.ValidReceive() { | 				if !et.ValidReceive() { | ||||||
|  | 					log.Debug().Int("event", int(et)).Msg("invalid event for subscription") | ||||||
|  | 
 | ||||||
| 					// if it's not a valid event, ignore silently | 					// if it's not a valid event, ignore silently | ||||||
| 					continue | 					continue | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				socket.SetEvent(et, e.Type != streaming.EventTypeUnsubscribe) | 				socket.SetEvent(et, e.Type == streaming.EventTypeSubscribe) | ||||||
|  | 				log.Debug(). | ||||||
|  | 					Int("event", int(et)). | ||||||
|  | 					Bool("subscribed", e.Type == streaming.EventTypeSubscribe). | ||||||
|  | 					Msg("toggled subscription status for event") | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if socket.WillAcceptEvent(streaming.EventTypeEcho) { | ||||||
|  | 				conn.WriteJSON(streaming.Event{ | ||||||
|  | 					Type: streaming.EventTypeEcho, | ||||||
|  | 					Data: e, | ||||||
|  | 				}) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue