feat: proxy response bodies

This commit is contained in:
sam 2024-02-15 15:11:04 +01:00
parent 0858d4893a
commit 18b644d24b
Signed by: sam
GPG key ID: B4EF20DDE721CAA1
4 changed files with 121 additions and 36 deletions

View file

@ -29,6 +29,8 @@ pub enum FoxError {
NotInGuild, NotInGuild,
#[error("channel not found")] #[error("channel not found")]
ChannelNotFound, ChannelNotFound,
#[error("internal server error while proxying")]
ProxyInternalServerError,
} }
impl From<ToStrError> for FoxError { impl From<ToStrError> for FoxError {

View file

@ -1,16 +1,19 @@
use eyre::Result; use eyre::{Context, Result};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use reqwest::{Response, StatusCode, header::{CONTENT_TYPE, CONTENT_LENGTH, DATE}}; use reqwest::{
header::{CONTENT_LENGTH, CONTENT_TYPE, DATE},
Response, StatusCode,
};
use rsa::RsaPrivateKey; use rsa::RsaPrivateKey;
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tracing::error; use tracing::error;
use crate::{ use crate::{
http::ErrorCode,
signature::{build_signature, format_date}, signature::{build_signature, format_date},
FoxError,
}; };
use super::{SIGNATURE_HEADER, USER_HEADER, SERVER_HEADER}; use super::{SERVER_HEADER, SIGNATURE_HEADER, USER_HEADER};
static APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); static APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),);
@ -32,13 +35,7 @@ pub async fn get<R: DeserializeOwned>(
path: &str, path: &str,
user_id: Option<String>, user_id: Option<String>,
) -> Result<R> { ) -> Result<R> {
let (signature, date) = build_signature( let (signature, date) = build_signature(private_key, host, path, None, user_id.clone());
private_key,
host,
path,
None,
user_id.clone(),
);
let mut req = CLIENT let mut req = CLIENT
.get(format!("https://{host}{path}")) .get(format!("https://{host}{path}"))
@ -53,7 +50,7 @@ pub async fn get<R: DeserializeOwned>(
}; };
let resp = req.send().await?; let resp = req.send().await?;
handle_response(resp).await handle_response(resp).await.wrap_err("handling response")
} }
pub async fn post<T: Serialize, R: DeserializeOwned>( pub async fn post<T: Serialize, R: DeserializeOwned>(
@ -66,13 +63,8 @@ pub async fn post<T: Serialize, R: DeserializeOwned>(
) -> Result<R> { ) -> Result<R> {
let body = serde_json::to_string(body)?; let body = serde_json::to_string(body)?;
let (signature, date) = build_signature( let (signature, date) =
private_key, build_signature(private_key, host, path, Some(body.len()), user_id.clone());
host,
path,
Some(body.len()),
user_id.clone(),
);
let mut req = CLIENT let mut req = CLIENT
.post(format!("https://{host}{path}")) .post(format!("https://{host}{path}"))
@ -90,15 +82,46 @@ pub async fn post<T: Serialize, R: DeserializeOwned>(
}; };
let resp = req.send().await?; let resp = req.send().await?;
handle_response(resp).await handle_response(resp).await.wrap_err("handling response")
} }
async fn handle_response<R: DeserializeOwned>(resp: Response) -> Result<R> { async fn handle_response<R: DeserializeOwned>(resp: Response) -> Result<R, ResponseError> {
if resp.status() != StatusCode::OK { if resp.status() != StatusCode::OK {
error!("federation request failed with status code {}", resp.status()); let status = resp.status().as_u16();
return Err(FoxError::ResponseNotOk.into());
let error_json = resp
.json::<ApiError>()
.await
.map_err(|_| ResponseError::JsonError)?;
return Err(ResponseError::NotOk {
status,
code: error_json.code,
message: error_json.message,
});
} }
let parsed = resp.json::<R>().await?; let parsed = resp
.json::<R>()
.await
.map_err(|_| ResponseError::JsonError)?;
Ok(parsed) Ok(parsed)
} }
#[derive(thiserror::Error, Debug, Clone)]
pub enum ResponseError {
#[error("non-200 status code")]
NotOk {
status: u16,
code: ErrorCode,
message: String,
},
#[error("error deserializing JSON")]
JsonError,
}
#[derive(Deserialize)]
struct ApiError {
pub code: ErrorCode,
pub message: String,
}

