feat: add auth
This commit is contained in:
parent
3d7217ec69
commit
8e752689df
8 changed files with 96 additions and 15 deletions
|
@ -34,6 +34,10 @@ async def get_user_from_token():
|
|||
return
|
||||
|
||||
async with async_session() as session:
|
||||
try:
|
||||
token, user = await validate_token(session, token)
|
||||
g.token = token
|
||||
g.user = user
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise
|
||||
|
|
23
foxnouns/auth.py
Normal file
23
foxnouns/auth.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
from functools import wraps
|
||||
from quart import g
|
||||
|
||||
from foxnouns.exceptions import ForbiddenError, ErrorCode
|
||||
|
||||
|
||||
def require_auth(*, scope: str | None = None):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
if not ("user" in g) or not ("token" in g):
|
||||
raise ForbiddenError("Not authenticated", type=ErrorCode.Forbidden)
|
||||
|
||||
if scope and not g.token.has_scope(scope):
|
||||
raise ForbiddenError(
|
||||
f"Missing scope '{scope}'", type=ErrorCode.MissingScope
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
|
@ -1,11 +1,15 @@
|
|||
from pydantic import BaseModel, Field, field_validator
|
||||
from quart import Blueprint
|
||||
from pydantic import Field, field_validator
|
||||
from quart import Blueprint, g
|
||||
from quart_schema import validate_response, validate_request
|
||||
from sqlalchemy import select
|
||||
|
||||
from foxnouns.auth import require_auth
|
||||
from foxnouns.db import User
|
||||
from foxnouns.db.aio import async_session
|
||||
from foxnouns.db.util import user_from_ref
|
||||
from foxnouns.db.util import user_from_ref, is_self
|
||||
from foxnouns.exceptions import NotFoundError, ErrorCode
|
||||
from foxnouns.models.user import UserModel, check_username
|
||||
from foxnouns.models import BasePatchModel
|
||||
from foxnouns.models.user import UserModel, SelfUserModel, check_username
|
||||
from foxnouns.settings import BASE_DOMAIN
|
||||
|
||||
bp = Blueprint("users_v2", __name__)
|
||||
|
@ -18,14 +22,20 @@ async def get_user(user_ref: str):
|
|||
user = await user_from_ref(session, user_ref)
|
||||
if not user:
|
||||
raise NotFoundError("User not found", type=ErrorCode.UserNotFound)
|
||||
return UserModel.model_validate(user)
|
||||
|
||||
return (
|
||||
SelfUserModel.model_validate(user)
|
||||
if is_self(user)
|
||||
else UserModel.model_validate(user)
|
||||
)
|
||||
|
||||
|
||||
class EditUserRequest(BaseModel):
|
||||
class EditUserRequest(BasePatchModel):
|
||||
username: str | None = Field(
|
||||
min_length=2, max_length=40, pattern=r"^[\w\-\.]{2,40}$", default=None
|
||||
)
|
||||
display_name: str | None = Field(min_length=2, max_length=100, default=None)
|
||||
display_name: str | None = Field(max_length=100, default=None)
|
||||
bio: str | None = Field(max_length=1024, default=None)
|
||||
|
||||
@field_validator("username")
|
||||
@classmethod
|
||||
|
@ -34,6 +44,20 @@ class EditUserRequest(BaseModel):
|
|||
|
||||
|
||||
@bp.patch("/api/v2/users/@me", host=BASE_DOMAIN)
|
||||
@require_auth(scope="user.update")
|
||||
@validate_request(EditUserRequest)
|
||||
@validate_response(SelfUserModel, 200)
|
||||
async def edit_user(data: EditUserRequest):
|
||||
return data
|
||||
async with async_session() as session:
|
||||
user = await session.scalar(select(User).where(User.id == g.user.id))
|
||||
|
||||
if data.username:
|
||||
user.username = data.username
|
||||
if "display_name" in data.model_fields_set:
|
||||
user.display_name = data.display_name
|
||||
if data.is_set("bio"):
|
||||
user.bio = data.bio
|
||||
|
||||
await session.commit()
|
||||
|
||||
return SelfUserModel.model_validate(user)
|
||||
|
|
|
@ -19,7 +19,12 @@ async def user_from_ref(session: AsyncSession, user_ref: str):
|
|||
|
||||
if user_ref == "@me":
|
||||
if "user" in g:
|
||||
if g.token.has_scope("user.read"):
|
||||
query = query.where(User.id == g.user.id)
|
||||
else:
|
||||
raise ForbiddenError(
|
||||
f"Missing scope 'user.read'", type=ErrorCode.MissingScope
|
||||
)
|
||||
else:
|
||||
raise ForbiddenError("Not authenticated")
|
||||
else:
|
||||
|
@ -55,9 +60,11 @@ async def validate_token(session: AsyncSession, header: str) -> (Token, User):
|
|||
except BadSignature:
|
||||
raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken)
|
||||
|
||||
row = (await session.execute(
|
||||
row = (
|
||||
await session.execute(
|
||||
select(Token, User).join(Token.user).where(Token.id == token_id)
|
||||
)).first()
|
||||
)
|
||||
).first()
|
||||
|
||||
if not row or not row.Token:
|
||||
raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken)
|
||||
|
@ -66,3 +73,7 @@ async def validate_token(session: AsyncSession, header: str) -> (Token, User):
|
|||
raise ForbiddenError("Token has expired", type=ErrorCode.InvalidToken)
|
||||
|
||||
return (row.Token, row.User)
|
||||
|
||||
|
||||
def is_self(user: User) -> bool:
|
||||
return "user" in g and g.user.id == user.id
|
||||
|
|
|
@ -8,6 +8,7 @@ class ErrorCode(enum.IntEnum):
|
|||
MethodNotAllowed = 405
|
||||
TooManyRequests = 429
|
||||
InternalServerError = 500 # catch-all code for unknown errors
|
||||
|
||||
# Login/authorize error codes
|
||||
InvalidState = 1001
|
||||
InvalidOAuthCode = 1002
|
||||
|
@ -28,19 +29,24 @@ class ErrorCode(enum.IntEnum):
|
|||
1016 # unlinking provider would leave account with no authentication method
|
||||
)
|
||||
InvalidCaptcha = 1017 # invalid or missing captcha response
|
||||
MissingScope = 1018 # missing the required scope for this endpoint
|
||||
|
||||
# User-related error codes
|
||||
UserNotFound = 2001
|
||||
MemberListPrivate = 2002
|
||||
FlagLimitReached = 2003
|
||||
RerollingTooQuickly = 2004
|
||||
|
||||
# Member-related error codes
|
||||
MemberNotFound = 3001
|
||||
MemberLimitReached = 3002
|
||||
MemberNameInUse = 3003
|
||||
NotOwnMember = 3004
|
||||
|
||||
# General request error codes
|
||||
RequestTooBig = 4001
|
||||
MissingPermissions = 4002
|
||||
|
||||
# Moderation related error codes
|
||||
ReportAlreadyHandled = 5001
|
||||
NotSelfDelete = 5002
|
||||
|
|
|
@ -1,5 +1,13 @@
|
|||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
class BasePatchModel(BaseModel):
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
def is_set(self, key: str) -> bool:
|
||||
return key in self.model_fields_set
|
||||
|
||||
|
||||
class BaseSnowflakeModel(BaseModel):
|
||||
"""A base model with a Snowflake ID that is serialized as a string.
|
||||
|
|
|
@ -9,6 +9,10 @@ class UserModel(BaseSnowflakeModel):
|
|||
bio: str | None
|
||||
|
||||
|
||||
class SelfUserModel(UserModel):
|
||||
pass
|
||||
|
||||
|
||||
def check_username(value):
|
||||
if not value:
|
||||
return value
|
||||
|
|
|
@ -28,6 +28,7 @@ pytest = "^8.0.2"
|
|||
pytest-asyncio = "^0.23.5.post1"
|
||||
|
||||
[tool.poe.tasks]
|
||||
dev = "env QUART_APP=foxnouns.app:app quart --debug run --reload"
|
||||
server = "uvicorn 'foxnouns.app:app'"
|
||||
migrate = "alembic upgrade head"
|
||||
test = "pytest"
|
||||
|
|
Loading…
Reference in a new issue