add token checking

This commit is contained in:
sam 2024-03-14 02:35:42 +01:00
parent 97a64296cd
commit dd6e7cf73f
Signed by: sam
GPG key ID: B4EF20DDE721CAA1
8 changed files with 94 additions and 41 deletions

View file

@ -1,7 +1,9 @@
from quart import Quart, make_response, jsonify from quart import Quart, make_response, jsonify, request, g
from quart_schema import QuartSchema, RequestSchemaValidationError from quart_schema import QuartSchema, RequestSchemaValidationError
from .blueprints import users_blueprint from .blueprints import users_blueprint
from .db.aio import async_session
from .db.util import validate_token
from .exceptions import ExpectedError from .exceptions import ExpectedError
app = Quart(__name__) app = Quart(__name__)
@ -9,9 +11,6 @@ QuartSchema(app)
app.register_blueprint(users_blueprint) app.register_blueprint(users_blueprint)
for route in app.url_map.iter_rules():
print(route, route.host)
@app.errorhandler(RequestSchemaValidationError) @app.errorhandler(RequestSchemaValidationError)
async def handle_request_validation_error(error: RequestSchemaValidationError): async def handle_request_validation_error(error: RequestSchemaValidationError):
@ -22,3 +21,19 @@ async def handle_request_validation_error(error: RequestSchemaValidationError):
@app.errorhandler(ExpectedError) @app.errorhandler(ExpectedError)
async def handle_expected_error(error: ExpectedError): async def handle_expected_error(error: ExpectedError):
return {"code": error.type, "message": error.msg}, error.status_code return {"code": error.type, "message": error.msg}, error.status_code
@app.before_request
async def get_user_from_token():
"""Get the current user from a token given in the `Authorization` header or the `pronounscc-token` cookie.
If no token is set, does nothing; if an invalid token is set, raises an error."""
token = request.headers.get("Authorization", None) or request.cookies.get(
"pronounscc-token", None
)
if not token:
return
async with async_session() as session:
token, user = await validate_token(session, token)
g.token = token
g.user = user

View file

