use std::sync::Arc;

use axum::{
    extract::{
        ws::{Message, WebSocket},
        WebSocketUpgrade,
    },
    response::Response,
    Extension,
};
use eyre::{eyre, Error, Result};
use foxchat::{
    s2s::{Dispatch, Payload},
    signature::{parse_date, verify_signature},
};
use futures::{
    stream::{SplitSink, SplitStream, StreamExt},
    SinkExt,
};
use tokio::sync::{broadcast, mpsc, RwLock};
use tracing::error;

use crate::{app_state::AppState, model::identity_instance::IdentityInstance};

pub async fn handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<AppState>>) -> Response {
    ws.on_upgrade(|socket| handle_socket(state, socket))
}

struct SocketState {
    instance: Option<IdentityInstance>,
}

async fn handle_socket(app_state: Arc<AppState>, socket: WebSocket) {
    let (sender, receiver) = socket.split();

    let (tx, rx) = mpsc::channel::<Payload>(10);
    let dispatch = app_state.broadcast.subscribe();
    let socket_state = Arc::new(RwLock::new(SocketState {
        instance: None, // Filled out after IDENTIFY
    }));

    tokio::spawn(merge_channels(
        app_state.clone(),
        socket_state.clone(),
        dispatch,
        tx.clone(),
    ));
    tokio::spawn(write(rx, sender));
    tokio::spawn(read(
        app_state.clone(),
        socket_state.clone(),
        tx.clone(),
        receiver,
    ));
}

async fn read(
    app_state: Arc<AppState>,
    socket_state: Arc<RwLock<SocketState>>,
    tx: mpsc::Sender<Payload>,
    mut receiver: SplitStream<WebSocket>,
) {
    let msg = receiver.next().await;
    if let Some(Ok(msg)) = msg {
        if let Ok(msg) = msg.into_text() {
            if let Ok(pl) = serde_json::from_str::<Payload>(&msg) {
                match pl {
                    Payload::Identify {
                        host,
                        date,
                        server,
                        signature,
                    } => {
                        let instance = match IdentityInstance::get(app_state, &server).await {
                            Ok(i) => i,
                            Err(e) => {
                                error!("getting instance {}: {}", server, e);
                                tx.send(Payload::Error {
                                    message: "Unknown instance".into(),
                                })
                                .await
                                .ok();
                                return;
                            }
                        };

                        let public_key = match instance.parse_public_key() {
                            Ok(k) => k,
                            Err(e) => {
                                error!("parsing public key for instance: {}", e);
                                tx.send(Payload::Error {
                                    message: "Internal server error".into(),
                                })
                                .await
                                .ok();
                                return;
                            }
                        };

                        let date = match parse_date(&date) {
                            Ok(d) => d,
                            Err(_) => {
                                tx.send(Payload::Error {
                                    message: "Invalid date".into(),
                                })
                                .await
                                .ok();
                                return;
                            }
                        };

                        if let Err(e) = verify_signature(
                            &public_key,
                            signature,
                            date,
                            &host,
                            "/_fox/chat/ws",
                            None,
                            None,
                        ) {
                            error!("Verifying signature in websocket from {}: {}", &server, e);

                            tx.send(Payload::Error {
                                message: "Invalid signature".into(),
                            })
                            .await
                            .ok();
                            return;
                        }

                        // Everything is good, store instance
                        socket_state.write().await.instance = Some(instance);
                    }
                    _ => {
                        tx.send(Payload::Error {
                            message: "First payload was not IDENTIFY".into(),
                        })
                        .await
                        .ok();
                        return;
                    }
                }
            } else {
                tx.send(Payload::Error {
                    message: "Invalid JSON payload".into(),
                })
                .await
                .ok();
                return;
            }
        } else {
            tx.send(Payload::Error {
                message: "First payload was not IDENTIFY".into(),
            })
            .await
            .ok();
            return;
        }
    } else {
        // Websocket closed, return
        return;
    }

    // Send hello
    tx.send(Payload::Hello {}).await.ok();

    while let Some(msg) = receiver.next().await {
        let Ok(msg) = msg else {
            return;
        };

        let Ok(msg) = msg.to_text() else {
            tx.send(Payload::Error {
                message: "Invalid message".into(),
            })
            .await
            .ok();
            return;
        };

        let Ok(msg) = serde_json::from_str::<Payload>(msg) else {
            tx.send(Payload::Error {
                message: "Invalid message".into(),
            })
            .await
            .ok();
            return;
        };

        match msg {
            Payload::Connect { user_id } => {}
        }

        // TODO: handle incoming payloads
    }
}

