feat: make identity server store user instance/guild IDs

This commit is contained in:
sam 2024-03-04 15:59:07 +01:00
parent 42abd70184
commit 60a158495b
Signed by: sam
GPG key ID: B4EF20DDE721CAA1
19 changed files with 245 additions and 28 deletions

View file

@ -1,3 +1,7 @@
mod rest;
pub use rest::post_rest_event;
use std::sync::Arc; use std::sync::Arc;
use axum::{ use axum::{

38
chat/src/fed/rest.rs Normal file
View file

@ -0,0 +1,38 @@
use std::sync::Arc;
use eyre::Result;
use foxchat::{fed, s2s::http::RestEvent};
use tracing::{debug, error};
use crate::{app_state::AppState, model::identity_instance::IdentityInstance};
/// Posts an event to a remote instance's inbox.
pub async fn post_rest_event(
state: Arc<AppState>,
instance: &IdentityInstance,
event: RestEvent,
) -> Result<()> {
debug!("Sending {:?} event to {}'s inbox", &event, &instance.domain);
match fed::post::<RestEvent, ()>(
&state.private_key,
&state.config.domain,
&instance.domain,
"/_fox/ident/inbox",
None,
&event,
)
.await
{
Ok(_) => Ok(()),
Err(e) => {
error!(
"Error sending {:?} event to {}'s inbox: {}",
&event, &instance.domain, e
);
return Err(e);
}
}
}

View file

@ -4,13 +4,17 @@ use axum::{Extension, Json};
use foxchat::{ use foxchat::{
http::ApiError, http::ApiError,
model::{channel::PartialChannel, http::guild::CreateGuildParams, Guild}, model::{channel::PartialChannel, http::guild::CreateGuildParams, Guild},
s2s::http::RestEvent,
FoxError, FoxError,
}; };
use crate::{ use crate::{
app_state::AppState, app_state::AppState,
db::{channel::create_channel, guild::{create_guild, join_guild}}, db::{
fed::FoxRequestData, channel::create_channel,
guild::{create_guild, join_guild},
},
fed::{post_rest_event, FoxRequestData},
model::user::User, model::user::User,
}; };
@ -27,9 +31,20 @@ pub async fn post_guilds(
let channel = create_channel(&mut *tx, &guild.id, "general", None).await?; let channel = create_channel(&mut *tx, &guild.id, "general", None).await?;
join_guild(&mut *tx, &guild.id, &user.id).await?; join_guild(&mut *tx, &guild.id, &user.id).await?;
tx.commit().await?; tx.commit().await?;
// Send an event to the user's instance that they joined a guild
post_rest_event(
state,
&request.instance,
RestEvent::GuildJoin {
guild_id: guild.id.0.clone(),
user_id: user.remote_user_id.clone(),
},
)
.await
.ok();
Ok(Json(Guild { Ok(Json(Guild {
id: guild.id.0, id: guild.id.0,
name: guild.name, name: guild.name,

View file

@ -3,6 +3,7 @@ use std::sync::Arc;
use axum::{Extension, Json}; use axum::{Extension, Json};
use eyre::Context; use eyre::Context;
use foxchat::{ use foxchat::{
error::ToFoxError,
fed, fed,
http::ApiError, http::ApiError,
s2s::http::{HelloRequest, HelloResponse, NodeResponse}, s2s::http::{HelloRequest, HelloResponse, NodeResponse},
@ -30,7 +31,8 @@ pub async fn post_hello(
"/_fox/ident/node", "/_fox/ident/node",
None, None,
) )
.await?; .await?
.to_fox_error()?;
let public_key = let public_key =
RsaPublicKey::from_pkcs1_pem(&node.public_key).wrap_err("parsing remote public key")?; RsaPublicKey::from_pkcs1_pem(&node.public_key).wrap_err("parsing remote public key")?;

View file

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use eyre::Result; use eyre::Result;
use foxchat::{fed, model::User as HttpUser, Id, id::UserType}; use foxchat::{error::ToFoxError, fed, id::UserType, model::User as HttpUser, Id};
use ulid::Ulid; use ulid::Ulid;
use crate::app_state::AppState; use crate::app_state::AppState;
@ -48,7 +48,8 @@ impl User {
&format!("/_fox/ident/users/{}", remote_id), &format!("/_fox/ident/users/{}", remote_id),
None, None,
) )
.await?; .await?
.to_fox_error()?;
let user = sqlx::query!( let user = sqlx::query!(
"insert into users (id, instance_id, remote_user_id, username, avatar) values ($1, $2, $3, $4, $5) returning *", "insert into users (id, instance_id, remote_user_id, username, avatar) values ($1, $2, $3, $4, $5) returning *",

View file

@ -18,7 +18,8 @@ pub enum Payload {
/// Hello message, sent after authentication succeeds /// Hello message, sent after authentication succeeds
Hello { Hello {
heartbeat_interval: u64, heartbeat_interval: u64,
guilds: Vec<String>, guilds: Vec<GuildInstance>,
instances: Vec<InstanceId>,
}, },
Identify { Identify {
token: String, token: String,
@ -34,3 +35,15 @@ pub enum Payload {
timestamp: u64, timestamp: u64,
} }
} }
#[derive(Debug, Serialize, Deserialize)]
pub struct InstanceId {
pub id: String,
pub domain: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GuildInstance {
pub guild_id: String,
pub instance_id: String,
}

View file

@ -1,5 +1,6 @@
use axum::http::header::ToStrError; use axum::http::header::ToStrError;
use thiserror::Error; use thiserror::Error;
use eyre::Result;
#[derive(Error, Debug, Copy, Clone)] #[derive(Error, Debug, Copy, Clone)]
pub enum FoxError { pub enum FoxError {
@ -31,6 +32,8 @@ pub enum FoxError {
ChannelNotFound, ChannelNotFound,
#[error("internal server error while proxying")] #[error("internal server error while proxying")]
ProxyInternalServerError, ProxyInternalServerError,
#[error("invalid RestEvent for this instance type")]
InvalidRestEvent,
} }
impl From<ToStrError> for FoxError { impl From<ToStrError> for FoxError {
@ -44,3 +47,16 @@ impl From<chrono::ParseError> for FoxError {
Self::InvalidDate Self::InvalidDate
} }
} }
pub trait ToFoxError<T> {
fn to_fox_error(self) -> Result<T>;
}
impl<T> ToFoxError<T> for Option<T> {
fn to_fox_error(self) -> Result<T> {
match self {
Some(t) => Ok(t),
None => Err(FoxError::ResponseNotOk.into())
}
}
}

View file

@ -34,7 +34,7 @@ pub async fn get<R: DeserializeOwned>(
host: &str, host: &str,
path: &str, path: &str,
user_id: Option<String>, user_id: Option<String>,
) -> Result<R> { ) -> Result<Option<R>> {
let (signature, date) = build_signature(private_key, host, path, None, user_id.clone()); let (signature, date) = build_signature(private_key, host, path, None, user_id.clone());
let mut req = CLIENT let mut req = CLIENT
@ -60,7 +60,7 @@ pub async fn post<T: Serialize, R: DeserializeOwned>(
path: &str, path: &str,
user_id: Option<String>, user_id: Option<String>,
body: &T, body: &T,
) -> Result<R> { ) -> Result<Option<R>> {
let body = serde_json::to_string(body)?; let body = serde_json::to_string(body)?;
let (signature, date) = let (signature, date) =
@ -85,7 +85,11 @@ pub async fn post<T: Serialize, R: DeserializeOwned>(
handle_response(resp).await.wrap_err("handling response") handle_response(resp).await.wrap_err("handling response")
} }
async fn handle_response<R: DeserializeOwned>(resp: Response) -> Result<R, ResponseError> { async fn handle_response<R: DeserializeOwned>(resp: Response) -> Result<Option<R>, ResponseError> {
if resp.status() == StatusCode::NO_CONTENT {
return Ok(None);
}
if resp.status() != StatusCode::OK { if resp.status() != StatusCode::OK {
let status = resp.status().as_u16(); let status = resp.status().as_u16();
@ -105,7 +109,7 @@ async fn handle_response<R: DeserializeOwned>(resp: Response) -> Result<R, Respo
.json::<R>() .json::<R>()
.await .await
.map_err(|_| ResponseError::JsonError)?; .map_err(|_| ResponseError::JsonError)?;
Ok(parsed) Ok(Some(parsed))
} }
#[derive(thiserror::Error, Debug, Clone)] #[derive(thiserror::Error, Debug, Clone)]

View file

@ -45,8 +45,6 @@ fn plaintext_string(
raw_time, host, request_path, raw_content_length, raw_user_id raw_time, host, request_path, raw_content_length, raw_user_id
); );
println!("{}", s);
s s
} }

View file

@ -44,6 +44,7 @@ pub enum ErrorCode {
MissingSignature, MissingSignature,
GuildNotFound, GuildNotFound,
Unauthorized, Unauthorized,
InvalidRestEvent,
} }
impl From<sqlx::Error> for ApiError { impl From<sqlx::Error> for ApiError {
@ -152,7 +153,12 @@ impl From<FoxError> for ApiError {
status: StatusCode::INTERNAL_SERVER_ERROR, status: StatusCode::INTERNAL_SERVER_ERROR,
code: ErrorCode::InternalServerError, code: ErrorCode::InternalServerError,
message: "Internal server error".into(), message: "Internal server error".into(),
} },
FoxError::InvalidRestEvent => ApiError {
status: StatusCode::BAD_REQUEST,
code: ErrorCode::InvalidRestEvent,
message: "Invalid RestEvent for this instance type".into(),
},
} }
} }
} }

View file

@ -1,3 +1,5 @@
mod hello; mod hello;
mod rest_event;
pub use hello::{HelloRequest, HelloResponse, NodeResponse}; pub use hello::{HelloRequest, HelloResponse, NodeResponse};
pub use rest_event::RestEvent;

View file

@ -0,0 +1,18 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")]
pub enum RestEvent {
/// Sent when a user creates or joins a guild (chat -> identity)
GuildJoin { guild_id: String, user_id: String },
/// Sent when a user leaves or is removed from a guild (chat -> identity)
GuildLeave { guild_id: String, user_id: String },
/// Sent when a user updates their profile (identity -> chat)
UserUpdate {
user_id: String,
username: Option<String>,
avatar: Option<String>,
},
/// Sent when a user deletes their account (identity -> chat)
UserDelete { user_id: String },
}

View file

@ -2,4 +2,4 @@
fn main() { fn main() {
// trigger recompilation when a new migration is added // trigger recompilation when a new migration is added
println!("cargo:rerun-if-changed=migrations"); println!("cargo:rerun-if-changed=migrations");
} }

View file

@ -0,0 +1,7 @@
create table guilds_accounts (
chat_instance_id text not null references chat_instances (id) on delete cascade,
guild_id text not null,
account_id text not null references accounts (id) on delete cascade,
primary key (chat_instance_id, guild_id, account_id)
);

View file

@ -0,0 +1,59 @@
use std::sync::Arc;
use axum::{http::StatusCode, Extension, Json};
use foxchat::{http::ApiError, s2s::http::RestEvent, FoxError};
use tracing::info;
use crate::{app_state::AppState, fed::FoxRequestData};
pub async fn post_inbox(
Extension(state): Extension<Arc<AppState>>,
request: FoxRequestData,
Json(evt): Json<RestEvent>,
) -> Result<(StatusCode, &'static str), ApiError> {
match evt {
RestEvent::GuildJoin { guild_id, user_id } => {
info!(
"received GUILD_JOIN event from {}/{} for user {}/guild {}",
&request.instance.id, &request.instance.domain, &user_id, &guild_id,
);
let mut tx = state.pool.begin().await?;
sqlx::query!(
r#"INSERT INTO chat_instance_accounts
(account_id, chat_instance_id) VALUES ($1, $2)
ON CONFLICT (account_id, chat_instance_id) DO NOTHING"#,
&user_id,
&request.instance.id
)
.execute(&mut *tx)
.await?;
sqlx::query!(
r#"INSERT INTO guilds_accounts (chat_instance_id, guild_id, account_id)
VALUES ($1, $2, $3) ON CONFLICT (chat_instance_id, guild_id, account_id) DO NOTHING"#,
&request.instance.id,
&guild_id,
&user_id
)
.execute(&mut *tx)
.await?;
tx.commit().await?;
}
RestEvent::GuildLeave { guild_id, user_id } => {
sqlx::query!(
"DELETE FROM guilds_accounts WHERE chat_instance_id = $1 AND guild_id = $2 AND account_id = $3",
&request.instance.id,
&guild_id,
&user_id,
).execute(&state.pool).await?;
}
_ => {
return Err(FoxError::InvalidServer.into());
}
}
Ok((StatusCode::NO_CONTENT, ""))
}

View file

@ -1,12 +1,13 @@
mod account; mod account;
mod auth; mod auth;
mod inbox;
mod node; mod node;
mod proxy; mod proxy;
mod ws; mod ws;
use std::sync::Arc; use std::sync::Arc;
use axum::{routing::get, Extension, Router}; use axum::{routing::{get, post}, Extension, Router};
use sqlx::{Pool, Postgres}; use sqlx::{Pool, Postgres};
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
@ -26,6 +27,7 @@ pub fn new(pool: Pool<Postgres>, config: Config, instance: Instance) -> Router {
.route("/_fox/ident/node", get(node::get_node)) .route("/_fox/ident/node", get(node::get_node))
.route("/_fox/ident/node/:domain", get(node::get_chat_node)) .route("/_fox/ident/node/:domain", get(node::get_chat_node))
.route("/_fox/ident/ws", get(ws::handler)) .route("/_fox/ident/ws", get(ws::handler))
.route("/_fox/ident/inbox", post(inbox::post_inbox))
.layer(TraceLayer::new_for_http()) .layer(TraceLayer::new_for_http())
.layer(Extension(app_state)); .layer(Extension(app_state));

View file

@ -37,7 +37,6 @@ async fn proxy_get<R: Serialize + DeserializeOwned>(
let proxy_path = original let proxy_path = original
.strip_prefix("/_fox/proxy/") .strip_prefix("/_fox/proxy/")
.wrap_err("invalid url")?; .wrap_err("invalid url")?;
println!("{}", proxy_path);
let resp = fed::get::<R>( let resp = fed::get::<R>(
&state.private_key, &state.private_key,
@ -49,7 +48,10 @@ async fn proxy_get<R: Serialize + DeserializeOwned>(
.await; .await;
match resp { match resp {
Ok(r) => return Ok(Json(r)), Ok(r) => match r {
Some(r) => Ok(Json(r)),
None => Err(FoxError::ResponseNotOk.into()),
},
Err(e) => { Err(e) => {
if let Some(e) = e.downcast_ref::<ResponseError>() { if let Some(e) = e.downcast_ref::<ResponseError>() {
match e { match e {
@ -102,7 +104,10 @@ async fn proxy_post<B: Serialize, R: Serialize + DeserializeOwned>(
.await; .await;
match resp { match resp {
Ok(r) => return Ok(Json(r)), Ok(r) => match r {
Some(r) => Ok(Json(r)),
None => Err(FoxError::ResponseNotOk.into()),
},
Err(e) => { Err(e) => {
if let Some(e) = e.downcast_ref::<ResponseError>() { if let Some(e) = e.downcast_ref::<ResponseError>() {
match e { match e {

View file

@ -9,13 +9,16 @@ use axum::{
Extension, Extension,
}; };
use eyre::Result; use eyre::Result;
use foxchat::{c2s::Payload, FoxError}; use foxchat::{c2s::{Payload, event::{InstanceId, GuildInstance}}, FoxError};
use futures::{ use futures::{
stream::{SplitSink, SplitStream, StreamExt}, stream::{SplitSink, SplitStream, StreamExt},
SinkExt, SinkExt,
}; };
use rand::Rng; use rand::Rng;
use tokio::{sync::mpsc::{self, Receiver, Sender}, time::timeout}; use tokio::{
sync::mpsc::{self, Receiver, Sender},
time::timeout,
};
use tracing::error; use tracing::error;
use crate::{app_state::AppState, db::check_token, model::account::Account}; use crate::{app_state::AppState, db::check_token, model::account::Account};
@ -126,15 +129,41 @@ async fn read(
// Send HELLO event // Send HELLO event
// TODO: fetch guild IDs // TODO: fetch guild IDs
let heartbeat_interval = rand::thread_rng().gen_range(45_000..70_000); let heartbeat_interval = rand::thread_rng().gen_range(45_000..70_000);
tx.send(Payload::Hello { heartbeat_interval, guilds: vec![] }).await.ok(); // Fetch the instance ID/domain pairs
let instances = match sqlx::query_as!(InstanceId, r#"SELECT i.id, i.domain FROM chat_instances i
JOIN chat_instance_accounts ia ON ia.chat_instance_id = i.id
WHERE ia.account_id = $1"#, &user.id).fetch_all(&app_state.pool).await {
Ok(i) => i,
Err(e) => {
error!("error getting instances for user {}: {}", &user.id, e);
tx.send(Payload::Error { message: "Internal server error".into() }).await.ok();
return;
}
};
// Fetch guild ID/instance ID pairs
let guilds = match sqlx::query_as!(GuildInstance, r#"SELECT guild_id, chat_instance_id AS instance_id
FROM guilds_accounts WHERE account_id = $1"#, &user.id).fetch_all(&app_state.pool).await {
Ok(g) => g,
Err(e) => {
error!("error getting guilds for user {}: {}", &user.id, e);
tx.send(Payload::Error { message: "Internal server error".into() }).await.ok();
return;
}
};
tx.send(Payload::Hello {
heartbeat_interval,
guilds,
instances,
})
.await
.ok();
// Start the heartbeat loop // Start the heartbeat loop
let (heartbeat_tx, heartbeat_rx) = mpsc::channel::<u64>(10); let (heartbeat_tx, heartbeat_rx) = mpsc::channel::<u64>(10);
tokio::spawn(heartbeat(heartbeat_interval, heartbeat_rx, tx.clone())); tokio::spawn(heartbeat(heartbeat_interval, heartbeat_rx, tx.clone()));
// Fire off ready event
tokio::spawn(collect_ready(app_state.clone(), user.clone(), tx.clone()));
// Start listening for events // Start listening for events
while let Some(msg) = receiver.next().await { while let Some(msg) = receiver.next().await {
let Ok(msg) = msg else { let Ok(msg) = msg else {
@ -237,8 +266,6 @@ async fn write(
} }
} }
async fn collect_ready(app_state: Arc<AppState>, user: Arc<Account>, tx: Sender<Payload>) {}
async fn heartbeat( async fn heartbeat(
heartbeat_interval: u64, heartbeat_interval: u64,
mut rx: mpsc::Receiver<u64>, mut rx: mpsc::Receiver<u64>,

View file

@ -58,7 +58,7 @@ impl ChatInstance {
host: state.config.domain.clone(), host: state.config.domain.clone(),
}, },
) )
.await?; .await?.expect("");
if resp.host != domain { if resp.host != domain {
return Err(FoxError::InvalidServer.into()); return Err(FoxError::InvalidServer.into());