diff --git a/Cargo.lock b/Cargo.lock index e92aa5a..d5a7535 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1172,6 +1172,7 @@ dependencies = [ "color-eyre", "eyre", "foxchat", + "futures", "rand", "reqwest", "rsa", @@ -1180,6 +1181,7 @@ dependencies = [ "sha256", "sqlx", "tokio", + "tokio-tungstenite", "toml", "tower-http", "tracing", @@ -1833,10 +1835,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" dependencies = [ "ring", - "rustls-webpki", + "rustls-webpki 0.101.7", "sct", ] +[[package]] +name = "rustls" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41" +dependencies = [ + "log", + "ring", + "rustls-pki-types", + "rustls-webpki 0.102.1", + "subtle", + "zeroize", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" @@ -1846,6 +1862,12 @@ dependencies = [ "base64", ] +[[package]] +name = "rustls-pki-types" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "048a63e5b3ac996d78d402940b5fa47973d2d080c6c6fffa1d0f19c4445310b7" + [[package]] name = "rustls-webpki" version = "0.101.7" @@ -1856,6 +1878,17 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustls-webpki" +version = "0.102.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef4ca26037c909dedb327b48c3327d0ba91d3dd3c4e05dad328f210ffb68e95b" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.14" @@ -2135,7 +2168,7 @@ dependencies = [ "once_cell", "paste", "percent-encoding", - "rustls", + "rustls 0.21.10", "rustls-pemfile", "serde", "serde_json", @@ -2484,6 +2517,7 @@ checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" dependencies = [ "futures-util", "log", + "rustls 0.22.2", "tokio", "tungstenite", ] diff --git a/chat/src/http/api/guilds/create_guild.rs b/chat/src/http/api/guilds/create_guild.rs index d506707..c98065a 100644 --- a/chat/src/http/api/guilds/create_guild.rs +++ b/chat/src/http/api/guilds/create_guild.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use axum::{Extension, Json}; use foxchat::{ http::ApiError, - model::{http::guild::CreateGuildParams, user::PartialUser, Guild, channel::PartialChannel}, + model::{channel::PartialChannel, http::guild::CreateGuildParams, user::PartialUser, Guild}, FoxError, }; @@ -39,8 +39,12 @@ pub async fn post_guilds( instance: user.instance.domain, }, default_channel: PartialChannel { + id: channel.id.0.clone(), + name: channel.name.clone(), + }, + channels: Some(vec![PartialChannel { id: channel.id.0, name: channel.name, - } + }]), })) } diff --git a/chat/src/http/ws/mod.rs b/chat/src/http/ws/mod.rs index 1590a7d..e38eb4c 100644 --- a/chat/src/http/ws/mod.rs +++ b/chat/src/http/ws/mod.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use axum::{ extract::{ @@ -8,8 +8,9 @@ use axum::{ response::Response, Extension, }; -use eyre::Result; +use eyre::{eyre, Result}; use foxchat::{ + model::User, s2s::{Dispatch, Payload}, signature::{parse_date, verify_signature}, }; @@ -17,7 +18,11 @@ use futures::{ stream::{SplitSink, SplitStream, StreamExt}, SinkExt, }; -use tokio::sync::{broadcast, mpsc, RwLock}; +use rand::Rng; +use tokio::{ + sync::{broadcast, mpsc, RwLock}, + time::timeout, +}; use tracing::error; use crate::{app_state::AppState, model::identity_instance::IdentityInstance}; @@ -28,6 +33,7 @@ pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension, + last_heartbeat: u64, } async fn handle_socket(app_state: Arc, socket: WebSocket) { @@ -36,16 +42,21 @@ async fn handle_socket(app_state: Arc, socket: WebSocket) { let (tx, rx) = mpsc::channel::(10); let dispatch = app_state.broadcast.subscribe(); let socket_state = Arc::new(RwLock::new(SocketState { - instance: None, // Filled out after IDENTIFY + // These are filled out after IDENTIFY + instance: None, + last_heartbeat: 0, })); + // Function that merges the socket-local and broadcast streams for sending. tokio::spawn(merge_channels( app_state.clone(), socket_state.clone(), dispatch, tx.clone(), )); + // Function that writes all payloads to the socket. tokio::spawn(write(rx, sender)); + // Function that reads from the socket. tokio::spawn(read( app_state.clone(), socket_state.clone(), @@ -71,7 +82,8 @@ async fn read( server, signature, } => { - let instance = match IdentityInstance::get(app_state, &server).await { + let instance = match IdentityInstance::get(app_state.clone(), &server).await + { Ok(i) => i, Err(e) => { error!("getting instance {}: {}", server, e); @@ -161,20 +173,105 @@ async fn read( return; } - // Send hello - tx.send(Payload::Hello {}).await.ok(); + // Generate heartbeat interval and send it in the Hello payload + let heartbeat_interval = rand::thread_rng().gen_range(45_000..70_000); + tx.send(Payload::Hello { heartbeat_interval }).await.ok(); + // Start the heartbeat loop + let (heartbeat_tx, heartbeat_rx) = mpsc::channel::(10); + tokio::spawn(heartbeat( + socket_state.clone(), + heartbeat_interval, + heartbeat_rx, + tx.clone(), + )); while let Some(msg) = receiver.next().await { - let msg = if let Ok(msg) = msg { - msg - } else { + let Ok(msg) = msg else { return; }; - // TODO: handle incoming payloads + let Ok(msg) = msg.to_text() else { + tx.send(Payload::Error { + message: "Invalid message".into(), + }) + .await + .ok(); + return; + }; + + let Ok(msg) = serde_json::from_str::(msg) else { + tx.send(Payload::Error { + message: "Invalid message".into(), + }) + .await + .ok(); + return; + }; + + match msg { + Payload::Connect { user_id } => { + match collect_ready(app_state.clone(), socket_state.clone(), &user_id).await { + Ok(p) => { + tx.send(p).await.ok(); + } + Err(e) => { + error!("Error collecting ready event data for {}: {}", user_id, e); + tx.send(Payload::Error { + message: "Internal server error".into(), + }) + .await + .ok(); + } + } + } + Payload::Heartbeat { timestamp } => { + // Heartbeats are handled in another function + heartbeat_tx.send(timestamp).await.ok(); + } + _ => { + tx.send(Payload::Error { + message: "Invalid send event".into(), + }) + .await + .ok(); + return; + } + } } } +async fn heartbeat( + socket_state: Arc>, + heartbeat_interval: u64, + mut rx: mpsc::Receiver, + tx: mpsc::Sender, +) { + // The timeout is twice the heartbeat interval. If no heartbeat is received by then, close the connection with an error. + while let Ok(i) = timeout(Duration::from_millis(heartbeat_interval * 2), rx.recv()).await { + match i { + // ACK the heartbeat with the same timestamp. TODO: validate the timestamp to make sure we aren't too out of sync with the identity server. + Some(timestamp) => { + socket_state.write().await.last_heartbeat = timestamp; + tx.send(Payload::HeartbeatAck { timestamp }).await.ok(); + } + // If the channel returns None, that means it's been closed, which means the socket as a whole was closed + None => { + return; + } + }; + } + + // Send an error, which will automatically close the socket + tx.send(Payload::Error { + message: format!( + "Did not receive a heartbeat after {}ms", + heartbeat_interval * 2 + ), + }) + .await + .ok(); +} + async fn write(mut rx: mpsc::Receiver, mut sender: SplitSink) { loop { if let Some(ev) = rx.recv().await { @@ -289,10 +386,78 @@ async fn filter_events( socket_state: Arc>, evt: &Dispatch, ) -> Result<(bool, Vec)> { - // If we're not authenticated yet, don't send anything - if socket_state.read().await.instance.is_none() { + let Some(instance) = &socket_state.read().await.instance else { return Ok((false, vec![])); - } + }; - Ok((true, vec![])) + match evt { + Dispatch::MessageCreate { + id: _, + channel_id: _, + guild_id, + author: _, + content: _, + created_at: _, + } => { + let users = sqlx::query!( + r#"SELECT ARRAY( + SELECT u.remote_user_id FROM users u + JOIN guilds_users gu ON gu.user_id = u.id + WHERE u.instance_id = $1 AND gu.guild_id = $2 + )"#, + instance.id, + guild_id + ) + .fetch_one(&app_state.pool) + .await?; + + if let Some(users) = users.array { + return Ok((users.len() > 0, users)); + } + return Ok((false, vec![])); + } + Dispatch::Ready { user, guilds: _ } => { + let user = sqlx::query!( + "SELECT remote_user_id FROM users WHERE id = $1", + user.id.clone() + ) + .fetch_one(&app_state.pool) + .await?; + + return Ok((true, vec![user.remote_user_id])); + } + } +} + +async fn collect_ready( + app_state: Arc, + socket_state: Arc>, + user_id: &str, +) -> Result { + let Some(instance) = &socket_state.read().await.instance else { + return Err(eyre!("instance was None when it shouldn't be")); + }; + + let user = sqlx::query!( + r#"SELECT u.*, i.domain FROM users u + JOIN identity_instances i ON i.id = u.instance_id + WHERE u.instance_id = $1 AND u.remote_user_id = $2"#, + instance.id, + user_id + ) + .fetch_one(&app_state.pool) + .await?; + + Ok(Payload::Dispatch { + event: Dispatch::Ready { + user: User { + id: user.id, + username: user.username, + instance: user.domain, + avatar_url: None, + }, + guilds: vec![], + }, + recipients: vec![user.remote_user_id], + }) } diff --git a/foxchat/src/c2s/event.rs b/foxchat/src/c2s/event.rs new file mode 100644 index 0000000..f68cb28 --- /dev/null +++ b/foxchat/src/c2s/event.rs @@ -0,0 +1,25 @@ +use serde::{Deserialize, Serialize}; + +use crate::s2s::Dispatch; + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")] +pub enum Payload { + #[serde(rename = "D")] + Dispatch { + #[serde(rename = "e")] + event: Dispatch, + #[serde(rename = "s")] + server_id: String, + }, + Error { + message: String, + }, + /// Hello message, sent after authentication succeeds + Hello { + guilds: Vec, + }, + Identify { + token: String, + }, +} diff --git a/foxchat/src/c2s/mod.rs b/foxchat/src/c2s/mod.rs new file mode 100644 index 0000000..47b0373 --- /dev/null +++ b/foxchat/src/c2s/mod.rs @@ -0,0 +1,3 @@ +pub mod event; + +pub use event::Payload; diff --git a/foxchat/src/lib.rs b/foxchat/src/lib.rs index 03ad1c1..b30451f 100644 --- a/foxchat/src/lib.rs +++ b/foxchat/src/lib.rs @@ -3,6 +3,7 @@ pub mod fed; pub mod http; pub mod model; pub mod s2s; +pub mod c2s; pub mod id; pub use error::FoxError; diff --git a/foxchat/src/model/channel.rs b/foxchat/src/model/channel.rs index 65be52a..e5d5c8f 100644 --- a/foxchat/src/model/channel.rs +++ b/foxchat/src/model/channel.rs @@ -8,7 +8,7 @@ pub struct Channel { pub topic: Option, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct PartialChannel { pub id: String, pub name: String, diff --git a/foxchat/src/model/guild.rs b/foxchat/src/model/guild.rs index 11c0159..7518f5e 100644 --- a/foxchat/src/model/guild.rs +++ b/foxchat/src/model/guild.rs @@ -1,11 +1,12 @@ use serde::{Serialize, Deserialize}; -use super::{user::PartialUser, channel::PartialChannel}; +use super::{channel::PartialChannel, user::PartialUser}; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct Guild { pub id: String, pub name: String, pub owner: PartialUser, pub default_channel: PartialChannel, + pub channels: Option>, } diff --git a/foxchat/src/s2s/dispatch.rs b/foxchat/src/s2s/dispatch.rs index 3718331..7b30f0d 100644 --- a/foxchat/src/s2s/dispatch.rs +++ b/foxchat/src/s2s/dispatch.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use crate::{ id::GuildType, - model::{user::PartialUser, Message}, + model::{user::PartialUser, Guild, Message, User}, Id, }; @@ -18,6 +18,10 @@ pub enum Dispatch { content: Option, created_at: DateTime, }, + Ready { + user: User, + guilds: Vec, + } } impl Dispatch { diff --git a/foxchat/src/s2s/event.rs b/foxchat/src/s2s/event.rs index 110269c..fb866a6 100644 --- a/foxchat/src/s2s/event.rs +++ b/foxchat/src/s2s/event.rs @@ -5,6 +5,7 @@ use super::Dispatch; #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")] pub enum Payload { + #[serde(rename = "D")] Dispatch { #[serde(rename = "e")] event: Dispatch, @@ -15,7 +16,9 @@ pub enum Payload { message: String, }, /// Hello message, sent after authentication succeeds - Hello {}, + Hello { + heartbeat_interval: u64, + }, /// S2S authentication. Fields correspond to headers (Host, Date, X-Foxchat-Server, X-Foxchat-Signature) Identify { host: String, @@ -26,5 +29,15 @@ pub enum Payload { /// Sent when a user connects to the identity server's gateway, to signal the chat server to send READY for that user Connect { user_id: String, + }, + /// Sent on a regular interval by the connecting server, to keep the connection alive. + Heartbeat { + #[serde(rename = "t")] + timestamp: u64, + }, + /// Sent in response to a Heartbeat. + HeartbeatAck { + #[serde(rename = "t")] + timestamp: u64, } } diff --git a/identity/Cargo.toml b/identity/Cargo.toml index da31cf7..e9b6254 100644 --- a/identity/Cargo.toml +++ b/identity/Cargo.toml @@ -27,3 +27,5 @@ base64 = "0.21.7" sha256 = "1.5.0" reqwest = { version = "0.11.23", features = ["json", "gzip", "brotli", "multipart"] } chrono = "0.4.31" +futures = "0.3.30" +tokio-tungstenite = { version = "0.21.0", features = ["rustls"] } diff --git a/identity/src/http/mod.rs b/identity/src/http/mod.rs index 926283b..7490413 100644 --- a/identity/src/http/mod.rs +++ b/identity/src/http/mod.rs @@ -2,6 +2,7 @@ mod account; mod auth; mod node; mod proxy; +mod ws; use std::sync::Arc; diff --git a/identity/src/http/ws/mod.rs b/identity/src/http/ws/mod.rs new file mode 100644 index 0000000..c716906 --- /dev/null +++ b/identity/src/http/ws/mod.rs @@ -0,0 +1,23 @@ +use std::sync::Arc; + +use axum::{extract::{ws::WebSocket, WebSocketUpgrade}, response::Response, Extension}; +use futures::{ + stream::{SplitSink, SplitStream, StreamExt}, + SinkExt, +}; + +use crate::{app_state::AppState, model::account::Account}; + +pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension>) -> Response { + ws.on_upgrade(|socket| handle_socket(state, socket)) +} + +struct SocketState { + user: Option, +} + +async fn handle_socket(app_state: Arc, socket: WebSocket) { + let (sender, receiver) = socket.split(); + + +}