add basic guild create + message create endpoints

This commit is contained in:
sam 2024-01-20 16:43:03 +01:00
parent 5b23095520
commit e57bff00c2
Signed by: sam
GPG key ID: B4EF20DDE721CAA1
27 changed files with 367 additions and 36 deletions

2
Cargo.lock generated
View file

@ -413,6 +413,7 @@ dependencies = [
"iana-time-zone", "iana-time-zone",
"js-sys", "js-sys",
"num-traits", "num-traits",
"serde",
"wasm-bindgen", "wasm-bindgen",
"windows-targets 0.48.5", "windows-targets 0.48.5",
] ]
@ -757,6 +758,7 @@ dependencies = [
"sqlx", "sqlx",
"thiserror", "thiserror",
"tracing", "tracing",
"ulid",
"uuid", "uuid",
] ]

View file

@ -31,6 +31,8 @@ create table guilds_users (
guild_id text not null references guilds (id) on delete cascade, guild_id text not null references guilds (id) on delete cascade,
user_id text not null references users (id) on delete cascade, user_id text not null references users (id) on delete cascade,
joined_at timestamptz not null default now(),
primary key (guild_id, user_id) primary key (guild_id, user_id)
); );

47
chat/src/db/channel.rs Normal file
View file

@ -0,0 +1,47 @@
use eyre::Result;
use foxchat::FoxError;
use sqlx::PgExecutor;
use ulid::Ulid;
pub struct Channel {
pub id: String,
pub guild_id: String,
pub name: String,
pub topic: Option<String>,
}
pub async fn create_channel(
executor: impl PgExecutor<'_>,
guild_id: &str,
name: &str,
topic: Option<String>,
) -> Result<Channel> {
let channel = sqlx::query_as!(
Channel,
"insert into channels (id, guild_id, name, topic) values ($1, $2, $3, $4) returning *",
Ulid::new().to_string(),
guild_id,
name,
topic
)
.fetch_one(executor)
.await?;
Ok(channel)
}
pub async fn get_channel(executor: impl PgExecutor<'_>, channel_id: &str) -> Result<Channel, FoxError> {
let channel = sqlx::query_as!(Channel, "select * from channels where id = $1", channel_id)
.fetch_one(executor)
.await
.map_err(|e| match e {
sqlx::Error::RowNotFound => FoxError::NotInGuild,
_ => {
tracing::error!("database error: {}", e);
return FoxError::DatabaseError;
}
})?;
Ok(channel)
}

71
chat/src/db/guild.rs Normal file
View file

@ -0,0 +1,71 @@
use eyre::Result;
use foxchat::FoxError;
use sqlx::PgExecutor;
use ulid::Ulid;
pub struct Guild {
pub id: String,
pub owner_id: String,
pub name: String,
}
pub async fn create_guild(
executor: impl PgExecutor<'_>,
owner_id: &str,
name: &str,
) -> Result<Guild> {
let guild = sqlx::query_as!(
Guild,
"insert into guilds (id, owner_id, name) values ($1, $2, $3) returning *",
Ulid::new().to_string(),
owner_id,
name
)
.fetch_one(executor)
.await?;
Ok(guild)
}
pub async fn join_guild(
executor: impl PgExecutor<'_>,
guild_id: &str,
user_id: &str,
) -> Result<()> {
sqlx::query!(
"insert into guilds_users (guild_id, user_id) values ($1, $2) on conflict (guild_id, user_id) do nothing",
guild_id,
user_id
)
.execute(executor)
.await?;
Ok(())
}
pub async fn get_guild(
executor: impl PgExecutor<'_>,
guild_id: &str,
user_id: &str,
) -> Result<Guild, FoxError> {
println!("guild id: {}, user id: {}", guild_id, user_id);
let guild = sqlx::query_as!(
Guild,
"select g.* from guilds_users u join guilds g on u.guild_id = g.id where u.guild_id = $1 and u.user_id = $2",
guild_id,
user_id
)
.fetch_one(executor)
.await
.map_err(|e| match e {
sqlx::Error::RowNotFound => FoxError::NotInGuild,
_ => {
tracing::error!("database error: {}", e);
return FoxError::DatabaseError;
}
})?;
Ok(guild)
}

33
chat/src/db/message.rs Normal file
View file

