Skip to content

Commit 8fc76a1

Browse files
authored
Merge pull request #2 from Haoming-jpg/Haoming-tool.py
Re-write the MongoDB query checker
2 parents 079d20c + 53b1d82 commit 8fc76a1

File tree

2 files changed

+12
-14
lines changed

2 files changed

+12
-14
lines changed

libs/langchain/langchain/tools/mongo_database/prompt.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# flake8: noqa
22
QUERY_CHECKER = """
33
{query}
4-
Double check the {client} query above for common mistakes, including:
5-
- Improper use of $nin operator with null values
6-
- Using $merge instead of $concat for combining arrays
7-
- Incorrect use of $not or $ne for exclusive ranges
8-
- Data type mismatch in query conditions
9-
- Properly referencing field names in queries
10-
- Using the correct syntax for aggregation functions
11-
- Casting to the correct BSON data type
12-
- Using the proper fields for $lookup in aggregations
4+
Double check the MongoDB query above for common mistakes, including:
5+
- Correct syntax for query operators (e.g., $match, $group, $project)
6+
- Properly matching nested fields in the documents
7+
- Using the appropriate array operators (e.g., $elemMatch)
8+
- Utilizing indexes for performance optimization
9+
- Handling data type mismatch in queries
10+
- Ensuring proper field names and key names in queries
11+
- Using the correct projection operators for desired output
12+
- Properly structuring aggregation pipelines if applicable
1313
1414
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
1515

libs/langchain/langchain/tools/mongo_database/tool.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,13 @@ def _init_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
9797
values["llm_chain"] = LLMChain(
9898
llm=values.get("llm"),
9999
prompt=PromptTemplate(
100-
template=QUERY_CHECKER, input_variables=["client", "query"]
100+
template=QUERY_CHECKER, input_variables=["query"]
101101
),
102102
)
103103

104-
if values["llm_chain"].prompt.input_variables != ["client", "query"]:
104+
if values["llm_chain"].prompt.input_variables != ["query"]:
105105
raise ValueError(
106-
"LLM chain for QueryCheckerTool must have input variables ['query', 'client']"
106+
"LLM chain for QueryCheckerTool must have input variables ['query']"
107107
)
108108

109109
return values
@@ -116,7 +116,6 @@ def _run(
116116
"""Use the LLM to check the query."""
117117
return self.llm_chain.predict(
118118
query=query,
119-
client=self.db.client,
120119
callbacks=run_manager.get_child() if run_manager else None,
121120
)
122121

@@ -127,6 +126,5 @@ async def _arun(
127126
) -> str:
128127
return await self.llm_chain.apredict(
129128
query=query,
130-
client=self.db.client,
131129
callbacks=run_manager.get_child() if run_manager else None,
132130
)

0 commit comments

Comments
 (0)