feat: add auth

This commit is contained in:
sam 2024-03-20 03:37:11 +01:00
parent 3d7217ec69
commit 8e752689df
Signed by: sam
GPG key ID: B4EF20DDE721CAA1
8 changed files with 96 additions and 15 deletions

View file

@ -34,6 +34,10 @@ async def get_user_from_token():
return return
async with async_session() as session: async with async_session() as session:
try:
token, user = await validate_token(session, token) token, user = await validate_token(session, token)
g.token = token g.token = token
g.user = user g.user = user
except Exception as e:
print(e)
raise

23
foxnouns/auth.py Normal file
View file

@ -0,0 +1,23 @@
from functools import wraps
from quart import g
from foxnouns.exceptions import ForbiddenError, ErrorCode
def require_auth(*, scope: str | None = None):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
if not ("user" in g) or not ("token" in g):
raise ForbiddenError("Not authenticated", type=ErrorCode.Forbidden)
if scope and not g.token.has_scope(scope):
raise ForbiddenError(
f"Missing scope '{scope}'", type=ErrorCode.MissingScope
)
return await func(*args, **kwargs)
return wrapper
return decorator

View file

@ -1,11 +1,15 @@
from pydantic import BaseModel, Field, field_validator from pydantic import Field, field_validator
from quart import Blueprint from quart import Blueprint, g
from quart_schema import validate_response, validate_request from quart_schema import validate_response, validate_request
from sqlalchemy import select
from foxnouns.auth import require_auth
from foxnouns.db import User
from foxnouns.db.aio import async_session from foxnouns.db.aio import async_session
from foxnouns.db.util import user_from_ref from foxnouns.db.util import user_from_ref, is_self
from foxnouns.exceptions import NotFoundError, ErrorCode from foxnouns.exceptions import NotFoundError, ErrorCode
from foxnouns.models.user import UserModel, check_username from foxnouns.models import BasePatchModel
from foxnouns.models.user import UserModel, SelfUserModel, check_username
from foxnouns.settings import BASE_DOMAIN from foxnouns.settings import BASE_DOMAIN
bp = Blueprint("users_v2", __name__) bp = Blueprint("users_v2", __name__)
@ -18,14 +22,20 @@ async def get_user(user_ref: str):
user = await user_from_ref(session, user_ref) user = await user_from_ref(session, user_ref)
if not user: if not user:
raise NotFoundError("User not found", type=ErrorCode.UserNotFound) raise NotFoundError("User not found", type=ErrorCode.UserNotFound)
return UserModel.model_validate(user)
return (
SelfUserModel.model_validate(user)
if is_self(user)
else UserModel.model_validate(user)
)
class EditUserRequest(BaseModel): class EditUserRequest(BasePatchModel):
username: str | None = Field( username: str | None = Field(
min_length=2, max_length=40, pattern=r"^[\w\-\.]{2,40}$", default=None min_length=2, max_length=40, pattern=r"^[\w\-\.]{2,40}$", default=None
) )
display_name: str | None = Field(min_length=2, max_length=100, default=None) display_name: str | None = Field(max_length=100, default=None)
bio: str | None = Field(max_length=1024, default=None)
@field_validator("username") @field_validator("username")
@classmethod @classmethod
@ -34,6 +44,20 @@ class EditUserRequest(BaseModel):
@bp.patch("/api/v2/users/@me", host=BASE_DOMAIN) @bp.patch("/api/v2/users/@me", host=BASE_DOMAIN)
@require_auth(scope="user.update")
@validate_request(EditUserRequest) @validate_request(EditUserRequest)
@validate_response(SelfUserModel, 200)
async def edit_user(data: EditUserRequest): async def edit_user(data: EditUserRequest):
return data async with async_session() as session:
user = await session.scalar(select(User).where(User.id == g.user.id))
if data.username:
user.username = data.username
if "display_name" in data.model_fields_set:
user.display_name = data.display_name
if data.is_set("bio"):
user.bio = data.bio
await session.commit()
return SelfUserModel.model_validate(user)

View file

