maybe (barely) working websocket now?

This commit is contained in:
sam 2024-01-22 02:18:34 +01:00
parent f07333e358
commit 0858d4893a
Signed by: sam
GPG key ID: B4EF20DDE721CAA1
5 changed files with 253 additions and 41 deletions

View file

@ -8,38 +8,55 @@ use axum::{
response::Response,
Extension,
};
use foxchat::s2s::{Dispatch, Payload};
use futures::stream::{SplitSink, SplitStream, StreamExt};
use tokio::sync::{broadcast, mpsc};
use eyre::Result;
use foxchat::{
s2s::{Dispatch, Payload},
signature::{parse_date, verify_signature},
};
use futures::{
stream::{SplitSink, SplitStream, StreamExt},
SinkExt,
};
use tokio::sync::{broadcast, mpsc, RwLock};
use tracing::error;
use crate::app_state::AppState;
use crate::{app_state::AppState, model::identity_instance::IdentityInstance};
pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppState>>) -> Response {
ws.on_upgrade(|socket| handle_socket(state, socket))
}
struct SocketState {
instance_id: Option<String>,
instance: Option<IdentityInstance>,
}
async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
let (mut sender, mut receiver) = socket.split();
let (sender, receiver) = socket.split();
let (tx, rx) = mpsc::channel::<Payload>(10);
let dispatch = app_state.broadcast.subscribe();
let socket_state = Arc::new(SocketState {
instance_id: None, // Filled out after IDENTIFY
});
let socket_state = Arc::new(RwLock::new(SocketState {
instance: None, // Filled out after IDENTIFY
}));
tokio::spawn(merge_channels(app_state.clone(), socket_state.clone(), dispatch, tx.clone()));
tokio::spawn(write(app_state.clone(), socket_state.clone(), rx, sender));
tokio::spawn(read(app_state.clone(), socket_state.clone(), tx.clone(), receiver));
tokio::spawn(merge_channels(
app_state.clone(),
socket_state.clone(),
dispatch,
tx.clone(),
));
tokio::spawn(write(rx, sender));
tokio::spawn(read(
app_state.clone(),
socket_state.clone(),
tx.clone(),
receiver,
));
}
async fn read(
app_state: Arc<AppState>,
socket_state: Arc<SocketState>,
socket_state: Arc<RwLock<SocketState>>,
tx: mpsc::Sender<Payload>,
mut receiver: SplitStream<WebSocket>,
) {
@ -48,26 +65,105 @@ async fn read(
if let Ok(msg) = msg.into_text() {
if let Ok(pl) = serde_json::from_str::<Payload>(&msg) {
match pl {
Payload::Identify { host, date, server, signature } => {
// TODO: identify
},
Payload::Identify {
host,
date,
server,
signature,
} => {
let instance = match IdentityInstance::get(app_state, &server).await {
Ok(i) => i,
Err(e) => {
error!("getting instance {}: {}", server, e);
tx.send(Payload::Error {
message: "Unknown instance".into(),
})
.await
.ok();
return;
}
};
let public_key = match instance.parse_public_key() {
Ok(k) => k,
Err(e) => {
error!("parsing public key for instance: {}", e);
tx.send(Payload::Error {
message: "Internal server error".into(),
})
.await
.ok();
return;
}
};
let date = match parse_date(&date) {
Ok(d) => d,
Err(_) => {
tx.send(Payload::Error {
message: "Invalid date".into(),
})
.await
.ok();
return;
}
};
if let Err(e) = verify_signature(
&public_key,
signature,
date,
&host,
"/_fox/chat/ws",
None,
None,
) {
error!("Verifying signature in websocket from {}: {}", &server, e);
tx.send(Payload::Error {
message: "Invalid signature".into(),
})
.await
.ok();
return;
}
// Everything is good, store instance
socket_state.write().await.instance = Some(instance);
}
_ => {
tx.send(Payload::Error { message: "First payload was not IDENTIFY".into() }).await.ok();
tx.send(Payload::Error {
message: "First payload was not IDENTIFY".into(),
})
.await
.ok();
return;
}
}
} else {
tx.send(Payload::Error { message: "Invalid JSON payload".into() }).await.ok();
tx.send(Payload::Error {
message: "Invalid JSON payload".into(),
})
.await
.ok();
return;
}
} else {
tx.send(Payload::Error {
message: "First payload was not IDENTIFY".into(),
})
.await
.ok();
return;
}
} else {
// Websocket closed, return
return;
}
// Send hello
tx.send(Payload::Hello {}).await.ok();
while let Some(msg) = receiver.next().await {
let msg = if let Ok(msg) = msg {
msg
@ -79,41 +175,124 @@ async fn read(
}
}
async fn write(
app_state: Arc<AppState>,
socket_state: Arc<SocketState>,
mut rx: mpsc::Receiver<Payload>,
mut sender: SplitSink<WebSocket, Message>,
) {
async fn write(mut rx: mpsc::Receiver<Payload>, mut sender: SplitSink<WebSocket, Message>) {
loop {
if let Some(ev) = rx.recv().await {
match &ev {
Payload::Error { message: _ } => {
let msg = match to_json(ev) {
Ok(m) => m,
Err(e) => {
error!("error serializing payload to JSON: {}", e);
return;
}
};
match sender.send(msg).await {
Ok(_) => {}
// Pretty sure we can assume the websocket was closed, so just return
Err(e) => {
error!("error writing to websocket: {}", e);
return;
}
}
match sender.close().await {
Ok(_) => {}
Err(e) => {
error!("error closing websocket: {}", e)
}
}
return;
}
_ => {
// Forward payload
let msg = match to_json(ev) {
Ok(m) => m,
Err(e) => {
error!("error serializing payload to JSON: {}", e);
return;
}
};
match sender.send(msg).await {
Ok(_) => {}
// Pretty sure we can assume the websocket was closed, so just return
Err(e) => {
error!("error writing to websocket: {}", e);
return;
}
}
}
}
}
}
}
async fn merge_channels(
app_state: Arc<AppState>,
socket_state: Arc<SocketState>,
socket_state: Arc<RwLock<SocketState>>,
mut rx: broadcast::Receiver<Dispatch>,
tx: mpsc::Sender<Payload>,
) {
loop {
let msg = rx.recv().await;
match msg {
Ok(p) => {
// TODO: filter users
Ok(evt) => {
let (send, recipients) =
match filter_events(app_state.clone(), socket_state.clone(), &evt).await {
Ok(v) => v,
Err(e) => {
error!("Checking whether to send an event: {}", e);
tx.send(Payload::Error {
message: "Internal server error".into(),
})
.await
.ok();
continue;
}
};
if !send {
continue;
}
tx.send(Payload::Dispatch {
event: p,
recipients: vec![],
event: evt,
recipients,
})
.await
.ok();
},
}
Err(e) => match e {
broadcast::error::RecvError::Closed => {
error!("Broadcast channel was closed, this is not supposed to happen");
return;
},
}
broadcast::error::RecvError::Lagged(i) => {
error!("Broadcast receive lagged by {i}");
}
}
},
}
}
}
fn to_json(p: Payload) -> Result<Message> {
Ok(Message::Text(serde_json::to_string(&p)?))
}
/// Returns whether or not an event should be sent at all, and if it should be, which users to send it to.
async fn filter_events(
app_state: Arc<AppState>,
socket_state: Arc<RwLock<SocketState>>,
evt: &Dispatch,
) -> Result<(bool, Vec<String>)> {
// If we're not authenticated yet, don't send anything
if socket_state.read().await.instance.is_none() {
return Ok((false, vec![]));
}
Ok((true, vec![]))
}

View file

@ -3,7 +3,7 @@ use serde::{Serialize, Deserialize};
use super::user::PartialUser;
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Message {
pub id: String,
pub channel_id: String,

View file

@ -1,6 +1,6 @@
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct User {
pub id: String,
pub username: String,
@ -8,7 +8,7 @@ pub struct User {
pub avatar_url: Option<String>,
}
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PartialUser {
pub id: String,
pub username: String,

View file

@ -1,7 +1,34 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::{
id::GuildType,
model::{user::PartialUser, Message},
Id,
};
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Dispatch {
MessageCreate
MessageCreate {
id: String,
channel_id: String,
guild_id: String,
author: PartialUser,
content: Option<String>,
created_at: DateTime<Utc>,
},
}
impl Dispatch {
pub fn message_create(m: Message, guild_id: Id<GuildType>) -> Self {
Self::MessageCreate {
id: m.id,
channel_id: m.channel_id,
guild_id: guild_id.0,
author: m.author,
content: m.content,
created_at: m.created_at,
}
}
}

View file

@ -11,14 +11,20 @@ pub enum Payload {
#[serde(rename = "r")]
recipients: Vec<String>,
},
Hello,
Error {
message: String,
},
/// Hello message, sent after authentication succeeds
Hello {},
/// S2S authentication. Fields correspond to headers (Host, Date, X-Foxchat-Server, X-Foxchat-Signature)
Identify {
host: String,
date: String,
server: String,
signature: String,
},
Error {
message: String,
/// Sent when a user connects to the identity server's gateway, to signal the chat server to send READY for that user
Connect {
user_id: String,
}
}