feat: build out chat server websocket more, start identity websocket

This commit is contained in:
sam 2024-02-23 22:04:32 +01:00
parent 18b644d24b
commit f7494034d5
13 changed files with 300 additions and 24 deletions

38
Cargo.lock generated
View file

@ -1172,6 +1172,7 @@ dependencies = [
"color-eyre", "color-eyre",
"eyre", "eyre",
"foxchat", "foxchat",
"futures",
"rand", "rand",
"reqwest", "reqwest",
"rsa", "rsa",
@ -1180,6 +1181,7 @@ dependencies = [
"sha256", "sha256",
"sqlx", "sqlx",
"tokio", "tokio",
"tokio-tungstenite",
"toml", "toml",
"tower-http", "tower-http",
"tracing", "tracing",
@ -1833,10 +1835,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba"
dependencies = [ dependencies = [
"ring", "ring",
"rustls-webpki", "rustls-webpki 0.101.7",
"sct", "sct",
] ]
[[package]]
name = "rustls"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41"
dependencies = [
"log",
"ring",
"rustls-pki-types",
"rustls-webpki 0.102.1",
"subtle",
"zeroize",
]
[[package]] [[package]]
name = "rustls-pemfile" name = "rustls-pemfile"
version = "1.0.4" version = "1.0.4"
@ -1846,6 +1862,12 @@ dependencies = [
"base64", "base64",
] ]
[[package]]
name = "rustls-pki-types"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "048a63e5b3ac996d78d402940b5fa47973d2d080c6c6fffa1d0f19c4445310b7"
[[package]] [[package]]
name = "rustls-webpki" name = "rustls-webpki"
version = "0.101.7" version = "0.101.7"
@ -1856,6 +1878,17 @@ dependencies = [
"untrusted", "untrusted",
] ]
[[package]]
name = "rustls-webpki"
version = "0.102.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef4ca26037c909dedb327b48c3327d0ba91d3dd3c4e05dad328f210ffb68e95b"
dependencies = [
"ring",
"rustls-pki-types",
"untrusted",
]
[[package]] [[package]]
name = "rustversion" name = "rustversion"
version = "1.0.14" version = "1.0.14"
@ -2135,7 +2168,7 @@ dependencies = [
"once_cell", "once_cell",
"paste", "paste",
"percent-encoding", "percent-encoding",
"rustls", "rustls 0.21.10",
"rustls-pemfile", "rustls-pemfile",
"serde", "serde",
"serde_json", "serde_json",
@ -2484,6 +2517,7 @@ checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
dependencies = [ dependencies = [
"futures-util", "futures-util",
"log", "log",
"rustls 0.22.2",
"tokio", "tokio",
"tungstenite", "tungstenite",
] ]

View file

@ -3,7 +3,7 @@ use std::sync::Arc;
use axum::{Extension, Json}; use axum::{Extension, Json};
use foxchat::{ use foxchat::{
http::ApiError, http::ApiError,
model::{http::guild::CreateGuildParams, user::PartialUser, Guild, channel::PartialChannel}, model::{channel::PartialChannel, http::guild::CreateGuildParams, user::PartialUser, Guild},
FoxError, FoxError,
}; };
@ -39,8 +39,12 @@ pub async fn post_guilds(
instance: user.instance.domain, instance: user.instance.domain,
}, },
default_channel: PartialChannel { default_channel: PartialChannel {
id: channel.id.0.clone(),
name: channel.name.clone(),
},
channels: Some(vec![PartialChannel {
id: channel.id.0, id: channel.id.0,
name: channel.name, name: channel.name,
} }]),
})) }))
} }

View file

