Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
RunShowStorageInfoTool,
RunShowTriggersTool,
RunShowConstraintInfoTool,
RunNodeNeighborhoodTool,
RunNodeVectorSearchTool,
)


Expand Down Expand Up @@ -79,4 +81,6 @@ def get_tools(self) -> List[BaseTool]:
RunShowStorageInfoTool(db=self.db),
RunShowTriggersTool(db=self.db),
RunShowConstraintInfoTool(db=self.db),
RunNodeNeighborhoodTool(db=self.db),
RunNodeVectorSearchTool(db=self.db),
]
86 changes: 86 additions & 0 deletions integrations/langchain-memgraph/langchain_memgraph/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from memgraph_toolbox.tools.index import ShowIndexInfoTool
from memgraph_toolbox.tools.betweenness_centrality import BetweennessCentralityTool
from memgraph_toolbox.tools.constraint import ShowConstraintInfoTool
from memgraph_toolbox.tools.node_neighborhood import NodeNeighborhoodTool
from memgraph_toolbox.tools.node_vector_search import NodeVectorSearchTool
from memgraph_toolbox.utils.logging import logger_init


Expand Down Expand Up @@ -284,3 +286,87 @@ def _run(
return BetweennessCentralityTool(
db=self.db,
).call({"isDirectionIgnored": isDirectionIgnored, "limit": limit})


class _NodeNeighborhoodToolInput(BaseModel):
"""
Input schema for the Node Neighborhood Memgraph tool.
"""

node_id: str = Field(
...,
description="The ID of the starting node to find neighborhood around",
)
max_distance: int = Field(
1,
description="Maximum distance (hops) to search from the starting node. Default is 1.",
)
limit: int = Field(
100, description="Maximum number of nodes to return. Default is 100."
)


class RunNodeNeighborhoodTool(BaseMemgraphTool, BaseTool):
"""Tool for finding nodes within a specified neighborhood distance in Memgraph."""

name: str = NodeNeighborhoodTool(db=None).get_name()
"""The name that is passed to the model when performing tool calling."""

description: str = NodeNeighborhoodTool(db=None).get_description()
"""The description that is passed to the model when performing tool calling."""

args_schema: Type[BaseModel] = _NodeNeighborhoodToolInput
"""The schema that is passed to the model when performing tool calling."""

def _run(
self,
node_id: str,
max_distance: int = 1,
limit: int = 100,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> List[Dict[str, Any]]:
return NodeNeighborhoodTool(
db=self.db,
).call({"node_id": node_id, "max_distance": max_distance, "limit": limit})


class _NodeVectorSearchToolInput(BaseModel):
"""
Input schema for the Node Vector Search Memgraph tool.
"""

index_name: str = Field(
...,
description="Name of the index to use for the vector search",
)
query_vector: List[float] = Field(
...,
description="Query vector to search for similarity",
)
limit: int = Field(
10, description="Number of similar nodes to return. Default is 10."
)


class RunNodeVectorSearchTool(BaseMemgraphTool, BaseTool):
"""Tool for performing vector similarity search on nodes in Memgraph."""

name: str = NodeVectorSearchTool(db=None).get_name()
"""The name that is passed to the model when performing tool calling."""

description: str = NodeVectorSearchTool(db=None).get_description()
"""The description that is passed to the model when performing tool calling."""

args_schema: Type[BaseModel] = _NodeVectorSearchToolInput
"""The schema that is passed to the model when performing tool calling."""

def _run(
self,
index_name: str,
query_vector: List[float],
limit: int = 10,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> List[Dict[str, Any]]:
return NodeVectorSearchTool(
db=self.db,
).call({"index_name": index_name, "query_vector": query_vector, "limit": limit})
4 changes: 2 additions & 2 deletions integrations/langchain-memgraph/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "langchain-memgraph"
version = "0.1.6"
version = "0.1.7"
description = "An integration package connecting Memgraph and LangChain"
authors = [{ name = "Ante Javor", email = "ante.javor@memgraph.com" }]
readme = "README.md"
Expand All @@ -17,7 +17,7 @@ classifiers = [
dependencies = [
"langchain-core>=0.3.15",
"neo4j>=5.28.1",
"memgraph-toolbox>=0.1.4",
"memgraph-toolbox>=0.1.5",
"langchain>=0.3.25",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
RunShowConfigTool,
RunShowTriggersTool,
RunBetweennessCentralityTool,
RunNodeNeighborhoodTool,
RunNodeVectorSearchTool,
)
from memgraph_toolbox.api.memgraph import Memgraph

Expand Down Expand Up @@ -157,3 +159,31 @@ def tool_constructor_params(self) -> dict:
@property
def tool_invoke_params_example(self) -> dict:
return {"isDirectionIgnored": True, "limit": 5}


class TestNodeNeighborhoodIntegration(ToolsIntegrationTests):
@property
def tool_constructor(self) -> Type[RunNodeNeighborhoodTool]:
return RunNodeNeighborhoodTool

@property
def tool_constructor_params(self) -> dict:
return {"db": Memgraph("bolt://localhost:7687", "", "")}

@property
def tool_invoke_params_example(self) -> dict:
return {"node_id": "1", "max_distance": 2, "limit": 10}


class TestNodeVectorSearchIntegration(ToolsIntegrationTests):
@property
def tool_constructor(self) -> Type[RunNodeVectorSearchTool]:
return RunNodeVectorSearchTool

@property
def tool_constructor_params(self) -> dict:
return {"db": Memgraph("bolt://localhost:7687", "", "")}

@property
def tool_invoke_params_example(self) -> dict:
return {"index_name": "test_index", "query_vector": [1.0, 2.0, 3.0], "limit": 5}
10 changes: 5 additions & 5 deletions integrations/langchain-memgraph/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "memgraph-ai"
version = "0.1.3"
version = "0.1.4"
description = "Memgraph AI Toolkit"
readme = "README.md"
requires-python = ">=3.10"
Expand Down