feat: build out chat server websocket more, start identity websocket
This commit is contained in:
parent
18b644d24b
commit
f7494034d5
13 changed files with 300 additions and 24 deletions
38
Cargo.lock
generated
38
Cargo.lock
generated
|
@ -1172,6 +1172,7 @@ dependencies = [
|
|||
"color-eyre",
|
||||
"eyre",
|
||||
"foxchat",
|
||||
"futures",
|
||||
"rand",
|
||||
"reqwest",
|
||||
"rsa",
|
||||
|
@ -1180,6 +1181,7 @@ dependencies = [
|
|||
"sha256",
|
||||
"sqlx",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"toml",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
|
@ -1833,10 +1835,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"rustls-webpki",
|
||||
"rustls-webpki 0.101.7",
|
||||
"sct",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.22.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41"
|
||||
dependencies = [
|
||||
"log",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki 0.102.1",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pemfile"
|
||||
version = "1.0.4"
|
||||
|
@ -1846,6 +1862,12 @@ dependencies = [
|
|||
"base64",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pki-types"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "048a63e5b3ac996d78d402940b5fa47973d2d080c6c6fffa1d0f19c4445310b7"
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.101.7"
|
||||
|
@ -1856,6 +1878,17 @@ dependencies = [
|
|||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.102.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ef4ca26037c909dedb327b48c3327d0ba91d3dd3c4e05dad328f210ffb68e95b"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.14"
|
||||
|
@ -2135,7 +2168,7 @@ dependencies = [
|
|||
"once_cell",
|
||||
"paste",
|
||||
"percent-encoding",
|
||||
"rustls",
|
||||
"rustls 0.21.10",
|
||||
"rustls-pemfile",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
@ -2484,6 +2517,7 @@ checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
|
|||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"rustls 0.22.2",
|
||||
"tokio",
|
||||
"tungstenite",
|
||||
]
|
||||
|
|
|
@ -3,7 +3,7 @@ use std::sync::Arc;
|
|||
use axum::{Extension, Json};
|
||||
use foxchat::{
|
||||
http::ApiError,
|
||||
model::{http::guild::CreateGuildParams, user::PartialUser, Guild, channel::PartialChannel},
|
||||
model::{channel::PartialChannel, http::guild::CreateGuildParams, user::PartialUser, Guild},
|
||||
FoxError,
|
||||
};
|
||||
|
||||
|
@ -39,8 +39,12 @@ pub async fn post_guilds(
|
|||
instance: user.instance.domain,
|
||||
},
|
||||
default_channel: PartialChannel {
|
||||
id: channel.id.0.clone(),
|
||||
name: channel.name.clone(),
|
||||
},
|
||||
channels: Some(vec![PartialChannel {
|
||||
id: channel.id.0,
|
||||
name: channel.name,
|
||||
}
|
||||
}]),
|
||||
}))
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::sync::Arc;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use axum::{
|
||||
extract::{
|
||||
|
@ -8,8 +8,9 @@ use axum::{
|
|||
response::Response,
|
||||
Extension,
|
||||
};
|
||||
use eyre::Result;
|
||||
use eyre::{eyre, Result};
|
||||
use foxchat::{
|
||||
model::User,
|
||||
s2s::{Dispatch, Payload},
|
||||
signature::{parse_date, verify_signature},
|
||||
};
|
||||
|
@ -17,7 +18,11 @@ use futures::{
|
|||
stream::{SplitSink, SplitStream, StreamExt},
|
||||
SinkExt,
|
||||
};
|
||||
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||
use rand::Rng;
|
||||
use tokio::{
|
||||
sync::{broadcast, mpsc, RwLock},
|
||||
time::timeout,
|
||||
};
|
||||
use tracing::error;
|
||||
|
||||
use crate::{app_state::AppState, model::identity_instance::IdentityInstance};
|
||||
|
@ -28,6 +33,7 @@ pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppSt
|
|||
|
||||
struct SocketState {
|
||||
instance: Option<IdentityInstance>,
|
||||
last_heartbeat: u64,
|
||||
}
|
||||
|
||||
async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
|
||||
|
@ -36,16 +42,21 @@ async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
|
|||
let (tx, rx) = mpsc::channel::<Payload>(10);
|
||||
let dispatch = app_state.broadcast.subscribe();
|
||||
let socket_state = Arc::new(RwLock::new(SocketState {
|
||||
instance: None, // Filled out after IDENTIFY
|
||||
// These are filled out after IDENTIFY
|
||||
instance: None,
|
||||
last_heartbeat: 0,
|
||||
}));
|
||||
|
||||
// Function that merges the socket-local and broadcast streams for sending.
|
||||
tokio::spawn(merge_channels(
|
||||
app_state.clone(),
|
||||
socket_state.clone(),
|
||||
dispatch,
|
||||
tx.clone(),
|
||||
));
|
||||
// Function that writes all payloads to the socket.
|
||||
tokio::spawn(write(rx, sender));
|
||||
// Function that reads from the socket.
|
||||
tokio::spawn(read(
|
||||
app_state.clone(),
|
||||
socket_state.clone(),
|
||||
|
@ -71,7 +82,8 @@ async fn read(
|
|||
server,
|
||||
signature,
|
||||
} => {
|
||||
let instance = match IdentityInstance::get(app_state, &server).await {
|
||||
let instance = match IdentityInstance::get(app_state.clone(), &server).await
|
||||
{
|
||||
Ok(i) => i,
|
||||
Err(e) => {
|
||||
error!("getting instance {}: {}", server, e);
|
||||
|
@ -161,20 +173,105 @@ async fn read(
|
|||
return;
|
||||
}
|
||||
|
||||
// Send hello
|
||||
tx.send(Payload::Hello {}).await.ok();
|
||||
// Generate heartbeat interval and send it in the Hello payload
|
||||
let heartbeat_interval = rand::thread_rng().gen_range(45_000..70_000);
|
||||
tx.send(Payload::Hello { heartbeat_interval }).await.ok();
|
||||
// Start the heartbeat loop
|
||||
let (heartbeat_tx, heartbeat_rx) = mpsc::channel::<u64>(10);
|
||||
tokio::spawn(heartbeat(
|
||||
socket_state.clone(),
|
||||
heartbeat_interval,
|
||||
heartbeat_rx,
|
||||
tx.clone(),
|
||||
));
|
||||
|
||||
while let Some(msg) = receiver.next().await {
|
||||
let msg = if let Ok(msg) = msg {
|
||||
msg
|
||||
} else {
|
||||
let Ok(msg) = msg else {
|
||||
return;
|
||||
};
|
||||
|
||||
// TODO: handle incoming payloads
|
||||
let Ok(msg) = msg.to_text() else {
|
||||
tx.send(Payload::Error {
|
||||
message: "Invalid message".into(),
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
return;
|
||||
};
|
||||
|
||||
let Ok(msg) = serde_json::from_str::<Payload>(msg) else {
|
||||
tx.send(Payload::Error {
|
||||
message: "Invalid message".into(),
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
return;
|
||||
};
|
||||
|
||||
match msg {
|
||||
Payload::Connect { user_id } => {
|
||||
match collect_ready(app_state.clone(), socket_state.clone(), &user_id).await {
|
||||
Ok(p) => {
|
||||
tx.send(p).await.ok();
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error collecting ready event data for {}: {}", user_id, e);
|
||||
tx.send(Payload::Error {
|
||||
message: "Internal server error".into(),
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
Payload::Heartbeat { timestamp } => {
|
||||
// Heartbeats are handled in another function
|
||||
heartbeat_tx.send(timestamp).await.ok();
|
||||
}
|
||||
_ => {
|
||||
tx.send(Payload::Error {
|
||||
message: "Invalid send event".into(),
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn heartbeat(
|
||||
socket_state: Arc<RwLock<SocketState>>,
|
||||
heartbeat_interval: u64,
|
||||
mut rx: mpsc::Receiver<u64>,
|
||||
tx: mpsc::Sender<Payload>,
|
||||
) {
|
||||
// The timeout is twice the heartbeat interval. If no heartbeat is received by then, close the connection with an error.
|
||||
while let Ok(i) = timeout(Duration::from_millis(heartbeat_interval * 2), rx.recv()).await {
|
||||
match i {
|
||||
// ACK the heartbeat with the same timestamp. TODO: validate the timestamp to make sure we aren't too out of sync with the identity server.
|
||||
Some(timestamp) => {
|
||||
socket_state.write().await.last_heartbeat = timestamp;
|
||||
tx.send(Payload::HeartbeatAck { timestamp }).await.ok();
|
||||
}
|
||||
// If the channel returns None, that means it's been closed, which means the socket as a whole was closed
|
||||
None => {
|
||||
return;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Send an error, which will automatically close the socket
|
||||
tx.send(Payload::Error {
|
||||
message: format!(
|
||||
"Did not receive a heartbeat after {}ms",
|
||||
heartbeat_interval * 2
|
||||
),
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
|
||||
async fn write(mut rx: mpsc::Receiver<Payload>, mut sender: SplitSink<WebSocket, Message>) {
|
||||
loop {
|
||||
if let Some(ev) = rx.recv().await {
|
||||
|
@ -289,10 +386,78 @@ async fn filter_events(
|
|||
socket_state: Arc<RwLock<SocketState>>,
|
||||
evt: &Dispatch,
|
||||
) -> Result<(bool, Vec<String>)> {
|
||||
// If we're not authenticated yet, don't send anything
|
||||
if socket_state.read().await.instance.is_none() {
|
||||
let Some(instance) = &socket_state.read().await.instance else {
|
||||
return Ok((false, vec![]));
|
||||
}
|
||||
};
|
||||
|
||||
Ok((true, vec![]))
|
||||
match evt {
|
||||
Dispatch::MessageCreate {
|
||||
id: _,
|
||||
channel_id: _,
|
||||
guild_id,
|
||||
author: _,
|
||||
content: _,
|
||||
created_at: _,
|
||||
} => {
|
||||
let users = sqlx::query!(
|
||||
r#"SELECT ARRAY(
|
||||
SELECT u.remote_user_id FROM users u
|
||||
JOIN guilds_users gu ON gu.user_id = u.id
|
||||
WHERE u.instance_id = $1 AND gu.guild_id = $2
|
||||
)"#,
|
||||
instance.id,
|
||||
guild_id
|
||||
)
|
||||
.fetch_one(&app_state.pool)
|
||||
.await?;
|
||||
|
||||
if let Some(users) = users.array {
|
||||
return Ok((users.len() > 0, users));
|
||||
}
|
||||
return Ok((false, vec![]));
|
||||
}
|
||||
Dispatch::Ready { user, guilds: _ } => {
|
||||
let user = sqlx::query!(
|
||||
"SELECT remote_user_id FROM users WHERE id = $1",
|
||||
user.id.clone()
|
||||
)
|
||||
.fetch_one(&app_state.pool)
|
||||
.await?;
|
||||
|
||||
return Ok((true, vec![user.remote_user_id]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn collect_ready(
|
||||
app_state: Arc<AppState>,
|
||||
socket_state: Arc<RwLock<SocketState>>,
|
||||
user_id: &str,
|
||||
) -> Result<Payload> {
|
||||
let Some(instance) = &socket_state.read().await.instance else {
|
||||
return Err(eyre!("instance was None when it shouldn't be"));
|
||||
};
|
||||
|
||||
let user = sqlx::query!(
|
||||
r#"SELECT u.*, i.domain FROM users u
|
||||
JOIN identity_instances i ON i.id = u.instance_id
|
||||
WHERE u.instance_id = $1 AND u.remote_user_id = $2"#,
|
||||
instance.id,
|
||||
user_id
|
||||
)
|
||||
.fetch_one(&app_state.pool)
|
||||
.await?;
|
||||
|
||||
Ok(Payload::Dispatch {
|
||||
event: Dispatch::Ready {
|
||||
user: User {
|
||||
id: user.id,
|
||||
username: user.username,
|
||||
instance: user.domain,
|
||||
avatar_url: None,
|
||||
},
|
||||
guilds: vec![],
|
||||
},
|
||||
recipients: vec![user.remote_user_id],
|
||||
})
|
||||
}
|
||||
|
|
25
foxchat/src/c2s/event.rs
Normal file
25
foxchat/src/c2s/event.rs
Normal file
|
@ -0,0 +1,25 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::s2s::Dispatch;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")]
|
||||
pub enum Payload {
|
||||
#[serde(rename = "D")]
|
||||
Dispatch {
|
||||
#[serde(rename = "e")]
|
||||
event: Dispatch,
|
||||
#[serde(rename = "s")]
|
||||
server_id: String,
|
||||
},
|
||||
Error {
|
||||
message: String,
|
||||
},
|
||||
/// Hello message, sent after authentication succeeds
|
||||
Hello {
|
||||
guilds: Vec<String>,
|
||||
},
|
||||
Identify {
|
||||
token: String,
|
||||
},
|
||||
}
|
3
foxchat/src/c2s/mod.rs
Normal file
3
foxchat/src/c2s/mod.rs
Normal file
|
@ -0,0 +1,3 @@
|
|||
pub mod event;
|
||||
|
||||
pub use event::Payload;
|
|
@ -3,6 +3,7 @@ pub mod fed;
|
|||
pub mod http;
|
||||
pub mod model;
|
||||
pub mod s2s;
|
||||
pub mod c2s;
|
||||
pub mod id;
|
||||
|
||||
pub use error::FoxError;
|
||||
|
|
|
@ -8,7 +8,7 @@ pub struct Channel {
|
|||
pub topic: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct PartialChannel {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use super::{user::PartialUser, channel::PartialChannel};
|
||||
use super::{channel::PartialChannel, user::PartialUser};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct Guild {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub owner: PartialUser,
|
||||
pub default_channel: PartialChannel,
|
||||
pub channels: Option<Vec<PartialChannel>>,
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};
|
|||
|
||||
use crate::{
|
||||
id::GuildType,
|
||||
model::{user::PartialUser, Message},
|
||||
model::{user::PartialUser, Guild, Message, User},
|
||||
Id,
|
||||
};
|
||||
|
||||
|
@ -18,6 +18,10 @@ pub enum Dispatch {
|
|||
content: Option<String>,
|
||||
created_at: DateTime<Utc>,
|
||||
},
|
||||
Ready {
|
||||
user: User,
|
||||
guilds: Vec<Guild>,
|
||||
}
|
||||
}
|
||||
|
||||
impl Dispatch {
|
||||
|
|
|
@ -5,6 +5,7 @@ use super::Dispatch;
|
|||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")]
|
||||
pub enum Payload {
|
||||
#[serde(rename = "D")]
|
||||
Dispatch {
|
||||
#[serde(rename = "e")]
|
||||
event: Dispatch,
|
||||
|
@ -15,7 +16,9 @@ pub enum Payload {
|
|||
message: String,
|
||||
},
|
||||
/// Hello message, sent after authentication succeeds
|
||||
Hello {},
|
||||
Hello {
|
||||
heartbeat_interval: u64,
|
||||
},
|
||||
/// S2S authentication. Fields correspond to headers (Host, Date, X-Foxchat-Server, X-Foxchat-Signature)
|
||||
Identify {
|
||||
host: String,
|
||||
|
@ -26,5 +29,15 @@ pub enum Payload {
|
|||
/// Sent when a user connects to the identity server's gateway, to signal the chat server to send READY for that user
|
||||
Connect {
|
||||
user_id: String,
|
||||
},
|
||||
/// Sent on a regular interval by the connecting server, to keep the connection alive.
|
||||
Heartbeat {
|
||||
#[serde(rename = "t")]
|
||||
timestamp: u64,
|
||||
},
|
||||
/// Sent in response to a Heartbeat.
|
||||
HeartbeatAck {
|
||||
#[serde(rename = "t")]
|
||||
timestamp: u64,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,3 +27,5 @@ base64 = "0.21.7"
|
|||
sha256 = "1.5.0"
|
||||
reqwest = { version = "0.11.23", features = ["json", "gzip", "brotli", "multipart"] }
|
||||
chrono = "0.4.31"
|
||||
futures = "0.3.30"
|
||||
tokio-tungstenite = { version = "0.21.0", features = ["rustls"] }
|
||||
|
|
|
@ -2,6 +2,7 @@ mod account;
|
|||
mod auth;
|
||||
mod node;
|
||||
mod proxy;
|
||||
mod ws;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
|
|
23
identity/src/http/ws/mod.rs
Normal file
23
identity/src/http/ws/mod.rs
Normal file
|
@ -0,0 +1,23 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use axum::{extract::{ws::WebSocket, WebSocketUpgrade}, response::Response, Extension};
|
||||
use futures::{
|
||||
stream::{SplitSink, SplitStream, StreamExt},
|
||||
SinkExt,
|
||||
};
|
||||
|
||||
use crate::{app_state::AppState, model::account::Account};
|
||||
|
||||
pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppState>>) -> Response {
|
||||
ws.on_upgrade(|socket| handle_socket(state, socket))
|
||||
}
|
||||
|
||||
struct SocketState {
|
||||
user: Option<Account>,
|
||||
}
|
||||
|
||||
async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
|
||||
let (sender, receiver) = socket.split();
|
||||
|
||||
|
||||
}
|
Loading…
Reference in a new issue