feat: distribute new local posts to followers via websocket
This commit is contained in:
		
							parent
							
								
									5c6da51234
								
							
						
					
					
						commit
						97b05d787f
					
				
					 6 changed files with 107 additions and 75 deletions
				
			
		|  | @ -85,10 +85,16 @@ func (s *BlogStore) LocalCount(ctx context.Context) (count int64, err error) { | |||
| 	return count, nil | ||||
| } | ||||
| 
 | ||||
| func (s *BlogStore) Followers(ctx context.Context, id ulid.ULID) ([]ulid.ULID, error) { | ||||
| 	q := sqlf.Sprintf("SELECT account_id FROM account_follows WHERE blog_id = %s", id) | ||||
| type BlogFollower struct { | ||||
| 	AccountID ulid.ULID | ||||
| 	IsLocal   bool | ||||
| } | ||||
| 
 | ||||
| 	follows, err := Select[ulid.ULID](ctx, s.q, q) | ||||
| func (s *BlogStore) Followers(ctx context.Context, id ulid.ULID) ([]BlogFollower, error) { | ||||
| 	q := sqlf.Sprintf(`SELECT f.account_id, a.domain IS NOT NULL AS is_local | ||||
| 	FROM account_follows f JOIN accounts a ON a.id = f.account_id WHERE f.blog_id = %s`, id) | ||||
| 
 | ||||
| 	follows, err := Select[BlogFollower](ctx, s.q, q) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "executing query") | ||||
| 	} | ||||
|  |  | |||
|  | @ -1,9 +1,14 @@ | |||
| package processor | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"git.sleepycat.moe/sam/mercury/internal/database" | ||||
| 	"git.sleepycat.moe/sam/mercury/internal/database/sql" | ||||
| 	"git.sleepycat.moe/sam/mercury/internal/streaming" | ||||
| 	"git.sleepycat.moe/sam/mercury/web/api" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| ) | ||||
| 
 | ||||
