Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/neo4j_graphrag/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ class ToolParameter(BaseModel):
def model_dump_tool(self) -> Dict[str, Any]:
"""Convert the parameter to a dictionary format for tool usage."""
result: Dict[str, Any] = {"type": self.type, "description": self.description}
if self.required:
result["required"] = True
return result

@classmethod
Expand Down Expand Up @@ -183,8 +181,8 @@ def model_dump_tool(self, exclude: Optional[list[str]] = None) -> Dict[str, Any]
if self.required_properties and "required" not in exclude:
result["required"] = self.required_properties

if not self.additional_properties and "additional_properties" not in exclude:
result["additionalProperties"] = False
if "additional_properties" not in exclude:
result["additionalProperties"] = self.additional_properties

return result

Expand Down
103 changes: 103 additions & 0 deletions tests/unit/retrievers/test_retriever_parameter_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Tests for retriever parameter inference and convert_to_tool functionality.
"""

import pytest
from unittest.mock import MagicMock, patch
from typing import Optional, Any, Dict

Expand Down Expand Up @@ -468,3 +469,105 @@ def get_search_results(self, test_param: str) -> RawSearchResult:

# Should use fallback description
assert properties["test_param"].description == "Parameter test_param"


class TestOpenAICompatibilityFix:
"""Test the specific fixes for OpenAI API compatibility."""

@patch("neo4j_graphrag.retrievers.base.get_version")
def test_text2cypher_retriever_openai_schema_compatibility(self, mock_get_version):
"""Test that Text2CypherRetriever generates OpenAI-compatible schema.

This test specifically covers the bug that was causing:
'Invalid schema for function 't2c_retriever': True is not of type 'array''
"""
mock_get_version.return_value = ((5, 20, 0), False, False)

driver = create_mock_driver()
llm = create_mock_llm()
retriever = Text2CypherRetriever(
driver=driver, llm=llm, neo4j_schema="(Person)-[:KNOWS]->(Person)"
)

# Convert to tool (this is where the original bug occurred)
tool = retriever.convert_to_tool(
name="t2c_retriever",
description="Use this tool when no other tool can help. It will directly try to build a Cypher query to query the graph.",
)

# Get the tool parameters schema
schema = tool.get_parameters()

# Verify JSON Schema structure is correct for OpenAI
assert schema["type"] == "object"
assert "properties" in schema
assert "required" in schema
assert "additionalProperties" in schema

# Check that required is an array, not a boolean
assert isinstance(schema["required"], list)
assert "query_text" in schema["required"]

# Check individual properties don't have 'required' field
for prop_name, prop_schema in schema["properties"].items():
assert (
"required" not in prop_schema
), f"Property {prop_name} should not have 'required' field"

# Check the specific property that was causing issues
prompt_params_schema = schema["properties"]["prompt_params"]
assert prompt_params_schema["type"] == "object"
assert "additionalProperties" in prompt_params_schema
assert prompt_params_schema["additionalProperties"] is True

# Ensure the schema is valid JSON Schema format
import json

try:
# This should not raise any exceptions
json_str = json.dumps(schema)
parsed = json.loads(json_str)
assert parsed == schema
except (TypeError, ValueError) as e:
pytest.fail(f"Schema is not JSON serializable: {e}")

@patch("neo4j_graphrag.retrievers.base.get_version")
def test_tools_retriever_with_t2c_tool_integration(self, mock_get_version):
"""Integration test showing the full ToolsRetriever + Text2CypherRetriever workflow."""
mock_get_version.return_value = ((5, 20, 0), False, False)

driver = create_mock_driver()
llm = create_mock_llm()

# Create a Text2CypherRetriever
t2c_retriever = Text2CypherRetriever(
driver=driver, llm=llm, neo4j_schema="(Movie)-[:ACTED_IN]-(Person)"
)

# Convert it to a tool (this was failing before the fix)
t2c_tool = t2c_retriever.convert_to_tool(
name="t2c_retriever",
description="Generate Cypher queries from natural language",
)

# Create ToolsRetriever with the t2c_tool
tools_retriever = ToolsRetriever(driver=driver, llm=llm, tools=[t2c_tool])

# Verify that the tools_retriever was created successfully
assert len(tools_retriever._tools) == 1
assert tools_retriever._tools[0].get_name() == "t2c_retriever"

# Get the tool's parameters to verify schema structure
tool_params = t2c_tool.get_parameters()

# This should have the correct structure that OpenAI expects
assert tool_params["type"] == "object"
assert isinstance(tool_params["required"], list)
assert "additionalProperties" in tool_params

# All nested objects should also have additionalProperties
for prop_name, prop_schema in tool_params["properties"].items():
if prop_schema.get("type") == "object":
assert (
"additionalProperties" in prop_schema
), f"Nested object {prop_name} missing additionalProperties"
147 changes: 138 additions & 9 deletions tests/unit/tool/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def test_string_parameter() -> None:
d = param.model_dump_tool()
assert d["type"] == ParameterType.STRING
assert d["enum"] == ["a", "b"]
assert d["required"] is True
# Note: 'required' is handled at the object level, not individual parameter level
assert "required" not in d


def test_integer_parameter() -> None:
Expand Down Expand Up @@ -141,38 +142,166 @@ def test_from_dict() -> None:


def test_required_parameter() -> None:
# Test that required=True is included in model_dump_tool output for different parameter types
# Test that individual parameters don't include 'required' field (it's handled at object level)
string_param = StringParameter(description="Required string", required=True)
assert string_param.model_dump_tool()["required"] is True
assert "required" not in string_param.model_dump_tool()

integer_param = IntegerParameter(description="Required integer", required=True)
assert integer_param.model_dump_tool()["required"] is True
assert "required" not in integer_param.model_dump_tool()

number_param = NumberParameter(description="Required number", required=True)
assert number_param.model_dump_tool()["required"] is True
assert "required" not in number_param.model_dump_tool()

boolean_param = BooleanParameter(description="Required boolean", required=True)
assert boolean_param.model_dump_tool()["required"] is True
assert "required" not in boolean_param.model_dump_tool()

array_param = ArrayParameter(
description="Required array",
items=StringParameter(description="item"),
required=True,
)
assert array_param.model_dump_tool()["required"] is True
assert "required" not in array_param.model_dump_tool()

object_param = ObjectParameter(
description="Required object",
properties={"prop": StringParameter(description="property")},
required=True,
)
assert object_param.model_dump_tool()["required"] is True
assert "required" not in object_param.model_dump_tool()

# Test that required=False doesn't include the required field
# Test that optional parameters also don't include the required field
optional_param = StringParameter(description="Optional string", required=False)
assert "required" not in optional_param.model_dump_tool()


def test_object_parameter_additional_properties_always_present() -> None:
"""Test that additionalProperties is always present in ObjectParameter schema, fixing OpenAI compatibility."""

# Test additionalProperties=True (default)
obj_param_true = ObjectParameter(
description="Object with additional properties",
properties={"prop": StringParameter(description="A property")},
additional_properties=True,
)
schema_true = obj_param_true.model_dump_tool()
assert "additionalProperties" in schema_true
assert schema_true["additionalProperties"] is True

# Test additionalProperties=False
obj_param_false = ObjectParameter(
description="Object without additional properties",
properties={"prop": StringParameter(description="A property")},
additional_properties=False,
)
schema_false = obj_param_false.model_dump_tool()
assert "additionalProperties" in schema_false
assert schema_false["additionalProperties"] is False


def test_json_schema_compatibility() -> None:
"""Test that the generated schema is compatible with JSON Schema specification."""

# Create a complex object with nested properties and required fields
nested_obj = ObjectParameter(
description="Nested object",
properties={
"nested_prop": StringParameter(description="Nested string"),
},
additional_properties=True,
)

main_obj = ObjectParameter(
description="Main object",
properties={
"required_string": StringParameter(description="Required string"),
"optional_number": NumberParameter(description="Optional number"),
"nested_object": nested_obj,
},
required_properties=["required_string"],
additional_properties=False,
)

schema = main_obj.model_dump_tool()

# Verify JSON Schema structure
assert schema["type"] == "object"
assert "properties" in schema
assert "required" in schema
assert "additionalProperties" in schema

# Check required is an array (not boolean on individual properties)
assert isinstance(schema["required"], list)
assert "required_string" in schema["required"]
assert len(schema["required"]) == 1

# Check individual properties don't have 'required' field
for prop_name, prop_schema in schema["properties"].items():
assert "required" not in prop_schema

# Check additionalProperties is properly set at all levels
assert schema["additionalProperties"] is False
assert schema["properties"]["nested_object"]["additionalProperties"] is True


def test_text2cypher_retriever_schema_compatibility() -> None:
"""Test the specific schema structure that caused the OpenAI API error."""

# Simulate the Text2CypherRetriever parameter structure
prompt_params = ObjectParameter(
description="Parameter prompt_params",
properties={},
additional_properties=True, # This was missing in the original bug
)

t2c_params = ObjectParameter(
description="Parameters for Text2CypherRetriever",
properties={
"query_text": StringParameter(description="Parameter query_text"),
"prompt_params": prompt_params,
},
required_properties=["query_text"],
additional_properties=False,
)

schema = t2c_params.model_dump_tool()

# Verify the fix: prompt_params should have additionalProperties
prompt_params_schema = schema["properties"]["prompt_params"]
assert "additionalProperties" in prompt_params_schema
assert prompt_params_schema["additionalProperties"] is True

# Verify query_text doesn't have individual 'required' field
query_text_schema = schema["properties"]["query_text"]
assert "required" not in query_text_schema

# Verify required array at object level
assert schema["required"] == ["query_text"]


def test_exclude_parameter_in_object_schema() -> None:
"""Test that exclude parameter works correctly in ObjectParameter.model_dump_tool()."""

obj_param = ObjectParameter(
description="Test object",
properties={
"prop1": StringParameter(description="Property 1"),
"prop2": IntegerParameter(description="Property 2"),
},
required_properties=["prop1"],
additional_properties=True,
)

# Test excluding required field
schema_no_required = obj_param.model_dump_tool(exclude=["required"])
assert "required" not in schema_no_required
assert "additionalProperties" in schema_no_required # Should still be present

# Test excluding additionalProperties field
schema_no_additional = obj_param.model_dump_tool(exclude=["additional_properties"])
assert "additionalProperties" not in schema_no_additional
assert "required" in schema_no_additional # Should still be present


def test_tool_class() -> None:
def dummy_func(**kwargs: Any) -> dict[str, Any]:
return kwargs
Expand Down