add request verification extractor
This commit is contained in:
		
							parent
							
								
									7a694623e5
								
							
						
					
					
						commit
						1e53661b0a
					
				
					 18 changed files with 482 additions and 32 deletions
				
			
		|  | @ -5,3 +5,6 @@ members = [ | |||
|   "chat" | ||||
| ] | ||||
| resolver = "2" | ||||
| 
 | ||||
| [profile.dev.package.num-bigint-dig] | ||||
| opt-level = 3 | ||||
|  |  | |||
|  | @ -14,7 +14,7 @@ create table users ( | |||
|     instance_id text not null references identity_instances (id) on delete cascade, | ||||
|     remote_user_id text not null, | ||||
|     username       text not null, | ||||
|      | ||||
| 
 | ||||
|     avatar text -- URL, not hash, as this is a remote file | ||||
| ); | ||||
| 
 | ||||
|  | @ -46,6 +46,14 @@ create table messages ( | |||
|     channel_id text not null references channels (id) on delete cascade, | ||||
|     author_id  text not null, | ||||
|     updated_at timestamptz not null default now(), | ||||
|      | ||||
| 
 | ||||
|     content text not null | ||||
| ); | ||||
| 
 | ||||
| create table instance ( | ||||
|     id integer not null primary key default 1, | ||||
|     public_key  text not null, | ||||
|     private_key text not null, | ||||
| 
 | ||||
|     constraint singleton check (id = 1) | ||||
| ); | ||||
|  |  | |||
							
								
								
									
										8
									
								
								chat/src/app_state.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								chat/src/app_state.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,8 @@ | |||
| use sqlx::{Pool, Postgres}; | ||||
| 
 | ||||
| use crate::config::Config; | ||||
| 
 | ||||
| pub struct AppState { | ||||
|     pub pool: Pool<Postgres>, | ||||
|     pub config: Config, | ||||
| } | ||||
							
								
								
									
										56
									
								
								chat/src/config.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								chat/src/config.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,56 @@ | |||
| use eyre::Result; | ||||
| use serde::Deserialize; | ||||
| use std::path::Path; | ||||
| use std::{env, fs}; | ||||
| use tracing::Level; | ||||
| 
 | ||||
| pub const CONFIG_FILE: &str = "config.chat.toml"; | ||||
| 
 | ||||
| #[derive(Deserialize)] | ||||
| pub struct Config { | ||||
|     pub database_url: String, | ||||
|     pub port: u16, | ||||
|     pub domain: String, | ||||
|     pub auto_migrate: Option<bool>, | ||||
|     pub log_level: Option<String>, | ||||
| } | ||||
| 
 | ||||