View file

@ -6,16 +6,16 @@ use axum::{
Json, Json,
}; };
use eyre::Report; use eyre::Report;
use serde::Serialize; use serde::{Serialize, Deserialize};
use serde_json::json; use serde_json::json;
use tracing::error; use tracing::error;
use crate::FoxError; use crate::FoxError;
pub struct ApiError { pub struct ApiError {
status: StatusCode, pub status: StatusCode,
code: ErrorCode, pub code: ErrorCode,
message: String, pub message: String,
} }
impl IntoResponse for ApiError { impl IntoResponse for ApiError {
@ -32,7 +32,7 @@ impl IntoResponse for ApiError {
} }
} }
#[derive(Serialize)] #[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ErrorCode { pub enum ErrorCode {
InternalServerError, InternalServerError,
@ -147,6 +147,11 @@ impl From<FoxError> for ApiError {
status: StatusCode::NOT_FOUND, status: StatusCode::NOT_FOUND,
code: ErrorCode::GuildNotFound, code: ErrorCode::GuildNotFound,
message: "Channel or guild not found".into(), message: "Channel or guild not found".into(),
},
FoxError::ProxyInternalServerError => ApiError {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: ErrorCode::InternalServerError,
message: "Internal server error".into(),
} }
} }
} }

View file

@ -1,14 +1,15 @@
use std::sync::Arc; use std::sync::Arc;
use axum::{extract::OriginalUri, routing::post, Extension, Json, Router}; use axum::{extract::OriginalUri, http::StatusCode, routing::post, Extension, Json, Router};
use eyre::ContextCompat; use eyre::ContextCompat;
use foxchat::{ use foxchat::{
fed, fed::{self, request::ResponseError},
http::ApiError, http::ApiError,
model::{ model::{
http::{channel::CreateMessageParams, guild::CreateGuildParams}, http::{channel::CreateMessageParams, guild::CreateGuildParams},
Guild, Message, Guild, Message,
}, },
FoxError,
}; };
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use tracing::debug; use tracing::debug;
@ -45,9 +46,36 @@ async fn proxy_get<R: Serialize + DeserializeOwned>(
&format!("/_fox/chat/{}", proxy_path), &format!("/_fox/chat/{}", proxy_path),
Some(user.id), Some(user.id),
) )
.await?; .await;
Ok(Json(resp)) match resp {
Ok(r) => return Ok(Json(r)),
Err(e) => {
if let Some(e) = e.downcast_ref::<ResponseError>() {
match e {
ResponseError::JsonError => {
return Err(FoxError::ProxyInternalServerError.into())
}
ResponseError::NotOk {
status,
code,
message,
} => {
return Err(ApiError {
status: StatusCode::from_u16(status.to_owned()).map_err(
|_| -> ApiError { FoxError::ProxyInternalServerError.into() },
)?,
code: code.to_owned(),
message: message.to_owned(),
})
}
}
} else {
tracing::error!("proxying GET request: {}", e);
return Err(FoxError::ProxyInternalServerError.into());
}
}
}
} }
async fn proxy_post<B: Serialize, R: Serialize + DeserializeOwned>( async fn proxy_post<B: Serialize, R: Serialize + DeserializeOwned>(
@ -71,7 +99,34 @@ async fn proxy_post<B: Serialize, R: Serialize + DeserializeOwned>(
Some(user.id), Some(user.id),
&body, &body,
) )
.await?; .await;
Ok(Json(resp)) match resp {
Ok(r) => return Ok(Json(r)),
Err(e) => {
if let Some(e) = e.downcast_ref::<ResponseError>() {
match e {
ResponseError::JsonError => {
return Err(FoxError::ProxyInternalServerError.into())
}
ResponseError::NotOk {
status,
code,
message,
} => {
return Err(ApiError {
status: StatusCode::from_u16(status.to_owned()).map_err(
|_| -> ApiError { FoxError::ProxyInternalServerError.into() },
)?,
code: code.to_owned(),
message: message.to_owned(),
})
}
}
} else {
tracing::error!("proxying POST request: {}", e);
return Err(FoxError::ProxyInternalServerError.into());
}
}
}
} }