init
This commit is contained in:
commit
97a64296cd
28 changed files with 2288 additions and 0 deletions
2
foxnouns/db/__init__.py
Normal file
2
foxnouns/db/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from .base import Base
|
||||
from .user import User, Token, AuthMethod, FediverseApp
|
6
foxnouns/db/aio.py
Normal file
6
foxnouns/db/aio.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
|
||||
from foxnouns.settings import ASYNC_DATABASE_URL
|
||||
|
||||
engine = create_async_engine(ASYNC_DATABASE_URL)
|
||||
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
4
foxnouns/db/base.py
Normal file
4
foxnouns/db/base.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
110
foxnouns/db/snowflake.py
Normal file
110
foxnouns/db/snowflake.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
import os
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from random import randrange
|
||||
|
||||
_local = threading.local()
|
||||
|
||||
|
||||
def _get_increment() -> int:
|
||||
if not hasattr(_local, "increment"):
|
||||
_local.increment = randrange(0, 4095)
|
||||
|
||||
increment = _local.increment
|
||||
_local.increment += 1
|
||||
return increment
|
||||
|
||||
|
||||
class Snowflake:
|
||||
"""A Snowflake ID (https://en.wikipedia.org/wiki/Snowflake_ID).
|
||||
This class wraps an integer and adds convenience functions."""
|
||||
|
||||
EPOCH = 1_640_995_200_000 # 2022-01-01 at 00:00:00 UTC
|
||||
|
||||
_raw: int
|
||||
|
||||
def __init__(self, src: int):
|
||||
self._raw = src
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.id)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Snowflake<{self.id}, {self.process}, {self.thread}, {self.increment}, {self.timestamp}>"
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self._raw
|
||||
|
||||
def __float__(self) -> float:
|
||||
return float(self._raw)
|
||||
|
||||
def __lt__(self, y: "Snowflake"):
|
||||
return self.id < y.id
|
||||
|
||||
def __le__(self, y: "Snowflake"):
|
||||
return self.id <= y.id
|
||||
|
||||
def __eq__(self, y: "Snowflake"):
|
||||
return self.id == y.id
|
||||
|
||||
def __ne__(self, y: "Snowflake"):
|
||||
return self.id != y.id
|
||||
|
||||
def __gt__(self, y: "Snowflake"):
|
||||
return self.id > y.id
|
||||
|
||||
def __ge__(self, y: "Snowflake"):
|
||||
return self.id >= y.id
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
"""The raw integer value of the snowflake."""
|
||||
return self._raw
|
||||
|
||||
@property
|
||||
def time(self) -> datetime:
|
||||
"""The time embedded into the snowflake."""
|
||||
return datetime.fromtimestamp(self.timestamp, tz=timezone.utc)
|
||||
|
||||
@property
|
||||
def timestamp(self) -> float:
|
||||
"""The unix timestamp embedded into the snowflake."""
|
||||
return ((self._raw >> 22) + self.EPOCH) / 1000
|
||||
|
||||
@property
|
||||
def process(self) -> int:
|
||||
"""The process ID embedded into the snowflake."""
|
||||
return (self._raw & 0x3E0000) >> 17
|
||||
|
||||
@property
|
||||
def thread(self) -> int:
|
||||
"""The thread ID embedded into the snowflake."""
|
||||
return (self._raw & 0x1F000) >> 12
|
||||
|
||||
@property
|
||||
def increment(self) -> int:
|
||||
"""The increment embedded into the snowflake."""
|
||||
return self._raw & 0xFFF
|
||||
|
||||
@classmethod
|
||||
def generate(cls, time: datetime | None = None):
|
||||
"""Generates a new snowflake.
|
||||
If `time` is set, use that time for the snowflake, otherwise, use the current time.
|
||||
"""
|
||||
|
||||
process_id = os.getpid()
|
||||
thread_id = threading.get_native_id()
|
||||
increment = _get_increment()
|
||||
now = time if time else datetime.now(tz=timezone.utc)
|
||||
timestamp = round(now.timestamp() * 1000) - cls.EPOCH
|
||||
|
||||
return cls(
|
||||
timestamp << 22
|
||||
| (process_id % 32) << 17
|
||||
| (thread_id % 32) << 12
|
||||
| (increment % 4096)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_int(cls, time: datetime | None = None):
|
||||
return cls.generate(time).id
|
5
foxnouns/db/sync.py
Normal file
5
foxnouns/db/sync.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
from sqlalchemy import create_engine
|
||||
|
||||
from foxnouns.settings import SYNC_DATABASE_URL
|
||||
|
||||
engine = create_engine(SYNC_DATABASE_URL)
|
109
foxnouns/db/user.py
Normal file
109
foxnouns/db/user.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
from datetime import datetime
|
||||
import enum
|
||||
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer
|
||||
from sqlalchemy import Text, Integer, BigInteger, ForeignKey, DateTime
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from .base import Base
|
||||
from .snowflake import Snowflake
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[int] = mapped_column(
|
||||
BigInteger(), primary_key=True, default=Snowflake.generate_int
|
||||
)
|
||||
username: Mapped[str] = mapped_column(Text(), unique=True, nullable=False)
|
||||
display_name: Mapped[str | None] = mapped_column(Text(), nullable=True)
|
||||
bio: Mapped[str | None] = mapped_column(Text(), nullable=True)
|
||||
|
||||
tokens: Mapped[list["Token"]] = relationship(
|
||||
back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
auth_methods: Mapped[list["AuthMethod"]] = relationship(
|
||||
back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"User(id={self.id!r}, username={self.username!r})"
|
||||
|
||||
|
||||
class Token(Base):
|
||||
__tablename__ = "tokens"
|
||||
|
||||
id: Mapped[int] = mapped_column(
|
||||
BigInteger(), primary_key=True, default=Snowflake.generate_int
|
||||
)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(), nullable=False)
|
||||
scopes: Mapped[list[str]] = mapped_column(ARRAY(Text), nullable=False)
|
||||
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
|
||||
user: Mapped[User] = relationship(back_populates="tokens")
|
||||
|
||||
def __repr__(self):
|
||||
return f"Token(id={self.id!r}, user={self.user_id!r})"
|
||||
|
||||
def has_scope(self, scope: str):
|
||||
"""Returns whether this token can be used for the given scope."""
|
||||
|
||||
# `*` is a special scope for site tokens, which grants access to all endpoints.
|
||||
if "*" in self.scopes:
|
||||
return True
|
||||
|
||||
# Some scopes have sub-scopes, indicated by a `.` (i.e. `user.edit` is contained in `user`)
|
||||
# Tokens can have these narrower scopes given to them, or the wider, more privileged scopes
|
||||
# This way, both `user` and `user.edit` tokens will grant access to `user.edit` endpoints.
|
||||
return scope in self.scopes or scope.split(".")[0] in self.scopes
|
||||
|
||||
def token_str(self):
|
||||
...
|
||||
|
||||
|
||||
class AuthType(enum.IntEnum):
|
||||
DISCORD = 1
|
||||
GOOGLE = 2
|
||||
TUMBLR = 3
|
||||
FEDIVERSE = 4
|
||||
EMAIL = 5
|
||||
|
||||
|
||||
class AuthMethod(Base):
|
||||
__tablename__ = "auth_methods"
|
||||
|
||||
id: Mapped[int] = mapped_column(
|
||||
BigInteger(), primary_key=True, default=Snowflake.generate_int
|
||||
)
|
||||
auth_type: Mapped[AuthType] = mapped_column(Integer(), nullable=False)
|
||||
|
||||
remote_id: Mapped[str] = mapped_column(Text(), nullable=False)
|
||||
remote_username: Mapped[str | None] = mapped_column(Text(), nullable=True)
|
||||
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
|
||||
user: Mapped[User] = relationship(back_populates="auth_methods")
|
||||
|
||||
fediverse_app_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("fediverse_apps.id"), nullable=True
|
||||
)
|
||||
fediverse_app: Mapped["FediverseApp"] = relationship()
|
||||
|
||||
|
||||
class FediverseInstanceType(enum.IntEnum):
|
||||
MASTODON_API = 1
|
||||
MISSKEY_API = 2
|
||||
|
||||
|
||||
class FediverseApp(Base):
|
||||
__tablename__ = "fediverse_apps"
|
||||
|
||||
id: Mapped[int] = mapped_column(
|
||||
BigInteger(), primary_key=True, default=Snowflake.generate_int
|
||||
)
|
||||
instance: Mapped[str] = mapped_column(Text(), unique=True, nullable=False)
|
||||
client_id: Mapped[str] = mapped_column(Text(), nullable=False)
|
||||
client_secret: Mapped[str] = mapped_column(Text(), nullable=False)
|
||||
instance_type: Mapped[FediverseInstanceType] = mapped_column(
|
||||
Integer(), nullable=False
|
||||
)
|
27
foxnouns/db/util.py
Normal file
27
foxnouns/db/util.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from quart import g
|
||||
|
||||
from .user import User
|
||||
from foxnouns.exceptions import ForbiddenError
|
||||
|
||||
|
||||
async def user_from_ref(session: AsyncSession, user_ref: str):
|
||||
"""Returns a user from a `user_ref` value. If `user_ref` is `@me`, returns the current user.
|
||||
Otherwise, tries to convert the user to a snowflake ID and queries that. Otherwise, returns a user with that username.
|
||||
"""
|
||||
query = select(User)
|
||||
|
||||
if user_ref == "@me":
|
||||
if "user" in g:
|
||||
query = query.where(User.id == g.user.id)
|
||||
else:
|
||||
raise ForbiddenError("Not authenticated")
|
||||
else:
|
||||
try:
|
||||
id = int(user_ref)
|
||||
query = query.where(User.id == id)
|
||||
except ValueError:
|
||||
query = query.where(User.username == user_ref)
|
||||
|
||||
return await session.scalar(query)
|
Loading…
Add table
Add a link
Reference in a new issue