feat: add c2s websocket to identity server
This commit is contained in:
parent
809af7e637
commit
42abd70184
3 changed files with 271 additions and 8 deletions
|
@ -17,9 +17,20 @@ pub enum Payload {
|
|||
},
|
||||
/// Hello message, sent after authentication succeeds
|
||||
Hello {
|
||||
heartbeat_interval: u64,
|
||||
guilds: Vec<String>,
|
||||
},
|
||||
Identify {
|
||||
token: String,
|
||||
},
|
||||
/// Sent on a regular interval by the client, to keep the connection alive.
|
||||
Heartbeat {
|
||||
#[serde(rename = "t")]
|
||||
timestamp: u64,
|
||||
},
|
||||
/// Sent in response to a Heartbeat.
|
||||
HeartbeatAck {
|
||||
#[serde(rename = "t")]
|
||||
timestamp: u64,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ pub fn new(pool: Pool<Postgres>, config: Config, instance: Instance) -> Router {
|
|||
.nest("/_fox/proxy", proxy::router())
|
||||
.route("/_fox/ident/node", get(node::get_node))
|
||||
.route("/_fox/ident/node/:domain", get(node::get_chat_node))
|
||||
.route("/_fox/ident/ws", get(ws::handler))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(Extension(app_state));
|
||||
|
||||
|
|
|
@ -1,23 +1,274 @@
|
|||
use std::sync::Arc;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use axum::{extract::{ws::WebSocket, WebSocketUpgrade}, response::Response, Extension};
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{Message, WebSocket},
|
||||
WebSocketUpgrade,
|
||||
},
|
||||
response::Response,
|
||||
Extension,
|
||||
};
|
||||
use eyre::Result;
|
||||
use foxchat::{c2s::Payload, FoxError};
|
||||
use futures::{
|
||||
stream::{SplitSink, SplitStream, StreamExt},
|
||||
SinkExt,
|
||||
};
|
||||
use rand::Rng;
|
||||
use tokio::{sync::mpsc::{self, Receiver, Sender}, time::timeout};
|
||||
use tracing::error;
|
||||
|
||||
use crate::{app_state::AppState, model::account::Account};
|
||||
use crate::{app_state::AppState, db::check_token, model::account::Account};
|
||||
|
||||
pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppState>>) -> Response {
|
||||
ws.on_upgrade(|socket| handle_socket(state, socket))
|
||||
}
|
||||
|
||||
struct SocketState {
|
||||
user: Option<Account>,
|
||||
}
|
||||
|
||||
async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
|
||||
let (sender, receiver) = socket.split();
|
||||
|
||||
let (tx, rx) = mpsc::channel::<Payload>(10);
|
||||
|
||||
tokio::spawn(read(app_state, sender, receiver, tx, rx));
|
||||
}
|
||||
|
||||
async fn read(
|
||||
app_state: Arc<AppState>,
|
||||
mut sender: SplitSink<WebSocket, Message>,
|
||||
mut receiver: SplitStream<WebSocket>,
|
||||
tx: Sender<Payload>,
|
||||
rx: Receiver<Payload>,
|
||||
) {
|
||||
let msg = receiver.next().await;
|
||||
let Some(msg) = msg else {
|
||||
// Websocket was closed, so return
|
||||
return;
|
||||
};
|
||||
|
||||
let Ok(msg) = msg else {
|
||||
let Ok(p) = to_json(Payload::Error {
|
||||
message: "Internal server error".into(),
|
||||
}) else {
|
||||
return;
|
||||
};
|
||||
|
||||
sender.send(p).await.ok();
|
||||
return;
|
||||
};
|
||||
|
||||
// Convert Message into a string. This can fail if the payload is not valid UTF-8.
|
||||
let Ok(msg) = msg.into_text() else {
|
||||
let Ok(p) = to_json(Payload::Error {
|
||||
message: "Invalid event".into(),
|
||||
}) else {
|
||||
return;
|
||||
};
|
||||
|
||||
sender.send(p).await.ok();
|
||||
return;
|
||||
};
|
||||
|
||||
// Parse payload JSON
|
||||
let Ok(pl) = serde_json::from_str::<Payload>(&msg) else {
|
||||
let Ok(p) = to_json(Payload::Error {
|
||||
message: "Invalid event".into(),
|
||||
}) else {
|
||||
return;
|
||||
};
|
||||
|
||||
sender.send(p).await.ok();
|
||||
return;
|
||||
};
|
||||
|
||||
let Payload::Identify { token } = pl else {
|
||||
let Ok(p) = to_json(Payload::Error {
|
||||
message: "First event was not IDENTIFY".into(),
|
||||
}) else {
|
||||
return;
|
||||
};
|
||||
|
||||
sender.send(p).await.ok();
|
||||
return;
|
||||
};
|
||||
|
||||
// Check the token. `check_token` can return an error if the token is invalid (or a database error occurs), so handle that.
|
||||
let user = match check_token(&app_state.pool, token).await {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
// The only FoxError this function can return is NotFound
|
||||
if let Some(_) = e.downcast_ref::<FoxError>() {
|
||||
let Ok(p) = to_json(Payload::Error {
|
||||
message: "Invalid token".into(),
|
||||
}) else {
|
||||
return;
|
||||
};
|
||||
|
||||
sender.send(p).await.ok();
|
||||
return;
|
||||
}
|
||||
|
||||
let Ok(p) = to_json(Payload::Error {
|
||||
message: "Internal server error".into(),
|
||||
}) else {
|
||||
return;
|
||||
};
|
||||
|
||||
sender.send(p).await.ok();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Put the user in a comfy little box, because it'll be used by multiple tasks
|
||||
let user = Arc::new(user);
|
||||
// Spawn the `write` task, all writes will now go through tx.send()
|
||||
tokio::spawn(write(app_state.clone(), user.clone(), sender, rx));
|
||||
|
||||
// Send HELLO event
|
||||
// TODO: fetch guild IDs
|
||||
let heartbeat_interval = rand::thread_rng().gen_range(45_000..70_000);
|
||||
tx.send(Payload::Hello { heartbeat_interval, guilds: vec![] }).await.ok();
|
||||
|
||||
// Start the heartbeat loop
|
||||
let (heartbeat_tx, heartbeat_rx) = mpsc::channel::<u64>(10);
|
||||
tokio::spawn(heartbeat(heartbeat_interval, heartbeat_rx, tx.clone()));
|
||||
|
||||
// Fire off ready event
|
||||
tokio::spawn(collect_ready(app_state.clone(), user.clone(), tx.clone()));
|
||||
|
||||
// Start listening for events
|
||||
while let Some(msg) = receiver.next().await {
|
||||
let Ok(msg) = msg else {
|
||||
return;
|
||||
};
|
||||
|
||||
let Ok(msg) = msg.to_text() else {
|
||||
tx.send(Payload::Error {
|
||||
message: "Invalid message".into(),
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
return;
|
||||
};
|
||||
|
||||
let Ok(msg) = serde_json::from_str::<Payload>(msg) else {
|
||||
tx.send(Payload::Error {
|
||||
message: "Invalid message".into(),
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
return;
|
||||
};
|
||||
|
||||
match msg {
|
||||
Payload::Heartbeat { timestamp } => {
|
||||
// Heartbeats are handled in another function
|
||||
heartbeat_tx.send(timestamp).await.ok();
|
||||
}
|
||||
// TODO: handle other events
|
||||
_ => {
|
||||
tx.send(Payload::Error {
|
||||
message: "Invalid send event".into(),
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn write(
|
||||
app_state: Arc<AppState>,
|
||||
user: Arc<Account>,
|
||||
mut sender: SplitSink<WebSocket, Message>,
|
||||
mut rx: Receiver<Payload>,
|
||||
) {
|
||||
loop {
|
||||
if let Some(ev) = rx.recv().await {
|
||||
match &ev {
|
||||
Payload::Error { message: _ } => {
|
||||
let msg = match to_json(ev) {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
error!("error serializing payload to JSON: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match sender.send(msg).await {
|
||||
Ok(_) => {}
|
||||
// Pretty sure we can assume the websocket was closed, so just return
|
||||
Err(e) => {
|
||||
error!("error writing to websocket: {}", e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
match sender.close().await {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
error!("error closing websocket: {}", e)
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
_ => {
|
||||
// Forward payload
|
||||
let msg = match to_json(ev) {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
error!("error serializing payload to JSON: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match sender.send(msg).await {
|
||||
Ok(_) => {}
|
||||
// Pretty sure we can assume the websocket was closed, so just return
|
||||
Err(e) => {
|
||||
error!("error writing to websocket: {}", e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn collect_ready(app_state: Arc<AppState>, user: Arc<Account>, tx: Sender<Payload>) {}
|
||||
|
||||
async fn heartbeat(
|
||||
heartbeat_interval: u64,
|
||||
mut rx: mpsc::Receiver<u64>,
|
||||
tx: mpsc::Sender<Payload>,
|
||||
) {
|
||||
// The timeout is twice the heartbeat interval. If no heartbeat is received by then, close the connection with an error.
|
||||
while let Ok(i) = timeout(Duration::from_millis(heartbeat_interval * 2), rx.recv()).await {
|
||||
match i {
|
||||
// ACK the heartbeat with the same timestamp. TODO: validate the timestamp to make sure we aren't too out of sync with the identity server.
|
||||
Some(timestamp) => {
|
||||
tx.send(Payload::HeartbeatAck { timestamp }).await.ok();
|
||||
}
|
||||
// If the channel returns None, that means it's been closed, which means the socket as a whole was closed
|
||||
None => {
|
||||
return;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Send an error, which will automatically close the socket
|
||||
tx.send(Payload::Error {
|
||||
message: format!(
|
||||
"Did not receive a heartbeat after {}ms",
|
||||
heartbeat_interval * 2
|
||||
),
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn to_json(p: Payload) -> Result<Message> {
|
||||
Ok(Message::Text(serde_json::to_string(&p)?))
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue