Skip to content

Commit d5b4393

Browse files
hwchase17vbarda
andauthored
Harrison/llm math (#1808)
Co-authored-by: Vadym Barda <vadim.barda@gmail.com>
1 parent 7b6ff7f commit d5b4393

File tree

2 files changed

+310
-13
lines changed

2 files changed

+310
-13
lines changed
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "a4734146",
6+
"metadata": {},
7+
"source": [
8+
"# LLM Math\n",
9+
"\n",
10+
"Evaluating chains that know how to do math."
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": 6,
16+
"id": "fdd7afae",
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"# Comment this out if you are NOT using tracing\n",
21+
"import os\n",
22+
"os.environ[\"LANGCHAIN_HANDLER\"] = \"langchain\""
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": 7,
28+
"id": "ce05ffea",
29+
"metadata": {},
30+
"outputs": [
31+
{
32+
"data": {
33+
"application/vnd.jupyter.widget-view+json": {
34+
"model_id": "d028a511cede4de2b845b9a9954d6bea",
35+
"version_major": 2,
36+
"version_minor": 0
37+
},
38+
"text/plain": [
39+
"Downloading readme: 0%| | 0.00/21.0 [00:00<?, ?B/s]"
40+
]
41+
},
42+
"metadata": {},
43+
"output_type": "display_data"
44+
},
45+
{
46+
"name": "stdout",
47+
"output_type": "stream",
48+
"text": [
49+
"Downloading and preparing dataset json/LangChainDatasets--llm-math to /Users/harrisonchase/.cache/huggingface/datasets/LangChainDatasets___json/LangChainDatasets--llm-math-509b11d101165afa/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51...\n"
50+
]
51+
},
52+
{
53+
"data": {
54+
"application/vnd.jupyter.widget-view+json": {
55+
"model_id": "a71c8e5a21dd4da5a20a354b544f7a58",
56+
"version_major": 2,
57+
"version_minor": 0
58+
},
59+
"text/plain": [
60+
"Downloading data files: 0%| | 0/1 [00:00<?, ?it/s]"
61+
]
62+
},
63+
"metadata": {},
64+
"output_type": "display_data"
65+
},
66+
{
67+
"data": {
68+
"application/vnd.jupyter.widget-view+json": {
69+
"model_id": "ae530ca624154a1a934075c47d1093a6",
70+
"version_major": 2,
71+
"version_minor": 0
72+
},
73+
"text/plain": [
74+
"Downloading data: 0%| | 0.00/631 [00:00<?, ?B/s]"
75+
]
76+
},
77+
"metadata": {},
78+
"output_type": "display_data"
79+
},
80+
{
81+
"data": {
82+
"application/vnd.jupyter.widget-view+json": {
83+
"model_id": "7a4968df05d84bc483aa2c5039aecafe",
84+
"version_major": 2,
85+
"version_minor": 0
86+
},
87+
"text/plain": [
88+
"Extracting data files: 0%| | 0/1 [00:00<?, ?it/s]"
89+
]
90+
},
91+
"metadata": {},
92+
"output_type": "display_data"
93+
},
94+
{
95+
"data": {
96+
"application/vnd.jupyter.widget-view+json": {
97+
"model_id": "",
98+
"version_major": 2,
99+
"version_minor": 0
100+
},
101+
"text/plain": [
102+
"Generating train split: 0 examples [00:00, ? examples/s]"
103+
]
104+
},
105+
"metadata": {},
106+
"output_type": "display_data"
107+
},
108+
{
109+
"name": "stdout",
110+
"output_type": "stream",
111+
"text": [
112+
"Dataset json downloaded and prepared to /Users/harrisonchase/.cache/huggingface/datasets/LangChainDatasets___json/LangChainDatasets--llm-math-509b11d101165afa/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51. Subsequent calls will reuse this data.\n"
113+
]
114+
},
115+
{
116+
"data": {
117+
"application/vnd.jupyter.widget-view+json": {
118+
"model_id": "9a2caed96225410fb1cc0f8f155eb766",
119+
"version_major": 2,
120+
"version_minor": 0
121+
},
122+
"text/plain": [
123+
" 0%| | 0/1 [00:00<?, ?it/s]"
124+
]
125+
},
126+
"metadata": {},
127+
"output_type": "display_data"
128+
}
129+
],
130+
"source": [
131+
"from langchain.evaluation.loading import load_dataset\n",
132+
"dataset = load_dataset(\"llm-math\")"
133+
]
134+
},
135+
{
136+
"cell_type": "markdown",
137+
"id": "8a998d6f",
138+
"metadata": {},
139+
"source": [
140+
"## Setting up a chain\n",
141+
"Now we need to create some pipelines for doing math."
142+
]
143+
},
144+
{
145+
"cell_type": "code",
146+
"execution_count": 10,
147+
"id": "7078f7f8",
148+
"metadata": {},
149+
"outputs": [],
150+
"source": [
151+
"from langchain.llms import OpenAI\n",
152+
"from langchain.chains import LLMMathChain"
153+
]
154+
},
155+
{
156+
"cell_type": "code",
157+
"execution_count": 9,
158+
"id": "2bd70c46",
159+
"metadata": {},
160+
"outputs": [],
161+
"source": [
162+
"llm = OpenAI()"
163+
]
164+
},
165+
{
166+
"cell_type": "code",
167+
"execution_count": 11,
168+
"id": "954c3270",
169+
"metadata": {},
170+
"outputs": [],
171+
"source": [
172+
"chain = LLMMathChain(llm=llm)"
173+
]
174+
},
175+
{
176+
"cell_type": "code",
177+
"execution_count": 13,
178+
"id": "f252027e",
179+
"metadata": {},
180+
"outputs": [],
181+
"source": [
182+
"predictions = chain.apply(dataset)"
183+
]
184+
},
185+
{
186+
"cell_type": "code",
187+
"execution_count": 22,
188+
"id": "c8af7041",
189+
"metadata": {},
190+
"outputs": [],
191+
"source": [
192+
"numeric_output = [float(p['answer'].strip().strip(\"Answer: \")) for p in predictions]"
193+
]
194+
},
195+
{
196+
"cell_type": "code",
197+
"execution_count": 23,
198+
"id": "cc09ffe4",
199+
"metadata": {},
200+
"outputs": [],
201+
"source": [
202+
"correct = [example['answer'] == numeric_output[i] for i, example in enumerate(dataset)]"
203+
]
204+
},
205+
{
206+
"cell_type": "code",
207+
"execution_count": 24,
208+
"id": "585244e4",
209+
"metadata": {},
210+
"outputs": [
211+
{
212+
"data": {
213+
"text/plain": [
214+
"1.0"
215+
]
216+
},
217+
"execution_count": 24,
218+
"metadata": {},
219+
"output_type": "execute_result"
220+
}
221+
],
222+
"source": [
223+
"sum(correct) / len(correct)"
224+
]
225+
},
226+
{
227+
"cell_type": "code",
228+
"execution_count": 25,
229+
"id": "0d14ac78",
230+
"metadata": {},
231+
"outputs": [
232+
{
233+
"name": "stdout",
234+
"output_type": "stream",
235+
"text": [
236+
"input: 5\n",
237+
"expected output : 5.0\n",
238+
"prediction: 5.0\n",
239+
"input: 5 + 3\n",
240+
"expected output : 8.0\n",
241+
"prediction: 8.0\n",
242+
"input: 2^3.171\n",
243+
"expected output : 9.006708689094099\n",
244+
"prediction: 9.006708689094099\n",
245+
"input: 2 ^3.171 \n",
246+
"expected output : 9.006708689094099\n",
247+
"prediction: 9.006708689094099\n",
248+
"input: two to the power of three point one hundred seventy one\n",
249+
"expected output : 9.006708689094099\n",
250+
"prediction: 9.006708689094099\n",
251+
"input: five + three squared minus 1\n",
252+
"expected output : 13.0\n",
253+
"prediction: 13.0\n",
254+
"input: 2097 times 27.31\n",
255+
"expected output : 57269.07\n",
256+
"prediction: 57269.07\n",
257+
"input: two thousand ninety seven times twenty seven point thirty one\n",
258+
"expected output : 57269.07\n",
259+
"prediction: 57269.07\n",
260+
"input: 209758 / 2714\n",
261+
"expected output : 77.28739867354459\n",
262+
"prediction: 77.28739867354459\n",
263+
"input: 209758.857 divided by 2714.31\n",
264+
"expected output : 77.27888745205964\n",
265+
"prediction: 77.27888745205964\n"
266+
]
267+
}
268+
],
269+
"source": [
270+
"for i, example in enumerate(dataset):\n",
271+
" print(\"input: \", example[\"question\"])\n",
272+
" print(\"expected output :\", example[\"answer\"])\n",
273+
" print(\"prediction: \", numeric_output[i])"
274+
]
275+
},
276+
{
277+
"cell_type": "code",
278+
"execution_count": null,
279+
"id": "b9021ffd",
280+
"metadata": {},
281+
"outputs": [],
282+
"source": []
283+
}
284+
],
285+
"metadata": {
286+
"kernelspec": {
287+
"display_name": "Python 3 (ipykernel)",
288+
"language": "python",
289+
"name": "python3"
290+
},
291+
"language_info": {
292+
"codemirror_mode": {
293+
"name": "ipython",
294+
"version": 3
295+
},
296+
"file_extension": ".py",
297+
"mimetype": "text/x-python",
298+
"name": "python",
299+
"nbconvert_exporter": "python",
300+
"pygments_lexer": "ipython3",
301+
"version": "3.9.1"
302+
}
303+
},
304+
"nbformat": 4,
305+
"nbformat_minor": 5
306+
}

