Skip to content

Commit 1556025

Browse files
Code agent: allow function persistence between steps (huggingface#31769)
* Code agent: allow function persistence between steps
1 parent eef0507 commit 1556025

File tree

5 files changed

+63
-11
lines changed

5 files changed

+63
-11
lines changed

src/transformers/agents/agent_types.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def __init__(self, value, samplerate=16_000):
188188
self.samplerate = samplerate
189189
if isinstance(value, (str, pathlib.Path)):
190190
self._path = value
191-
elif isinstance(value, torch.Tensor):
191+
elif is_torch_available() and isinstance(value, torch.Tensor):
192192
self._tensor = value
193193
elif isinstance(value, tuple):
194194
self.samplerate = value[0]
@@ -232,7 +232,10 @@ def to_string(self):
232232

233233

234234
AGENT_TYPE_MAPPING = {"text": AgentText, "image": AgentImage, "audio": AgentAudio}
235-
INSTANCE_TYPE_MAPPING = {str: AgentText, float: AgentText, int: AgentText, Tensor: AgentAudio, ImageType: AgentImage}
235+
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage}
236+
237+
if is_torch_available():
238+
INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio
236239

237240

238241
def handle_agent_inputs(*args, **kwargs):
@@ -251,4 +254,4 @@ def handle_agent_outputs(output, output_type=None):
251254
for _k, _v in INSTANCE_TYPE_MAPPING.items():
252255
if isinstance(output, _k):
253256
return _v(output)
254-
return AgentType(output)
257+
return output

src/transformers/agents/agents.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,10 @@ def __init__(
856856
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
857857
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
858858
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
859+
self.available_tools = {
860+
**BASE_PYTHON_TOOLS.copy(),
861+
**self.toolbox.tools,
862+
} # This list can be augmented by the code agent creating some new functions
859863

860864
def step(self):
861865
"""
@@ -905,10 +909,9 @@ def step(self):
905909
# Execute
906910
self.log_code_action(code_action)
907911
try:
908-
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
909912
result = self.python_evaluator(
910913
code_action,
911-
available_tools,
914+
tools=self.available_tools,
912915
state=self.state,
913916
authorized_imports=self.authorized_imports,
914917
)

src/transformers/agents/python_interpreter.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,10 @@ def evaluate_ast(
778778

779779

780780
def evaluate_python_code(
781-
code: str, tools: Optional[Dict[str, Callable]] = {}, state=None, authorized_imports: List[str] = LIST_SAFE_MODULES
781+
code: str,
782+
tools: Optional[Dict[str, Callable]] = None,
783+
state: Optional[Dict[str, Any]] = None,
784+
authorized_imports: List[str] = LIST_SAFE_MODULES,
782785
):
783786
"""
784787
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
@@ -803,6 +806,8 @@ def evaluate_python_code(
803806
raise SyntaxError(f"The code generated by the agent is not valid.\n{e}")
804807
if state is None:
805808
state = {}
809+
if tools is None:
810+
tools = {}
806811
result = None
807812
global PRINT_OUTPUTS
808813
PRINT_OUTPUTS = ""

tests/agents/test_agents.py

+46-3
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,48 @@ def fake_react_code_llm_error(messages, stop_sequences=None) -> str:
9494
"""
9595

9696

97+
def fake_react_code_functiondef(messages, stop_sequences=None) -> str:
98+
prompt = str(messages)
99+
if "special_marker" not in prompt:
100+
return """
101+
Thought: Let's define the function. special_marker
102+
Code:
103+
```py
104+
import numpy as np
105+
106+
def moving_average(x, w):
107+
return np.convolve(x, np.ones(w), 'valid') / w
108+
```<end_code>
109+
"""
110+
else: # We're at step 2
111+
return """
112+
Thought: I can now answer the initial question
113+
Code:
114+
```py
115+
x, w = [0, 1, 2, 3, 4, 5], 2
116+
res = moving_average(x, w)
117+
final_answer(res)
118+
```<end_code>
119+
"""
120+
121+
97122
def fake_code_llm_oneshot(messages, stop_sequences=None) -> str:
98123
return """
99124
Thought: I should multiply 2 by 3.6452. special_marker
100125
Code:
101126
```py
102127
result = python_interpreter(code="2*3.6452")
128+
final_answer(result)
129+
```
130+
"""
131+
132+
133+
def fake_code_llm_no_return(messages, stop_sequences=None) -> str:
134+
return """
135+
Thought: I should multiply 2 by 3.6452. special_marker
136+
Code:
137+
```py
138+
result = python_interpreter(code="2*3.6452")
103139
print(result)
104140
```
105141
"""
@@ -135,8 +171,8 @@ def test_fake_react_json_agent(self):
135171
def test_fake_react_code_agent(self):
136172
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm)
137173
output = agent.run("What is 2 multiplied by 3.6452?")
138-
assert isinstance(output, AgentText)
139-
assert output == "7.2904"
174+
assert isinstance(output, float)
175+
assert output == 7.2904
140176
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
141177
assert float(agent.logs[1]["observation"].strip()) - 12.511648 < 1e-6
142178
assert agent.logs[2]["tool_call"] == {
@@ -157,7 +193,7 @@ def test_setup_agent_with_empty_toolbox(self):
157193
def test_react_fails_max_iterations(self):
158194
agent = ReactCodeAgent(
159195
tools=[PythonInterpreterTool()],
160-
llm_engine=fake_code_llm_oneshot, # use this callable because it never ends
196+
llm_engine=fake_code_llm_no_return, # use this callable because it never ends
161197
max_iterations=5,
162198
)
163199
agent.run("What is 2 multiplied by 3.6452?")
@@ -192,3 +228,10 @@ def test_init_agent_with_different_toolsets(self):
192228
# check that python_interpreter base tool does not get added to code agents
193229
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
194230
assert len(agent.toolbox.tools) == 6 # added final_answer tool + 5 base tools (excluding interpreter)
231+
232+
def test_function_persistence_across_steps(self):
233+
agent = ReactCodeAgent(
234+
tools=[], llm_engine=fake_react_code_functiondef, max_iterations=2, additional_authorized_imports=["numpy"]
235+
)
236+
res = agent.run("ok")
237+
assert res[0] == 0.5

tests/agents/test_python_interpreter.py

-2
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,6 @@ def add_one(n, shift):
660660
"""
661661
state = {}
662662
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
663-
print(state)
664663
assert result == 2
665664

666665
# test returning None
@@ -672,5 +671,4 @@ def returns_none(a):
672671
"""
673672
state = {}
674673
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
675-
print(state)
676674
assert result is None

0 commit comments

Comments
 (0)