feat: add auth
This commit is contained in:
		
							parent
							
								
									3d7217ec69
								
							
						
					
					
						commit
						8e752689df
					
				
					 8 changed files with 96 additions and 15 deletions
				
			
		|  | @ -34,6 +34,10 @@ async def get_user_from_token(): | |||
|         return | ||||
| 
 | ||||
|     async with async_session() as session: | ||||
|         token, user = await validate_token(session, token) | ||||
|         g.token = token | ||||
|         g.user = user | ||||
|         try: | ||||
|             token, user = await validate_token(session, token) | ||||
|             g.token = token | ||||
|             g.user = user | ||||
|         except Exception as e: | ||||
|             print(e) | ||||
|             raise | ||||
|  |  | |||
							
								
								
									
										23
									
								
								foxnouns/auth.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								foxnouns/auth.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,23 @@ | |||
| from functools import wraps | ||||
| from quart import g | ||||
| 
 | ||||
| from foxnouns.exceptions import ForbiddenError, ErrorCode | ||||
| 
 | ||||
| 
 | ||||
| def require_auth(*, scope: str | None = None): | ||||
|     def decorator(func): | ||||
|         @wraps(func) | ||||
|         async def wrapper(*args, **kwargs): | ||||
|             if not ("user" in g) or not ("token" in g): | ||||
|                 raise ForbiddenError("Not authenticated", type=ErrorCode.Forbidden) | ||||
| 
 | ||||
|             if scope and not g.token.has_scope(scope): | ||||
|                 raise ForbiddenError( | ||||
|                     f"Missing scope '{scope}'", type=ErrorCode.MissingScope | ||||
|                 ) | ||||
| 
 | ||||
|             return await func(*args, **kwargs) | ||||
| 
 | ||||
|         return wrapper | ||||
| 
 | ||||
|     return decorator | ||||
|  | @ -1,11 +1,15 @@ | |||
| from pydantic import BaseModel, Field, field_validator | ||||
| from quart import Blueprint | ||||
| from pydantic import Field, field_validator | ||||
| from quart import Blueprint, g | ||||
| from quart_schema import validate_response, validate_request | ||||
| from sqlalchemy import select | ||||
| 
 | ||||
| from foxnouns.auth import require_auth | ||||
| from foxnouns.db import User | ||||
| from foxnouns.db.aio import async_session | ||||
| from foxnouns.db.util import user_from_ref | ||||
| from foxnouns.db.util import user_from_ref, is_self | ||||
| from foxnouns.exceptions import NotFoundError, ErrorCode | ||||
| from foxnouns.models.user import UserModel, check_username | ||||
| from foxnouns.models import BasePatchModel | ||||
| from foxnouns.models.user import UserModel, SelfUserModel, check_username | ||||
| from foxnouns.settings import BASE_DOMAIN | ||||
| 
 | ||||
| bp = Blueprint("users_v2", __name__) | ||||
|  | @ -18,14 +22,20 @@ async def get_user(user_ref: str): | |||
|         user = await user_from_ref(session, user_ref) | ||||
|         if not user: | ||||
|             raise NotFoundError("User not found", type=ErrorCode.UserNotFound) | ||||
|         return UserModel.model_validate(user) | ||||
| 
 | ||||
|         return ( | ||||
|             SelfUserModel.model_validate(user) | ||||
|             if is_self(user) | ||||
|             else UserModel.model_validate(user) | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class EditUserRequest(BaseModel): | ||||
| class EditUserRequest(BasePatchModel): | ||||
|     username: str | None = Field( | ||||
|         min_length=2, max_length=40, pattern=r"^[\w\-\.]{2,40}$", default=None | ||||
|     ) | ||||
|     display_name: str | None = Field(min_length=2, max_length=100, default=None) | ||||
|     display_name: str | None = Field(max_length=100, default=None) | ||||
|     bio: str | None = Field(max_length=1024, default=None) | ||||
| 
 | ||||
|     @field_validator("username") | ||||
|     @classmethod | ||||
|  | @ -34,6 +44,20 @@ class EditUserRequest(BaseModel): | |||
| 
 | ||||
| 
 | ||||
| @bp.patch("/api/v2/users/@me", host=BASE_DOMAIN) | ||||
| @require_auth(scope="user.update") | ||||
| @validate_request(EditUserRequest) | ||||
| @validate_response(SelfUserModel, 200) | ||||
| async def edit_user(data: EditUserRequest): | ||||
|     return data | ||||
|     async with async_session() as session: | ||||
|         user = await session.scalar(select(User).where(User.id == g.user.id)) | ||||
| 
 | ||||
|         if data.username: | ||||
|             user.username = data.username | ||||
|         if "display_name" in data.model_fields_set: | ||||
|             user.display_name = data.display_name | ||||
|         if data.is_set("bio"): | ||||
|             user.bio = data.bio | ||||
| 
 | ||||
|         await session.commit() | ||||
| 
 | ||||
|         return SelfUserModel.model_validate(user) | ||||
|  |  | |||
|  | @ -19,7 +19,12 @@ async def user_from_ref(session: AsyncSession, user_ref: str): | |||
| 
 | ||||
|     if user_ref == "@me": | ||||
|         if "user" in g: | ||||
|             query = query.where(User.id == g.user.id) | ||||
|             if g.token.has_scope("user.read"): | ||||
|                 query = query.where(User.id == g.user.id) | ||||
|             else: | ||||
|                 raise ForbiddenError( | ||||
|                     f"Missing scope 'user.read'", type=ErrorCode.MissingScope | ||||
|                 ) | ||||
|         else: | ||||
|             raise ForbiddenError("Not authenticated") | ||||
|     else: | ||||
|  | @ -55,9 +60,11 @@ async def validate_token(session: AsyncSession, header: str) -> (Token, User): | |||
|     except BadSignature: | ||||
|         raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken) | ||||
| 
 | ||||
