diff --git a/foxnouns/app.py b/foxnouns/app.py index b87c416..1ba7148 100644 --- a/foxnouns/app.py +++ b/foxnouns/app.py @@ -1,7 +1,9 @@ -from quart import Quart, make_response, jsonify +from quart import Quart, make_response, jsonify, request, g from quart_schema import QuartSchema, RequestSchemaValidationError from .blueprints import users_blueprint +from .db.aio import async_session +from .db.util import validate_token from .exceptions import ExpectedError app = Quart(__name__) @@ -9,9 +11,6 @@ QuartSchema(app) app.register_blueprint(users_blueprint) -for route in app.url_map.iter_rules(): - print(route, route.host) - @app.errorhandler(RequestSchemaValidationError) async def handle_request_validation_error(error: RequestSchemaValidationError): @@ -22,3 +21,19 @@ async def handle_request_validation_error(error: RequestSchemaValidationError): @app.errorhandler(ExpectedError) async def handle_expected_error(error: ExpectedError): return {"code": error.type, "message": error.msg}, error.status_code + + +@app.before_request +async def get_user_from_token(): + """Get the current user from a token given in the `Authorization` header or the `pronounscc-token` cookie. + If no token is set, does nothing; if an invalid token is set, raises an error.""" + token = request.headers.get("Authorization", None) or request.cookies.get( + "pronounscc-token", None + ) + if not token: + return + + async with async_session() as session: + token, user = await validate_token(session, token) + g.token = token + g.user = user diff --git a/foxnouns/blueprints/v2/users.py b/foxnouns/blueprints/v2/users.py index 1031083..5451ac5 100644 --- a/foxnouns/blueprints/v2/users.py +++ b/foxnouns/blueprints/v2/users.py @@ -6,7 +6,7 @@ from foxnouns.db.aio import async_session from foxnouns.db.util import user_from_ref from foxnouns.exceptions import NotFoundError, ErrorCode from foxnouns.models.user import UserModel, check_username -from foxnouns.settings import BASE_DOMAIN, SHORT_DOMAIN +from foxnouns.settings import BASE_DOMAIN bp = Blueprint("users_v2", __name__) @@ -14,18 +14,12 @@ bp = Blueprint("users_v2", __name__) @bp.get("/api/v2/users/", host=BASE_DOMAIN) @validate_response(UserModel, 200) async def get_user(user_ref: str): - print(request.host) - async with async_session() as session: user = await user_from_ref(session, user_ref) if not user: raise NotFoundError("User not found", type=ErrorCode.UserNotFound) return UserModel.model_validate(user) -@bp.get("/api/v2/users/", host=SHORT_DOMAIN) -async def hello(user_ref): - return {"hello": f"from {SHORT_DOMAIN}"} - class EditUserRequest(BaseModel): username: str | None = Field( diff --git a/foxnouns/db/aio.py b/foxnouns/db/aio.py index e48eb71..060b651 100644 --- a/foxnouns/db/aio.py +++ b/foxnouns/db/aio.py @@ -1,6 +1,15 @@ +from sqlalchemy import URL from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker -from foxnouns.settings import ASYNC_DATABASE_URL +from foxnouns.settings import DATABASE + +ASYNC_DATABASE_URL = URL.create( + "postgresql+asyncpg", + username=DATABASE["USER"], + password=DATABASE["PASSWORD"], + host=DATABASE["HOST"], + database=DATABASE["NAME"], +) engine = create_async_engine(ASYNC_DATABASE_URL) async_session = async_sessionmaker(engine, expire_on_commit=False) diff --git a/foxnouns/db/sync.py b/foxnouns/db/sync.py index d44a0dd..c3e2006 100644 --- a/foxnouns/db/sync.py +++ b/foxnouns/db/sync.py @@ -1,5 +1,13 @@ -from sqlalchemy import create_engine +from sqlalchemy import URL, create_engine -from foxnouns.settings import SYNC_DATABASE_URL +from foxnouns.settings import DATABASE + +SYNC_DATABASE_URL = URL.create( + "postgresql+psycopg", + username=DATABASE["USER"], + password=DATABASE["PASSWORD"], + host=DATABASE["HOST"], + database=DATABASE["NAME"], +) engine = create_engine(SYNC_DATABASE_URL) diff --git a/foxnouns/db/user.py b/foxnouns/db/user.py index 3162314..5ef9591 100644 --- a/foxnouns/db/user.py +++ b/foxnouns/db/user.py @@ -1,7 +1,6 @@ from datetime import datetime import enum -from itsdangerous.url_safe import URLSafeTimedSerializer from sqlalchemy import Text, Integer, BigInteger, ForeignKey, DateTime from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -58,9 +57,6 @@ class Token(Base): # This way, both `user` and `user.edit` tokens will grant access to `user.edit` endpoints. return scope in self.scopes or scope.split(".")[0] in self.scopes - def token_str(self): - ... - class AuthType(enum.IntEnum): DISCORD = 1 diff --git a/foxnouns/db/util.py b/foxnouns/db/util.py index ad70187..bccf084 100644 --- a/foxnouns/db/util.py +++ b/foxnouns/db/util.py @@ -1,9 +1,14 @@ +import datetime + +from itsdangerous import BadSignature +from itsdangerous.url_safe import URLSafeTimedSerializer from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select +from sqlalchemy import select, insert from quart import g -from .user import User -from foxnouns.exceptions import ForbiddenError +from .user import User, Token +from foxnouns.exceptions import ForbiddenError, ErrorCode +from foxnouns.settings import SECRET_KEY async def user_from_ref(session: AsyncSession, user_ref: str): @@ -25,3 +30,39 @@ async def user_from_ref(session: AsyncSession, user_ref: str): query = query.where(User.username == user_ref) return await session.scalar(query) + + +serializer = URLSafeTimedSerializer(SECRET_KEY) + + +def generate_token(token: Token): + return serializer.dumps(token.id) + + +async def create_token(session: AsyncSession, user: User, scopes: list[str] = ["*"]): + expires = datetime.datetime.now() + datetime.timedelta(days=90) + query = ( + insert(Token) + .values(user_id=user.id, expires_at=expires, scopes=scopes) + .returning(Token) + ) + return await session.scalar(query) + + +async def validate_token(session: AsyncSession, header: str) -> (Token, User): + try: + token_id = serializer.loads(header) + except BadSignature: + raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken) + + row = (await session.execute( + select(Token, User).join(Token.user).where(Token.id == token_id) + )).first() + + if not row or not row.Token: + raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken) + + if row.Token.expires_at < datetime.datetime.now(): + raise ForbiddenError("Token has expired", type=ErrorCode.InvalidToken) + + return (row.Token, row.User) diff --git a/foxnouns/settings.py b/foxnouns/settings.py index b65c93a..290b5da 100644 --- a/foxnouns/settings.py +++ b/foxnouns/settings.py @@ -1,10 +1,11 @@ from environs import Env -from sqlalchemy import URL # read .env file env = Env() env.read_env() +# Format: postgresql+{driver}//{user}:{password}@{host}/{name} +# Note that the driver is set by the application. DATABASE = { "USER": env("DATABASE_USER"), "PASSWORD": env("DATABASE_PASSWORD"), @@ -12,22 +13,6 @@ DATABASE = { "NAME": env("DATABASE_NAME"), } -SYNC_DATABASE_URL = URL.create( - "postgresql+psycopg", - username=DATABASE["USER"], - password=DATABASE["PASSWORD"], - host=DATABASE["HOST"], - database=DATABASE["NAME"], -) - -ASYNC_DATABASE_URL = URL.create( - "postgresql+asyncpg", - username=DATABASE["USER"], - password=DATABASE["PASSWORD"], - host=DATABASE["HOST"], - database=DATABASE["NAME"], -) - # The base domain the API is served on. This must be set. BASE_DOMAIN = env("BASE_DOMAIN") # The base domain for short URLs. diff --git a/pyproject.toml b/pyproject.toml index 6672108..9753157 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,9 +27,14 @@ optional = true pytest = "^8.0.2" pytest-asyncio = "^0.23.5.post1" -[build-system] -requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" +[tool.poe.tasks] +server = "uvicorn 'foxnouns.app:app'" +migrate = "alembic upgrade head" +test = "pytest" [tool.pytest.ini_options] addopts = ["--import-mode=importlib"] + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api"