feat: build out chat server websocket more
This commit is contained in:
parent
ad247ca0f4
commit
fd027aee5c
2 changed files with 117 additions and 14 deletions
|
@ -1,4 +1,4 @@
|
||||||
use std::sync::Arc;
|
use std::{sync::Arc, time::Duration};
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{
|
extract::{
|
||||||
|
@ -8,8 +8,9 @@ use axum::{
|
||||||
response::Response,
|
response::Response,
|
||||||
Extension,
|
Extension,
|
||||||
};
|
};
|
||||||
use eyre::{eyre, Error, Result};
|
use eyre::{eyre, Result};
|
||||||
use foxchat::{
|
use foxchat::{
|
||||||
|
model::User,
|
||||||
s2s::{Dispatch, Payload},
|
s2s::{Dispatch, Payload},
|
||||||
signature::{parse_date, verify_signature},
|
signature::{parse_date, verify_signature},
|
||||||
};
|
};
|
||||||
|
@ -17,7 +18,11 @@ use futures::{
|
||||||
stream::{SplitSink, SplitStream, StreamExt},
|
stream::{SplitSink, SplitStream, StreamExt},
|
||||||
SinkExt,
|
SinkExt,
|
||||||
};
|
};
|
||||||
use tokio::sync::{broadcast, mpsc, RwLock};
|
use rand::Rng;
|
||||||
|
use tokio::{
|
||||||
|
sync::{broadcast, mpsc, RwLock},
|
||||||
|
time::timeout,
|
||||||
|
};
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
|
|
||||||
use crate::{app_state::AppState, model::identity_instance::IdentityInstance};
|
use crate::{app_state::AppState, model::identity_instance::IdentityInstance};
|
||||||
|
@ -28,6 +33,7 @@ pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppSt
|
||||||
|
|
||||||
struct SocketState {
|
struct SocketState {
|
||||||
instance: Option<IdentityInstance>,
|
instance: Option<IdentityInstance>,
|
||||||
|
last_heartbeat: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
|
async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
|
||||||
|
@ -36,16 +42,21 @@ async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
|
||||||
let (tx, rx) = mpsc::channel::<Payload>(10);
|
let (tx, rx) = mpsc::channel::<Payload>(10);
|
||||||
let dispatch = app_state.broadcast.subscribe();
|
let dispatch = app_state.broadcast.subscribe();
|
||||||
let socket_state = Arc::new(RwLock::new(SocketState {
|
let socket_state = Arc::new(RwLock::new(SocketState {
|
||||||
instance: None, // Filled out after IDENTIFY
|
// These are filled out after IDENTIFY
|
||||||
|
instance: None,
|
||||||
|
last_heartbeat: 0,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
// Function that merges the socket-local and broadcast streams for sending.
|
||||||
tokio::spawn(merge_channels(
|
tokio::spawn(merge_channels(
|
||||||
app_state.clone(),
|
app_state.clone(),
|
||||||
socket_state.clone(),
|
socket_state.clone(),
|
||||||
dispatch,
|
dispatch,
|
||||||
tx.clone(),
|
tx.clone(),
|
||||||
));
|
));
|
||||||
|
// Function that writes all payloads to the socket.
|
||||||
tokio::spawn(write(rx, sender));
|
tokio::spawn(write(rx, sender));
|
||||||
|
// Function that reads from the socket.
|
||||||
tokio::spawn(read(
|
tokio::spawn(read(
|
||||||
app_state.clone(),
|
app_state.clone(),
|
||||||
socket_state.clone(),
|
socket_state.clone(),
|
||||||
|
@ -71,7 +82,8 @@ async fn read(
|
||||||
server,
|
server,
|
||||||
signature,
|
signature,
|
||||||
} => {
|
} => {
|
||||||
let instance = match IdentityInstance::get(app_state, &server).await {
|
let instance = match IdentityInstance::get(app_state.clone(), &server).await
|
||||||
|
{
|
||||||
Ok(i) => i,
|
Ok(i) => i,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("getting instance {}: {}", server, e);
|
error!("getting instance {}: {}", server, e);
|
||||||
|
@ -161,8 +173,17 @@ async fn read(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send hello
|
// Generate heartbeat interval and send it in the Hello payload
|
||||||
tx.send(Payload::Hello {}).await.ok();
|
let heartbeat_interval = rand::thread_rng().gen_range(45_000..70_000);
|
||||||
|
tx.send(Payload::Hello { heartbeat_interval }).await.ok();
|
||||||
|
// Start the heartbeat loop
|
||||||
|
let (heartbeat_tx, heartbeat_rx) = mpsc::channel::<u64>(10);
|
||||||
|
tokio::spawn(heartbeat(
|
||||||
|
socket_state.clone(),
|
||||||
|
heartbeat_interval,
|
||||||
|
heartbeat_rx,
|
||||||
|
tx.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
while let Some(msg) = receiver.next().await {
|
while let Some(msg) = receiver.next().await {
|
||||||
let Ok(msg) = msg else {
|
let Ok(msg) = msg else {
|
||||||
|
@ -188,13 +209,69 @@ async fn read(
|
||||||
};
|
};
|
||||||
|
|
||||||
match msg {
|
match msg {
|
||||||
Payload::Connect { user_id } => {}
|
Payload::Connect { user_id } => {
|
||||||
|
match collect_ready(app_state.clone(), socket_state.clone(), &user_id).await {
|
||||||
|
Ok(p) => {
|
||||||
|
tx.send(p).await.ok();
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Error collecting ready event data for {}: {}", user_id, e);
|
||||||
|
tx.send(Payload::Error {
|
||||||
|
message: "Internal server error".into(),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Payload::Heartbeat { timestamp } => {
|
||||||
|
// Heartbeats are handled in another function
|
||||||
|
heartbeat_tx.send(timestamp).await.ok();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
tx.send(Payload::Error {
|
||||||
|
message: "Invalid send event".into(),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: handle incoming payloads
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn heartbeat(
|
||||||
|
socket_state: Arc<RwLock<SocketState>>,
|
||||||
|
heartbeat_interval: u64,
|
||||||
|
mut rx: mpsc::Receiver<u64>,
|
||||||
|
tx: mpsc::Sender<Payload>,
|
||||||
|
) {
|
||||||
|
// The timeout is twice the heartbeat interval. If no heartbeat is received by then, close the connection with an error.
|
||||||
|
while let Ok(i) = timeout(Duration::from_millis(heartbeat_interval * 2), rx.recv()).await {
|
||||||
|
match i {
|
||||||
|
// ACK the heartbeat with the same timestamp. TODO: validate the timestamp to make sure we aren't too out of sync with the identity server.
|
||||||
|
Some(timestamp) => {
|
||||||
|
socket_state.write().await.last_heartbeat = timestamp;
|
||||||
|
tx.send(Payload::HeartbeatAck { timestamp }).await.ok();
|
||||||
|
}
|
||||||
|
// If the channel returns None, that means it's been closed, which means the socket as a whole was closed
|
||||||
|
None => {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send an error, which will automatically close the socket
|
||||||
|
tx.send(Payload::Error {
|
||||||
|
message: format!(
|
||||||
|
"Did not receive a heartbeat after {}ms",
|
||||||
|
heartbeat_interval * 2
|
||||||
|
),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
async fn write(mut rx: mpsc::Receiver<Payload>, mut sender: SplitSink<WebSocket, Message>) {
|
async fn write(mut rx: mpsc::Receiver<Payload>, mut sender: SplitSink<WebSocket, Message>) {
|
||||||
loop {
|
loop {
|
||||||
if let Some(ev) = rx.recv().await {
|
if let Some(ev) = rx.recv().await {
|
||||||
|
@ -355,19 +432,32 @@ async fn filter_events(
|
||||||
async fn collect_ready(
|
async fn collect_ready(
|
||||||
app_state: Arc<AppState>,
|
app_state: Arc<AppState>,
|
||||||
socket_state: Arc<RwLock<SocketState>>,
|
socket_state: Arc<RwLock<SocketState>>,
|
||||||
user_id: String,
|
user_id: &str,
|
||||||
) -> Result<Payload> {
|
) -> Result<Payload> {
|
||||||
let Some(instance) = &socket_state.read().await.instance else {
|
let Some(instance) = &socket_state.read().await.instance else {
|
||||||
return Err(eyre!("instance was None when it shouldn't be"));
|
return Err(eyre!("instance was None when it shouldn't be"));
|
||||||
};
|
};
|
||||||
|
|
||||||
let user = sqlx::query!(
|
let user = sqlx::query!(
|
||||||
"SELECT * FROM users WHERE instance_id = $1 AND remote_user_id = $2",
|
r#"SELECT u.*, i.domain FROM users u
|
||||||
|
JOIN identity_instances i ON i.id = u.instance_id
|
||||||
|
WHERE u.instance_id = $1 AND u.remote_user_id = $2"#,
|
||||||
instance.id,
|
instance.id,
|
||||||
user_id
|
user_id
|
||||||
)
|
)
|
||||||
.fetch_one(&app_state.pool)
|
.fetch_one(&app_state.pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
todo!()
|
Ok(Payload::Dispatch {
|
||||||
|
event: Dispatch::Ready {
|
||||||
|
user: User {
|
||||||
|
id: user.id,
|
||||||
|
username: user.username,
|
||||||
|
instance: user.domain,
|
||||||
|
avatar_url: None,
|
||||||
|
},
|
||||||
|
guilds: vec![],
|
||||||
|
},
|
||||||
|
recipients: vec![user.remote_user_id],
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ use super::Dispatch;
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")]
|
#[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")]
|
||||||
pub enum Payload {
|
pub enum Payload {
|
||||||
|
#[serde(rename = "D")]
|
||||||
Dispatch {
|
Dispatch {
|
||||||
#[serde(rename = "e")]
|
#[serde(rename = "e")]
|
||||||
event: Dispatch,
|
event: Dispatch,
|
||||||
|
@ -15,7 +16,9 @@ pub enum Payload {
|
||||||
message: String,
|
message: String,
|
||||||
},
|
},
|
||||||
/// Hello message, sent after authentication succeeds
|
/// Hello message, sent after authentication succeeds
|
||||||
Hello {},
|
Hello {
|
||||||
|
heartbeat_interval: u64,
|
||||||
|
},
|
||||||
/// S2S authentication. Fields correspond to headers (Host, Date, X-Foxchat-Server, X-Foxchat-Signature)
|
/// S2S authentication. Fields correspond to headers (Host, Date, X-Foxchat-Server, X-Foxchat-Signature)
|
||||||
Identify {
|
Identify {
|
||||||
host: String,
|
host: String,
|
||||||
|
@ -26,5 +29,15 @@ pub enum Payload {
|
||||||
/// Sent when a user connects to the identity server's gateway, to signal the chat server to send READY for that user
|
/// Sent when a user connects to the identity server's gateway, to signal the chat server to send READY for that user
|
||||||
Connect {
|
Connect {
|
||||||
user_id: String,
|
user_id: String,
|
||||||
|
},
|
||||||
|
/// Sent on a regular interval by the connecting server, to keep the connection alive.
|
||||||
|
Heartbeat {
|
||||||
|
#[serde(rename = "t")]
|
||||||
|
timestamp: u64,
|
||||||
|
},
|
||||||
|
/// Sent in response to a Heartbeat.
|
||||||
|
HeartbeatAck {
|
||||||
|
#[serde(rename = "t")]
|
||||||
|
timestamp: u64,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue