This commit is contained in:
sam 2024-03-13 17:03:18 +01:00
commit 97a64296cd
Signed by: sam
GPG key ID: B4EF20DDE721CAA1
28 changed files with 2288 additions and 0 deletions

2
foxnouns/db/__init__.py Normal file
View file

@ -0,0 +1,2 @@
from .base import Base
from .user import User, Token, AuthMethod, FediverseApp

6
foxnouns/db/aio.py Normal file
View file

@ -0,0 +1,6 @@
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from foxnouns.settings import ASYNC_DATABASE_URL
engine = create_async_engine(ASYNC_DATABASE_URL)
async_session = async_sessionmaker(engine, expire_on_commit=False)

4
foxnouns/db/base.py Normal file
View file

@ -0,0 +1,4 @@
from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
pass

110
foxnouns/db/snowflake.py Normal file
View file

@ -0,0 +1,110 @@
import os
import threading
from datetime import datetime, timezone
from random import randrange
_local = threading.local()
def _get_increment() -> int:
if not hasattr(_local, "increment"):
_local.increment = randrange(0, 4095)
increment = _local.increment
_local.increment += 1
return increment
class Snowflake:
"""A Snowflake ID (https://en.wikipedia.org/wiki/Snowflake_ID).
This class wraps an integer and adds convenience functions."""
EPOCH = 1_640_995_200_000 # 2022-01-01 at 00:00:00 UTC
_raw: int
def __init__(self, src: int):
self._raw = src
def __str__(self) -> str:
return str(self.id)
def __repr__(self) -> str:
return f"Snowflake<{self.id}, {self.process}, {self.thread}, {self.increment}, {self.timestamp}>"
def __int__(self) -> int:
return self._raw
def __float__(self) -> float:
return float(self._raw)
def __lt__(self, y: "Snowflake"):
return self.id < y.id
def __le__(self, y: "Snowflake"):
return self.id <= y.id
def __eq__(self, y: "Snowflake"):
return self.id == y.id
def __ne__(self, y: "Snowflake"):
return self.id != y.id
def __gt__(self, y: "Snowflake"):
return self.id > y.id
def __ge__(self, y: "Snowflake"):
return self.id >= y.id
@property
def id(self) -> int:
"""The raw integer value of the snowflake."""
return self._raw
@property
def time(self) -> datetime:
"""The time embedded into the snowflake."""
return datetime.fromtimestamp(self.timestamp, tz=timezone.utc)
@property
def timestamp(self) -> float:
"""The unix timestamp embedded into the snowflake."""
return ((self._raw >> 22) + self.EPOCH) / 1000
@property
def process(self) -> int:
"""The process ID embedded into the snowflake."""
return (self._raw & 0x3E0000) >> 17
@property
def thread(self) -> int:
"""The thread ID embedded into the snowflake."""
return (self._raw & 0x1F000) >> 12
@property
def increment(self) -> int:
"""The increment embedded into the snowflake."""
return self._raw & 0xFFF
@classmethod
def generate(cls, time: datetime | None = None):
"""Generates a new snowflake.
If `time` is set, use that time for the snowflake, otherwise, use the current time.
"""
process_id = os.getpid()
thread_id = threading.get_native_id()
increment = _get_increment()
now = time if time else datetime.now(tz=timezone.utc)
timestamp = round(now.timestamp() * 1000) - cls.EPOCH
return cls(
timestamp << 22
| (process_id % 32) << 17
| (thread_id % 32) << 12
| (increment % 4096)
)
@classmethod
def generate_int(cls, time: datetime | None = None):
return cls.generate(time).id

5
foxnouns/db/sync.py Normal file
View file

@ -0,0 +1,5 @@
from sqlalchemy import create_engine
from foxnouns.settings import SYNC_DATABASE_URL
engine = create_engine(SYNC_DATABASE_URL)

109
foxnouns/db/user.py Normal file
View file