| type Processor struct { | ||||
|  | @ -14,7 +19,7 @@ type Processor struct { | |||
| 
 | ||||
| func New(db *sql.Base) *Processor { | ||||
| 	p := &Processor{ | ||||
| 		SocketHolder: &streaming.SocketHolder{}, | ||||
| 		SocketHolder: streaming.NewSocketHolder(), | ||||
| 
 | ||||
| 		db: db, | ||||
| 	} | ||||
|  | @ -23,5 +28,54 @@ func New(db *sql.Base) *Processor { | |||
| } | ||||
| 
 | ||||
| func (p *Processor) HandlePost(post database.Post) { | ||||
| 	// this function is spun off in a separate goroutine, so we shouldn't use the same context as the parent request | ||||
| 	// TODO: make timeout configurable? | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), time.Minute) | ||||
| 	defer cancel() | ||||
| 
 | ||||
| 	conn, err := p.db.Acquire(ctx) | ||||
| 	if err != nil { | ||||
| 		log.Err(err).Msg("acquiring database") | ||||
| 		return | ||||
| 	} | ||||
| 	defer conn.Release() | ||||
| 
 | ||||
| 	// get the blog and their followers | ||||
| 	blog, err := sql.NewBlogStore(conn).ByID(ctx, post.BlogID) | ||||
| 	if err != nil { | ||||
| 		log.Err(err).Msg("getting blog") | ||||
| 		return | ||||
| 	} | ||||
| 	acct, err := sql.NewAccountStore(conn).ByID(ctx, blog.AccountID) | ||||
| 	if err != nil { | ||||
| 		log.Err(err).Msg("getting account") | ||||
| 		return | ||||
| 	} | ||||
| 	followers, err := sql.NewBlogStore(conn).Followers(ctx, blog.ID) | ||||
| 	if err != nil { | ||||
| 		log.Err(err).Msg("getting followers") | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	p.handlePostLocal(ctx, post, blog, acct, followers) | ||||
| } | ||||
| 
 | ||||
| // handlePostLocal handles a post on the local side--mostly sending notifications to websockets. | ||||
| // All posts go through this function. | ||||
| func (p *Processor) handlePostLocal( | ||||
| 	ctx context.Context, post database.Post, | ||||
| 	blog database.Blog, acct database.Account, | ||||
| 	followers []sql.BlogFollower, | ||||
| ) { | ||||
| 	// send to self | ||||
| 	apiPost := api.DBPostToPost(post, blog, acct) | ||||
| 	p.SocketHolder.Send(acct.ID, streaming.EventTypePost, apiPost) | ||||
| 
 | ||||
| 	for _, follower := range followers { | ||||
| 		if !follower.IsLocal { | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		p.SocketHolder.Send(follower.AccountID, streaming.EventTypePost, apiPost) | ||||
| 	} | ||||
| } | ||||
|  |  | |||
|  | @ -5,10 +5,20 @@ import "encoding/json" | |||
| type EventType int8 | ||||
| 
 | ||||
| const ( | ||||
| 	// Sent when an error occurs | ||||
| 	// Payload: {code:int, message:string} | ||||
| 	EventTypeError EventType = 1 | ||||
| 	EventTypePost  EventType = 2 | ||||
| 	// Sent when echoing back received messages | ||||
| 	// Payload: <received message> | ||||
| 	EventTypeEcho EventType = 2 | ||||
| 	// Sent on a new post being created | ||||
| 	// Payload: <post object> | ||||
| 	EventTypePost EventType = 3 | ||||
| 
 | ||||
| 	// Receive events | ||||
| 	// Subscribe to a new event | ||||
| 	EventTypeSubscribe EventType = 126 | ||||
| 	// Unsubscribe from an event | ||||
| 	EventTypeUnsubscribe EventType = 127 | ||||
| ) | ||||
| 
 | ||||
|  | @ -16,6 +26,8 @@ func (et EventType) Valid() bool { | |||
| 	switch et { | ||||
| 	case EventTypeError: | ||||
| 		return true | ||||
| 	case EventTypeEcho: | ||||
| 		return true | ||||
| 	case EventTypePost: | ||||
| 		return true | ||||
| 	case EventTypeSubscribe: | ||||
|  |  | |||
|  | @ -13,6 +13,12 @@ type SocketHolder struct { | |||
| 	mu      sync.Mutex | ||||
| } | ||||
| 
 | ||||
| func NewSocketHolder() *SocketHolder { | ||||
| 	return &SocketHolder{ | ||||
| 		sockets: make(map[ulid.ULID]*userSockets), | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (sh *SocketHolder) Send(acctID ulid.ULID, et EventType, data any) { | ||||
| 	userSockets := sh.SocketsFor(acctID) | ||||
| 
 | ||||
|  | @ -22,7 +28,7 @@ func (sh *SocketHolder) Send(acctID ulid.ULID, et EventType, data any) { | |||
| 	userSockets.mu.Unlock() | ||||
| 
 | ||||
| 	for _, s := range sockets { | ||||
| 		if s.willAcceptEvent(et) { | ||||
| 		if s.WillAcceptEvent(et) { | ||||
| 			// the socket might block for a bit, so spin this off into a separate goroutine | ||||
| 			go func(s *Socket) { | ||||
| 				s.ch <- Event{Type: et, Data: data} | ||||
|  | @ -71,7 +77,7 @@ type Socket struct { | |||
| 	mu sync.RWMutex | ||||
| } | ||||
| 
 | ||||
| func (s *Socket) willAcceptEvent(mt EventType) bool { | ||||
| func (s *Socket) WillAcceptEvent(mt EventType) bool { | ||||
| 	if mt == EventTypeError { | ||||
| 		return true | ||||
| 	} | ||||
|  | @ -109,6 +115,8 @@ func NewSocket(ctx context.Context, cancel context.CancelFunc) *Socket { | |||
| 		ctx:    ctx, | ||||
| 		cancel: cancel, | ||||
| 		ch:     make(chan Event), | ||||
| 		types:  make(map[EventType]struct{}), | ||||
| 		types: map[EventType]struct{}{ | ||||
| 			EventTypeEcho: {}, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  |  | |||
|  | @ -1,64 +0,0 @@ | |||
| package streaming | ||||
| 
 | ||||
| import "encoding/json" | ||||
| 
 | ||||
| type EventType int8 | ||||
| 
 | ||||
| const ( | ||||
| 	EventTypeError EventType = 1 | ||||
| 	EventTypePost  EventType = 2 | ||||
| 
 | ||||
| 	EventTypeSubscribe   EventType = 126 | ||||
| 	EventTypeUnsubscribe EventType = 127 | ||||
| ) | ||||
| 
 | ||||
| func (et EventType) Valid() bool { | ||||
| 	switch et { | ||||
| 	case EventTypeError: | ||||
| 		return true | ||||
| 	case EventTypePost: | ||||
| 		return true | ||||
| 	case EventTypeSubscribe: | ||||
| 		return true | ||||
| 	case EventTypeUnsubscribe: | ||||
| 		return true | ||||
| 	default: | ||||
| 		return false | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Returns true if this event can be subscribed to/unsubscribed from | ||||
| func (et EventType) ValidReceive() bool { | ||||
| 	if !et.Valid() { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	switch et { | ||||
| 	case EventTypeError, EventTypeSubscribe, EventTypeUnsubscribe: | ||||
| 		return false | ||||
| 	default: | ||||
| 		return true | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type Event struct { | ||||
| 	Type EventType `json:"t"` | ||||
| 	Data any       `json:"d"` | ||||
| } | ||||
| 
 | ||||
| func newEvent(mt EventType, data any) *Event { | ||||
| 	return &Event{ | ||||
| 		Type: mt, | ||||
| 		Data: data, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type ErrorEvent struct { | ||||
| 	Code    int    `json:"code"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
| 
 | ||||
| type IncomingEvent struct { | ||||
| 	Type EventType       `json:"t"` | ||||
| 	Data json.RawMessage `json:"d"` // this is a RawMessage so we can easily unmarshal it later | ||||
| } | ||||
|  | @ -41,7 +41,10 @@ func (app *App) Streaming(w http.ResponseWriter, r *http.Request) error { | |||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
| 	socket, ok := app.Processor.SocketHolder.SocketsFor(token.UserID).NewSocket(ctx, cancel) | ||||
| 	if !ok { | ||||
| 		err := conn.WriteJSON(newEvent(EventTypeError, ErrorEvent{Code: api.ErrTooManyStreams, Message: "Too many streams open"})) | ||||
| 		err := conn.WriteJSON(streaming.Event{ | ||||
| 			Type: streaming.EventTypeError, | ||||
| 			Data: streaming.ErrorEvent{Code: api.ErrTooManyStreams, Message: "Too many streams open"}, | ||||
| 		}) | ||||
| 		if err != nil { | ||||
| 			log.Err(err).Msg("writing stream rejection message to socket") | ||||
| 		} | ||||
|  | @ -100,11 +103,24 @@ func (app *App) readStream(conn *websocket.Conn, socket *streaming.Socket) { | |||
| 				} | ||||
| 
 | ||||
| 				if !et.ValidReceive() { | ||||
| 					log.Debug().Int("event", int(et)).Msg("invalid event for subscription") | ||||
| 
 | ||||
| 					// if it's not a valid event, ignore silently | ||||
| 					continue | ||||
| 				} | ||||
| 
 | ||||
| 				socket.SetEvent(et, e.Type != streaming.EventTypeUnsubscribe) | ||||
| 				socket.SetEvent(et, e.Type == streaming.EventTypeSubscribe) | ||||
| 				log.Debug(). | ||||
| 					Int("event", int(et)). | ||||
| 					Bool("subscribed", e.Type == streaming.EventTypeSubscribe). | ||||
| 					Msg("toggled subscription status for event") | ||||
| 			} | ||||
| 
 | ||||
| 			if socket.WillAcceptEvent(streaming.EventTypeEcho) { | ||||
| 				conn.WriteJSON(streaming.Event{ | ||||
| 					Type: streaming.EventTypeEcho, | ||||
| 					Data: e, | ||||
| 				}) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue