feat: add auth

This commit is contained in:
sam 2024-03-20 03:37:11 +01:00
parent 3d7217ec69
commit 8e752689df
Signed by: sam
GPG key ID: B4EF20DDE721CAA1
8 changed files with 96 additions and 15 deletions

View file

@ -34,6 +34,10 @@ async def get_user_from_token():
return
async with async_session() as session:
token, user = await validate_token(session, token)
g.token = token
g.user = user
try:
token, user = await validate_token(session, token)
g.token = token
g.user = user
except Exception as e:
print(e)
raise

23
foxnouns/auth.py Normal file
View file

@ -0,0 +1,23 @@
from functools import wraps
from quart import g
from foxnouns.exceptions import ForbiddenError, ErrorCode
def require_auth(*, scope: str | None = None):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
if not ("user" in g) or not ("token" in g):
raise ForbiddenError("Not authenticated", type=ErrorCode.Forbidden)
if scope and not g.token.has_scope(scope):
raise ForbiddenError(
f"Missing scope '{scope}'", type=ErrorCode.MissingScope
)
return await func(*args, **kwargs)
return wrapper
return decorator

View file

@ -1,11 +1,15 @@
from pydantic import BaseModel, Field, field_validator
from quart import Blueprint
from pydantic import Field, field_validator
from quart import Blueprint, g
from quart_schema import validate_response, validate_request
from sqlalchemy import select
from foxnouns.auth import require_auth
from foxnouns.db import User
from foxnouns.db.aio import async_session
from foxnouns.db.util import user_from_ref
from foxnouns.db.util import user_from_ref, is_self
from foxnouns.exceptions import NotFoundError, ErrorCode
from foxnouns.models.user import UserModel, check_username
from foxnouns.models import BasePatchModel
from foxnouns.models.user import UserModel, SelfUserModel, check_username
from foxnouns.settings import BASE_DOMAIN
bp = Blueprint("users_v2", __name__)
@ -18,14 +22,20 @@ async def get_user(user_ref: str):
user = await user_from_ref(session, user_ref)
if not user:
raise NotFoundError("User not found", type=ErrorCode.UserNotFound)
return UserModel.model_validate(user)
return (
SelfUserModel.model_validate(user)
if is_self(user)
else UserModel.model_validate(user)
)
class EditUserRequest(BaseModel):
class EditUserRequest(BasePatchModel):
username: str | None = Field(
min_length=2, max_length=40, pattern=r"^[\w\-\.]{2,40}$", default=None
)
display_name: str | None = Field(min_length=2, max_length=100, default=None)
display_name: str | None = Field(max_length=100, default=None)
bio: str | None = Field(max_length=1024, default=None)
@field_validator("username")
@classmethod
@ -34,6 +44,20 @@ class EditUserRequest(BaseModel):
@bp.patch("/api/v2/users/@me", host=BASE_DOMAIN)
@require_auth(scope="user.update")
@validate_request(EditUserRequest)
@validate_response(SelfUserModel, 200)
async def edit_user(data: EditUserRequest):
return data
async with async_session() as session:
user = await session.scalar(select(User).where(User.id == g.user.id))
if data.username:
user.username = data.username
if "display_name" in data.model_fields_set:
user.display_name = data.display_name
if data.is_set("bio"):
user.bio = data.bio
await session.commit()
return SelfUserModel.model_validate(user)

View file

@ -19,7 +19,12 @@ async def user_from_ref(session: AsyncSession, user_ref: str):
if user_ref == "@me":
if "user" in g:
query = query.where(User.id == g.user.id)
if g.token.has_scope("user.read"):
query = query.where(User.id == g.user.id)
else:
raise ForbiddenError(
f"Missing scope 'user.read'", type=ErrorCode.MissingScope
)
else:
raise ForbiddenError("Not authenticated")
else:
@ -55,9 +60,11 @@ async def validate_token(session: AsyncSession, header: str) -> (Token, User):
except BadSignature:
raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken)
row = (await session.execute(
select(Token, User).join(Token.user).where(Token.id == token_id)
)).first()
row = (
await session.execute(
select(Token, User).join(Token.user).where(Token.id == token_id)
)
).first()
if not row or not row.Token:
raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken)
@ -66,3 +73,7 @@ async def validate_token(session: AsyncSession, header: str) -> (Token, User):
raise ForbiddenError("Token has expired", type=ErrorCode.InvalidToken)
return (row.Token, row.User)
def is_self(user: User) -> bool:
return "user" in g and g.user.id == user.id

View file

@ -8,6 +8,7 @@ class ErrorCode(enum.IntEnum):
MethodNotAllowed = 405
TooManyRequests = 429
InternalServerError = 500 # catch-all code for unknown errors
# Login/authorize error codes
InvalidState = 1001
InvalidOAuthCode = 1002
@ -28,19 +29,24 @@ class ErrorCode(enum.IntEnum):
1016 # unlinking provider would leave account with no authentication method
)
InvalidCaptcha = 1017 # invalid or missing captcha response
MissingScope = 1018 # missing the required scope for this endpoint
# User-related error codes
UserNotFound = 2001
MemberListPrivate = 2002
FlagLimitReached = 2003
RerollingTooQuickly = 2004
# Member-related error codes
MemberNotFound = 3001
MemberLimitReached = 3002
MemberNameInUse = 3003
NotOwnMember = 3004
# General request error codes
RequestTooBig = 4001
MissingPermissions = 4002
# Moderation related error codes
ReportAlreadyHandled = 5001
NotSelfDelete = 5002

View file

@ -1,5 +1,13 @@
from typing import Any
from pydantic import BaseModel, field_validator
class BasePatchModel(BaseModel):
model_config = {"from_attributes": True}
def is_set(self, key: str) -> bool:
return key in self.model_fields_set
class BaseSnowflakeModel(BaseModel):
"""A base model with a Snowflake ID that is serialized as a string.

View file

@ -9,6 +9,10 @@ class UserModel(BaseSnowflakeModel):
bio: str | None
class SelfUserModel(UserModel):
pass
def check_username(value):
if not value:
return value

View file

@ -28,6 +28,7 @@ pytest = "^8.0.2"
pytest-asyncio = "^0.23.5.post1"
[tool.poe.tasks]
dev = "env QUART_APP=foxnouns.app:app quart --debug run --reload"
server = "uvicorn 'foxnouns.app:app'"
migrate = "alembic upgrade head"
test = "pytest"