use std::sync::Arc; use axum::{ extract::{ ws::{Message, WebSocket}, WebSocketUpgrade, }, response::Response, Extension, }; use foxchat::s2s::{Dispatch, Payload}; use futures::stream::{SplitSink, SplitStream, StreamExt}; use tokio::sync::{broadcast, mpsc}; use tracing::error; use crate::app_state::AppState; pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension>) -> Response { ws.on_upgrade(|socket| handle_socket(state, socket)) } struct SocketState { instance_id: Option, } async fn handle_socket(app_state: Arc, socket: WebSocket) { let (mut sender, mut receiver) = socket.split(); let (tx, rx) = mpsc::channel::(10); let dispatch = app_state.broadcast.subscribe(); let socket_state = Arc::new(SocketState { instance_id: None, // Filled out after IDENTIFY }); tokio::spawn(merge_channels(app_state.clone(), socket_state.clone(), dispatch, tx.clone())); tokio::spawn(write(app_state.clone(), socket_state.clone(), rx, sender)); tokio::spawn(read(app_state.clone(), socket_state.clone(), tx.clone(), receiver)); } async fn read( app_state: Arc, socket_state: Arc, tx: mpsc::Sender, mut receiver: SplitStream, ) { let msg = receiver.next().await; if let Some(Ok(msg)) = msg { if let Ok(msg) = msg.into_text() { if let Ok(pl) = serde_json::from_str::(&msg) { match pl { Payload::Identify { host, date, server, signature } => { // TODO: identify }, _ => { tx.send(Payload::Error { message: "First payload was not IDENTIFY".into() }).await.ok(); return; } } } else { tx.send(Payload::Error { message: "Invalid JSON payload".into() }).await.ok(); return; } } else { } } else { // Websocket closed, return return; } while let Some(msg) = receiver.next().await { let msg = if let Ok(msg) = msg { msg } else { return; }; // TODO: handle incoming payloads } } async fn write( app_state: Arc, socket_state: Arc, mut rx: mpsc::Receiver, mut sender: SplitSink, ) { } async fn merge_channels( app_state: Arc, socket_state: Arc, mut rx: broadcast::Receiver, tx: mpsc::Sender, ) { loop { let msg = rx.recv().await; match msg { Ok(p) => { // TODO: filter users tx.send(Payload::Dispatch { event: p, recipients: vec![], }) .await .ok(); }, Err(e) => match e { broadcast::error::RecvError::Closed => { error!("Broadcast channel was closed, this is not supposed to happen"); return; }, broadcast::error::RecvError::Lagged(i) => { error!("Broadcast receive lagged by {i}"); } } } } }