add token checking
This commit is contained in:
		
							parent
							
								
									97a64296cd
								
							
						
					
					
						commit
						dd6e7cf73f
					
				
					 8 changed files with 94 additions and 41 deletions
				
			
		| 
						 | 
				
			
			@ -1,7 +1,9 @@
 | 
			
		|||
from quart import Quart, make_response, jsonify
 | 
			
		||||
from quart import Quart, make_response, jsonify, request, g
 | 
			
		||||
from quart_schema import QuartSchema, RequestSchemaValidationError
 | 
			
		||||
 | 
			
		||||
from .blueprints import users_blueprint
 | 
			
		||||
from .db.aio import async_session
 | 
			
		||||
from .db.util import validate_token
 | 
			
		||||
from .exceptions import ExpectedError
 | 
			
		||||
 | 
			
		||||
app = Quart(__name__)
 | 
			
		||||
| 
						 | 
				
			
			@ -9,9 +11,6 @@ QuartSchema(app)
 | 
			
		|||
 | 
			
		||||
app.register_blueprint(users_blueprint)
 | 
			
		||||
 | 
			
		||||
for route in app.url_map.iter_rules():
 | 
			
		||||
    print(route, route.host)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.errorhandler(RequestSchemaValidationError)
 | 
			
		||||
async def handle_request_validation_error(error: RequestSchemaValidationError):
 | 
			
		||||
| 
						 | 
				
			
			@ -22,3 +21,19 @@ async def handle_request_validation_error(error: RequestSchemaValidationError):
 | 
			
		|||
@app.errorhandler(ExpectedError)
 | 
			
		||||
async def handle_expected_error(error: ExpectedError):
 | 
			
		||||
    return {"code": error.type, "message": error.msg}, error.status_code
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.before_request
 | 
			
		||||
async def get_user_from_token():
 | 
			
		||||
    """Get the current user from a token given in the `Authorization` header or the `pronounscc-token` cookie.
 | 
			
		||||
    If no token is set, does nothing; if an invalid token is set, raises an error."""
 | 
			
		||||
    token = request.headers.get("Authorization", None) or request.cookies.get(
 | 
			
		||||
        "pronounscc-token", None
 | 
			
		||||
    )
 | 
			
		||||
    if not token:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    async with async_session() as session:
 | 
			
		||||
        token, user = await validate_token(session, token)
 | 
			
		||||
        g.token = token
 | 
			
		||||
        g.user = user
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,7 +6,7 @@ from foxnouns.db.aio import async_session
 | 
			
		|||
from foxnouns.db.util import user_from_ref
 | 
			
		||||
from foxnouns.exceptions import NotFoundError, ErrorCode
 | 
			
		||||
from foxnouns.models.user import UserModel, check_username
 | 
			
		||||
from foxnouns.settings import BASE_DOMAIN, SHORT_DOMAIN
 | 
			
		||||
from foxnouns.settings import BASE_DOMAIN
 | 
			
		||||
 | 
			
		||||
bp = Blueprint("users_v2", __name__)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -14,18 +14,12 @@ bp = Blueprint("users_v2", __name__)
 | 
			
		|||
@bp.get("/api/v2/users/<user_ref>", host=BASE_DOMAIN)
 | 
			
		||||
@validate_response(UserModel, 200)
 | 
			
		||||
async def get_user(user_ref: str):
 | 
			
		||||
    print(request.host)
 | 
			
		||||
 | 
			
		||||
    async with async_session() as session:
 | 
			
		||||
        user = await user_from_ref(session, user_ref)
 | 
			
		||||
        if not user:
 | 
			
		||||
            raise NotFoundError("User not found", type=ErrorCode.UserNotFound)
 | 
			
		||||
        return UserModel.model_validate(user)
 | 
			
		||||
 | 
			
		||||
@bp.get("/api/v2/users/<user_ref>", host=SHORT_DOMAIN)
 | 
			
		||||