@ -1,4 +1,4 @@
use std::sync::Arc; use std::{sync::Arc, time::Duration};
use axum::{ use axum::{
extract::{ extract::{
@ -8,8 +8,9 @@ use axum::{
response::Response, response::Response,
Extension, Extension,
}; };
use eyre::Result; use eyre::{eyre, Result};
use foxchat::{ use foxchat::{
model::User,
s2s::{Dispatch, Payload}, s2s::{Dispatch, Payload},
signature::{parse_date, verify_signature}, signature::{parse_date, verify_signature},
}; };
@ -17,7 +18,11 @@ use futures::{
stream::{SplitSink, SplitStream, StreamExt}, stream::{SplitSink, SplitStream, StreamExt},
SinkExt, SinkExt,
}; };
use tokio::sync::{broadcast, mpsc, RwLock}; use rand::Rng;
use tokio::{
sync::{broadcast, mpsc, RwLock},
time::timeout,
};
use tracing::error; use tracing::error;
use crate::{app_state::AppState, model::identity_instance::IdentityInstance}; use crate::{app_state::AppState, model::identity_instance::IdentityInstance};
@ -28,6 +33,7 @@ pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppSt
struct SocketState { struct SocketState {
instance: Option<IdentityInstance>, instance: Option<IdentityInstance>,
last_heartbeat: u64,
} }
async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) { async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
@ -36,16 +42,21 @@ async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
let (tx, rx) = mpsc::channel::<Payload>(10); let (tx, rx) = mpsc::channel::<Payload>(10);
let dispatch = app_state.broadcast.subscribe(); let dispatch = app_state.broadcast.subscribe();
let socket_state = Arc::new(RwLock::new(SocketState { let socket_state = Arc::new(RwLock::new(SocketState {
instance: None, // Filled out after IDENTIFY // These are filled out after IDENTIFY
instance: None,
last_heartbeat: 0,
})); }));
// Function that merges the socket-local and broadcast streams for sending.
tokio::spawn(merge_channels( tokio::spawn(merge_channels(
app_state.clone(), app_state.clone(),
socket_state.clone(), socket_state.clone(),
dispatch, dispatch,
tx.clone(), tx.clone(),
)); ));
// Function that writes all payloads to the socket.
tokio::spawn(write(rx, sender)); tokio::spawn(write(rx, sender));
// Function that reads from the socket.
tokio::spawn(read( tokio::spawn(read(
app_state.clone(), app_state.clone(),
socket_state.clone(), socket_state.clone(),
@ -71,7 +82,8 @@ async fn read(
server, server,
signature, signature,
} => { } => {
let instance = match IdentityInstance::get(app_state, &server).await { let instance = match IdentityInstance::get(app_state.clone(), &server).await
{
Ok(i) => i, Ok(i) => i,
Err(e) => { Err(e) => {
error!("getting instance {}: {}", server, e); error!("getting instance {}: {}", server, e);
@ -161,18 +173,103 @@ async fn read(
return; return;
} }
// Send hello // Generate heartbeat interval and send it in the Hello payload
tx.send(Payload::Hello {}).await.ok(); let heartbeat_interval = rand::thread_rng().gen_range(45_000..70_000);
tx.send(Payload::Hello { heartbeat_interval }).await.ok();
// Start the heartbeat loop
let (heartbeat_tx, heartbeat_rx) = mpsc::channel::<u64>(10);
tokio::spawn(heartbeat(
socket_state.clone(),
heartbeat_interval,
heartbeat_rx,
tx.clone(),
));
while let Some(msg) = receiver.next().await { while let Some(msg) = receiver.next().await {
let msg = if let Ok(msg) = msg { let Ok(msg) = msg else {
msg
} else {
return; return;
}; };
// TODO: handle incoming payloads let Ok(msg) = msg.to_text() else {
tx.send(Payload::Error {
message: "Invalid message".into(),
})
.await
.ok();
return;
};
let Ok(msg) = serde_json::from_str::<Payload>(msg) else {
tx.send(Payload::Error {
message: "Invalid message".into(),
})
.await
.ok();
return;
};
match msg {
Payload::Connect { user_id } => {
match collect_ready(app_state.clone(), socket_state.clone(), &user_id).await {
Ok(p) => {
tx.send(p).await.ok();
} }
Err(e) => {
error!("Error collecting ready event data for {}: {}", user_id, e);
tx.send(Payload::Error {
message: "Internal server error".into(),
})
.await
.ok();
}
}
}
Payload::Heartbeat { timestamp } => {
// Heartbeats are handled in another function
heartbeat_tx.send(timestamp).await.ok();
}
_ => {
tx.send(Payload::Error {
message: "Invalid send event".into(),
})
.await
.ok();
return;
}
}
}
}
async fn heartbeat(
socket_state: Arc<RwLock<SocketState>>,
heartbeat_interval: u64,
mut rx: mpsc::Receiver<u64>,
tx: mpsc::Sender<Payload>,
) {
// The timeout is twice the heartbeat interval. If no heartbeat is received by then, close the connection with an error.
while let Ok(i) = timeout(Duration::from_millis(heartbeat_interval * 2), rx.recv()).await {
match i {
// ACK the heartbeat with the same timestamp. TODO: validate the timestamp to make sure we aren't too out of sync with the identity server.
Some(timestamp) => {
socket_state.write().await.last_heartbeat = timestamp;
tx.send(Payload::HeartbeatAck { timestamp }).await.ok();
}
// If the channel returns None, that means it's been closed, which means the socket as a whole was closed
None => {
return;
}
};
}
// Send an error, which will automatically close the socket
tx.send(Payload::Error {
message: format!(
"Did not receive a heartbeat after {}ms",
heartbeat_interval * 2
),
})
.await
.ok();
} }
async fn write(mut rx: mpsc::Receiver<Payload>, mut sender: SplitSink<WebSocket, Message>) { async fn write(mut rx: mpsc::Receiver<Payload>, mut sender: SplitSink<WebSocket, Message>) {
@ -289,10 +386,78 @@ async fn filter_events(
socket_state: Arc<RwLock<SocketState>>, socket_state: Arc<RwLock<SocketState>>,
evt: &Dispatch, evt: &Dispatch,
) -> Result<(bool, Vec<String>)> { ) -> Result<(bool, Vec<String>)> {
// If we're not authenticated yet, don't send anything let Some(instance) = &socket_state.read().await.instance else {
if socket_state.read().await.instance.is_none() { return Ok((false, vec![]));
};
match evt {
Dispatch::MessageCreate {
id: _,
channel_id: _,
guild_id,
author: _,
content: _,
created_at: _,
} => {
let users = sqlx::query!(
r#"SELECT ARRAY(
SELECT u.remote_user_id FROM users u
JOIN guilds_users gu ON gu.user_id = u.id
WHERE u.instance_id = $1 AND gu.guild_id = $2
)"#,
instance.id,
guild_id
)
.fetch_one(&app_state.pool)
.await?;
if let Some(users) = users.array {
return Ok((users.len() > 0, users));
}
return Ok((false, vec![])); return Ok((false, vec![]));
} }
Dispatch::Ready { user, guilds: _ } => {
let user = sqlx::query!(
"SELECT remote_user_id FROM users WHERE id = $1",
user.id.clone()
)
.fetch_one(&app_state.pool)
.await?;
Ok((true, vec![])) return Ok((true, vec![user.remote_user_id]));
}
}
}
async fn collect_ready(
app_state: Arc<AppState>,
socket_state: Arc<RwLock<SocketState>>,
user_id: &str,
) -> Result<Payload> {
let Some(instance) = &socket_state.read().await.instance else {
return Err(eyre!("instance was None when it shouldn't be"));
};
let user = sqlx::query!(
r#"SELECT u.*, i.domain FROM users u
JOIN identity_instances i ON i.id = u.instance_id
WHERE u.instance_id = $1 AND u.remote_user_id = $2"#,
instance.id,
user_id
)
.fetch_one(&app_state.pool)
.await?;
Ok(Payload::Dispatch {
event: Dispatch::Ready {
user: User {
id: user.id,
username: user.username,
instance: user.domain,
avatar_url: None,
},
guilds: vec![],
},
recipients: vec![user.remote_user_id],
})
} }

25
foxchat/src/c2s/event.rs Normal file
View file

@ -0,0 +1,25 @@
use serde::{Deserialize, Serialize};
use crate::s2s::Dispatch;
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Payload {
#[serde(rename = "D")]
Dispatch {
#[serde(rename = "e")]
event: Dispatch,
#[serde(rename = "s")]
server_id: String,
},
Error {
message: String,
},
/// Hello message, sent after authentication succeeds
Hello {
guilds: Vec<String>,
},
Identify {
token: String,
},
}

3
foxchat/src/c2s/mod.rs Normal file
View file

@ -0,0 +1,3 @@
pub mod event;
pub use event::Payload;

View file

@ -3,6 +3,7 @@ pub mod fed;
pub mod http; pub mod http;
pub mod model; pub mod model;
pub mod s2s; pub mod s2s;
pub mod c2s;
pub mod id; pub mod id;
pub use error::FoxError; pub use error::FoxError;

View file

@ -8,7 +8,7 @@ pub struct Channel {
pub topic: Option<String>, pub topic: Option<String>,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PartialChannel { pub struct PartialChannel {
pub id: String, pub id: String,
pub name: String, pub name: String,

View file

@ -1,11 +1,12 @@
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use super::{user::PartialUser, channel::PartialChannel}; use super::{channel::PartialChannel, user::PartialUser};
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Guild { pub struct Guild {
pub id: String, pub id: String,
pub name: String, pub name: String,
pub owner: PartialUser, pub owner: PartialUser,
pub default_channel: PartialChannel, pub default_channel: PartialChannel,
pub channels: Option<Vec<PartialChannel>>,
} }

View file

@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};
use crate::{ use crate::{
id::GuildType, id::GuildType,
model::{user::PartialUser, Message}, model::{user::PartialUser, Guild, Message, User},
Id, Id,
}; };
@ -18,6 +18,10 @@ pub enum Dispatch {
content: Option<String>, content: Option<String>,
created_at: DateTime<Utc>, created_at: DateTime<Utc>,
}, },
Ready {
user: User,
guilds: Vec<Guild>,
}
} }
impl Dispatch { impl Dispatch {

View file

@ -5,6 +5,7 @@ use super::Dispatch;
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")] #[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Payload { pub enum Payload {
#[serde(rename = "D")]
Dispatch { Dispatch {
#[serde(rename = "e")] #[serde(rename = "e")]
event: Dispatch, event: Dispatch,
@ -15,7 +16,9 @@ pub enum Payload {
message: String, message: String,
}, },
/// Hello message, sent after authentication succeeds /// Hello message, sent after authentication succeeds
Hello {}, Hello {
heartbeat_interval: u64,
},
/// S2S authentication. Fields correspond to headers (Host, Date, X-Foxchat-Server, X-Foxchat-Signature) /// S2S authentication. Fields correspond to headers (Host, Date, X-Foxchat-Server, X-Foxchat-Signature)
Identify { Identify {
host: String, host: String,
@ -26,5 +29,15 @@ pub enum Payload {
/// Sent when a user connects to the identity server's gateway, to signal the chat server to send READY for that user /// Sent when a user connects to the identity server's gateway, to signal the chat server to send READY for that user
Connect { Connect {
user_id: String, user_id: String,
},
/// Sent on a regular interval by the connecting server, to keep the connection alive.
Heartbeat {
#[serde(rename = "t")]
timestamp: u64,
},
/// Sent in response to a Heartbeat.
HeartbeatAck {
#[serde(rename = "t")]
timestamp: u64,
} }
} }

View file

@ -27,3 +27,5 @@ base64 = "0.21.7"
sha256 = "1.5.0" sha256 = "1.5.0"
reqwest = { version = "0.11.23", features = ["json", "gzip", "brotli", "multipart"] } reqwest = { version = "0.11.23", features = ["json", "gzip", "brotli", "multipart"] }
chrono = "0.4.31" chrono = "0.4.31"
futures = "0.3.30"
tokio-tungstenite = { version = "0.21.0", features = ["rustls"] }

View file

@ -2,6 +2,7 @@ mod account;
mod auth; mod auth;
mod node; mod node;
mod proxy; mod proxy;
mod ws;
use std::sync::Arc; use std::sync::Arc;

View file

@ -0,0 +1,23 @@
use std::sync::Arc;
use axum::{extract::{ws::WebSocket, WebSocketUpgrade}, response::Response, Extension};
use futures::{
stream::{SplitSink, SplitStream, StreamExt},
SinkExt,
};
use crate::{app_state::AppState, model::account::Account};
pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppState>>) -> Response {
ws.on_upgrade(|socket| handle_socket(state, socket))
}
struct SocketState {
user: Option<Account>,
}
async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
let (sender, receiver) = socket.split();
}