Skip to content

Commit e266729

Browse files
committed
Basic structure of MongoDB tool
1 parent 227c65b commit e266729

File tree

3 files changed

+151
-3
lines changed

3 files changed

+151
-3
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# flake8: noqa
2+
QUERY_CHECKER = """
3+
{query}
4+
Double check the {client} query above for common mistakes, including:
5+
- Improper use of $nin operator with null values
6+
- Using $merge instead of $concat for combining arrays
7+
- Incorrect use of $not or $ne for exclusive ranges
8+
- Data type mismatch in query conditions
9+
- Properly referencing field names in queries
10+
- Using the correct syntax for aggregation functions
11+
- Casting to the correct BSON data type
12+
- Using the proper fields for $lookup in aggregations
13+
14+
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
15+
16+
MongoDB Query: """
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# flake8: noqa
2+
"""Tools for interacting with a MongoDB database."""
3+
from typing import Any, Dict, Optional
4+
5+
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
6+
7+
from langchain.schema.language_model import BaseLanguageModel
8+
from langchain.callbacks.manager import (
9+
AsyncCallbackManagerForToolRun,
10+
CallbackManagerForToolRun,
11+
)
12+
from langchain.chains.llm import LLMChain
13+
from langchain.prompts import PromptTemplate
14+
from langchain.utilities.mongo_database import MongoDBDatabase
15+
from langchain.tools.base import BaseTool
16+
from langchain.tools.mongo_database.prompt import QUERY_CHECKER
17+
18+
19+
class BaseMongoDBTool(BaseModel):
20+
"""Base tool for interacting with a MongoDB database."""
21+
22+
db: MongoDBDatabase = Field(exclude=True)
23+
24+
class Config(BaseTool.Config):
25+
pass
26+
27+
28+
class QueryMongoDBTool(BaseMongoDBTool, BaseTool):
29+
"""Tool for querying a MongoDB database."""
30+
31+
name: str = "mongo_db_query"
32+
description: str = """
33+
Input to this tool is a detailed and correct MongoDB query, output is a result from the database.
34+
If the query is not correct, an error message will be returned.
35+
If an error is returned, rewrite the query, check the query, and try again.
36+
"""
37+
38+
def _run(
39+
self,
40+
query: str,
41+
run_manager: Optional[CallbackManagerForToolRun] = None,
42+
) -> str:
43+
"""Execute the query, return the results or an error message."""
44+
return self.db.run(query)
45+
46+
47+
class InfoMongoDBTool(BaseMongoDBTool, BaseTool):
48+
"""Tool for getting metadata about a MongoDB database."""
49+
50+
name: str = "mongo_db_schema"
51+
description: str = """
52+
Input to this tool is a comma-separated list of collections, output is the schema and sample documents for those collections.
53+
54+
Example Input: "collection1, collection2, collection3"
55+
"""
56+
57+
def _run(
58+
self,
59+
collection_names: str,
60+
run_manager: Optional[CallbackManagerForToolRun] = None,
61+
) -> str:
62+
"""Get information about specified collections."""
63+
return self.db.get_document_info(collection_names=collection_names)
64+
65+
66+
class ListMongoDBTool(BaseMongoDBTool, BaseTool):
67+
"""Tool for listing collections in a MongoDB database."""
68+
69+
name: str = "mongo_db_list"
70+
description: str = """
71+
Output of this tool is a list of collections in the database.
72+
"""
73+
74+
def _run(
75+
self,
76+
run_manager: Optional[CallbackManagerForToolRun] = None,
77+
) -> str:
78+
"""Get a list of collections in the database."""
79+
return self.db.collection_info()
80+
81+
82+
class QueryMongoDBCheckerTool(BaseMongoDBTool, BaseTool):
83+
"""Use an LLM to check if a query is correct"""
84+
85+
template: str = QUERY_CHECKER
86+
llm: BaseLanguageModel
87+
llm_chain: LLMChain = Field(init=False)
88+
name: str = "mongo_db_query_checker"
89+
description: str = """
90+
Use this tool to double check a MongoDB query for common mistakes.
91+
"""
92+
93+
@root_validator(pre=True)
94+
def _init_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
95+
"""Initialize the LLM chain."""
96+
if "llm_chain" not in values:
97+
values["llm_chain"] = LLMChain(
98+
llm=values.get("llm"),
99+
prompt=PromptTemplate(
100+
template=QUERY_CHECKER, input_variables=["client", "query"]
101+
),
102+
)
103+
104+
if values["llm_chain"].prompt.input_variables != ["client", "query"]:
105+
raise ValueError(
106+
"LLM chain for QueryCheckerTool must have input variables ['query', 'client']"
107+
)
108+
109+
return values
110+
111+
def _run(
112+
self,
113+
query: str,
114+
run_manager: Optional[CallbackManagerForToolRun] = None,
115+
) -> str:
116+
"""Use the LLM to check the query."""
117+
return self.llm_chain.predict(
118+
query=query,
119+
client=self.db.client,
120+
callbacks=run_manager.get_child() if run_manager else None,
121+
)
122+
123+
async def _arun(
124+
self,
125+
query: str,
126+
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
127+
) -> str:
128+
return await self.llm_chain.apredict(
129+
query=query,
130+
client=self.db.client,
131+
callbacks=run_manager.get_child() if run_manager else None,
132+
)

libs/langchain/langchain/tools/sql_database/tool.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
"""Tools for interacting with a SQL database."""
33
from typing import Any, Dict, Optional
44

5-
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
5+
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
66

7-
from langchain_core.language_models import BaseLanguageModel
7+
from langchain.schema.language_model import BaseLanguageModel
88
from langchain.callbacks.manager import (
99
AsyncCallbackManagerForToolRun,
1010
CallbackManagerForToolRun,
1111
)
1212
from langchain.chains.llm import LLMChain
13-
from langchain_core.prompts import PromptTemplate
13+
from langchain.prompts import PromptTemplate
1414
from langchain.utilities.sql_database import SQLDatabase
1515
from langchain.tools.base import BaseTool
1616
from langchain.tools.sql_database.prompt import QUERY_CHECKER

0 commit comments

Comments
 (0)