Skip to content

Commit

Permalink
refactor(kg): modify kg query response to prompt LLM to include query…
Browse files Browse the repository at this point in the history
… statement (#200)

* fix some bugs in reflexion agent

* fix an vectorstore issue

* introduce use_reflexion to kg RagAgent

* prompt LLM not to generate cypher marker

* fix test error

* change vectorstore collection names to apply previous fix of collection field max length issue

* phrasing, formatting, parameter explanations

* fix type hinting

* prompt LLM even if kg query result is empty

* prompt LLM to quote query in answer

* fix an error in test

* adjust wording

* quotation marks -> code block

---------

Co-authored-by: fengsh <shaohong.feng.78@gmail.com>
Co-authored-by: slobentanzer <sebastian.lobentanzer@gmail.com>
  • Loading branch information
3 people authored Aug 21, 2024
1 parent 1fae6e3 commit 2e5ed29
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 57 deletions.
71 changes: 42 additions & 29 deletions biochatter/database_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Callable
import json
from typing import Dict, List, Optional

from langchain.schema import Document
import neo4j_utils as nu
Expand Down Expand Up @@ -81,6 +82,44 @@ def _generate_query(self, query: str):
results = self.driver.query(query=query)
return query, results

def _build_response(
self,
results: List[Dict],
cypher_query: str,
results_num: Optional[int] = 3,
) -> List[Document]:
if len(results) == 0:
return [
Document(
page_content=(
"I didn't find any result in knowledge graph, "
f"but here is the query I used: {cypher_query}. "
"You can ask user to refine the question. "
"Note: please ensure to include the query in a code "
"block in your response so that the user can refine "
"their question effectively."
),
metadata={"cypher_query": cypher_query},
)
]

clipped_results = results[:results_num] if results_num > 0 else results
results_dump = json.dumps(clipped_results)

return [
Document(
page_content=(
"The results retrieved from knowledge graph are: "
f"{results_dump}. "
f"The query used is: {cypher_query}. "
"Note: please ensure to include the query in a code block "
"in your response so that the user can refine "
"their question effectively."
),
metadata={"cypher_query": cypher_query},
)
]

def get_query_results(self, query: str, k: int = 3) -> list[Document]:
"""
Generate a query using the prompt engine and return the results.
Expand Down Expand Up @@ -109,40 +148,14 @@ def get_query_results(self, query: str, k: int = 3) -> list[Document]:
else:
results = self.driver.query(query=cypher_query)

documents = []
# return first k results
# returned nodes can have any formatting, and can also be empty or fewer
# than k
if results is None or len(results) == 0 or results[0] is None:
return []
if len(results[0]) == 0:
return [
Document(
page_content = (
"I didn't find any result in knowledge graph, "
f"but here is the query I used: {cypher_query}. "
"You can ask user to refine the question, "
"but don't make up anything."
),
metadata={
"cypher_query": cypher_query,
},
)
]

for result in results[0]:
documents.append(
Document(
page_content=json.dumps(result),
metadata={
"cypher_query": cypher_query,
},
)
)
if len(documents) == k:
break

return documents
return self._build_response(
results=results[0], cypher_query=cypher_query, results_num=k
)

def get_description(self):
result = self.driver.query("MATCH (n:Schema_info) RETURN n LIMIT 1")
Expand Down
44 changes: 16 additions & 28 deletions test/test_database_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,14 @@ def test_get_query_results_with_reflexion():
result = db_agent.get_query_results("test_query", 3)

# Check if the result is as expected
expected_result = [
Document(
page_content='{"key": "value"}',
metadata={"cypher_query": "test_query"},
),
Document(
page_content='{"key": "value"}',
metadata={"cypher_query": "test_query"},
),
Document(
page_content='{"key": "value"}',
metadata={"cypher_query": "test_query"},
),
]
expected_result = db_agent._build_response(
[
{"key": "value"},
{"key": "value"},
{"key": "value"},
],
"test_query",
)
assert result == expected_result


Expand Down Expand Up @@ -98,18 +92,12 @@ def test_get_query_results_without_reflexion():
result = db_agent.get_query_results("test_query", 3)

# Check if the result is as expected
expected_result = [
Document(
page_content='{"key": "value"}',
metadata={"cypher_query": "test_query"},
),
Document(
page_content='{"key": "value"}',
metadata={"cypher_query": "test_query"},
),
Document(
page_content='{"key": "value"}',
metadata={"cypher_query": "test_query"},
),
]
expected_result = db_agent._build_response(
[
{"key": "value"},
{"key": "value"},
{"key": "value"},
],
"test_query",
)
assert result == expected_result

0 comments on commit 2e5ed29

Please sign in to comment.