@ -0,0 +1,109 @@
from datetime import datetime
import enum
from itsdangerous.url_safe import URLSafeTimedSerializer
from sqlalchemy import Text, Integer, BigInteger, ForeignKey, DateTime
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.orm import Mapped, mapped_column, relationship
from .base import Base
from .snowflake import Snowflake
class User(Base):
__tablename__ = "users"
id: Mapped[int] = mapped_column(
BigInteger(), primary_key=True, default=Snowflake.generate_int
)
username: Mapped[str] = mapped_column(Text(), unique=True, nullable=False)
display_name: Mapped[str | None] = mapped_column(Text(), nullable=True)
bio: Mapped[str | None] = mapped_column(Text(), nullable=True)
tokens: Mapped[list["Token"]] = relationship(
back_populates="user", cascade="all, delete-orphan"
)
auth_methods: Mapped[list["AuthMethod"]] = relationship(
back_populates="user", cascade="all, delete-orphan"
)
def __repr__(self):
return f"User(id={self.id!r}, username={self.username!r})"
class Token(Base):
__tablename__ = "tokens"
id: Mapped[int] = mapped_column(
BigInteger(), primary_key=True, default=Snowflake.generate_int
)
expires_at: Mapped[datetime] = mapped_column(DateTime(), nullable=False)
scopes: Mapped[list[str]] = mapped_column(ARRAY(Text), nullable=False)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
user: Mapped[User] = relationship(back_populates="tokens")
def __repr__(self):
return f"Token(id={self.id!r}, user={self.user_id!r})"
def has_scope(self, scope: str):
"""Returns whether this token can be used for the given scope."""
# `*` is a special scope for site tokens, which grants access to all endpoints.
if "*" in self.scopes:
return True
# Some scopes have sub-scopes, indicated by a `.` (i.e. `user.edit` is contained in `user`)
# Tokens can have these narrower scopes given to them, or the wider, more privileged scopes
# This way, both `user` and `user.edit` tokens will grant access to `user.edit` endpoints.
return scope in self.scopes or scope.split(".")[0] in self.scopes
def token_str(self):
...
class AuthType(enum.IntEnum):
DISCORD = 1
GOOGLE = 2
TUMBLR = 3
FEDIVERSE = 4
EMAIL = 5
class AuthMethod(Base):
__tablename__ = "auth_methods"
id: Mapped[int] = mapped_column(
BigInteger(), primary_key=True, default=Snowflake.generate_int
)
auth_type: Mapped[AuthType] = mapped_column(Integer(), nullable=False)
remote_id: Mapped[str] = mapped_column(Text(), nullable=False)
remote_username: Mapped[str | None] = mapped_column(Text(), nullable=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
user: Mapped[User] = relationship(back_populates="auth_methods")
fediverse_app_id: Mapped[int] = mapped_column(
ForeignKey("fediverse_apps.id"), nullable=True
)
fediverse_app: Mapped["FediverseApp"] = relationship()
class FediverseInstanceType(enum.IntEnum):
MASTODON_API = 1
MISSKEY_API = 2
class FediverseApp(Base):
__tablename__ = "fediverse_apps"
id: Mapped[int] = mapped_column(
BigInteger(), primary_key=True, default=Snowflake.generate_int
)
instance: Mapped[str] = mapped_column(Text(), unique=True, nullable=False)
client_id: Mapped[str] = mapped_column(Text(), nullable=False)
client_secret: Mapped[str] = mapped_column(Text(), nullable=False)
instance_type: Mapped[FediverseInstanceType] = mapped_column(
Integer(), nullable=False
)

27
foxnouns/db/util.py Normal file
View file

@ -0,0 +1,27 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from quart import g
from .user import User
from foxnouns.exceptions import ForbiddenError
async def user_from_ref(session: AsyncSession, user_ref: str):
"""Returns a user from a `user_ref` value. If `user_ref` is `@me`, returns the current user.
Otherwise, tries to convert the user to a snowflake ID and queries that. Otherwise, returns a user with that username.
"""
query = select(User)
if user_ref == "@me":
if "user" in g:
query = query.where(User.id == g.user.id)
else:
raise ForbiddenError("Not authenticated")
else:
try:
id = int(user_ref)
query = query.where(User.id == id)
except ValueError:
query = query.where(User.username == user_ref)
return await session.scalar(query)