2022-05-02 17:19:37 +02:00
|
|
|
package auth
|
|
|
|
|
|
|
|
import (
|
|
|
|
"net/http"
|
2022-05-17 22:35:26 +02:00
|
|
|
"os"
|
2022-05-02 17:19:37 +02:00
|
|
|
|
2022-06-17 15:18:44 +02:00
|
|
|
"codeberg.org/u1f320/pronouns.cc/backend/db"
|
2022-05-14 16:52:08 +02:00
|
|
|
"codeberg.org/u1f320/pronouns.cc/backend/log"
|
|
|
|
"codeberg.org/u1f320/pronouns.cc/backend/server"
|
2022-05-04 16:27:16 +02:00
|
|
|
"emperror.dev/errors"
|
2022-05-02 17:19:37 +02:00
|
|
|
"github.com/go-chi/chi/v5"
|
2022-05-04 16:27:16 +02:00
|
|
|
"github.com/go-chi/render"
|
2022-06-17 15:18:44 +02:00
|
|
|
"github.com/rs/xid"
|
2022-05-02 17:19:37 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
type Server struct {
|
|
|
|
*server.Server
|
2022-05-17 22:35:26 +02:00
|
|
|
|
|
|
|
RequireInvite bool
|
2022-05-02 17:19:37 +02:00
|
|
|
}
|
|
|
|
|
2022-06-17 15:18:44 +02:00
|
|
|
type userResponse struct {
|
|
|
|
ID xid.ID `json:"id"`
|
|
|
|
Username string `json:"username"`
|
|
|
|
DisplayName *string `json:"display_name"`
|
|
|
|
Bio *string `json:"bio"`
|
|
|
|
AvatarURL *string `json:"avatar_url"`
|
|
|
|
Links []string `json:"links"`
|
|
|
|
|
|
|
|
Discord *string `json:"discord"`
|
|
|
|
DiscordUsername *string `json:"discord_username"`
|
|
|
|
}
|
|
|
|
|
|
|
|
func dbUserToUserResponse(u db.User) *userResponse {
|
|
|
|
return &userResponse{
|
|
|
|
ID: u.ID,
|
|
|
|
Username: u.Username,
|
|
|
|
DisplayName: u.DisplayName,
|
|
|
|
Bio: u.Bio,
|
|
|
|
AvatarURL: u.AvatarURL,
|
|
|
|
Links: u.Links,
|
|
|
|
Discord: u.Discord,
|
|
|
|
DiscordUsername: u.DiscordUsername,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-05-02 17:19:37 +02:00
|
|
|
func Mount(srv *server.Server, r chi.Router) {
|
2022-05-17 22:35:26 +02:00
|
|
|
s := &Server{
|
|
|
|
Server: srv,
|
|
|
|
RequireInvite: os.Getenv("REQUIRE_INVITE") == "true",
|
|
|
|
}
|
2022-05-02 17:19:37 +02:00
|
|
|
|
2022-05-04 16:27:16 +02:00
|
|
|
r.Route("/auth", func(r chi.Router) {
|
2022-05-17 22:35:26 +02:00
|
|
|
// check if username is taken
|
|
|
|
r.Get("/username", server.WrapHandler(s.usernameTaken))
|
|
|
|
|
2022-05-04 16:27:16 +02:00
|
|
|
// generate csrf token, returns all supported OAuth provider URLs
|
2022-05-12 16:41:32 +02:00
|
|
|
r.Post("/urls", server.WrapHandler(s.oauthURLs))
|
2022-05-02 17:19:37 +02:00
|
|
|
|
2022-05-04 16:27:16 +02:00
|
|
|
r.Route("/discord", func(r chi.Router) {
|
|
|
|
// takes code + state, validates it, returns token OR discord signup ticket
|
2022-05-12 16:41:32 +02:00
|
|
|
r.Post("/callback", server.WrapHandler(s.discordCallback))
|
2022-05-04 16:27:16 +02:00
|
|
|
// takes discord signup ticket to register account
|
|
|
|
r.Post("/signup", nil)
|
2022-05-02 17:19:37 +02:00
|
|
|
})
|
|
|
|
})
|
|
|
|
}
|
2022-05-04 16:27:16 +02:00
|
|
|
|
|
|
|
type oauthURLsRequest struct {
|
2022-05-12 16:41:32 +02:00
|
|
|
CallbackDomain string `json:"callback_domain"`
|
2022-05-04 16:27:16 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
type oauthURLsResponse struct {
|
|
|
|
Discord string `json:"discord"`
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Server) oauthURLs(w http.ResponseWriter, r *http.Request) error {
|
|
|
|
req, err := Decode[oauthURLsRequest](r)
|
|
|
|
if err != nil {
|
2022-05-12 16:41:32 +02:00
|
|
|
log.Error(err)
|
|
|
|
|
2022-05-04 16:27:16 +02:00
|
|
|
return server.APIError{Code: server.ErrBadRequest}
|
|
|
|
}
|
|
|
|
|
|
|
|
// generate CSRF state
|
|
|
|
state, err := s.setCSRFState(r.Context())
|
|
|
|
if err != nil {
|
|
|
|
return errors.Wrap(err, "setting CSRF state")
|
|
|
|
}
|
|
|
|
|
|
|
|
// copy Discord config and set redirect url
|
|
|
|
discordCfg := discordOAuthConfig
|
2022-05-12 16:41:32 +02:00
|
|
|
discordCfg.RedirectURL = req.CallbackDomain + "/login/discord"
|
2022-05-04 16:27:16 +02:00
|
|
|
|
|
|
|
render.JSON(w, r, oauthURLsResponse{
|
|
|
|
Discord: discordCfg.AuthCodeURL(state),
|
|
|
|
})
|
|
|
|
return nil
|
|
|
|
}
|
2022-05-17 22:35:26 +02:00
|
|
|
|
|
|
|
func (s *Server) usernameTaken(w http.ResponseWriter, r *http.Request) error {
|
|
|
|
type Response struct {
|
|
|
|
Valid bool `json:"valid"`
|
|
|
|
Taken bool `json:"taken"`
|
|
|
|
}
|
|
|
|
|
|
|
|
name := r.FormValue("username")
|
|
|
|
if name == "" {
|
|
|
|
render.JSON(w, r, Response{
|
|
|
|
Valid: false,
|
|
|
|
})
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
valid, taken, err := s.DB.UsernameTaken(r.Context(), name)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
render.JSON(w, r, Response{
|
|
|
|
Valid: valid,
|
|
|
|
Taken: taken,
|
|
|
|
})
|
|
|
|
return nil
|
|
|
|
}
|