maybe (barely) working websocket now?

This commit is contained in:
sam 2024-01-22 02:18:34 +01:00
parent f07333e358
commit 0858d4893a
Signed by: sam
GPG key ID: B4EF20DDE721CAA1
5 changed files with 253 additions and 41 deletions

View file

@ -8,38 +8,55 @@ use axum::{
response::Response, response::Response,
Extension, Extension,
}; };
use foxchat::s2s::{Dispatch, Payload}; use eyre::Result;
use futures::stream::{SplitSink, SplitStream, StreamExt}; use foxchat::{
use tokio::sync::{broadcast, mpsc}; s2s::{Dispatch, Payload},
signature::{parse_date, verify_signature},
};
use futures::{
stream::{SplitSink, SplitStream, StreamExt},
SinkExt,
};
use tokio::sync::{broadcast, mpsc, RwLock};
use tracing::error; use tracing::error;
use crate::app_state::AppState; use crate::{app_state::AppState, model::identity_instance::IdentityInstance};
pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppState>>) -> Response { pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppState>>) -> Response {
ws.on_upgrade(|socket| handle_socket(state, socket)) ws.on_upgrade(|socket| handle_socket(state, socket))
} }
struct SocketState { struct SocketState {
instance_id: Option<String>, instance: Option<IdentityInstance>,
} }
async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) { async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
let (mut sender, mut receiver) = socket.split(); let (sender, receiver) = socket.split();
let (tx, rx) = mpsc::channel::<Payload>(10); let (tx, rx) = mpsc::channel::<Payload>(10);
let dispatch = app_state.broadcast.subscribe(); let dispatch = app_state.broadcast.subscribe();
let socket_state = Arc::new(SocketState { let socket_state = Arc::new(RwLock::new(SocketState {
instance_id: None, // Filled out after IDENTIFY instance: None, // Filled out after IDENTIFY
}); }));
tokio::spawn(merge_channels(app_state.clone(), socket_state.clone(), dispatch, tx.clone())); tokio::spawn(merge_channels(
tokio::spawn(write(app_state.clone(), socket_state.clone(), rx, sender)); app_state.clone(),
tokio::spawn(read(app_state.clone(), socket_state.clone(), tx.clone(), receiver)); socket_state.clone(),
dispatch,
tx.clone(),
));
tokio::spawn(write(rx, sender));
tokio::spawn(read(
app_state.clone(),
socket_state.clone(),
tx.clone(),
receiver,
));
} }
async fn read( async fn read(
app_state: Arc<AppState>, app_state: Arc<AppState>,
socket_state: Arc<SocketState>, socket_state: Arc<RwLock<SocketState>>,
tx: mpsc::Sender<Payload>, tx: mpsc::Sender<Payload>,
mut receiver: SplitStream<WebSocket>, mut receiver: SplitStream<WebSocket>,
) { ) {
@ -48,26 +65,105 @@ async fn read(
if let Ok(msg) = msg.into_text() { if let Ok(msg) = msg.into_text() {
if let Ok(pl) = serde_json::from_str::<Payload>(&msg) { if let Ok(pl) = serde_json::from_str::<Payload>(&msg) {
match pl { match pl {
Payload::Identify { host, date, server, signature } => { Payload::Identify {
// TODO: identify host,
}, date,
_ => { server,
tx.send(Payload::Error { message: "First payload was not IDENTIFY".into() }).await.ok(); signature,
} => {
let instance = match IdentityInstance::get(app_state, &server).await {
Ok(i) => i,
Err(e) => {
error!("getting instance {}: {}", server, e);
tx.send(Payload::Error {
message: "Unknown instance".into(),
})
.await
.ok();
return; return;
} }
} };
} else {
tx.send(Payload::Error { message: "Invalid JSON payload".into() }).await.ok();
return;
}
} else {
let public_key = match instance.parse_public_key() {
Ok(k) => k,
Err(e) => {
error!("parsing public key for instance: {}", e);
tx.send(Payload::Error {
message: "Internal server error".into(),
})
.await
.ok();
return;
}
};
let date = match parse_date(&date) {
Ok(d) => d,
Err(_) => {
tx.send(Payload::Error {
message: "Invalid date".into(),
})
.await
.ok();
return;
}
};
if let Err(e) = verify_signature(
&public_key,
signature,
date,
&host,
"/_fox/chat/ws",
None,
None,
) {
error!("Verifying signature in websocket from {}: {}", &server, e);
tx.send(Payload::Error {
message: "Invalid signature".into(),
})
.await
.ok();
return;
}
// Everything is good, store instance
socket_state.write().await.instance = Some(instance);
}
_ => {
tx.send(Payload::Error {
message: "First payload was not IDENTIFY".into(),
})
.await
.ok();
return;
}
}
} else {
tx.send(Payload::Error {
message: "Invalid JSON payload".into(),
})
.await
.ok();
return;
}
} else {
tx.send(Payload::Error {
message: "First payload was not IDENTIFY".into(),
})
.await
.ok();
return;
} }
} else { } else {
// Websocket closed, return // Websocket closed, return
return; return;
} }
// Send hello
tx.send(Payload::Hello {}).await.ok();
while let Some(msg) = receiver.next().await { while let Some(msg) = receiver.next().await {
let msg = if let Ok(msg) = msg { let msg = if let Ok(msg) = msg {
msg msg
@ -79,41 +175,124 @@ async fn read(
} }
} }
async fn write( async fn write(mut rx: mpsc::Receiver<Payload>, mut sender: SplitSink<WebSocket, Message>) {
app_state: Arc<AppState>, loop {
socket_state: Arc<SocketState>, if let Some(ev) = rx.recv().await {
mut rx: mpsc::Receiver<Payload>, match &ev {
mut sender: SplitSink<WebSocket, Message>, Payload::Error { message: _ } => {
) { let msg = match to_json(ev) {
Ok(m) => m,
Err(e) => {
error!("error serializing payload to JSON: {}", e);
return;
}
};
match sender.send(msg).await {
Ok(_) => {}
// Pretty sure we can assume the websocket was closed, so just return
Err(e) => {
error!("error writing to websocket: {}", e);
return;
}
}
match sender.close().await {
Ok(_) => {}
Err(e) => {
error!("error closing websocket: {}", e)
}
}
return;
}
_ => {
// Forward payload
let msg = match to_json(ev) {
Ok(m) => m,
Err(e) => {
error!("error serializing payload to JSON: {}", e);
return;
}
};
match sender.send(msg).await {
Ok(_) => {}
// Pretty sure we can assume the websocket was closed, so just return
Err(e) => {
error!("error writing to websocket: {}", e);
return;
}
}
}
}
}
}
} }
async fn merge_channels( async fn merge_channels(
app_state: Arc<AppState>, app_state: Arc<AppState>,
socket_state: Arc<SocketState>, socket_state: Arc<RwLock<SocketState>>,
mut rx: broadcast::Receiver<Dispatch>, mut rx: broadcast::Receiver<Dispatch>,
tx: mpsc::Sender<Payload>, tx: mpsc::Sender<Payload>,
) { ) {
loop { loop {
let msg = rx.recv().await; let msg = rx.recv().await;
match msg { match msg {
Ok(p) => { Ok(evt) => {
// TODO: filter users let (send, recipients) =
tx.send(Payload::Dispatch { match filter_events(app_state.clone(), socket_state.clone(), &evt).await {
event: p, Ok(v) => v,
recipients: vec![], Err(e) => {
error!("Checking whether to send an event: {}", e);
tx.send(Payload::Error {
message: "Internal server error".into(),
}) })
.await .await
.ok(); .ok();
},
continue;
}
};
if !send {
continue;
}
tx.send(Payload::Dispatch {
event: evt,
recipients,
})
.await
.ok();
}
Err(e) => match e { Err(e) => match e {
broadcast::error::RecvError::Closed => { broadcast::error::RecvError::Closed => {
error!("Broadcast channel was closed, this is not supposed to happen"); error!("Broadcast channel was closed, this is not supposed to happen");
return; return;
}, }
broadcast::error::RecvError::Lagged(i) => { broadcast::error::RecvError::Lagged(i) => {
error!("Broadcast receive lagged by {i}"); error!("Broadcast receive lagged by {i}");
} }
},
} }
} }
} }
fn to_json(p: Payload) -> Result<Message> {
Ok(Message::Text(serde_json::to_string(&p)?))
}
/// Returns whether or not an event should be sent at all, and if it should be, which users to send it to.
async fn filter_events(
app_state: Arc<AppState>,
socket_state: Arc<RwLock<SocketState>>,
evt: &Dispatch,
) -> Result<(bool, Vec<String>)> {
// If we're not authenticated yet, don't send anything
if socket_state.read().await.instance.is_none() {
return Ok((false, vec![]));
}
Ok((true, vec![]))
} }

View file

@ -3,7 +3,7 @@ use serde::{Serialize, Deserialize};
use super::user::PartialUser; use super::user::PartialUser;
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Message { pub struct Message {
pub id: String, pub id: String,
pub channel_id: String, pub channel_id: String,

View file

@ -1,6 +1,6 @@
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct User { pub struct User {
pub id: String, pub id: String,
pub username: String, pub username: String,
@ -8,7 +8,7 @@ pub struct User {
pub avatar_url: Option<String>, pub avatar_url: Option<String>,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PartialUser { pub struct PartialUser {
pub id: String, pub id: String,
pub username: String, pub username: String,

View file

@ -1,7 +1,34 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{
id::GuildType,
model::{user::PartialUser, Message},
Id,
};
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")] #[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Dispatch { pub enum Dispatch {
MessageCreate MessageCreate {
id: String,
channel_id: String,
guild_id: String,
author: PartialUser,
content: Option<String>,
created_at: DateTime<Utc>,
},
}
impl Dispatch {
pub fn message_create(m: Message, guild_id: Id<GuildType>) -> Self {
Self::MessageCreate {
id: m.id,
channel_id: m.channel_id,
guild_id: guild_id.0,
author: m.author,
content: m.content,
created_at: m.created_at,
}
}
} }

View file

@ -11,14 +11,20 @@ pub enum Payload {
#[serde(rename = "r")] #[serde(rename = "r")]
recipients: Vec<String>, recipients: Vec<String>,
}, },
Hello, Error {
message: String,
},
/// Hello message, sent after authentication succeeds
Hello {},
/// S2S authentication. Fields correspond to headers (Host, Date, X-Foxchat-Server, X-Foxchat-Signature)
Identify { Identify {
host: String, host: String,
date: String, date: String,
server: String, server: String,
signature: String, signature: String,
}, },
Error { /// Sent when a user connects to the identity server's gateway, to signal the chat server to send READY for that user
message: String, Connect {
user_id: String,
} }
} }