diff --git a/chat/src/http/ws/mod.rs b/chat/src/http/ws/mod.rs index 28643ba..1590a7d 100644 --- a/chat/src/http/ws/mod.rs +++ b/chat/src/http/ws/mod.rs @@ -8,38 +8,55 @@ use axum::{ response::Response, Extension, }; -use foxchat::s2s::{Dispatch, Payload}; -use futures::stream::{SplitSink, SplitStream, StreamExt}; -use tokio::sync::{broadcast, mpsc}; +use eyre::Result; +use foxchat::{ + s2s::{Dispatch, Payload}, + signature::{parse_date, verify_signature}, +}; +use futures::{ + stream::{SplitSink, SplitStream, StreamExt}, + SinkExt, +}; +use tokio::sync::{broadcast, mpsc, RwLock}; use tracing::error; -use crate::app_state::AppState; +use crate::{app_state::AppState, model::identity_instance::IdentityInstance}; pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension>) -> Response { ws.on_upgrade(|socket| handle_socket(state, socket)) } struct SocketState { - instance_id: Option, + instance: Option, } async fn handle_socket(app_state: Arc, socket: WebSocket) { - let (mut sender, mut receiver) = socket.split(); + let (sender, receiver) = socket.split(); let (tx, rx) = mpsc::channel::(10); let dispatch = app_state.broadcast.subscribe(); - let socket_state = Arc::new(SocketState { - instance_id: None, // Filled out after IDENTIFY - }); + let socket_state = Arc::new(RwLock::new(SocketState { + instance: None, // Filled out after IDENTIFY + })); - tokio::spawn(merge_channels(app_state.clone(), socket_state.clone(), dispatch, tx.clone())); - tokio::spawn(write(app_state.clone(), socket_state.clone(), rx, sender)); - tokio::spawn(read(app_state.clone(), socket_state.clone(), tx.clone(), receiver)); + tokio::spawn(merge_channels( + app_state.clone(), + socket_state.clone(), + dispatch, + tx.clone(), + )); + tokio::spawn(write(rx, sender)); + tokio::spawn(read( + app_state.clone(), + socket_state.clone(), + tx.clone(), + receiver, + )); } async fn read( app_state: Arc, - socket_state: Arc, + socket_state: Arc>, tx: mpsc::Sender, mut receiver: SplitStream, ) { @@ -48,26 +65,105 @@ async fn read( if let Ok(msg) = msg.into_text() { if let Ok(pl) = serde_json::from_str::(&msg) { match pl { - Payload::Identify { host, date, server, signature } => { - // TODO: identify - }, + Payload::Identify { + host, + date, + server, + signature, + } => { + let instance = match IdentityInstance::get(app_state, &server).await { + Ok(i) => i, + Err(e) => { + error!("getting instance {}: {}", server, e); + tx.send(Payload::Error { + message: "Unknown instance".into(), + }) + .await + .ok(); + return; + } + }; + + let public_key = match instance.parse_public_key() { + Ok(k) => k, + Err(e) => { + error!("parsing public key for instance: {}", e); + tx.send(Payload::Error { + message: "Internal server error".into(), + }) + .await + .ok(); + return; + } + }; + + let date = match parse_date(&date) { + Ok(d) => d, + Err(_) => { + tx.send(Payload::Error { + message: "Invalid date".into(), + }) + .await + .ok(); + return; + } + }; + + if let Err(e) = verify_signature( + &public_key, + signature, + date, + &host, + "/_fox/chat/ws", + None, + None, + ) { + error!("Verifying signature in websocket from {}: {}", &server, e); + + tx.send(Payload::Error { + message: "Invalid signature".into(), + }) + .await + .ok(); + return; + } + + // Everything is good, store instance + socket_state.write().await.instance = Some(instance); + } _ => { - tx.send(Payload::Error { message: "First payload was not IDENTIFY".into() }).await.ok(); + tx.send(Payload::Error { + message: "First payload was not IDENTIFY".into(), + }) + .await + .ok(); return; } } } else { - tx.send(Payload::Error { message: "Invalid JSON payload".into() }).await.ok(); + tx.send(Payload::Error { + message: "Invalid JSON payload".into(), + }) + .await + .ok(); return; } } else { - + tx.send(Payload::Error { + message: "First payload was not IDENTIFY".into(), + }) + .await + .ok(); + return; } } else { // Websocket closed, return return; } + // Send hello + tx.send(Payload::Hello {}).await.ok(); + while let Some(msg) = receiver.next().await { let msg = if let Ok(msg) = msg { msg @@ -79,41 +175,124 @@ async fn read( } } -async fn write( - app_state: Arc, - socket_state: Arc, - mut rx: mpsc::Receiver, - mut sender: SplitSink, -) { +async fn write(mut rx: mpsc::Receiver, mut sender: SplitSink) { + loop { + if let Some(ev) = rx.recv().await { + match &ev { + Payload::Error { message: _ } => { + let msg = match to_json(ev) { + Ok(m) => m, + Err(e) => { + error!("error serializing payload to JSON: {}", e); + return; + } + }; + + match sender.send(msg).await { + Ok(_) => {} + // Pretty sure we can assume the websocket was closed, so just return + Err(e) => { + error!("error writing to websocket: {}", e); + return; + } + } + + match sender.close().await { + Ok(_) => {} + Err(e) => { + error!("error closing websocket: {}", e) + } + } + + return; + } + _ => { + // Forward payload + let msg = match to_json(ev) { + Ok(m) => m, + Err(e) => { + error!("error serializing payload to JSON: {}", e); + return; + } + }; + + match sender.send(msg).await { + Ok(_) => {} + // Pretty sure we can assume the websocket was closed, so just return + Err(e) => { + error!("error writing to websocket: {}", e); + return; + } + } + } + } + } + } } async fn merge_channels( app_state: Arc, - socket_state: Arc, + socket_state: Arc>, mut rx: broadcast::Receiver, tx: mpsc::Sender, ) { loop { let msg = rx.recv().await; match msg { - Ok(p) => { - // TODO: filter users + Ok(evt) => { + let (send, recipients) = + match filter_events(app_state.clone(), socket_state.clone(), &evt).await { + Ok(v) => v, + Err(e) => { + error!("Checking whether to send an event: {}", e); + tx.send(Payload::Error { + message: "Internal server error".into(), + }) + .await + .ok(); + + continue; + } + }; + + if !send { + continue; + } + tx.send(Payload::Dispatch { - event: p, - recipients: vec![], + event: evt, + recipients, }) .await .ok(); - }, + } Err(e) => match e { broadcast::error::RecvError::Closed => { error!("Broadcast channel was closed, this is not supposed to happen"); return; - }, + } broadcast::error::RecvError::Lagged(i) => { error!("Broadcast receive lagged by {i}"); } - } + }, } } } + +fn to_json(p: Payload) -> Result { + Ok(Message::Text(serde_json::to_string(&p)?)) +} + +/// Returns whether or not an event should be sent at all, and if it should be, which users to send it to. +async fn filter_events( + app_state: Arc, + socket_state: Arc>, + evt: &Dispatch, +) -> Result<(bool, Vec)> { + // If we're not authenticated yet, don't send anything + if socket_state.read().await.instance.is_none() { + return Ok((false, vec![])); + } + + Ok((true, vec![])) +} diff --git a/foxchat/src/model/message.rs b/foxchat/src/model/message.rs index 05ec899..987664c 100644 --- a/foxchat/src/model/message.rs +++ b/foxchat/src/model/message.rs @@ -3,7 +3,7 @@ use serde::{Serialize, Deserialize}; use super::user::PartialUser; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct Message { pub id: String, pub channel_id: String, diff --git a/foxchat/src/model/user.rs b/foxchat/src/model/user.rs index 6268a5e..78ebef4 100644 --- a/foxchat/src/model/user.rs +++ b/foxchat/src/model/user.rs @@ -1,6 +1,6 @@ -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct User { pub id: String, pub username: String, @@ -8,7 +8,7 @@ pub struct User { pub avatar_url: Option, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct PartialUser { pub id: String, pub username: String, diff --git a/foxchat/src/s2s/dispatch.rs b/foxchat/src/s2s/dispatch.rs index ffbc441..3718331 100644 --- a/foxchat/src/s2s/dispatch.rs +++ b/foxchat/src/s2s/dispatch.rs @@ -1,7 +1,34 @@ +use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; +use crate::{ + id::GuildType, + model::{user::PartialUser, Message}, + Id, +}; + #[derive(Debug, Serialize, Deserialize, Clone)] #[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")] pub enum Dispatch { - MessageCreate + MessageCreate { + id: String, + channel_id: String, + guild_id: String, + author: PartialUser, + content: Option, + created_at: DateTime, + }, +} + +impl Dispatch { + pub fn message_create(m: Message, guild_id: Id) -> Self { + Self::MessageCreate { + id: m.id, + channel_id: m.channel_id, + guild_id: guild_id.0, + author: m.author, + content: m.content, + created_at: m.created_at, + } + } } diff --git a/foxchat/src/s2s/event.rs b/foxchat/src/s2s/event.rs index 3a0af2b..110269c 100644 --- a/foxchat/src/s2s/event.rs +++ b/foxchat/src/s2s/event.rs @@ -11,14 +11,20 @@ pub enum Payload { #[serde(rename = "r")] recipients: Vec, }, - Hello, + Error { + message: String, + }, + /// Hello message, sent after authentication succeeds + Hello {}, + /// S2S authentication. Fields correspond to headers (Host, Date, X-Foxchat-Server, X-Foxchat-Signature) Identify { host: String, date: String, server: String, signature: String, }, - Error { - message: String, + /// Sent when a user connects to the identity server's gateway, to signal the chat server to send READY for that user + Connect { + user_id: String, } }