Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 42 additions & 10 deletions backend/api/depends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,72 @@

from fastapi import Depends, Header
from sqlalchemy import select
from sqlalchemy.sql.expression import any_
from sqlalchemy.ext.asyncio import AsyncSession

from backend.api.entity import BaseUser
from common.exception import CommonException
from common.trace_info import TraceInfo
from backend.db import session_context
from backend.db.entity import ResourceDB as ResourceDB, Namespace as NamespaceDB
from backend.db import entity as db


def get_trace_info(trace_id: Annotated[str | None, Header()] = None) -> TraceInfo:
return TraceInfo(trace_id=trace_id)
def get_trace_info(x_trace_id: Annotated[str | None, Header()] = None) -> TraceInfo:
return TraceInfo(trace_id=x_trace_id)


async def get_session() -> AsyncSession:
async with session_context() as session:
yield session

def _mock_user() -> db.User:
return db.User(
user_id="mock_user_id",
username="mock_username",
email="mock_email@example.com",
nickname="mock_nickname",
password="mock_password",
role={},
api_keys=[]
)

async def _get_user() -> BaseUser:
return BaseUser(user_id="0" * 22)
async def _get_user() -> db.User:
return _mock_user()


async def _get_namespace_by_name(namespace: str, session: AsyncSession = Depends(get_session)) -> NamespaceDB:
async def _get_user_by_api_key(authorization: Annotated[str | None, Header()] = None, session: AsyncSession = Depends(get_session)) -> db.User:
if not authorization:
raise CommonException(code=401, error="Authorization required")
return _mock_user()
username, api_key = authorization.split(",")

with session.no_autoflush:
query = select(db.User).where(db.User.username == username, db.User.deleted_at.is_(None))
result = await session.execute(query)
user: db.User = result.scalar()
if not user:
raise CommonException(code=401, error="User not found")

api_key_query = select(db.APIKey).where(db.APIKey.api_key == api_key, db.APIKey.user_id == user.user_id)
api_key_result = await session.execute(api_key_query)
api_key_orm: db.APIKey = api_key_result.scalar()
if not api_key_orm:
raise CommonException(code=401, error="Invalid API key")

return user

async def _get_namespace_by_name(namespace: str, session: AsyncSession = Depends(get_session)) -> db.Namespace:
with session.no_autoflush:
query = select(NamespaceDB).where(NamespaceDB.name == namespace, NamespaceDB.deleted_at.is_(None))
query = select(db.Namespace).where(db.Namespace.name == namespace, db.Namespace.deleted_at.is_(None))
result = await session.execute(query)
namespace_orm: NamespaceDB = result.scalar()
namespace_orm: db.Namespace = result.scalar()
if not namespace_orm:
raise CommonException(code=404, error="Namespace not found")
return namespace_orm


async def _get_resource(resource_id: str, session: AsyncSession = Depends(get_session)) -> ResourceDB:
resource_orm: ResourceDB = await session.get(ResourceDB, resource_id) # noqa
async def _get_resource(resource_id: str, session: AsyncSession = Depends(get_session)) -> db.Resource:
resource_orm: db.Resource = await session.get(db.Resource, resource_id) # noqa
if not resource_orm:
raise CommonException(code=404, error="Resource not found")
return resource_orm
28 changes: 28 additions & 0 deletions backend/api/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,35 @@ class Resource(BaseDBModel):
resource_id: Optional[str] = Field(default=None, alias="id")
child_count: Optional[int] = Field(default=None)

attrs: Optional[dict] = Field(default=None)


class BaseUser(BaseAPIModel):
user_id: Optional[str] = Field(default=None, alias="id")
username: Optional[str] = Field(default=None)
api_keys: Optional[List[dict]] = Field(default=None)


class Task(BaseAPIModel):
task_id: str
priority: int

namespace_id: str
user_id: str

function: str
input: dict
payload: dict | None = Field(default=None, description="Task payload, would pass through to the webhook")

output: dict | None = None
exception: dict | None = None

started_at: datetime | None = None
ended_at: datetime | None = None
canceled_at: datetime | None = None

concurrency_threshold: int = Field(description="Concurrency threshold")

created_at: datetime
updated_at: datetime | None = None
deleted_at: datetime | None = None
26 changes: 13 additions & 13 deletions backend/api/namespaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from sqlalchemy.ext.asyncio import AsyncSession

from backend.api.depends import get_session, _get_user
from backend.api.entity import BaseUser, IDResponse, BaseAPIModel
from backend.api.entity import IDResponse, BaseAPIModel
from common.exception import CommonException
from backend.db.entity import Namespace as NamespaceDB, ResourceDB
from backend.db import entity as db

router_namespaces = APIRouter(prefix="/namespaces", tags=["namespaces"])

Expand All @@ -25,11 +25,11 @@ class NamespaceResponse(BaseAPIModel):
@router_namespaces.get("", response_model=List[NamespaceResponse])
async def get_namespaces(
session: AsyncSession = Depends(get_session),
user: BaseUser = Depends(_get_user)
user: db.User = Depends(_get_user)
):
result = await session.execute(select(NamespaceDB).where(
NamespaceDB.deleted_at.is_(None),
or_(NamespaceDB.owner_id == user.user_id, NamespaceDB.collaborators.any(user.user_id))
result = await session.execute(select(db.Namespace).where(
db.Namespace.deleted_at.is_(None),
or_(db.Namespace.owner_id == user.user_id, db.Namespace.collaborators.any(user.user_id))
))
return result.scalars().all()

Expand All @@ -38,18 +38,18 @@ async def get_namespaces(
async def create_namespace(
namespace: NamespaceCreate,
session: AsyncSession = Depends(get_session),
user: BaseUser = Depends(_get_user)
user: db.User = Depends(_get_user)
):
new_namespace_db = NamespaceDB(owner_id=user.user_id, **namespace.model_dump())
new_namespace_db = db.Namespace(owner_id=user.user_id, **namespace.model_dump())
session.add(new_namespace_db)

await session.commit()
await session.refresh(new_namespace_db)

parameter = {"namespace_id": new_namespace_db.namespace_id, "user_id": user.user_id, "resource_type": "folder"}

teamspace_root = ResourceDB(space_type="teamspace", **parameter)
private_root = ResourceDB(space_type="private", **parameter)
teamspace_root = db.Resource(space_type="teamspace", **parameter)
private_root = db.Resource(space_type="private", **parameter)

session.add(teamspace_root)
session.add(private_root)
Expand All @@ -62,16 +62,16 @@ async def create_namespace(
async def delete_namespace(
namespace_id: str,
session: AsyncSession = Depends(get_session),
user: BaseUser = Depends(_get_user)
user: db.User = Depends(_get_user)
):
"""
Delete a namespace
"""
result = await session.execute(select(NamespaceDB).where(NamespaceDB.namespace_id == namespace_id))
result = await session.execute(select(db.Namespace).where(db.Namespace.namespace_id == namespace_id))
existing_namespace = result.scalar_one_or_none()

if not existing_namespace or existing_namespace.owner_id != user.user_id:
raise CommonException(code=status.HTTP_404_NOT_FOUND, error="Namespace not found")

await session.execute(delete(NamespaceDB).where(NamespaceDB.namespace_id == namespace_id))
await session.execute(delete(db.Namespace).where(db.Namespace.namespace_id == namespace_id))
await session.commit()
Loading