add token checking
This commit is contained in:
parent
97a64296cd
commit
dd6e7cf73f
8 changed files with 94 additions and 41 deletions
|
@ -1,7 +1,9 @@
|
||||||
from quart import Quart, make_response, jsonify
|
from quart import Quart, make_response, jsonify, request, g
|
||||||
from quart_schema import QuartSchema, RequestSchemaValidationError
|
from quart_schema import QuartSchema, RequestSchemaValidationError
|
||||||
|
|
||||||
from .blueprints import users_blueprint
|
from .blueprints import users_blueprint
|
||||||
|
from .db.aio import async_session
|
||||||
|
from .db.util import validate_token
|
||||||
from .exceptions import ExpectedError
|
from .exceptions import ExpectedError
|
||||||
|
|
||||||
app = Quart(__name__)
|
app = Quart(__name__)
|
||||||
|
@ -9,9 +11,6 @@ QuartSchema(app)
|
||||||
|
|
||||||
app.register_blueprint(users_blueprint)
|
app.register_blueprint(users_blueprint)
|
||||||
|
|
||||||
for route in app.url_map.iter_rules():
|
|
||||||
print(route, route.host)
|
|
||||||
|
|
||||||
|
|
||||||
@app.errorhandler(RequestSchemaValidationError)
|
@app.errorhandler(RequestSchemaValidationError)
|
||||||
async def handle_request_validation_error(error: RequestSchemaValidationError):
|
async def handle_request_validation_error(error: RequestSchemaValidationError):
|
||||||
|
@ -22,3 +21,19 @@ async def handle_request_validation_error(error: RequestSchemaValidationError):
|
||||||
@app.errorhandler(ExpectedError)
|
@app.errorhandler(ExpectedError)
|
||||||
async def handle_expected_error(error: ExpectedError):
|
async def handle_expected_error(error: ExpectedError):
|
||||||
return {"code": error.type, "message": error.msg}, error.status_code
|
return {"code": error.type, "message": error.msg}, error.status_code
|
||||||
|
|
||||||
|
|
||||||
|
@app.before_request
|
||||||
|
async def get_user_from_token():
|
||||||
|
"""Get the current user from a token given in the `Authorization` header or the `pronounscc-token` cookie.
|
||||||
|
If no token is set, does nothing; if an invalid token is set, raises an error."""
|
||||||
|
token = request.headers.get("Authorization", None) or request.cookies.get(
|
||||||
|
"pronounscc-token", None
|
||||||
|
)
|
||||||
|
if not token:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
token, user = await validate_token(session, token)
|
||||||
|
g.token = token
|
||||||
|
g.user = user
|
||||||
|
|
|
@ -6,7 +6,7 @@ from foxnouns.db.aio import async_session
|
||||||
from foxnouns.db.util import user_from_ref
|
from foxnouns.db.util import user_from_ref
|
||||||
from foxnouns.exceptions import NotFoundError, ErrorCode
|
from foxnouns.exceptions import NotFoundError, ErrorCode
|
||||||
from foxnouns.models.user import UserModel, check_username
|
from foxnouns.models.user import UserModel, check_username
|
||||||
from foxnouns.settings import BASE_DOMAIN, SHORT_DOMAIN
|
from foxnouns.settings import BASE_DOMAIN
|
||||||
|
|
||||||
bp = Blueprint("users_v2", __name__)
|
bp = Blueprint("users_v2", __name__)
|
||||||
|
|
||||||
|
@ -14,18 +14,12 @@ bp = Blueprint("users_v2", __name__)
|
||||||
@bp.get("/api/v2/users/<user_ref>", host=BASE_DOMAIN)
|
@bp.get("/api/v2/users/<user_ref>", host=BASE_DOMAIN)
|
||||||
@validate_response(UserModel, 200)
|
@validate_response(UserModel, 200)
|
||||||
async def get_user(user_ref: str):
|
async def get_user(user_ref: str):
|
||||||
print(request.host)
|
|
||||||
|
|
||||||
async with async_session() as session:
|
async with async_session() as session:
|
||||||
user = await user_from_ref(session, user_ref)
|
user = await user_from_ref(session, user_ref)
|
||||||
if not user:
|
if not user:
|
||||||
raise NotFoundError("User not found", type=ErrorCode.UserNotFound)
|
raise NotFoundError("User not found", type=ErrorCode.UserNotFound)
|
||||||
return UserModel.model_validate(user)
|
return UserModel.model_validate(user)
|
||||||
|
|
||||||
@bp.get("/api/v2/users/<user_ref>", host=SHORT_DOMAIN)
|
|
||||||
async def hello(user_ref):
|
|
||||||
return {"hello": f"from {SHORT_DOMAIN}"}
|
|
||||||
|
|
||||||
|
|
||||||
class EditUserRequest(BaseModel):
|
class EditUserRequest(BaseModel):
|
||||||
username: str | None = Field(
|
username: str | None = Field(
|
||||||
|
|
|
@ -1,6 +1,15 @@
|
||||||
|
from sqlalchemy import URL
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
|
|
||||||
from foxnouns.settings import ASYNC_DATABASE_URL
|
from foxnouns.settings import DATABASE
|
||||||
|
|
||||||
|
ASYNC_DATABASE_URL = URL.create(
|
||||||
|
"postgresql+asyncpg",
|
||||||
|
username=DATABASE["USER"],
|
||||||
|
password=DATABASE["PASSWORD"],
|
||||||
|
host=DATABASE["HOST"],
|
||||||
|
database=DATABASE["NAME"],
|
||||||
|
)
|
||||||
|
|
||||||
engine = create_async_engine(ASYNC_DATABASE_URL)
|
engine = create_async_engine(ASYNC_DATABASE_URL)
|
||||||
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
|
@ -1,5 +1,13 @@
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import URL, create_engine
|
||||||
|
|
||||||
from foxnouns.settings import SYNC_DATABASE_URL
|
from foxnouns.settings import DATABASE
|
||||||
|
|
||||||
|
SYNC_DATABASE_URL = URL.create(
|
||||||
|
"postgresql+psycopg",
|
||||||
|
username=DATABASE["USER"],
|
||||||
|
password=DATABASE["PASSWORD"],
|
||||||
|
host=DATABASE["HOST"],
|
||||||
|
database=DATABASE["NAME"],
|
||||||
|
)
|
||||||
|
|
||||||
engine = create_engine(SYNC_DATABASE_URL)
|
engine = create_engine(SYNC_DATABASE_URL)
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import enum
|
import enum
|
||||||
|
|
||||||
from itsdangerous.url_safe import URLSafeTimedSerializer
|
|
||||||
from sqlalchemy import Text, Integer, BigInteger, ForeignKey, DateTime
|
from sqlalchemy import Text, Integer, BigInteger, ForeignKey, DateTime
|
||||||
from sqlalchemy.dialects.postgresql import ARRAY
|
from sqlalchemy.dialects.postgresql import ARRAY
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
@ -58,9 +57,6 @@ class Token(Base):
|
||||||
# This way, both `user` and `user.edit` tokens will grant access to `user.edit` endpoints.
|
# This way, both `user` and `user.edit` tokens will grant access to `user.edit` endpoints.
|
||||||
return scope in self.scopes or scope.split(".")[0] in self.scopes
|
return scope in self.scopes or scope.split(".")[0] in self.scopes
|
||||||
|
|
||||||
def token_str(self):
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class AuthType(enum.IntEnum):
|
class AuthType(enum.IntEnum):
|
||||||
DISCORD = 1
|
DISCORD = 1
|
||||||
|
|
|
@ -1,9 +1,14 @@
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from itsdangerous import BadSignature
|
||||||
|
from itsdangerous.url_safe import URLSafeTimedSerializer
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select, insert
|
||||||
from quart import g
|
from quart import g
|
||||||
|
|
||||||
from .user import User
|
from .user import User, Token
|
||||||
from foxnouns.exceptions import ForbiddenError
|
from foxnouns.exceptions import ForbiddenError, ErrorCode
|
||||||
|
from foxnouns.settings import SECRET_KEY
|
||||||
|
|
||||||
|
|
||||||
async def user_from_ref(session: AsyncSession, user_ref: str):
|
async def user_from_ref(session: AsyncSession, user_ref: str):
|
||||||
|
@ -25,3 +30,39 @@ async def user_from_ref(session: AsyncSession, user_ref: str):
|
||||||
query = query.where(User.username == user_ref)
|
query = query.where(User.username == user_ref)
|
||||||
|
|
||||||
return await session.scalar(query)
|
return await session.scalar(query)
|
||||||
|
|
||||||
|
|
||||||
|
serializer = URLSafeTimedSerializer(SECRET_KEY)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_token(token: Token):
|
||||||
|
return serializer.dumps(token.id)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_token(session: AsyncSession, user: User, scopes: list[str] = ["*"]):
|
||||||
|
expires = datetime.datetime.now() + datetime.timedelta(days=90)
|
||||||
|
query = (
|
||||||
|
insert(Token)
|
||||||
|
.values(user_id=user.id, expires_at=expires, scopes=scopes)
|
||||||
|
.returning(Token)
|
||||||
|
)
|
||||||
|
return await session.scalar(query)
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_token(session: AsyncSession, header: str) -> (Token, User):
|
||||||
|
try:
|
||||||
|
token_id = serializer.loads(header)
|
||||||
|
except BadSignature:
|
||||||
|
raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken)
|
||||||
|
|
||||||
|
row = (await session.execute(
|
||||||
|
select(Token, User).join(Token.user).where(Token.id == token_id)
|
||||||
|
)).first()
|
||||||
|
|
||||||
|
if not row or not row.Token:
|
||||||
|
raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken)
|
||||||
|
|
||||||
|
if row.Token.expires_at < datetime.datetime.now():
|
||||||
|
raise ForbiddenError("Token has expired", type=ErrorCode.InvalidToken)
|
||||||
|
|
||||||
|
return (row.Token, row.User)
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
from environs import Env
|
from environs import Env
|
||||||
from sqlalchemy import URL
|
|
||||||
|
|
||||||
# read .env file
|
# read .env file
|
||||||
env = Env()
|
env = Env()
|
||||||
env.read_env()
|
env.read_env()
|
||||||
|
|
||||||
|
# Format: postgresql+{driver}//{user}:{password}@{host}/{name}
|
||||||
|
# Note that the driver is set by the application.
|
||||||
DATABASE = {
|
DATABASE = {
|
||||||
"USER": env("DATABASE_USER"),
|
"USER": env("DATABASE_USER"),
|
||||||
"PASSWORD": env("DATABASE_PASSWORD"),
|
"PASSWORD": env("DATABASE_PASSWORD"),
|
||||||
|
@ -12,22 +13,6 @@ DATABASE = {
|
||||||
"NAME": env("DATABASE_NAME"),
|
"NAME": env("DATABASE_NAME"),
|
||||||
}
|
}
|
||||||
|
|
||||||
SYNC_DATABASE_URL = URL.create(
|
|
||||||
"postgresql+psycopg",
|
|
||||||
username=DATABASE["USER"],
|
|
||||||
password=DATABASE["PASSWORD"],
|
|
||||||
host=DATABASE["HOST"],
|
|
||||||
database=DATABASE["NAME"],
|
|
||||||
)
|
|
||||||
|
|
||||||
ASYNC_DATABASE_URL = URL.create(
|
|
||||||
"postgresql+asyncpg",
|
|
||||||
username=DATABASE["USER"],
|
|
||||||
password=DATABASE["PASSWORD"],
|
|
||||||
host=DATABASE["HOST"],
|
|
||||||
database=DATABASE["NAME"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# The base domain the API is served on. This must be set.
|
# The base domain the API is served on. This must be set.
|
||||||
BASE_DOMAIN = env("BASE_DOMAIN")
|
BASE_DOMAIN = env("BASE_DOMAIN")
|
||||||
# The base domain for short URLs.
|
# The base domain for short URLs.
|
||||||
|
|
|
@ -27,9 +27,14 @@ optional = true
|
||||||
pytest = "^8.0.2"
|
pytest = "^8.0.2"
|
||||||
pytest-asyncio = "^0.23.5.post1"
|
pytest-asyncio = "^0.23.5.post1"
|
||||||
|
|
||||||
[build-system]
|
[tool.poe.tasks]
|
||||||
requires = ["poetry-core"]
|
server = "uvicorn 'foxnouns.app:app'"
|
||||||
build-backend = "poetry.core.masonry.api"
|
migrate = "alembic upgrade head"
|
||||||
|
test = "pytest"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = ["--import-mode=importlib"]
|
addopts = ["--import-mode=importlib"]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
Loading…
Reference in a new issue