1+ # flake8: noqa
2+ """Tools for interacting with a MongoDB database."""
3+ from typing import Any , Dict , Optional
4+
5+ from langchain .pydantic_v1 import BaseModel , Extra , Field , root_validator
6+
7+ from langchain .schema .language_model import BaseLanguageModel
8+ from langchain .callbacks .manager import (
9+ AsyncCallbackManagerForToolRun ,
10+ CallbackManagerForToolRun ,
11+ )
12+ from langchain .chains .llm import LLMChain
13+ from langchain .prompts import PromptTemplate
14+ from langchain .utilities .mongo_database import MongoDBDatabase
15+ from langchain .tools .base import BaseTool
16+ from langchain .tools .mongo_database .prompt import QUERY_CHECKER
17+
18+
19+ class BaseMongoDBTool (BaseModel ):
20+ """Base tool for interacting with a MongoDB database."""
21+
22+ db : MongoDBDatabase = Field (exclude = True )
23+
24+ class Config (BaseTool .Config ):
25+ pass
26+
27+
28+ class QueryMongoDBTool (BaseMongoDBTool , BaseTool ):
29+ """Tool for querying a MongoDB database."""
30+
31+ name : str = "mongo_db_query"
32+ description : str = """
33+ Input to this tool is a detailed and correct MongoDB query, output is a result from the database.
34+ If the query is not correct, an error message will be returned.
35+ If an error is returned, rewrite the query, check the query, and try again.
36+ """
37+
38+ def _run (
39+ self ,
40+ query : str ,
41+ run_manager : Optional [CallbackManagerForToolRun ] = None ,
42+ ) -> str :
43+ """Execute the query, return the results or an error message."""
44+ return self .db .run (query )
45+
46+
47+ class InfoMongoDBTool (BaseMongoDBTool , BaseTool ):
48+ """Tool for getting metadata about a MongoDB database."""
49+
50+ name : str = "mongo_db_schema"
51+ description : str = """
52+ Input to this tool is a comma-separated list of collections, output is the schema and sample documents for those collections.
53+
54+ Example Input: "collection1, collection2, collection3"
55+ """
56+
57+ def _run (
58+ self ,
59+ collection_names : str ,
60+ run_manager : Optional [CallbackManagerForToolRun ] = None ,
61+ ) -> str :
62+ """Get information about specified collections."""
63+ return self .db .get_document_info (collection_names = collection_names )
64+
65+
66+ class ListMongoDBTool (BaseMongoDBTool , BaseTool ):
67+ """Tool for listing collections in a MongoDB database."""
68+
69+ name : str = "mongo_db_list"
70+ description : str = """
71+ Output of this tool is a list of collections in the database.
72+ """
73+
74+ def _run (
75+ self ,
76+ run_manager : Optional [CallbackManagerForToolRun ] = None ,
77+ ) -> str :
78+ """Get a list of collections in the database."""
79+ return self .db .collection_info ()
80+
81+
82+ class QueryMongoDBCheckerTool (BaseMongoDBTool , BaseTool ):
83+ """Use an LLM to check if a query is correct"""
84+
85+ template : str = QUERY_CHECKER
86+ llm : BaseLanguageModel
87+ llm_chain : LLMChain = Field (init = False )
88+ name : str = "mongo_db_query_checker"
89+ description : str = """
90+ Use this tool to double check a MongoDB query for common mistakes.
91+ """
92+
93+ @root_validator (pre = True )
94+ def _init_llm_chain (cls , values : Dict [str , Any ]) -> Dict [str , Any ]:
95+ """Initialize the LLM chain."""
96+ if "llm_chain" not in values :
97+ values ["llm_chain" ] = LLMChain (
98+ llm = values .get ("llm" ),
99+ prompt = PromptTemplate (
100+ template = QUERY_CHECKER , input_variables = ["client" , "query" ]
101+ ),
102+ )
103+
104+ if values ["llm_chain" ].prompt .input_variables != ["client" , "query" ]:
105+ raise ValueError (
106+ "LLM chain for QueryCheckerTool must have input variables ['query', 'client']"
107+ )
108+
109+ return values
110+
111+ def _run (
112+ self ,
113+ query : str ,
114+ run_manager : Optional [CallbackManagerForToolRun ] = None ,
115+ ) -> str :
116+ """Use the LLM to check the query."""
117+ return self .llm_chain .predict (
118+ query = query ,
119+ client = self .db .client ,
120+ callbacks = run_manager .get_child () if run_manager else None ,
121+ )
122+
123+ async def _arun (
124+ self ,
125+ query : str ,
126+ run_manager : Optional [AsyncCallbackManagerForToolRun ] = None ,
127+ ) -> str :
128+ return await self .llm_chain .apredict (
129+ query = query ,
130+ client = self .db .client ,
131+ callbacks = run_manager .get_child () if run_manager else None ,
132+ )
0 commit comments