Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add owner_id to extractors #44

Merged
merged 14 commits into from
Mar 20, 2024
18 changes: 18 additions & 0 deletions backend/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ class Extractor(TimestampedModel):
server_default="",
comment="The name of the extractor.",
)
owner_id = Column(
UUID(as_uuid=True),
nullable=False,
comment="Owner uuid.",
)
schema = Column(
JSONB,
nullable=False,
Expand Down Expand Up @@ -168,3 +173,16 @@ class Extractor(TimestampedModel):

def __repr__(self) -> str:
return f"<Extractor(id={self.uuid}, description={self.description})>"


def validate_extractor_owner(
session: Session, extractor_id: UUID, owner_id: UUID
) -> Extractor:
"""Validate the extractor id."""
extractor = (
session.query(Extractor).filter_by(uuid=extractor_id, owner_id=owner_id).first()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.first() -> should usually be .scalar() or .one() if there are unique constraints

)
if extractor is None:
return False
else:
return True
20 changes: 16 additions & 4 deletions backend/server/api/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from typing import Any, List
from uuid import UUID

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Cookie, Depends, HTTPException
from sqlalchemy.orm import Session
from typing_extensions import Annotated, TypedDict

from db.models import Example, get_session
from db.models import Example, get_session, validate_extractor_owner

router = APIRouter(
prefix="/examples",
Expand Down Expand Up @@ -36,8 +36,12 @@ def create(
create_request: CreateExample,
*,
session: Session = Depends(get_session),
owner_id: UUID = Cookie(...),
) -> CreateExampleResponse:
"""Endpoint to create an example."""
if not validate_extractor_owner(session, create_request["extractor_id"], owner_id):
raise HTTPException(status_code=404, detail="Extractor not found for owner.")

instance = Example(
extractor_id=create_request["extractor_id"],
content=create_request["content"],
Expand All @@ -55,8 +59,11 @@ def list(
limit: int = 10,
offset: int = 0,
session=Depends(get_session),
owner_id: UUID = Cookie(...),
) -> List[Any]:
"""Endpoint to get all examples."""
if not validate_extractor_owner(session, extractor_id, owner_id):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works

raise HTTPException(status_code=404, detail="Extractor not found for owner.")
return (
session.query(Example)
.filter(Example.extractor_id == extractor_id)
Expand All @@ -68,7 +75,12 @@ def list(


@router.delete("/{uuid}")
def delete(uuid: UUID, *, session: Session = Depends(get_session)) -> None:
def delete(
uuid: UUID, *, session: Session = Depends(get_session), owner_id: UUID = Cookie(...)
) -> None:
"""Endpoint to delete an example."""
session.query(Example).filter(Example.uuid == str(uuid)).delete()
extractor_id = session.query(Example).filter_by(uuid=str(uuid)).first().extractor_id
if not validate_extractor_owner(session, extractor_id, owner_id):
raise HTTPException(status_code=404, detail="Extractor not found for owner.")
session.query(Example).filter_by(uuid=str(uuid)).delete()
session.commit()
14 changes: 10 additions & 4 deletions backend/server/api/extract.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Literal, Optional
from uuid import UUID

from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from fastapi import APIRouter, Cookie, Depends, File, Form, HTTPException, UploadFile
from sqlalchemy.orm import Session
from typing_extensions import Annotated

from db.models import Extractor, get_session
from extraction.parsing import parse_binary_input
from server.extraction_runnable import ExtractResponse, extract_entire_document
from server.models import DEFAULT_MODEL
from server.retrieval import extract_from_content

router = APIRouter(
Expand All @@ -24,8 +25,9 @@ async def extract_using_existing_extractor(
text: Optional[str] = Form(None),
mode: Literal["entire_document", "retrieval"] = Form("entire_document"),
file: Optional[UploadFile] = File(None),
model_name: Optional[str] = Form("default"),
model_name: Optional[str] = Form(DEFAULT_MODEL),
session: Session = Depends(get_session),
owner_id: UUID = Cookie(...),
) -> ExtractResponse:
"""Endpoint that is used with an existing extractor.

Expand All @@ -35,9 +37,13 @@ async def extract_using_existing_extractor(
if text is None and file is None:
raise HTTPException(status_code=422, detail="No text or file provided.")

extractor = session.query(Extractor).filter(Extractor.uuid == extractor_id).scalar()
extractor = (
session.query(Extractor)
.filter_by(uuid=extractor_id, owner_id=owner_id)
.scalar()
)
if extractor is None:
raise HTTPException(status_code=404, detail="Extractor not found.")
raise HTTPException(status_code=404, detail="Extractor not found for owner.")

if text:
text_ = text
Expand Down
38 changes: 29 additions & 9 deletions backend/server/api/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from typing import Any, Dict, List
from uuid import UUID, uuid4

from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Cookie, Depends, HTTPException
from pydantic import BaseModel, Field, validator
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

from db.models import Extractor, SharedExtractors, get_session
from db.models import Extractor, SharedExtractors, get_session, validate_extractor_owner
from server.validators import validate_json_schema

router = APIRouter(
Expand Down Expand Up @@ -60,6 +60,7 @@ def share(
uuid: UUID,
*,
session: Session = Depends(get_session),
owner_id: UUID = Cookie(...),
) -> ShareExtractorResponse:
"""Endpoint to share an extractor.

Expand All @@ -73,6 +74,8 @@ def share(
Returns:
The UUID for the shared extractor.
"""
if not validate_extractor_owner(session, uuid, owner_id):
raise HTTPException(status_code=404, detail="Extractor not found for owner.")
# Check if the extractor is already shared
shared_extractor = (
session.query(SharedExtractors)
Expand Down Expand Up @@ -104,12 +107,16 @@ def share(

@router.post("")
def create(
create_request: CreateExtractor, *, session: Session = Depends(get_session)
create_request: CreateExtractor,
*,
session: Session = Depends(get_session),
owner_id: UUID = Cookie(...),
) -> CreateExtractorResponse:
"""Endpoint to create an extractor."""

instance = Extractor(
name=create_request.name,
owner_id=owner_id,
schema=create_request.json_schema,
description=create_request.description,
instruction=create_request.instruction,
Expand All @@ -120,11 +127,15 @@ def create(


@router.get("/{uuid}")
def get(uuid: UUID, *, session: Session = Depends(get_session)) -> Dict[str, Any]:
def get(
uuid: UUID, *, session: Session = Depends(get_session), owner_id: UUID = Cookie(...)
) -> Dict[str, Any]:
"""Endpoint to get an extractor."""
extractor = session.query(Extractor).filter(Extractor.uuid == str(uuid)).scalar()
extractor = (
session.query(Extractor).filter_by(uuid=str(uuid), owner_id=owner_id).scalar()
)
if extractor is None:
raise HTTPException(status_code=404, detail="Extractor not found.")
raise HTTPException(status_code=404, detail="Extractor not found for owner.")
return {
"uuid": extractor.uuid,
"name": extractor.name,
Expand All @@ -140,13 +151,22 @@ def list(
limit: int = 10,
offset: int = 0,
session=Depends(get_session),
owner_id: UUID = Cookie(...),
) -> List[Any]:
"""Endpoint to get all extractors."""
return session.query(Extractor).limit(limit).offset(offset).all()
return (
session.query(Extractor)
.filter_by(owner_id=owner_id)
.limit(limit)
.offset(offset)
.all()
)


@router.delete("/{uuid}")
def delete(uuid: UUID, *, session: Session = Depends(get_session)) -> None:
def delete(
uuid: UUID, *, session: Session = Depends(get_session), owner_id: UUID = Cookie(...)
) -> None:
"""Endpoint to delete an extractor."""
session.query(Extractor).filter(Extractor.uuid == str(uuid)).delete()
session.query(Extractor).filter_by(uuid=str(uuid), owner_id=owner_id).delete()
session.commit()
4 changes: 2 additions & 2 deletions backend/server/extraction_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from db.models import Example, Extractor
from extraction.utils import update_json_schema
from server.models import get_chunk_size, get_model
from server.models import DEFAULT_MODEL, get_chunk_size, get_model
from server.validators import validate_json_schema


Expand Down Expand Up @@ -188,7 +188,7 @@ async def extract_entire_document(
text_splitter = TokenTextSplitter(
chunk_size=get_chunk_size(model_name),
chunk_overlap=20,
model_name=model_name,
model_name=DEFAULT_MODEL,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated bug. for now we'll use gpt 3.5 tokenizer everywhere.

)
texts = text_splitter.split_text(content)
extraction_requests = [
Expand Down
80 changes: 61 additions & 19 deletions backend/tests/unit_tests/api/test_api_defining_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ async def test_extractors_api() -> None:
"""This will test a few of the extractors API endpoints."""
# First verify that the database is empty
async with get_async_client() as client:
response = await client.get("/extractors")
owner_id = str(uuid.uuid4())
cookies = {"owner_id": owner_id}
response = await client.get("/extractors", cookies=cookies)
assert response.status_code == 200
assert response.json() == []

Expand All @@ -18,22 +20,39 @@ async def test_extractors_api() -> None:
"schema": {"type": "object"},
"instruction": "Test Instruction",
}
response = await client.post("/extractors", json=create_request)
response = await client.post(
"/extractors", json=create_request, cookies=cookies
)
assert response.status_code == 200

# Verify that the extractor was created
response = await client.get("/extractors")
response = await client.get("/extractors", cookies=cookies)
assert response.status_code == 200
get_response = response.json()
assert len(get_response) == 1

# Check cookies
bad_cookies = {"owner_id": str(uuid.uuid4())}
bad_response = await client.get("/extractors", cookies=bad_cookies)
assert bad_response.status_code == 200
assert len(bad_response.json()) == 0

# Check we need cookie to delete
uuid_str = get_response[0]["uuid"]
_ = uuid.UUID(uuid_str) # assert valid uuid
bad_response = await client.delete(
f"/extractors/{uuid_str}", cookies=bad_cookies
)
# Check extractor was not deleted
response = await client.get("/extractors", cookies=cookies)
assert len(response.json()) == 1

# Verify that we can delete an extractor
get_response = response.json()
uuid_str = get_response[0]["uuid"]
_ = uuid.UUID(uuid_str) # assert valid uuid
response = await client.delete(f"/extractors/{uuid_str}")
response = await client.delete(f"/extractors/{uuid_str}", cookies=cookies)
assert response.status_code == 200

get_response = await client.get("/extractors")
get_response = await client.get("/extractors", cookies=cookies)
assert get_response.status_code == 200
assert get_response.json() == []

Expand All @@ -43,39 +62,53 @@ async def test_extractors_api() -> None:
"schema": {"type": "object"},
"instruction": "Test Instruction",
}
response = await client.post("/extractors", json=create_request)
response = await client.post(
"/extractors", json=create_request, cookies=cookies
)
assert response.status_code == 200

# Verify that the extractor was created
response = await client.get("/extractors")
response = await client.get("/extractors", cookies=cookies)
assert response.status_code == 200
assert len(response.json()) == 1

# Verify that we can delete an extractor
get_response = response.json()
uuid_str = get_response[0]["uuid"]
_ = uuid.UUID(uuid_str) # assert valid uuid
response = await client.delete(f"/extractors/{uuid_str}")
response = await client.delete(f"/extractors/{uuid_str}", cookies=cookies)
assert response.status_code == 200

get_response = await client.get("/extractors")
get_response = await client.get("/extractors", cookies=cookies)
assert get_response.status_code == 200
assert get_response.json() == []

# Verify that we can create an extractor
# Verify that we can create an extractor, including other properties
owner_id = str(uuid.uuid4())
create_request = {
"name": "my extractor",
"description": "Test Description",
"schema": {"type": "object"},
"instruction": "Test Instruction",
}
response = await client.post("/extractors", json=create_request)
response = await client.post(
"/extractors", json=create_request, cookies=cookies
)
extractor_uuid = response.json()["uuid"]
assert response.status_code == 200
response = await client.get(f"/extractors/{extractor_uuid}", cookies=cookies)
response_data = response.json()
assert extractor_uuid == response_data["uuid"]
assert "my extractor" == response_data["name"]
assert "owner_id" not in response_data


async def test_sharing_extractor() -> None:
"""Test sharing an extractor."""
async with get_async_client() as client:
response = await client.get("/extractors")
owner_id = str(uuid.uuid4())
cookies = {"owner_id": owner_id}
response = await client.get("/extractors", cookies=cookies)
assert response.status_code == 200
assert response.json() == []
# Verify that we can create an extractor
Expand All @@ -85,23 +118,32 @@ async def test_sharing_extractor() -> None:
"schema": {"type": "object"},
"instruction": "Test Instruction",
}
response = await client.post("/extractors", json=create_request)
response = await client.post(
"/extractors", json=create_request, cookies=cookies
)
assert response.status_code == 200

uuid = response.json()["uuid"]
uuid_str = response.json()["uuid"]

# Verify that the extractor was created
response = await client.post(f"/extractors/{uuid}/share")
# Generate a share uuid
response = await client.post(f"/extractors/{uuid_str}/share", cookies=cookies)
assert response.status_code == 200
assert "share_uuid" in response.json()
share_uuid = response.json()["share_uuid"]

# Test idempotency
response = await client.post(f"/extractors/{uuid}/share")
response = await client.post(f"/extractors/{uuid_str}/share", cookies=cookies)
assert response.status_code == 200
assert "share_uuid" in response.json()
assert response.json()["share_uuid"] == share_uuid

# Check cookies
bad_cookies = {"owner_id": str(uuid.uuid4())}
response = await client.post(
f"/extractors/{uuid_str}/share", cookies=bad_cookies
)
assert response.status_code == 404

# Check that we can retrieve the shared extractor
response = await client.get(f"/s/{share_uuid}")
assert response.status_code == 200
Expand Down
Loading