diff --git a/foxchat/src/c2s/event.rs b/foxchat/src/c2s/event.rs index f68cb28..72f316a 100644 --- a/foxchat/src/c2s/event.rs +++ b/foxchat/src/c2s/event.rs @@ -17,9 +17,20 @@ pub enum Payload { }, /// Hello message, sent after authentication succeeds Hello { + heartbeat_interval: u64, guilds: Vec, }, Identify { token: String, }, + /// Sent on a regular interval by the client, to keep the connection alive. + Heartbeat { + #[serde(rename = "t")] + timestamp: u64, + }, + /// Sent in response to a Heartbeat. + HeartbeatAck { + #[serde(rename = "t")] + timestamp: u64, + } } diff --git a/identity/src/http/mod.rs b/identity/src/http/mod.rs index 7490413..1936d8b 100644 --- a/identity/src/http/mod.rs +++ b/identity/src/http/mod.rs @@ -25,6 +25,7 @@ pub fn new(pool: Pool, config: Config, instance: Instance) -> Router { .nest("/_fox/proxy", proxy::router()) .route("/_fox/ident/node", get(node::get_node)) .route("/_fox/ident/node/:domain", get(node::get_chat_node)) + .route("/_fox/ident/ws", get(ws::handler)) .layer(TraceLayer::new_for_http()) .layer(Extension(app_state)); diff --git a/identity/src/http/ws/mod.rs b/identity/src/http/ws/mod.rs index c716906..a745e36 100644 --- a/identity/src/http/ws/mod.rs +++ b/identity/src/http/ws/mod.rs @@ -1,23 +1,274 @@ -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; -use axum::{extract::{ws::WebSocket, WebSocketUpgrade}, response::Response, Extension}; +use axum::{ + extract::{ + ws::{Message, WebSocket}, + WebSocketUpgrade, + }, + response::Response, + Extension, +}; +use eyre::Result; +use foxchat::{c2s::Payload, FoxError}; use futures::{ stream::{SplitSink, SplitStream, StreamExt}, SinkExt, }; +use rand::Rng; +use tokio::{sync::mpsc::{self, Receiver, Sender}, time::timeout}; +use tracing::error; -use crate::{app_state::AppState, model::account::Account}; +use crate::{app_state::AppState, db::check_token, model::account::Account}; pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension>) -> Response { ws.on_upgrade(|socket| handle_socket(state, socket)) } -struct SocketState { - user: Option, -} - async fn handle_socket(app_state: Arc, socket: WebSocket) { let (sender, receiver) = socket.split(); - + let (tx, rx) = mpsc::channel::(10); + + tokio::spawn(read(app_state, sender, receiver, tx, rx)); +} + +async fn read( + app_state: Arc, + mut sender: SplitSink, + mut receiver: SplitStream, + tx: Sender, + rx: Receiver, +) { + let msg = receiver.next().await; + let Some(msg) = msg else { + // Websocket was closed, so return + return; + }; + + let Ok(msg) = msg else { + let Ok(p) = to_json(Payload::Error { + message: "Internal server error".into(), + }) else { + return; + }; + + sender.send(p).await.ok(); + return; + }; + + // Convert Message into a string. This can fail if the payload is not valid UTF-8. + let Ok(msg) = msg.into_text() else { + let Ok(p) = to_json(Payload::Error { + message: "Invalid event".into(), + }) else { + return; + }; + + sender.send(p).await.ok(); + return; + }; + + // Parse payload JSON + let Ok(pl) = serde_json::from_str::(&msg) else { + let Ok(p) = to_json(Payload::Error { + message: "Invalid event".into(), + }) else { + return; + }; + + sender.send(p).await.ok(); + return; + }; + + let Payload::Identify { token } = pl else { + let Ok(p) = to_json(Payload::Error { + message: "First event was not IDENTIFY".into(), + }) else { + return; + }; + + sender.send(p).await.ok(); + return; + }; + + // Check the token. `check_token` can return an error if the token is invalid (or a database error occurs), so handle that. + let user = match check_token(&app_state.pool, token).await { + Ok(u) => u, + Err(e) => { + // The only FoxError this function can return is NotFound + if let Some(_) = e.downcast_ref::() { + let Ok(p) = to_json(Payload::Error { + message: "Invalid token".into(), + }) else { + return; + }; + + sender.send(p).await.ok(); + return; + } + + let Ok(p) = to_json(Payload::Error { + message: "Internal server error".into(), + }) else { + return; + }; + + sender.send(p).await.ok(); + return; + } + }; + + // Put the user in a comfy little box, because it'll be used by multiple tasks + let user = Arc::new(user); + // Spawn the `write` task, all writes will now go through tx.send() + tokio::spawn(write(app_state.clone(), user.clone(), sender, rx)); + + // Send HELLO event + // TODO: fetch guild IDs + let heartbeat_interval = rand::thread_rng().gen_range(45_000..70_000); + tx.send(Payload::Hello { heartbeat_interval, guilds: vec![] }).await.ok(); + + // Start the heartbeat loop + let (heartbeat_tx, heartbeat_rx) = mpsc::channel::(10); + tokio::spawn(heartbeat(heartbeat_interval, heartbeat_rx, tx.clone())); + + // Fire off ready event + tokio::spawn(collect_ready(app_state.clone(), user.clone(), tx.clone())); + + // Start listening for events + while let Some(msg) = receiver.next().await { + let Ok(msg) = msg else { + return; + }; + + let Ok(msg) = msg.to_text() else { + tx.send(Payload::Error { + message: "Invalid message".into(), + }) + .await + .ok(); + return; + }; + + let Ok(msg) = serde_json::from_str::(msg) else { + tx.send(Payload::Error { + message: "Invalid message".into(), + }) + .await + .ok(); + return; + }; + + match msg { + Payload::Heartbeat { timestamp } => { + // Heartbeats are handled in another function + heartbeat_tx.send(timestamp).await.ok(); + } + // TODO: handle other events + _ => { + tx.send(Payload::Error { + message: "Invalid send event".into(), + }) + .await + .ok(); + return; + } + } + } +} + +async fn write( + app_state: Arc, + user: Arc, + mut sender: SplitSink, + mut rx: Receiver, +) { + loop { + if let Some(ev) = rx.recv().await { + match &ev { + Payload::Error { message: _ } => { + let msg = match to_json(ev) { + Ok(m) => m, + Err(e) => { + error!("error serializing payload to JSON: {}", e); + return; + } + }; + + match sender.send(msg).await { + Ok(_) => {} + // Pretty sure we can assume the websocket was closed, so just return + Err(e) => { + error!("error writing to websocket: {}", e); + return; + } + } + + match sender.close().await { + Ok(_) => {} + Err(e) => { + error!("error closing websocket: {}", e) + } + } + + return; + } + _ => { + // Forward payload + let msg = match to_json(ev) { + Ok(m) => m, + Err(e) => { + error!("error serializing payload to JSON: {}", e); + return; + } + }; + + match sender.send(msg).await { + Ok(_) => {} + // Pretty sure we can assume the websocket was closed, so just return + Err(e) => { + error!("error writing to websocket: {}", e); + return; + } + } + } + } + } + } +} + +async fn collect_ready(app_state: Arc, user: Arc, tx: Sender) {} + +async fn heartbeat( + heartbeat_interval: u64, + mut rx: mpsc::Receiver, + tx: mpsc::Sender, +) { + // The timeout is twice the heartbeat interval. If no heartbeat is received by then, close the connection with an error. + while let Ok(i) = timeout(Duration::from_millis(heartbeat_interval * 2), rx.recv()).await { + match i { + // ACK the heartbeat with the same timestamp. TODO: validate the timestamp to make sure we aren't too out of sync with the identity server. + Some(timestamp) => { + tx.send(Payload::HeartbeatAck { timestamp }).await.ok(); + } + // If the channel returns None, that means it's been closed, which means the socket as a whole was closed + None => { + return; + } + }; + } + + // Send an error, which will automatically close the socket + tx.send(Payload::Error { + message: format!( + "Did not receive a heartbeat after {}ms", + heartbeat_interval * 2 + ), + }) + .await + .ok(); +} + +fn to_json(p: Payload) -> Result { + Ok(Message::Text(serde_json::to_string(&p)?)) }