forked from run-llama/llama-hub
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Neo4j Schema Query Builder Integration (run-llama#520)
--------- Co-authored-by: shahafpariente <shahaf.pariente@stargo.co>
- Loading branch information
Showing
5 changed files
with
213 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Neo4j Schema Query Builder | ||
|
||
The `Neo4jQueryToolSpec` class provides a way to query a Neo4j graph database based on a provided schema definition. The class uses a language model to generate Cypher queries from user questions and has the capability to recover from Cypher syntax errors through a self-healing mechanism. | ||
|
||
## Table of Contents | ||
|
||
- [Usage](#usage) | ||
- [Initialization](#initialization) | ||
- [Running a Query](#running-a-query) | ||
- [Features](#features) | ||
|
||
## Usage | ||
|
||
### Initialization | ||
|
||
Initialize the `Neo4jQueryToolSpec` class with: | ||
|
||
```python | ||
from llama_hub.tools.neo4j_db.base import Neo4jQueryToolSpec | ||
from llama_index.llms import OpenAI | ||
from llama_index.agent import OpenAIAgent | ||
|
||
llm = OpenAI(model="gpt-4", | ||
openai_api_key="XXXX-XXXX", | ||
temperature=0 | ||
) | ||
|
||
gds_db = Neo4jQueryToolSpec( | ||
url="neo4j-url", | ||
user="neo4j-user", | ||
password="neo4j=password", | ||
llm=llm, | ||
database='neo4j' | ||
) | ||
|
||
tools = gds_db.to_tool_list() | ||
agent = OpenAIAgent.from_tools(tools, verbose=True) | ||
|
||
``` | ||
|
||
Where: | ||
|
||
- `url`: Connection string for the Neo4j database. | ||
- `user`: Username for the Neo4j database. | ||
- `password`: Password for the Neo4j database. | ||
- `llm`: A language model for generating Cypher queries (any type of LLM). | ||
- `database`: The database name. | ||
|
||
### Running a Query | ||
|
||
To use the agent: | ||
|
||
```python | ||
# use agent | ||
agent.chat("Where is JFK airport is located?") | ||
``` | ||
|
||
``` | ||
Generated Cypher: | ||
MATCH (p:Port {port_code: 'JFK'}) | ||
RETURN p.location_name_wo_diacritics AS Location | ||
Final answer: | ||
'The port code JFK is located in New York, United States.' | ||
``` | ||
|
||
|
||
## Features | ||
|
||
- **Schema-Based Querying**: The class extracts the Neo4j database schema to guide the Cypher query generation. | ||
- **Self-Healing**: On a Cypher syntax error, the class corrects itself to produce a valid query. | ||
- **Language Model Integration**: Uses a language model for natural and accurate Cypher query generation. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Init file.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
from llama_index.graph_stores import Neo4jGraphStore | ||
from llama_index.llms.base import LLM, ChatMessage, MessageRole | ||
from llama_index.tools.tool_spec.base import BaseToolSpec | ||
|
||
|
||
class Neo4jQueryToolSpec(BaseToolSpec): | ||
""" | ||
This class is responsible for querying a Neo4j graph database based on a provided schema definition. | ||
""" | ||
|
||
spec_functions = ["run_request"] | ||
|
||
def __init__(self, url, user, password, database, llm: LLM): | ||
""" | ||
Initializes the Neo4jSchemaWiseQuery object. | ||
Args: | ||
url (str): The connection string for the Neo4j database. | ||
user (str): Username for the Neo4j database. | ||
password (str): Password for the Neo4j database. | ||
llm (obj): A language model for generating Cypher queries. | ||
""" | ||
try: | ||
from neo4j import GraphDatabase | ||
|
||
except ImportError: | ||
raise ImportError( | ||
"`neo4j` package not found, please run `pip install neo4j`" | ||
) | ||
|
||
self.graph_store = Neo4jGraphStore(url=url, username=user, password=password, database=database) | ||
self.llm = llm | ||
|
||
def get_system_message(self): | ||
""" | ||
Generates a system message detailing the task and schema. | ||
Returns: | ||
str: The system message. | ||
""" | ||
return f""" | ||
Task: Generate Cypher queries to query a Neo4j graph database based on the provided schema definition. | ||
Instructions: | ||
Use only the provided relationship types and properties. | ||
Do not use any other relationship types or properties that are not provided. | ||
If you cannot generate a Cypher statement based on the provided schema, explain the reason to the user. | ||
Schema: | ||
{self.graph_store.schema} | ||
Note: Do not include any explanations or apologies in your responses. | ||
""" | ||
|
||
def query_graph_db(self, neo4j_query, params=None): | ||
""" | ||
Queries the Neo4j database. | ||
Args: | ||
neo4j_query (str): The Cypher query to be executed. | ||
params (dict, optional): Parameters for the Cypher query. Defaults to None. | ||
Returns: | ||
list: The query results. | ||
""" | ||
if params is None: | ||
params = {} | ||
with self.graph_store.client.session() as session: | ||
result = session.run(neo4j_query, params) | ||
output = [r.values() for r in result] | ||
output.insert(0, list(result.keys())) | ||
return output | ||
|
||
def construct_cypher_query(self, question, history=None): | ||
""" | ||
Constructs a Cypher query based on a given question and history. | ||
Args: | ||
question (str): The question to construct the Cypher query for. | ||
history (list, optional): A list of previous interactions for context. Defaults to None. | ||
Returns: | ||
str: The constructed Cypher query. | ||
""" | ||
messages = [ | ||
ChatMessage(role=MessageRole.SYSTEM, content=self.get_system_message()), | ||
ChatMessage(role=MessageRole.USER, content=question), | ||
] | ||
# Used for Cypher healing flows | ||
if history: | ||
messages.extend(history) | ||
|
||
completions = self.llm.chat(messages) | ||
return completions.message.content | ||
|
||
def run_request(self, question, history=None, retry=True): | ||
""" | ||
Executes a Cypher query based on a given question. | ||
Args: | ||
question (str): The question to execute the Cypher query for. | ||
history (list, optional): A list of previous interactions for context. Defaults to None. | ||
retry (bool, optional): Whether to retry in case of a syntax error. Defaults to True. | ||
Returns: | ||
list/str: The query results or an error message. | ||
""" | ||
from neo4j.exceptions import CypherSyntaxError | ||
|
||
# Construct Cypher statement | ||
cypher = self.construct_cypher_query(question, history) | ||
print(cypher) | ||
try: | ||
return self.query_graph_db(cypher) | ||
# Self-healing flow | ||
except CypherSyntaxError as e: | ||
# If out of retries | ||
if not retry: | ||
return "Invalid Cypher syntax" | ||
# Self-healing Cypher flow by | ||
# providing specific error to GPT-4 | ||
print("Retrying") | ||
return self.run_request( | ||
question, | ||
[ | ||
ChatMessage(role=MessageRole.ASSISTANT, content=cypher), | ||
ChatMessage(role=MessageRole.SYSTEM, conent=f"This query returns an error: {str(e)}\n" | ||
"Give me a improved query that works without any explanations or apologies"), | ||
], | ||
retry=False | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
neo4j |