foxchat/chat/src/http/ws/mod.rs
2024-02-25 22:26:24 +01:00

373 lines
11 KiB
Rust

use std::sync::Arc;
use axum::{
extract::{
ws::{Message, WebSocket},
WebSocketUpgrade,
},
response::Response,
Extension,
};
use eyre::{eyre, Error, Result};
use foxchat::{
s2s::{Dispatch, Payload},
signature::{parse_date, verify_signature},
};
use futures::{
stream::{SplitSink, SplitStream, StreamExt},
SinkExt,
};
use tokio::sync::{broadcast, mpsc, RwLock};
use tracing::error;
use crate::{app_state::AppState, model::identity_instance::IdentityInstance};
pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppState>>) -> Response {
ws.on_upgrade(|socket| handle_socket(state, socket))
}
struct SocketState {
instance: Option<IdentityInstance>,
}
async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
let (sender, receiver) = socket.split();
let (tx, rx) = mpsc::channel::<Payload>(10);
let dispatch = app_state.broadcast.subscribe();
let socket_state = Arc::new(RwLock::new(SocketState {
instance: None, // Filled out after IDENTIFY
}));
tokio::spawn(merge_channels(
app_state.clone(),
socket_state.clone(),
dispatch,
tx.clone(),
));
tokio::spawn(write(rx, sender));
tokio::spawn(read(
app_state.clone(),
socket_state.clone(),
tx.clone(),
receiver,
));
}
async fn read(
app_state: Arc<AppState>,
socket_state: Arc<RwLock<SocketState>>,
tx: mpsc::Sender<Payload>,
mut receiver: SplitStream<WebSocket>,
) {
let msg = receiver.next().await;
if let Some(Ok(msg)) = msg {
if let Ok(msg) = msg.into_text() {
if let Ok(pl) = serde_json::from_str::<Payload>(&msg) {
match pl {
Payload::Identify {
host,
date,
server,
signature,
} => {
let instance = match IdentityInstance::get(app_state, &server).await {
Ok(i) => i,
Err(e) => {
error!("getting instance {}: {}", server, e);
tx.send(Payload::Error {
message: "Unknown instance".into(),
})
.await
.ok();
return;
}
};
let public_key = match instance.parse_public_key() {
Ok(k) => k,
Err(e) => {
error!("parsing public key for instance: {}", e);
tx.send(Payload::Error {
message: "Internal server error".into(),
})
.await
.ok();
return;
}
};
let date = match parse_date(&date) {
Ok(d) => d,
Err(_) => {
tx.send(Payload::Error {
message: "Invalid date".into(),
})
.await
.ok();
return;
}
};
if let Err(e) = verify_signature(
&public_key,
signature,
date,
&host,
"/_fox/chat/ws",
None,
None,
) {
error!("Verifying signature in websocket from {}: {}", &server, e);
tx.send(Payload::Error {
message: "Invalid signature".into(),
})
.await
.ok();
return;
}
// Everything is good, store instance
socket_state.write().await.instance = Some(instance);
}
_ => {
tx.send(Payload::Error {
message: "First payload was not IDENTIFY".into(),
})
.await
.ok();
return;
}
}
} else {
tx.send(Payload::Error {
message: "Invalid JSON payload".into(),
})
.await
.ok();
return;
}
} else {
tx.send(Payload::Error {
message: "First payload was not IDENTIFY".into(),
})
.await
.ok();
return;
}
} else {
// Websocket closed, return
return;
}
// Send hello
tx.send(Payload::Hello {}).await.ok();
while let Some(msg) = receiver.next().await {
let Ok(msg) = msg else {
return;
};
let Ok(msg) = msg.to_text() else {
tx.send(Payload::Error {
message: "Invalid message".into(),
})
.await
.ok();
return;
};
let Ok(msg) = serde_json::from_str::<Payload>(msg) else {
tx.send(Payload::Error {
message: "Invalid message".into(),
})
.await
.ok();
return;
};
match msg {
Payload::Connect { user_id } => {}
}
// TODO: handle incoming payloads
}
}
async fn write(mut rx: mpsc::Receiver<Payload>, mut sender: SplitSink<WebSocket, Message>) {
loop {
if let Some(ev) = rx.recv().await {
match &ev {
Payload::Error { message: _ } => {
let msg = match to_json(ev) {
Ok(m) => m,
Err(e) => {
error!("error serializing payload to JSON: {}", e);
return;
}
};
match sender.send(msg).await {
Ok(_) => {}
// Pretty sure we can assume the websocket was closed, so just return
Err(e) => {
error!("error writing to websocket: {}", e);
return;
}
}
match sender.close().await {
Ok(_) => {}
Err(e) => {
error!("error closing websocket: {}", e)
}
}
return;
}
_ => {
// Forward payload
let msg = match to_json(ev) {
Ok(m) => m,
Err(e) => {
error!("error serializing payload to JSON: {}", e);
return;
}
};
match sender.send(msg).await {
Ok(_) => {}
// Pretty sure we can assume the websocket was closed, so just return
Err(e) => {
error!("error writing to websocket: {}", e);
return;
}
}
}
}
}
}
}
async fn merge_channels(
app_state: Arc<AppState>,
socket_state: Arc<RwLock<SocketState>>,
mut rx: broadcast::Receiver<Dispatch>,
tx: mpsc::Sender<Payload>,
) {
loop {
let msg = rx.recv().await;
match msg {
Ok(evt) => {
let (send, recipients) =
match filter_events(app_state.clone(), socket_state.clone(), &evt).await {
Ok(v) => v,
Err(e) => {
error!("Checking whether to send an event: {}", e);
tx.send(Payload::Error {
message: "Internal server error".into(),
})
.await
.ok();
continue;
}
};
if !send {
continue;
}
tx.send(Payload::Dispatch {
event: evt,
recipients,
})
.await
.ok();
}
Err(e) => match e {
broadcast::error::RecvError::Closed => {
error!("Broadcast channel was closed, this is not supposed to happen");
return;
}
broadcast::error::RecvError::Lagged(i) => {
error!("Broadcast receive lagged by {i}");
}
},
}
}
}
fn to_json(p: Payload) -> Result<Message> {
Ok(Message::Text(serde_json::to_string(&p)?))
}
/// Returns whether or not an event should be sent at all, and if it should be, which users to send it to.
async fn filter_events(
app_state: Arc<AppState>,
socket_state: Arc<RwLock<SocketState>>,
evt: &Dispatch,
) -> Result<(bool, Vec<String>)> {
let Some(instance) = &socket_state.read().await.instance else {
return Ok((false, vec![]));
};
match evt {
Dispatch::MessageCreate {
id: _,
channel_id: _,
guild_id,
author: _,
content: _,
created_at: _,
} => {
let users = sqlx::query!(
r#"SELECT ARRAY(
SELECT u.remote_user_id FROM users u
JOIN guilds_users gu ON gu.user_id = u.id
WHERE u.instance_id = $1 AND gu.guild_id = $2
)"#,
instance.id,
guild_id
)
.fetch_one(&app_state.pool)
.await?;
if let Some(users) = users.array {
return Ok((users.len() > 0, users));
}
return Ok((false, vec![]));
}
Dispatch::Ready { user, guilds: _ } => {
let user = sqlx::query!(
"SELECT remote_user_id FROM users WHERE id = $1",
user.id.clone()
)
.fetch_one(&app_state.pool)
.await?;
return Ok((true, vec![user.remote_user_id]));
}
}
}
async fn collect_ready(
app_state: Arc<AppState>,
socket_state: Arc<RwLock<SocketState>>,
user_id: String,
) -> Result<Payload> {
let Some(instance) = &socket_state.read().await.instance else {
return Err(eyre!("instance was None when it shouldn't be"));
};
let user = sqlx::query!(
"SELECT * FROM users WHERE instance_id = $1 AND remote_user_id = $2",
instance.id,
user_id
)
.fetch_one(&app_state.pool)
.await?;
todo!()
}