@ -6,7 +6,7 @@ from foxnouns.db.aio import async_session
from foxnouns.db.util import user_from_ref from foxnouns.db.util import user_from_ref
from foxnouns.exceptions import NotFoundError, ErrorCode from foxnouns.exceptions import NotFoundError, ErrorCode
from foxnouns.models.user import UserModel, check_username from foxnouns.models.user import UserModel, check_username
from foxnouns.settings import BASE_DOMAIN, SHORT_DOMAIN from foxnouns.settings import BASE_DOMAIN
bp = Blueprint("users_v2", __name__) bp = Blueprint("users_v2", __name__)
@ -14,18 +14,12 @@ bp = Blueprint("users_v2", __name__)
@bp.get("/api/v2/users/<user_ref>", host=BASE_DOMAIN) @bp.get("/api/v2/users/<user_ref>", host=BASE_DOMAIN)
@validate_response(UserModel, 200) @validate_response(UserModel, 200)
async def get_user(user_ref: str): async def get_user(user_ref: str):
print(request.host)
async with async_session() as session: async with async_session() as session:
user = await user_from_ref(session, user_ref) user = await user_from_ref(session, user_ref)
if not user: if not user:
raise NotFoundError("User not found", type=ErrorCode.UserNotFound) raise NotFoundError("User not found", type=ErrorCode.UserNotFound)
return UserModel.model_validate(user) return UserModel.model_validate(user)
@bp.get("/api/v2/users/<user_ref>", host=SHORT_DOMAIN)
async def hello(user_ref):
return {"hello": f"from {SHORT_DOMAIN}"}
class EditUserRequest(BaseModel): class EditUserRequest(BaseModel):
username: str | None = Field( username: str | None = Field(

View file

@ -1,6 +1,15 @@
from sqlalchemy import URL
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from foxnouns.settings import ASYNC_DATABASE_URL from foxnouns.settings import DATABASE
ASYNC_DATABASE_URL = URL.create(
"postgresql+asyncpg",
username=DATABASE["USER"],
password=DATABASE["PASSWORD"],
host=DATABASE["HOST"],
database=DATABASE["NAME"],
)
engine = create_async_engine(ASYNC_DATABASE_URL) engine = create_async_engine(ASYNC_DATABASE_URL)
async_session = async_sessionmaker(engine, expire_on_commit=False) async_session = async_sessionmaker(engine, expire_on_commit=False)

View file

@ -1,5 +1,13 @@
from sqlalchemy import create_engine from sqlalchemy import URL, create_engine
from foxnouns.settings import SYNC_DATABASE_URL from foxnouns.settings import DATABASE
SYNC_DATABASE_URL = URL.create(
"postgresql+psycopg",
username=DATABASE["USER"],
password=DATABASE["PASSWORD"],
host=DATABASE["HOST"],
database=DATABASE["NAME"],
)
engine = create_engine(SYNC_DATABASE_URL) engine = create_engine(SYNC_DATABASE_URL)

View file

@ -1,7 +1,6 @@
from datetime import datetime from datetime import datetime
import enum import enum
from itsdangerous.url_safe import URLSafeTimedSerializer
from sqlalchemy import Text, Integer, BigInteger, ForeignKey, DateTime from sqlalchemy import Text, Integer, BigInteger, ForeignKey, DateTime
from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
@ -58,9 +57,6 @@ class Token(Base):
# This way, both `user` and `user.edit` tokens will grant access to `user.edit` endpoints. # This way, both `user` and `user.edit` tokens will grant access to `user.edit` endpoints.
return scope in self.scopes or scope.split(".")[0] in self.scopes return scope in self.scopes or scope.split(".")[0] in self.scopes
def token_str(self):
...
class AuthType(enum.IntEnum): class AuthType(enum.IntEnum):
DISCORD = 1 DISCORD = 1

View file

@ -1,9 +1,14 @@
import datetime
from itsdangerous import BadSignature
from itsdangerous.url_safe import URLSafeTimedSerializer
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select from sqlalchemy import select, insert
from quart import g from quart import g
from .user import User from .user import User, Token
from foxnouns.exceptions import ForbiddenError from foxnouns.exceptions import ForbiddenError, ErrorCode
from foxnouns.settings import SECRET_KEY
async def user_from_ref(session: AsyncSession, user_ref: str): async def user_from_ref(session: AsyncSession, user_ref: str):
@ -25,3 +30,39 @@ async def user_from_ref(session: AsyncSession, user_ref: str):
query = query.where(User.username == user_ref) query = query.where(User.username == user_ref)
return await session.scalar(query) return await session.scalar(query)
serializer = URLSafeTimedSerializer(SECRET_KEY)
def generate_token(token: Token):
return serializer.dumps(token.id)
async def create_token(session: AsyncSession, user: User, scopes: list[str] = ["*"]):
expires = datetime.datetime.now() + datetime.timedelta(days=90)
query = (
insert(Token)
.values(user_id=user.id, expires_at=expires, scopes=scopes)
.returning(Token)
)
return await session.scalar(query)
async def validate_token(session: AsyncSession, header: str) -> (Token, User):
try:
token_id = serializer.loads(header)
except BadSignature:
raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken)
row = (await session.execute(
select(Token, User).join(Token.user).where(Token.id == token_id)
)).first()
if not row or not row.Token:
raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken)
if row.Token.expires_at < datetime.datetime.now():
raise ForbiddenError("Token has expired", type=ErrorCode.InvalidToken)
return (row.Token, row.User)

View file

@ -1,10 +1,11 @@
from environs import Env from environs import Env
from sqlalchemy import URL
# read .env file # read .env file
env = Env() env = Env()
env.read_env() env.read_env()
# Format: postgresql+{driver}//{user}:{password}@{host}/{name}
# Note that the driver is set by the application.
DATABASE = { DATABASE = {
"USER": env("DATABASE_USER"), "USER": env("DATABASE_USER"),
"PASSWORD": env("DATABASE_PASSWORD"), "PASSWORD": env("DATABASE_PASSWORD"),
@ -12,22 +13,6 @@ DATABASE = {
"NAME": env("DATABASE_NAME"), "NAME": env("DATABASE_NAME"),
} }
SYNC_DATABASE_URL = URL.create(
"postgresql+psycopg",
username=DATABASE["USER"],
password=DATABASE["PASSWORD"],
host=DATABASE["HOST"],
database=DATABASE["NAME"],
)
ASYNC_DATABASE_URL = URL.create(
"postgresql+asyncpg",
username=DATABASE["USER"],
password=DATABASE["PASSWORD"],
host=DATABASE["HOST"],
database=DATABASE["NAME"],
)
# The base domain the API is served on. This must be set. # The base domain the API is served on. This must be set.
BASE_DOMAIN = env("BASE_DOMAIN") BASE_DOMAIN = env("BASE_DOMAIN")
# The base domain for short URLs. # The base domain for short URLs.

View file

@ -27,9 +27,14 @@ optional = true
pytest = "^8.0.2" pytest = "^8.0.2"
pytest-asyncio = "^0.23.5.post1" pytest-asyncio = "^0.23.5.post1"
[build-system] [tool.poe.tasks]
requires = ["poetry-core"] server = "uvicorn 'foxnouns.app:app'"
build-backend = "poetry.core.masonry.api" migrate = "alembic upgrade head"
test = "pytest"
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = ["--import-mode=importlib"] addopts = ["--import-mode=importlib"]
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"