Skip to content

Tests for new data, all tests passing #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,22 +1,4 @@
{
"pgsql.connections": [
{
"id": "92CC1089-BAD0-44A4-B071-A50A6EC12B67",
"groupId": "F3347CD6-9995-4EE9-8D98-D88DB010FA5B",
"authenticationType": "SqlLogin",
"connectTimeout": 15,
"applicationName": "vscode-pgsql",
"clientEncoding": "utf8",
"sslmode": "prefer",
"server": "localhost",
"user": "admin",
"password": "",
"savePassword": true,
"database": "postgres",
"profileName": "local-pg",
"expiresOn": 0
}
],
"python.testing.pytestArgs": [
"tests"
],
Expand Down
3 changes: 2 additions & 1 deletion convert_csv_json.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ast
import csv
import json
from typing import Any

# Read CSV file - Using the correct dialect to handle quotes properly
with open("pittsburgh_restaurants.csv", encoding="utf-8") as csv_file:
Expand All @@ -15,7 +16,7 @@
item = {}
for i in range(len(header)):
if i < len(row): # Ensure we don't go out of bounds
value = row[i].strip()
value: Any = row[i].strip()
# Check if the value looks like a JSON array
if value.startswith("[") and value.endswith("]"):
try:
Expand Down
21 changes: 10 additions & 11 deletions evals/generate_ground_truth.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,17 @@ def source_retriever() -> Generator[str, None, None]:
DATABASE_URI = f"postgresql://{DBUSER}:{DBPASS}@{DBHOST}/{DBNAME}"
engine = create_engine(DATABASE_URI, echo=False)
with Session(engine) as session:
# Fetch all products for a particular type
item_types = session.scalars(select(Item.type).distinct())
for item_type in item_types:
records = list(session.scalars(select(Item).filter(Item.type == item_type).order_by(Item.id)))
logger.info(f"Processing database records for type: {item_type}")
yield "\n\n".join([f"## Product ID: [{record.id}]\n" + record.to_str_for_rag() for record in records])
# Fetch all products for a particular type - depends on the database columns
# item_types = session.scalars(select(Item.type).distinct())
# for item_type in item_types:
# records = list(session.scalars(select(Item).filter(Item.type == item_type).order_by(Item.id)))
# logger.info(f"Processing database records for type: {item_type}")
# yield "\n\n".join([f"## Product ID: [{record.id}]\n" + record.to_str_for_rag() for record in records])
# Fetch each item individually
# records = list(session.scalars(select(Item).order_by(Item.id)))
# for record in records:
# logger.info(f"Processing database record: {record.name}")
# yield f"## Product ID: [{record.id}]\n" + record.to_str_for_rag()
# await self.openai_chat_client.chat.completions.create(
records = list(session.scalars(select(Item).order_by(Item.id)))
for record in records:
logger.info(f"Processing database record: {record.name}")
yield f"## Product ID: [{record.id}]\n" + record.to_str_for_rag()


def source_to_text(source) -> str:
Expand Down
2 changes: 1 addition & 1 deletion src/backend/fastapi_app/api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class ItemPublic(BaseModel):
id: str
name: str
cuisine: str
rating: int
rating: float
price_level: int
review_count: int
description: str
Expand Down
2 changes: 1 addition & 1 deletion src/backend/fastapi_app/postgres_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Item(Base):
id: Mapped[str] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column()
cuisine: Mapped[str] = mapped_column()
rating: Mapped[int] = mapped_column()
rating: Mapped[float] = mapped_column()
price_level: Mapped[int] = mapped_column()
review_count: Mapped[int] = mapped_column()
description: Mapped[str] = mapped_column()
Expand Down
4 changes: 2 additions & 2 deletions src/backend/fastapi_app/routes/api_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def format_as_ndjson(r: AsyncGenerator[RetrievalResponseDelta, None]) -> A


@router.get("/items/{id}", response_model=ItemPublic)
async def item_handler(database_session: DBSession, id: int) -> ItemPublic:
async def item_handler(database_session: DBSession, id: str) -> ItemPublic:
"""A simple API to get an item by ID."""
item = (await database_session.scalars(select(Item).where(Item.id == id))).first()
if not item:
Expand All @@ -55,7 +55,7 @@ async def item_handler(database_session: DBSession, id: int) -> ItemPublic:

@router.get("/similar", response_model=list[ItemWithDistance])
async def similar_handler(
context: CommonDeps, database_session: DBSession, id: int, n: int = 5
context: CommonDeps, database_session: DBSession, id: str, n: int = 5
) -> list[ItemWithDistance]:
"""A similarity API to find items similar to items with given ID."""
item = (await database_session.scalars(select(Item).where(Item.id == id))).first()
Expand Down
Loading
Loading