very very unfinished websocket

This commit is contained in:
sam 2024-01-21 22:30:56 +01:00
parent ce543e7ee1
commit f07333e358
8 changed files with 168 additions and 6 deletions

29
Cargo.lock generated
View file

@ -389,6 +389,7 @@ dependencies = [
"color-eyre", "color-eyre",
"eyre", "eyre",
"foxchat", "foxchat",
"futures",
"rand", "rand",
"reqwest", "reqwest",
"rsa", "rsa",
@ -762,6 +763,21 @@ dependencies = [
"uuid", "uuid",
] ]
[[package]]
name = "futures"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]] [[package]]
name = "futures-channel" name = "futures-channel"
version = "0.3.30" version = "0.3.30"
@ -806,6 +822,17 @@ version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
[[package]]
name = "futures-macro"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
]
[[package]] [[package]]
name = "futures-sink" name = "futures-sink"
version = "0.3.30" version = "0.3.30"
@ -824,8 +851,10 @@ version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
dependencies = [ dependencies = [
"futures-channel",
"futures-core", "futures-core",
"futures-io", "futures-io",
"futures-macro",
"futures-sink", "futures-sink",
"futures-task", "futures-task",
"memchr", "memchr",

View file

@ -5,6 +5,3 @@ members = [
"chat" "chat"
] ]
resolver = "2" resolver = "2"
[profile.dev.package.num-bigint-dig]
opt-level = 3

View file

@ -24,3 +24,4 @@ tracing = "0.1.40"
tower-http = { version = "0.5.1", features = ["trace"] } tower-http = { version = "0.5.1", features = ["trace"] }
chrono = "0.4.31" chrono = "0.4.31"
reqwest = { version = "0.11.23", features = ["json"] } reqwest = { version = "0.11.23", features = ["json"] }
futures = "0.3.30"

View file

@ -1,5 +1,7 @@
use foxchat::s2s::Dispatch;
use rsa::{RsaPublicKey, RsaPrivateKey}; use rsa::{RsaPublicKey, RsaPrivateKey};
use sqlx::{Pool, Postgres}; use sqlx::{Pool, Postgres};
use tokio::sync::broadcast::Sender;
use crate::config::Config; use crate::config::Config;
@ -8,4 +10,5 @@ pub struct AppState {
pub config: Config, pub config: Config,
pub public_key: RsaPublicKey, pub public_key: RsaPublicKey,
pub private_key: RsaPrivateKey, pub private_key: RsaPrivateKey,
pub broadcast: Sender<Dispatch>,
} }

View file

@ -1,22 +1,29 @@
mod api; mod api;
mod hello; mod hello;
mod ws;
use crate::{app_state::AppState, config::Config, model::instance::Instance}; use crate::{app_state::AppState, config::Config, model::instance::Instance};
use axum::{routing::post, Extension, Router}; use axum::{routing::{post, get}, Extension, Router};
use foxchat::s2s::Dispatch;
use sqlx::{Pool, Postgres}; use sqlx::{Pool, Postgres};
use tokio::sync::broadcast;
use std::sync::Arc; use std::sync::Arc;
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
pub fn new(pool: Pool<Postgres>, config: Config, instance: Instance) -> Router { pub fn new(pool: Pool<Postgres>, config: Config, instance: Instance) -> Router {
let (broadcast, _) = broadcast::channel::<Dispatch>(1024);
let app_state = Arc::new(AppState { let app_state = Arc::new(AppState {
pool, pool,
config, config,
public_key: instance.public_key, public_key: instance.public_key,
private_key: instance.private_key, private_key: instance.private_key,
broadcast,
}); });
let app = Router::new() let app = Router::new()
.route("/_fox/chat/hello", post(hello::post_hello)) .route("/_fox/chat/hello", post(hello::post_hello))
.route("/_fox/chat/ws", get(ws::handler))
.merge(api::router()) .merge(api::router())
.layer(TraceLayer::new_for_http()) .layer(TraceLayer::new_for_http())
.layer(Extension(app_state)); .layer(Extension(app_state));

119
chat/src/http/ws/mod.rs Normal file
View file

@ -0,0 +1,119 @@
use std::sync::Arc;
use axum::{
extract::{
ws::{Message, WebSocket},
WebSocketUpgrade,
},
response::Response,
Extension,
};
use foxchat::s2s::{Dispatch, Payload};
use futures::stream::{SplitSink, SplitStream, StreamExt};
use tokio::sync::{broadcast, mpsc};
use tracing::error;
use crate::app_state::AppState;
pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppState>>) -> Response {
ws.on_upgrade(|socket| handle_socket(state, socket))
}
struct SocketState {
instance_id: Option<String>,
}
async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
let (mut sender, mut receiver) = socket.split();
let (tx, rx) = mpsc::channel::<Payload>(10);
let dispatch = app_state.broadcast.subscribe();
let socket_state = Arc::new(SocketState {
instance_id: None, // Filled out after IDENTIFY
});
tokio::spawn(merge_channels(app_state.clone(), socket_state.clone(), dispatch, tx.clone()));
tokio::spawn(write(app_state.clone(), socket_state.clone(), rx, sender));
tokio::spawn(read(app_state.clone(), socket_state.clone(), tx.clone(), receiver));
}
async fn read(
app_state: Arc<AppState>,
socket_state: Arc<SocketState>,
tx: mpsc::Sender<Payload>,
mut receiver: SplitStream<WebSocket>,
) {
let msg = receiver.next().await;
if let Some(Ok(msg)) = msg {
if let Ok(msg) = msg.into_text() {
if let Ok(pl) = serde_json::from_str::<Payload>(&msg) {
match pl {
Payload::Identify { host, date, server, signature } => {
// TODO: identify
},
_ => {
tx.send(Payload::Error { message: "First payload was not IDENTIFY".into() }).await.ok();
return;
}
}
} else {
tx.send(Payload::Error { message: "Invalid JSON payload".into() }).await.ok();
return;
}
} else {
}
} else {
// Websocket closed, return
return;
}
while let Some(msg) = receiver.next().await {
let msg = if let Ok(msg) = msg {
msg
} else {
return;
};
// TODO: handle incoming payloads
}
}
async fn write(
app_state: Arc<AppState>,
socket_state: Arc<SocketState>,
mut rx: mpsc::Receiver<Payload>,
mut sender: SplitSink<WebSocket, Message>,
) {
}
async fn merge_channels(
app_state: Arc<AppState>,
socket_state: Arc<SocketState>,
mut rx: broadcast::Receiver<Dispatch>,
tx: mpsc::Sender<Payload>,
) {
loop {
let msg = rx.recv().await;
match msg {
Ok(p) => {
// TODO: filter users
tx.send(Payload::Dispatch {
event: p,
recipients: vec![],
})
.await
.ok();
},
Err(e) => match e {
broadcast::error::RecvError::Closed => {
error!("Broadcast channel was closed, this is not supposed to happen");
return;
},
broadcast::error::RecvError::Lagged(i) => {
error!("Broadcast receive lagged by {i}");
}
}
}
}
}

View file

@ -1,6 +1,6 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")] #[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Dispatch { pub enum Dispatch {
MessageCreate MessageCreate

View file

@ -13,6 +13,12 @@ pub enum Payload {
}, },
Hello, Hello,
Identify { Identify {
token: String, host: String,
date: String,
server: String,
signature: String,
}, },
Error {
message: String,
}
} }