Skip to content

Commit

Permalink
use with_structured_output (#36)
Browse files Browse the repository at this point in the history
Here we pass JSON schema all the way through, relying on
`with_structured_output` to control extraction methodology.

#32
  • Loading branch information
ccurme authored Mar 18, 2024
1 parent 8cf26d7 commit bdfeea5
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 50 deletions.
15 changes: 5 additions & 10 deletions backend/extraction/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ def _rm_titles(kv: dict) -> dict:
# PUBLIC API


def convert_json_schema_to_openai_schema(
def update_json_schema(
schema: dict,
*,
rm_titles: bool = True,
multi: bool = True,
) -> dict:
"""Convert JSON schema to a corresponding OpenAI function call."""
"""Add missing fields to JSON schema and add support for multiple records."""
if multi:
# Wrap the schema in an object called "Root" with a property called: "data"
# which will be a json array of the original schema.
Expand All @@ -43,10 +42,6 @@ def convert_json_schema_to_openai_schema(
else:
raise NotImplementedError("Only multi is supported for now.")

schema_.pop("definitions", None)

return {
"name": "extractor",
"description": "Extract information matching the given schema.",
"parameters": _rm_titles(schema_) if rm_titles else schema_,
}
schema_["title"] = "extractor"
schema_["description"] = "Extract information matching the given schema."
return schema_
17 changes: 7 additions & 10 deletions backend/server/extraction_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from fastapi import HTTPException
from jsonschema import Draft202012Validator, exceptions
from langchain.chains.openai_functions import create_openai_fn_runnable
from langchain.text_splitter import TokenTextSplitter
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
Expand All @@ -15,9 +14,7 @@
from typing_extensions import TypedDict

from db.models import Example, Extractor
from extraction.utils import (
convert_json_schema_to_openai_schema,
)
from extraction.utils import update_json_schema
from server.settings import CHUNK_SIZE, MODEL_NAME, get_model
from server.validators import validate_json_schema

Expand Down Expand Up @@ -161,22 +158,22 @@ def get_examples_from_extractor(extractor: Extractor) -> List[Dict[str, Any]]:
async def extraction_runnable(extraction_request: ExtractRequest) -> ExtractResponse:
"""An end point to extract content from a given text object."""
# TODO: Add validation for model context window size
schema = extraction_request.json_schema
schema = update_json_schema(extraction_request.json_schema)
try:
Draft202012Validator.check_schema(schema)
except exceptions.ValidationError as e:
raise HTTPException(status_code=422, detail=f"Invalid schema: {e.message}")

openai_function = convert_json_schema_to_openai_schema(schema)
function_name = openai_function["name"]
prompt = _make_prompt_template(
extraction_request.instructions,
extraction_request.examples,
function_name,
schema["title"],
)
runnable = create_openai_fn_runnable(
functions=[openai_function], llm=model, prompt=prompt
# N.B. method must be consistent with examples in _make_prompt_template
runnable = prompt | model.with_structured_output(
schema=schema, method="function_calling"
)

return await runnable.ainvoke({"text": extraction_request.text})


Expand Down
2 changes: 1 addition & 1 deletion backend/tests/integration_tests/test_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class Person(BaseModel):
json={
"input": {
"text": text,
"schema": Person(),
"schema": Person.schema(),
"instructions": "Redact all names using the characters `######`",
"examples": examples,
}
Expand Down
57 changes: 28 additions & 29 deletions backend/tests/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from langchain.pydantic_v1 import BaseModel, Field

from extraction.utils import (
convert_json_schema_to_openai_schema,
)
from extraction.utils import update_json_schema
from server.extraction_runnable import ExtractionExample, _make_prompt_template


def test_convert_json_schema_to_openai_schema() -> None:
"""Test converting a JSON schema to an OpenAI schema."""
def test_update_json_schema() -> None:
"""Test updating JSON schema."""

class Person(BaseModel):
name: str = Field(..., description="The name of the person.")
Expand All @@ -33,33 +31,34 @@ class Person(BaseModel):
"type": "object",
}

openai_schema = convert_json_schema_to_openai_schema(schema)
assert openai_schema == {
"description": "Extract information matching the given schema.",
"name": "extractor",
"parameters": {
"properties": {
"data": {
"items": {
"properties": {
"age": {
"description": "The age of the person.",
"type": "integer",
},
"name": {
"description": "The name of the person.",
"type": "string",
},
updated_schema = update_json_schema(schema)
assert updated_schema == {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"title": "Person",
"type": "object",
"properties": {
"name": {
"title": "Name",
"description": "The name of the person.",
"type": "string",
},
"age": {
"title": "Age",
"description": "The age of the person.",
"type": "integer",
},
"required": ["name", "age"],
"type": "object",
},
"type": "array",
}
},
"required": ["data"],
"type": "object",
"required": ["name", "age"],
},
}
},
"required": ["data"],
"title": "extractor",
"description": "Extract information matching the given schema.",
}


Expand Down

0 comments on commit bdfeea5

Please sign in to comment.