373 lines
11 KiB
Rust
373 lines
11 KiB
Rust
use std::sync::Arc;
|
|
|
|
use axum::{
|
|
extract::{
|
|
ws::{Message, WebSocket},
|
|
WebSocketUpgrade,
|
|
},
|
|
response::Response,
|
|
Extension,
|
|
};
|
|
use eyre::{eyre, Error, Result};
|
|
use foxchat::{
|
|
s2s::{Dispatch, Payload},
|
|
signature::{parse_date, verify_signature},
|
|
};
|
|
use futures::{
|
|
stream::{SplitSink, SplitStream, StreamExt},
|
|
SinkExt,
|
|
};
|
|
use tokio::sync::{broadcast, mpsc, RwLock};
|
|
use tracing::error;
|
|
|
|
use crate::{app_state::AppState, model::identity_instance::IdentityInstance};
|
|
|
|
pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppState>>) -> Response {
|
|
ws.on_upgrade(|socket| handle_socket(state, socket))
|
|
}
|
|
|
|
struct SocketState {
|
|
instance: Option<IdentityInstance>,
|
|
}
|
|
|
|
async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
|
|
let (sender, receiver) = socket.split();
|
|
|
|
let (tx, rx) = mpsc::channel::<Payload>(10);
|
|
let dispatch = app_state.broadcast.subscribe();
|
|
let socket_state = Arc::new(RwLock::new(SocketState {
|
|
instance: None, // Filled out after IDENTIFY
|
|
}));
|
|
|
|
tokio::spawn(merge_channels(
|
|
app_state.clone(),
|
|
socket_state.clone(),
|
|
dispatch,
|
|
tx.clone(),
|
|
));
|
|
tokio::spawn(write(rx, sender));
|
|
tokio::spawn(read(
|
|
app_state.clone(),
|
|
socket_state.clone(),
|
|
tx.clone(),
|
|
receiver,
|
|
));
|
|
}
|
|
|
|
async fn read(
|
|
app_state: Arc<AppState>,
|
|
socket_state: Arc<RwLock<SocketState>>,
|
|
tx: mpsc::Sender<Payload>,
|
|
mut receiver: SplitStream<WebSocket>,
|
|
) {
|
|
let msg = receiver.next().await;
|
|
if let Some(Ok(msg)) = msg {
|
|
if let Ok(msg) = msg.into_text() {
|
|
if let Ok(pl) = serde_json::from_str::<Payload>(&msg) {
|
|
match pl {
|
|
Payload::Identify {
|
|
host,
|
|
date,
|
|
server,
|
|
signature,
|
|
} => {
|
|
let instance = match IdentityInstance::get(app_state, &server).await {
|
|
Ok(i) => i,
|
|
Err(e) => {
|
|
error!("getting instance {}: {}", server, e);
|
|
tx.send(Payload::Error {
|
|
message: "Unknown instance".into(),
|
|
})
|
|
.await
|
|
.ok();
|
|
return;
|
|
}
|
|
};
|
|
|
|
let public_key = match instance.parse_public_key() {
|
|
Ok(k) => k,
|
|
Err(e) => {
|
|
error!("parsing public key for instance: {}", e);
|
|
tx.send(Payload::Error {
|
|
message: "Internal server error".into(),
|
|
})
|
|
.await
|
|
.ok();
|
|
return;
|
|
}
|
|
};
|
|
|
|
let date = match parse_date(&date) {
|
|
Ok(d) => d,
|
|
Err(_) => {
|
|
tx.send(Payload::Error {
|
|
message: "Invalid date".into(),
|
|
})
|
|
.await
|
|
.ok();
|
|
return;
|
|
}
|
|
};
|
|
|
|
if let Err(e) = verify_signature(
|
|
&public_key,
|
|
signature,
|
|
date,
|
|
&host,
|
|
"/_fox/chat/ws",
|
|
None,
|
|
None,
|
|
) {
|
|
error!("Verifying signature in websocket from {}: {}", &server, e);
|
|
|
|
tx.send(Payload::Error {
|
|
message: "Invalid signature".into(),
|
|
})
|
|
.await
|
|
.ok();
|
|
return;
|
|
}
|
|
|
|
// Everything is good, store instance
|
|
socket_state.write().await.instance = Some(instance);
|
|
}
|
|
_ => {
|
|
tx.send(Payload::Error {
|
|
message: "First payload was not IDENTIFY".into(),
|
|
})
|
|
.await
|
|
.ok();
|
|
return;
|
|
}
|
|
}
|
|
} else {
|
|
tx.send(Payload::Error {
|
|
message: "Invalid JSON payload".into(),
|
|
})
|
|
.await
|
|
.ok();
|
|
return;
|
|
}
|
|
} else {
|
|
tx.send(Payload::Error {
|
|
message: "First payload was not IDENTIFY".into(),
|
|
})
|
|
.await
|
|
.ok();
|
|
return;
|
|
}
|
|
} else {
|
|
// Websocket closed, return
|
|
return;
|
|
}
|
|
|
|
// Send hello
|
|
tx.send(Payload::Hello {}).await.ok();
|
|
|
|
while let Some(msg) = receiver.next().await {
|
|
let Ok(msg) = msg else {
|
|
return;
|
|
};
|
|
|
|
let Ok(msg) = msg.to_text() else {
|
|
tx.send(Payload::Error {
|
|
message: "Invalid message".into(),
|
|
})
|
|
.await
|
|
.ok();
|
|
return;
|
|
};
|
|
|
|
let Ok(msg) = serde_json::from_str::<Payload>(msg) else {
|
|
tx.send(Payload::Error {
|
|
message: "Invalid message".into(),
|
|
})
|
|
.await
|
|
.ok();
|
|
return;
|
|
};
|
|
|
|
match msg {
|
|
Payload::Connect { user_id } => {}
|
|
}
|
|
|
|
// TODO: handle incoming payloads
|
|
}
|
|
}
|
|
|
|
async fn write(mut rx: mpsc::Receiver<Payload>, mut sender: SplitSink<WebSocket, Message>) {
|
|
loop {
|
|
if let Some(ev) = rx.recv().await {
|
|
match &ev {
|
|
Payload::Error { message: _ } => {
|
|
let msg = match to_json(ev) {
|
|
Ok(m) => m,
|
|
Err(e) => {
|
|
error!("error serializing payload to JSON: {}", e);
|
|
return;
|
|
}
|
|
};
|
|
|
|
match sender.send(msg).await {
|
|
Ok(_) => {}
|
|
// Pretty sure we can assume the websocket was closed, so just return
|
|
Err(e) => {
|
|
error!("error writing to websocket: {}", e);
|
|
return;
|
|
}
|
|
}
|
|
|
|
match sender.close().await {
|
|
Ok(_) => {}
|
|
Err(e) => {
|
|
error!("error closing websocket: {}", e)
|
|
}
|
|
}
|
|
|
|
return;
|
|
}
|
|
_ => {
|
|
// Forward payload
|
|
let msg = match to_json(ev) {
|
|
Ok(m) => m,
|
|
Err(e) => {
|
|
error!("error serializing payload to JSON: {}", e);
|
|
return;
|
|
}
|
|
};
|
|
|
|
match sender.send(msg).await {
|
|
Ok(_) => {}
|
|
// Pretty sure we can assume the websocket was closed, so just return
|
|
Err(e) => {
|
|
error!("error writing to websocket: {}", e);
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn merge_channels(
|
|
app_state: Arc<AppState>,
|
|
socket_state: Arc<RwLock<SocketState>>,
|
|
mut rx: broadcast::Receiver<Dispatch>,
|
|
tx: mpsc::Sender<Payload>,
|
|
) {
|
|
loop {
|
|
let msg = rx.recv().await;
|
|
match msg {
|
|
Ok(evt) => {
|
|
let (send, recipients) =
|
|
match filter_events(app_state.clone(), socket_state.clone(), &evt).await {
|
|
Ok(v) => v,
|
|
Err(e) => {
|
|
error!("Checking whether to send an event: {}", e);
|
|
tx.send(Payload::Error {
|
|
message: "Internal server error".into(),
|
|
})
|
|
.await
|
|
.ok();
|
|
|
|
continue;
|
|
}
|
|
};
|
|
|
|
if !send {
|
|
continue;
|
|
}
|
|
|
|
tx.send(Payload::Dispatch {
|
|
event: evt,
|
|
recipients,
|
|
})
|
|
.await
|
|
.ok();
|
|
}
|
|
Err(e) => match e {
|
|
broadcast::error::RecvError::Closed => {
|
|
error!("Broadcast channel was closed, this is not supposed to happen");
|
|
return;
|
|
}
|
|
broadcast::error::RecvError::Lagged(i) => {
|
|
error!("Broadcast receive lagged by {i}");
|
|
}
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
fn to_json(p: Payload) -> Result<Message> {
|
|
Ok(Message::Text(serde_json::to_string(&p)?))
|
|
}
|
|
|
|
/// Returns whether or not an event should be sent at all, and if it should be, which users to send it to.
|
|
async fn filter_events(
|
|
app_state: Arc<AppState>,
|
|
socket_state: Arc<RwLock<SocketState>>,
|
|
evt: &Dispatch,
|
|
) -> Result<(bool, Vec<String>)> {
|
|
let Some(instance) = &socket_state.read().await.instance else {
|
|
return Ok((false, vec![]));
|
|
};
|
|
|
|
match evt {
|
|
Dispatch::MessageCreate {
|
|
id: _,
|
|
channel_id: _,
|
|
guild_id,
|
|
author: _,
|
|
content: _,
|
|
created_at: _,
|
|
} => {
|
|
let users = sqlx::query!(
|
|
r#"SELECT ARRAY(
|
|
SELECT u.remote_user_id FROM users u
|
|
JOIN guilds_users gu ON gu.user_id = u.id
|
|
WHERE u.instance_id = $1 AND gu.guild_id = $2
|
|
)"#,
|
|
instance.id,
|
|
guild_id
|
|
)
|
|
.fetch_one(&app_state.pool)
|
|
.await?;
|
|
|
|
if let Some(users) = users.array {
|
|
return Ok((users.len() > 0, users));
|
|
}
|
|
return Ok((false, vec![]));
|
|
}
|
|
Dispatch::Ready { user, guilds: _ } => {
|
|
let user = sqlx::query!(
|
|
"SELECT remote_user_id FROM users WHERE id = $1",
|
|
user.id.clone()
|
|
)
|
|
.fetch_one(&app_state.pool)
|
|
.await?;
|
|
|
|
return Ok((true, vec![user.remote_user_id]));
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn collect_ready(
|
|
app_state: Arc<AppState>,
|
|
socket_state: Arc<RwLock<SocketState>>,
|
|
user_id: String,
|
|
) -> Result<Payload> {
|
|
let Some(instance) = &socket_state.read().await.instance else {
|
|
return Err(eyre!("instance was None when it shouldn't be"));
|
|
};
|
|
|
|
let user = sqlx::query!(
|
|
"SELECT * FROM users WHERE instance_id = $1 AND remote_user_id = $2",
|
|
instance.id,
|
|
user_id
|
|
)
|
|
.fetch_one(&app_state.pool)
|
|
.await?;
|
|
|
|
todo!()
|
|
}
|