feat: build out chat server websocket more, start identity websocket
This commit is contained in:
		
							parent
							
								
									18b644d24b
								
							
						
					
					
						commit
						f7494034d5
					
				
					 13 changed files with 300 additions and 24 deletions
				
			
		
							
								
								
									
										38
									
								
								Cargo.lock
									
										
									
										generated
									
									
									
								
							
							
						
						
									
										38
									
								
								Cargo.lock
									
										
									
										generated
									
									
									
								
							|  | @ -1172,6 +1172,7 @@ dependencies = [ | |||
|  "color-eyre", | ||||
|  "eyre", | ||||
|  "foxchat", | ||||
|  "futures", | ||||
|  "rand", | ||||
|  "reqwest", | ||||
|  "rsa", | ||||
|  | @ -1180,6 +1181,7 @@ dependencies = [ | |||
|  "sha256", | ||||
|  "sqlx", | ||||
|  "tokio", | ||||
|  "tokio-tungstenite", | ||||
|  "toml", | ||||
|  "tower-http", | ||||
|  "tracing", | ||||
|  | @ -1833,10 +1835,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index" | |||
| checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" | ||||
| dependencies = [ | ||||
|  "ring", | ||||
|  "rustls-webpki", | ||||
|  "rustls-webpki 0.101.7", | ||||
|  "sct", | ||||
| ] | ||||
| 
 | ||||
| [[package]] | ||||
| name = "rustls" | ||||
| version = "0.22.2" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41" | ||||
| dependencies = [ | ||||
|  "log", | ||||
|  "ring", | ||||
|  "rustls-pki-types", | ||||
|  "rustls-webpki 0.102.1", | ||||
|  "subtle", | ||||
|  "zeroize", | ||||
| ] | ||||
| 
 | ||||
| [[package]] | ||||
| name = "rustls-pemfile" | ||||
| version = "1.0.4" | ||||
|  | @ -1846,6 +1862,12 @@ dependencies = [ | |||
|  "base64", | ||||
| ] | ||||
| 
 | ||||
| [[package]] | ||||
| name = "rustls-pki-types" | ||||
| version = "1.3.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "048a63e5b3ac996d78d402940b5fa47973d2d080c6c6fffa1d0f19c4445310b7" | ||||
| 
 | ||||
| [[package]] | ||||
| name = "rustls-webpki" | ||||
| version = "0.101.7" | ||||
|  | @ -1856,6 +1878,17 @@ dependencies = [ | |||
|  "untrusted", | ||||
| ] | ||||
| 
 | ||||
| [[package]] | ||||
| name = "rustls-webpki" | ||||
| version = "0.102.1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "ef4ca26037c909dedb327b48c3327d0ba91d3dd3c4e05dad328f210ffb68e95b" | ||||
| dependencies = [ | ||||
|  "ring", | ||||
|  "rustls-pki-types", | ||||
|  "untrusted", | ||||
| ] | ||||
| 
 | ||||
| [[package]] | ||||
| name = "rustversion" | ||||
| version = "1.0.14" | ||||
|  | @ -2135,7 +2168,7 @@ dependencies = [ | |||
|  "once_cell", | ||||
|  "paste", | ||||
|  "percent-encoding", | ||||
|  "rustls", | ||||
|  "rustls 0.21.10", | ||||
|  "rustls-pemfile", | ||||
|  "serde", | ||||
|  "serde_json", | ||||
|  | @ -2484,6 +2517,7 @@ checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" | |||
| dependencies = [ | ||||
|  "futures-util", | ||||
|  "log", | ||||
|  "rustls 0.22.2", | ||||
|  "tokio", | ||||
|  "tungstenite", | ||||
| ] | ||||
|  |  | |||
|  | @ -3,7 +3,7 @@ use std::sync::Arc; | |||
| use axum::{Extension, Json}; | ||||
| use foxchat::{ | ||||
|     http::ApiError, | ||||
|     model::{http::guild::CreateGuildParams, user::PartialUser, Guild, channel::PartialChannel}, | ||||
|     model::{channel::PartialChannel, http::guild::CreateGuildParams, user::PartialUser, Guild}, | ||||
|     FoxError, | ||||
| }; | ||||
| 
 | ||||
|  | @ -39,8 +39,12 @@ pub async fn post_guilds( | |||
|             instance: user.instance.domain, | ||||
|         }, | ||||
|         default_channel: PartialChannel { | ||||
|             id: channel.id.0.clone(), | ||||
|             name: channel.name.clone(), | ||||
|         }, | ||||
|         channels: Some(vec![PartialChannel { | ||||
|             id: channel.id.0, | ||||
|             name: channel.name, | ||||
|         } | ||||
|         }]), | ||||
|     })) | ||||
| } | ||||
|  |  | |||
|  | @ -1,4 +1,4 @@ | |||
| use std::sync::Arc; | ||||
| use std::{sync::Arc, time::Duration}; | ||||
| 
 | ||||