async def hello(user_ref):
 | 
			
		||||
    return {"hello": f"from {SHORT_DOMAIN}"}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EditUserRequest(BaseModel):
 | 
			
		||||
    username: str | None = Field(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,15 @@
 | 
			
		|||
from sqlalchemy import URL
 | 
			
		||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
 | 
			
		||||
 | 
			
		||||
from foxnouns.settings import ASYNC_DATABASE_URL
 | 
			
		||||
from foxnouns.settings import DATABASE
 | 
			
		||||
 | 
			
		||||
ASYNC_DATABASE_URL = URL.create(
 | 
			
		||||
    "postgresql+asyncpg",
 | 
			
		||||
    username=DATABASE["USER"],
 | 
			
		||||
    password=DATABASE["PASSWORD"],
 | 
			
		||||
    host=DATABASE["HOST"],
 | 
			
		||||
    database=DATABASE["NAME"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
engine = create_async_engine(ASYNC_DATABASE_URL)
 | 
			
		||||
async_session = async_sessionmaker(engine, expire_on_commit=False)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,13 @@
 | 
			
		|||
from sqlalchemy import create_engine
 | 
			
		||||
from sqlalchemy import URL, create_engine
 | 
			
		||||
 | 
			
		||||
from foxnouns.settings import SYNC_DATABASE_URL
 | 
			
		||||
from foxnouns.settings import DATABASE
 | 
			
		||||
 | 
			
		||||
SYNC_DATABASE_URL = URL.create(
 | 
			
		||||
    "postgresql+psycopg",
 | 
			
		||||
    username=DATABASE["USER"],
 | 
			
		||||
    password=DATABASE["PASSWORD"],
 | 
			
		||||
    host=DATABASE["HOST"],
 | 
			
		||||
    database=DATABASE["NAME"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
engine = create_engine(SYNC_DATABASE_URL)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,7 +1,6 @@
 | 
			
		|||
from datetime import datetime
 | 
			
		||||
import enum
 | 
			
		||||
 | 
			
		||||
from itsdangerous.url_safe import URLSafeTimedSerializer
 | 
			
		||||
from sqlalchemy import Text, Integer, BigInteger, ForeignKey, DateTime
 | 
			
		||||
from sqlalchemy.dialects.postgresql import ARRAY
 | 
			
		||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
 | 
			
		||||
| 
						 | 
				
			
			@ -58,9 +57,6 @@ class Token(Base):
 | 
			
		|||
        # This way, both `user` and `user.edit` tokens will grant access to `user.edit` endpoints.
 | 
			
		||||
        return scope in self.scopes or scope.split(".")[0] in self.scopes
 | 
			
		||||
 | 
			
		||||
    def token_str(self):
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AuthType(enum.IntEnum):
 | 
			
		||||
    DISCORD = 1
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,9 +1,14 @@
 | 
			
		|||
import datetime
 | 
			
		||||
 | 
			
		||||
from itsdangerous import BadSignature
 | 
			
		||||
from itsdangerous.url_safe import URLSafeTimedSerializer
 | 
			
		||||
from sqlalchemy.ext.asyncio import AsyncSession
 | 
			
		||||
from sqlalchemy import select
 | 
			
		||||
from sqlalchemy import select, insert
 | 
			
		||||
from quart import g
 | 
			
		||||
 | 
			
		||||
from .user import User
 | 
			
		||||
from foxnouns.exceptions import ForbiddenError
 | 
			
		||||
from .user import User, Token
 | 
			
		||||
from foxnouns.exceptions import ForbiddenError, ErrorCode
 | 
			
		||||
from foxnouns.settings import SECRET_KEY
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def user_from_ref(session: AsyncSession, user_ref: str):
 | 
			
		||||
| 
						 | 
				
			
			@ -25,3 +30,39 @@ async def user_from_ref(session: AsyncSession, user_ref: str):
 | 
			
		|||
            query = query.where(User.username == user_ref)
 | 
			
		||||
 | 
			
		||||
    return await session.scalar(query)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
serializer = URLSafeTimedSerializer(SECRET_KEY)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_token(token: Token):
 | 
			
		||||
    return serializer.dumps(token.id)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def create_token(session: AsyncSession, user: User, scopes: list[str] = ["*"]):
 | 
			
		||||
    expires = datetime.datetime.now() + datetime.timedelta(days=90)
 | 
			
		||||
    query = (
 | 
			
		||||
        insert(Token)
 | 
			
		||||
        .values(user_id=user.id, expires_at=expires, scopes=scopes)
 | 
			
		||||
        .returning(Token)
 | 
			
		||||
    )
 | 
			
		||||
    return await session.scalar(query)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def validate_token(session: AsyncSession, header: str) -> (Token, User):
 | 
			
		||||
    try:
 | 
			
		||||
        token_id = serializer.loads(header)
 | 
			
		||||
    except BadSignature:
 | 
			
		||||
        raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken)
 | 
			
		||||
 | 
			
		||||
    row = (await session.execute(
 | 
			
		||||
        select(Token, User).join(Token.user).where(Token.id == token_id)
 | 
			
		||||
    )).first()
 | 
			
		||||
 | 
			
		||||
    if not row or not row.Token:
 | 
			
		||||
        raise ForbiddenError("Invalid token", type=ErrorCode.InvalidToken)
 | 
			
		||||
 | 
			
		||||
    if row.Token.expires_at < datetime.datetime.now():
 | 
			
		||||
        raise ForbiddenError("Token has expired", type=ErrorCode.InvalidToken)
 | 
			
		||||
 | 
			
		||||
    return (row.Token, row.User)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,10 +1,11 @@
 | 
			
		|||
from environs import Env
 | 
			
		||||
from sqlalchemy import URL
 | 
			
		||||
 | 
			
		||||
# read .env file
 | 
			
		||||
env = Env()
 | 
			
		||||
env.read_env()
 | 
			
		||||
 | 
			
		||||
# Format: postgresql+{driver}//{user}:{password}@{host}/{name}
 | 
			
		||||
# Note that the driver is set by the application.
 | 
			
		||||
DATABASE = {
 | 
			
		||||
    "USER": env("DATABASE_USER"),
 | 
			
		||||
    "PASSWORD": env("DATABASE_PASSWORD"),
 | 
			
		||||
| 
						 | 
				
			
			@ -12,22 +13,6 @@ DATABASE = {
 | 
			
		|||
    "NAME": env("DATABASE_NAME"),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
SYNC_DATABASE_URL = URL.create(
 | 
			
		||||
    "postgresql+psycopg",
 | 
			
		||||
    username=DATABASE["USER"],
 | 
			
		||||
    password=DATABASE["PASSWORD"],
 | 
			
		||||
    host=DATABASE["HOST"],
 | 
			
		||||
    database=DATABASE["NAME"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
ASYNC_DATABASE_URL = URL.create(
 | 
			
		||||
    "postgresql+asyncpg",
 | 
			
		||||
    username=DATABASE["USER"],
 | 
			
		||||
    password=DATABASE["PASSWORD"],
 | 
			
		||||
    host=DATABASE["HOST"],
 | 
			
		||||
    database=DATABASE["NAME"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# The base domain the API is served on. This must be set.
 | 
			
		||||
BASE_DOMAIN = env("BASE_DOMAIN")
 | 
			
		||||
# The base domain for short URLs.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -27,9 +27,14 @@ optional = true
 | 
			
		|||
pytest = "^8.0.2"
 | 
			
		||||
pytest-asyncio = "^0.23.5.post1"
 | 
			
		||||
 | 
			
		||||
[build-system]
 | 
			
		||||
requires = ["poetry-core"]
 | 
			
		||||
build-backend = "poetry.core.masonry.api"
 | 
			
		||||
[tool.poe.tasks]
 | 
			
		||||
server = "uvicorn 'foxnouns.app:app'"
 | 
			
		||||
migrate = "alembic upgrade head"
 | 
			
		||||
test = "pytest"
 | 
			
		||||
 | 
			
		||||
[tool.pytest.ini_options]
 | 
			
		||||
addopts = ["--import-mode=importlib"]
 | 
			
		||||
 | 
			
		||||
[build-system]
 | 
			
		||||
requires = ["poetry-core"]
 | 
			
		||||
build-backend = "poetry.core.masonry.api"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue