very very unfinished websocket
This commit is contained in:
parent
ce543e7ee1
commit
f07333e358
8 changed files with 168 additions and 6 deletions
|
@ -24,3 +24,4 @@ tracing = "0.1.40"
|
|||
tower-http = { version = "0.5.1", features = ["trace"] }
|
||||
chrono = "0.4.31"
|
||||
reqwest = { version = "0.11.23", features = ["json"] }
|
||||
futures = "0.3.30"
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
use foxchat::s2s::Dispatch;
|
||||
use rsa::{RsaPublicKey, RsaPrivateKey};
|
||||
use sqlx::{Pool, Postgres};
|
||||
use tokio::sync::broadcast::Sender;
|
||||
|
||||
use crate::config::Config;
|
||||
|
||||
|
@ -8,4 +10,5 @@ pub struct AppState {
|
|||
pub config: Config,
|
||||
pub public_key: RsaPublicKey,
|
||||
pub private_key: RsaPrivateKey,
|
||||
pub broadcast: Sender<Dispatch>,
|
||||
}
|
||||
|
|
|
@ -1,22 +1,29 @@
|
|||
mod api;
|
||||
mod hello;
|
||||
mod ws;
|
||||
|
||||
use crate::{app_state::AppState, config::Config, model::instance::Instance};
|
||||
use axum::{routing::post, Extension, Router};
|
||||
use axum::{routing::{post, get}, Extension, Router};
|
||||
use foxchat::s2s::Dispatch;
|
||||
use sqlx::{Pool, Postgres};
|
||||
use tokio::sync::broadcast;
|
||||
use std::sync::Arc;
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
pub fn new(pool: Pool<Postgres>, config: Config, instance: Instance) -> Router {
|
||||
let (broadcast, _) = broadcast::channel::<Dispatch>(1024);
|
||||
|
||||
let app_state = Arc::new(AppState {
|
||||
pool,
|
||||
config,
|
||||
public_key: instance.public_key,
|
||||
private_key: instance.private_key,
|
||||
broadcast,
|
||||
});
|
||||
|
||||
let app = Router::new()
|
||||
.route("/_fox/chat/hello", post(hello::post_hello))
|
||||
.route("/_fox/chat/ws", get(ws::handler))
|
||||
.merge(api::router())
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(Extension(app_state));
|
||||
|
|
119
chat/src/http/ws/mod.rs
Normal file
119
chat/src/http/ws/mod.rs
Normal file
|
@ -0,0 +1,119 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{Message, WebSocket},
|
||||
WebSocketUpgrade,
|
||||
},
|
||||
response::Response,
|
||||
Extension,
|
||||
};
|
||||
use foxchat::s2s::{Dispatch, Payload};
|
||||
use futures::stream::{SplitSink, SplitStream, StreamExt};
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
use tracing::error;
|
||||
|
||||
use crate::app_state::AppState;
|
||||
|
||||
pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppState>>) -> Response {
|
||||
ws.on_upgrade(|socket| handle_socket(state, socket))
|
||||
}
|
||||
|
||||
struct SocketState {
|
||||
instance_id: Option<String>,
|
||||
}
|
||||
|
||||
async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
let (tx, rx) = mpsc::channel::<Payload>(10);
|
||||
let dispatch = app_state.broadcast.subscribe();
|
||||
let socket_state = Arc::new(SocketState {
|
||||
instance_id: None, // Filled out after IDENTIFY
|
||||
});
|
||||
|
||||
tokio::spawn(merge_channels(app_state.clone(), socket_state.clone(), dispatch, tx.clone()));
|
||||
tokio::spawn(write(app_state.clone(), socket_state.clone(), rx, sender));
|
||||
tokio::spawn(read(app_state.clone(), socket_state.clone(), tx.clone(), receiver));
|
||||
}
|
||||
|
||||
async fn read(
|
||||
app_state: Arc<AppState>,
|
||||
socket_state: Arc<SocketState>,
|
||||
tx: mpsc::Sender<Payload>,
|
||||
mut receiver: SplitStream<WebSocket>,
|
||||
) {
|
||||
let msg = receiver.next().await;
|
||||
if let Some(Ok(msg)) = msg {
|
||||
if let Ok(msg) = msg.into_text() {
|
||||
if let Ok(pl) = serde_json::from_str::<Payload>(&msg) {
|
||||
match pl {
|
||||
Payload::Identify { host, date, server, signature } => {
|
||||
// TODO: identify
|
||||
},
|
||||
_ => {
|
||||
tx.send(Payload::Error { message: "First payload was not IDENTIFY".into() }).await.ok();
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tx.send(Payload::Error { message: "Invalid JSON payload".into() }).await.ok();
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
|
||||
}
|
||||
} else {
|
||||
// Websocket closed, return
|
||||
return;
|
||||
}
|
||||
|
||||
while let Some(msg) = receiver.next().await {
|
||||
let msg = if let Ok(msg) = msg {
|
||||
msg
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
|
||||
// TODO: handle incoming payloads
|
||||
}
|
||||
}
|
||||
|
||||
async fn write(
|
||||
app_state: Arc<AppState>,
|
||||
socket_state: Arc<SocketState>,
|
||||
mut rx: mpsc::Receiver<Payload>,
|
||||
mut sender: SplitSink<WebSocket, Message>,
|
||||
) {
|
||||
}
|
||||
|
||||
async fn merge_channels(
|
||||
app_state: Arc<AppState>,
|
||||
socket_state: Arc<SocketState>,
|
||||
mut rx: broadcast::Receiver<Dispatch>,
|
||||
tx: mpsc::Sender<Payload>,
|
||||
) {
|
||||
loop {
|
||||
let msg = rx.recv().await;
|
||||
match msg {
|
||||
Ok(p) => {
|
||||
// TODO: filter users
|
||||
tx.send(Payload::Dispatch {
|
||||
event: p,
|
||||
recipients: vec![],
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
},
|
||||
Err(e) => match e {
|
||||
broadcast::error::RecvError::Closed => {
|
||||
error!("Broadcast channel was closed, this is not supposed to happen");
|
||||
return;
|
||||
},
|
||||
broadcast::error::RecvError::Lagged(i) => {
|
||||
error!("Broadcast receive lagged by {i}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue