Skip to content

Commit

Permalink
Neo4j Schema Query Builder Integration (run-llama#520)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: shahafpariente <shahaf.pariente@stargo.co>
  • Loading branch information
shahafp and shahafpariente authored Sep 24, 2023
1 parent 0e13be1 commit 8f24610
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 0 deletions.
9 changes: 9 additions & 0 deletions llama_hub/tools/library.json
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,14 @@
"ZapierToolSpec": {
"id": "tools/zapier",
"author": "ajhofmann"
},
"Neo4jQueryToolSpec": {
"id": "tools/neo4j_db",
"author": "shahafp",
"keywords": [
"graph",
"neo4j",
"cypher"
]
}
}
73 changes: 73 additions & 0 deletions llama_hub/tools/neo4j_db/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Neo4j Schema Query Builder

The `Neo4jQueryToolSpec` class provides a way to query a Neo4j graph database based on a provided schema definition. The class uses a language model to generate Cypher queries from user questions and has the capability to recover from Cypher syntax errors through a self-healing mechanism.

## Table of Contents

- [Usage](#usage)
- [Initialization](#initialization)
- [Running a Query](#running-a-query)
- [Features](#features)

## Usage

### Initialization

Initialize the `Neo4jQueryToolSpec` class with:

```python
from llama_hub.tools.neo4j_db.base import Neo4jQueryToolSpec
from llama_index.llms import OpenAI
from llama_index.agent import OpenAIAgent

llm = OpenAI(model="gpt-4",
openai_api_key="XXXX-XXXX",
temperature=0
)

gds_db = Neo4jQueryToolSpec(
url="neo4j-url",
user="neo4j-user",
password="neo4j=password",
llm=llm,
database='neo4j'
)

tools = gds_db.to_tool_list()
agent = OpenAIAgent.from_tools(tools, verbose=True)

```

Where:

- `url`: Connection string for the Neo4j database.
- `user`: Username for the Neo4j database.
- `password`: Password for the Neo4j database.
- `llm`: A language model for generating Cypher queries (any type of LLM).
- `database`: The database name.

### Running a Query

To use the agent:

```python
# use agent
agent.chat("Where is JFK airport is located?")
```

```
Generated Cypher:
MATCH (p:Port {port_code: 'JFK'})
RETURN p.location_name_wo_diacritics AS Location
Final answer:
'The port code JFK is located in New York, United States.'
```


## Features

- **Schema-Based Querying**: The class extracts the Neo4j database schema to guide the Cypher query generation.
- **Self-Healing**: On a Cypher syntax error, the class corrects itself to produce a valid query.
- **Language Model Integration**: Uses a language model for natural and accurate Cypher query generation.
1 change: 1 addition & 0 deletions llama_hub/tools/neo4j_db/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Init file."""
129 changes: 129 additions & 0 deletions llama_hub/tools/neo4j_db/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from llama_index.graph_stores import Neo4jGraphStore
from llama_index.llms.base import LLM, ChatMessage, MessageRole
from llama_index.tools.tool_spec.base import BaseToolSpec


class Neo4jQueryToolSpec(BaseToolSpec):
"""
This class is responsible for querying a Neo4j graph database based on a provided schema definition.
"""

spec_functions = ["run_request"]

def __init__(self, url, user, password, database, llm: LLM):
"""
Initializes the Neo4jSchemaWiseQuery object.
Args:
url (str): The connection string for the Neo4j database.
user (str): Username for the Neo4j database.
password (str): Password for the Neo4j database.
llm (obj): A language model for generating Cypher queries.
"""
try:
from neo4j import GraphDatabase

except ImportError:
raise ImportError(
"`neo4j` package not found, please run `pip install neo4j`"
)

self.graph_store = Neo4jGraphStore(url=url, username=user, password=password, database=database)
self.llm = llm

def get_system_message(self):
"""
Generates a system message detailing the task and schema.
Returns:
str: The system message.
"""
return f"""
Task: Generate Cypher queries to query a Neo4j graph database based on the provided schema definition.
Instructions:
Use only the provided relationship types and properties.
Do not use any other relationship types or properties that are not provided.
If you cannot generate a Cypher statement based on the provided schema, explain the reason to the user.
Schema:
{self.graph_store.schema}
Note: Do not include any explanations or apologies in your responses.
"""

def query_graph_db(self, neo4j_query, params=None):
"""
Queries the Neo4j database.
Args:
neo4j_query (str): The Cypher query to be executed.
params (dict, optional): Parameters for the Cypher query. Defaults to None.
Returns:
list: The query results.
"""
if params is None:
params = {}
with self.graph_store.client.session() as session:
result = session.run(neo4j_query, params)
output = [r.values() for r in result]
output.insert(0, list(result.keys()))
return output

def construct_cypher_query(self, question, history=None):
"""
Constructs a Cypher query based on a given question and history.
Args:
question (str): The question to construct the Cypher query for.
history (list, optional): A list of previous interactions for context. Defaults to None.
Returns:
str: The constructed Cypher query.
"""
messages = [
ChatMessage(role=MessageRole.SYSTEM, content=self.get_system_message()),
ChatMessage(role=MessageRole.USER, content=question),
]
# Used for Cypher healing flows
if history:
messages.extend(history)

completions = self.llm.chat(messages)
return completions.message.content

def run_request(self, question, history=None, retry=True):
"""
Executes a Cypher query based on a given question.
Args:
question (str): The question to execute the Cypher query for.
history (list, optional): A list of previous interactions for context. Defaults to None.
retry (bool, optional): Whether to retry in case of a syntax error. Defaults to True.
Returns:
list/str: The query results or an error message.
"""
from neo4j.exceptions import CypherSyntaxError

# Construct Cypher statement
cypher = self.construct_cypher_query(question, history)
print(cypher)
try:
return self.query_graph_db(cypher)
# Self-healing flow
except CypherSyntaxError as e:
# If out of retries
if not retry:
return "Invalid Cypher syntax"
# Self-healing Cypher flow by
# providing specific error to GPT-4
print("Retrying")
return self.run_request(
question,
[
ChatMessage(role=MessageRole.ASSISTANT, content=cypher),
ChatMessage(role=MessageRole.SYSTEM, conent=f"This query returns an error: {str(e)}\n"
"Give me a improved query that works without any explanations or apologies"),
],
retry=False
)
1 change: 1 addition & 0 deletions llama_hub/tools/neo4j_db/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
neo4j

0 comments on commit 8f24610

Please sign in to comment.