@ -19,7 +19,12 @@ async def user_from_ref(session: AsyncSession, user_ref: str):
if user_ref == "@me": if user_ref == "@me":
if "user" in g: if "user" in g:
if g.token.has_scope("user.read"):
query = query.where(User.id == g.user.id) query = query.where(User.id == g.user.id)
else:
raise ForbiddenError(
f"Missing scope 'user.read'", type=ErrorCode.MissingScope
)
else: else:
raise ForbiddenError("Not authenticated") raise ForbiddenError("Not authenticated")
else: else:
@ -55,9 +60,11 @@ async def validate_token(session: AsyncSession, header: str) -> (Token, User):
except BadSignature: except BadSignature:
raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken) raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken)
row = (await session.execute( row = (
await session.execute(
select(Token, User).join(Token.user).where(Token.id == token_id) select(Token, User).join(Token.user).where(Token.id == token_id)
)).first() )
).first()
if not row or not row.Token: if not row or not row.Token:
raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken) raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken)
@ -66,3 +73,7 @@ async def validate_token(session: AsyncSession, header: str) -> (Token, User):
raise ForbiddenError("Token has expired", type=ErrorCode.InvalidToken) raise ForbiddenError("Token has expired", type=ErrorCode.InvalidToken)
return (row.Token, row.User) return (row.Token, row.User)
def is_self(user: User) -> bool:
return "user" in g and g.user.id == user.id

View file

@ -8,6 +8,7 @@ class ErrorCode(enum.IntEnum):
MethodNotAllowed = 405 MethodNotAllowed = 405
TooManyRequests = 429 TooManyRequests = 429
InternalServerError = 500 # catch-all code for unknown errors InternalServerError = 500 # catch-all code for unknown errors
# Login/authorize error codes # Login/authorize error codes
InvalidState = 1001 InvalidState = 1001
InvalidOAuthCode = 1002 InvalidOAuthCode = 1002
@ -28,19 +29,24 @@ class ErrorCode(enum.IntEnum):
1016 # unlinking provider would leave account with no authentication method 1016 # unlinking provider would leave account with no authentication method
) )
InvalidCaptcha = 1017 # invalid or missing captcha response InvalidCaptcha = 1017 # invalid or missing captcha response
MissingScope = 1018 # missing the required scope for this endpoint
# User-related error codes # User-related error codes
UserNotFound = 2001 UserNotFound = 2001
MemberListPrivate = 2002 MemberListPrivate = 2002
FlagLimitReached = 2003 FlagLimitReached = 2003
RerollingTooQuickly = 2004 RerollingTooQuickly = 2004
# Member-related error codes # Member-related error codes
MemberNotFound = 3001 MemberNotFound = 3001
MemberLimitReached = 3002 MemberLimitReached = 3002
MemberNameInUse = 3003 MemberNameInUse = 3003
NotOwnMember = 3004 NotOwnMember = 3004
# General request error codes # General request error codes
RequestTooBig = 4001 RequestTooBig = 4001
MissingPermissions = 4002 MissingPermissions = 4002
# Moderation related error codes # Moderation related error codes
ReportAlreadyHandled = 5001 ReportAlreadyHandled = 5001
NotSelfDelete = 5002 NotSelfDelete = 5002

View file

@ -1,5 +1,13 @@
from typing import Any
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
class BasePatchModel(BaseModel):
model_config = {"from_attributes": True}
def is_set(self, key: str) -> bool:
return key in self.model_fields_set
class BaseSnowflakeModel(BaseModel): class BaseSnowflakeModel(BaseModel):
"""A base model with a Snowflake ID that is serialized as a string. """A base model with a Snowflake ID that is serialized as a string.

View file

@ -9,6 +9,10 @@ class UserModel(BaseSnowflakeModel):
bio: str | None bio: str | None
class SelfUserModel(UserModel):
pass
def check_username(value): def check_username(value):
if not value: if not value:
return value return value

View file

@ -28,6 +28,7 @@ pytest = "^8.0.2"
pytest-asyncio = "^0.23.5.post1" pytest-asyncio = "^0.23.5.post1"
[tool.poe.tasks] [tool.poe.tasks]
dev = "env QUART_APP=foxnouns.app:app quart --debug run --reload"
server = "uvicorn 'foxnouns.app:app'" server = "uvicorn 'foxnouns.app:app'"
migrate = "alembic upgrade head" migrate = "alembic upgrade head"
test = "pytest" test = "pytest"