@@ -94,12 +94,48 @@ def fake_react_code_llm_error(messages, stop_sequences=None) -> str:
94
94
"""
95
95
96
96
97
+ def fake_react_code_functiondef (messages , stop_sequences = None ) -> str :
98
+ prompt = str (messages )
99
+ if "special_marker" not in prompt :
100
+ return """
101
+ Thought: Let's define the function. special_marker
102
+ Code:
103
+ ```py
104
+ import numpy as np
105
+
106
+ def moving_average(x, w):
107
+ return np.convolve(x, np.ones(w), 'valid') / w
108
+ ```<end_code>
109
+ """
110
+ else : # We're at step 2
111
+ return """
112
+ Thought: I can now answer the initial question
113
+ Code:
114
+ ```py
115
+ x, w = [0, 1, 2, 3, 4, 5], 2
116
+ res = moving_average(x, w)
117
+ final_answer(res)
118
+ ```<end_code>
119
+ """
120
+
121
+
97
122
def fake_code_llm_oneshot (messages , stop_sequences = None ) -> str :
98
123
return """
99
124
Thought: I should multiply 2 by 3.6452. special_marker
100
125
Code:
101
126
```py
102
127
result = python_interpreter(code="2*3.6452")
128
+ final_answer(result)
129
+ ```
130
+ """
131
+
132
+
133
+ def fake_code_llm_no_return (messages , stop_sequences = None ) -> str :
134
+ return """
135
+ Thought: I should multiply 2 by 3.6452. special_marker
136
+ Code:
137
+ ```py
138
+ result = python_interpreter(code="2*3.6452")
103
139
print(result)
104
140
```
105
141
"""
@@ -135,8 +171,8 @@ def test_fake_react_json_agent(self):
135
171
def test_fake_react_code_agent (self ):
136
172
agent = ReactCodeAgent (tools = [PythonInterpreterTool ()], llm_engine = fake_react_code_llm )
137
173
output = agent .run ("What is 2 multiplied by 3.6452?" )
138
- assert isinstance (output , AgentText )
139
- assert output == " 7.2904"
174
+ assert isinstance (output , float )
175
+ assert output == 7.2904
140
176
assert agent .logs [0 ]["task" ] == "What is 2 multiplied by 3.6452?"
141
177
assert float (agent .logs [1 ]["observation" ].strip ()) - 12.511648 < 1e-6
142
178
assert agent .logs [2 ]["tool_call" ] == {
@@ -157,7 +193,7 @@ def test_setup_agent_with_empty_toolbox(self):
157
193
def test_react_fails_max_iterations (self ):
158
194
agent = ReactCodeAgent (
159
195
tools = [PythonInterpreterTool ()],
160
- llm_engine = fake_code_llm_oneshot , # use this callable because it never ends
196
+ llm_engine = fake_code_llm_no_return , # use this callable because it never ends
161
197
max_iterations = 5 ,
162
198
)
163
199
agent .run ("What is 2 multiplied by 3.6452?" )
@@ -192,3 +228,10 @@ def test_init_agent_with_different_toolsets(self):
192
228
# check that python_interpreter base tool does not get added to code agents
193
229
agent = ReactCodeAgent (tools = [], llm_engine = fake_react_code_llm , add_base_tools = True )
194
230
assert len (agent .toolbox .tools ) == 6 # added final_answer tool + 5 base tools (excluding interpreter)
231
+
232
+ def test_function_persistence_across_steps (self ):
233
+ agent = ReactCodeAgent (
234
+ tools = [], llm_engine = fake_react_code_functiondef , max_iterations = 2 , additional_authorized_imports = ["numpy" ]
235
+ )
236
+ res = agent .run ("ok" )
237
+ assert res [0 ] == 0.5
0 commit comments