Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions tests/test_stateful_tool_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,22 @@ def test_stateful_tool_env_add_tool_skips_args(self, mock_stateful_tool_env):

@pytest.mark.asyncio
async def test_tool_env_tool_invalid_json_arguments(
self, mock_stateful_tool_env, mock_openai_client, sample_chat_dataset
self, mock_openai_client, sample_chat_dataset
):
"""Test that ToolEnv stops rollout when tool call is not JSON-parsable."""
"""Test that StatefulToolEnv stops rollout when tool call is not JSON-parsable."""

class ParseErrorStatefulToolEnv(vf.StatefulToolEnv):
def __init__(self, **kwargs):
super().__init__(tools=[], stop_errors=[vf.ToolParseError], **kwargs)

def update_tool_args(self, tool_name, tool_args, messages, state, **kwargs):
return tool_args

env = ParseErrorStatefulToolEnv(
client=mock_openai_client,
model="test-model",
dataset=sample_chat_dataset,
)

# Create a tool call with invalid JSON arguments
from openai.types.chat.chat_completion_message_tool_call import (
Expand All @@ -112,7 +125,7 @@ async def test_tool_env_tool_invalid_json_arguments(
tool_calls=[tool_call_with_invalid_json_arguments],
)

state = await mock_stateful_tool_env.rollout(
state = await env.rollout(
input=RolloutInput(
prompt=[{"role": "user", "content": "Square 4"}],
answer="",
Expand Down
35 changes: 31 additions & 4 deletions verifiers/envs/stateful_tool_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ async def env_response(
tool_messages = []
last_msg = cast(ChatCompletionAssistantMessageParam, messages[-1])
for tool_call in last_msg.get("tool_calls", []):
tool_call_id: str = tool_call.get("id", "")
try:
tool_name: str = tool_call.get("function", {}).get("name", "")
parsed_args = json.loads(
Expand All @@ -121,16 +122,42 @@ async def env_response(
)
tool_args: dict = parsed_args
except Exception as e:
raise vf.ToolParseError(cause=e)
tool_call_id: str = tool_call.get("id", "")
err = vf.ToolParseError(cause=e)
if self._should_stop_for_error(err):
raise err
tool_messages.append(
cast(
vf.Message,
{
"role": "tool",
"content": self.error_formatter(e),
"tool_call_id": tool_call_id,
},
)
)
continue

tool_args = self.update_tool_args(
tool_name, tool_args, messages, state, **kwargs
)
try:
tool_message: vf.Message = await self.call_tool(
tool_name, tool_args, tool_call_id
)
tool_messages.append(tool_message)
except Exception as e:
raise vf.ToolCallError(cause=e)
tool_messages.append(tool_message)
err = vf.ToolCallError(cause=e)
if self._should_stop_for_error(err):
raise err
tool_messages.append(
cast(
vf.Message,
{
"role": "tool",
"content": self.error_formatter(e),
"tool_call_id": tool_call_id,
},
)
)

return tool_messages
Loading