add request verification extractor

This commit is contained in:
sam 2024-01-18 16:34:40 +01:00
parent 7a694623e5
commit 1e53661b0a
Signed by: sam
GPG key ID: B4EF20DDE721CAA1
18 changed files with 482 additions and 32 deletions

View file

@ -5,3 +5,6 @@ members = [
"chat"
]
resolver = "2"
[profile.dev.package.num-bigint-dig]
opt-level = 3

View file

@ -14,7 +14,7 @@ create table users (
instance_id text not null references identity_instances (id) on delete cascade,
remote_user_id text not null,
username text not null,
avatar text -- URL, not hash, as this is a remote file
);
@ -46,6 +46,14 @@ create table messages (
channel_id text not null references channels (id) on delete cascade,
author_id text not null,
updated_at timestamptz not null default now(),
content text not null
);
create table instance (
id integer not null primary key default 1,
public_key text not null,
private_key text not null,
constraint singleton check (id = 1)
);

8
chat/src/app_state.rs Normal file
View file

@ -0,0 +1,8 @@
use sqlx::{Pool, Postgres};
use crate::config::Config;
pub struct AppState {
pub pool: Pool<Postgres>,
pub config: Config,
}

56
chat/src/config.rs Normal file
View file

@ -0,0 +1,56 @@
use eyre::Result;
use serde::Deserialize;
use std::path::Path;
use std::{env, fs};
use tracing::Level;
pub const CONFIG_FILE: &str = "config.chat.toml";
#[derive(Deserialize)]
pub struct Config {
pub database_url: String,
pub port: u16,
pub domain: String,
pub auto_migrate: Option<bool>,
pub log_level: Option<String>,
}
impl Config {
pub fn load() -> Result<Self> {
let cwd = env::current_dir()?;
let config_file = Path::join(cwd.as_path(), Path::new(CONFIG_FILE));
println!("config file: {}", config_file.display());
let s = fs::read_to_string(config_file)?;
let config = toml::from_str(s.as_str())?;
Ok(config)
}
pub fn tracing_level(&self) -> Option<Level> {
match self
.log_level
.as_deref()
.unwrap_or("INFO")
{
"TRACE" => Some(Level::TRACE),
"DEBUG" => Some(Level::DEBUG),
"INFO" => Some(Level::INFO),
"WARN" => Some(Level::WARN),
"ERROR" => Some(Level::ERROR),
_ => None
}
}
}
impl Default for Config {
fn default() -> Self {
Config {
database_url: env::var("DATABASE_URL").unwrap_or("".into()),
domain: "".into(),
port: 3000,
auto_migrate: None,
log_level: None,
}
}
}

50
chat/src/db/mod.rs Normal file
View file

@ -0,0 +1,50 @@
use eyre::{OptionExt, Result};
use rsa::pkcs1::{EncodeRsaPrivateKey, EncodeRsaPublicKey, LineEnding};
use rsa::{RsaPrivateKey, RsaPublicKey};
use sqlx::postgres::PgPoolOptions;
use sqlx::{Pool, Postgres};
use std::time::Duration;
pub async fn init(dsn: &str) -> Result<Pool<Postgres>> {
let pool = PgPoolOptions::new()
.acquire_timeout(Duration::from_secs(2)) // Fail fast and don't hang
.max_connections(100)
.connect(dsn)
.await?;
Ok(pool)
}
const PRIVATE_KEY_BITS: usize = 2048;
pub async fn init_instance(pool: &Pool<Postgres>) -> Result<()> {
let mut tx = pool.begin().await?;
// Check if we already have an instance configuration
let row = sqlx::query!("select exists(select * from instance)")
.fetch_one(&mut *tx)
.await?;
if row.exists.ok_or_eyre("exists was null")? {
return Ok(());
}
// Generate public/private key
let mut rng = rand::thread_rng();
let priv_key = RsaPrivateKey::new(&mut rng, PRIVATE_KEY_BITS)?;
let pub_key = RsaPublicKey::from(&priv_key);
let priv_key_string = priv_key.to_pkcs1_pem(LineEnding::LF)?;
let pub_key_string = pub_key.to_pkcs1_pem(LineEnding::LF)?;
sqlx::query!(
"insert into instance (public_key, private_key) values ($1, $2)",
pub_key_string,
priv_key_string.to_string(),
)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(())
}

