use std::sync::Arc; use axum::{ extract::{ ws::{Message, WebSocket}, WebSocketUpgrade, }, response::Response, Extension, }; use eyre::Result; use foxchat::{ s2s::{Dispatch, Payload}, signature::{parse_date, verify_signature}, }; use futures::{ stream::{SplitSink, SplitStream, StreamExt}, SinkExt, }; use tokio::sync::{broadcast, mpsc, RwLock}; use tracing::error; use crate::{app_state::AppState, model::identity_instance::IdentityInstance}; pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension>) -> Response { ws.on_upgrade(|socket| handle_socket(state, socket)) } struct SocketState { instance: Option, } async fn handle_socket(app_state: Arc, socket: WebSocket) { let (sender, receiver) = socket.split(); let (tx, rx) = mpsc::channel::(10); let dispatch = app_state.broadcast.subscribe(); let socket_state = Arc::new(RwLock::new(SocketState { instance: None, // Filled out after IDENTIFY })); tokio::spawn(merge_channels( app_state.clone(), socket_state.clone(), dispatch, tx.clone(), )); tokio::spawn(write(rx, sender)); tokio::spawn(read( app_state.clone(), socket_state.clone(), tx.clone(), receiver, )); } async fn read( app_state: Arc, socket_state: Arc>, tx: mpsc::Sender, mut receiver: SplitStream, ) { let msg = receiver.next().await; if let Some(Ok(msg)) = msg { if let Ok(msg) = msg.into_text() { if let Ok(pl) = serde_json::from_str::(&msg) { match pl { Payload::Identify { host, date, server, signature, } => { let instance = match IdentityInstance::get(app_state, &server).await { Ok(i) => i, Err(e) => { error!("getting instance {}: {}", server, e); tx.send(Payload::Error { message: "Unknown instance".into(), }) .await .ok(); return; } }; let public_key = match instance.parse_public_key() { Ok(k) => k, Err(e) => { error!("parsing public key for instance: {}", e); tx.send(Payload::Error { message: "Internal server error".into(), }) .await .ok(); return; } }; let date = match parse_date(&date) { Ok(d) => d, Err(_) => { tx.send(Payload::Error { message: "Invalid date".into(), }) .await .ok(); return; } }; if let Err(e) = verify_signature( &public_key, signature, date, &host, "/_fox/chat/ws", None, None, ) { error!("Verifying signature in websocket from {}: {}", &server, e); tx.send(Payload::Error { message: "Invalid signature".into(), }) .await .ok(); return; } // Everything is good, store instance socket_state.write().await.instance = Some(instance); } _ => { tx.send(Payload::Error { message: "First payload was not IDENTIFY".into(), }) .await .ok(); return; } } } else { tx.send(Payload::Error { message: "Invalid JSON payload".into(), }) .await .ok(); return; } } else { tx.send(Payload::Error { message: "First payload was not IDENTIFY".into(), }) .await .ok(); return; } } else { // Websocket closed, return return; } // Send hello tx.send(Payload::Hello {}).await.ok(); while let Some(msg) = receiver.next().await { let msg = if let Ok(msg) = msg { msg } else { return; }; // TODO: handle incoming payloads } } async fn write(mut rx: mpsc::Receiver, mut sender: SplitSink) { loop { if let Some(ev) = rx.recv().await { match &ev { Payload::Error { message: _ } => { let msg = match to_json(ev) { Ok(m) => m, Err(e) => { error!("error serializing payload to JSON: {}", e); return; } }; match sender.send(msg).await { Ok(_) => {} // Pretty sure we can assume the websocket was closed, so just return Err(e) => { error!("error writing to websocket: {}", e); return; } } match sender.close().await { Ok(_) => {} Err(e) => { error!("error closing websocket: {}", e) } } return; } _ => { // Forward payload let msg = match to_json(ev) { Ok(m) => m, Err(e) => { error!("error serializing payload to JSON: {}", e); return; } }; match sender.send(msg).await { Ok(_) => {} // Pretty sure we can assume the websocket was closed, so just return Err(e) => { error!("error writing to websocket: {}", e); return; } } } } } } } async fn merge_channels( app_state: Arc, socket_state: Arc>, mut rx: broadcast::Receiver, tx: mpsc::Sender, ) { loop { let msg = rx.recv().await; match msg { Ok(evt) => { let (send, recipients) = match filter_events(app_state.clone(), socket_state.clone(), &evt).await { Ok(v) => v, Err(e) => { error!("Checking whether to send an event: {}", e); tx.send(Payload::Error { message: "Internal server error".into(), }) .await .ok(); continue; } }; if !send { continue; } tx.send(Payload::Dispatch { event: evt, recipients, }) .await .ok(); } Err(e) => match e { broadcast::error::RecvError::Closed => { error!("Broadcast channel was closed, this is not supposed to happen"); return; } broadcast::error::RecvError::Lagged(i) => { error!("Broadcast receive lagged by {i}"); } }, } } } fn to_json(p: Payload) -> Result { Ok(Message::Text(serde_json::to_string(&p)?)) } /// Returns whether or not an event should be sent at all, and if it should be, which users to send it to. async fn filter_events( app_state: Arc, socket_state: Arc>, evt: &Dispatch, ) -> Result<(bool, Vec)> { // If we're not authenticated yet, don't send anything if socket_state.read().await.instance.is_none() { return Ok((false, vec![])); } Ok((true, vec![])) }