| impl Config { | ||||
|     pub fn load() -> Result<Self> { | ||||
|         let cwd = env::current_dir()?; | ||||
|         let config_file = Path::join(cwd.as_path(), Path::new(CONFIG_FILE)); | ||||
|         println!("config file: {}", config_file.display()); | ||||
| 
 | ||||
|         let s = fs::read_to_string(config_file)?; | ||||
|         let config = toml::from_str(s.as_str())?; | ||||
| 
 | ||||
|         Ok(config) | ||||
|     } | ||||
| 
 | ||||
|     pub fn tracing_level(&self) -> Option<Level> { | ||||
|         match self | ||||
|             .log_level | ||||
|             .as_deref() | ||||
|             .unwrap_or("INFO") | ||||
|         { | ||||
|             "TRACE" => Some(Level::TRACE), | ||||
|             "DEBUG" => Some(Level::DEBUG), | ||||
|             "INFO" => Some(Level::INFO), | ||||
|             "WARN" => Some(Level::WARN), | ||||
|             "ERROR" => Some(Level::ERROR), | ||||
|             _ => None | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Default for Config { | ||||
|     fn default() -> Self { | ||||
|         Config { | ||||
|             database_url: env::var("DATABASE_URL").unwrap_or("".into()), | ||||
|             domain: "".into(), | ||||
|             port: 3000, | ||||
|             auto_migrate: None, | ||||
|             log_level: None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
							
								
								
									
										50
									
								
								chat/src/db/mod.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								chat/src/db/mod.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,50 @@ | |||
| use eyre::{OptionExt, Result}; | ||||
| use rsa::pkcs1::{EncodeRsaPrivateKey, EncodeRsaPublicKey, LineEnding}; | ||||
| use rsa::{RsaPrivateKey, RsaPublicKey}; | ||||
| use sqlx::postgres::PgPoolOptions; | ||||
| use sqlx::{Pool, Postgres}; | ||||
| use std::time::Duration; | ||||
| 
 | ||||
| pub async fn init(dsn: &str) -> Result<Pool<Postgres>> { | ||||
|     let pool = PgPoolOptions::new() | ||||
|         .acquire_timeout(Duration::from_secs(2)) // Fail fast and don't hang
 | ||||
|         .max_connections(100) | ||||
|         .connect(dsn) | ||||
|         .await?; | ||||
| 
 | ||||
|     Ok(pool) | ||||
| } | ||||
| 
 | ||||
| const PRIVATE_KEY_BITS: usize = 2048; | ||||
| 
 | ||||
| pub async fn init_instance(pool: &Pool<Postgres>) -> Result<()> { | ||||
|     let mut tx = pool.begin().await?; | ||||
| 
 | ||||
|     // Check if we already have an instance configuration
 | ||||
|     let row = sqlx::query!("select exists(select * from instance)") | ||||
|         .fetch_one(&mut *tx) | ||||
|         .await?; | ||||
|     if row.exists.ok_or_eyre("exists was null")? { | ||||
|         return Ok(()); | ||||
|     } | ||||
| 
 | ||||
|     // Generate public/private key
 | ||||
|     let mut rng = rand::thread_rng(); | ||||
|     let priv_key = RsaPrivateKey::new(&mut rng, PRIVATE_KEY_BITS)?; | ||||
|     let pub_key = RsaPublicKey::from(&priv_key); | ||||
| 
 | ||||
|     let priv_key_string = priv_key.to_pkcs1_pem(LineEnding::LF)?; | ||||
|     let pub_key_string = pub_key.to_pkcs1_pem(LineEnding::LF)?; | ||||
| 
 | ||||
|     sqlx::query!( | ||||
|         "insert into instance (public_key, private_key) values ($1, $2)", | ||||
|         pub_key_string, | ||||
|         priv_key_string.to_string(), | ||||
|     ) | ||||
|     .execute(&mut *tx) | ||||
|     .await?; | ||||
| 
 | ||||
|     tx.commit().await?; | ||||
| 
 | ||||
|     Ok(()) | ||||
| } | ||||
							
								
								
									
										102
									
								
								chat/src/fed/mod.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										102
									
								
								chat/src/fed/mod.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,102 @@ | |||
| use std::sync::Arc; | ||||
| 
 | ||||
| use axum::{ | ||||
|     async_trait, | ||||
|     extract::FromRequestParts, | ||||
|     http::{ | ||||
|         header::{CONTENT_LENGTH, DATE, HOST}, | ||||
|         request::Parts, | ||||
|     }, | ||||
|     Extension, | ||||
| }; | ||||
| use foxchat::{ | ||||
|     fed::{SERVER_HEADER, SIGNATURE_HEADER, USER_HEADER}, | ||||
|     http::ApiError, | ||||
|     signature::{parse_date, verify_signature}, | ||||
|     FoxError, | ||||
| }; | ||||
| use tracing::error; | ||||
| 
 | ||||
| use crate::{app_state::AppState, model::identity_instance::IdentityInstance}; | ||||
| 
 | ||||
| pub struct FoxRequestData { | ||||
|     pub instance: IdentityInstance, | ||||
|     pub user_id: Option<String>, | ||||
| } | ||||
| 
 | ||||
| #[async_trait] | ||||
| impl<S> FromRequestParts<S> for FoxRequestData | ||||
| where | ||||
|     S: Send + Sync, | ||||
| { | ||||
|     type Rejection = ApiError; | ||||
| 
 | ||||
|     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { | ||||
|         let state: Extension<Arc<AppState>> = Extension::from_request_parts(parts, state) | ||||
|             .await | ||||
|             .expect("AppState was not added as an extension"); | ||||
| 
 | ||||
|         let domain = parts | ||||
|             .headers | ||||
|             .get(SERVER_HEADER) | ||||
|             .ok_or(FoxError::InvalidHeader)? | ||||
|             .to_str()?; | ||||
|         let instance = IdentityInstance::get(state.0, domain).await?; | ||||
|         let public_key = instance.parse_public_key()?; | ||||
| 
 | ||||
|         let date = parse_date( | ||||
|             parts | ||||
|                 .headers | ||||
|                 .get(DATE) | ||||
|                 .ok_or(FoxError::InvalidHeader)? | ||||
|                 .to_str()?, | ||||
|         )?; | ||||
|         let signature = parts | ||||
|             .headers | ||||
|             .get(SIGNATURE_HEADER) | ||||
|             .ok_or(FoxError::MissingSignature)? | ||||
|             .to_str()? | ||||
|             .to_string(); | ||||
|         let host = parts | ||||
|             .headers | ||||
|             .get(HOST) | ||||
|             .ok_or(FoxError::InvalidHeader)? | ||||
|             .to_str()?; | ||||
| 
 | ||||
|         let content_length = if let Some(raw_length) = parts.headers.get(CONTENT_LENGTH) { | ||||
|             Some(raw_length.to_str()?.parse::<usize>()?) | ||||
|         } else { | ||||
|             None | ||||
|         }; | ||||
| 
 | ||||
|         let user_id = if let Some(raw_id) = parts.headers.get(USER_HEADER) { | ||||
|             Some(raw_id.to_str()?) | ||||
|         } else { | ||||
|             None | ||||
|         }; | ||||
| 
 | ||||
|         if let Err(e) = verify_signature( | ||||
|             &public_key, | ||||
|             signature, | ||||
|             date, | ||||
|             host, | ||||
|             parts.uri.path(), | ||||
|             content_length, | ||||
|             user_id, | ||||
|         ) { | ||||
|             error!( | ||||
|                 "Verifying signature from request for {} from {}: {}", | ||||
|                 parts.uri.path(), | ||||
|                 domain, | ||||
|                 e | ||||
|             ); | ||||
| 
 | ||||
|             return Err(FoxError::InvalidSignature.into()); | ||||
|         } | ||||
| 
 | ||||
|         Ok(FoxRequestData { | ||||
|             instance, | ||||
|             user_id: user_id.map(|v| v.to_string()), | ||||
|         }) | ||||
|     } | ||||
| } | ||||
|  | @ -1,3 +1,65 @@ | |||
| fn main() { | ||||
|     println!("Hello, world!"); | ||||
| mod config; | ||||
| mod db; | ||||
| mod fed; | ||||
| mod app_state; | ||||
| mod model; | ||||
| 
 | ||||
| use crate::config::Config; | ||||
| use clap::{Parser, Subcommand}; | ||||
| use eyre::Result; | ||||
| use tracing::info; | ||||
| 
 | ||||
| #[derive(Debug, Parser)] | ||||
| struct Cli { | ||||
|     #[command(subcommand)] | ||||
|     command: Option<Command>, | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Subcommand)] | ||||
| enum Command { | ||||
|     Serve, | ||||
|     Migrate, | ||||
| } | ||||
| 
 | ||||
| #[tokio::main] | ||||
| async fn main() -> Result<()> { | ||||
|     color_eyre::install()?; | ||||
| 
 | ||||
|     let config = Config::load()?; | ||||
|     let args = Cli::parse(); | ||||
| 
 | ||||
|     tracing_subscriber::fmt() | ||||
|         .with_max_level(config.tracing_level().unwrap_or(tracing::Level::INFO)) | ||||
|         .init(); | ||||
| 
 | ||||
|     match args.command.unwrap_or(Command::Serve) { | ||||
|         Command::Serve => main_web(config).await, | ||||
|         Command::Migrate => main_migrate(config).await, | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| async fn main_migrate(config: Config) -> Result<()> { | ||||
|     info!("Connecting to database"); | ||||
|     let pool = db::init(&config.database_url).await?; | ||||
|     info!("Migrating database"); | ||||
|     sqlx::migrate!().run(&pool).await?; | ||||
|     info!("Migrated database"); | ||||
|     Ok(()) | ||||
| } | ||||
| 
 | ||||
| async fn main_web(config: Config) -> Result<()> { | ||||
|     info!("Connecting to database"); | ||||
|     let pool = db::init(&config.database_url).await?; | ||||
| 
 | ||||
|     if config.auto_migrate.unwrap_or(false) { | ||||
|         info!("Auto-migrate is enabled, migrating database"); | ||||
|         sqlx::migrate!().run(&pool).await?; | ||||
|         info!("Migrated database"); | ||||
|     } | ||||
| 
 | ||||
|     info!("Initializing instance data"); | ||||
|     db::init_instance(&pool).await?; | ||||
|     info!("Initialized instance data!"); | ||||
| 
 | ||||
|     Ok(()) | ||||
| } | ||||
|  |  | |||
							
								
								
									
										52
									
								
								chat/src/model/identity_instance.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								chat/src/model/identity_instance.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,52 @@ | |||
| use std::sync::Arc; | ||||
| 
 | ||||
| use eyre::{Result, Context}; | ||||
| use foxchat::{fed::request::is_valid_domain, FoxError}; | ||||
| use rsa::{RsaPublicKey, pkcs1::DecodeRsaPublicKey}; | ||||
| use serde::{Deserialize, Serialize}; | ||||
| 
 | ||||
| use crate::app_state::AppState; | ||||
| 
 | ||||
| #[derive(Serialize)] | ||||
| pub struct IdentityInstance { | ||||
|     pub id: String, | ||||
|     pub domain: String, | ||||
|     pub base_url: String, | ||||
|     pub public_key: String, | ||||
|     pub status: InstanceStatus, | ||||
|     pub reason: Option<String>, | ||||
| } | ||||
| 
 | ||||
| #[derive(Serialize, Deserialize, Debug, sqlx::Type)] | ||||
| #[sqlx(type_name = "instance_status", rename_all = "lowercase")] | ||||
| pub enum InstanceStatus { | ||||
|     Active, | ||||
|     Suspended, | ||||
| } | ||||
| 
 | ||||
| impl IdentityInstance { | ||||
|     pub async fn get(state: Arc<AppState>, domain: &str) -> Result<Self> { | ||||
|         if !is_valid_domain(domain) { | ||||
|             return Err(FoxError::InvalidServer.into()); | ||||
|         } | ||||
| 
 | ||||
|         if let Some(instance) = sqlx::query_as!( | ||||
|             Self, | ||||
|             r#"select id, domain, base_url, public_key,
 | ||||
|             status as "status: InstanceStatus", reason | ||||
|             from identity_instances where domain = $1"#,
 | ||||
|             domain | ||||
|         ) | ||||
|         .fetch_optional(&state.pool) | ||||
|         .await? | ||||
|         { | ||||
|             return Ok(instance); | ||||
|         } | ||||
| 
 | ||||
|         return Err(FoxError::InvalidServer.into()); | ||||
|     } | ||||
| 
 | ||||
|     pub fn parse_public_key(&self) -> Result<RsaPublicKey> { | ||||
|         RsaPublicKey::from_pkcs1_pem(&self.public_key).wrap_err("parsing identity instance public key") | ||||
|     } | ||||
| } | ||||
							
								
								
									
										27
									
								
								chat/src/model/instance.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								chat/src/model/instance.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,27 @@ | |||
| use eyre::Result; | ||||
| use sqlx::{Pool, Postgres}; | ||||
| use rsa::{RsaPrivateKey, RsaPublicKey, pkcs1::{DecodeRsaPublicKey, DecodeRsaPrivateKey}}; | ||||
| 
 | ||||
| #[derive(Debug, Clone)] | ||||
| pub struct Instance { | ||||
|     pub public_key: RsaPublicKey, | ||||
|     pub private_key: RsaPrivateKey, | ||||
| } | ||||
| 
 | ||||
| impl Instance { | ||||
|     /// Gets the instance's configuration.
 | ||||
|     /// This is a singleton row that is always present.
 | ||||
|     pub async fn get(pool: &Pool<Postgres>) -> Result<Self> { | ||||
|         let instance = sqlx::query!("SELECT * FROM instance WHERE id = 1") | ||||
|             .fetch_one(pool) | ||||
|             .await?; | ||||
| 
 | ||||
|         let public_key = RsaPublicKey::from_pkcs1_pem(&instance.public_key)?; | ||||
|         let private_key = RsaPrivateKey::from_pkcs1_pem(&instance.private_key)?; | ||||
| 
 | ||||
|         Ok(Self { | ||||
|             public_key, | ||||
|             private_key, | ||||
|         }) | ||||
|     } | ||||
| } | ||||
							
								
								
									
										2
									
								
								chat/src/model/mod.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								chat/src/model/mod.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,2 @@ | |||
| pub mod instance; | ||||
| pub mod identity_instance; | ||||
|  | @ -1,3 +1,4 @@ | |||
| database_url = "postgresql://foxchat:password@localhost/foxchat_chat_dev" | ||||
| port = 3001 | ||||
| port = 7610 | ||||
| log_level = "DEBUG" | ||||
| domain = "chat.fox.localhost" | ||||
|  |  | |||
|  | @ -1,4 +1,4 @@ | |||
| database_url = "postgresql://foxchat:password@localhost/foxchat_ident_dev" | ||||
| port = 3000 | ||||
| port = 7611 | ||||
| log_level = "DEBUG" | ||||
| domain = "chat.foxchat.localhost" | ||||
| domain = "id.fox.localhost" | ||||
|  |  | |||
|  | @ -1,3 +1,4 @@ | |||
| use axum::http::header::ToStrError; | ||||
| use thiserror::Error; | ||||
| 
 | ||||
| #[derive(Error, Debug, Copy, Clone)] | ||||
|  | @ -10,4 +11,24 @@ pub enum FoxError { | |||
|     ResponseNotOk, | ||||
|     #[error("server is invalid")] | ||||
|     InvalidServer, | ||||
|     #[error("invalid header")] | ||||
|     InvalidHeader, | ||||
|     #[error("invalid date format")] | ||||
|     InvalidDate, | ||||
|     #[error("missing signature")] | ||||
|     MissingSignature, | ||||
|     #[error("invalid signature")] | ||||
|     InvalidSignature, | ||||
| } | ||||
| 
 | ||||
| impl From<ToStrError> for FoxError { | ||||
|     fn from(_: ToStrError) -> Self { | ||||
|         Self::InvalidHeader | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl From<chrono::ParseError> for FoxError { | ||||
|     fn from(_: chrono::ParseError) -> Self { | ||||
|         Self::InvalidDate | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -2,3 +2,7 @@ pub mod request; | |||
| pub mod signature; | ||||
| 
 | ||||
| pub use request::{get, post}; | ||||
| 
 | ||||
| pub const SERVER_HEADER: &'static str = "X-Foxchat-Server"; | ||||
| pub const SIGNATURE_HEADER: &'static str = "X-Foxchat-Signature"; | ||||
| pub const USER_HEADER: &'static str = "X-Foxchat-User"; | ||||
|  |  | |||
|  | @ -1,14 +1,17 @@ | |||
| use eyre::Result; | ||||
| use once_cell::sync::Lazy; | ||||
| use reqwest::{Response, StatusCode}; | ||||
| use reqwest::{Response, StatusCode, header::{CONTENT_TYPE, CONTENT_LENGTH, DATE}}; | ||||
| use rsa::RsaPrivateKey; | ||||
| use serde::{de::DeserializeOwned, Serialize}; | ||||
| use tracing::error; | ||||
| 
 | ||||
| use crate::{ | ||||
|     signature::{build_signature, format_date}, | ||||
|     FoxError, | ||||
| }; | ||||
| 
 | ||||
| use super::{SIGNATURE_HEADER, USER_HEADER, SERVER_HEADER}; | ||||
| 
 | ||||
| static APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); | ||||
| 
 | ||||
| static CLIENT: Lazy<reqwest::Client> = Lazy::new(|| { | ||||
|  | @ -27,7 +30,7 @@ pub async fn get<R: DeserializeOwned>( | |||
|     self_domain: String, | ||||
|     host: String, | ||||
|     path: String, | ||||
|     user_id: Option<String>, | ||||
|     user_id: Option<&str>, | ||||
| ) -> Result<R> { | ||||
|     let (signature, date) = build_signature( | ||||
|         private_key, | ||||
|  | @ -39,12 +42,12 @@ pub async fn get<R: DeserializeOwned>( | |||
| 
 | ||||
|     let mut req = CLIENT | ||||
|         .get(format!("https://{host}{path}")) | ||||
|         .header("Date", format_date(date)) | ||||
|         .header("X-Foxchat-Signature", signature) | ||||
|         .header("X-Foxchat-Server", self_domain); | ||||
|         .header(DATE, format_date(date)) | ||||
|         .header(SIGNATURE_HEADER, signature) | ||||
|         .header(SERVER_HEADER, self_domain); | ||||
| 
 | ||||
|     req = if let Some(id) = user_id { | ||||
|         req.header("X-Foxchat-User", id) | ||||
|         req.header(USER_HEADER, id) | ||||
|     } else { | ||||
|         req | ||||
|     }; | ||||
|  | @ -58,7 +61,7 @@ pub async fn post<T: Serialize, R: DeserializeOwned>( | |||
|     self_domain: String, | ||||
|     host: String, | ||||
|     path: String, | ||||
|     user_id: Option<String>, | ||||
|     user_id: Option<&str>, | ||||
|     body: &T, | ||||
| ) -> Result<R> { | ||||
|     let body = serde_json::to_string(body)?; | ||||
|  | @ -73,15 +76,15 @@ pub async fn post<T: Serialize, R: DeserializeOwned>( | |||
| 
 | ||||
|     let mut req = CLIENT | ||||
|         .post(format!("https://{host}{path}")) | ||||
|         .header("Date", format_date(date)) | ||||
|         .header("X-Foxchat-Signature", signature) | ||||
|         .header("X-Foxchat-Server", self_domain) | ||||
|         .header("Content-Type", "application/json; charset=utf-8") | ||||
|         .header("Content-Length", body.len().to_string()) | ||||
|         .header(DATE, format_date(date)) | ||||
|         .header(SIGNATURE_HEADER, signature) | ||||
|         .header(SERVER_HEADER, self_domain) | ||||
|         .header(CONTENT_TYPE, "application/json; charset=utf-8") | ||||
|         .header(CONTENT_LENGTH, body.len().to_string()) | ||||
|         .body(body); | ||||
| 
 | ||||
|     req = if let Some(id) = user_id { | ||||
|         req.header("X-Foxchat-User", id) | ||||
|         req.header(USER_HEADER, id) | ||||
|     } else { | ||||
|         req | ||||
|     }; | ||||
|  | @ -92,6 +95,7 @@ pub async fn post<T: Serialize, R: DeserializeOwned>( | |||
| 
 | ||||
| async fn handle_response<R: DeserializeOwned>(resp: Response) -> Result<R> { | ||||
|     if resp.status() != StatusCode::OK { | ||||
|         error!("federation request failed with status code {}", resp.status()); | ||||
|         return Err(FoxError::ResponseNotOk.into()); | ||||
|     } | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,5 +1,5 @@ | |||
| use base64::prelude::{Engine, BASE64_URL_SAFE}; | ||||
| use chrono::{DateTime, Utc, Duration}; | ||||
| use chrono::{DateTime, Utc, Duration, NaiveDateTime}; | ||||
| use eyre::Result; | ||||
| use rsa::pkcs1v15::{SigningKey, VerifyingKey, Signature}; | ||||
| use rsa::sha2::Sha256; | ||||
|  | @ -13,13 +13,13 @@ pub fn build_signature( | |||
|     host: String, | ||||
|     request_path: String, | ||||
|     content_length: Option<usize>, | ||||
|     user_id: Option<String>, | ||||
|     user_id: Option<&str>, | ||||
| ) -> (String, DateTime<Utc>) { | ||||
|     let mut rng = rand::thread_rng(); | ||||
|     let signing_key = SigningKey::<Sha256>::new(private_key.clone()); | ||||
|     let time = Utc::now(); | ||||
| 
 | ||||
|     let plaintext = plaintext_string(time, host, request_path, content_length, user_id); | ||||
|     let plaintext = plaintext_string(time, &host, &request_path, content_length, user_id); | ||||
|     let signature = signing_key.sign_with_rng(&mut rng, plaintext.as_bytes()); | ||||
| 
 | ||||
|     let str = BASE64_URL_SAFE.encode(signature.to_bytes()); | ||||
|  | @ -29,16 +29,16 @@ pub fn build_signature( | |||
| 
 | ||||
| fn plaintext_string( | ||||
|     time: DateTime<Utc>, | ||||
|     host: String, | ||||
|     request_path: String, | ||||
|     host: &str, | ||||
|     request_path: &str, | ||||
|     content_length: Option<usize>, | ||||
|     user_id: Option<String>, | ||||
|     user_id: Option<&str>, | ||||
| ) -> String { | ||||
|     let raw_time = format_date(time); | ||||
|     let raw_content_length = content_length | ||||
|         .map(|i| i.to_string()) | ||||
|         .unwrap_or("".to_owned()); | ||||
|     let raw_user_id = user_id.unwrap_or("".to_owned()); | ||||
|     let raw_user_id = user_id.unwrap_or(""); | ||||
| 
 | ||||
|     format!( | ||||
|         "{}:{}:{}:{}:{}", | ||||
|  | @ -50,14 +50,18 @@ pub fn format_date(time: DateTime<Utc>) -> String { | |||
|     time.format("%a, %d %b %Y %T GMT").to_string() | ||||
| } | ||||
| 
 | ||||
| pub fn parse_date(input: &str) -> Result<DateTime<Utc>, chrono::ParseError> { | ||||
|     Ok(NaiveDateTime::parse_from_str(input, "%a, %d %b %Y %T GMT")?.and_utc()) | ||||
| } | ||||
| 
 | ||||
| pub fn verify_signature( | ||||
|     public_key: &RsaPublicKey, | ||||
|     encoded_signature: String, | ||||
|     time: DateTime<Utc>, // from Date header
 | ||||
|     host: String, // from Host header, verify that it's actually your host
 | ||||
|     request_path: String, // from router
 | ||||
|     host: &str, // from Host header, verify that it's actually your host
 | ||||
|     request_path: &str, // from router
 | ||||
|     content_length: Option<usize>, // from Content-Length header
 | ||||
|     user_id: Option<String>, // from X-Foxchat-User header
 | ||||
|     user_id: Option<&str>, // from X-Foxchat-User header
 | ||||
| ) -> Result<bool> { | ||||
|     let verifying_key = VerifyingKey::<Sha256>::new(public_key.clone()); | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,5 +1,7 @@ | |||
| use std::num::ParseIntError; | ||||
| 
 | ||||
| use axum::{ | ||||
|     http::StatusCode, | ||||
|     http::{StatusCode, header::ToStrError}, | ||||
|     response::{IntoResponse, Response}, | ||||
|     Json, | ||||
| }; | ||||
|  | @ -36,6 +38,10 @@ pub enum ErrorCode { | |||
|     InternalServerError, | ||||
|     ObjectNotFound, | ||||
|     InvalidServer, | ||||
|     InvalidHeader, | ||||
|     InvalidDate, | ||||
|     InvalidSignature, | ||||
|     MissingSignature, | ||||
| } | ||||
| 
 | ||||
| impl From<sqlx::Error> for ApiError { | ||||
|  | @ -85,6 +91,16 @@ impl From<FoxError> for ApiError { | |||
|                 code: ErrorCode::InvalidServer, | ||||
|                 message: "Invalid domain or server".into(), | ||||
|             }, | ||||
|             FoxError::MissingSignature => ApiError { | ||||
|                 status: StatusCode::BAD_REQUEST, | ||||
|                 code: ErrorCode::MissingSignature, | ||||
|                 message: "Missing signature".into(), | ||||
|             }, | ||||
|             FoxError::InvalidSignature => ApiError { | ||||
|                 status: StatusCode::BAD_REQUEST, | ||||
|                 code: ErrorCode::InvalidSignature, | ||||
|                 message: "Invalid signature".into(), | ||||
|             }, | ||||
|             _ => ApiError { | ||||
|                 status: StatusCode::INTERNAL_SERVER_ERROR, | ||||
|                 code: ErrorCode::InternalServerError, | ||||
|  | @ -93,3 +109,33 @@ impl From<FoxError> for ApiError { | |||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl From<ToStrError> for ApiError { | ||||
|     fn from(_: ToStrError) -> Self { | ||||
|         ApiError { | ||||
|             status: StatusCode::BAD_REQUEST, | ||||
|             code: ErrorCode::InvalidHeader, | ||||
|             message: "Invalid header value".into(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl From<chrono::ParseError> for ApiError { | ||||
|     fn from(_: chrono::ParseError) -> Self { | ||||
|         ApiError { | ||||
|             status: StatusCode::BAD_REQUEST, | ||||
|             code: ErrorCode::InvalidDate, | ||||
|             message: "Invalid date header value".into(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl From<ParseIntError> for ApiError { | ||||
|     fn from(_: ParseIntError) -> Self { | ||||
|         ApiError { | ||||
|             status: StatusCode::BAD_REQUEST, | ||||
|             code: ErrorCode::InvalidHeader, | ||||
|             message: "Invalid content length value".into(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -18,7 +18,7 @@ color-eyre = "0.6.2" | |||
| rsa = { version = "0.9.6", features = ["serde", "sha2"] } | ||||
| rand = "0.8.5" | ||||
| toml = "0.8.8" | ||||
| tokio = { version = "1.35.1", features = ["macros", "rt-multi-thread"] } | ||||
| tokio = { version = "1.35.1", features = ["macros", "rt-multi-thread", "sync"] } | ||||
| tracing-subscriber = "0.3.18" | ||||
| tracing = "0.1.40" | ||||
| tower-http = { version = "0.5.1", features = ["trace"] } | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue