Skip to content

Commit

Permalink
keep track of response over refine chunks (run-llama#1062)
Browse files Browse the repository at this point in the history
Co-authored-by: Jerry Liu <jerryjliu98@gmail.com>
  • Loading branch information
logan-markewich and jerryjliu authored Apr 5, 2023
1 parent a4d7144 commit 9205e1f
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 23 deletions.
11 changes: 10 additions & 1 deletion gpt_index/indices/response/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ def refine_response_single(
refine_template,
context_msg=cur_text_chunk,
)
refine_template = self.refine_template.partial_format(
query_str=query_str, existing_answer=response
)

self._log_prompt_and_response(
formatted_prompt, response, log_prefix="Refined"
)
Expand Down Expand Up @@ -244,8 +248,13 @@ def _get_response_compact(
) -> RESPONSE_TEXT_TYPE:
"""Get compact response."""
# use prompt helper to fix compact text_chunks under the prompt limitation
# TODO: This is a temporary fix - reason it's temporary is that
# the refine template does not account for size of previous answer.
text_qa_template = self.text_qa_template.partial_format(query_str=query_str)
refine_template = self.refine_template.partial_format(query_str=query_str)

max_prompt = self._service_context.prompt_helper.get_biggest_prompt(
[self.text_qa_template, self.refine_template]
[text_qa_template, refine_template]
)
with temp_set_attrs(
self._service_context.prompt_helper, use_chunk_size_limit=False
Expand Down
11 changes: 10 additions & 1 deletion tests/indices/knowledge_graph/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,16 @@ def test_query(
index = GPTKnowledgeGraphIndex.from_documents(documents)
response = index.query("foo")
# when include_text is True, the first node is the raw text
assert str(response) == "foo:(foo, is, bar)"
# the second node is the query
rel_initial_text = (
"The following are knowledge triplets "
"in the form of (subset, predicate, object):"
)
expected_response = (
"foo:(foo, is, bar):" + rel_initial_text + ":('foo', 'is', 'bar')"
)

assert str(response) == expected_response
assert response.extra_info is not None
assert response.extra_info["kg_rel_map"] == {
"foo": [("bar", "is")],
Expand Down
27 changes: 21 additions & 6 deletions tests/indices/list/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,13 @@ def test_query(

query_str = "What is?"
response = index.query(query_str, mode="default", **query_kwargs)
assert str(response) == ("What is?:Hello world.")
expected_answer = (
"What is?:Hello world.:"
"This is a test.:"
"This is another test.:"
"This is a test v2."
)
assert str(response) == expected_answer
node_info = (
response.source_nodes[0].node.node_info
if response.source_nodes[0].node.node_info
Expand Down Expand Up @@ -323,11 +329,14 @@ def test_query_with_keywords(
query_str = "What is?"
query_kwargs.update({"required_keywords": ["test"]})
response = index.query(query_str, mode="default", **query_kwargs)
assert str(response) == ("What is?:This is a test.")
expected_answer = (
"What is?:This is a test.:" "This is another test.:" "This is a test v2."
)
assert str(response) == expected_answer

query_kwargs.update({"exclude_keywords": ["Hello"]})
query_kwargs.update({"exclude_keywords": ["test"], "required_keywords": []})
response = index.query(query_str, mode="default", **query_kwargs)
assert str(response) == ("What is?:This is a test.")
assert str(response) == ("What is?:Hello world.")


@patch_common
Expand Down Expand Up @@ -455,7 +464,13 @@ def test_async_query(
query_str = "What is?"
task = index.aquery(query_str, mode="default", **query_kwargs)
response = asyncio.run(task)
assert str(response) == ("What is?:Hello world.")
expected_answer = (
"What is?:Hello world.:"
"This is a test.:"
"This is another test.:"
"This is a test v2."
)
assert str(response) == expected_answer
node_info = (
response.source_nodes[0].node.node_info
if response.source_nodes[0].node.node_info
Expand All @@ -470,7 +485,7 @@ def test_async_query(
query_kwargs_copy["response_mode"] = "tree_summarize"
task = index.aquery(query_str, mode="default", **query_kwargs_copy)
response = asyncio.run(task)
assert str(response) == ("What is?:Hello world.")
assert str(response) == expected_answer
node_info = (
response.source_nodes[0].node.node_info
if response.source_nodes[0].node.node_info
Expand Down
22 changes: 13 additions & 9 deletions tests/indices/query/test_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ def test_recursive_query_list_tree(
# query should first pick the left root node, then pick list1
# within list1, it should go through the first document and second document
response = graph.query(query_str, query_configs=query_configs)
assert str(response) == ("What is?:What is?:This is a test v2.")
assert str(response) == (
"What is?:What is?:This is a test v2.:This is another test."
)


@patch.object(TokenTextSplitter, "split_text", side_effect=mock_token_splitter_newline)
Expand Down Expand Up @@ -197,7 +199,9 @@ def test_recursive_query_tree_list(
# query should first pick the left root node, then pick list1
# within list1, it should go through the first document and second document
response = graph.query(query_str, query_configs=query_configs)
assert str(response) == ("What is?:What is?:This is a test.")
assert str(response) == (
"What is?:What is?:This is a test.:What is?:This is a test v2."
)


@patch.object(TokenTextSplitter, "split_text", side_effect=mock_token_splitter_newline)
Expand Down Expand Up @@ -231,18 +235,18 @@ def test_recursive_query_table_list(
assert isinstance(graph, ComposableGraph)
query_str = "World?"
response = graph.query(query_str, query_configs=query_configs)
assert str(response) == ("World?:World?:Hello world.")
assert str(response) == ("World?:World?:Hello world.:None")

query_str = "Test?"
response = graph.query(query_str, query_configs=query_configs)
assert str(response) == ("Test?:Test?:This is a test.")
assert str(response) == ("Test?:Test?:This is a test.:Test?:This is a test.")

# test serialize and then back
with TemporaryDirectory() as tmpdir:
graph.save_to_disk(str(Path(tmpdir) / "tmp.json"))
graph = ComposableGraph.load_from_disk(str(Path(tmpdir) / "tmp.json"))
response = graph.query(query_str, query_configs=query_configs)
assert str(response) == ("Test?:Test?:This is a test.")
assert str(response) == ("Test?:Test?:This is a test.:Test?:This is a test.")


@patch.object(TokenTextSplitter, "split_text", side_effect=mock_token_splitter_newline)
Expand Down Expand Up @@ -284,21 +288,21 @@ def test_recursive_query_list_table(
assert isinstance(graph, ComposableGraph)
query_str = "Foo?"
response = graph.query(query_str, query_configs=query_configs)
assert str(response) == ("Foo?:Foo?:This is a test v2.")
assert str(response) == ("Foo?:Foo?:This is a test v2.:This is another test.")
query_str = "Orange?"
response = graph.query(query_str, query_configs=query_configs)
assert str(response) == ("Orange?:Orange?:This is a test.")
assert str(response) == ("Orange?:Orange?:This is a test.:Hello world.")
query_str = "Cat?"
response = graph.query(query_str, query_configs=query_configs)
assert str(response) == ("Cat?:Cat?:This is another test.")
assert str(response) == ("Cat?:Cat?:This is another test.:This is a test v2.")

# test serialize and then back
# use composable graph struct
with TemporaryDirectory() as tmpdir:
graph.save_to_disk(str(Path(tmpdir) / "tmp.json"))
graph = ComposableGraph.load_from_disk(str(Path(tmpdir) / "tmp.json"))
response = graph.query(query_str, query_configs=query_configs)
assert str(response) == ("Cat?:Cat?:This is another test.")
assert str(response) == ("Cat?:Cat?:This is another test.:This is a test v2.")


@patch.object(LLMChain, "predict", side_effect=mock_llmchain_predict)
Expand Down
5 changes: 4 additions & 1 deletion tests/indices/struct_store/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,10 @@ def test_sql_index_with_index_context(
sql_context_container = context_builder.build_context_container(
ignore_db_schema=True
)
assert context_response == "Context query?:table_name: test_table"
print(context_response)
assert (
context_response == "Context query?:table_name: test_table:test_table_context"
)
assert sql_context_container.context_str == context_response

index = GPTSQLStructStoreIndex.from_documents(
Expand Down
11 changes: 8 additions & 3 deletions tests/indices/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,14 @@ def test_give_response(
texts=[TextChunk(documents[0].get_text())],
)
response = builder.get_response(query_str)
assert str(response) == "What is?:Hello world."
expected_answer = (
"What is?:"
"Hello world.:"
"This is a test.:"
"This is another test.:"
"This is a test v2."
)
assert str(response) == expected_answer


@patch.object(LLMPredictor, "total_tokens_used", return_value=0)
Expand Down Expand Up @@ -106,8 +113,6 @@ def test_compact_response(
texts = [
TextChunk("This\n\nis\n\na\n\nbar"),
TextChunk("This\n\nis\n\na\n\ntest"),
TextChunk("This\n\nis\n\nanother\n\ntest"),
TextChunk("This\n\nis\n\na\n\nfoo"),
]

builder = ResponseBuilder(
Expand Down
5 changes: 4 additions & 1 deletion tests/indices/tree/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,10 @@ def test_summarize_query(
}
# TODO: fix unit test later
response = tree.query(query_str, mode="summarize", **query_kwargs)
assert str(response) == ("What is?:Hello world.")
print(str(response))
assert str(response) == (
"What is?:Hello world.:This is a test.:This is another test.:This is a test v2."
)

# test that default query fails
with pytest.raises(ValueError):
Expand Down
2 changes: 1 addition & 1 deletion tests/mock_utils/mock_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _mock_answer(prompt_args: Dict) -> str:

def _mock_refine(prompt_args: Dict) -> str:
"""Mock refine."""
return prompt_args["existing_answer"]
return prompt_args["existing_answer"] + ":" + prompt_args["context_msg"]


def _mock_keyword_extract(prompt_args: Dict) -> str:
Expand Down

0 comments on commit 9205e1f

Please sign in to comment.