From f07333e35826ef8809bf1ce43b1b3b5cf320cc81 Mon Sep 17 00:00:00 2001 From: sam Date: Sun, 21 Jan 2024 22:30:56 +0100 Subject: [PATCH] very very unfinished websocket --- Cargo.lock | 29 +++++++++ Cargo.toml | 3 - chat/Cargo.toml | 1 + chat/src/app_state.rs | 3 + chat/src/http/mod.rs | 9 ++- chat/src/http/ws/mod.rs | 119 ++++++++++++++++++++++++++++++++++++ foxchat/src/s2s/dispatch.rs | 2 +- foxchat/src/s2s/event.rs | 8 ++- 8 files changed, 168 insertions(+), 6 deletions(-) create mode 100644 chat/src/http/ws/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 30f3071..e92aa5a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -389,6 +389,7 @@ dependencies = [ "color-eyre", "eyre", "foxchat", + "futures", "rand", "reqwest", "rsa", @@ -762,6 +763,21 @@ dependencies = [ "uuid", ] +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.30" @@ -806,6 +822,17 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "futures-sink" version = "0.3.30" @@ -824,8 +851,10 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ + "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", diff --git a/Cargo.toml b/Cargo.toml index 61e16dc..23e83db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,3 @@ members = [ "chat" ] resolver = "2" - -[profile.dev.package.num-bigint-dig] -opt-level = 3 diff --git a/chat/Cargo.toml b/chat/Cargo.toml index 2f0efdf..1ac60a5 100644 --- a/chat/Cargo.toml +++ b/chat/Cargo.toml @@ -24,3 +24,4 @@ tracing = "0.1.40" tower-http = { version = "0.5.1", features = ["trace"] } chrono = "0.4.31" reqwest = { version = "0.11.23", features = ["json"] } +futures = "0.3.30" diff --git a/chat/src/app_state.rs b/chat/src/app_state.rs index 07bd75d..d14dd1c 100644 --- a/chat/src/app_state.rs +++ b/chat/src/app_state.rs @@ -1,5 +1,7 @@ +use foxchat::s2s::Dispatch; use rsa::{RsaPublicKey, RsaPrivateKey}; use sqlx::{Pool, Postgres}; +use tokio::sync::broadcast::Sender; use crate::config::Config; @@ -8,4 +10,5 @@ pub struct AppState { pub config: Config, pub public_key: RsaPublicKey, pub private_key: RsaPrivateKey, + pub broadcast: Sender, } diff --git a/chat/src/http/mod.rs b/chat/src/http/mod.rs index 00fb7c4..fe03027 100644 --- a/chat/src/http/mod.rs +++ b/chat/src/http/mod.rs @@ -1,22 +1,29 @@ mod api; mod hello; +mod ws; use crate::{app_state::AppState, config::Config, model::instance::Instance}; -use axum::{routing::post, Extension, Router}; +use axum::{routing::{post, get}, Extension, Router}; +use foxchat::s2s::Dispatch; use sqlx::{Pool, Postgres}; +use tokio::sync::broadcast; use std::sync::Arc; use tower_http::trace::TraceLayer; pub fn new(pool: Pool, config: Config, instance: Instance) -> Router { + let (broadcast, _) = broadcast::channel::(1024); + let app_state = Arc::new(AppState { pool, config, public_key: instance.public_key, private_key: instance.private_key, + broadcast, }); let app = Router::new() .route("/_fox/chat/hello", post(hello::post_hello)) + .route("/_fox/chat/ws", get(ws::handler)) .merge(api::router()) .layer(TraceLayer::new_for_http()) .layer(Extension(app_state)); diff --git a/chat/src/http/ws/mod.rs b/chat/src/http/ws/mod.rs new file mode 100644 index 0000000..28643ba --- /dev/null +++ b/chat/src/http/ws/mod.rs @@ -0,0 +1,119 @@ +use std::sync::Arc; + +use axum::{ + extract::{ + ws::{Message, WebSocket}, + WebSocketUpgrade, + }, + response::Response, + Extension, +}; +use foxchat::s2s::{Dispatch, Payload}; +use futures::stream::{SplitSink, SplitStream, StreamExt}; +use tokio::sync::{broadcast, mpsc}; +use tracing::error; + +use crate::app_state::AppState; + +pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension>) -> Response { + ws.on_upgrade(|socket| handle_socket(state, socket)) +} + +struct SocketState { + instance_id: Option, +} + +async fn handle_socket(app_state: Arc, socket: WebSocket) { + let (mut sender, mut receiver) = socket.split(); + + let (tx, rx) = mpsc::channel::(10); + let dispatch = app_state.broadcast.subscribe(); + let socket_state = Arc::new(SocketState { + instance_id: None, // Filled out after IDENTIFY + }); + + tokio::spawn(merge_channels(app_state.clone(), socket_state.clone(), dispatch, tx.clone())); + tokio::spawn(write(app_state.clone(), socket_state.clone(), rx, sender)); + tokio::spawn(read(app_state.clone(), socket_state.clone(), tx.clone(), receiver)); +} + +async fn read( + app_state: Arc, + socket_state: Arc, + tx: mpsc::Sender, + mut receiver: SplitStream, +) { + let msg = receiver.next().await; + if let Some(Ok(msg)) = msg { + if let Ok(msg) = msg.into_text() { + if let Ok(pl) = serde_json::from_str::(&msg) { + match pl { + Payload::Identify { host, date, server, signature } => { + // TODO: identify + }, + _ => { + tx.send(Payload::Error { message: "First payload was not IDENTIFY".into() }).await.ok(); + return; + } + } + } else { + tx.send(Payload::Error { message: "Invalid JSON payload".into() }).await.ok(); + return; + } + } else { + + } + } else { + // Websocket closed, return + return; + } + + while let Some(msg) = receiver.next().await { + let msg = if let Ok(msg) = msg { + msg + } else { + return; + }; + + // TODO: handle incoming payloads + } +} + +async fn write( + app_state: Arc, + socket_state: Arc, + mut rx: mpsc::Receiver, + mut sender: SplitSink, +) { +} + +async fn merge_channels( + app_state: Arc, + socket_state: Arc, + mut rx: broadcast::Receiver, + tx: mpsc::Sender, +) { + loop { + let msg = rx.recv().await; + match msg { + Ok(p) => { + // TODO: filter users + tx.send(Payload::Dispatch { + event: p, + recipients: vec![], + }) + .await + .ok(); + }, + Err(e) => match e { + broadcast::error::RecvError::Closed => { + error!("Broadcast channel was closed, this is not supposed to happen"); + return; + }, + broadcast::error::RecvError::Lagged(i) => { + error!("Broadcast receive lagged by {i}"); + } + } + } + } +} diff --git a/foxchat/src/s2s/dispatch.rs b/foxchat/src/s2s/dispatch.rs index 0213be1..ffbc441 100644 --- a/foxchat/src/s2s/dispatch.rs +++ b/foxchat/src/s2s/dispatch.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] #[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")] pub enum Dispatch { MessageCreate diff --git a/foxchat/src/s2s/event.rs b/foxchat/src/s2s/event.rs index 1ba5cba..3a0af2b 100644 --- a/foxchat/src/s2s/event.rs +++ b/foxchat/src/s2s/event.rs @@ -13,6 +13,12 @@ pub enum Payload { }, Hello, Identify { - token: String, + host: String, + date: String, + server: String, + signature: String, }, + Error { + message: String, + } }