120 lines
3.3 KiB
Rust
120 lines
3.3 KiB
Rust
|
use std::sync::Arc;
|
||
|
|
||
|
use axum::{
|
||
|
extract::{
|
||
|
ws::{Message, WebSocket},
|
||
|
WebSocketUpgrade,
|
||
|
},
|
||
|
response::Response,
|
||
|
Extension,
|
||
|
};
|
||
|
use foxchat::s2s::{Dispatch, Payload};
|
||
|
use futures::stream::{SplitSink, SplitStream, StreamExt};
|
||
|
use tokio::sync::{broadcast, mpsc};
|
||
|
use tracing::error;
|
||
|
|
||
|
use crate::app_state::AppState;
|
||
|
|
||
|
pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppState>>) -> Response {
|
||
|
ws.on_upgrade(|socket| handle_socket(state, socket))
|
||
|
}
|
||
|
|
||
|
struct SocketState {
|
||
|
instance_id: Option<String>,
|
||
|
}
|
||
|
|
||
|
async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
|
||
|
let (mut sender, mut receiver) = socket.split();
|
||
|
|
||
|
let (tx, rx) = mpsc::channel::<Payload>(10);
|
||
|
let dispatch = app_state.broadcast.subscribe();
|
||
|
let socket_state = Arc::new(SocketState {
|
||
|
instance_id: None, // Filled out after IDENTIFY
|
||
|
});
|
||
|
|
||
|
tokio::spawn(merge_channels(app_state.clone(), socket_state.clone(), dispatch, tx.clone()));
|
||
|
tokio::spawn(write(app_state.clone(), socket_state.clone(), rx, sender));
|
||
|
tokio::spawn(read(app_state.clone(), socket_state.clone(), tx.clone(), receiver));
|
||
|
}
|
||
|
|
||
|
async fn read(
|
||
|
app_state: Arc<AppState>,
|
||
|
socket_state: Arc<SocketState>,
|
||
|
tx: mpsc::Sender<Payload>,
|
||
|
mut receiver: SplitStream<WebSocket>,
|
||
|
) {
|
||
|
let msg = receiver.next().await;
|
||
|
if let Some(Ok(msg)) = msg {
|
||
|
if let Ok(msg) = msg.into_text() {
|
||
|
if let Ok(pl) = serde_json::from_str::<Payload>(&msg) {
|
||
|
match pl {
|
||
|
Payload::Identify { host, date, server, signature } => {
|
||
|
// TODO: identify
|
||
|
},
|
||
|
_ => {
|
||
|
tx.send(Payload::Error { message: "First payload was not IDENTIFY".into() }).await.ok();
|
||
|
return;
|
||
|
}
|
||
|
}
|
||
|
} else {
|
||
|
tx.send(Payload::Error { message: "Invalid JSON payload".into() }).await.ok();
|
||
|
return;
|
||
|
}
|
||
|
} else {
|
||
|
|
||
|
}
|
||
|
} else {
|
||
|
// Websocket closed, return
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
while let Some(msg) = receiver.next().await {
|
||
|
let msg = if let Ok(msg) = msg {
|
||
|
msg
|
||
|
} else {
|
||
|
return;
|
||
|
};
|
||
|
|
||
|
// TODO: handle incoming payloads
|
||
|
}
|
||
|
}
|
||
|
|
||
|
async fn write(
|
||
|
app_state: Arc<AppState>,
|
||
|
socket_state: Arc<SocketState>,
|
||
|
mut rx: mpsc::Receiver<Payload>,
|
||
|
mut sender: SplitSink<WebSocket, Message>,
|
||
|
) {
|
||
|
}
|
||
|
|
||
|
async fn merge_channels(
|
||
|
app_state: Arc<AppState>,
|
||
|
socket_state: Arc<SocketState>,
|
||
|
mut rx: broadcast::Receiver<Dispatch>,
|
||
|
tx: mpsc::Sender<Payload>,
|
||
|
) {
|
||
|
loop {
|
||
|
let msg = rx.recv().await;
|
||
|
match msg {
|
||
|
Ok(p) => {
|
||
|
// TODO: filter users
|
||
|
tx.send(Payload::Dispatch {
|
||
|
event: p,
|
||
|
recipients: vec![],
|
||
|
})
|
||
|
.await
|
||
|
.ok();
|
||
|
},
|
||
|
Err(e) => match e {
|
||
|
broadcast::error::RecvError::Closed => {
|
||
|
error!("Broadcast channel was closed, this is not supposed to happen");
|
||
|
return;
|
||
|
},
|
||
|
broadcast::error::RecvError::Lagged(i) => {
|
||
|
error!("Broadcast receive lagged by {i}");
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|