diff --git a/chat/src/http/ws/mod.rs b/chat/src/http/ws/mod.rs index a3557e1..e38eb4c 100644 --- a/chat/src/http/ws/mod.rs +++ b/chat/src/http/ws/mod.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use axum::{ extract::{ @@ -8,8 +8,9 @@ use axum::{ response::Response, Extension, }; -use eyre::{eyre, Error, Result}; +use eyre::{eyre, Result}; use foxchat::{ + model::User, s2s::{Dispatch, Payload}, signature::{parse_date, verify_signature}, }; @@ -17,7 +18,11 @@ use futures::{ stream::{SplitSink, SplitStream, StreamExt}, SinkExt, }; -use tokio::sync::{broadcast, mpsc, RwLock}; +use rand::Rng; +use tokio::{ + sync::{broadcast, mpsc, RwLock}, + time::timeout, +}; use tracing::error; use crate::{app_state::AppState, model::identity_instance::IdentityInstance}; @@ -28,6 +33,7 @@ pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension, + last_heartbeat: u64, } async fn handle_socket(app_state: Arc, socket: WebSocket) { @@ -36,16 +42,21 @@ async fn handle_socket(app_state: Arc, socket: WebSocket) { let (tx, rx) = mpsc::channel::(10); let dispatch = app_state.broadcast.subscribe(); let socket_state = Arc::new(RwLock::new(SocketState { - instance: None, // Filled out after IDENTIFY + // These are filled out after IDENTIFY + instance: None, + last_heartbeat: 0, })); + // Function that merges the socket-local and broadcast streams for sending. tokio::spawn(merge_channels( app_state.clone(), socket_state.clone(), dispatch, tx.clone(), )); + // Function that writes all payloads to the socket. tokio::spawn(write(rx, sender)); + // Function that reads from the socket. tokio::spawn(read( app_state.clone(), socket_state.clone(), @@ -71,7 +82,8 @@ async fn read( server, signature, } => { - let instance = match IdentityInstance::get(app_state, &server).await { + let instance = match IdentityInstance::get(app_state.clone(), &server).await + { Ok(i) => i, Err(e) => { error!("getting instance {}: {}", server, e); @@ -161,8 +173,17 @@ async fn read( return; } - // Send hello - tx.send(Payload::Hello {}).await.ok(); + // Generate heartbeat interval and send it in the Hello payload + let heartbeat_interval = rand::thread_rng().gen_range(45_000..70_000); + tx.send(Payload::Hello { heartbeat_interval }).await.ok(); + // Start the heartbeat loop + let (heartbeat_tx, heartbeat_rx) = mpsc::channel::(10); + tokio::spawn(heartbeat( + socket_state.clone(), + heartbeat_interval, + heartbeat_rx, + tx.clone(), + )); while let Some(msg) = receiver.next().await { let Ok(msg) = msg else { @@ -188,13 +209,69 @@ async fn read( }; match msg { - Payload::Connect { user_id } => {} + Payload::Connect { user_id } => { + match collect_ready(app_state.clone(), socket_state.clone(), &user_id).await { + Ok(p) => { + tx.send(p).await.ok(); + } + Err(e) => { + error!("Error collecting ready event data for {}: {}", user_id, e); + tx.send(Payload::Error { + message: "Internal server error".into(), + }) + .await + .ok(); + } + } + } + Payload::Heartbeat { timestamp } => { + // Heartbeats are handled in another function + heartbeat_tx.send(timestamp).await.ok(); + } + _ => { + tx.send(Payload::Error { + message: "Invalid send event".into(), + }) + .await + .ok(); + return; + } } - - // TODO: handle incoming payloads } } +async fn heartbeat( + socket_state: Arc>, + heartbeat_interval: u64, + mut rx: mpsc::Receiver, + tx: mpsc::Sender, +) { + // The timeout is twice the heartbeat interval. If no heartbeat is received by then, close the connection with an error. + while let Ok(i) = timeout(Duration::from_millis(heartbeat_interval * 2), rx.recv()).await { + match i { + // ACK the heartbeat with the same timestamp. TODO: validate the timestamp to make sure we aren't too out of sync with the identity server. + Some(timestamp) => { + socket_state.write().await.last_heartbeat = timestamp; + tx.send(Payload::HeartbeatAck { timestamp }).await.ok(); + } + // If the channel returns None, that means it's been closed, which means the socket as a whole was closed + None => { + return; + } + }; + } + + // Send an error, which will automatically close the socket + tx.send(Payload::Error { + message: format!( + "Did not receive a heartbeat after {}ms", + heartbeat_interval * 2 + ), + }) + .await + .ok(); +} + async fn write(mut rx: mpsc::Receiver, mut sender: SplitSink) { loop { if let Some(ev) = rx.recv().await { @@ -355,19 +432,32 @@ async fn filter_events( async fn collect_ready( app_state: Arc, socket_state: Arc>, - user_id: String, + user_id: &str, ) -> Result { let Some(instance) = &socket_state.read().await.instance else { return Err(eyre!("instance was None when it shouldn't be")); }; let user = sqlx::query!( - "SELECT * FROM users WHERE instance_id = $1 AND remote_user_id = $2", + r#"SELECT u.*, i.domain FROM users u + JOIN identity_instances i ON i.id = u.instance_id + WHERE u.instance_id = $1 AND u.remote_user_id = $2"#, instance.id, user_id ) .fetch_one(&app_state.pool) .await?; - todo!() + Ok(Payload::Dispatch { + event: Dispatch::Ready { + user: User { + id: user.id, + username: user.username, + instance: user.domain, + avatar_url: None, + }, + guilds: vec![], + }, + recipients: vec![user.remote_user_id], + }) } diff --git a/foxchat/src/s2s/event.rs b/foxchat/src/s2s/event.rs index 110269c..fb866a6 100644 --- a/foxchat/src/s2s/event.rs +++ b/foxchat/src/s2s/event.rs @@ -5,6 +5,7 @@ use super::Dispatch; #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")] pub enum Payload { + #[serde(rename = "D")] Dispatch { #[serde(rename = "e")] event: Dispatch, @@ -15,7 +16,9 @@ pub enum Payload { message: String, }, /// Hello message, sent after authentication succeeds - Hello {}, + Hello { + heartbeat_interval: u64, + }, /// S2S authentication. Fields correspond to headers (Host, Date, X-Foxchat-Server, X-Foxchat-Signature) Identify { host: String, @@ -26,5 +29,15 @@ pub enum Payload { /// Sent when a user connects to the identity server's gateway, to signal the chat server to send READY for that user Connect { user_id: String, + }, + /// Sent on a regular interval by the connecting server, to keep the connection alive. + Heartbeat { + #[serde(rename = "t")] + timestamp: u64, + }, + /// Sent in response to a Heartbeat. + HeartbeatAck { + #[serde(rename = "t")] + timestamp: u64, } }