Skip to content

Commit

Permalink
update retrieval (#76)
Browse files Browse the repository at this point in the history
Update retrieval mode to accommodate multiple retrieved chunks.
  • Loading branch information
ccurme authored Mar 21, 2024
1 parent 63966c1 commit f400fa4
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 43 deletions.
48 changes: 24 additions & 24 deletions backend/server/extraction_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,29 +62,6 @@ class ExtractResponse(TypedDict):
data: List[Any]


def _deduplicate(
extract_responses: Sequence[ExtractResponse],
) -> ExtractResponse:
"""Deduplicate the results.
The deduplication is done by comparing the serialized JSON of each of the results
and only keeping the unique ones.
"""
unique_extracted = []
seen = set()
for response in extract_responses:
for data_item in response["data"]:
# Serialize the data item for comparison purposes
serialized = json.dumps(data_item, sort_keys=True)
if serialized not in seen:
seen.add(serialized)
unique_extracted.append(data_item)

return {
"data": unique_extracted,
}


def _cast_example_to_dict(example: Example) -> Dict[str, Any]:
"""Cast example record to dictionary."""
return {
Expand Down Expand Up @@ -147,6 +124,29 @@ def _make_prompt_template(
# PUBLIC API


def deduplicate(
extract_responses: Sequence[ExtractResponse],
) -> ExtractResponse:
"""Deduplicate the results.
The deduplication is done by comparing the serialized JSON of each of the results
and only keeping the unique ones.
"""
unique_extracted = []
seen = set()
for response in extract_responses:
for data_item in response["data"]:
# Serialize the data item for comparison purposes
serialized = json.dumps(data_item, sort_keys=True)
if serialized not in seen:
seen.add(serialized)
unique_extracted.append(data_item)

return {
"data": unique_extracted,
}


def get_examples_from_extractor(extractor: Extractor) -> List[Dict[str, Any]]:
"""Get examples from an extractor."""
return [_cast_example_to_dict(example) for example in extractor.examples]
Expand Down Expand Up @@ -206,4 +206,4 @@ async def extract_entire_document(
extraction_requests, {"max_concurrency": 1}
)
# Deduplicate the results
return _deduplicate(extract_responses)
return deduplicate(extract_responses)
22 changes: 8 additions & 14 deletions backend/server/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,22 @@

from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_core.runnables import RunnableLambda
from langchain_openai import OpenAIEmbeddings

from db.models import Extractor
from server.extraction_runnable import (
ExtractRequest,
ExtractResponse,
deduplicate,
extraction_runnable,
get_examples_from_extractor,
)


def _get_top_doc_content(docs: List[Document]) -> str:
if docs:
return docs[0].page_content
else:
return ""


def _make_extract_request(input_dict: Dict[str, Any]) -> ExtractRequest:
return ExtractRequest(**input_dict)
def _make_extract_requests(input_dict: Dict[str, Any]) -> List[ExtractRequest]:
docs = input_dict.pop("text")
return [ExtractRequest(text=doc.page_content, **input_dict) for doc in docs]


async def extract_from_content(
Expand All @@ -50,14 +44,14 @@ async def extract_from_content(

runnable = (
{
"text": itemgetter("query") | retriever | _get_top_doc_content,
"text": itemgetter("query") | retriever,
"schema": itemgetter("schema"),
"instructions": lambda x: x.get("instructions"),
"examples": lambda x: x.get("examples"),
"model_name": lambda x: x.get("model_name"),
}
| RunnableLambda(_make_extract_request)
| extraction_runnable
| RunnableLambda(_make_extract_requests)
| extraction_runnable.abatch
)
schema = extractor.schema
examples = get_examples_from_extractor(extractor)
Expand All @@ -71,4 +65,4 @@ async def extract_from_content(
"model_name": model_name,
}
)
return result
return deduplicate(result)
10 changes: 5 additions & 5 deletions backend/tests/unit_tests/test_deduplication.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from server.extraction_runnable import ExtractResponse, _deduplicate
from server.extraction_runnable import ExtractResponse, deduplicate


async def test_deduplication_different_resutls() -> None:
"""Test deduplication of extraction results."""
result = _deduplicate(
result = deduplicate(
[
{"data": [{"name": "Chester", "age": 42}]},
{"data": [{"name": "Jane", "age": 42}]},
Expand All @@ -17,7 +17,7 @@ async def test_deduplication_different_resutls() -> None:
)
assert expected == result

result = _deduplicate(
result = deduplicate(
[
{
"data": [
Expand All @@ -44,11 +44,11 @@ async def test_deduplication_different_resutls() -> None:
assert expected == result

# Test with data being a list of strings
result = _deduplicate([{"data": ["1", "2"]}, {"data": ["1", "3"]}])
result = deduplicate([{"data": ["1", "2"]}, {"data": ["1", "3"]}])
expected = ExtractResponse(data=["1", "2", "3"])
assert expected == result

# Test with data being a mix of integer and string
result = _deduplicate([{"data": [1, "2"]}, {"data": ["1", "3"]}])
result = deduplicate([{"data": [1, "2"]}, {"data": ["1", "3"]}])
expected = ExtractResponse(data=[1, "2", "1", "3"])
assert expected == result

0 comments on commit f400fa4

Please sign in to comment.