feat: build out chat server websocket more, start identity websocket
This commit is contained in:
parent
18b644d24b
commit
f7494034d5
13 changed files with 300 additions and 24 deletions
|
@ -3,7 +3,7 @@ use std::sync::Arc;
|
|||
use axum::{Extension, Json};
|
||||
use foxchat::{
|
||||
http::ApiError,
|
||||
model::{http::guild::CreateGuildParams, user::PartialUser, Guild, channel::PartialChannel},
|
||||
model::{channel::PartialChannel, http::guild::CreateGuildParams, user::PartialUser, Guild},
|
||||
FoxError,
|
||||
};
|
||||
|
||||
|
@ -39,8 +39,12 @@ pub async fn post_guilds(
|
|||
instance: user.instance.domain,
|
||||
},
|
||||
default_channel: PartialChannel {
|
||||
id: channel.id.0.clone(),
|
||||
name: channel.name.clone(),
|
||||
},
|
||||
channels: Some(vec![PartialChannel {
|
||||
id: channel.id.0,
|
||||
name: channel.name,
|
||||
}
|
||||
}]),
|
||||
}))
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::sync::Arc;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use axum::{
|
||||
extract::{
|
||||
|
@ -8,8 +8,9 @@ use axum::{
|
|||
response::Response,
|
||||
Extension,
|
||||
};
|
||||
use eyre::Result;
|
||||
use eyre::{eyre, Result};
|
||||
use foxchat::{
|
||||
model::User,
|
||||
s2s::{Dispatch, Payload},
|
||||
signature::{parse_date, verify_signature},
|
||||
};
|
||||
|
@ -17,7 +18,11 @@ use futures::{
|
|||
stream::{SplitSink, SplitStream, StreamExt},
|
||||
SinkExt,
|
||||
};
|
||||
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||
use rand::Rng;
|
||||
use tokio::{
|
||||
sync::{broadcast, mpsc, RwLock},
|
||||
time::timeout,
|
||||
};
|
||||
use tracing::error;
|
||||
|
||||
use crate::{app_state::AppState, model::identity_instance::IdentityInstance};
|
||||
|
@ -28,6 +33,7 @@ pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppSt
|
|||
|
||||
struct SocketState {
|
||||
instance: Option<IdentityInstance>,
|
||||
last_heartbeat: u64,
|
||||
}
|
||||
|
||||
async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
|
||||
|
@ -36,16 +42,21 @@ async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
|
|||
let (tx, rx) = mpsc::channel::<Payload>(10);
|
||||
let dispatch = app_state.broadcast.subscribe();
|
||||
let socket_state = Arc::new(RwLock::new(SocketState {
|
||||
instance: None, // Filled out after IDENTIFY
|
||||
// These are filled out after IDENTIFY
|
||||
instance: None,
|
||||
last_heartbeat: 0,
|
||||
}));
|
||||
|
||||
// Function that merges the socket-local and broadcast streams for sending.
|
||||
tokio::spawn(merge_channels(
|
||||
app_state.clone(),
|
||||
socket_state.clone(),
|
||||
dispatch,
|
||||
tx.clone(),
|
||||
));
|
||||
// Function that writes all payloads to the socket.
|
||||
tokio::spawn(write(rx, sender));
|
||||
// Function that reads from the socket.
|
||||
tokio::spawn(read(
|
||||
app_state.clone(),
|
||||
socket_state.clone(),
|
||||
|
@ -71,7 +82,8 @@ async fn read(
|
|||
server,
|
||||
signature,
|
||||
} => {
|
||||
let instance = match IdentityInstance::get(app_state, &server).await {
|
||||
let instance = match IdentityInstance::get(app_state.clone(), &server).await
|
||||
{
|
||||
Ok(i) => i,
|
||||
Err(e) => {
|
||||
error!("getting instance {}: {}", server, e);
|
||||
|
@ -161,20 +173,105 @@ async fn read(
|
|||
return;
|
||||
}
|
||||
|
||||
// Send hello
|
||||
tx.send(Payload::Hello {}).await.ok();
|
||||
// Generate heartbeat interval and send it in the Hello payload
|
||||
let heartbeat_interval = rand::thread_rng().gen_range(45_000..70_000);
|
||||
tx.send(Payload::Hello { heartbeat_interval }).await.ok();
|
||||
// Start the heartbeat loop
|
||||
let (heartbeat_tx, heartbeat_rx) = mpsc::channel::<u64>(10);
|
||||
tokio::spawn(heartbeat(
|
||||
socket_state.clone(),
|
||||
heartbeat_interval,
|
||||
heartbeat_rx,
|
||||
tx.clone(),
|
||||
));
|
||||
|
||||
while let Some(msg) = receiver.next().await {
|
||||
let msg = if let Ok(msg) = msg {
|
||||
msg
|
||||
} else {
|
||||
let Ok(msg) = msg else {
|
||||
return;
|
||||
};
|
||||
|
||||
// TODO: handle incoming payloads
|
||||
let Ok(msg) = msg.to_text() else {
|
||||
tx.send(Payload::Error {
|
||||
message: "Invalid message".into(),
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
return;
|
||||
};
|
||||
|
||||
let Ok(msg) = serde_json::from_str::<Payload>(msg) else {
|
||||
tx.send(Payload::Error {
|
||||
message: "Invalid message".into(),
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
return;
|
||||
};
|
||||
|
||||
match msg {
|
||||
Payload::Connect { user_id } => {
|
||||
match collect_ready(app_state.clone(), socket_state.clone(), &user_id).await {
|
||||
Ok(p) => {
|
||||
tx.send(p).await.ok();
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error collecting ready event data for {}: {}", user_id, e);
|
||||
tx.send(Payload::Error {
|
||||
message: "Internal server error".into(),
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
Payload::Heartbeat { timestamp } => {
|
||||
// Heartbeats are handled in another function
|
||||
heartbeat_tx.send(timestamp).await.ok();
|
||||
}
|
||||
_ => {
|
||||
tx.send(Payload::Error {
|
||||
message: "Invalid send event".into(),
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn heartbeat(
|
||||
socket_state: Arc<RwLock<SocketState>>,
|
||||
heartbeat_interval: u64,
|
||||
mut rx: mpsc::Receiver<u64>,
|
||||
tx: mpsc::Sender<Payload>,
|
||||
) {
|
||||
// The timeout is twice the heartbeat interval. If no heartbeat is received by then, close the connection with an error.
|
||||
while let Ok(i) = timeout(Duration::from_millis(heartbeat_interval * 2), rx.recv()).await {
|
||||
match i {
|
||||
// ACK the heartbeat with the same timestamp. TODO: validate the timestamp to make sure we aren't too out of sync with the identity server.
|
||||
Some(timestamp) => {
|
||||
socket_state.write().await.last_heartbeat = timestamp;
|
||||
tx.send(Payload::HeartbeatAck { timestamp }).await.ok();
|
||||
}
|
||||
// If the channel returns None, that means it's been closed, which means the socket as a whole was closed
|
||||
None => {
|
||||
return;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Send an error, which will automatically close the socket
|
||||
tx.send(Payload::Error {
|
||||
message: format!(
|
||||
"Did not receive a heartbeat after {}ms",
|
||||
heartbeat_interval * 2
|
||||
),
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
|
||||
async fn write(mut rx: mpsc::Receiver<Payload>, mut sender: SplitSink<WebSocket, Message>) {
|
||||
loop {
|
||||
if let Some(ev) = rx.recv().await {
|
||||
|
@ -289,10 +386,78 @@ async fn filter_events(
|
|||
socket_state: Arc<RwLock<SocketState>>,
|
||||
evt: &Dispatch,
|
||||
) -> Result<(bool, Vec<String>)> {
|
||||
// If we're not authenticated yet, don't send anything
|
||||
if socket_state.read().await.instance.is_none() {
|
||||
let Some(instance) = &socket_state.read().await.instance else {
|
||||
return Ok((false, vec![]));
|
||||
}
|
||||
};
|
||||
|
||||
Ok((true, vec![]))
|
||||
match evt {
|
||||
Dispatch::MessageCreate {
|
||||
id: _,
|
||||
channel_id: _,
|
||||
guild_id,
|
||||
author: _,
|
||||
content: _,
|
||||
created_at: _,
|
||||
} => {
|
||||
let users = sqlx::query!(
|
||||
r#"SELECT ARRAY(
|
||||
SELECT u.remote_user_id FROM users u
|
||||
JOIN guilds_users gu ON gu.user_id = u.id
|
||||
WHERE u.instance_id = $1 AND gu.guild_id = $2
|
||||
)"#,
|
||||
instance.id,
|
||||
guild_id
|
||||
)
|
||||
.fetch_one(&app_state.pool)
|
||||
.await?;
|
||||
|
||||
if let Some(users) = users.array {
|
||||
return Ok((users.len() > 0, users));
|
||||
}
|
||||
return Ok((false, vec![]));
|
||||
}
|
||||
Dispatch::Ready { user, guilds: _ } => {
|
||||
let user = sqlx::query!(
|
||||
"SELECT remote_user_id FROM users WHERE id = $1",
|
||||
user.id.clone()
|
||||
)
|
||||
.fetch_one(&app_state.pool)
|
||||
.await?;
|
||||
|
||||
return Ok((true, vec![user.remote_user_id]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn collect_ready(
|
||||
app_state: Arc<AppState>,
|
||||
socket_state: Arc<RwLock<SocketState>>,
|
||||
user_id: &str,
|
||||
) -> Result<Payload> {
|
||||
let Some(instance) = &socket_state.read().await.instance else {
|
||||
return Err(eyre!("instance was None when it shouldn't be"));
|
||||
};
|
||||
|
||||
let user = sqlx::query!(
|
||||
r#"SELECT u.*, i.domain FROM users u
|
||||
JOIN identity_instances i ON i.id = u.instance_id
|
||||
WHERE u.instance_id = $1 AND u.remote_user_id = $2"#,
|
||||
instance.id,
|
||||
user_id
|
||||
)
|
||||
.fetch_one(&app_state.pool)
|
||||
.await?;
|
||||
|
||||
Ok(Payload::Dispatch {
|
||||
event: Dispatch::Ready {
|
||||
user: User {
|
||||
id: user.id,
|
||||
username: user.username,
|
||||
instance: user.domain,
|
||||
avatar_url: None,
|
||||
},
|
||||
guilds: vec![],
|
||||
},
|
||||
recipients: vec![user.remote_user_id],
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue