feat: add c2s websocket to identity server

This commit is contained in:
sam 2024-02-27 03:52:39 +01:00
parent 809af7e637
commit 42abd70184
Signed by: sam
GPG key ID: B4EF20DDE721CAA1
3 changed files with 271 additions and 8 deletions

View file

@ -17,9 +17,20 @@ pub enum Payload {
},
/// Hello message, sent after authentication succeeds
Hello {
heartbeat_interval: u64,
guilds: Vec<String>,
},
Identify {
token: String,
},
/// Sent on a regular interval by the client, to keep the connection alive.
Heartbeat {
#[serde(rename = "t")]
timestamp: u64,
},
/// Sent in response to a Heartbeat.
HeartbeatAck {
#[serde(rename = "t")]
timestamp: u64,
}
}

View file

@ -25,6 +25,7 @@ pub fn new(pool: Pool<Postgres>, config: Config, instance: Instance) -> Router {
.nest("/_fox/proxy", proxy::router())
.route("/_fox/ident/node", get(node::get_node))
.route("/_fox/ident/node/:domain", get(node::get_chat_node))
.route("/_fox/ident/ws", get(ws::handler))
.layer(TraceLayer::new_for_http())
.layer(Extension(app_state));

View file

@ -1,23 +1,274 @@
use std::sync::Arc;
use std::{sync::Arc, time::Duration};
use axum::{extract::{ws::WebSocket, WebSocketUpgrade}, response::Response, Extension};
use axum::{
extract::{
ws::{Message, WebSocket},
WebSocketUpgrade,
},
response::Response,
Extension,
};
use eyre::Result;
use foxchat::{c2s::Payload, FoxError};
use futures::{
stream::{SplitSink, SplitStream, StreamExt},
SinkExt,
};
use rand::Rng;
use tokio::{sync::mpsc::{self, Receiver, Sender}, time::timeout};
use tracing::error;
use crate::{app_state::AppState, model::account::Account};
use crate::{app_state::AppState, db::check_token, model::account::Account};
pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppState>>) -> Response {
ws.on_upgrade(|socket| handle_socket(state, socket))
}
struct SocketState {
user: Option<Account>,
}
async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
let (sender, receiver) = socket.split();
let (tx, rx) = mpsc::channel::<Payload>(10);
tokio::spawn(read(app_state, sender, receiver, tx, rx));
}
async fn read(
app_state: Arc<AppState>,
mut sender: SplitSink<WebSocket, Message>,
mut receiver: SplitStream<WebSocket>,
tx: Sender<Payload>,
rx: Receiver<Payload>,
) {
let msg = receiver.next().await;
let Some(msg) = msg else {
// Websocket was closed, so return
return;
};
let Ok(msg) = msg else {
let Ok(p) = to_json(Payload::Error {
message: "Internal server error".into(),
}) else {
return;
};
sender.send(p).await.ok();
return;
};
// Convert Message into a string. This can fail if the payload is not valid UTF-8.
let Ok(msg) = msg.into_text() else {
let Ok(p) = to_json(Payload::Error {
message: "Invalid event".into(),
}) else {
return;
};
sender.send(p).await.ok();
return;
};
// Parse payload JSON
let Ok(pl) = serde_json::from_str::<Payload>(&msg) else {
let Ok(p) = to_json(Payload::Error {
message: "Invalid event".into(),
}) else {
return;
};
sender.send(p).await.ok();
return;
};
let Payload::Identify { token } = pl else {
let Ok(p) = to_json(Payload::Error {
message: "First event was not IDENTIFY".into(),
}) else {
return;
};
sender.send(p).await.ok();
return;
};
// Check the token. `check_token` can return an error if the token is invalid (or a database error occurs), so handle that.
let user = match check_token(&app_state.pool, token).await {
Ok(u) => u,
Err(e) => {
// The only FoxError this function can return is NotFound
if let Some(_) = e.downcast_ref::<FoxError>() {
let Ok(p) = to_json(Payload::Error {
message: "Invalid token".into(),
}) else {
return;
};
sender.send(p).await.ok();
return;
}
let Ok(p) = to_json(Payload::Error {
message: "Internal server error".into(),
}) else {
return;
};
sender.send(p).await.ok();
return;
}
};
// Put the user in a comfy little box, because it'll be used by multiple tasks
let user = Arc::new(user);
// Spawn the `write` task, all writes will now go through tx.send()
tokio::spawn(write(app_state.clone(), user.clone(), sender, rx));
// Send HELLO event
// TODO: fetch guild IDs
let heartbeat_interval = rand::thread_rng().gen_range(45_000..70_000);
tx.send(Payload::Hello { heartbeat_interval, guilds: vec![] }).await.ok();
// Start the heartbeat loop
let (heartbeat_tx, heartbeat_rx) = mpsc::channel::<u64>(10);
tokio::spawn(heartbeat(heartbeat_interval, heartbeat_rx, tx.clone()));
// Fire off ready event
tokio::spawn(collect_ready(app_state.clone(), user.clone(), tx.clone()));
// Start listening for events
while let Some(msg) = receiver.next().await {
let Ok(msg) = msg else {
return;
};
let Ok(msg) = msg.to_text() else {
tx.send(Payload::Error {
message: "Invalid message".into(),
})
.await
.ok();
return;
};
let Ok(msg) = serde_json::from_str::<Payload>(msg) else {
tx.send(Payload::Error {
message: "Invalid message".into(),
})
.await
.ok();
return;
};
match msg {
Payload::Heartbeat { timestamp } => {
// Heartbeats are handled in another function
heartbeat_tx.send(timestamp).await.ok();
}
// TODO: handle other events
_ => {
tx.send(Payload::Error {
message: "Invalid send event".into(),
})
.await
.ok();
return;
}
}
}
}
async fn write(
app_state: Arc<AppState>,
user: Arc<Account>,
mut sender: SplitSink<WebSocket, Message>,
mut rx: Receiver<Payload>,
) {
loop {
if let Some(ev) = rx.recv().await {
match &ev {
Payload::Error { message: _ } => {
let msg = match to_json(ev) {
Ok(m) => m,
Err(e) => {
error!("error serializing payload to JSON: {}", e);
return;
}
};
match sender.send(msg).await {
Ok(_) => {}
// Pretty sure we can assume the websocket was closed, so just return
Err(e) => {
error!("error writing to websocket: {}", e);
return;
}
}
match sender.close().await {
Ok(_) => {}
Err(e) => {
error!("error closing websocket: {}", e)
}
}
return;
}
_ => {
// Forward payload
let msg = match to_json(ev) {
Ok(m) => m,
Err(e) => {
error!("error serializing payload to JSON: {}", e);
return;
}
};
match sender.send(msg).await {
Ok(_) => {}
// Pretty sure we can assume the websocket was closed, so just return
Err(e) => {
error!("error writing to websocket: {}", e);
return;
}
}
}
}
}
}
}
async fn collect_ready(app_state: Arc<AppState>, user: Arc<Account>, tx: Sender<Payload>) {}
async fn heartbeat(
heartbeat_interval: u64,
mut rx: mpsc::Receiver<u64>,
tx: mpsc::Sender<Payload>,
) {
// The timeout is twice the heartbeat interval. If no heartbeat is received by then, close the connection with an error.
while let Ok(i) = timeout(Duration::from_millis(heartbeat_interval * 2), rx.recv()).await {
match i {
// ACK the heartbeat with the same timestamp. TODO: validate the timestamp to make sure we aren't too out of sync with the identity server.
Some(timestamp) => {
tx.send(Payload::HeartbeatAck { timestamp }).await.ok();
}
// If the channel returns None, that means it's been closed, which means the socket as a whole was closed
None => {
return;
}
};
}
// Send an error, which will automatically close the socket
tx.send(Payload::Error {
message: format!(
"Did not receive a heartbeat after {}ms",
heartbeat_interval * 2
),
})
.await
.ok();
}
fn to_json(p: Payload) -> Result<Message> {
Ok(Message::Text(serde_json::to_string(&p)?))
}