@ -0,0 +1,33 @@
use chrono::{DateTime, Utc};
use eyre::Result;
use foxchat::model::http::channel::CreateMessageParams;
use sqlx::PgExecutor;
use ulid::Ulid;
pub struct Message {
pub id: String,
pub channel_id: String,
pub author_id: String,
pub updated_at: DateTime<Utc>,
pub content: String,
}
pub async fn create_message(
executor: impl PgExecutor<'_>,
channel_id: &str,
user_id: &str,
params: CreateMessageParams,
) -> Result<Message> {
let message = sqlx::query_as!(
Message,
"insert into messages (id, channel_id, author_id, content) values ($1, $2, $3, $4) returning *",
Ulid::new().to_string(),
channel_id,
user_id,
params.content,
)
.fetch_one(executor)
.await?;
Ok(message)
}

View file

@ -1,3 +1,7 @@
pub mod guild;
pub mod channel;
pub mod message;
use eyre::{OptionExt, Result}; use eyre::{OptionExt, Result};
use rsa::pkcs1::{EncodeRsaPrivateKey, EncodeRsaPublicKey, LineEnding}; use rsa::pkcs1::{EncodeRsaPrivateKey, EncodeRsaPublicKey, LineEnding};
use rsa::{RsaPrivateKey, RsaPublicKey}; use rsa::{RsaPrivateKey, RsaPublicKey};

View file

@ -0,0 +1,48 @@
use std::sync::Arc;
use axum::{extract::Path, Extension, Json};
use eyre::Result;
use foxchat::{
http::ApiError,
model::{http::channel::CreateMessageParams, Message, user::PartialUser},
FoxError, ulid_timestamp,
};
use crate::{
app_state::AppState,
db::{channel::get_channel, guild::get_guild, message::create_message},
fed::FoxRequestData,
model::user::User,
};
pub async fn post_messages(
Extension(state): Extension<Arc<AppState>>,
request: FoxRequestData,
Path(channel_id): Path<String>,
Json(params): Json<CreateMessageParams>,
) -> Result<Json<Message>, ApiError> {
let user_id = request.user_id.ok_or(FoxError::MissingUser)?;
let user = User::get(&state, &request.instance, &user_id).await?;
let mut tx = state.pool.begin().await?;
let channel = get_channel(&mut *tx, &channel_id).await?;
let _guild = get_guild(&mut *tx, &channel.guild_id, &user.id).await?;
let message = create_message(&mut *tx, &channel.id, &user.id, params).await?;
tx.commit().await?;
// TODO: dispatch message create event
Ok(Json(Message {
id: message.id.clone(),
channel_id: channel.id,
author: PartialUser {
id: user.id,
username: user.username,
instance: request.instance.domain,
},
content: Some(message.content),
created_at: ulid_timestamp(&message.id),
}))
}

View file

@ -0,0 +1,7 @@
mod messages;
use axum::{routing::post, Router};
pub fn router() -> Router {
Router::new().route("/_fox/chat/channels/:id/messages", post(messages::post_messages))
}

View file

