-
Notifications
You must be signed in to change notification settings - Fork 15.8k
/
toolkit.py
136 lines (108 loc) Β· 4.69 KB
/
toolkit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""Toolkit for interacting with an SQL database."""
from typing import List
from langchain_core.caches import BaseCache as BaseCache
from langchain_core.callbacks import Callbacks as Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from pydantic import ConfigDict, Field
from langchain_community.tools.sql_database.tool import (
InfoSQLDatabaseTool,
ListSQLDatabaseTool,
QuerySQLCheckerTool,
QuerySQLDataBaseTool,
)
from langchain_community.utilities.sql_database import SQLDatabase
class SQLDatabaseToolkit(BaseToolkit):
"""SQLDatabaseToolkit for interacting with SQL databases.
Setup:
Install ``langchain-community``.
.. code-block:: bash
pip install -U langchain-community
Key init args:
db: SQLDatabase
The SQL database.
llm: BaseLanguageModel
The language model (for use with QuerySQLCheckerTool)
Instantiate:
.. code-block:: python
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_openai import ChatOpenAI
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
llm = ChatOpenAI(temperature=0)
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
Tools:
.. code-block:: python
toolkit.get_tools()
Use within an agent:
.. code-block:: python
from langchain import hub
from langgraph.prebuilt import create_react_agent
# Pull prompt (or define your own)
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
system_message = prompt_template.format(dialect="SQLite", top_k=5)
# Create agent
agent_executor = create_react_agent(
llm, toolkit.get_tools(), state_modifier=system_message
)
# Query agent
example_query = "Which country's customers spent the most?"
events = agent_executor.stream(
{"messages": [("user", example_query)]},
stream_mode="values",
)
for event in events:
event["messages"][-1].pretty_print()
""" # noqa: E501
db: SQLDatabase = Field(exclude=True)
llm: BaseLanguageModel = Field(exclude=True)
@property
def dialect(self) -> str:
"""Return string representation of SQL dialect to use."""
return self.db.dialect
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
list_sql_database_tool = ListSQLDatabaseTool(db=self.db)
info_sql_database_tool_description = (
"Input to this tool is a comma-separated list of tables, output is the "
"schema and sample rows for those tables. "
"Be sure that the tables actually exist by calling "
f"{list_sql_database_tool.name} first! "
"Example Input: table1, table2, table3"
)
info_sql_database_tool = InfoSQLDatabaseTool(
db=self.db, description=info_sql_database_tool_description
)
query_sql_database_tool_description = (
"Input to this tool is a detailed and correct SQL query, output is a "
"result from the database. If the query is not correct, an error message "
"will be returned. If an error is returned, rewrite the query, check the "
"query, and try again. If you encounter an issue with Unknown column "
f"'xxxx' in 'field list', use {info_sql_database_tool.name} "
"to query the correct table fields."
)
query_sql_database_tool = QuerySQLDataBaseTool(
db=self.db, description=query_sql_database_tool_description
)
query_sql_checker_tool_description = (
"Use this tool to double check if your query is correct before executing "
"it. Always use this tool before executing a query with "
f"{query_sql_database_tool.name}!"
)
query_sql_checker_tool = QuerySQLCheckerTool(
db=self.db, llm=self.llm, description=query_sql_checker_tool_description
)
return [
query_sql_database_tool,
info_sql_database_tool,
list_sql_database_tool,
query_sql_checker_tool,
]
def get_context(self) -> dict:
"""Return db context that you may want in agent prompt."""
return self.db.get_context()
SQLDatabaseToolkit.model_rebuild()