Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use with_structured_output #36

Merged
merged 6 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions backend/extraction/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ def _rm_titles(kv: dict) -> dict:
# PUBLIC API


def convert_json_schema_to_openai_schema(
def update_json_schema(
schema: dict,
*,
rm_titles: bool = True,
multi: bool = True,
) -> dict:
"""Convert JSON schema to a corresponding OpenAI function call."""
"""Add missing fields to JSON schema and add support for multiple records."""
if multi:
# Wrap the schema in an object called "Root" with a property called: "data"
# which will be a json array of the original schema.
Expand All @@ -43,10 +42,6 @@ def convert_json_schema_to_openai_schema(
else:
raise NotImplementedError("Only multi is supported for now.")

schema_.pop("definitions", None)

return {
"name": "extractor",
"description": "Extract information matching the given schema.",
"parameters": _rm_titles(schema_) if rm_titles else schema_,
}
schema_["title"] = "extractor"
schema_["description"] = "Extract information matching the given schema."
return schema_
17 changes: 7 additions & 10 deletions backend/server/extraction_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from fastapi import HTTPException
from jsonschema import Draft202012Validator, exceptions
from langchain.chains.openai_functions import create_openai_fn_runnable
from langchain.text_splitter import TokenTextSplitter
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
Expand All @@ -15,9 +14,7 @@
from typing_extensions import TypedDict

from db.models import Example, Extractor
from extraction.utils import (
convert_json_schema_to_openai_schema,
)
from extraction.utils import update_json_schema
from server.settings import CHUNK_SIZE, MODEL_NAME, get_model
from server.validators import validate_json_schema

Expand Down Expand Up @@ -161,22 +158,22 @@ def get_examples_from_extractor(extractor: Extractor) -> List[Dict[str, Any]]:
async def extraction_runnable(extraction_request: ExtractRequest) -> ExtractResponse:
"""An end point to extract content from a given text object."""
# TODO: Add validation for model context window size
schema = extraction_request.json_schema
schema = update_json_schema(extraction_request.json_schema)
try:
Draft202012Validator.check_schema(schema)
except exceptions.ValidationError as e:
raise HTTPException(status_code=422, detail=f"Invalid schema: {e.message}")

openai_function = convert_json_schema_to_openai_schema(schema)
function_name = openai_function["name"]
prompt = _make_prompt_template(
extraction_request.instructions,
extraction_request.examples,
function_name,
schema["title"],
)
runnable = create_openai_fn_runnable(
functions=[openai_function], llm=model, prompt=prompt
# N.B. method must be consistent with examples in _make_prompt_template
runnable = prompt | model.with_structured_output(
schema=schema, method="function_calling"
)

return await runnable.ainvoke({"text": extraction_request.text})


Expand Down
2 changes: 1 addition & 1 deletion backend/tests/integration_tests/test_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class Person(BaseModel):
json={
"input": {
"text": text,
"schema": Person(),
"schema": Person.schema(),
"instructions": "Redact all names using the characters `######`",
"examples": examples,
}
Expand Down
57 changes: 28 additions & 29 deletions backend/tests/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from langchain.pydantic_v1 import BaseModel, Field

from extraction.utils import (
convert_json_schema_to_openai_schema,
)
from extraction.utils import update_json_schema
from server.extraction_runnable import ExtractionExample, _make_prompt_template


def test_convert_json_schema_to_openai_schema() -> None:
"""Test converting a JSON schema to an OpenAI schema."""
def test_update_json_schema() -> None:
"""Test updating JSON schema."""

class Person(BaseModel):
name: str = Field(..., description="The name of the person.")
Expand All @@ -33,33 +31,34 @@ class Person(BaseModel):
"type": "object",
}

openai_schema = convert_json_schema_to_openai_schema(schema)
assert openai_schema == {
"description": "Extract information matching the given schema.",
"name": "extractor",
"parameters": {
"properties": {
"data": {
"items": {
"properties": {
"age": {
"description": "The age of the person.",
"type": "integer",
},
"name": {
"description": "The name of the person.",
"type": "string",
},
updated_schema = update_json_schema(schema)
assert updated_schema == {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"title": "Person",
"type": "object",
"properties": {
"name": {
"title": "Name",
"description": "The name of the person.",
"type": "string",
},
"age": {
"title": "Age",
"description": "The age of the person.",
"type": "integer",
},
"required": ["name", "age"],
"type": "object",
},
"type": "array",
}
},
"required": ["data"],
"type": "object",
"required": ["name", "age"],
},
}
},
"required": ["data"],
"title": "extractor",
"description": "Extract information matching the given schema.",
}


Expand Down
Loading