102
chat/src/fed/mod.rs Normal file
View file

@ -0,0 +1,102 @@
use std::sync::Arc;
use axum::{
async_trait,
extract::FromRequestParts,
http::{
header::{CONTENT_LENGTH, DATE, HOST},
request::Parts,
},
Extension,
};
use foxchat::{
fed::{SERVER_HEADER, SIGNATURE_HEADER, USER_HEADER},
http::ApiError,
signature::{parse_date, verify_signature},
FoxError,
};
use tracing::error;
use crate::{app_state::AppState, model::identity_instance::IdentityInstance};
pub struct FoxRequestData {
pub instance: IdentityInstance,
pub user_id: Option<String>,
}
#[async_trait]
impl<S> FromRequestParts<S> for FoxRequestData
where
S: Send + Sync,
{
type Rejection = ApiError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let state: Extension<Arc<AppState>> = Extension::from_request_parts(parts, state)
.await
.expect("AppState was not added as an extension");
let domain = parts
.headers
.get(SERVER_HEADER)
.ok_or(FoxError::InvalidHeader)?
.to_str()?;
let instance = IdentityInstance::get(state.0, domain).await?;
let public_key = instance.parse_public_key()?;
let date = parse_date(
parts
.headers
.get(DATE)
.ok_or(FoxError::InvalidHeader)?
.to_str()?,
)?;
let signature = parts
.headers
.get(SIGNATURE_HEADER)
.ok_or(FoxError::MissingSignature)?
.to_str()?
.to_string();
let host = parts
.headers
.get(HOST)
.ok_or(FoxError::InvalidHeader)?
.to_str()?;
let content_length = if let Some(raw_length) = parts.headers.get(CONTENT_LENGTH) {
Some(raw_length.to_str()?.parse::<usize>()?)
} else {
None
};
let user_id = if let Some(raw_id) = parts.headers.get(USER_HEADER) {
Some(raw_id.to_str()?)
} else {
None
};
if let Err(e) = verify_signature(
&public_key,
signature,
date,
host,
parts.uri.path(),
content_length,
user_id,
) {
error!(
"Verifying signature from request for {} from {}: {}",
parts.uri.path(),
domain,
e
);
return Err(FoxError::InvalidSignature.into());
}
Ok(FoxRequestData {
instance,
user_id: user_id.map(|v| v.to_string()),
})
}
}

View file