|     row = (await session.execute( | ||||
|         select(Token, User).join(Token.user).where(Token.id == token_id) | ||||
|     )).first() | ||||
|     row = ( | ||||
|         await session.execute( | ||||
|             select(Token, User).join(Token.user).where(Token.id == token_id) | ||||
|         ) | ||||
|     ).first() | ||||
| 
 | ||||
|     if not row or not row.Token: | ||||
|         raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken) | ||||
|  | @ -66,3 +73,7 @@ async def validate_token(session: AsyncSession, header: str) -> (Token, User): | |||
|         raise ForbiddenError("Token has expired", type=ErrorCode.InvalidToken) | ||||
| 
 | ||||
|     return (row.Token, row.User) | ||||
| 
 | ||||
| 
 | ||||
| def is_self(user: User) -> bool: | ||||
|     return "user" in g and g.user.id == user.id | ||||
|  |  | |||
|  | @ -8,6 +8,7 @@ class ErrorCode(enum.IntEnum): | |||
|     MethodNotAllowed = 405 | ||||
|     TooManyRequests = 429 | ||||
|     InternalServerError = 500  # catch-all code for unknown errors | ||||
| 
 | ||||
|     # Login/authorize error codes | ||||
|     InvalidState = 1001 | ||||
|     InvalidOAuthCode = 1002 | ||||
|  | @ -28,19 +29,24 @@ class ErrorCode(enum.IntEnum): | |||
|         1016  # unlinking provider would leave account with no authentication method | ||||
|     ) | ||||
|     InvalidCaptcha = 1017  # invalid or missing captcha response | ||||
|     MissingScope = 1018  # missing the required scope for this endpoint | ||||
| 
 | ||||
|     # User-related error codes | ||||
|     UserNotFound = 2001 | ||||
|     MemberListPrivate = 2002 | ||||
|     FlagLimitReached = 2003 | ||||
|     RerollingTooQuickly = 2004 | ||||
| 
 | ||||
|     # Member-related error codes | ||||
|     MemberNotFound = 3001 | ||||
|     MemberLimitReached = 3002 | ||||
|     MemberNameInUse = 3003 | ||||
|     NotOwnMember = 3004 | ||||
| 
 | ||||
|     # General request error codes | ||||
|     RequestTooBig = 4001 | ||||
|     MissingPermissions = 4002 | ||||
| 
 | ||||
|     # Moderation related error codes | ||||
|     ReportAlreadyHandled = 5001 | ||||
|     NotSelfDelete = 5002 | ||||
|  |  | |||
|  | @ -1,5 +1,13 @@ | |||
| from typing import Any | ||||
| 
 | ||||
| from pydantic import BaseModel, field_validator | ||||
| 
 | ||||
| class BasePatchModel(BaseModel): | ||||
|     model_config = {"from_attributes": True} | ||||
| 
 | ||||
|     def is_set(self, key: str) -> bool: | ||||
|         return key in self.model_fields_set | ||||
| 
 | ||||
| 
 | ||||
| class BaseSnowflakeModel(BaseModel): | ||||
|     """A base model with a Snowflake ID that is serialized as a string. | ||||
|  |  | |||
|  | @ -9,6 +9,10 @@ class UserModel(BaseSnowflakeModel): | |||
|     bio: str | None | ||||
| 
 | ||||
| 
 | ||||
| class SelfUserModel(UserModel): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| def check_username(value): | ||||
|     if not value: | ||||
|         return value | ||||
|  |  | |||
|  | @ -28,6 +28,7 @@ pytest = "^8.0.2" | |||
| pytest-asyncio = "^0.23.5.post1" | ||||
| 
 | ||||
| [tool.poe.tasks] | ||||
| dev = "env QUART_APP=foxnouns.app:app quart --debug run --reload" | ||||
| server = "uvicorn 'foxnouns.app:app'" | ||||
| migrate = "alembic upgrade head" | ||||
| test = "pytest" | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue