feat: make identity server store user instance/guild IDs

This commit is contained in:
sam 2024-03-04 15:59:07 +01:00
parent 42abd70184
commit fd77dd01fa
Signed by: sam
GPG key ID: B4EF20DDE721CAA1
19 changed files with 245 additions and 29 deletions

View file

@ -2,4 +2,4 @@
fn main() {
// trigger recompilation when a new migration is added
println!("cargo:rerun-if-changed=migrations");
}
}

View file

@ -0,0 +1,7 @@
create table guilds_accounts (
chat_instance_id text not null references chat_instances (id) on delete cascade,
guild_id text not null,
account_id text not null references accounts (id) on delete cascade,
primary key (chat_instance_id, guild_id, account_id)
);

View file

@ -0,0 +1,59 @@
use std::sync::Arc;
use axum::{http::StatusCode, Extension, Json};
use foxchat::{http::ApiError, s2s::http::RestEvent, FoxError};
use tracing::info;
use crate::{app_state::AppState, fed::FoxRequestData};
pub async fn post_inbox(
Extension(state): Extension<Arc<AppState>>,
request: FoxRequestData,
Json(evt): Json<RestEvent>,
) -> Result<(StatusCode, &'static str), ApiError> {
match evt {
RestEvent::GuildJoin { guild_id, user_id } => {
info!(
"received GUILD_JOIN event from {}/{} for user {}/guild {}",
&request.instance.id, &request.instance.domain, &user_id, &guild_id,
);
let mut tx = state.pool.begin().await?;
sqlx::query!(
r#"INSERT INTO chat_instance_accounts
(account_id, chat_instance_id) VALUES ($1, $2)
ON CONFLICT (account_id, chat_instance_id) DO NOTHING"#,
&user_id,
&request.instance.id
)
.execute(&mut *tx)
.await?;
sqlx::query!(
r#"INSERT INTO guilds_accounts (chat_instance_id, guild_id, account_id)
VALUES ($1, $2, $3) ON CONFLICT (chat_instance_id, guild_id, account_id) DO NOTHING"#,
&request.instance.id,
&guild_id,
&user_id
)
.execute(&mut *tx)
.await?;
tx.commit().await?;
}
RestEvent::GuildLeave { guild_id, user_id } => {
sqlx::query!(
"DELETE FROM guilds_accounts WHERE chat_instance_id = $1 AND guild_id = $2 AND account_id = $3",
&request.instance.id,
&guild_id,
&user_id,
).execute(&state.pool).await?;
}
_ => {
return Err(FoxError::InvalidServer.into());
}
}
Ok((StatusCode::NO_CONTENT, ""))
}

View file

@ -1,12 +1,13 @@
mod account;
mod auth;
mod inbox;
mod node;
mod proxy;
mod ws;
use std::sync::Arc;
use axum::{routing::get, Extension, Router};
use axum::{routing::{get, post}, Extension, Router};
use sqlx::{Pool, Postgres};
use tower_http::trace::TraceLayer;
@ -26,6 +27,7 @@ pub fn new(pool: Pool<Postgres>, config: Config, instance: Instance) -> Router {
.route("/_fox/ident/node", get(node::get_node))
.route("/_fox/ident/node/:domain", get(node::get_chat_node))
.route("/_fox/ident/ws", get(ws::handler))
.route("/_fox/ident/inbox", post(inbox::post_inbox))
.layer(TraceLayer::new_for_http())
.layer(Extension(app_state));

View file

@ -37,7 +37,6 @@ async fn proxy_get<R: Serialize + DeserializeOwned>(
let proxy_path = original
.strip_prefix("/_fox/proxy/")
.wrap_err("invalid url")?;
println!("{}", proxy_path);
let resp = fed::get::<R>(
&state.private_key,
@ -49,7 +48,10 @@ async fn proxy_get<R: Serialize + DeserializeOwned>(
.await;
match resp {
Ok(r) => return Ok(Json(r)),
Ok(r) => match r {
Some(r) => Ok(Json(r)),
None => Err(FoxError::ResponseNotOk.into()),
},
Err(e) => {
if let Some(e) = e.downcast_ref::<ResponseError>() {
match e {
@ -102,7 +104,10 @@ async fn proxy_post<B: Serialize, R: Serialize + DeserializeOwned>(
.await;
match resp {
Ok(r) => return Ok(Json(r)),
Ok(r) => match r {
Some(r) => Ok(Json(r)),
None => Err(FoxError::ResponseNotOk.into()),
},
Err(e) => {
if let Some(e) = e.downcast_ref::<ResponseError>() {
match e {

View file

@ -9,13 +9,16 @@ use axum::{
Extension,
};
use eyre::Result;
use foxchat::{c2s::Payload, FoxError};
use foxchat::{c2s::{Payload, event::{InstanceId, GuildInstance}}, FoxError};
use futures::{
stream::{SplitSink, SplitStream, StreamExt},
SinkExt,
};
use rand::Rng;
use tokio::{sync::mpsc::{self, Receiver, Sender}, time::timeout};
use tokio::{
sync::mpsc::{self, Receiver, Sender},
time::timeout,
};
use tracing::error;
use crate::{app_state::AppState, db::check_token, model::account::Account};
@ -124,17 +127,42 @@ async fn read(
tokio::spawn(write(app_state.clone(), user.clone(), sender, rx));
// Send HELLO event
// TODO: fetch guild IDs
let heartbeat_interval = rand::thread_rng().gen_range(45_000..70_000);
tx.send(Payload::Hello { heartbeat_interval, guilds: vec![] }).await.ok();
// Fetch the instance ID/domain pairs
let instances = match sqlx::query_as!(InstanceId, r#"SELECT i.id, i.domain FROM chat_instances i
JOIN chat_instance_accounts ia ON ia.chat_instance_id = i.id
WHERE ia.account_id = $1"#, &user.id).fetch_all(&app_state.pool).await {
Ok(i) => i,
Err(e) => {
error!("error getting instances for user {}: {}", &user.id, e);
tx.send(Payload::Error { message: "Internal server error".into() }).await.ok();
return;
}
};
// Fetch guild ID/instance ID pairs
let guilds = match sqlx::query_as!(GuildInstance, r#"SELECT guild_id, chat_instance_id AS instance_id
FROM guilds_accounts WHERE account_id = $1"#, &user.id).fetch_all(&app_state.pool).await {
Ok(g) => g,
Err(e) => {
error!("error getting guilds for user {}: {}", &user.id, e);
tx.send(Payload::Error { message: "Internal server error".into() }).await.ok();
return;
}
};
tx.send(Payload::Hello {
heartbeat_interval,
guilds,
instances,
})
.await
.ok();
// Start the heartbeat loop
let (heartbeat_tx, heartbeat_rx) = mpsc::channel::<u64>(10);
tokio::spawn(heartbeat(heartbeat_interval, heartbeat_rx, tx.clone()));
// Fire off ready event
tokio::spawn(collect_ready(app_state.clone(), user.clone(), tx.clone()));
// Start listening for events
while let Some(msg) = receiver.next().await {
let Ok(msg) = msg else {
@ -237,8 +265,6 @@ async fn write(
}
}
async fn collect_ready(app_state: Arc<AppState>, user: Arc<Account>, tx: Sender<Payload>) {}
async fn heartbeat(
heartbeat_interval: u64,
mut rx: mpsc::Receiver<u64>,

View file

@ -58,7 +58,7 @@ impl ChatInstance {
host: state.config.domain.clone(),
},
)
.await?;
.await?.expect("");
if resp.host != domain {
return Err(FoxError::InvalidServer.into());