@ -3,30 +3,32 @@ use std::sync::Arc;
use axum::{Extension, Json}; use axum::{Extension, Json};
use foxchat::{ use foxchat::{
http::ApiError, http::ApiError,
model::{http::guild::CreateGuildParams, Guild, user::PartialUser}, model::{http::guild::CreateGuildParams, user::PartialUser, Guild, channel::PartialChannel},
FoxError, FoxError,
}; };
use ulid::Ulid;
use crate::{app_state::AppState, fed::FoxRequestData, model::user::User}; use crate::{
app_state::AppState,
db::{channel::create_channel, guild::{create_guild, join_guild}},
fed::FoxRequestData,
model::user::User,
};
pub async fn create_guild( pub async fn post_guilds(
Extension(state): Extension<Arc<AppState>>, Extension(state): Extension<Arc<AppState>>,
request: FoxRequestData, request: FoxRequestData,
Json(params): Json<CreateGuildParams>, Json(params): Json<CreateGuildParams>,
) -> Result<Json<Guild>, ApiError> { ) -> Result<Json<Guild>, ApiError> {
let user_id = request.user_id.ok_or(FoxError::MissingUser)?; let user_id = request.user_id.ok_or(FoxError::MissingUser)?;
let user = User::get(&state, &request.instance, &user_id).await?;
let user = User::get(&state, &request.instance, user_id).await?; let mut tx = state.pool.begin().await?;
let guild = create_guild(&mut *tx, &user.id, &params.name).await?;
let channel = create_channel(&mut *tx, &guild.id, "general", None).await?;
let guild = sqlx::query!( join_guild(&mut *tx, &guild.id, &user.id).await?;
"insert into guilds (id, owner_id, name) values ($1, $2, $3) returning *",
Ulid::new().to_string(), tx.commit().await?;
user.id,
params.name
)
.fetch_one(&state.pool)
.await?;
Ok(Json(Guild { Ok(Json(Guild {
id: guild.id, id: guild.id,
@ -34,7 +36,11 @@ pub async fn create_guild(
owner: PartialUser { owner: PartialUser {
id: user.id, id: user.id,
username: user.username, username: user.username,
instance: request.instance.domain, instance: user.instance.domain,
},
default_channel: PartialChannel {
id: channel.id,
name: channel.name,
} }
})) }))
} }

View file

@ -4,5 +4,5 @@ use axum::{Router, routing::post};
pub fn router() -> Router { pub fn router() -> Router {
Router::new() Router::new()
.route("/_fox/chat/guilds", post(create_guild::create_guild)) .route("/_fox/chat/guilds", post(create_guild::post_guilds))
} }

View file

@ -1,7 +1,10 @@
use axum::Router; use axum::Router;
pub mod channels;
pub mod guilds; pub mod guilds;
pub fn router() -> Router { pub fn router() -> Router {
Router::new().merge(guilds::router()) Router::new()
.merge(guilds::router())
.merge(channels::router())
} }

View file

@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
use crate::app_state::AppState; use crate::app_state::AppState;
#[derive(Serialize)] #[derive(Serialize, Clone, Debug)]
pub struct IdentityInstance { pub struct IdentityInstance {
pub id: String, pub id: String,
pub domain: String, pub domain: String,
@ -17,7 +17,7 @@ pub struct IdentityInstance {
pub reason: Option<String>, pub reason: Option<String>,
} }
#[derive(Serialize, Deserialize, Debug, sqlx::Type)] #[derive(Serialize, Deserialize, Clone, Debug, sqlx::Type)]
#[sqlx(type_name = "instance_status", rename_all = "lowercase")] #[sqlx(type_name = "instance_status", rename_all = "lowercase")]
pub enum InstanceStatus { pub enum InstanceStatus {
Active, Active,

View file

@ -11,6 +11,7 @@ use super::identity_instance::IdentityInstance;
pub struct User { pub struct User {
pub id: String, pub id: String,
pub instance_id: String, pub instance_id: String,
pub instance: IdentityInstance,
pub remote_user_id: String, pub remote_user_id: String,
pub username: String, pub username: String,
pub avatar: Option<String>, pub avatar: Option<String>,
@ -20,10 +21,9 @@ impl User {
pub async fn get( pub async fn get(
state: &Arc<AppState>, state: &Arc<AppState>,
instance: &IdentityInstance, instance: &IdentityInstance,
remote_id: String, remote_id: &str,
) -> Result<User> { ) -> Result<User> {
if let Some(user) = sqlx::query_as!( if let Some(user) = sqlx::query!(
User,
"select * from users where instance_id = $1 and remote_user_id = $2", "select * from users where instance_id = $1 and remote_user_id = $2",
instance.id, instance.id,
remote_id remote_id
@ -31,7 +31,14 @@ impl User {
.fetch_optional(&state.pool) .fetch_optional(&state.pool)
.await? .await?
{ {
return Ok(user); return Ok(User {
id: user.id,
username: user.username,
instance_id: user.instance_id,
instance: instance.to_owned(),
remote_user_id: user.remote_user_id,
avatar: user.avatar,
});
} }
let http_user = fed::get::<HttpUser>( let http_user = fed::get::<HttpUser>(
@ -43,8 +50,7 @@ impl User {
) )
.await?; .await?;
let user = sqlx::query_as!( let user = sqlx::query!(
User,
"insert into users (id, instance_id, remote_user_id, username, avatar) values ($1, $2, $3, $4, $5) returning *", "insert into users (id, instance_id, remote_user_id, username, avatar) values ($1, $2, $3, $4, $5) returning *",
Ulid::new().to_string(), Ulid::new().to_string(),
instance.id, instance.id,
@ -55,6 +61,13 @@ impl User {
.fetch_one(&state.pool) .fetch_one(&state.pool)
.await?; .await?;
Ok(user) Ok(User {
id: user.id,
username: user.username,
instance_id: user.instance_id,
instance: instance.to_owned(),
remote_user_id: user.remote_user_id,
avatar: user.avatar,
})
} }
} }

View file

@ -9,7 +9,7 @@ edition = "2021"
addr = "0.15.6" addr = "0.15.6"
axum = "0.7.4" axum = "0.7.4"
base64 = "0.21.7" base64 = "0.21.7"
chrono = "0.4.31" chrono = { version = "0.4.31", features = ["serde"] }
eyre = "0.6.11" eyre = "0.6.11"
once_cell = "1.19.0" once_cell = "1.19.0"
rand = "0.8.5" rand = "0.8.5"
@ -20,4 +20,5 @@ serde_json = "1.0.111"
sqlx = "0.7.3" sqlx = "0.7.3"
thiserror = "1.0.56" thiserror = "1.0.56"
tracing = "0.1.40" tracing = "0.1.40"
ulid = "1.1.0"
uuid = { version = "1.6.1", features = ["v7"] } uuid = { version = "1.6.1", features = ["v7"] }

View file

@ -7,6 +7,8 @@ pub enum FoxError {
NotFound, NotFound,
#[error("date for signature out of range")] #[error("date for signature out of range")]
SignatureDateOutOfRange(&'static str), SignatureDateOutOfRange(&'static str),
#[error("database error")]
DatabaseError,
#[error("non-200 response to federation request")] #[error("non-200 response to federation request")]
ResponseNotOk, ResponseNotOk,
#[error("server is invalid")] #[error("server is invalid")]
@ -23,6 +25,10 @@ pub enum FoxError {
Unauthorized, Unauthorized,
#[error("missing target user ID")] #[error("missing target user ID")]
MissingUser, MissingUser,
#[error("user is not in guild or guild doesn't exist")]
NotInGuild,
#[error("channel not found")]
ChannelNotFound,
} }
impl From<ToStrError> for FoxError { impl From<ToStrError> for FoxError {

View file

@ -42,6 +42,7 @@ pub enum ErrorCode {
InvalidDate, InvalidDate,
InvalidSignature, InvalidSignature,
MissingSignature, MissingSignature,
GuildNotFound,
Unauthorized, Unauthorized,
} }
@ -87,6 +88,11 @@ impl From<FoxError> for ApiError {
code: ErrorCode::ObjectNotFound, code: ErrorCode::ObjectNotFound,
message: "Object not found".into(), message: "Object not found".into(),
}, },
FoxError::DatabaseError => ApiError {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: ErrorCode::InternalServerError,
message: "Database error".into(),
},
FoxError::SignatureDateOutOfRange(s) => ApiError { FoxError::SignatureDateOutOfRange(s) => ApiError {
status: StatusCode::BAD_REQUEST, status: StatusCode::BAD_REQUEST,
code: ErrorCode::InvalidSignature, code: ErrorCode::InvalidSignature,
@ -127,10 +133,20 @@ impl From<FoxError> for ApiError {
code: ErrorCode::InvalidHeader, code: ErrorCode::InvalidHeader,
message: "Missing user header".into(), message: "Missing user header".into(),
}, },
FoxError::NotInGuild => ApiError {
status: StatusCode::NOT_FOUND,
code: ErrorCode::GuildNotFound,
message: "Channel or guild not found".into(),
},
FoxError::Unauthorized => ApiError { FoxError::Unauthorized => ApiError {
status: StatusCode::UNAUTHORIZED, status: StatusCode::UNAUTHORIZED,
code: ErrorCode::Unauthorized, code: ErrorCode::Unauthorized,
message: "Missing or invalid token".into(), message: "Missing or invalid token".into(),
},
FoxError::ChannelNotFound => ApiError {
status: StatusCode::NOT_FOUND,
code: ErrorCode::GuildNotFound,
message: "Channel or guild not found".into(),
} }
} }
} }

View file

@ -6,3 +6,20 @@ pub mod s2s;
pub use error::FoxError; pub use error::FoxError;
pub use fed::signature; pub use fed::signature;
use chrono::{DateTime, Utc};
use ulid::Ulid;
/// Extracts a DateTime from a ULID.
/// This function should only be used on valid ULIDs (such as those used as primary keys), else it will fail and panic!
pub fn ulid_timestamp(id: &str) -> DateTime<Utc> {
let ts = Ulid::from_string(id).expect("invalid ULID").timestamp_ms();
let (secs, rem) = (ts / 1000, ts % 1000);
let nsecs = rem * 1000000;
DateTime::from_timestamp(
secs.try_into().expect("converting secs to i64"),
nsecs.try_into().expect("converting nsecs to i32"),
)
.expect("converting timestamp to DateTime")
}

View file

@ -0,0 +1,15 @@
use serde::{Serialize, Deserialize};
#[derive(Serialize, Deserialize, Debug)]
pub struct Channel {
pub id: String,
pub guild_id: String,
pub name: String,
pub topic: Option<String>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct PartialChannel {
pub id: String,
pub name: String,
}

View file

@ -1,10 +1,11 @@
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use super::user::PartialUser; use super::{user::PartialUser, channel::PartialChannel};
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct Guild { pub struct Guild {
pub id: String, pub id: String,
pub name: String, pub name: String,
pub owner: PartialUser, pub owner: PartialUser,
pub default_channel: PartialChannel,
} }

View file

@ -0,0 +1,6 @@
use serde::{Serialize, Deserialize};
#[derive(Serialize, Deserialize, Debug)]
pub struct CreateMessageParams {
pub content: String,
}

View file

@ -1 +1,2 @@
pub mod guild; pub mod guild;
pub mod channel;

View file

@ -0,0 +1,13 @@
use chrono::{DateTime, Utc};
use serde::{Serialize, Deserialize};
use super::user::PartialUser;
#[derive(Serialize, Deserialize, Debug)]
pub struct Message {
pub id: String,
pub channel_id: String,
pub author: PartialUser,
pub content: Option<String>,
pub created_at: DateTime<Utc>,
}

View file

@ -1,6 +1,10 @@
pub mod channel;
pub mod guild; pub mod guild;
pub mod user;
pub mod http; pub mod http;
pub mod message;
pub mod user;
pub use channel::Channel;
pub use guild::Guild; pub use guild::Guild;
pub use message::Message;
pub use user::User; pub use user::User;

View file

@ -0,0 +1,7 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Dispatch {
MessageCreate
}

View file

@ -1,11 +1,13 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::Dispatch;
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")] #[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Payload { pub enum Payload {
Dispatch { Dispatch {
#[serde(rename = "e")] #[serde(rename = "e")]
event: DispatchEvent, event: Dispatch,
#[serde(rename = "r")] #[serde(rename = "r")]
recipients: Vec<String>, recipients: Vec<String>,
}, },
@ -14,7 +16,3 @@ pub enum Payload {
token: String, token: String,
}, },
} }
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "t", content = "d", rename_all = "SCREAMING_SNAKE_CASE")]
pub enum DispatchEvent {}

View file

@ -1,4 +1,6 @@
mod dispatch;
mod event; mod event;
pub mod http; pub mod http;
pub use event::{DispatchEvent, Payload}; pub use event::Payload;
pub use dispatch::Dispatch;

View file

@ -5,7 +5,10 @@ use eyre::ContextCompat;
use foxchat::{ use foxchat::{
fed, fed,
http::ApiError, http::ApiError,
model::{http::guild::CreateGuildParams, Guild}, model::{
http::{channel::CreateMessageParams, guild::CreateGuildParams},
Guild, Message,
},
}; };
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use tracing::debug; use tracing::debug;
@ -15,7 +18,12 @@ use crate::{app_state::AppState, fed::ProxyServerHeader};
use super::auth::AuthUser; use super::auth::AuthUser;
pub fn router() -> Router { pub fn router() -> Router {
Router::new().route("/guilds", post(proxy_post::<CreateGuildParams, Guild>)) Router::new()
.route("/guilds", post(proxy_post::<CreateGuildParams, Guild>))
.route(
"/channels/:id/messages",
post(proxy_post::<CreateMessageParams, Message>),
)
} }
async fn proxy_get<R: Serialize + DeserializeOwned>( async fn proxy_get<R: Serialize + DeserializeOwned>(