From 8e752689df6a124dd762fd7cc09c1467f16f246d Mon Sep 17 00:00:00 2001 From: sam Date: Wed, 20 Mar 2024 03:37:11 +0100 Subject: [PATCH] feat: add auth --- foxnouns/app.py | 10 ++++++--- foxnouns/auth.py | 23 +++++++++++++++++++ foxnouns/blueprints/v2/users.py | 40 ++++++++++++++++++++++++++------- foxnouns/db/util.py | 19 ++++++++++++---- foxnouns/exceptions.py | 6 +++++ foxnouns/models/__init__.py | 8 +++++++ foxnouns/models/user.py | 4 ++++ pyproject.toml | 1 + 8 files changed, 96 insertions(+), 15 deletions(-) create mode 100644 foxnouns/auth.py diff --git a/foxnouns/app.py b/foxnouns/app.py index 6729c51..25745eb 100644 --- a/foxnouns/app.py +++ b/foxnouns/app.py @@ -34,6 +34,10 @@ async def get_user_from_token(): return async with async_session() as session: - token, user = await validate_token(session, token) - g.token = token - g.user = user + try: + token, user = await validate_token(session, token) + g.token = token + g.user = user + except Exception as e: + print(e) + raise diff --git a/foxnouns/auth.py b/foxnouns/auth.py new file mode 100644 index 0000000..be240e3 --- /dev/null +++ b/foxnouns/auth.py @@ -0,0 +1,23 @@ +from functools import wraps +from quart import g + +from foxnouns.exceptions import ForbiddenError, ErrorCode + + +def require_auth(*, scope: str | None = None): + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + if not ("user" in g) or not ("token" in g): + raise ForbiddenError("Not authenticated", type=ErrorCode.Forbidden) + + if scope and not g.token.has_scope(scope): + raise ForbiddenError( + f"Missing scope '{scope}'", type=ErrorCode.MissingScope + ) + + return await func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/foxnouns/blueprints/v2/users.py b/foxnouns/blueprints/v2/users.py index fdb5c56..7d3d890 100644 --- a/foxnouns/blueprints/v2/users.py +++ b/foxnouns/blueprints/v2/users.py @@ -1,11 +1,15 @@ -from pydantic import BaseModel, Field, field_validator -from quart import Blueprint +from pydantic import Field, field_validator +from quart import Blueprint, g from quart_schema import validate_response, validate_request +from sqlalchemy import select +from foxnouns.auth import require_auth +from foxnouns.db import User from foxnouns.db.aio import async_session -from foxnouns.db.util import user_from_ref +from foxnouns.db.util import user_from_ref, is_self from foxnouns.exceptions import NotFoundError, ErrorCode -from foxnouns.models.user import UserModel, check_username +from foxnouns.models import BasePatchModel +from foxnouns.models.user import UserModel, SelfUserModel, check_username from foxnouns.settings import BASE_DOMAIN bp = Blueprint("users_v2", __name__) @@ -18,14 +22,20 @@ async def get_user(user_ref: str): user = await user_from_ref(session, user_ref) if not user: raise NotFoundError("User not found", type=ErrorCode.UserNotFound) - return UserModel.model_validate(user) + + return ( + SelfUserModel.model_validate(user) + if is_self(user) + else UserModel.model_validate(user) + ) -class EditUserRequest(BaseModel): +class EditUserRequest(BasePatchModel): username: str | None = Field( min_length=2, max_length=40, pattern=r"^[\w\-\.]{2,40}$", default=None ) - display_name: str | None = Field(min_length=2, max_length=100, default=None) + display_name: str | None = Field(max_length=100, default=None) + bio: str | None = Field(max_length=1024, default=None) @field_validator("username") @classmethod @@ -34,6 +44,20 @@ class EditUserRequest(BaseModel): @bp.patch("/api/v2/users/@me", host=BASE_DOMAIN) +@require_auth(scope="user.update") @validate_request(EditUserRequest) +@validate_response(SelfUserModel, 200) async def edit_user(data: EditUserRequest): - return data + async with async_session() as session: + user = await session.scalar(select(User).where(User.id == g.user.id)) + + if data.username: + user.username = data.username + if "display_name" in data.model_fields_set: + user.display_name = data.display_name + if data.is_set("bio"): + user.bio = data.bio + + await session.commit() + + return SelfUserModel.model_validate(user) diff --git a/foxnouns/db/util.py b/foxnouns/db/util.py index bccf084..617db48 100644 --- a/foxnouns/db/util.py +++ b/foxnouns/db/util.py @@ -19,7 +19,12 @@ async def user_from_ref(session: AsyncSession, user_ref: str): if user_ref == "@me": if "user" in g: - query = query.where(User.id == g.user.id) + if g.token.has_scope("user.read"): + query = query.where(User.id == g.user.id) + else: + raise ForbiddenError( + f"Missing scope 'user.read'", type=ErrorCode.MissingScope + ) else: raise ForbiddenError("Not authenticated") else: @@ -55,9 +60,11 @@ async def validate_token(session: AsyncSession, header: str) -> (Token, User): except BadSignature: raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken) - row = (await session.execute( - select(Token, User).join(Token.user).where(Token.id == token_id) - )).first() + row = ( + await session.execute( + select(Token, User).join(Token.user).where(Token.id == token_id) + ) + ).first() if not row or not row.Token: raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken) @@ -66,3 +73,7 @@ async def validate_token(session: AsyncSession, header: str) -> (Token, User): raise ForbiddenError("Token has expired", type=ErrorCode.InvalidToken) return (row.Token, row.User) + + +def is_self(user: User) -> bool: + return "user" in g and g.user.id == user.id diff --git a/foxnouns/exceptions.py b/foxnouns/exceptions.py index cd5ac39..63739cc 100644 --- a/foxnouns/exceptions.py +++ b/foxnouns/exceptions.py @@ -8,6 +8,7 @@ class ErrorCode(enum.IntEnum): MethodNotAllowed = 405 TooManyRequests = 429 InternalServerError = 500 # catch-all code for unknown errors + # Login/authorize error codes InvalidState = 1001 InvalidOAuthCode = 1002 @@ -28,19 +29,24 @@ class ErrorCode(enum.IntEnum): 1016 # unlinking provider would leave account with no authentication method ) InvalidCaptcha = 1017 # invalid or missing captcha response + MissingScope = 1018 # missing the required scope for this endpoint + # User-related error codes UserNotFound = 2001 MemberListPrivate = 2002 FlagLimitReached = 2003 RerollingTooQuickly = 2004 + # Member-related error codes MemberNotFound = 3001 MemberLimitReached = 3002 MemberNameInUse = 3003 NotOwnMember = 3004 + # General request error codes RequestTooBig = 4001 MissingPermissions = 4002 + # Moderation related error codes ReportAlreadyHandled = 5001 NotSelfDelete = 5002 diff --git a/foxnouns/models/__init__.py b/foxnouns/models/__init__.py index 391a4e3..4a59cbf 100644 --- a/foxnouns/models/__init__.py +++ b/foxnouns/models/__init__.py @@ -1,5 +1,13 @@ +from typing import Any + from pydantic import BaseModel, field_validator +class BasePatchModel(BaseModel): + model_config = {"from_attributes": True} + + def is_set(self, key: str) -> bool: + return key in self.model_fields_set + class BaseSnowflakeModel(BaseModel): """A base model with a Snowflake ID that is serialized as a string. diff --git a/foxnouns/models/user.py b/foxnouns/models/user.py index d640145..2ba7ecf 100644 --- a/foxnouns/models/user.py +++ b/foxnouns/models/user.py @@ -9,6 +9,10 @@ class UserModel(BaseSnowflakeModel): bio: str | None +class SelfUserModel(UserModel): + pass + + def check_username(value): if not value: return value diff --git a/pyproject.toml b/pyproject.toml index 9753157..7801283 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ pytest = "^8.0.2" pytest-asyncio = "^0.23.5.post1" [tool.poe.tasks] +dev = "env QUART_APP=foxnouns.app:app quart --debug run --reload" server = "uvicorn 'foxnouns.app:app'" migrate = "alembic upgrade head" test = "pytest"