use std::sync::Arc; use axum::{ extract::{ ws::{Message, WebSocket}, WebSocketUpgrade, }, response::Response, Extension, }; use eyre::{eyre, Error, Result}; use foxchat::{ s2s::{Dispatch, Payload}, signature::{parse_date, verify_signature}, }; use futures::{ stream::{SplitSink, SplitStream, StreamExt}, SinkExt, }; use tokio::sync::{broadcast, mpsc, RwLock}; use tracing::error; use crate::{app_state::AppState, model::identity_instance::IdentityInstance}; pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppState>>) -> Response { ws.on_upgrade(|socket| handle_socket(state, socket)) } struct SocketState { instance: Option<IdentityInstance>, } async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) { let (sender, receiver) = socket.split(); let (tx, rx) = mpsc::channel::<Payload>(10); let dispatch = app_state.broadcast.subscribe(); let socket_state = Arc::new(RwLock::new(SocketState { instance: None, // Filled out after IDENTIFY })); tokio::spawn(merge_channels( app_state.clone(), socket_state.clone(), dispatch, tx.clone(), )); tokio::spawn(write(rx, sender)); tokio::spawn(read( app_state.clone(), socket_state.clone(), tx.clone(), receiver, )); } async fn read( app_state: Arc<AppState>, socket_state: Arc<RwLock<SocketState>>, tx: mpsc::Sender<Payload>, mut receiver: SplitStream<WebSocket>, ) { let msg = receiver.next().await; if let Some(Ok(msg)) = msg { if let Ok(msg) = msg.into_text() { if let Ok(pl) = serde_json::from_str::<Payload>(&msg) { match pl { Payload::Identify { host, date, server, signature, } => { let instance = match IdentityInstance::get(app_state, &server).await { Ok(i) => i, Err(e) => { error!("getting instance {}: {}", server, e); tx.send(Payload::Error { message: "Unknown instance".into(), }) .await .ok(); return; } }; let public_key = match instance.parse_public_key() { Ok(k) => k, Err(e) => { error!("parsing public key for instance: {}", e); tx.send(Payload::Error { message: "Internal server error".into(), }) .await .ok(); return; } }; let date = match parse_date(&date) { Ok(d) => d, Err(_) => { tx.send(Payload::Error { message: "Invalid date".into(), }) .await .ok(); return; } }; if let Err(e) = verify_signature( &public_key, signature, date, &host, "/_fox/chat/ws", None, None, ) { error!("Verifying signature in websocket from {}: {}", &server, e); tx.send(Payload::Error { message: "Invalid signature".into(), }) .await .ok(); return; } // Everything is good, store instance socket_state.write().await.instance = Some(instance); } _ => { tx.send(Payload::Error { message: "First payload was not IDENTIFY".into(), }) .await .ok(); return; } } } else { tx.send(Payload::Error { message: "Invalid JSON payload".into(), }) .await .ok(); return; } } else { tx.send(Payload::Error { message: "First payload was not IDENTIFY".into(), }) .await .ok(); return; } } else { // Websocket closed, return return; } // Send hello tx.send(Payload::Hello {}).await.ok(); while let Some(msg) = receiver.next().await { let Ok(msg) = msg else { return; }; let Ok(msg) = msg.to_text() else { tx.send(Payload::Error { message: "Invalid message".into(), }) .await .ok(); return; }; let Ok(msg) = serde_json::from_str::<Payload>(msg) else { tx.send(Payload::Error { message: "Invalid message".into(), }) .await .ok(); return; }; match msg { Payload::Connect { user_id } => {} } // TODO: handle incoming payloads } } async fn write(mut rx: mpsc::Receiver<Payload>, mut sender: SplitSink<WebSocket, Message>) { loop { if let Some(ev) = rx.recv().await { match &ev { Payload::Error { message: _ } => { let msg = match to_json(ev) { Ok(m) => m, Err(e) => { error!("error serializing payload to JSON: {}", e); return; } }; match sender.send(msg).await { Ok(_) => {} // Pretty sure we can assume the websocket was closed, so just return Err(e) => { error!("error writing to websocket: {}", e); return; } } match sender.close().await { Ok(_) => {} Err(e) => { error!("error closing websocket: {}", e) } } return; } _ => { // Forward payload let msg = match to_json(ev) { Ok(m) => m, Err(e) => { error!("error serializing payload to JSON: {}", e); return; } }; match sender.send(msg).await { Ok(_) => {} // Pretty sure we can assume the websocket was closed, so just return Err(e) => { error!("error writing to websocket: {}", e); return; } } } } } } } async fn merge_channels( app_state: Arc<AppState>, socket_state: Arc<RwLock<SocketState>>, mut rx: broadcast::Receiver<Dispatch>, tx: mpsc::Sender<Payload>, ) { loop { let msg = rx.recv().await; match msg { Ok(evt) => { let (send, recipients) = match filter_events(app_state.clone(), socket_state.clone(), &evt).await { Ok(v) => v, Err(e) => { error!("Checking whether to send an event: {}", e); tx.send(Payload::Error { message: "Internal server error".into(), }) .await .ok(); continue; } }; if !send { continue; } tx.send(Payload::Dispatch { event: evt, recipients, }) .await .ok(); } Err(e) => match e { broadcast::error::RecvError::Closed => { error!("Broadcast channel was closed, this is not supposed to happen"); return; } broadcast::error::RecvError::Lagged(i) => { error!("Broadcast receive lagged by {i}"); } }, } } } fn to_json(p: Payload) -> Result<Message> { Ok(Message::Text(serde_json::to_string(&p)?)) } /// Returns whether or not an event should be sent at all, and if it should be, which users to send it to. async fn filter_events( app_state: Arc<AppState>, socket_state: Arc<RwLock<SocketState>>, evt: &Dispatch, ) -> Result<(bool, Vec<String>)> { let Some(instance) = &socket_state.read().await.instance else { return Ok((false, vec![])); }; match evt { Dispatch::MessageCreate { id: _, channel_id: _, guild_id, author: _, content: _, created_at: _, } => { let users = sqlx::query!( r#"SELECT ARRAY( SELECT u.remote_user_id FROM users u JOIN guilds_users gu ON gu.user_id = u.id WHERE u.instance_id = $1 AND gu.guild_id = $2 )"#, instance.id, guild_id ) .fetch_one(&app_state.pool) .await?; if let Some(users) = users.array { return Ok((users.len() > 0, users)); } return Ok((false, vec![])); } Dispatch::Ready { user, guilds: _ } => { let user = sqlx::query!( "SELECT remote_user_id FROM users WHERE id = $1", user.id.clone() ) .fetch_one(&app_state.pool) .await?; return Ok((true, vec![user.remote_user_id])); } } } async fn collect_ready( app_state: Arc<AppState>, socket_state: Arc<RwLock<SocketState>>, user_id: String, ) -> Result<Payload> { let Some(instance) = &socket_state.read().await.instance else { return Err(eyre!("instance was None when it shouldn't be")); }; let user = sqlx::query!( "SELECT * FROM users WHERE instance_id = $1 AND remote_user_id = $2", instance.id, user_id ) .fetch_one(&app_state.pool) .await?; todo!() }