langchain/chains/llm_math/prompt.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,17 @@
11
# flake8: noqa
22
from langchain.prompts.prompt import PromptTemplate
33

4-
_PROMPT_TEMPLATE = """You are GPT-3, and you can't do math.
4+
_PROMPT_TEMPLATE = """Translate a math problem into Python code that can be executed in Python 3 REPL. Use the output of running this code to answer the question.
55
6-
You can do basic math, and your memorization abilities are impressive, but you can't do any complex calculations that a human could not do in their head. You also have an annoying tendency to just make up highly specific, but wrong, answers.
7-
8-
So we hooked you up to a Python 3 kernel, and now you can execute code. If anyone gives you a hard math problem, just use this format and we’ll take care of the rest:
9-
10-
Question: ${{Question with hard calculation.}}
6+
Question: ${{Question with math problem.}}
117
```python
12-
${{Code that prints what you need to know}}
8+
${{Code that solves the problem and prints the solution}}
139
```
1410
```output
15-
${{Output of your code}}
11+
${{Output of running the code}}
1612
```
1713
Answer: ${{Answer}}
1814
19-
Otherwise, use this simpler format:
20-
21-
Question: ${{Question without hard calculation}}
22-
Answer: ${{Answer}}
23-
2415
Begin.
2516
2617
Question: What is 37593 * 67?

0 commit comments

Comments
 (0)