@ -1,3 +1,65 @@
fn main() {
println!("Hello, world!");
mod config;
mod db;
mod fed;
mod app_state;
mod model;
use crate::config::Config;
use clap::{Parser, Subcommand};
use eyre::Result;
use tracing::info;
#[derive(Debug, Parser)]
struct Cli {
#[command(subcommand)]
command: Option<Command>,
}
#[derive(Debug, Subcommand)]
enum Command {
Serve,
Migrate,
}
#[tokio::main]
async fn main() -> Result<()> {
color_eyre::install()?;
let config = Config::load()?;
let args = Cli::parse();
tracing_subscriber::fmt()
.with_max_level(config.tracing_level().unwrap_or(tracing::Level::INFO))
.init();
match args.command.unwrap_or(Command::Serve) {
Command::Serve => main_web(config).await,
Command::Migrate => main_migrate(config).await,
}
}
async fn main_migrate(config: Config) -> Result<()> {
info!("Connecting to database");
let pool = db::init(&config.database_url).await?;
info!("Migrating database");
sqlx::migrate!().run(&pool).await?;
info!("Migrated database");
Ok(())
}
async fn main_web(config: Config) -> Result<()> {
info!("Connecting to database");
let pool = db::init(&config.database_url).await?;
if config.auto_migrate.unwrap_or(false) {
info!("Auto-migrate is enabled, migrating database");
sqlx::migrate!().run(&pool).await?;
info!("Migrated database");
}
info!("Initializing instance data");
db::init_instance(&pool).await?;
info!("Initialized instance data!");
Ok(())
}

View file

@ -0,0 +1,52 @@
use std::sync::Arc;
use eyre::{Result, Context};
use foxchat::{fed::request::is_valid_domain, FoxError};
use rsa::{RsaPublicKey, pkcs1::DecodeRsaPublicKey};
use serde::{Deserialize, Serialize};
use crate::app_state::AppState;
#[derive(Serialize)]
pub struct IdentityInstance {
pub id: String,
pub domain: String,
pub base_url: String,
pub public_key: String,
pub status: InstanceStatus,
pub reason: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, sqlx::Type)]
#[sqlx(type_name = "instance_status", rename_all = "lowercase")]
pub enum InstanceStatus {
Active,
Suspended,
}
impl IdentityInstance {
pub async fn get(state: Arc<AppState>, domain: &str) -> Result<Self> {
if !is_valid_domain(domain) {
return Err(FoxError::InvalidServer.into());
}
if let Some(instance) = sqlx::query_as!(
Self,
r#"select id, domain, base_url, public_key,
status as "status: InstanceStatus", reason
from identity_instances where domain = $1"#,
domain
)
.fetch_optional(&state.pool)
.await?
{
return Ok(instance);
}
return Err(FoxError::InvalidServer.into());
}
pub fn parse_public_key(&self) -> Result<RsaPublicKey> {
RsaPublicKey::from_pkcs1_pem(&self.public_key).wrap_err("parsing identity instance public key")
}
}

View file

@ -0,0 +1,27 @@
use eyre::Result;
use sqlx::{Pool, Postgres};
use rsa::{RsaPrivateKey, RsaPublicKey, pkcs1::{DecodeRsaPublicKey, DecodeRsaPrivateKey}};
#[derive(Debug, Clone)]
pub struct Instance {
pub public_key: RsaPublicKey,
pub private_key: RsaPrivateKey,
}
impl Instance {
/// Gets the instance's configuration.
/// This is a singleton row that is always present.
pub async fn get(pool: &Pool<Postgres>) -> Result<Self> {
let instance = sqlx::query!("SELECT * FROM instance WHERE id = 1")
.fetch_one(pool)
.await?;
let public_key = RsaPublicKey::from_pkcs1_pem(&instance.public_key)?;
let private_key = RsaPrivateKey::from_pkcs1_pem(&instance.private_key)?;
Ok(Self {
public_key,
private_key,
})
}
}

2
chat/src/model/mod.rs Normal file
View file

@ -0,0 +1,2 @@
pub mod instance;
pub mod identity_instance;

View file

@ -1,3 +1,4 @@
database_url = "postgresql://foxchat:password@localhost/foxchat_chat_dev"
port = 3001
port = 7610
log_level = "DEBUG"
domain = "chat.fox.localhost"

View file

@ -1,4 +1,4 @@
database_url = "postgresql://foxchat:password@localhost/foxchat_ident_dev"
port = 3000
port = 7611
log_level = "DEBUG"
domain = "chat.foxchat.localhost"
domain = "id.fox.localhost"

View file

@ -1,3 +1,4 @@
use axum::http::header::ToStrError;
use thiserror::Error;
#[derive(Error, Debug, Copy, Clone)]
@ -10,4 +11,24 @@ pub enum FoxError {
ResponseNotOk,
#[error("server is invalid")]
InvalidServer,
#[error("invalid header")]
InvalidHeader,
#[error("invalid date format")]
InvalidDate,
#[error("missing signature")]
MissingSignature,
#[error("invalid signature")]
InvalidSignature,
}
impl From<ToStrError> for FoxError {
fn from(_: ToStrError) -> Self {
Self::InvalidHeader
}
}
impl From<chrono::ParseError> for FoxError {
fn from(_: chrono::ParseError) -> Self {
Self::InvalidDate
}
}

View file

