Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update retrieval #76

Merged
merged 2 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions backend/server/extraction_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,29 +62,6 @@ class ExtractResponse(TypedDict):
data: List[Any]


def _deduplicate(
extract_responses: Sequence[ExtractResponse],
) -> ExtractResponse:
"""Deduplicate the results.

The deduplication is done by comparing the serialized JSON of each of the results
and only keeping the unique ones.
"""
unique_extracted = []
seen = set()
for response in extract_responses:
for data_item in response["data"]:
# Serialize the data item for comparison purposes
serialized = json.dumps(data_item, sort_keys=True)
if serialized not in seen:
seen.add(serialized)
unique_extracted.append(data_item)

return {
"data": unique_extracted,
}


def _cast_example_to_dict(example: Example) -> Dict[str, Any]:
"""Cast example record to dictionary."""
return {
Expand Down Expand Up @@ -147,6 +124,29 @@ def _make_prompt_template(
# PUBLIC API


def deduplicate(
extract_responses: Sequence[ExtractResponse],
) -> ExtractResponse:
"""Deduplicate the results.

The deduplication is done by comparing the serialized JSON of each of the results
and only keeping the unique ones.
"""
unique_extracted = []
seen = set()
for response in extract_responses:
for data_item in response["data"]:
# Serialize the data item for comparison purposes
serialized = json.dumps(data_item, sort_keys=True)
if serialized not in seen:
seen.add(serialized)
unique_extracted.append(data_item)

return {
"data": unique_extracted,
}


def get_examples_from_extractor(extractor: Extractor) -> List[Dict[str, Any]]:
"""Get examples from an extractor."""
return [_cast_example_to_dict(example) for example in extractor.examples]
Expand Down Expand Up @@ -206,4 +206,4 @@ async def extract_entire_document(
extraction_requests, {"max_concurrency": 1}
)
# Deduplicate the results
return _deduplicate(extract_responses)
return deduplicate(extract_responses)
22 changes: 8 additions & 14 deletions backend/server/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,22 @@

from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_core.runnables import RunnableLambda
from langchain_openai import OpenAIEmbeddings

from db.models import Extractor
from server.extraction_runnable import (
ExtractRequest,
ExtractResponse,
deduplicate,
extraction_runnable,
get_examples_from_extractor,
)


def _get_top_doc_content(docs: List[Document]) -> str:
if docs:
return docs[0].page_content
else:
return ""


def _make_extract_request(input_dict: Dict[str, Any]) -> ExtractRequest:
return ExtractRequest(**input_dict)
def _make_extract_requests(input_dict: Dict[str, Any]) -> List[ExtractRequest]:
docs = input_dict.pop("text")
return [ExtractRequest(text=doc.page_content, **input_dict) for doc in docs]


async def extract_from_content(
Expand All @@ -50,14 +44,14 @@ async def extract_from_content(

runnable = (
{
"text": itemgetter("query") | retriever | _get_top_doc_content,
"text": itemgetter("query") | retriever,
"schema": itemgetter("schema"),
"instructions": lambda x: x.get("instructions"),
"examples": lambda x: x.get("examples"),
"model_name": lambda x: x.get("model_name"),
}
| RunnableLambda(_make_extract_request)
| extraction_runnable
| RunnableLambda(_make_extract_requests)
| extraction_runnable.abatch
)
schema = extractor.schema
examples = get_examples_from_extractor(extractor)
Expand All @@ -71,4 +65,4 @@ async def extract_from_content(
"model_name": model_name,
}
)
return result
return deduplicate(result)
10 changes: 5 additions & 5 deletions backend/tests/unit_tests/test_deduplication.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from server.extraction_runnable import ExtractResponse, _deduplicate
from server.extraction_runnable import ExtractResponse, deduplicate


async def test_deduplication_different_resutls() -> None:
"""Test deduplication of extraction results."""
result = _deduplicate(
result = deduplicate(
[
{"data": [{"name": "Chester", "age": 42}]},
{"data": [{"name": "Jane", "age": 42}]},
Expand All @@ -17,7 +17,7 @@ async def test_deduplication_different_resutls() -> None:
)
assert expected == result

result = _deduplicate(
result = deduplicate(
[
{
"data": [
Expand All @@ -44,11 +44,11 @@ async def test_deduplication_different_resutls() -> None:
assert expected == result

# Test with data being a list of strings
result = _deduplicate([{"data": ["1", "2"]}, {"data": ["1", "3"]}])
result = deduplicate([{"data": ["1", "2"]}, {"data": ["1", "3"]}])
expected = ExtractResponse(data=["1", "2", "3"])
assert expected == result

# Test with data being a mix of integer and string
result = _deduplicate([{"data": [1, "2"]}, {"data": ["1", "3"]}])
result = deduplicate([{"data": [1, "2"]}, {"data": ["1", "3"]}])
expected = ExtractResponse(data=[1, "2", "1", "3"])
assert expected == result
Loading