oh no
This commit is contained in:
parent
8c1db3fadb
commit
e05419ebe9
8 changed files with 75 additions and 37 deletions
18
.gitignore
vendored
18
.gitignore
vendored
|
@ -1,9 +1,9 @@
|
|||
__pycache__/
|
||||
.pytest_cache/
|
||||
.env
|
||||
node_modules
|
||||
build
|
||||
.svelte-kit
|
||||
package
|
||||
vite.config.js.timestamp-*
|
||||
vite.config.ts.timestamp-*
|
||||
__pycache__/
|
||||
.pytest_cache/
|
||||
.env
|
||||
node_modules
|
||||
build
|
||||
.svelte-kit
|
||||
package
|
||||
vite.config.js.timestamp-*
|
||||
vite.config.ts.timestamp-*
|
||||
|
|
|
@ -1,24 +0,0 @@
|
|||
from functools import wraps
|
||||
|
||||
from quart import g
|
||||
|
||||
from foxnouns.exceptions import ErrorCode, ForbiddenError
|
||||
|
||||
|
||||
def require_auth(*, scope: str | None = None):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
if "user" not in g or "token" not in g:
|
||||
raise ForbiddenError("Not authenticated", type=ErrorCode.Forbidden)
|
||||
|
||||
if scope and not g.token.has_scope(scope):
|
||||
raise ForbiddenError(
|
||||
f"Missing scope '{scope}'", type=ErrorCode.MissingScope
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
|
@ -1,14 +1,16 @@
|
|||
from quart import Blueprint
|
||||
from quart_schema import validate_request, validate_response
|
||||
|
||||
from foxnouns.settings import BASE_DOMAIN
|
||||
from foxnouns import settings
|
||||
from foxnouns.decorators import require_config_key
|
||||
|
||||
from . import BaseCallbackResponse, OAuthCallbackRequest
|
||||
|
||||
bp = Blueprint("discord_v2", __name__)
|
||||
|
||||
|
||||
@bp.post("/api/v2/auth/discord/callback", host=BASE_DOMAIN)
|
||||
@bp.post("/api/v2/auth/discord/callback", host=settings.BASE_DOMAIN)
|
||||
@require_config_key(keys=[settings.DISCORD_CLIENT_ID, settings.DISCORD_CLIENT_SECRET])
|
||||
@validate_request(OAuthCallbackRequest)
|
||||
@validate_response(BaseCallbackResponse)
|
||||
async def discord_callback(data: OAuthCallbackRequest):
|
||||
|
|
|
@ -3,7 +3,7 @@ from quart import Blueprint, g
|
|||
from quart_schema import validate_request, validate_response
|
||||
|
||||
from foxnouns import tasks
|
||||
from foxnouns.auth import require_auth
|
||||
from foxnouns.decorators import require_auth
|
||||
from foxnouns.db import Member
|
||||
from foxnouns.db.aio import async_session
|
||||
from foxnouns.db.util import user_from_ref
|
||||
|
|
|
@ -4,7 +4,7 @@ from quart_schema import validate_request, validate_response
|
|||
from sqlalchemy import select
|
||||
|
||||
from foxnouns import tasks
|
||||
from foxnouns.auth import require_auth
|
||||
from foxnouns.decorators import require_auth
|
||||
from foxnouns.db import User
|
||||
from foxnouns.db.aio import async_session
|
||||
from foxnouns.db.snowflake import Snowflake
|
||||
|
|
45
foxnouns/decorators.py
Normal file
45
foxnouns/decorators.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
from typing import Any
|
||||
from functools import wraps
|
||||
|
||||
from quart import g
|
||||
|
||||
from foxnouns.exceptions import ErrorCode, ForbiddenError, UnsupportedEndpointError
|
||||
|
||||
|
||||
def require_auth(*, scope: str | None = None):
|
||||
"""Decorator that requires a token with the given scopes.
|
||||
If no token is given or the required scopes aren't set on it, execution is aborted."""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
if "user" not in g or "token" not in g:
|
||||
raise ForbiddenError("Not authenticated", type=ErrorCode.Forbidden)
|
||||
|
||||
if scope and not g.token.has_scope(scope):
|
||||
raise ForbiddenError(
|
||||
f"Missing scope '{scope}'", type=ErrorCode.MissingScope
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_config_key(*, keys: list[Any]):
|
||||
"""Decorator that requires one or more config keys to be set.
|
||||
If any of them are None, execution is aborted."""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
for key in keys:
|
||||
if not key:
|
||||
raise UnsupportedEndpointError()
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
|
@ -80,3 +80,14 @@ class ForbiddenError(ExpectedError):
|
|||
def __init__(self, msg: str, type=ErrorCode.Forbidden):
|
||||
self.type = type
|
||||
super().__init__(msg, type)
|
||||
|
||||
|
||||
class UnsupportedEndpointError(ExpectedError):
|
||||
status_code = 404
|
||||
|
||||
def __init__(self):
|
||||
self.type = ErrorCode.NotFound
|
||||
super().__init__(
|
||||
"Endpoint is not supported on this instance",
|
||||
type=ErrorCode.NotFound,
|
||||
)
|
||||
|
|
|
@ -27,6 +27,10 @@ with env.prefixed("MINIO_"):
|
|||
"REGION": env("REGION", "auto"),
|
||||
}
|
||||
|
||||
# Discord OAuth credentials. If these are not set the Discord OAuth endpoints will not work.
|
||||
DISCORD_CLIENT_ID = env("DISCORD_CLIENT_ID", None)
|
||||
DISCORD_CLIENT_SECRET = env("DISCORD_CLIENT_SECRET", None)
|
||||
|
||||
# The base domain the API is served on. This must be set.
|
||||
BASE_DOMAIN = env("BASE_DOMAIN")
|
||||
# The base domain for short URLs.
|
||||
|
|
Loading…
Reference in a new issue