Skip to content

Commit befe19d

Browse files
vrtnisseratch
andauthored
enhancement: Add tool_name to ToolContext to support shared tool handlers (#1043)
This adds a `tool_name` field to `ToolContext`, which gets passed into the `on_invoke_tool` handler. Helpful for scenarios where we dynamically register multiple tools that all share a single generic handler e.g.in multi-agent setups. As such, by including the name of the tool that was invoked, the handler can now easily branch logic or route requests accordingly. Resolves #1030 All tests pass. and here is a script to test it out https://gist.github.com/vrtnis/ca354244f7a5ecd9a73c0a2d34cb194b --------- Co-authored-by: Kazuhiro Sera <seratch@openai.com>
1 parent 741da67 commit befe19d

File tree

7 files changed

+61
-21
lines changed

7 files changed

+61
-21
lines changed

docs/ref/tool_context.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# `Tool context`
2+
3+
::: agents.tool_context

docs/tools.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ Sometimes, you don't want to use a Python function as a tool. You can directly c
180180
- `name`
181181
- `description`
182182
- `params_json_schema`, which is the JSON schema for the arguments
183-
- `on_invoke_tool`, which is an async function that receives the context and the arguments as a JSON string, and must return the tool output as a string.
183+
- `on_invoke_tool`, which is an async function that receives a [`ToolContext`][agents.tool_context.ToolContext] and the arguments as a JSON string, and must return the tool output as a string.
184184

185185
```python
186186
from typing import Any

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ plugins:
9292
- ref/lifecycle.md
9393
- ref/items.md
9494
- ref/run_context.md
95+
- ref/tool_context.md
9596
- ref/usage.md
9697
- ref/exceptions.md
9798
- ref/guardrail.md

src/agents/_run_impl.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,11 @@ async def run_single_tool(
548548
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
549549
) -> Any:
550550
with function_span(func_tool.name) as span_fn:
551-
tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id)
551+
tool_context = ToolContext.from_agent_context(
552+
context_wrapper,
553+
tool_call.call_id,
554+
tool_call=tool_call,
555+
)
552556
if config.trace_include_sensitive_data:
553557
span_fn.span_data.input = tool_call.arguments
554558
try:

src/agents/tool_context.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from dataclasses import dataclass, field, fields
2-
from typing import Any
2+
from typing import Any, Optional
3+
4+
from openai.types.responses import ResponseFunctionToolCall
35

46
from .run_context import RunContextWrapper, TContext
57

@@ -8,16 +10,26 @@ def _assert_must_pass_tool_call_id() -> str:
810
raise ValueError("tool_call_id must be passed to ToolContext")
911

1012

13+
def _assert_must_pass_tool_name() -> str:
14+
raise ValueError("tool_name must be passed to ToolContext")
15+
16+
1117
@dataclass
1218
class ToolContext(RunContextWrapper[TContext]):
1319
"""The context of a tool call."""
1420

21+
tool_name: str = field(default_factory=_assert_must_pass_tool_name)
22+
"""The name of the tool being invoked."""
23+
1524
tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
1625
"""The ID of the tool call."""
1726

1827
@classmethod
1928
def from_agent_context(
20-
cls, context: RunContextWrapper[TContext], tool_call_id: str
29+
cls,
30+
context: RunContextWrapper[TContext],
31+
tool_call_id: str,
32+
tool_call: Optional[ResponseFunctionToolCall] = None,
2133
) -> "ToolContext":
2234
"""
2335
Create a ToolContext from a RunContextWrapper.
@@ -26,4 +38,5 @@ def from_agent_context(
2638
base_values: dict[str, Any] = {
2739
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
2840
}
29-
return cls(tool_call_id=tool_call_id, **base_values)
41+
tool_name = tool_call.name if tool_call is not None else _assert_must_pass_tool_name()
42+
return cls(tool_name=tool_name, tool_call_id=tool_call_id, **base_values)

tests/test_function_tool.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ async def test_argless_function():
2626
tool = function_tool(argless_function)
2727
assert tool.name == "argless_function"
2828

29-
result = await tool.on_invoke_tool(ToolContext(context=None, tool_call_id="1"), "")
29+
result = await tool.on_invoke_tool(
30+
ToolContext(context=None, tool_name=tool.name, tool_call_id="1"), ""
31+
)
3032
assert result == "ok"
3133

3234

@@ -39,11 +41,13 @@ async def test_argless_with_context():
3941
tool = function_tool(argless_with_context)
4042
assert tool.name == "argless_with_context"
4143

42-
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
44+
result = await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "")
4345
assert result == "ok"
4446

4547
# Extra JSON should not raise an error
46-
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
48+
result = await tool.on_invoke_tool(
49+
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}'
50+
)
4751
assert result == "ok"
4852

4953

@@ -56,15 +60,19 @@ async def test_simple_function():
5660
tool = function_tool(simple_function, failure_error_function=None)
5761
assert tool.name == "simple_function"
5862

59-
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
63+
result = await tool.on_invoke_tool(
64+
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}'
65+
)
6066
assert result == 6
6167

62-
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1, "b": 2}')
68+
result = await tool.on_invoke_tool(
69+
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1, "b": 2}'
70+
)
6371
assert result == 3
6472

6573
# Missing required argument should raise an error
6674
with pytest.raises(ModelBehaviorError):
67-
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
75+
await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "")
6876

6977

7078
class Foo(BaseModel):
@@ -92,7 +100,9 @@ async def test_complex_args_function():
92100
"bar": Bar(x="hello", y=10),
93101
}
94102
)
95-
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
103+
result = await tool.on_invoke_tool(
104+
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
105+
)
96106
assert result == "6 hello10 hello"
97107

98108
valid_json = json.dumps(
@@ -101,7 +111,9 @@ async def test_complex_args_function():
101111
"bar": Bar(x="hello", y=10),
102112
}
103113
)
104-
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
114+
result = await tool.on_invoke_tool(
115+
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
116+
)
105117
assert result == "3 hello10 hello"
106118

107119
valid_json = json.dumps(
@@ -111,12 +123,16 @@ async def test_complex_args_function():
111123
"baz": "world",
112124
}
113125
)
114-
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
126+
result = await tool.on_invoke_tool(
127+
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
128+
)
115129
assert result == "3 hello10 world"
116130

117131
# Missing required argument should raise an error
118132
with pytest.raises(ModelBehaviorError):
119-
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"foo": {"a": 1}}')
133+
await tool.on_invoke_tool(
134+
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"foo": {"a": 1}}'
135+
)
120136

121137

122138
def test_function_config_overrides():
@@ -176,7 +192,9 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
176192
assert tool.params_json_schema[key] == value
177193
assert tool.strict_json_schema
178194

179-
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"data": "hello"}')
195+
result = await tool.on_invoke_tool(
196+
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"data": "hello"}'
197+
)
180198
assert result == "hello_done"
181199

182200
tool_not_strict = FunctionTool(
@@ -191,7 +209,8 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
191209
assert "additionalProperties" not in tool_not_strict.params_json_schema
192210

193211
result = await tool_not_strict.on_invoke_tool(
194-
ToolContext(None, tool_call_id="1"), '{"data": "hello", "bar": "baz"}'
212+
ToolContext(None, tool_name=tool_not_strict.name, tool_call_id="1"),
213+
'{"data": "hello", "bar": "baz"}',
195214
)
196215
assert result == "hello_done"
197216

@@ -202,7 +221,7 @@ def my_func(a: int, b: int = 5):
202221
raise ValueError("test")
203222

204223
tool = function_tool(my_func)
205-
ctx = ToolContext(None, tool_call_id="1")
224+
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")
206225

207226
result = await tool.on_invoke_tool(ctx, "")
208227
assert "Invalid JSON" in str(result)
@@ -226,7 +245,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
226245
return f"error_{error.__class__.__name__}"
227246

228247
tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
229-
ctx = ToolContext(None, tool_call_id="1")
248+
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")
230249

231250
result = await tool.on_invoke_tool(ctx, "")
232251
assert result == "error_ModelBehaviorError"
@@ -250,7 +269,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
250269
return f"error_{error.__class__.__name__}"
251270

252271
tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
253-
ctx = ToolContext(None, tool_call_id="1")
272+
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")
254273

255274
result = await tool.on_invoke_tool(ctx, "")
256275
assert result == "error_ModelBehaviorError"

tests/test_function_tool_decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self):
1616

1717

1818
def ctx_wrapper() -> ToolContext[DummyContext]:
19-
return ToolContext(context=DummyContext(), tool_call_id="1")
19+
return ToolContext(context=DummyContext(), tool_name="dummy", tool_call_id="1")
2020

2121

2222
@function_tool

0 commit comments

Comments
 (0)