@ -2,3 +2,7 @@ pub mod request;
pub mod signature;
pub use request::{get, post};
pub const SERVER_HEADER: &'static str = "X-Foxchat-Server";
pub const SIGNATURE_HEADER: &'static str = "X-Foxchat-Signature";
pub const USER_HEADER: &'static str = "X-Foxchat-User";

View file

@ -1,14 +1,17 @@
use eyre::Result;
use once_cell::sync::Lazy;
use reqwest::{Response, StatusCode};
use reqwest::{Response, StatusCode, header::{CONTENT_TYPE, CONTENT_LENGTH, DATE}};
use rsa::RsaPrivateKey;
use serde::{de::DeserializeOwned, Serialize};
use tracing::error;
use crate::{
signature::{build_signature, format_date},
FoxError,
};
use super::{SIGNATURE_HEADER, USER_HEADER, SERVER_HEADER};
static APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),);
static CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
@ -27,7 +30,7 @@ pub async fn get<R: DeserializeOwned>(
self_domain: String,
host: String,
path: String,
user_id: Option<String>,
user_id: Option<&str>,
) -> Result<R> {
let (signature, date) = build_signature(
private_key,
@ -39,12 +42,12 @@ pub async fn get<R: DeserializeOwned>(
let mut req = CLIENT
.get(format!("https://{host}{path}"))
.header("Date", format_date(date))
.header("X-Foxchat-Signature", signature)
.header("X-Foxchat-Server", self_domain);
.header(DATE, format_date(date))
.header(SIGNATURE_HEADER, signature)
.header(SERVER_HEADER, self_domain);
req = if let Some(id) = user_id {
req.header("X-Foxchat-User", id)
req.header(USER_HEADER, id)
} else {
req
};
@ -58,7 +61,7 @@ pub async fn post<T: Serialize, R: DeserializeOwned>(
self_domain: String,
host: String,
path: String,
user_id: Option<String>,
user_id: Option<&str>,
body: &T,
) -> Result<R> {
let body = serde_json::to_string(body)?;
@ -73,15 +76,15 @@ pub async fn post<T: Serialize, R: DeserializeOwned>(
let mut req = CLIENT
.post(format!("https://{host}{path}"))
.header("Date", format_date(date))
.header("X-Foxchat-Signature", signature)
.header("X-Foxchat-Server", self_domain)
.header("Content-Type", "application/json; charset=utf-8")
.header("Content-Length", body.len().to_string())
.header(DATE, format_date(date))
.header(SIGNATURE_HEADER, signature)
.header(SERVER_HEADER, self_domain)
.header(CONTENT_TYPE, "application/json; charset=utf-8")
.header(CONTENT_LENGTH, body.len().to_string())
.body(body);
req = if let Some(id) = user_id {
req.header("X-Foxchat-User", id)
req.header(USER_HEADER, id)
} else {
req
};
@ -92,6 +95,7 @@ pub async fn post<T: Serialize, R: DeserializeOwned>(
async fn handle_response<R: DeserializeOwned>(resp: Response) -> Result<R> {
if resp.status() != StatusCode::OK {
error!("federation request failed with status code {}", resp.status());
return Err(FoxError::ResponseNotOk.into());
}

View file

@ -1,5 +1,5 @@
use base64::prelude::{Engine, BASE64_URL_SAFE};
use chrono::{DateTime, Utc, Duration};
use chrono::{DateTime, Utc, Duration, NaiveDateTime};
use eyre::Result;
use rsa::pkcs1v15::{SigningKey, VerifyingKey, Signature};
use rsa::sha2::Sha256;
@ -13,13 +13,13 @@ pub fn build_signature(
host: String,
request_path: String,
content_length: Option<usize>,
user_id: Option<String>,
user_id: Option<&str>,
) -> (String, DateTime<Utc>) {
let mut rng = rand::thread_rng();
let signing_key = SigningKey::<Sha256>::new(private_key.clone());
let time = Utc::now();
let plaintext = plaintext_string(time, host, request_path, content_length, user_id);
let plaintext = plaintext_string(time, &host, &request_path, content_length, user_id);
let signature = signing_key.sign_with_rng(&mut rng, plaintext.as_bytes());
let str = BASE64_URL_SAFE.encode(signature.to_bytes());
@ -29,16 +29,16 @@ pub fn build_signature(
fn plaintext_string(
time: DateTime<Utc>,
host: String,
request_path: String,
host: &str,
request_path: &str,
content_length: Option<usize>,
user_id: Option<String>,
user_id: Option<&str>,
) -> String {
let raw_time = format_date(time);
let raw_content_length = content_length
.map(|i| i.to_string())
.unwrap_or("".to_owned());
let raw_user_id = user_id.unwrap_or("".to_owned());
let raw_user_id = user_id.unwrap_or("");
format!(
"{}:{}:{}:{}:{}",
@ -50,14 +50,18 @@ pub fn format_date(time: DateTime<Utc>) -> String {
time.format("%a, %d %b %Y %T GMT").to_string()
}
pub fn parse_date(input: &str) -> Result<DateTime<Utc>, chrono::ParseError> {
Ok(NaiveDateTime::parse_from_str(input, "%a, %d %b %Y %T GMT")?.and_utc())
}
pub fn verify_signature(
public_key: &RsaPublicKey,
encoded_signature: String,
time: DateTime<Utc>, // from Date header
host: String, // from Host header, verify that it's actually your host
request_path: String, // from router
host: &str, // from Host header, verify that it's actually your host
request_path: &str, // from router
content_length: Option<usize>, // from Content-Length header
user_id: Option<String>, // from X-Foxchat-User header
user_id: Option<&str>, // from X-Foxchat-User header
) -> Result<bool> {
let verifying_key = VerifyingKey::<Sha256>::new(public_key.clone());

View file

@ -1,5 +1,7 @@
use std::num::ParseIntError;
use axum::{
http::StatusCode,
http::{StatusCode, header::ToStrError},
response::{IntoResponse, Response},
Json,
};
@ -36,6 +38,10 @@ pub enum ErrorCode {
InternalServerError,
ObjectNotFound,
InvalidServer,
InvalidHeader,
InvalidDate,
InvalidSignature,
MissingSignature,
}
impl From<sqlx::Error> for ApiError {
@ -85,6 +91,16 @@ impl From<FoxError> for ApiError {
code: ErrorCode::InvalidServer,
message: "Invalid domain or server".into(),
},
FoxError::MissingSignature => ApiError {
status: StatusCode::BAD_REQUEST,
code: ErrorCode::MissingSignature,
message: "Missing signature".into(),
},
FoxError::InvalidSignature => ApiError {
status: StatusCode::BAD_REQUEST,
code: ErrorCode::InvalidSignature,
message: "Invalid signature".into(),
},
_ => ApiError {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: ErrorCode::InternalServerError,
@ -93,3 +109,33 @@ impl From<FoxError> for ApiError {
}
}
}
impl From<ToStrError> for ApiError {
fn from(_: ToStrError) -> Self {
ApiError {
status: StatusCode::BAD_REQUEST,
code: ErrorCode::InvalidHeader,
message: "Invalid header value".into(),
}
}
}
impl From<chrono::ParseError> for ApiError {
fn from(_: chrono::ParseError) -> Self {
ApiError {
status: StatusCode::BAD_REQUEST,
code: ErrorCode::InvalidDate,
message: "Invalid date header value".into(),
}
}
}
impl From<ParseIntError> for ApiError {
fn from(_: ParseIntError) -> Self {
ApiError {
status: StatusCode::BAD_REQUEST,
code: ErrorCode::InvalidHeader,
message: "Invalid content length value".into(),
}
}
}

View file

@ -18,7 +18,7 @@ color-eyre = "0.6.2"
rsa = { version = "0.9.6", features = ["serde", "sha2"] }
rand = "0.8.5"
toml = "0.8.8"
tokio = { version = "1.35.1", features = ["macros", "rt-multi-thread"] }
tokio = { version = "1.35.1", features = ["macros", "rt-multi-thread", "sync"] }
tracing-subscriber = "0.3.18"
tracing = "0.1.40"
tower-http = { version = "0.5.1", features = ["trace"] }