async fn write(mut rx: mpsc::Receiver<Payload>, mut sender: SplitSink<WebSocket, Message>) {
    loop {
        if let Some(ev) = rx.recv().await {
            match &ev {
                Payload::Error { message: _ } => {
                    let msg = match to_json(ev) {
                        Ok(m) => m,
                        Err(e) => {
                            error!("error serializing payload to JSON: {}", e);
                            return;
                        }
                    };

                    match sender.send(msg).await {
                        Ok(_) => {}
                        // Pretty sure we can assume the websocket was closed, so just return
                        Err(e) => {
                            error!("error writing to websocket: {}", e);
                            return;
                        }
                    }

                    match sender.close().await {
                        Ok(_) => {}
                        Err(e) => {
                            error!("error closing websocket: {}", e)
                        }
                    }

                    return;
                }
                _ => {
                    // Forward payload
                    let msg = match to_json(ev) {
                        Ok(m) => m,
                        Err(e) => {
                            error!("error serializing payload to JSON: {}", e);
                            return;
                        }
                    };

                    match sender.send(msg).await {
                        Ok(_) => {}
                        // Pretty sure we can assume the websocket was closed, so just return
                        Err(e) => {
                            error!("error writing to websocket: {}", e);
                            return;
                        }
                    }
                }
            }
        }
    }
}

async fn merge_channels(
    app_state: Arc<AppState>,
    socket_state: Arc<RwLock<SocketState>>,
    mut rx: broadcast::Receiver<Dispatch>,
    tx: mpsc::Sender<Payload>,
) {
    loop {
        let msg = rx.recv().await;
        match msg {
            Ok(evt) => {
                let (send, recipients) =
                    match filter_events(app_state.clone(), socket_state.clone(), &evt).await {
                        Ok(v) => v,
                        Err(e) => {
                            error!("Checking whether to send an event: {}", e);
                            tx.send(Payload::Error {
                                message: "Internal server error".into(),
                            })
                            .await
                            .ok();

                            continue;
                        }
                    };

                if !send {
                    continue;
                }

                tx.send(Payload::Dispatch {
                    event: evt,
                    recipients,
                })
                .await
                .ok();
            }
            Err(e) => match e {
                broadcast::error::RecvError::Closed => {
                    error!("Broadcast channel was closed, this is not supposed to happen");
                    return;
                }
                broadcast::error::RecvError::Lagged(i) => {
                    error!("Broadcast receive lagged by {i}");
                }
            },
        }
    }
}

fn to_json(p: Payload) -> Result<Message> {
    Ok(Message::Text(serde_json::to_string(&p)?))
}

/// Returns whether or not an event should be sent at all, and if it should be, which users to send it to.
async fn filter_events(
    app_state: Arc<AppState>,
    socket_state: Arc<RwLock<SocketState>>,
    evt: &Dispatch,
) -> Result<(bool, Vec<String>)> {
    let Some(instance) = &socket_state.read().await.instance else {
        return Ok((false, vec![]));
    };

    match evt {
        Dispatch::MessageCreate {
            id: _,
            channel_id: _,
            guild_id,
            author: _,
            content: _,
            created_at: _,
        } => {
            let users = sqlx::query!(
                r#"SELECT ARRAY(
                SELECT u.remote_user_id FROM users u
                JOIN guilds_users gu ON gu.user_id = u.id
                WHERE u.instance_id = $1 AND gu.guild_id = $2
            )"#,
                instance.id,
                guild_id
            )
            .fetch_one(&app_state.pool)
            .await?;

            if let Some(users) = users.array {
                return Ok((users.len() > 0, users));
            }
            return Ok((false, vec![]));
        }
        Dispatch::Ready { user, guilds: _ } => {
            let user = sqlx::query!(
                "SELECT remote_user_id FROM users WHERE id = $1",
                user.id.clone()
            )
            .fetch_one(&app_state.pool)
            .await?;

            return Ok((true, vec![user.remote_user_id]));
        }
    }
}

async fn collect_ready(
    app_state: Arc<AppState>,
    socket_state: Arc<RwLock<SocketState>>,
    user_id: String,
) -> Result<Payload> {
    let Some(instance) = &socket_state.read().await.instance else {
        return Err(eyre!("instance was None when it shouldn't be"));
    };

    let user = sqlx::query!(
        "SELECT * FROM users WHERE instance_id = $1 AND remote_user_id = $2",
        instance.id,
        user_id
    )
    .fetch_one(&app_state.pool)
    .await?;

    todo!()
}