feat: make identity server store user instance/guild IDs
This commit is contained in:
		
							parent
							
								
									42abd70184
								
							
						
					
					
						commit
						fd77dd01fa
					
				
					 19 changed files with 245 additions and 29 deletions
				
			
		|  | @ -1,3 +1,7 @@ | |||
| mod rest; | ||||
| 
 | ||||
| pub use rest::post_rest_event; | ||||
| 
 | ||||
| use std::sync::Arc; | ||||
| 
 | ||||
| use axum::{ | ||||
|  |  | |||
							
								
								
									
										38
									
								
								chat/src/fed/rest.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								chat/src/fed/rest.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,38 @@ | |||
| use std::sync::Arc; | ||||
| 
 | ||||
| use eyre::Result; | ||||
| 
 | ||||
| use foxchat::{fed, s2s::http::RestEvent}; | ||||
| use tracing::{debug, error}; | ||||
| 
 | ||||
| use crate::{app_state::AppState, model::identity_instance::IdentityInstance}; | ||||
| 
 | ||||
| /// Posts an event to a remote instance's inbox.
 | ||||
| pub async fn post_rest_event( | ||||
|     state: Arc<AppState>, | ||||
|     instance: &IdentityInstance, | ||||
|     event: RestEvent, | ||||
| ) -> Result<()> { | ||||
|     debug!("Sending {:?} event to {}'s inbox", &event, &instance.domain); | ||||
| 
 | ||||
|     match fed::post::<RestEvent, ()>( | ||||
|         &state.private_key, | ||||
|         &state.config.domain, | ||||
|         &instance.domain, | ||||
|         "/_fox/ident/inbox", | ||||
|         None, | ||||
|         &event, | ||||
|     ) | ||||
|     .await | ||||
|     { | ||||
|         Ok(_) => Ok(()), | ||||
|         Err(e) => { | ||||
|             error!( | ||||
|                 "Error sending {:?} event to {}'s inbox: {}", | ||||
|                 &event, &instance.domain, e | ||||
|             ); | ||||
| 
 | ||||
|             return Err(e); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | @ -4,13 +4,17 @@ use axum::{Extension, Json}; | |||
| use foxchat::{ | ||||
|     http::ApiError, | ||||
|     model::{channel::PartialChannel, http::guild::CreateGuildParams, Guild}, | ||||
|     s2s::http::RestEvent, | ||||
|     FoxError, | ||||
| }; | ||||
| 
 | ||||
| use crate::{ | ||||
|     app_state::AppState, | ||||
|     db::{channel::create_channel, guild::{create_guild, join_guild}}, | ||||
|     fed::FoxRequestData, | ||||
|     db::{ | ||||
|         channel::create_channel, | ||||
|         guild::{create_guild, join_guild}, | ||||
|     }, | ||||
|     fed::{post_rest_event, FoxRequestData}, | ||||
|     model::user::User, | ||||
| }; | ||||
| 
 | ||||
|  | @ -27,9 +31,20 @@ pub async fn post_guilds( | |||
|     let channel = create_channel(&mut *tx, &guild.id, "general", None).await?; | ||||
| 
 | ||||
|     join_guild(&mut *tx, &guild.id, &user.id).await?; | ||||
| 
 | ||||
|     tx.commit().await?; | ||||
| 
 | ||||
|     // Send an event to the user's instance that they joined a guild
 | ||||
|     post_rest_event( | ||||
|         state, | ||||
|         &request.instance, | ||||
|         RestEvent::GuildJoin { | ||||
|             guild_id: guild.id.0.clone(), | ||||
|             user_id: user.remote_user_id.clone(), | ||||
|         }, | ||||
|     ) | ||||
|     .await | ||||
|     .ok(); | ||||
| 
 | ||||
|     Ok(Json(Guild { | ||||
|         id: guild.id.0, | ||||
|         name: guild.name, | ||||
|  |  | |||
|  | @ -3,6 +3,7 @@ use std::sync::Arc; | |||
| use axum::{Extension, Json}; | ||||
| use eyre::Context; | ||||
| use foxchat::{ | ||||
|     error::ToFoxError, | ||||
|     fed, | ||||
|     http::ApiError, | ||||
|     s2s::http::{HelloRequest, HelloResponse, NodeResponse}, | ||||
|  | @ -30,7 +31,8 @@ pub async fn post_hello( | |||
|         "/_fox/ident/node", | ||||
|         None, | ||||
|     ) | ||||
|     .await?; | ||||
|     .await? | ||||
|     .to_fox_error()?; | ||||
|     let public_key = | ||||
|         RsaPublicKey::from_pkcs1_pem(&node.public_key).wrap_err("parsing remote public key")?; | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,7 +1,7 @@ | |||
| use std::sync::Arc; | ||||
| 
 | ||||
| use eyre::Result; | ||||
| use foxchat::{fed, model::User as HttpUser, Id, id::UserType}; | ||||
| use foxchat::{error::ToFoxError, fed, id::UserType, model::User as HttpUser, Id}; | ||||
| use ulid::Ulid; | ||||
| 
 | ||||
| use crate::app_state::AppState; | ||||
|  | @ -48,7 +48,8 @@ impl User { | |||
|             &format!("/_fox/ident/users/{}", remote_id), | ||||
|             None, | ||||
|         ) | ||||
|         .await?; | ||||
|         .await? | ||||
|         .to_fox_error()?; | ||||
| 
 | ||||
|         let user = sqlx::query!( | ||||
|             "insert into users (id, instance_id, remote_user_id, username, avatar) values ($1, $2, $3, $4, $5) returning *", | ||||
|  |  | |||
|  | @ -18,7 +18,8 @@ pub enum Payload { | |||
|     /// Hello message, sent after authentication succeeds
 | ||||
|     Hello { | ||||
|         heartbeat_interval: u64, | ||||
|         guilds: Vec<String>, | ||||
|         guilds: Vec<GuildInstance>, | ||||
|         instances: Vec<InstanceId>, | ||||
|     }, | ||||
|     Identify { | ||||
|         token: String, | ||||
|  | @ -34,3 +35,15 @@ pub enum Payload { | |||
|         timestamp: u64, | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Serialize, Deserialize)] | ||||
| pub struct InstanceId { | ||||
|     pub id: String, | ||||
|     pub domain: String, | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Serialize, Deserialize)] | ||||
| pub struct GuildInstance { | ||||
|     pub guild_id: String, | ||||
|     pub instance_id: String, | ||||
| } | ||||
|  |  | |||
|  | @ -1,5 +1,6 @@ | |||
| use axum::http::header::ToStrError; | ||||
| use thiserror::Error; | ||||
| use eyre::Result; | ||||
| 
 | ||||
| #[derive(Error, Debug, Copy, Clone)] | ||||
| pub enum FoxError { | ||||
|  | @ -31,6 +32,8 @@ pub enum FoxError { | |||
|     ChannelNotFound, | ||||
|     #[error("internal server error while proxying")] | ||||
|     ProxyInternalServerError, | ||||
|     #[error("invalid RestEvent for this instance type")] | ||||
|     InvalidRestEvent, | ||||
| } | ||||
| 
 | ||||
| impl From<ToStrError> for FoxError { | ||||
|  | @ -44,3 +47,16 @@ impl From<chrono::ParseError> for FoxError { | |||
|         Self::InvalidDate | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| pub trait ToFoxError<T> { | ||||
|     fn to_fox_error(self) -> Result<T>; | ||||
| } | ||||
| 
 | ||||
| impl<T> ToFoxError<T> for Option<T> { | ||||
|     fn to_fox_error(self) -> Result<T> { | ||||
|         match self { | ||||
|             Some(t) => Ok(t), | ||||
|             None => Err(FoxError::ResponseNotOk.into()) | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -34,7 +34,7 @@ pub async fn get<R: DeserializeOwned>( | |||
|     host: &str, | ||||
|     path: &str, | ||||
|     user_id: Option<String>, | ||||
| ) -> Result<R> { | ||||
| ) -> Result<Option<R>> { | ||||
|     let (signature, date) = build_signature(private_key, host, path, None, user_id.clone()); | ||||
| 
 | ||||
|     let mut req = CLIENT | ||||
|  | @ -60,7 +60,7 @@ pub async fn post<T: Serialize, R: DeserializeOwned>( | |||
|     path: &str, | ||||
|     user_id: Option<String>, | ||||
|     body: &T, | ||||
| ) -> Result<R> { | ||||
| ) -> Result<Option<R>> { | ||||
|     let body = serde_json::to_string(body)?; | ||||
| 
 | ||||
|     let (signature, date) = | ||||
|  | @ -85,7 +85,11 @@ pub async fn post<T: Serialize, R: DeserializeOwned>( | |||
|     handle_response(resp).await.wrap_err("handling response") | ||||
| } | ||||
| 
 | ||||
| async fn handle_response<R: DeserializeOwned>(resp: Response) -> Result<R, ResponseError> { | ||||
| async fn handle_response<R: DeserializeOwned>(resp: Response) -> Result<Option<R>, ResponseError> { | ||||
|     if resp.status() == StatusCode::NO_CONTENT { | ||||
|         return Ok(None); | ||||
|     } | ||||
| 
 | ||||
|     if resp.status() != StatusCode::OK { | ||||
|         let status = resp.status().as_u16(); | ||||
| 
 | ||||
|  | @ -105,7 +109,7 @@ async fn handle_response<R: DeserializeOwned>(resp: Response) -> Result<R, Respo | |||
|         .json::<R>() | ||||
|         .await | ||||
|         .map_err(|_| ResponseError::JsonError)?; | ||||
|     Ok(parsed) | ||||
|     Ok(Some(parsed)) | ||||
| } | ||||
| 
 | ||||
| #[derive(thiserror::Error, Debug, Clone)] | ||||
|  |  | |||
|  | @ -45,8 +45,6 @@ fn plaintext_string( | |||
|         raw_time, host, request_path, raw_content_length, raw_user_id | ||||
|     ); | ||||
| 
 | ||||
|     println!("{}", s); | ||||
| 
 | ||||
|     s | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -44,6 +44,7 @@ pub enum ErrorCode { | |||
|     MissingSignature, | ||||
|     GuildNotFound, | ||||
|     Unauthorized, | ||||
|     InvalidRestEvent, | ||||
| } | ||||
| 
 | ||||
| impl From<sqlx::Error> for ApiError { | ||||
|  | @ -152,7 +153,12 @@ impl From<FoxError> for ApiError { | |||
|                 status: StatusCode::INTERNAL_SERVER_ERROR, | ||||
|                 code: ErrorCode::InternalServerError, | ||||
|                 message: "Internal server error".into(), | ||||
|             } | ||||
|             }, | ||||
|             FoxError::InvalidRestEvent => ApiError { | ||||
|                 status: StatusCode::BAD_REQUEST, | ||||
|                 code: ErrorCode::InvalidRestEvent, | ||||
|                 message: "Invalid RestEvent for this instance type".into(), | ||||
|             }, | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -1,3 +1,5 @@ | |||
| mod hello; | ||||
| mod rest_event; | ||||
| 
 | ||||
| pub use hello::{HelloRequest, HelloResponse, NodeResponse}; | ||||
| pub use rest_event::RestEvent; | ||||
|  |  | |||
							
								
								
									
										18
									
								
								foxchat/src/s2s/http/rest_event.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								foxchat/src/s2s/http/rest_event.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,18 @@ | |||
| use serde::{Deserialize, Serialize}; | ||||
| 
 | ||||
| #[derive(Debug, Serialize, Deserialize)] | ||||
| #[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")] | ||||
| pub enum RestEvent { | ||||
|     /// Sent when a user creates or joins a guild (chat -> identity)
 | ||||
|     GuildJoin { guild_id: String, user_id: String }, | ||||
|     /// Sent when a user leaves or is removed from a guild (chat -> identity)
 | ||||
|     GuildLeave { guild_id: String, user_id: String }, | ||||
|     /// Sent when a user updates their profile (identity -> chat)
 | ||||
|     UserUpdate { | ||||
|         user_id: String, | ||||
|         username: Option<String>, | ||||
|         avatar: Option<String>, | ||||
|     }, | ||||
|     /// Sent when a user deletes their account (identity -> chat)
 | ||||
|     UserDelete { user_id: String }, | ||||
| } | ||||
|  | @ -2,4 +2,4 @@ | |||
| fn main() { | ||||
|     // trigger recompilation when a new migration is added
 | ||||
|     println!("cargo:rerun-if-changed=migrations"); | ||||
| } | ||||
| } | ||||
|  |  | |||
							
								
								
									
										7
									
								
								identity/migrations/20240228151314_store_guilds.sql
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								identity/migrations/20240228151314_store_guilds.sql
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,7 @@ | |||
| create table guilds_accounts ( | ||||
|     chat_instance_id text not null references chat_instances (id) on delete cascade, | ||||
|     guild_id text not null, | ||||
|     account_id text not null references accounts (id) on delete cascade, | ||||
| 
 | ||||
|     primary key (chat_instance_id, guild_id, account_id) | ||||
| ); | ||||
							
								
								
									
										59
									
								
								identity/src/http/inbox.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								identity/src/http/inbox.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,59 @@ | |||
| use std::sync::Arc; | ||||
| 
 | ||||
| use axum::{http::StatusCode, Extension, Json}; | ||||
| use foxchat::{http::ApiError, s2s::http::RestEvent, FoxError}; | ||||
| use tracing::info; | ||||
| 
 | ||||
| use crate::{app_state::AppState, fed::FoxRequestData}; | ||||
| 
 | ||||
| pub async fn post_inbox( | ||||
|     Extension(state): Extension<Arc<AppState>>, | ||||
|     request: FoxRequestData, | ||||
|     Json(evt): Json<RestEvent>, | ||||
| ) -> Result<(StatusCode, &'static str), ApiError> { | ||||
|     match evt { | ||||
|         RestEvent::GuildJoin { guild_id, user_id } => { | ||||
|             info!( | ||||
|                 "received GUILD_JOIN event from {}/{} for user {}/guild {}", | ||||
|                 &request.instance.id, &request.instance.domain, &user_id, &guild_id, | ||||
|             ); | ||||
| 
 | ||||
|             let mut tx = state.pool.begin().await?; | ||||
| 
 | ||||
|             sqlx::query!( | ||||
|                 r#"INSERT INTO chat_instance_accounts
 | ||||
|                 (account_id, chat_instance_id) VALUES ($1, $2) | ||||
|                 ON CONFLICT (account_id, chat_instance_id) DO NOTHING"#,
 | ||||
|                 &user_id, | ||||
|                 &request.instance.id | ||||
|             ) | ||||
|             .execute(&mut *tx) | ||||
|             .await?; | ||||
| 
 | ||||
|             sqlx::query!( | ||||
|                 r#"INSERT INTO guilds_accounts (chat_instance_id, guild_id, account_id)
 | ||||
|                 VALUES ($1, $2, $3) ON CONFLICT (chat_instance_id, guild_id, account_id) DO NOTHING"#,
 | ||||
|                 &request.instance.id, | ||||
|                 &guild_id, | ||||
|                 &user_id | ||||
|             ) | ||||
|             .execute(&mut *tx) | ||||
|             .await?; | ||||
| 
 | ||||
|             tx.commit().await?; | ||||
|         } | ||||
|         RestEvent::GuildLeave { guild_id, user_id } => { | ||||
|             sqlx::query!( | ||||
|                 "DELETE FROM guilds_accounts WHERE chat_instance_id = $1 AND guild_id = $2 AND account_id = $3", | ||||
|                 &request.instance.id, | ||||
|                 &guild_id, | ||||
|                 &user_id, | ||||
|             ).execute(&state.pool).await?; | ||||
|         } | ||||
|         _ => { | ||||
|             return Err(FoxError::InvalidServer.into()); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     Ok((StatusCode::NO_CONTENT, "")) | ||||
| } | ||||
|  | @ -1,12 +1,13 @@ | |||
| mod account; | ||||
| mod auth; | ||||
| mod inbox; | ||||
| mod node; | ||||
| mod proxy; | ||||
| mod ws; | ||||
| 
 | ||||
| use std::sync::Arc; | ||||
| 
 | ||||
| use axum::{routing::get, Extension, Router}; | ||||
| use axum::{routing::{get, post}, Extension, Router}; | ||||
| use sqlx::{Pool, Postgres}; | ||||
| use tower_http::trace::TraceLayer; | ||||
| 
 | ||||
|  | @ -26,6 +27,7 @@ pub fn new(pool: Pool<Postgres>, config: Config, instance: Instance) -> Router { | |||
|         .route("/_fox/ident/node", get(node::get_node)) | ||||
|         .route("/_fox/ident/node/:domain", get(node::get_chat_node)) | ||||
|         .route("/_fox/ident/ws", get(ws::handler)) | ||||
|         .route("/_fox/ident/inbox", post(inbox::post_inbox)) | ||||
|         .layer(TraceLayer::new_for_http()) | ||||
|         .layer(Extension(app_state)); | ||||
| 
 | ||||
|  |  | |||
|  | @ -37,7 +37,6 @@ async fn proxy_get<R: Serialize + DeserializeOwned>( | |||
|     let proxy_path = original | ||||
|         .strip_prefix("/_fox/proxy/") | ||||
|         .wrap_err("invalid url")?; | ||||
|     println!("{}", proxy_path); | ||||
| 
 | ||||
|     let resp = fed::get::<R>( | ||||
|         &state.private_key, | ||||
|  | @ -49,7 +48,10 @@ async fn proxy_get<R: Serialize + DeserializeOwned>( | |||
|     .await; | ||||
| 
 | ||||
|     match resp { | ||||
|         Ok(r) => return Ok(Json(r)), | ||||
|         Ok(r) => match r { | ||||
|             Some(r) => Ok(Json(r)), | ||||
|             None => Err(FoxError::ResponseNotOk.into()), | ||||
|         }, | ||||
|         Err(e) => { | ||||
|             if let Some(e) = e.downcast_ref::<ResponseError>() { | ||||
|                 match e { | ||||
|  | @ -102,7 +104,10 @@ async fn proxy_post<B: Serialize, R: Serialize + DeserializeOwned>( | |||
|     .await; | ||||
| 
 | ||||
|     match resp { | ||||
|         Ok(r) => return Ok(Json(r)), | ||||
|         Ok(r) => match r { | ||||
|             Some(r) => Ok(Json(r)), | ||||
|             None => Err(FoxError::ResponseNotOk.into()), | ||||
|         }, | ||||
|         Err(e) => { | ||||
|             if let Some(e) = e.downcast_ref::<ResponseError>() { | ||||
|                 match e { | ||||
|  |  | |||
|  | @ -9,13 +9,16 @@ use axum::{ | |||
|     Extension, | ||||
| }; | ||||
| use eyre::Result; | ||||
| use foxchat::{c2s::Payload, FoxError}; | ||||
| use foxchat::{c2s::{Payload, event::{InstanceId, GuildInstance}}, FoxError}; | ||||
| use futures::{ | ||||
|     stream::{SplitSink, SplitStream, StreamExt}, | ||||
|     SinkExt, | ||||
| }; | ||||
| use rand::Rng; | ||||
| use tokio::{sync::mpsc::{self, Receiver, Sender}, time::timeout}; | ||||
| use tokio::{ | ||||
|     sync::mpsc::{self, Receiver, Sender}, | ||||
|     time::timeout, | ||||
| }; | ||||
| use tracing::error; | ||||
| 
 | ||||
| use crate::{app_state::AppState, db::check_token, model::account::Account}; | ||||
|  | @ -124,17 +127,42 @@ async fn read( | |||
|     tokio::spawn(write(app_state.clone(), user.clone(), sender, rx)); | ||||
| 
 | ||||
|     // Send HELLO event
 | ||||
|     // TODO: fetch guild IDs
 | ||||
|     let heartbeat_interval = rand::thread_rng().gen_range(45_000..70_000); | ||||
|     tx.send(Payload::Hello { heartbeat_interval, guilds: vec![] }).await.ok(); | ||||
|     // Fetch the instance ID/domain pairs
 | ||||
|     let instances = match sqlx::query_as!(InstanceId, r#"SELECT i.id, i.domain FROM chat_instances i
 | ||||
|     JOIN chat_instance_accounts ia ON ia.chat_instance_id = i.id | ||||
|     WHERE ia.account_id = $1"#, &user.id).fetch_all(&app_state.pool).await {
 | ||||
|         Ok(i) => i, | ||||
|         Err(e) => { | ||||
|             error!("error getting instances for user {}: {}", &user.id, e); | ||||
|             tx.send(Payload::Error { message: "Internal server error".into() }).await.ok(); | ||||
|             return; | ||||
|         } | ||||
|     }; | ||||
| 
 | ||||
|     // Fetch guild ID/instance ID pairs
 | ||||
|     let guilds = match sqlx::query_as!(GuildInstance, r#"SELECT guild_id, chat_instance_id AS instance_id
 | ||||
|     FROM guilds_accounts WHERE account_id = $1"#, &user.id).fetch_all(&app_state.pool).await {
 | ||||
|         Ok(g) => g, | ||||
|         Err(e) => { | ||||
|             error!("error getting guilds for user {}: {}", &user.id, e); | ||||
|             tx.send(Payload::Error { message: "Internal server error".into() }).await.ok(); | ||||
|             return; | ||||
|         } | ||||
|     }; | ||||
| 
 | ||||
|     tx.send(Payload::Hello { | ||||
|         heartbeat_interval, | ||||
|         guilds, | ||||
|         instances, | ||||
|     }) | ||||
|     .await | ||||
|     .ok(); | ||||
| 
 | ||||
|     // Start the heartbeat loop
 | ||||
|     let (heartbeat_tx, heartbeat_rx) = mpsc::channel::<u64>(10); | ||||
|     tokio::spawn(heartbeat(heartbeat_interval, heartbeat_rx, tx.clone())); | ||||
| 
 | ||||
|     // Fire off ready event
 | ||||
|     tokio::spawn(collect_ready(app_state.clone(), user.clone(), tx.clone())); | ||||
| 
 | ||||
|     // Start listening for events
 | ||||
|     while let Some(msg) = receiver.next().await { | ||||
|         let Ok(msg) = msg else { | ||||
|  | @ -237,8 +265,6 @@ async fn write( | |||
|     } | ||||
| } | ||||
| 
 | ||||
| async fn collect_ready(app_state: Arc<AppState>, user: Arc<Account>, tx: Sender<Payload>) {} | ||||
| 
 | ||||
| async fn heartbeat( | ||||
|     heartbeat_interval: u64, | ||||
|     mut rx: mpsc::Receiver<u64>, | ||||
|  |  | |||
|  | @ -58,7 +58,7 @@ impl ChatInstance { | |||
|                 host: state.config.domain.clone(), | ||||
|             }, | ||||
|         ) | ||||
|         .await?; | ||||
|         .await?.expect(""); | ||||
| 
 | ||||
|         if resp.host != domain { | ||||
|             return Err(FoxError::InvalidServer.into()); | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue