feat: add c2s websocket to identity server
This commit is contained in:
parent
809af7e637
commit
42abd70184
3 changed files with 271 additions and 8 deletions
|
@ -17,9 +17,20 @@ pub enum Payload {
|
||||||
},
|
},
|
||||||
/// Hello message, sent after authentication succeeds
|
/// Hello message, sent after authentication succeeds
|
||||||
Hello {
|
Hello {
|
||||||
|
heartbeat_interval: u64,
|
||||||
guilds: Vec<String>,
|
guilds: Vec<String>,
|
||||||
},
|
},
|
||||||
Identify {
|
Identify {
|
||||||
token: String,
|
token: String,
|
||||||
},
|
},
|
||||||
|
/// Sent on a regular interval by the client, to keep the connection alive.
|
||||||
|
Heartbeat {
|
||||||
|
#[serde(rename = "t")]
|
||||||
|
timestamp: u64,
|
||||||
|
},
|
||||||
|
/// Sent in response to a Heartbeat.
|
||||||
|
HeartbeatAck {
|
||||||
|
#[serde(rename = "t")]
|
||||||
|
timestamp: u64,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ pub fn new(pool: Pool<Postgres>, config: Config, instance: Instance) -> Router {
|
||||||
.nest("/_fox/proxy", proxy::router())
|
.nest("/_fox/proxy", proxy::router())
|
||||||
.route("/_fox/ident/node", get(node::get_node))
|
.route("/_fox/ident/node", get(node::get_node))
|
||||||
.route("/_fox/ident/node/:domain", get(node::get_chat_node))
|
.route("/_fox/ident/node/:domain", get(node::get_chat_node))
|
||||||
|
.route("/_fox/ident/ws", get(ws::handler))
|
||||||
.layer(TraceLayer::new_for_http())
|
.layer(TraceLayer::new_for_http())
|
||||||
.layer(Extension(app_state));
|
.layer(Extension(app_state));
|
||||||
|
|
||||||
|
|
|
@ -1,23 +1,274 @@
|
||||||
use std::sync::Arc;
|
use std::{sync::Arc, time::Duration};
|
||||||
|
|
||||||
use axum::{extract::{ws::WebSocket, WebSocketUpgrade}, response::Response, Extension};
|
use axum::{
|
||||||
|
extract::{
|
||||||
|
ws::{Message, WebSocket},
|
||||||
|
WebSocketUpgrade,
|
||||||
|
},
|
||||||
|
response::Response,
|
||||||
|
Extension,
|
||||||
|
};
|
||||||
|
use eyre::Result;
|
||||||
|
use foxchat::{c2s::Payload, FoxError};
|
||||||
use futures::{
|
use futures::{
|
||||||
stream::{SplitSink, SplitStream, StreamExt},
|
stream::{SplitSink, SplitStream, StreamExt},
|
||||||
SinkExt,
|
SinkExt,
|
||||||
};
|
};
|
||||||
|
use rand::Rng;
|
||||||
|
use tokio::{sync::mpsc::{self, Receiver, Sender}, time::timeout};
|
||||||
|
use tracing::error;
|
||||||
|
|
||||||
use crate::{app_state::AppState, model::account::Account};
|
use crate::{app_state::AppState, db::check_token, model::account::Account};
|
||||||
|
|
||||||
pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppState>>) -> Response {
|
pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppState>>) -> Response {
|
||||||
ws.on_upgrade(|socket| handle_socket(state, socket))
|
ws.on_upgrade(|socket| handle_socket(state, socket))
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SocketState {
|
|
||||||
user: Option<Account>,
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
|
async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
|
||||||
let (sender, receiver) = socket.split();
|
let (sender, receiver) = socket.split();
|
||||||
|
|
||||||
|
let (tx, rx) = mpsc::channel::<Payload>(10);
|
||||||
|
|
||||||
|
tokio::spawn(read(app_state, sender, receiver, tx, rx));
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn read(
|
||||||
|
app_state: Arc<AppState>,
|
||||||
|
mut sender: SplitSink<WebSocket, Message>,
|
||||||
|
mut receiver: SplitStream<WebSocket>,
|
||||||
|
tx: Sender<Payload>,
|
||||||
|
rx: Receiver<Payload>,
|
||||||
|
) {
|
||||||
|
let msg = receiver.next().await;
|
||||||
|
let Some(msg) = msg else {
|
||||||
|
// Websocket was closed, so return
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let Ok(msg) = msg else {
|
||||||
|
let Ok(p) = to_json(Payload::Error {
|
||||||
|
message: "Internal server error".into(),
|
||||||
|
}) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
sender.send(p).await.ok();
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert Message into a string. This can fail if the payload is not valid UTF-8.
|
||||||
|
let Ok(msg) = msg.into_text() else {
|
||||||
|
let Ok(p) = to_json(Payload::Error {
|
||||||
|
message: "Invalid event".into(),
|
||||||
|
}) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
sender.send(p).await.ok();
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse payload JSON
|
||||||
|
let Ok(pl) = serde_json::from_str::<Payload>(&msg) else {
|
||||||
|
let Ok(p) = to_json(Payload::Error {
|
||||||
|
message: "Invalid event".into(),
|
||||||
|
}) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
sender.send(p).await.ok();
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let Payload::Identify { token } = pl else {
|
||||||
|
let Ok(p) = to_json(Payload::Error {
|
||||||
|
message: "First event was not IDENTIFY".into(),
|
||||||
|
}) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
sender.send(p).await.ok();
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check the token. `check_token` can return an error if the token is invalid (or a database error occurs), so handle that.
|
||||||
|
let user = match check_token(&app_state.pool, token).await {
|
||||||
|
Ok(u) => u,
|
||||||
|
Err(e) => {
|
||||||
|
// The only FoxError this function can return is NotFound
|
||||||
|
if let Some(_) = e.downcast_ref::<FoxError>() {
|
||||||
|
let Ok(p) = to_json(Payload::Error {
|
||||||
|
message: "Invalid token".into(),
|
||||||
|
}) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
sender.send(p).await.ok();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let Ok(p) = to_json(Payload::Error {
|
||||||
|
message: "Internal server error".into(),
|
||||||
|
}) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
sender.send(p).await.ok();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Put the user in a comfy little box, because it'll be used by multiple tasks
|
||||||
|
let user = Arc::new(user);
|
||||||
|
// Spawn the `write` task, all writes will now go through tx.send()
|
||||||
|
tokio::spawn(write(app_state.clone(), user.clone(), sender, rx));
|
||||||
|
|
||||||
|
// Send HELLO event
|
||||||
|
// TODO: fetch guild IDs
|
||||||
|
let heartbeat_interval = rand::thread_rng().gen_range(45_000..70_000);
|
||||||
|
tx.send(Payload::Hello { heartbeat_interval, guilds: vec![] }).await.ok();
|
||||||
|
|
||||||
|
// Start the heartbeat loop
|
||||||
|
let (heartbeat_tx, heartbeat_rx) = mpsc::channel::<u64>(10);
|
||||||
|
tokio::spawn(heartbeat(heartbeat_interval, heartbeat_rx, tx.clone()));
|
||||||
|
|
||||||
|
// Fire off ready event
|
||||||
|
tokio::spawn(collect_ready(app_state.clone(), user.clone(), tx.clone()));
|
||||||
|
|
||||||
|
// Start listening for events
|
||||||
|
while let Some(msg) = receiver.next().await {
|
||||||
|
let Ok(msg) = msg else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let Ok(msg) = msg.to_text() else {
|
||||||
|
tx.send(Payload::Error {
|
||||||
|
message: "Invalid message".into(),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let Ok(msg) = serde_json::from_str::<Payload>(msg) else {
|
||||||
|
tx.send(Payload::Error {
|
||||||
|
message: "Invalid message".into(),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
match msg {
|
||||||
|
Payload::Heartbeat { timestamp } => {
|
||||||
|
// Heartbeats are handled in another function
|
||||||
|
heartbeat_tx.send(timestamp).await.ok();
|
||||||
|
}
|
||||||
|
// TODO: handle other events
|
||||||
|
_ => {
|
||||||
|
tx.send(Payload::Error {
|
||||||
|
message: "Invalid send event".into(),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn write(
|
||||||
|
app_state: Arc<AppState>,
|
||||||
|
user: Arc<Account>,
|
||||||
|
mut sender: SplitSink<WebSocket, Message>,
|
||||||
|
mut rx: Receiver<Payload>,
|
||||||
|
) {
|
||||||
|
loop {
|
||||||
|
if let Some(ev) = rx.recv().await {
|
||||||
|
match &ev {
|
||||||
|
Payload::Error { message: _ } => {
|
||||||
|
let msg = match to_json(ev) {
|
||||||
|
Ok(m) => m,
|
||||||
|
Err(e) => {
|
||||||
|
error!("error serializing payload to JSON: {}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match sender.send(msg).await {
|
||||||
|
Ok(_) => {}
|
||||||
|
// Pretty sure we can assume the websocket was closed, so just return
|
||||||
|
Err(e) => {
|
||||||
|
error!("error writing to websocket: {}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match sender.close().await {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) => {
|
||||||
|
error!("error closing websocket: {}", e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Forward payload
|
||||||
|
let msg = match to_json(ev) {
|
||||||
|
Ok(m) => m,
|
||||||
|
Err(e) => {
|
||||||
|
error!("error serializing payload to JSON: {}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match sender.send(msg).await {
|
||||||
|
Ok(_) => {}
|
||||||
|
// Pretty sure we can assume the websocket was closed, so just return
|
||||||
|
Err(e) => {
|
||||||
|
error!("error writing to websocket: {}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn collect_ready(app_state: Arc<AppState>, user: Arc<Account>, tx: Sender<Payload>) {}
|
||||||
|
|
||||||
|
async fn heartbeat(
|
||||||
|
heartbeat_interval: u64,
|
||||||
|
mut rx: mpsc::Receiver<u64>,
|
||||||
|
tx: mpsc::Sender<Payload>,
|
||||||
|
) {
|
||||||
|
// The timeout is twice the heartbeat interval. If no heartbeat is received by then, close the connection with an error.
|
||||||
|
while let Ok(i) = timeout(Duration::from_millis(heartbeat_interval * 2), rx.recv()).await {
|
||||||
|
match i {
|
||||||
|
// ACK the heartbeat with the same timestamp. TODO: validate the timestamp to make sure we aren't too out of sync with the identity server.
|
||||||
|
Some(timestamp) => {
|
||||||
|
tx.send(Payload::HeartbeatAck { timestamp }).await.ok();
|
||||||
|
}
|
||||||
|
// If the channel returns None, that means it's been closed, which means the socket as a whole was closed
|
||||||
|
None => {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send an error, which will automatically close the socket
|
||||||
|
tx.send(Payload::Error {
|
||||||
|
message: format!(
|
||||||
|
"Did not receive a heartbeat after {}ms",
|
||||||
|
heartbeat_interval * 2
|
||||||
|
),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_json(p: Payload) -> Result<Message> {
|
||||||
|
Ok(Message::Text(serde_json::to_string(&p)?))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue