feat: add auth
This commit is contained in:
parent
3d7217ec69
commit
8e752689df
8 changed files with 96 additions and 15 deletions
|
@ -34,6 +34,10 @@ async def get_user_from_token():
|
||||||
return
|
return
|
||||||
|
|
||||||
async with async_session() as session:
|
async with async_session() as session:
|
||||||
|
try:
|
||||||
token, user = await validate_token(session, token)
|
token, user = await validate_token(session, token)
|
||||||
g.token = token
|
g.token = token
|
||||||
g.user = user
|
g.user = user
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
raise
|
||||||
|
|
23
foxnouns/auth.py
Normal file
23
foxnouns/auth.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
from functools import wraps
|
||||||
|
from quart import g
|
||||||
|
|
||||||
|
from foxnouns.exceptions import ForbiddenError, ErrorCode
|
||||||
|
|
||||||
|
|
||||||
|
def require_auth(*, scope: str | None = None):
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
if not ("user" in g) or not ("token" in g):
|
||||||
|
raise ForbiddenError("Not authenticated", type=ErrorCode.Forbidden)
|
||||||
|
|
||||||
|
if scope and not g.token.has_scope(scope):
|
||||||
|
raise ForbiddenError(
|
||||||
|
f"Missing scope '{scope}'", type=ErrorCode.MissingScope
|
||||||
|
)
|
||||||
|
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
|
@ -1,11 +1,15 @@
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import Field, field_validator
|
||||||
from quart import Blueprint
|
from quart import Blueprint, g
|
||||||
from quart_schema import validate_response, validate_request
|
from quart_schema import validate_response, validate_request
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from foxnouns.auth import require_auth
|
||||||
|
from foxnouns.db import User
|
||||||
from foxnouns.db.aio import async_session
|
from foxnouns.db.aio import async_session
|
||||||
from foxnouns.db.util import user_from_ref
|
from foxnouns.db.util import user_from_ref, is_self
|
||||||
from foxnouns.exceptions import NotFoundError, ErrorCode
|
from foxnouns.exceptions import NotFoundError, ErrorCode
|
||||||
from foxnouns.models.user import UserModel, check_username
|
from foxnouns.models import BasePatchModel
|
||||||
|
from foxnouns.models.user import UserModel, SelfUserModel, check_username
|
||||||
from foxnouns.settings import BASE_DOMAIN
|
from foxnouns.settings import BASE_DOMAIN
|
||||||
|
|
||||||
bp = Blueprint("users_v2", __name__)
|
bp = Blueprint("users_v2", __name__)
|
||||||
|
@ -18,14 +22,20 @@ async def get_user(user_ref: str):
|
||||||
user = await user_from_ref(session, user_ref)
|
user = await user_from_ref(session, user_ref)
|
||||||
if not user:
|
if not user:
|
||||||
raise NotFoundError("User not found", type=ErrorCode.UserNotFound)
|
raise NotFoundError("User not found", type=ErrorCode.UserNotFound)
|
||||||
return UserModel.model_validate(user)
|
|
||||||
|
return (
|
||||||
|
SelfUserModel.model_validate(user)
|
||||||
|
if is_self(user)
|
||||||
|
else UserModel.model_validate(user)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EditUserRequest(BaseModel):
|
class EditUserRequest(BasePatchModel):
|
||||||
username: str | None = Field(
|
username: str | None = Field(
|
||||||
min_length=2, max_length=40, pattern=r"^[\w\-\.]{2,40}$", default=None
|
min_length=2, max_length=40, pattern=r"^[\w\-\.]{2,40}$", default=None
|
||||||
)
|
)
|
||||||
display_name: str | None = Field(min_length=2, max_length=100, default=None)
|
display_name: str | None = Field(max_length=100, default=None)
|
||||||
|
bio: str | None = Field(max_length=1024, default=None)
|
||||||
|
|
||||||
@field_validator("username")
|
@field_validator("username")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -34,6 +44,20 @@ class EditUserRequest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@bp.patch("/api/v2/users/@me", host=BASE_DOMAIN)
|
@bp.patch("/api/v2/users/@me", host=BASE_DOMAIN)
|
||||||
|
@require_auth(scope="user.update")
|
||||||
@validate_request(EditUserRequest)
|
@validate_request(EditUserRequest)
|
||||||
|
@validate_response(SelfUserModel, 200)
|
||||||
async def edit_user(data: EditUserRequest):
|
async def edit_user(data: EditUserRequest):
|
||||||
return data
|
async with async_session() as session:
|
||||||
|
user = await session.scalar(select(User).where(User.id == g.user.id))
|
||||||
|
|
||||||
|
if data.username:
|
||||||
|
user.username = data.username
|
||||||
|
if "display_name" in data.model_fields_set:
|
||||||
|
user.display_name = data.display_name
|
||||||
|
if data.is_set("bio"):
|
||||||
|
user.bio = data.bio
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return SelfUserModel.model_validate(user)
|
||||||
|
|
|
@ -19,7 +19,12 @@ async def user_from_ref(session: AsyncSession, user_ref: str):
|
||||||
|
|
||||||
if user_ref == "@me":
|
if user_ref == "@me":
|
||||||
if "user" in g:
|
if "user" in g:
|
||||||
|
if g.token.has_scope("user.read"):
|
||||||
query = query.where(User.id == g.user.id)
|
query = query.where(User.id == g.user.id)
|
||||||
|
else:
|
||||||
|
raise ForbiddenError(
|
||||||
|
f"Missing scope 'user.read'", type=ErrorCode.MissingScope
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ForbiddenError("Not authenticated")
|
raise ForbiddenError("Not authenticated")
|
||||||
else:
|
else:
|
||||||
|
@ -55,9 +60,11 @@ async def validate_token(session: AsyncSession, header: str) -> (Token, User):
|
||||||
except BadSignature:
|
except BadSignature:
|
||||||
raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken)
|
raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken)
|
||||||
|
|
||||||
row = (await session.execute(
|
row = (
|
||||||
|
await session.execute(
|
||||||
select(Token, User).join(Token.user).where(Token.id == token_id)
|
select(Token, User).join(Token.user).where(Token.id == token_id)
|
||||||
)).first()
|
)
|
||||||
|
).first()
|
||||||
|
|
||||||
if not row or not row.Token:
|
if not row or not row.Token:
|
||||||
raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken)
|
raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken)
|
||||||
|
@ -66,3 +73,7 @@ async def validate_token(session: AsyncSession, header: str) -> (Token, User):
|
||||||
raise ForbiddenError("Token has expired", type=ErrorCode.InvalidToken)
|
raise ForbiddenError("Token has expired", type=ErrorCode.InvalidToken)
|
||||||
|
|
||||||
return (row.Token, row.User)
|
return (row.Token, row.User)
|
||||||
|
|
||||||
|
|
||||||
|
def is_self(user: User) -> bool:
|
||||||
|
return "user" in g and g.user.id == user.id
|
||||||
|
|
|
@ -8,6 +8,7 @@ class ErrorCode(enum.IntEnum):
|
||||||
MethodNotAllowed = 405
|
MethodNotAllowed = 405
|
||||||
TooManyRequests = 429
|
TooManyRequests = 429
|
||||||
InternalServerError = 500 # catch-all code for unknown errors
|
InternalServerError = 500 # catch-all code for unknown errors
|
||||||
|
|
||||||
# Login/authorize error codes
|
# Login/authorize error codes
|
||||||
InvalidState = 1001
|
InvalidState = 1001
|
||||||
InvalidOAuthCode = 1002
|
InvalidOAuthCode = 1002
|
||||||
|
@ -28,19 +29,24 @@ class ErrorCode(enum.IntEnum):
|
||||||
1016 # unlinking provider would leave account with no authentication method
|
1016 # unlinking provider would leave account with no authentication method
|
||||||
)
|
)
|
||||||
InvalidCaptcha = 1017 # invalid or missing captcha response
|
InvalidCaptcha = 1017 # invalid or missing captcha response
|
||||||
|
MissingScope = 1018 # missing the required scope for this endpoint
|
||||||
|
|
||||||
# User-related error codes
|
# User-related error codes
|
||||||
UserNotFound = 2001
|
UserNotFound = 2001
|
||||||
MemberListPrivate = 2002
|
MemberListPrivate = 2002
|
||||||
FlagLimitReached = 2003
|
FlagLimitReached = 2003
|
||||||
RerollingTooQuickly = 2004
|
RerollingTooQuickly = 2004
|
||||||
|
|
||||||
# Member-related error codes
|
# Member-related error codes
|
||||||
MemberNotFound = 3001
|
MemberNotFound = 3001
|
||||||
MemberLimitReached = 3002
|
MemberLimitReached = 3002
|
||||||
MemberNameInUse = 3003
|
MemberNameInUse = 3003
|
||||||
NotOwnMember = 3004
|
NotOwnMember = 3004
|
||||||
|
|
||||||
# General request error codes
|
# General request error codes
|
||||||
RequestTooBig = 4001
|
RequestTooBig = 4001
|
||||||
MissingPermissions = 4002
|
MissingPermissions = 4002
|
||||||
|
|
||||||
# Moderation related error codes
|
# Moderation related error codes
|
||||||
ReportAlreadyHandled = 5001
|
ReportAlreadyHandled = 5001
|
||||||
NotSelfDelete = 5002
|
NotSelfDelete = 5002
|
||||||
|
|
|
@ -1,5 +1,13 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
|
class BasePatchModel(BaseModel):
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
def is_set(self, key: str) -> bool:
|
||||||
|
return key in self.model_fields_set
|
||||||
|
|
||||||
|
|
||||||
class BaseSnowflakeModel(BaseModel):
|
class BaseSnowflakeModel(BaseModel):
|
||||||
"""A base model with a Snowflake ID that is serialized as a string.
|
"""A base model with a Snowflake ID that is serialized as a string.
|
||||||
|
|
|
@ -9,6 +9,10 @@ class UserModel(BaseSnowflakeModel):
|
||||||
bio: str | None
|
bio: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class SelfUserModel(UserModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def check_username(value):
|
def check_username(value):
|
||||||
if not value:
|
if not value:
|
||||||
return value
|
return value
|
||||||
|
|
|
@ -28,6 +28,7 @@ pytest = "^8.0.2"
|
||||||
pytest-asyncio = "^0.23.5.post1"
|
pytest-asyncio = "^0.23.5.post1"
|
||||||
|
|
||||||
[tool.poe.tasks]
|
[tool.poe.tasks]
|
||||||
|
dev = "env QUART_APP=foxnouns.app:app quart --debug run --reload"
|
||||||
server = "uvicorn 'foxnouns.app:app'"
|
server = "uvicorn 'foxnouns.app:app'"
|
||||||
migrate = "alembic upgrade head"
|
migrate = "alembic upgrade head"
|
||||||
test = "pytest"
|
test = "pytest"
|
||||||
|
|
Loading…
Reference in a new issue