foxchat/chat/src/http/ws/mod.rs

120 lines
3.3 KiB
Rust
Raw Normal View History

2024-01-21 22:30:56 +01:00
use std::sync::Arc;
use axum::{
extract::{
ws::{Message, WebSocket},
WebSocketUpgrade,
},
response::Response,
Extension,
};
use foxchat::s2s::{Dispatch, Payload};
use futures::stream::{SplitSink, SplitStream, StreamExt};
use tokio::sync::{broadcast, mpsc};
use tracing::error;
use crate::app_state::AppState;
pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppState>>) -> Response {
ws.on_upgrade(|socket| handle_socket(state, socket))
}
struct SocketState {
instance_id: Option<String>,
}
async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
let (mut sender, mut receiver) = socket.split();
let (tx, rx) = mpsc::channel::<Payload>(10);
let dispatch = app_state.broadcast.subscribe();
let socket_state = Arc::new(SocketState {
instance_id: None, // Filled out after IDENTIFY
});
tokio::spawn(merge_channels(app_state.clone(), socket_state.clone(), dispatch, tx.clone()));
tokio::spawn(write(app_state.clone(), socket_state.clone(), rx, sender));
tokio::spawn(read(app_state.clone(), socket_state.clone(), tx.clone(), receiver));
}
async fn read(
app_state: Arc<AppState>,
socket_state: Arc<SocketState>,
tx: mpsc::Sender<Payload>,
mut receiver: SplitStream<WebSocket>,
) {
let msg = receiver.next().await;
if let Some(Ok(msg)) = msg {
if let Ok(msg) = msg.into_text() {
if let Ok(pl) = serde_json::from_str::<Payload>(&msg) {
match pl {
Payload::Identify { host, date, server, signature } => {
// TODO: identify
},
_ => {
tx.send(Payload::Error { message: "First payload was not IDENTIFY".into() }).await.ok();
return;
}
}
} else {
tx.send(Payload::Error { message: "Invalid JSON payload".into() }).await.ok();
return;
}
} else {
}
} else {
// Websocket closed, return
return;
}
while let Some(msg) = receiver.next().await {
let msg = if let Ok(msg) = msg {
msg
} else {
return;
};
// TODO: handle incoming payloads
}
}
async fn write(
app_state: Arc<AppState>,
socket_state: Arc<SocketState>,
mut rx: mpsc::Receiver<Payload>,
mut sender: SplitSink<WebSocket, Message>,
) {
}
async fn merge_channels(
app_state: Arc<AppState>,
socket_state: Arc<SocketState>,
mut rx: broadcast::Receiver<Dispatch>,
tx: mpsc::Sender<Payload>,
) {
loop {
let msg = rx.recv().await;
match msg {
Ok(p) => {
// TODO: filter users
tx.send(Payload::Dispatch {
event: p,
recipients: vec![],
})
.await
.ok();
},
Err(e) => match e {
broadcast::error::RecvError::Closed => {
error!("Broadcast channel was closed, this is not supposed to happen");
return;
},
broadcast::error::RecvError::Lagged(i) => {
error!("Broadcast receive lagged by {i}");
}
}
}
}
}