| use axum::{ | ||||
|     extract::{ | ||||
|  | @ -8,8 +8,9 @@ use axum::{ | |||
|     response::Response, | ||||
|     Extension, | ||||
| }; | ||||
| use eyre::Result; | ||||
| use eyre::{eyre, Result}; | ||||
| use foxchat::{ | ||||
|     model::User, | ||||
|     s2s::{Dispatch, Payload}, | ||||
|     signature::{parse_date, verify_signature}, | ||||
| }; | ||||
|  | @ -17,7 +18,11 @@ use futures::{ | |||
|     stream::{SplitSink, SplitStream, StreamExt}, | ||||
|     SinkExt, | ||||
| }; | ||||
| use tokio::sync::{broadcast, mpsc, RwLock}; | ||||
| use rand::Rng; | ||||
| use tokio::{ | ||||
|     sync::{broadcast, mpsc, RwLock}, | ||||
|     time::timeout, | ||||
| }; | ||||
| use tracing::error; | ||||
| 
 | ||||
| use crate::{app_state::AppState, model::identity_instance::IdentityInstance}; | ||||
|  | @ -28,6 +33,7 @@ pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppSt | |||
| 
 | ||||
| struct SocketState { | ||||
|     instance: Option<IdentityInstance>, | ||||
|     last_heartbeat: u64, | ||||
| } | ||||
| 
 | ||||
| async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) { | ||||
|  | @ -36,16 +42,21 @@ async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) { | |||
|     let (tx, rx) = mpsc::channel::<Payload>(10); | ||||
|     let dispatch = app_state.broadcast.subscribe(); | ||||
|     let socket_state = Arc::new(RwLock::new(SocketState { | ||||
|         instance: None, // Filled out after IDENTIFY
 | ||||
|         // These are filled out after IDENTIFY
 | ||||
|         instance: None, | ||||
|         last_heartbeat: 0, | ||||
|     })); | ||||
| 
 | ||||
|     // Function that merges the socket-local and broadcast streams for sending.
 | ||||
|     tokio::spawn(merge_channels( | ||||
|         app_state.clone(), | ||||
|         socket_state.clone(), | ||||
|         dispatch, | ||||
|         tx.clone(), | ||||
|     )); | ||||
|     // Function that writes all payloads to the socket.
 | ||||
|     tokio::spawn(write(rx, sender)); | ||||
|     // Function that reads from the socket.
 | ||||
|     tokio::spawn(read( | ||||
|         app_state.clone(), | ||||
|         socket_state.clone(), | ||||
|  | @ -71,7 +82,8 @@ async fn read( | |||
|                         server, | ||||
|                         signature, | ||||
|                     } => { | ||||
|                         let instance = match IdentityInstance::get(app_state, &server).await { | ||||
|                         let instance = match IdentityInstance::get(app_state.clone(), &server).await | ||||
|                         { | ||||
|                             Ok(i) => i, | ||||
|                             Err(e) => { | ||||
|                                 error!("getting instance {}: {}", server, e); | ||||
|  | @ -161,20 +173,105 @@ async fn read( | |||
|         return; | ||||
|     } | ||||
| 
 | ||||
|     // Send hello
 | ||||
|     tx.send(Payload::Hello {}).await.ok(); | ||||
|     // Generate heartbeat interval and send it in the Hello payload
 | ||||
|     let heartbeat_interval = rand::thread_rng().gen_range(45_000..70_000); | ||||
|     tx.send(Payload::Hello { heartbeat_interval }).await.ok(); | ||||
|     // Start the heartbeat loop
 | ||||
|     let (heartbeat_tx, heartbeat_rx) = mpsc::channel::<u64>(10); | ||||
|     tokio::spawn(heartbeat( | ||||
|         socket_state.clone(), | ||||
|         heartbeat_interval, | ||||
|         heartbeat_rx, | ||||
|         tx.clone(), | ||||
|     )); | ||||
| 
 | ||||
|     while let Some(msg) = receiver.next().await { | ||||
|         let msg = if let Ok(msg) = msg { | ||||
|             msg | ||||
|         } else { | ||||
|         let Ok(msg) = msg else { | ||||
|             return; | ||||
|         }; | ||||
| 
 | ||||
|         // TODO: handle incoming payloads
 | ||||
|         let Ok(msg) = msg.to_text() else { | ||||
|             tx.send(Payload::Error { | ||||
|                 message: "Invalid message".into(), | ||||
|             }) | ||||
|             .await | ||||
|             .ok(); | ||||
|             return; | ||||
|         }; | ||||
| 
 | ||||
|         let Ok(msg) = serde_json::from_str::<Payload>(msg) else { | ||||
|             tx.send(Payload::Error { | ||||
|                 message: "Invalid message".into(), | ||||
|             }) | ||||
|             .await | ||||
|             .ok(); | ||||
|             return; | ||||
|         }; | ||||
| 
 | ||||
|         match msg { | ||||
|             Payload::Connect { user_id } => { | ||||
|                 match collect_ready(app_state.clone(), socket_state.clone(), &user_id).await { | ||||
|                     Ok(p) => { | ||||
|                         tx.send(p).await.ok(); | ||||
|                     } | ||||
|                     Err(e) => { | ||||
|                         error!("Error collecting ready event data for {}: {}", user_id, e); | ||||
|                         tx.send(Payload::Error { | ||||
|                             message: "Internal server error".into(), | ||||
|                         }) | ||||
|                         .await | ||||
|                         .ok(); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             Payload::Heartbeat { timestamp } => { | ||||
|                 // Heartbeats are handled in another function
 | ||||
|                 heartbeat_tx.send(timestamp).await.ok(); | ||||
|             } | ||||
|             _ => { | ||||
|                 tx.send(Payload::Error { | ||||
|                     message: "Invalid send event".into(), | ||||
|                 }) | ||||
|                 .await | ||||
|                 .ok(); | ||||
|                 return; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| async fn heartbeat( | ||||
|     socket_state: Arc<RwLock<SocketState>>, | ||||
|     heartbeat_interval: u64, | ||||
|     mut rx: mpsc::Receiver<u64>, | ||||
|     tx: mpsc::Sender<Payload>, | ||||
| ) { | ||||
|     // The timeout is twice the heartbeat interval. If no heartbeat is received by then, close the connection with an error.
 | ||||
|     while let Ok(i) = timeout(Duration::from_millis(heartbeat_interval * 2), rx.recv()).await { | ||||
|         match i { | ||||
|             // ACK the heartbeat with the same timestamp. TODO: validate the timestamp to make sure we aren't too out of sync with the identity server.
 | ||||
|             Some(timestamp) => { | ||||
|                 socket_state.write().await.last_heartbeat = timestamp; | ||||
|                 tx.send(Payload::HeartbeatAck { timestamp }).await.ok(); | ||||
|             } | ||||
|             // If the channel returns None, that means it's been closed, which means the socket as a whole was closed
 | ||||
|             None => { | ||||
|                 return; | ||||
|             } | ||||
|         }; | ||||
|     } | ||||
| 
 | ||||
|     // Send an error, which will automatically close the socket
 | ||||
|     tx.send(Payload::Error { | ||||
|         message: format!( | ||||
|             "Did not receive a heartbeat after {}ms", | ||||
|             heartbeat_interval * 2 | ||||
|         ), | ||||
|     }) | ||||
|     .await | ||||
|     .ok(); | ||||
| } | ||||
| 
 | ||||
| async fn write(mut rx: mpsc::Receiver<Payload>, mut sender: SplitSink<WebSocket, Message>) { | ||||
|     loop { | ||||
|         if let Some(ev) = rx.recv().await { | ||||
|  | @ -289,10 +386,78 @@ async fn filter_events( | |||
|     socket_state: Arc<RwLock<SocketState>>, | ||||
|     evt: &Dispatch, | ||||
| ) -> Result<(bool, Vec<String>)> { | ||||
|     // If we're not authenticated yet, don't send anything
 | ||||
|     if socket_state.read().await.instance.is_none() { | ||||
|     let Some(instance) = &socket_state.read().await.instance else { | ||||
|         return Ok((false, vec![])); | ||||
|     } | ||||
|     }; | ||||
| 
 | ||||
|     Ok((true, vec![])) | ||||
|     match evt { | ||||
|         Dispatch::MessageCreate { | ||||
|             id: _, | ||||
|             channel_id: _, | ||||
|             guild_id, | ||||
|             author: _, | ||||
|             content: _, | ||||
|             created_at: _, | ||||
|         } => { | ||||
|             let users = sqlx::query!( | ||||
|                 r#"SELECT ARRAY(
 | ||||
|                 SELECT u.remote_user_id FROM users u | ||||
|                 JOIN guilds_users gu ON gu.user_id = u.id | ||||
|                 WHERE u.instance_id = $1 AND gu.guild_id = $2 | ||||
|             )"#,
 | ||||
|                 instance.id, | ||||
|                 guild_id | ||||
|             ) | ||||
|             .fetch_one(&app_state.pool) | ||||
|             .await?; | ||||
| 
 | ||||
|             if let Some(users) = users.array { | ||||
|                 return Ok((users.len() > 0, users)); | ||||
|             } | ||||
|             return Ok((false, vec![])); | ||||
|         } | ||||
|         Dispatch::Ready { user, guilds: _ } => { | ||||
|             let user = sqlx::query!( | ||||
|                 "SELECT remote_user_id FROM users WHERE id = $1", | ||||
|                 user.id.clone() | ||||
|             ) | ||||
|             .fetch_one(&app_state.pool) | ||||
|             .await?; | ||||
| 
 | ||||
|             return Ok((true, vec![user.remote_user_id])); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| async fn collect_ready( | ||||
|     app_state: Arc<AppState>, | ||||
|     socket_state: Arc<RwLock<SocketState>>, | ||||
|     user_id: &str, | ||||
| ) -> Result<Payload> { | ||||
|     let Some(instance) = &socket_state.read().await.instance else { | ||||
|         return Err(eyre!("instance was None when it shouldn't be")); | ||||
|     }; | ||||
| 
 | ||||
|     let user = sqlx::query!( | ||||
|         r#"SELECT u.*, i.domain FROM users u
 | ||||
|         JOIN identity_instances i ON i.id = u.instance_id | ||||
|         WHERE u.instance_id = $1 AND u.remote_user_id = $2"#,
 | ||||
|         instance.id, | ||||
|         user_id | ||||
|     ) | ||||
|     .fetch_one(&app_state.pool) | ||||
|     .await?; | ||||
| 
 | ||||
|     Ok(Payload::Dispatch { | ||||
|         event: Dispatch::Ready { | ||||
|             user: User { | ||||
|                 id: user.id, | ||||
|                 username: user.username, | ||||
|                 instance: user.domain, | ||||
|                 avatar_url: None, | ||||
|             }, | ||||
|             guilds: vec![], | ||||
|         }, | ||||
|         recipients: vec![user.remote_user_id], | ||||
|     }) | ||||
| } | ||||
|  |  | |||
							
								
								
									
										25
									
								
								foxchat/src/c2s/event.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								foxchat/src/c2s/event.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,25 @@ | |||
| use serde::{Deserialize, Serialize}; | ||||
| 
 | ||||
| use crate::s2s::Dispatch; | ||||
| 
 | ||||
| #[derive(Debug, Serialize, Deserialize)] | ||||
| #[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")] | ||||
| pub enum Payload { | ||||
|     #[serde(rename = "D")] | ||||
|     Dispatch { | ||||
|         #[serde(rename = "e")] | ||||
|         event: Dispatch, | ||||
|         #[serde(rename = "s")] | ||||
|         server_id: String, | ||||
|     }, | ||||
|     Error { | ||||
|         message: String, | ||||
|     }, | ||||
|     /// Hello message, sent after authentication succeeds
 | ||||
|     Hello { | ||||
|         guilds: Vec<String>, | ||||
|     }, | ||||
|     Identify { | ||||
|         token: String, | ||||
|     }, | ||||
| } | ||||
							
								
								
									
										3
									
								
								foxchat/src/c2s/mod.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								foxchat/src/c2s/mod.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,3 @@ | |||
| pub mod event; | ||||
| 
 | ||||
| pub use event::Payload; | ||||
|  | @ -3,6 +3,7 @@ pub mod fed; | |||
| pub mod http; | ||||
| pub mod model; | ||||
| pub mod s2s; | ||||
| pub mod c2s; | ||||
| pub mod id; | ||||
| 
 | ||||
| pub use error::FoxError; | ||||
|  |  | |||
|  | @ -8,7 +8,7 @@ pub struct Channel { | |||
|     pub topic: Option<String>, | ||||
| } | ||||
| 
 | ||||
| #[derive(Serialize, Deserialize, Debug)] | ||||
| #[derive(Serialize, Deserialize, Debug, Clone)] | ||||
| pub struct PartialChannel { | ||||
|     pub id: String, | ||||
|     pub name: String, | ||||
|  |  | |||
|  | @ -1,11 +1,12 @@ | |||
| use serde::{Serialize, Deserialize}; | ||||
| 
 | ||||
| use super::{user::PartialUser, channel::PartialChannel}; | ||||
| use super::{channel::PartialChannel, user::PartialUser}; | ||||
| 
 | ||||
| #[derive(Serialize, Deserialize, Debug)] | ||||
| #[derive(Serialize, Deserialize, Debug, Clone)] | ||||
| pub struct Guild { | ||||
|     pub id: String, | ||||
|     pub name: String, | ||||
|     pub owner: PartialUser, | ||||
|     pub default_channel: PartialChannel, | ||||
|     pub channels: Option<Vec<PartialChannel>>, | ||||
| } | ||||
|  |  | |||
|  | @ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; | |||
| 
 | ||||
| use crate::{ | ||||
|     id::GuildType, | ||||
|     model::{user::PartialUser, Message}, | ||||
|     model::{user::PartialUser, Guild, Message, User}, | ||||
|     Id, | ||||
| }; | ||||
| 
 | ||||
|  | @ -18,6 +18,10 @@ pub enum Dispatch { | |||
|         content: Option<String>, | ||||
|         created_at: DateTime<Utc>, | ||||
|     }, | ||||
|     Ready { | ||||
|         user: User, | ||||
|         guilds: Vec<Guild>, | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Dispatch { | ||||
|  |  | |||
|  | @ -5,6 +5,7 @@ use super::Dispatch; | |||
| #[derive(Debug, Serialize, Deserialize)] | ||||
| #[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")] | ||||
| pub enum Payload { | ||||
|     #[serde(rename = "D")] | ||||
|     Dispatch { | ||||
|         #[serde(rename = "e")] | ||||
|         event: Dispatch, | ||||
|  | @ -15,7 +16,9 @@ pub enum Payload { | |||
|         message: String, | ||||
|     }, | ||||
|     /// Hello message, sent after authentication succeeds
 | ||||
|     Hello {}, | ||||
|     Hello { | ||||
|         heartbeat_interval: u64, | ||||
|     }, | ||||
|     /// S2S authentication. Fields correspond to headers (Host, Date, X-Foxchat-Server, X-Foxchat-Signature)
 | ||||
|     Identify { | ||||
|         host: String, | ||||
|  | @ -26,5 +29,15 @@ pub enum Payload { | |||
|     /// Sent when a user connects to the identity server's gateway, to signal the chat server to send READY for that user
 | ||||
|     Connect { | ||||
|         user_id: String, | ||||
|     }, | ||||
|     /// Sent on a regular interval by the connecting server, to keep the connection alive.
 | ||||
|     Heartbeat { | ||||
|         #[serde(rename = "t")] | ||||
|         timestamp: u64, | ||||
|     }, | ||||
|     /// Sent in response to a Heartbeat.
 | ||||
|     HeartbeatAck { | ||||
|         #[serde(rename = "t")] | ||||
|         timestamp: u64, | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -27,3 +27,5 @@ base64 = "0.21.7" | |||
| sha256 = "1.5.0" | ||||
| reqwest = { version = "0.11.23", features = ["json", "gzip", "brotli", "multipart"] } | ||||
| chrono = "0.4.31" | ||||
| futures = "0.3.30" | ||||
| tokio-tungstenite = { version = "0.21.0", features = ["rustls"] } | ||||
|  |  | |||
|  | @ -2,6 +2,7 @@ mod account; | |||
| mod auth; | ||||
| mod node; | ||||
| mod proxy; | ||||
| mod ws; | ||||
| 
 | ||||
| use std::sync::Arc; | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										23
									
								
								identity/src/http/ws/mod.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								identity/src/http/ws/mod.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,23 @@ | |||
| use std::sync::Arc; | ||||
| 
 | ||||
| use axum::{extract::{ws::WebSocket, WebSocketUpgrade}, response::Response, Extension}; | ||||
| use futures::{ | ||||
|     stream::{SplitSink, SplitStream, StreamExt}, | ||||
|     SinkExt, | ||||
| }; | ||||
| 
 | ||||
| use crate::{app_state::AppState, model::account::Account}; | ||||
| 
 | ||||
| pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppState>>) -> Response { | ||||
|     ws.on_upgrade(|socket| handle_socket(state, socket)) | ||||
| } | ||||
| 
 | ||||
| struct SocketState { | ||||
|     user: Option<Account>, | ||||
| } | ||||
| 
 | ||||
| async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) { | ||||
|     let (sender, receiver) = socket.split(); | ||||
| 
 | ||||
|     
 | ||||
| } | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue